In [114]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets,transforms, models
import torchvision
import pandas as pd

import math
import random
import time
import numpy as np
from typing import Optional
import os
import imageio
import time
import warnings
import sys
import copy
import json
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import tqdm
%matplotlib inline

In [115]:
RANDOM_SEED = 999
torch.manual_seed(RANDOM_SEED)
torch.cuda.manual_seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
random.seed(RANDOM_SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# Create a general generator for use with the validation dataloader,
# the test dataloader, and the unsupervised dataloader
general_generator = torch.Generator()
general_generator.manual_seed(RANDOM_SEED)
# Create a training generator to isolate the train dataloader from
# other dataloaders and better control non-deterministic behavior
train_generator = torch.Generator()
train_generator.manual_seed(RANDOM_SEED)
def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2 ** 32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

In [None]:
train_dataset_dir = r'X:\my_thermal_nose_dataset\RawData_CHUNK160_RESIZE128_littleROI_TRAIN\mydataset'
valid_dataset_dir = r'X:\my_thermal_nose_dataset\RawData_CHUNK160_RESIZE128_littleROI_TEST\mydataset'

# train_dataset_dir = r'X:\my_thermal_nose_dataset\RawData_CHUNK160_RESIZE96_TRAIN\mydataset'
# valid_dataset_dir = r'X:\my_thermal_nose_dataset\RawData_CHUNK160_RESIZE96_TEST\mydataset'
class MConfig:
    DEVICE = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
    EPOCHS = 80
    MODEL_DIR = 'PreTrainedModels'
    DATA_PATH = r'xx'
    MODEL_FILE_NAME = 'ThermRNet'
    TEST_BATCH_SIZE = 1
    TRAIN_BATCH_SIZE = 4
    NUM_OF_GPU_TRAIN = 1
    DO_CHUNK = True
    CHUNK_LENGTH = 160 
    RESIZE_H = 96
    RESIZE_W = 96
    LR = 9e-5   # 自己的训练集权重大一点
    TOOLBOX_MODE = "train_and_test"
    TEST_USE_LAST_EPOCH = True
    BEGIN = 0.0
    END = 1.0
    DATA_FORMAT =  "NCDHW"  # 从Dataset中取出的数据格式  N表示一个视频分段的数量，D表示一个视频段的帧数，C:channel，H:height，W:width
    PATCH_SIZE = 4
    DIM = 96
    FF_DIM = 144
    NUM_HEADS = 4
    NUM_LAYERS = 4
    THETA = 0.7
    DROP_RATE = 0.2
    GRA_SHARP = 2.0

config = MConfig()

In [117]:
# 是否用GPU训练
train_on_gpu = torch.cuda.is_available()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
if not train_on_gpu:
    print('CUDA is not available.  Training on CPU ...')
else:
    print('CUDA is available!  Training on GPU ...')

cuda:0
CUDA is available!  Training on GPU ...


In [118]:
def set_parameter_requires_grad(model, feature_extracting):
    if feature_extracting:  # 用于特征提取，则不更新网络的任何学习参数
        for param in model.parameters():
            param.requires_grad = False

## Transformer Block

In [119]:
def as_tuple(x):
    return x if isinstance(x, tuple) else (x, x)


'''
Temporal Center-difference based Convolutional layer (3D version)
theta: control the percentage of original convolution and centeral-difference convolution
'''


class CDC_T(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,
                 padding=1, dilation=1, groups=1, bias=False, theta=0.6):

        super(CDC_T, self).__init__()
        self.conv = nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,
                              dilation=dilation, groups=groups, bias=bias)
        self.theta = theta

    def forward(self, x):
        out_normal = self.conv(x)

        if math.fabs(self.theta - 0.0) < 1e-8:
            return out_normal
        else:
            [C_out, C_in, t, kernel_size, kernel_size] = self.conv.weight.shape

            # only CD works on temporal kernel size>1
            if self.conv.weight.shape[2] > 1:
                kernel_diff = self.conv.weight[:, :, 0, :, :].sum(2).sum(2) + self.conv.weight[:, :, 2, :, :].sum(
                    2).sum(2)
                kernel_diff = kernel_diff[:, :, None, None, None]
                out_diff = F.conv3d(input=x, weight=kernel_diff, bias=self.conv.bias, stride=self.conv.stride,
                                    padding=0, dilation=self.conv.dilation, groups=self.conv.groups)
                return out_normal - self.theta * out_diff

            else:
                return out_normal


def split_last(x, shape):
    "split the last dimension to given shape"
    shape = list(shape)
    assert shape.count(-1) <= 1
    if -1 in shape:
        shape[shape.index(-1)] = int(x.size(-1) / -np.prod(shape))
    return x.view(*x.size()[:-1], *shape)


def merge_last(x, n_dims):
    "merge the last n_dims to a dimension"
    s = x.size()
    assert n_dims > 1 and n_dims < len(s)
    return x.view(*s[:-n_dims], -1)


