In [None]:
import torch
import torch.nn as nn
from torchvision.models.video import r3d_18

class Simple3DCNN(nn.Module):
    def __init__(self):
        super(Simple3DCNN, self).__init__()
        self.conv1 = nn.Conv3d(1, 32, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm3d(32)
        self.pool = nn.MaxPool3d(2)
        self.conv2 = nn.Conv3d(32, 64, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm3d(64)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.pool(x)
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.pool(x)
        return x


class SEBlock(nn.Module):
    def __init__(self, input_dim, reduction=16):
        super(SEBlock, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool3d(1)
        self.fc = nn.Sequential(
            nn.Linear(input_dim, input_dim // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(input_dim // reduction, input_dim, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1, 1)
        self.saved_attention = y
        return x * y.expand_as(x)

class SEBasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(SEBasicBlock, self).__init__()
        self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm3d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv3d(planes, planes, kernel_size=3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm3d(planes)
        self.se_block = SEBlock(planes * self.expansion)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if mask_features is not None:
            out += mask_features
        # Apply SE block
        out = self.se_block(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out

    def get_attention_weight(self):
        return self.se_block.saved_attention

# Load the pre-built 3D ResNet model
model = r3d_18(pretrained=True, progress=True)

# Replace the first convolution layer for single-channel input
model.stem[0] = nn.Conv3d(1, 64, kernel_size=(3, 7, 7), stride=(1, 2, 2), padding=(1, 3, 3), bias=False)

# Replace existing blocks with SEBasicBlock
def make_layer(block, inplanes, planes, blocks, stride=1):
    downsample = None
    if stride != 1 or inplanes != planes * block.expansion:
        downsample = nn.Sequential(
            nn.Conv3d(inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
            nn.BatchNorm3d(planes * block.expansion),
        )

    layers = []
    layers.append(block(inplanes, planes, stride, downsample))
    inplanes = planes * block.expansion
    for _ in range(1, blocks):
        layers.append(block(inplanes, planes))

    return nn.Sequential(*layers)

# if you want 18 layers, use the following
# model.layer1 = make_layer(SEBasicBlock, 64, 64, 2)
# model.layer2 = make_layer(SEBasicBlock, 64, 128, 2, stride=2)
# model.layer3 = make_layer(SEBasicBlock, 128, 256, 2, stride=2)
# model.layer4 = make_layer(SEBasicBlock, 256, 512, 2, stride=2)

# if you want 9 layers, use the following
model.layer1 = make_layer(SEBasicBlock, 64, 64, 1) 
model.layer2 = make_layer(SEBasicBlock, 64, 128, 1, stride=2)  
model.layer3 = make_layer(SEBasicBlock, 128, 256, 1, stride=2) 
model.layer4 = make_layer(SEBasicBlock, 256, 512, 1, stride=2)  

# Replace the fully connected layer for specific output size
model.fc = nn.Linear(in_features=model.fc.in_features, out_features=3)

# Move the model to the appropriate device (GPU or CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

print("Using device:", device)


In [None]:
# 假设mask图像存储在同一个文件夹，具有相同的文件名
mask_folder_path = r'D:\Juntao\Data\ANDI3-T1\mask'

# 修改data_list部分，同时加载图像和mask
data_list = []
mask_list = []
for file_name in os.listdir(nii_folder_path):
    if file_name.endswith('.nii'):
        subject_id = file_name.split('.')[0]
        if subject_id in filtered_df['Subject'].values:
            # 加载图像
            image_file_path = os.path.join(nii_folder_path, file_name)
            image_data = nib.load(image_file_path).get_fdata()
            if image_data.ndim == 3:
                image_data = image_data[np.newaxis, ...]  # Add single-channel dimension
            data_list.append(image_data)

            # 加载对应的mask
            mask_file_path = os.path.join(mask_folder_path, file_name)
            mask_data = nib.load(mask_file_path).get_fdata()
            if mask_data.ndim == 3:
                mask_data = mask_data[np.newaxis, ...]  # Add single-channel dimension
            mask_list.append(mask_data)

# 堆叠图像和mask数据
data = np.vstack(data_list)[:, np.newaxis, ...]
mask_data = np.vstack(mask_list)[:, np.newaxis, ...]

print(f'data shape: {data.shape}')
print(f'mask data shape: {mask_data.shape}')


In [None]:
# 定义一个继承自原有模型的新模型类
class ModifiedModel(YourOriginalModel):
    def __init__(self):
        super(ModifiedModel, self).__init__()
        self.mask_model = Simple3DCNN()  # 实例化mask模型

    def forward(self, x, mask):
        mask_features = self.mask_model(mask)
        # 根据您的模型结构，您可能需要在这里修改代码来融合mask_features和x
        # 例如：x = torch.cat((x, mask_features), dim=1)
        # 然后调用您原始模型的其它层

# 实例化新模型
model = ModifiedModel().to(device)


In [None]:
for i, (inputs, masks, labels) in enumerate(train_loader, 0):
    inputs, masks, labels = inputs.to(device), masks.to(device), labels.to(device)
    optimizer.zero_grad()
    outputs = model(inputs, masks)  # 同时传入图像和mask
    # 后续的损失计算和优化步骤保持不变


In [None]:
from torch.utils.data import Dataset

class MRIandMaskDataset(Dataset):
    def __init__(self, data, mask_data, labels):
        self.data = data
        self.mask_data = mask_data
        self.labels = labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        image = self.data[idx]
        mask = self.mask_data[idx]
        label = self.labels[idx]
        return image, mask, label


In [None]:
# 将数据和标签转换为Tensor
data_tensor = torch.tensor(data, dtype=torch.float32)
mask_data_tensor = torch.tensor(mask_data, dtype=torch.float32)
labels_tensor = torch.tensor(int_labels_array, dtype=torch.long)

# 使用自定义的Dataset
dataset = MRIandMaskDataset(data_tensor, mask_data_tensor, labels_tensor)

# 划分数据集
train_dataset, val_dataset = train_test_split(dataset, test_size=0.2, random_state=42)

# 创建DataLoader
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False)