class MultiHeadedSelfAttention_TDC_gra_sharp(nn.Module):
    """Multi-Headed Dot Product Attention with depth-wise Conv3d"""

    def __init__(self, dim, num_heads, dropout, theta):
        super().__init__()

        self.proj_q = nn.Sequential(
            CDC_T(dim, dim, 3, stride=1, padding=1, groups=1, bias=False, theta=theta),
            nn.BatchNorm3d(dim),
        )
        self.proj_k = nn.Sequential(
            CDC_T(dim, dim, 3, stride=1, padding=1, groups=1, bias=False, theta=theta),
            nn.BatchNorm3d(dim),
        )
        self.proj_v = nn.Sequential(
            nn.Conv3d(dim, dim, 1, stride=1, padding=0, groups=1, bias=False),
        )

        self.drop = nn.Dropout(dropout)
        self.n_heads = num_heads
        self.scores = None  # for visualization

    def forward(self, x, gra_sharp):  # [B, 3*3*40, 128]
        """
        x, q(query), k(key), v(value) : (B(batch_size), S(seq_len), D(dim))
        mask : (B(batch_size) x S(seq_len))
        * split D(dim) into (H(n_heads), W(width of head)) ; D = H * W
        """
        # (B, S, D) -proj-> (B, S, D) -split-> (B, S, H, W) -trans-> (B, H, S, W)

        [B, P, C] = x.shape
        x = x.transpose(1, 2).view(B, C, P // 9, 3, 3)  # [B, dim, 40, 3, 3]
        q, k, v = self.proj_q(x), self.proj_k(x), self.proj_v(x)
        q = q.flatten(2).transpose(1, 2)  # [B, 3*3*40, dim]
        k = k.flatten(2).transpose(1, 2)  # [B, 3*3*40, dim]
        v = v.flatten(2).transpose(1, 2)  # [B, 3*3*40, dim]

        q, k, v = (split_last(x, (self.n_heads, -1)).transpose(1, 2) for x in [q, k, v])
        # (B, H, S, W) @ (B, H, W, S) -> (B, H, S, S) -softmax-> (B, H, S, S)
        scores = q @ k.transpose(-2, -1) / gra_sharp

        scores = self.drop(F.softmax(scores, dim=-1))
        # (B, H, S, S) @ (B, H, S, W) -> (B, H, S, W) -trans-> (B, S, H, W)
        h = (scores @ v).transpose(1, 2).contiguous()
        # -merge-> (B, S, D)
        h = merge_last(h, 2)
        self.scores = scores
        return h, scores


class PositionWiseFeedForward_ST(nn.Module):
    """FeedForward Neural Networks for each position"""

    def __init__(self, dim, ff_dim):
        super().__init__()

        self.fc1 = nn.Sequential(
            nn.Conv3d(dim, ff_dim, 1, stride=1, padding=0, bias=False),
            nn.BatchNorm3d(ff_dim),
            nn.ELU(),
        )

        self.STConv = nn.Sequential(
            nn.Conv3d(ff_dim, ff_dim, 3, stride=1, padding=1, groups=ff_dim, bias=False),
            nn.BatchNorm3d(ff_dim),
            nn.ELU(),
        )

        self.fc2 = nn.Sequential(
            nn.Conv3d(ff_dim, dim, 1, stride=1, padding=0, bias=False),
            nn.BatchNorm3d(dim),
        )

    def forward(self, x):  # [B, 3*3*40, 128]
        [B, P, C] = x.shape
        x = x.transpose(1, 2).view(B, C, P // 9, 3, 3)  # [B, dim, 40, 3, 3]
        x = self.fc1(x)  # x [B, ff_dim, 40, 3, 3]
        x = self.STConv(x)  # x [B, ff_dim, 40, 3, 3]
        x = self.fc2(x)  # x [B, dim, 40, 3, 3]
        x = x.flatten(2).transpose(1, 2)  # [B, 3*3*40, dim]

        return x


class Block_ST_TDC_gra_sharp(nn.Module):
    """
    变换器块Transformer Block
    包括多头自注意力机制、前馈神经网络以及相应的规范化和残差连接操作
    """

    def __init__(self, dim, num_heads, ff_dim, dropout, theta):
        super().__init__()
        # 计算输入特征的注意力权重，并根据这些权重对输入进行加权求和。通过多头机制，模型可以在不同的表示子空间中捕捉信息，从而增强其表达能力。
        self.attn = MultiHeadedSelfAttention_TDC_gra_sharp(dim, num_heads, dropout, theta)
        self.proj = nn.Linear(dim, dim)
        # 在自注意力机制和前馈网络之前应用层归一化（Layer Normalization），以稳定训练过程并加速收敛
        # BatchNorm 的目的是通过对每一批次（batch）数据进行标准化来减少内部协变量偏移。BatchNorm 的表现依赖于批量大小，较小的批量可能会导致不稳定的结果。
        # LayerNorm 的目标也是对输入进行标准化，但它是在单个样本的层面上进行的，而不是在整个批次上。LayerNorm 计算单个样本在给定层内的均值和方差，并使用这些统计量来标准化输入。
        self.norm1 = nn.LayerNorm(dim, eps=1e-6)
        # 前馈网络，用于对每个位置的特征进行非线性变换。它由两个全连接层组成，中间夹着一个激活函数（通常是ReLU或ELU）。
        self.pwff = PositionWiseFeedForward_ST(dim, ff_dim)
        self.norm2 = nn.LayerNorm(dim, eps=1e-6)
        self.drop = nn.Dropout(dropout)

    def forward(self, x, gra_sharp):
        Atten, Score = self.attn(self.norm1(x), gra_sharp)
        h = self.drop(self.proj(Atten))
        # 通过残差连接将自注意力机制和前馈网络的输出与输入相加，并在连接之前应用丢弃层（Dropout）以防止过拟合。
        x = x + h
        h = self.drop(self.pwff(self.norm2(x)))
        x = x + h
        return x, Score


class Transformer_ST_TDC_gra_sharp(nn.Module):
    """
    Transformer with Self-Attentive Blocks
    一个基于变换器架构的模块，用于处理具有时空特征的数据，如视频序列。
    这个类封装了一系列变换器块（Block_ST_TDC_gra_sharp），并通过自注意力机制来捕捉输入数据中的长距离依赖关系
    自注意力机制允许模型关注输入序列中的不同部分，并根据它们之间的关系进行加权平均。这对于捕捉视频中的时空特征特别有用，因为视频中的动作或事件往往不是孤立发生的，而是与其他部分有关联
    多头注意力可以从不同的表示子空间中捕捉信息
    """

    def __init__(self, num_layers, dim, num_heads, ff_dim, dropout, theta):
        super().__init__()
        """
        num_layers：变换器模块中包含的变换器块的数量。
        dim：每个变换器块的输入和输出维度。
        num_heads：每个变换器块中多头自注意力机制的头数。
        ff_dim：位置感知前馈网络（Position-wise Feed-Forward Network）的隐藏层维度。
        dropout：用于防止过拟合的丢弃概率。
        theta：控制 CDC 层中原始卷积和中心差分卷积的比例。
        """
        self.blocks = nn.ModuleList([
            Block_ST_TDC_gra_sharp(dim, num_heads, ff_dim, dropout, theta) for _ in range(num_layers)])

    def forward(self, x, gra_sharp):
        for block in self.blocks:
            x, Score = block(x, gra_sharp)
        return x, Score


## ThermRNet

In [120]:
class ThermRNet(nn.Module):
    # b, 3, 160, 96, 96
    def __init__(
            self,
            dim: int = 768,
            dropout_rate: float = 0.2,
            in_channels: int = 3,
            frame: int = 160,
            image_size: Optional[int] = 96,
    ):
        global config
        super().__init__()
        self.dropout_rate = dropout_rate
        self.image_size = image_size
        self.frame = frame
        self.dim = dim
        self.gra_sharp = config.GRA_SHARP

        # input(b, c, t, h, w),  (b, 3, 160, 96, 96)

        self.Stem0 = nn.Sequential(
            nn.Conv3d(3, dim // 16, [3, 3, 3], stride=1, padding=1),
            nn.BatchNorm3d(dim // 16),
            nn.ReLU(inplace=True),
            nn.Dropout3d(self.dropout_rate),
            nn.MaxPool3d((2, 2, 2), stride=(2, 2, 2)),  # kernel_size[0] = 1表示时间维度保持不变
        )  # (b, dim//16, 80, 48, 48)

        self.Stem1 = nn.Sequential(
            nn.Conv3d(dim // 16, dim // 8, [3, 3, 3], stride=1, padding=1),
            nn.BatchNorm3d(dim // 8),
            nn.ReLU(inplace=True),
            nn.Dropout3d(self.dropout_rate),
            nn.MaxPool3d((2, 2, 2), stride=(2, 2, 2)),
        )  # (b, dim//8, 40, 24, 24)
        self.Stem2 = nn.Sequential(
            nn.Conv3d(dim // 8, dim // 4, [3, 3, 3], stride=1, padding=1),
            nn.BatchNorm3d(dim // 4),
            nn.ReLU(inplace=True),
            nn.Dropout3d(self.dropout_rate),
            nn.MaxPool3d((1, 2, 2), stride=(1, 2, 2)),
        )  # (b, dim//4, 40, 12, 12)
        self.Stem3 = nn.Sequential(
            nn.Conv3d(dim // 4, dim // 2, [3, 3, 3], stride=1, padding=1),
            nn.BatchNorm3d(dim // 2),
            nn.ReLU(inplace=True),
            nn.Dropout3d(self.dropout_rate),
            nn.MaxPool3d((1, 2, 2), stride=(1, 2, 2)),
        )  # (b, dim//2, 40, 6, 6)

        self.Stem4 = nn.Sequential(
            nn.Conv3d(dim // 2, dim, [3, 3, 3], stride=1, padding=1),
            nn.BatchNorm3d(dim),
            nn.ReLU(inplace=True),
            nn.Dropout3d(self.dropout_rate),
            nn.MaxPool3d((1, 2, 2), stride=(1, 2, 2)),
        )  # (b, dim, 40, 3, 3)
        self.pos_embedding = nn.Parameter(torch.randn(1, 40 * 3 * 3, dim))  # 位置编码就是一个向量
        # 拿CNN做embedding了
        self.transformer1 = Transformer_ST_TDC_gra_sharp(num_layers=config.NUM_LAYERS, dim=dim, num_heads=config.NUM_HEADS,
                                                         ff_dim=config.FF_DIM, dropout=self.dropout_rate, theta=config.THETA)#  (b, 3 *3 * 40, dim)
        self.transformer2 = Transformer_ST_TDC_gra_sharp(num_layers=config.NUM_LAYERS, dim=dim, num_heads=config.NUM_HEADS,
                                                         ff_dim=config.FF_DIM, dropout=self.dropout_rate, theta=config.THETA)#  (b, 3 *3 * 40, dim)
        # 在残差连接后添加
        self.post_res_norm = nn.InstanceNorm3d(dim, affine=True)

        # 解码阶段需要逐步恢复分辨率
        # nn.Upsample(scale_factor=(8,1,1)) 将特征图在时间维度上放大8倍，而在空间维度上保持不变。这意味着如果输入是一个T×H×W的张量，那么输出将会是8T×H×W
        self.upsample = nn.Sequential(
            nn.Upsample(scale_factor=(2, 1, 1)),
            nn.Conv3d(dim, dim, [3, 1, 1], stride=1, padding=(1, 0, 0)),
            nn.BatchNorm3d(dim),
            nn.ELU(),
            nn.Dropout3d(self.dropout_rate),
        )  # (b, dim, 80, 3, 3)
        self.upsample2 = nn.Sequential(
            nn.Upsample(scale_factor=(2, 1, 1)),
            nn.Conv3d(dim, dim // 2, [3, 1, 1], stride=1, padding=(1, 0, 0)),
            nn.BatchNorm3d(dim // 2),
            nn.ELU(),
            nn.Dropout3d(self.dropout_rate),
        )  # (b, dim // 2, 160, 3, 3)
        # self.final_maxpool3d = nn.MaxPool3d((1, 3, 3), stride=(1, 3, 3))  # 输出为(b, dim // 2, 160, 1, 1)
        # 用conv1d代替最后的全连接层
        self.ConvBlockLast = nn.Conv1d(dim // 2, 2, 1, stride=1, padding=0)  # 最后输出为两个通道，表示属于每个类的概率
        
        # # conv1d效果不好，换为全连接层
        # self.classifier = nn.Sequential(
            
        # )

        # Initialize weights
        self.init_weights()

    @torch.no_grad()
    def init_weights(self):
        def _init(m):
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(
                    m.weight)  # _trunc_normal(m.weight, std=0.02)  # from .initialization import _trunc_normal
                if hasattr(m, 'bias') and m.bias is not None:
                    nn.init.normal_(m.bias, std=1e-6)  # nn.init.constant(m.bias, 0)

        self.apply(_init)

    def forward(self, x):
        b, c, t, fh, fw = x.shape
        x = self.Stem0(x)
        x = self.Stem1(x)
        x = self.Stem2(x)  # [B, 64, 160, 64, 64]
        x = self.Stem3(x)
        x_stem4 = self.Stem4(x)
         
        x = x_stem4.flatten(2).transpose(1, 2)  # [B, 3 *3 * 40, dim]
        x += self.pos_embedding  # 添加位置编码
        x, Score1 = self.transformer1(x, self.gra_sharp)  #  (b, 3 *3 * 40, dim)
        x, Score2 = self.transformer2(x, self.gra_sharp)  #  (b, 3 *3 * 40, dim)
        x = x.transpose(1, 2).view(b, self.dim, 40, 3, 3)  # [B, dim, 40, 3, 3]

        # 添加残差连接：将 Stem4 的输出与 Transformer 的输出相加
        x = x + x_stem4  # 确保维度一致
        x = self.post_res_norm(x) 

        x = self.upsample(x)  # [B, dim, 80, 3, 3]
        x = self.upsample2(x)  # [B, dim, 160, 3, 3]
        # features_last = self.final_maxpool3d(features_last).squeeze(-1).squeeze(-1)  # 去掉后面两个1的维度

        x = torch.mean(x, 3)  # x [B, dim, 160, 3]
        x = torch.mean(x, 3)  # x [B, dim, 160]
        logits = self.ConvBlockLast(x) # 输出为(B, 2, 160)
        return logits


## ThermRNet w/o skip connection

In [None]:
class ThermRNet_WO_SkipConnection(nn.Module):
    # b, 3, 160, 96, 96
    def __init__(
            self,
            dim: int = 768,
            dropout_rate: float = 0.2,
            in_channels: int = 3,
            frame: int = 160,
            image_size: Optional[int] = 96,
    ):
        global config
        super().__init__()
        self.dropout_rate = dropout_rate
        self.image_size = image_size
        self.frame = frame
        self.dim = dim
        self.gra_sharp = config.GRA_SHARP

        # input(b, c, t, h, w),  (b, 3, 160, 96, 96)

        self.Stem0 = nn.Sequential(
            nn.Conv3d(3, dim // 16, [3, 3, 3], stride=1, padding=1),
            nn.BatchNorm3d(dim // 16),
            nn.ReLU(inplace=True),
            nn.Dropout3d(self.dropout_rate),
            nn.MaxPool3d((2, 2, 2), stride=(2, 2, 2)),  # kernel_size[0] = 1表示时间维度保持不变
        )  # (b, dim//16, 80, 48, 48)

        self.Stem1 = nn.Sequential(
            nn.Conv3d(dim // 16, dim // 8, [3, 3, 3], stride=1, padding=1),
            nn.BatchNorm3d(dim // 8),
            nn.ReLU(inplace=True),
            nn.Dropout3d(self.dropout_rate),
            nn.MaxPool3d((2, 2, 2), stride=(2, 2, 2)),
        )  # (b, dim//8, 40, 24, 24)
        self.Stem2 = nn.Sequential(
            nn.Conv3d(dim // 8, dim // 4, [3, 3, 3], stride=1, padding=1),
            nn.BatchNorm3d(dim // 4),
            nn.ReLU(inplace=True),
            nn.Dropout3d(self.dropout_rate),
            nn.MaxPool3d((1, 2, 2), stride=(1, 2, 2)),
        )  # (b, dim//4, 40, 12, 12)
        self.Stem3 = nn.Sequential(
            nn.Conv3d(dim // 4, dim // 2, [3, 3, 3], stride=1, padding=1),
            nn.BatchNorm3d(dim // 2),
            nn.ReLU(inplace=True),
            nn.Dropout3d(self.dropout_rate),
            nn.MaxPool3d((1, 2, 2), stride=(1, 2, 2)),
        )  # (b, dim//2, 40, 6, 6)

        self.Stem4 = nn.Sequential(
            nn.Conv3d(dim // 2, dim, [3, 3, 3], stride=1, padding=1),
            nn.BatchNorm3d(dim),
            nn.ReLU(inplace=True),
            nn.Dropout3d(self.dropout_rate),
            nn.MaxPool3d((1, 2, 2), stride=(1, 2, 2)),
        )  # (b, dim, 40, 3, 3)
        self.pos_embedding = nn.Parameter(torch.randn(1, 40 * 3 * 3, dim))  # 位置编码就是一个向量
        # 拿CNN做embedding了
        self.transformer1 = Transformer_ST_TDC_gra_sharp(num_layers=config.NUM_LAYERS, dim=dim, num_heads=config.NUM_HEADS,
                                                         ff_dim=config.FF_DIM, dropout=self.dropout_rate, theta=config.THETA)#  (b, 3 *3 * 40, dim)
        self.transformer2 = Transformer_ST_TDC_gra_sharp(num_layers=config.NUM_LAYERS, dim=dim, num_heads=config.NUM_HEADS,
                                                         ff_dim=config.FF_DIM, dropout=self.dropout_rate, theta=config.THETA)#  (b, 3 *3 * 40, dim)

        # 解码阶段需要逐步恢复分辨率
        # nn.Upsample(scale_factor=(8,1,1)) 将特征图在时间维度上放大8倍，而在空间维度上保持不变。这意味着如果输入是一个T×H×W的张量，那么输出将会是8T×H×W
        self.upsample = nn.Sequential(
            nn.Upsample(scale_factor=(2, 1, 1)),
            nn.Conv3d(dim, dim, [3, 1, 1], stride=1, padding=(1, 0, 0)),
            nn.BatchNorm3d(dim),
            nn.ELU(),
            nn.Dropout3d(self.dropout_rate),
        )  # (b, dim, 80, 3, 3)
        self.upsample2 = nn.Sequential(
            nn.Upsample(scale_factor=(2, 1, 1)),
            nn.Conv3d(dim, dim // 2, [3, 1, 1], stride=1, padding=(1, 0, 0)),
            nn.BatchNorm3d(dim // 2),
            nn.ELU(),
            nn.Dropout3d(self.dropout_rate),
        )  # (b, dim // 2, 160, 3, 3)
        # self.final_maxpool3d = nn.MaxPool3d((1, 3, 3), stride=(1, 3, 3))  # 输出为(b, dim // 2, 160, 1, 1)
        # 用conv1d代替最后的全连接层
        self.ConvBlockLast = nn.Conv1d(dim // 2, 2, 1, stride=1, padding=0)  # 最后输出为两个通道，表示属于每个类的概率
        

        # Initialize weights
        self.init_weights()

    @torch.no_grad()
    def init_weights(self):
        def _init(m):
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(
                    m.weight)  # _trunc_normal(m.weight, std=0.02)  # from .initialization import _trunc_normal
                if hasattr(m, 'bias') and m.bias is not None:
                    nn.init.normal_(m.bias, std=1e-6)  # nn.init.constant(m.bias, 0)

        self.apply(_init)

    def forward(self, x):
        b, c, t, fh, fw = x.shape
        x = self.Stem0(x)
        x = self.Stem1(x)
        x = self.Stem2(x)  # [B, 64, 160, 64, 64]
        x = self.Stem3(x)
        x = self.Stem4(x)
         
        x = x.flatten(2).transpose(1, 2)  # [B, 3 *3 * 40, dim]
        x += self.pos_embedding  # 添加位置编码
        x, Score1 = self.transformer1(x, self.gra_sharp)  #  (b, 3 *3 * 40, dim)
        x, Score2 = self.transformer2(x, self.gra_sharp)  #  (b, 3 *3 * 40, dim)
        x = x.transpose(1, 2).view(b, self.dim, 40, 3, 3)  # [B, dim, 40, 3, 3]

        x = self.upsample(x)  # [B, dim, 80, 3, 3]
        x = self.upsample2(x)  # [B, dim, 160, 3, 3]

        x = torch.mean(x, 3)  # x [B, dim, 160, 3]
        x = torch.mean(x, 3)  # x [B, dim, 160]
        logits = self.ConvBlockLast(x) # 输出为(B, 2, 160)

        return logits

## ThermRNet w/o Stem

In [122]:
class ThermRNet_WO_Stem(nn.Module):
    # b, 3, 160, 96, 96
    def __init__(
            self,
            dim: int = 768,
            dropout_rate: float = 0.2,
            in_channels: int = 3,
            frame: int = 160,
            image_size: Optional[int] = 96,
    ):
        global config
        super().__init__()
        self.dropout_rate = dropout_rate
        self.image_size = image_size
        self.frame = frame
        self.dim = dim
        self.gra_sharp = config.GRA_SHARP

        # input(b, c, t, h, w),  (b, 3, 160, 96, 96)

        self.Stem0 = nn.Sequential(
            nn.Conv3d(3, dim, [2, 8, 8], stride=(2,8,8), padding=0),  # (b, dim, 80, 12, 12)
            nn.BatchNorm3d(dim),
            nn.ReLU(inplace=True),
            nn.Dropout3d(self.dropout_rate),
            nn.MaxPool3d((2, 4, 4), stride=(2, 4, 4)),  # kernel_size[0] = 1表示时间维度保持不变
        )  # (b, dim, 40, 3, 3)

        self.pos_embedding = nn.Parameter(torch.randn(1, 40 * 3 * 3, dim))  # 位置编码就是一个向量
        # 拿CNN做embedding了
        self.transformer1 = Transformer_ST_TDC_gra_sharp(num_layers=config.NUM_LAYERS, dim=dim, num_heads=config.NUM_HEADS,
                                                         ff_dim=config.FF_DIM, dropout=self.dropout_rate, theta=config.THETA)#  (b, 3 *3 * 40, dim)
        self.transformer2 = Transformer_ST_TDC_gra_sharp(num_layers=config.NUM_LAYERS, dim=dim, num_heads=config.NUM_HEADS,
                                                         ff_dim=config.FF_DIM, dropout=self.dropout_rate, theta=config.THETA)#  (b, 3 *3 * 40, dim)
        # 在残差连接后添加
        self.post_res_norm = nn.InstanceNorm3d(dim, affine=True)

        # 解码阶段需要逐步恢复分辨率
        # nn.Upsample(scale_factor=(8,1,1)) 将特征图在时间维度上放大8倍，而在空间维度上保持不变。这意味着如果输入是一个T×H×W的张量，那么输出将会是8T×H×W
        self.upsample = nn.Sequential(
            nn.Upsample(scale_factor=(2, 1, 1)),
            nn.Conv3d(dim, dim, [3, 1, 1], stride=1, padding=(1, 0, 0)),
            nn.BatchNorm3d(dim),
            nn.ELU(),
            nn.Dropout3d(self.dropout_rate),
        )  # (b, dim, 80, 3, 3)
        self.upsample2 = nn.Sequential(
            nn.Upsample(scale_factor=(2, 1, 1)),
            nn.Conv3d(dim, dim // 2, [3, 1, 1], stride=1, padding=(1, 0, 0)),
            nn.BatchNorm3d(dim // 2),
            nn.ELU(),
            nn.Dropout3d(self.dropout_rate),
        )  # (b, dim // 2, 160, 3, 3)
        # self.final_maxpool3d = nn.MaxPool3d((1, 3, 3), stride=(1, 3, 3))  # 输出为(b, dim // 2, 160, 1, 1)
        # 用conv1d代替最后的全连接层
        self.ConvBlockLast = nn.Conv1d(dim // 2, 2, 1, stride=1, padding=0)  # 最后输出为两个通道，表示属于每个类的概率
        
        # Initialize weights
        self.init_weights()

    @torch.no_grad()
    def init_weights(self):
        def _init(m):
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(
                    m.weight)  # _trunc_normal(m.weight, std=0.02)  # from .initialization import _trunc_normal
                if hasattr(m, 'bias') and m.bias is not None:
                    nn.init.normal_(m.bias, std=1e-6)  # nn.init.constant(m.bias, 0)

        self.apply(_init)

    def forward(self, x):
        b, c, t, fh, fw = x.shape
        x_stem0 = self.Stem0(x)
        x = x_stem0.flatten(2).transpose(1, 2)  # [B, 3 *3 * 40, dim]
        x += self.pos_embedding  # 添加位置编码
        x, Score1 = self.transformer1(x, self.gra_sharp)  #  (b, 3 *3 * 40, dim)
        x, Score2 = self.transformer2(x, self.gra_sharp)  #  (b, 3 *3 * 40, dim)
        x = x.transpose(1, 2).view(b, self.dim, 40, 3, 3)  # [B, dim, 40, 3, 3]

        # 添加残差连接：将 Stem0 的输出与 Transformer 的输出相加
        x = x + x_stem0  # 确保维度一致
        x = self.post_res_norm(x) 

        x = self.upsample(x)  # [B, dim, 80, 3, 3]
        x = self.upsample2(x)  # [B, dim, 160, 3, 3]
        # features_last = self.final_maxpool3d(features_last).squeeze(-1).squeeze(-1)  # 去掉后面两个1的维度

        x = torch.mean(x, 3)  # x [B, dim, 160, 3]
        x = torch.mean(x, 3)  # x [B, dim, 160]
        logits = self.ConvBlockLast(x) # 输出为(B, 2, 160)
        return logits


## ThermRNet w/o TDT

In [123]:
class ThermRNet_WO_TDT(nn.Module):
    # b, 3, 160, 96, 96
    def __init__(
            self,
            dim: int = 768,
            dropout_rate: float = 0.2,
            in_channels: int = 3,
            frame: int = 160,
            image_size: Optional[int] = 96,
    ):
        global config
        super().__init__()
        self.dropout_rate = dropout_rate
        self.image_size = image_size
        self.frame = frame
        self.dim = dim
        self.gra_sharp = config.GRA_SHARP

        # input(b, c, t, h, w),  (b, 3, 160, 96, 96)

        self.Stem0 = nn.Sequential(
            nn.Conv3d(3, dim // 16, [3, 3, 3], stride=1, padding=[2, 2, 2]),
            nn.BatchNorm3d(dim // 16),
            nn.ReLU(inplace=True),
            nn.Dropout3d(self.dropout_rate),
            nn.MaxPool3d((2, 2, 2), stride=(2, 2, 2)),  # kernel_size[0] = 1表示时间维度保持不变
        )  # (b, dim//16, 80, 48, 48)

        self.Stem1 = nn.Sequential(
            nn.Conv3d(dim // 16, dim // 8, [3, 3, 3], stride=1, padding=1),
            nn.BatchNorm3d(dim // 8),
            nn.ReLU(inplace=True),
            nn.Dropout3d(self.dropout_rate),
            nn.MaxPool3d((2, 2, 2), stride=(2, 2, 2)),
        )  # (b, dim//8, 40, 24, 24)
        self.Stem2 = nn.Sequential(
            nn.Conv3d(dim // 8, dim // 4, [3, 3, 3], stride=1, padding=1),
            nn.BatchNorm3d(dim // 4),
            nn.ReLU(inplace=True),
            nn.Dropout3d(self.dropout_rate),
            nn.MaxPool3d((1, 2, 2), stride=(1, 2, 2)),
        )  # (b, dim//4, 40, 12, 12)
        self.Stem3 = nn.Sequential(
            nn.Conv3d(dim // 4, dim // 2, [3, 3, 3], stride=1, padding=1),
            nn.BatchNorm3d(dim // 2),
            nn.ReLU(inplace=True),
            nn.Dropout3d(self.dropout_rate),
            nn.MaxPool3d((1, 2, 2), stride=(1, 2, 2)),
        )  # (b, dim//2, 40, 6, 6)

        self.Stem4 = nn.Sequential(
            nn.Conv3d(dim // 2, dim, [3, 3, 3], stride=1, padding=1),
            nn.BatchNorm3d(dim),
            nn.ReLU(inplace=True),
            nn.Dropout3d(self.dropout_rate),
            nn.MaxPool3d((1, 2, 2), stride=(1, 2, 2)),
        )  # (b, dim, 40, 3, 3)
        
        # 解码阶段需要逐步恢复分辨率
        # nn.Upsample(scale_factor=(8,1,1)) 将特征图在时间维度上放大8倍，而在空间维度上保持不变。这意味着如果输入是一个T×H×W的张量，那么输出将会是8T×H×W
        self.upsample = nn.Sequential(
            nn.Upsample(scale_factor=(2, 1, 1)),
            nn.Conv3d(dim, dim, [3, 1, 1], stride=1, padding=(1, 0, 0)),
            nn.BatchNorm3d(dim),
            nn.ELU(),
            nn.Dropout3d(self.dropout_rate),
        )  # (b, dim, 80, 3, 3)
        self.upsample2 = nn.Sequential(
            nn.Upsample(scale_factor=(2, 1, 1)),
            nn.Conv3d(dim, dim // 2, [3, 1, 1], stride=1, padding=(1, 0, 0)),
            nn.BatchNorm3d(dim // 2),
            nn.ELU(),
            nn.Dropout3d(self.dropout_rate),
        )  # (b, dim // 2, 160, 3, 3)
        # self.final_maxpool3d = nn.MaxPool3d((1, 3, 3), stride=(1, 3, 3))  # 输出为(b, dim // 2, 160, 1, 1)
        # 用conv1d代替最后的全连接层
        self.ConvBlockLast = nn.Conv1d(dim // 2, 2, 1, stride=1, padding=0)  # 最后输出为两个通道，表示属于每个类的概率

        # Initialize weights
        self.init_weights()

    @torch.no_grad()
    def init_weights(self):
        def _init(m):
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(
                    m.weight)  # _trunc_normal(m.weight, std=0.02)  # from .initialization import _trunc_normal
                if hasattr(m, 'bias') and m.bias is not None:
                    nn.init.normal_(m.bias, std=1e-6)  # nn.init.constant(m.bias, 0)

        self.apply(_init)

    def forward(self, x):
        b, c, t, fh, fw = x.shape
        x = self.Stem0(x)
        x = self.Stem1(x)
        x = self.Stem2(x)  # [B, 64, 160, 64, 64]
        x = self.Stem3(x)
        x = self.Stem4(x)  # (B, dim, 40, 3, 3)

        x = self.upsample(x)  # [B, dim, 80, 3, 3]
        x = self.upsample2(x)  # [B, dim, 160, 3, 3]
        # features_last = self.final_maxpool3d(features_last).squeeze(-1).squeeze(-1)  # 去掉后面两个1的维度

        x = torch.mean(x, 3)  # x [B, dim, 160, 3]
        x = torch.mean(x, 3)  # x [B, dim, 160]
        logits = self.ConvBlockLast(x) # 输出为(B, 2, 160)

        return logits

## Dataset

In [None]:
from torchvision import transforms
from torchvision.transforms.functional import to_pil_image, to_tensor


mean = [0.5, 0.5, 0.5]
std = [0.5, 0.5, 0.5]

# Define your transformations
# transform = transforms.Compose([  # 数据增强
#     transforms.RandomRotation(10),  # 随机旋转 -10 ~ 10 度
#     transforms.Resize(108),
#     transforms.RandomCrop(96),
#     transforms.ColorJitter(brightness=0.02, contrast=0.01, saturation=0.01, hue=0.01),  # 颜色变换 亮度、对比度、饱和度、色调
#     transforms.RandomHorizontalFlip(),
#     transforms.ToTensor(),  # This will convert PIL image to Tensor and scale values to [0, 1]
#     transforms.Normalize(mean=mean, std=std)  # 然后进行标准化
# ])

# transform_valid = transforms.Compose([
#     transforms.Resize(96),
#     transforms.ToTensor(),  # This will convert PIL image to Tensor and scale values to [0, 1]
#     transforms.Normalize(mean=mean, std=std)  # 然后进行标准化
# ])


transform = transforms.Compose([  # 数据增强
    transforms.Lambda(lambda x: x / 255.0),  # 替代 ToTensor 的归一化
    transforms.RandomRotation(10),
    transforms.Resize(108, antialias=True),
    transforms.RandomCrop(96),
    # transforms.ColorJitter(brightness=0.02, contrast=0.01, saturation=0.01, hue=0.01),
    transforms.RandomHorizontalFlip(),
    transforms.Normalize(mean=mean, std=std)
])

transform_valid = transforms.Compose([
    transforms.Lambda(lambda x: x / 255.0),  # 替代 ToTensor 的归一化
    transforms.Resize(96, antialias=True),
    # transforms.ToTensor(),  # This will convert PIL image to Tensor and scale values to [0, 1]
    transforms.Normalize(mean=mean, std=std)  # 然后进行标准化
])


def standardized_data(data):
    """Z-score standardization for video data."""
    data = data - np.mean(data)
    data = data / np.std(data)
    data[np.isnan(data)] = 0
    return data


def apply_transform_to_video(video, transform):
    """
    Apply the same transformation to every frame of the video.
    
    Args:
        video (np.ndarray): Video with shape (T, Height, Width, channels).
        transform (callable): Transform function that takes PIL image and returns transformed PIL image.
    
    Returns:
        torch.Tensor: Transformed video with shape (channels, T, Height, Width).
    """
    # Ensure the video is in the correct format (float32 and normalized between 0 and 255)
    assert isinstance(video, np.ndarray)
    
    # # Convert video frames to PIL images
    # pil_images = [to_pil_image(np.uint8(frame)) for frame in video]  # (T, H, W, C)
    # # Apply the same transform to each frame
    # transformed_pil_images = [transform(img) for img in pil_images]  # (T, C, H, W)
    # # Convert back to a single tensor
    # transformed_video = torch.stack(transformed_pil_images).permute(1, 0, 2, 3)  # (C, T, H, W)
    # return transformed_video
    video_tensor = torch.from_numpy(video).permute(0, 3, 1, 2).float() # (T, C, H, W)
    # 合并 T 和 C 维度以进行批量处理 (T*C, H, W)
    # 批量应用变换（所有帧统一参数）
    transformed = transform(video_tensor)  # transform 需支持 (T, C, H, W) 输入
    
    # 恢复维度 (C, T, H, W)
    return transformed.permute(1, 0, 2, 3)
    
    

class MyCustomDataset(Dataset):
    def __init__(self, dataset_name, dataset_dir, file_path_list_path=None, transform=None):
        self.dataset_name = dataset_name
        self.dataset_dir = dataset_dir
        self.file_path_list_path = file_path_list_path
        self.transform = transform
        self.data_list = os.listdir(dataset_dir)
        # 可以使用dataset_dir来获取，也可以使用file_path_list_path来获取
        self.inputs = [os.path.join(dataset_dir, x) for x in self.data_list if 'input' in x]
        self.labels = [os.path.join(dataset_dir, x) for x in self.data_list if 'label' in x]
        # 对齐
        self.inputs.sort()
        self.labels.sort()

    def __len__(self):
            """Returns the length of the dataset."""
            return len(self.inputs)

    def __getitem__(self, index):
        """
            在这里可以做数据增强，例如随机裁剪，随机翻转等，利用transform
            Returns a clip of video(T,3,W,H) and it's corresponding signals(T).
        """
        data = np.load(self.inputs[index])  # 获取一个分段数据
        label = np.load(self.labels[index])
        # self.data_format == 'NDCHW':
        # data = np.transpose(data, (0, 3, 1, 2))  # D C H W，而N是一个batch了，这里其实就是把 HWC的顺序换为了CHW的顺序
        # self.data_format == 'NCDHW':

        if random.random() < 0.1:   # 数据增强
            indices = list(range(len(data)))
            random.shuffle(indices)
            data = data[indices]
            label = label[indices]

        data = apply_transform_to_video(data, self.transform)  # 转换后每个像素值变为(0, 1)
        # data = np.transpose(data, (3, 0, 1, 2))
        # data = np.float32(data)
        # data = standardized_data(data)
        label = np.int64(label)  # 分类值 0 1， 必须转为float，因为损失函数要求tensor类型一致，而模型输出是float32，故这里label也转float32
        
        # 不进行下列步骤减少处理时间
        # # item_path is the location of a specific clip in a preprocessing output folder
        # # 例如: /xx/xx/s1_T1_1_input0.npy
        # item_path = self.inputs[index]
        # # item_path_filename is simply the filename of the specific clip
        # # 例如: s1_T1_1_input0.npy
        # item_path_filename = item_path.split(os.sep)[-1]
        # # split_idx represents the point in the previous filename where we want to split the string 
        # # 例如: s1_T1_1_input0.npy 中  input0前_的位置下标  split_idx=7
        # split_idx = item_path_filename.rindex('_')  # 从右边开始找，找到第一个_的位置（这个位置split_idx的值是从左边开始算的）
        # # Following the previous comments, 例如： s1_T1_1
        # filename = item_path_filename[:split_idx]
        # # chunk_id is the extracted, numeric chunk identifier. Following the previous comments, 
        # # 获取input后的数字（这里暂时是字符串）， 例如s1_T1_1_input0.npy的chunk_id='0'
        # chunk_id = item_path_filename[split_idx + 6:].split('.')[0]'
        
        filename = chunk_id = 1
        return data, label, filename, chunk_id

class MyCustomTestDataset(Dataset):
    def __init__(self, dataset_name, dataset_dir, file_path_list_path=None, transform=None):
        self.dataset_name = dataset_name
        self.dataset_dir = dataset_dir
        self.file_path_list_path = file_path_list_path
        self.transform = transform
        self.data_list = os.listdir(dataset_dir)
        # 可以使用dataset_dir来获取，也可以使用file_path_list_path来获取
        self.inputs = [os.path.join(dataset_dir, x) for x in self.data_list if 'input' in x]
        self.labels = [os.path.join(dataset_dir, x) for x in self.data_list if 'label' in x]
        # 对齐
        self.inputs.sort()
        self.labels.sort()

    def __len__(self):
            """Returns the length of the dataset."""
            return len(self.inputs)

    def __getitem__(self, index):
        """
            在这里可以做数据增强，例如随机裁剪，随机翻转等，利用transform
            Returns a clip of video(T,3,W,H) and it's corresponding signals(T).
        """
        data = np.load(self.inputs[index])  # 获取一个分段数据
        label = np.load(self.labels[index])
        
        data = apply_transform_to_video(data, self.transform)
        # data = standardized_data(data)
        # data = np.transpose(data, (3, 0, 1, 2))  # (C, T, H, W)
        # data = np.float32(data)
        label = np.int64(label)  # 分类值 0 1

        # 不进行下列步骤减少处理时间
        # # item_path is the location of a specific clip in a preprocessing output folder
        # # 例如: /xx/xx/s1_T1_1_input0.npy
        # item_path = self.inputs[index]
        # # item_path_filename is simply the filename of the specific clip
        # # 例如: s1_T1_1_input0.npy
        # item_path_filename = item_path.split(os.sep)[-1]
        # # split_idx represents the point in the previous filename where we want to split the string 
        # # 例如: s1_T1_1_input0.npy 中  input0前_的位置下标  split_idx=7
        # split_idx = item_path_filename.rindex('_')  # 从右边开始找，找到第一个_的位置（这个位置split_idx的值是从左边开始算的）
        # # Following the previous comments, 例如： s1_T1_1
        # filename = item_path_filename[:split_idx]
        # # chunk_id is the extracted, numeric chunk identifier. Following the previous comments, 
        # # 获取input后的数字（这里暂时是字符串）， 例如s1_T1_1_input0.npy的chunk_id='0'
        # chunk_id = item_path_filename[split_idx + 6:].split('.')[0]
        filename = chunk_id = 1

        return data, label, filename, chunk_id


In [130]:
net = ThermRNet(dim=config.DIM, dropout_rate=config.DROP_RATE, frame=config.CHUNK_LENGTH, image_size=config.RESIZE_W)
print(net)
total_params = sum(p.numel() for p in net.parameters())
print(total_params)


net = ThermRNet_WO_SkipConnection(dim=config.DIM, dropout_rate=config.DROP_RATE, frame=config.CHUNK_LENGTH, image_size=config.RESIZE_W)
print(net)
total_params = sum(p.numel() for p in net.parameters())
print(total_params)

net = ThermRNet_WO_Stem(dim=config.DIM, dropout_rate=config.DROP_RATE, frame=config.CHUNK_LENGTH, image_size=config.RESIZE_W)
print(net)
total_params = sum(p.numel() for p in net.parameters())
print(total_params)

net = ThermRNet_WO_TDT(dim=config.DIM, dropout_rate=config.DROP_RATE, frame=config.CHUNK_LENGTH, image_size=config.RESIZE_W)
print(net)
total_params = sum(p.numel() for p in net.parameters())
print(total_params)

ThermRNet(
  (Stem0): Sequential(
    (0): Conv3d(3, 6, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (1): BatchNorm3d(6, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Dropout3d(p=0.2, inplace=False)
    (4): MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2), padding=0, dilation=1, ceil_mode=False)
  )
  (Stem1): Sequential(
    (0): Conv3d(6, 12, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (1): BatchNorm3d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Dropout3d(p=0.2, inplace=False)
    (4): MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2), padding=0, dilation=1, ceil_mode=False)
  )
  (Stem2): Sequential(
    (0): Conv3d(12, 24, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (1): BatchNorm3d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Dropout3d(p=0.2, inpla

## Train ThermRNet

In [131]:
class MyTrainer():
    def __init__(self, config, data_loader, model_file_name, model, save_file):
        """Inits parameters from args and the writer for TensorboardX."""
        super().__init__()
        self.device = torch.device(config.DEVICE)
        self.max_epoch_num = config.EPOCHS
        self.model_dir = config.MODEL_DIR
        self.model_file_name = model_file_name
        self.batch_size = config.TRAIN_BATCH_SIZE
        self.config = config
        self.min_valid_loss = None
        self.best_epoch = 0
        self.scheduler = None
        self.data_loader = data_loader
        self.save_file = save_file

        
        if config.TOOLBOX_MODE == "train_and_test":
            self.model = model
            self.num_train_batches = len(data_loader["train"])
            # 设定损失函数和优化器
            self.criterion = nn.CrossEntropyLoss()  # 损失函数，内部使用了softmax
            # self.optimizer = optim.Adam(self.model.parameters(), lr=0.001)

            self.optimizer = optim.AdamW(self.model.parameters(), lr=config.LR, weight_decay=0)
            # See more details on the OneCycleLR scheduler here: https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.OneCycleLR.html
            self.scheduler = torch.optim.lr_scheduler.OneCycleLR(self.optimizer, max_lr=config.LR, epochs=config.EPOCHS, steps_per_epoch=self.num_train_batches)

    def train(self, data_loader):
        """Training routine for model"""
        if data_loader["train"] is None:
            raise ValueError("No data for train")
        min_val_loss = 1000
        min_running_loss = 1000
        df = pd.DataFrame(columns=["epoch", "train_acc", "train_loss", "val_acc", "val_loss"])

        for epoch in range(self.max_epoch_num):
            print('')
            print(f"====Training Epoch: {epoch}====")
            running_loss = 0.0
            train_loss = []
            self.model.train()
            correct = 0
            total_size = 0

            # Model Training
            for idx, batch in enumerate(data_loader["train"]):
                data, labels = batch[0].to(self.device), batch[1].to(self.device)
            
                N, C, D, H, W = data.shape
                # 处理data.D 为 320的情况
                # print(data[0][0][0][0], data[0][1][0][0], data[0][2][0][0])
                pred= self.model(data)  # (b, 2, 160)

                self.optimizer.zero_grad()
                # pred (b, 2, 160) , labels (b, 160)
                loss = self.criterion(pred, labels)  # a number (1,)
                loss.backward()
                self.optimizer.step()
                if self.scheduler is not None:  # 如果使用OneCycleLR，则使用这个
                    self.scheduler.step()
                total_size += labels.shape[0] * labels.shape[1]
                _, predicted = torch.max(pred, 1)
                correct += (predicted == labels).sum().item()
                running_loss += loss.item()
                if idx % 30 == 29:  # print every 100 mini-batches
                    print(
                        f'[{epoch}, {idx + 1:5d}] loss: {running_loss / (idx+1):.3f}')
                train_loss.append(loss.item())
            train_acc = 100 * correct / total_size
            print(f"Train Avg Accuracy: {train_acc:>0.1f}%")
            train_loss = np.asarray(train_loss)
            train_loss = np.mean(train_loss)
            print(f"Train Avg loss: {train_loss}")
            val_loss, val_acc = self.valid(data_loader)
            if(min_running_loss >= train_loss):
                print(f"Train Loss improved from {min_running_loss:.3f} to {train_loss:.3f}")
                min_running_loss = train_loss
                self.save_model('best')
            if(min_val_loss >= val_loss):
                print(f"Val Loss improved from {min_val_loss} to {val_loss}")
                min_val_loss = val_loss
                self.save_model('valBest')

            # 追加数据到 DataFrame
            new_row = pd.DataFrame({
                "epoch": [epoch],
                "train_acc": [train_acc],
                "train_loss": [train_loss],
                "val_acc": [val_acc],
                "val_loss": [val_loss]
            })
            df = pd.concat([df, new_row], ignore_index=True)

        self.save_model(self.max_epoch_num)
        df.to_csv(self.save_file,index=False)

    
    def valid(self, data_loader):
        """ Model evaluation on the validation dataset."""
        if data_loader["valid"] is None:
            raise ValueError("No data for valid")

        print('')
        print("===Validating===")
        valid_loss = []
        self.model.eval()
        valid_step = 0
        total_size = 0
        correct = 0
        with torch.no_grad():
            for valid_idx, valid_batch in enumerate(data_loader["valid"]):
                data_valid, labels_valid = valid_batch[0].to(self.device), valid_batch[1].to(self.device)
                N, C, D, H, W = data_valid.shape
                pred_valid = self.model(data_valid)
                loss = self.criterion(pred_valid, labels_valid)
                total_size += labels_valid.shape[0] * labels_valid.shape[1]
                _, predicted = torch.max(pred_valid, 1)
                correct += (predicted == labels_valid).sum().item()
                valid_loss.append(loss.item())
                valid_step += 1
            valid_loss = np.asarray(valid_loss)
            valid_loss = np.mean(valid_loss)
            valid_acc = 100 * correct / total_size
            print(f"valid Avg Accuracy: {valid_acc:>0.1f}%")
            print(f"valid Avg loss: {valid_loss}")

        return valid_loss, valid_acc


    def save_model(self, index):
        if not os.path.exists(self.model_dir):
            os.makedirs(self.model_dir)
        model_path = os.path.join(self.model_dir, self.model_file_name + '_Epoch' + str(index) + '.pth')
        torch.save(self.model.state_dict(), model_path)
        print('Saved Model Path: ', model_path)
        
    def load_pretrain_model(self, model_path):
        self.model.load_state_dict(torch.load(model_path), strict=False)
        self.model.to(self.device)
        self.num_train_batches = len(self.data_loader["train"])
        # 设定损失函数和优化器
        self.criterion = nn.CrossEntropyLoss()  # 损失函数，内部使用了softmax
        # self.optimizer = optim.Adam(self.model.parameters(), lr=0.001)

        self.optimizer = optim.AdamW(self.model.parameters(), lr=config.LR, weight_decay=0)
        # See more details on the OneCycleLR scheduler here: https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.OneCycleLR.html
        self.scheduler = torch.optim.lr_scheduler.OneCycleLR(self.optimizer, max_lr=config.LR, epochs=config.EPOCHS, steps_per_epoch=self.num_train_batches)

    def frozen_feature_extractor(self, is_frozen=True):
        if is_frozen:  # 冻结特征提取器
            for param in self.model.parameters():
                param.requires_grad = False  # (out_channel, input_channel, k1, k2, k3)
            # 解冻最后一层
            for param in self.model.ConvBlockLast.parameters():
                param.requires_grad = True
            for param in self.model.upsample2.parameters():
                param.requires_grad = True
            for param in self.model.upsample.parameters():  
                param.requires_grad = True
            for param in self.model.transformer2.parameters():  
                param.requires_grad = True
            for param in self.model.transformer1.parameters():  
                param.requires_grad = True
        else:
            for param in self.model.parameters():  # 解冻所有层
                param.requires_grad = True

In [134]:
data_loader_dict = dict()  # dictionary of data loaders
train_dataset = MyCustomDataset(
    dataset_name='train',
    dataset_dir=train_dataset_dir,
    transform=transform
)
valid_dataset = MyCustomTestDataset(
    dataset_name='valid',
    dataset_dir=valid_dataset_dir,
    transform=transform_valid,
)
data_loader_dict['train'] = DataLoader(
    dataset=train_dataset,
    batch_size=config.TRAIN_BATCH_SIZE,
    shuffle=True,
    worker_init_fn=seed_worker,
    generator=train_generator
)
data_loader_dict['valid'] = DataLoader(
    dataset=valid_dataset,
    batch_size=config.TEST_BATCH_SIZE,
    shuffle=False
)
data, label, filename, chunk_id = next(iter(data_loader_dict['train']))
print(data.size())
print(label.size())
print(filename)
print(chunk_id)
print(data_loader_dict['train'].num_workers)  # 输出 0
print(data_loader_dict['valid'].num_workers)  # 输出 0

['X:\\my_thermal_nose_dataset\\RawData_CHUNK160_RESIZE128_littleROI_TRAIN\\mydataset\\s10_T1_1_input0.npy', 'X:\\my_thermal_nose_dataset\\RawData_CHUNK160_RESIZE128_littleROI_TRAIN\\mydataset\\s10_T1_1_input1.npy', 'X:\\my_thermal_nose_dataset\\RawData_CHUNK160_RESIZE128_littleROI_TRAIN\\mydataset\\s10_T1_1_input10.npy', 'X:\\my_thermal_nose_dataset\\RawData_CHUNK160_RESIZE128_littleROI_TRAIN\\mydataset\\s10_T1_1_input11.npy', 'X:\\my_thermal_nose_dataset\\RawData_CHUNK160_RESIZE128_littleROI_TRAIN\\mydataset\\s10_T1_1_input12.npy', 'X:\\my_thermal_nose_dataset\\RawData_CHUNK160_RESIZE128_littleROI_TRAIN\\mydataset\\s10_T1_1_input13.npy', 'X:\\my_thermal_nose_dataset\\RawData_CHUNK160_RESIZE128_littleROI_TRAIN\\mydataset\\s10_T1_1_input14.npy', 'X:\\my_thermal_nose_dataset\\RawData_CHUNK160_RESIZE128_littleROI_TRAIN\\mydataset\\s10_T1_1_input15.npy', 'X:\\my_thermal_nose_dataset\\RawData_CHUNK160_RESIZE128_littleROI_TRAIN\\mydataset\\s10_T1_1_input16.npy', 'X:\\my_thermal_nose_dataset\

In [None]:
model_file_name = 'ThermRNet'
pretrain_model_path = rf'.\PreTrainedModels\{model_file_name}_Epochbest.pth'
save_file_path = rf".\{model_file_name}.csv"
model = ThermRNet(dim=config.DIM, dropout_rate=config.DROP_RATE, frame=config.CHUNK_LENGTH, image_size=config.RESIZE_W).to(device)
model_trainer = MyTrainer(config, data_loader_dict, model_file_name, model, save_file_path)
# model_trainer.load_pretrain_model(pretrain_model_path)
# model_trainer.frozen_feature_extractor(True)  # 只训练分类器
model_trainer.train(data_loader_dict)

## ThermRNet w/o skip connection

In [None]:
model_file_name = 'ThermRNet_WO_SkipConnection'
pretrain_model_path = rf'.\PreTrainedModels\{model_file_name}_Epochbest.pth'
save_file_path = rf".\{model_file_name}.csv"
model = ThermRNet_WO_SkipConnection(dim=config.DIM, dropout_rate=config.DROP_RATE, frame=config.CHUNK_LENGTH, image_size=config.RESIZE_W).to(device)
model_trainer = MyTrainer(config, data_loader_dict, model_file_name, model, save_file_path)
# model_trainer.load_pretrain_model(pretrain_model_path)
# model_trainer.frozen_feature_extractor(True)  # 只训练分类器
model_trainer.train(data_loader_dict)

## ThermRNet w/o Stem

In [None]:
model_file_name = 'ThermRNet_WO_Stem'
pretrain_model_path = rf'.\PreTrainedModels\{model_file_name}_Epochbest.pth'
save_file_path = rf".\{model_file_name}.csv"
model = ThermRNet_WO_Stem(dim=config.DIM, dropout_rate=config.DROP_RATE, frame=config.CHUNK_LENGTH, image_size=config.RESIZE_W).to(device)
model_trainer = MyTrainer(config, data_loader_dict, model_file_name, model, save_file_path)
# model_trainer.load_pretrain_model(pretrain_model_path)
# model_trainer.frozen_feature_extractor(True)  # 只训练分类器
model_trainer.train(data_loader_dict)

## ThermRNet w/o TDT

In [None]:
model_file_name = 'ThermRNet_WO_TDT'
pretrain_model_path = rf'.\PreTrainedModels\{model_file_name}_Epochbest.pth'
save_file_path = rf".\{model_file_name}.csv"
model = ThermRNet_WO_TDT(dim=config.DIM, dropout_rate=config.DROP_RATE, frame=config.CHUNK_LENGTH, image_size=config.RESIZE_W).to(device)
model_trainer = MyTrainer(config, data_loader_dict, model_file_name, model, save_file_path)
# model_trainer.load_pretrain_model(pretrain_model_path)
# model_trainer.frozen_feature_extractor(True)  # 只训练分类器
model_trainer.train(data_loader_dict)