基本库导入

In [None]:
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

import time
import numpy as np
import math

模型定义

In [None]:
# 残差流
class RestNetBasicBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride):
        super(RestNetBasicBlock, self).__init__()
        self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=[3,3,3], stride=stride, padding=1)
        self.bn1 = nn.BatchNorm3d(out_channels)
        self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=[3,3,3], stride=stride, padding=1)
        self.bn2 = nn.BatchNorm3d(out_channels)

    def forward(self, x):
        output = self.conv1(x)
        tmp = self.bn1(output)
        output = F.relu(tmp)
        output = self.conv2(output)
        output = self.bn2(output)
        return F.relu(x + output)
    
class RestNetDownBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride):
        super(RestNetDownBlock, self).__init__()
        self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=3, stride=stride[0], padding=1)
        self.bn1 = nn.BatchNorm3d(out_channels)
        self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=3, stride=stride[1], padding=1)
        self.bn2 = nn.BatchNorm3d(out_channels)
        self.extra = nn.Sequential(
            nn.Conv3d(in_channels, out_channels, kernel_size=1, stride=stride[0], padding=0),
            nn.BatchNorm3d(out_channels)
        )

    def forward(self, x):
        extra_x = self.extra(x)
        output = self.conv1(x)
        out = F.relu(self.bn1(output))

        out = self.conv2(out)
        out = self.bn2(out)
        return F.relu(extra_x + out)
    
class RestNet(nn.Module):
    def __init__(self, num_classes):
        super(RestNet, self).__init__()

        self.conv1 = nn.Conv3d(1, 64, kernel_size=7, stride=2, padding=3)
        self.bn1 = nn.BatchNorm3d(64)
        self.maxpool = nn.MaxPool3d(kernel_size=3, stride=2, padding=1)

        self.layer1 = nn.Sequential(RestNetBasicBlock(64, 64, 1),
                                    RestNetBasicBlock(64, 64, 1))

        self.layer2 = nn.Sequential(RestNetDownBlock(64, 128, [2, 1]),
                                    RestNetBasicBlock(128, 128, 1))

        self.layer3 = nn.Sequential(RestNetDownBlock(128, 256, [2, 1]),
                                    RestNetBasicBlock(256, 256, 1))

        self.layer4 = nn.Sequential(RestNetDownBlock(256, 512, [2, 1]),
                                    RestNetBasicBlock(512, 512, 1))

        self.avgpool = nn.AdaptiveAvgPool3d(output_size=(1,1,1))

    def forward(self, x):  
        # [1, 3, 224, 224]

        out = self.conv1(x)
        # [1, 64, 112, 112]
        
        out = self.layer1(out)
        # [1, 64, 112, 112]
        
        out = self.layer2(out)
        # [1, 128, 56, 56]
        
        out = self.layer3(out)
        # [1, 256, 28, 28]
        
        out = self.layer4(out)
        # [1, 512, 14, 14]
        
        out = self.avgpool(out)
        # [1, 512, 1, 1]
        
        out = out.view(x.shape[0], -1)
        # [1, 512]
        
        return out

In [None]:
# 线性流
class DenseNet(nn.Module):
    def __init__(self):
        super(DenseNet, self).__init__()

        self.fc1 = nn.Linear(60660, 32768)
        self.fc2 = nn.Linear(32768, 8192)
        self.fc3 = nn.Linear(8192, 2048)
        self.fc4 = nn.Linear(2048, 512)

    def forward(self, x):
        out = self.fc4(self.fc3(self.fc2(self.fc1(x))))
        return F.relu(out)
    # [1, 512]

In [None]:
class merge_Attention(nn.Module):
    
    def __init__(self, data_len, drop):
        
        super(merge_Attention, self).__init__()
        self.dropout = nn.Dropout(drop)
        self.query = nn.Linear(data_len, data_len)
        self.key = nn.Linear(data_len, data_len)
        self.value = nn.Linear(data_len, data_len)
        self.c_proj = nn.Linear(data_len, data_len)
        
    def forward(self, x_dicom, x_data):

        #x shape: (batch_size, data_len)
        query = self.query(x_dicom)
        key = self.key(x_data)        
        value = self.value(x_dicom)   
        
        # (batch_size, data_len) --> (batch_size, data_len, 1)
        query = query.unsqueeze(2)
        key = key.unsqueeze(2)
        value = value.unsqueeze(2)
        
        # (batch_size, data_len, 1) matmul (batch_size, 1, data_len) --> (batch_size, data_len, data_len)
        scores = torch.matmul(query, key.permute(0,2,1)) / math.sqrt(query.size(-1))
        weights = F.softmax(scores, dim = -1)           # (batch_size, data_len, data_len)
        weights = self.dropout(weights)
        
        # (batch_size, data_len, data_len) matmul (batch_size, data_len, 1) --> (batch_size, data_len, 1)
        context = torch.matmul(weights, value)
        
        # (batch_size, data_len, 1) --> (batch_size, data_len)
        interacted = context.contiguous().view(context.shape[0], -1)
        interacted = self.c_proj(interacted)
        
        return interacted

In [None]:
# 隐藏层
class FeedForward(nn.Module):

    def __init__(self, data_len, middle_dim, drop):
        super(FeedForward, self).__init__()
        
        self.fc1 = nn.Linear(data_len, middle_dim)
        self.fc2 = nn.Linear(middle_dim, data_len)
        self.dropout = nn.Dropout(drop)

    def forward(self, x):
        out = F.relu(self.fc1(x))
        out = self.fc2(self.dropout(out))
        return out

In [None]:
# 融合层
class attentionmerge(nn.Module):
    
    def __init__(self, data_len, middle_dim, drop, num_class):
        super(attentionmerge, self).__init__()
        
        self.resnet = RestNet(num_class)
        self.densenet = DenseNet()
        self.attention = merge_Attention(data_len, drop)
        self.layernorm = nn.LayerNorm(data_len)
        self.feedforward = FeedForward(data_len, middle_dim, drop)
        self.fc = nn.Linear(data_len, num_class)
    
    def forward(self, x_dicom, x_data):
        y_dicom = self.resnet(x_dicom)
        # [batch, 512]
        
        y_data = self.densenet(x_data)
        # [batch, 512]
        
        interacted = self.attention(y_dicom, y_data)
        interacted = self.layernorm(interacted + y_dicom)
        y = self.feedforward(interacted)
        y = self.layernorm(y + interacted)
        # [batch, data_len]
        
        out = self.fc(y)
        # [batch, num_class]
        
        return out

迭代器

In [None]:
def train(train_loader, model, criterion, epoch):
    metric = Accumulator(3)
    model.train()
    model = model.to(device)
    iter_time = time.time()
    i=0
    for x_dicom, x_data, y in train_loader:
        i+=1
        x_dicom = x_dicom.to(device)
        x_data = x_data.to(device)
        y = y.to(device)

        out = model(x_dicom, x_data)
        loss = criterion(out, y)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        with torch.no_grad():
            metric.add(loss * y.shape[0], accuracy(out, y), y.shape[0])

        train_l = metric[0] / metric[2]
        train_acc = metric[1] / metric[2]

        if ((i+1) % 5 == 0 and i!=0) or i==len(train_loader)-1:
            tnow = time.time()
            iter_dt = tnow - iter_time
            iter_time = tnow
            print("Epoch [{}][{}/{}]  Loss: {:.5f}  accuracy: {:.5f}  time: {:.2f}s".format(epoch, i+1, 
                                                                                            len(train_loader), train_l,
                                                                                            train_acc, iter_dt))
            f = open("training_data_attention.txt", "a")
            f.write("Epoch [{}][{}/{}]  Loss: {:.5f}  accuracy: {:.5f}  time: {:.2f}s".format(epoch, i+1, 
                                                                                              len(train_loader), train_l,
                                                                                              train_acc, iter_dt)+'\n')
            f.close()

In [None]:
# 参数初始化
def _init_weights(module):
    # isinstance用于判断第一个参数是否是第二个参数的实例
    # 对Linear层初始化
    if isinstance(module, nn.Linear):
        torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
        if module.bias is not None:
            torch.nn.init.zeros_(module.bias)
    elif isinstance(module, nn.LayerNorm):
        # 对LayerNorm层初始化
        torch.nn.init.zeros_(module.bias)
        torch.nn.init.ones_(module.weight)

优化器

In [None]:
def configure_optimizers(model, lr, weight_decay):
        decay = set()
        no_decay = set()
        # 通常来说我们只对线性层做weight decay, 主要是Attention
        whitelist_weight_modules = (torch.nn.Linear, )
        blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Conv3d,
                                    torch.nn.BatchNorm3d, torch.nn.Conv2d,
                                    torch.nn.MaxPool2d, torch.nn.AdaptiveAvgPool2d,
                                    torch.nn.BatchNorm2d)
        for mn, m in model.named_modules():
        # 遍历模型参数, mn是Moudle name m是Moudle
            for pn, p in m.named_parameters():
                fpn = '%s.%s' % (mn, pn) if mn else pn # 完整的参数名称
                if pn.endswith('bias'):
                    no_decay.add(fpn)
                elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
                    decay.add(fpn)
                elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
                    no_decay.add(fpn)

        # 判断有没有参数既被判定需要weight decay又被判定不需要(同时存在于两个集合当中)
        param_dict = {pn: p for pn, p in model.named_parameters()}
        inter_params = decay & no_decay
        union_params = decay | no_decay
        assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
        assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
                                                    % (str(param_dict.keys() - union_params), )

        # 创建 PyTorch 优化器对象
        optim_groups = [
            {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": weight_decay},
            {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0}]
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
        return optimizer

精度函数

In [None]:
# 计数器
class Accumulator:
    def __init__(self, n):
        self.data = [0.0] * n

    def add(self, *args):
        self.data = [a + float(b) for a, b in zip(self.data, args)]

    def reset(self):
        self.data = [0.0] * len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

In [None]:
def accuracy(y_hat, y):
    if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:
        y_hat = y_hat.argmax(axis=1)
    cmp = y_hat.type(y.dtype) == y
    return float(cmp.type(y.dtype).sum())

In [None]:
def evaluate_accuracy_gpu(model, data_iter, device=None):
    if isinstance(model, nn.Module):
        model.eval()
        if not device:
            device = next(iter(model.parameters())).device
    metric = Accumulator(2)
    with torch.no_grad():
        for x_dicom, x_data, y in data_iter:
            if isinstance(x_dicom, list):
                x_dicom = [x_dicom.to(device) for x in x_dicom]
            else:
                x_dicom = x_dicom.to(device)
            if isinstance(x_data, list):
                x_data = [x_data.to(device) for x in x_data]
            else:
                x_data = x_data.to(device)
            y = y.to(device)
            metric.add(accuracy(model(x_dicom, x_data), y), y.numel())
    return metric[0] / metric[1]

数据集载入

In [None]:
class my_Dataset(Dataset):
    def __init__(self, data1_path, data2_path,label_path, transform=None):
        super(Dataset,self).__init__()
        self.feature1 = torch.from_numpy(np.load(data1_path)).float()
    
        self.feature2 = torch.from_numpy(np.load(data2_path)).float()
        self.label = torch.from_numpy(np.load(label_path)).long()
        self.transform = transform
    def __len__(self):
        return self.feature1.shape[0]
    def __getitem__(self, item):
        X1 = self.feature1[item]
        X2 = self.feature2[item]
        y = self.label[item]
        if self.transform:
            X1 = self.transform(X1)
            X2 = self.transform(X2)
        return X1,X2,y

参数设置

In [None]:
data_len = 512
middle_dim = 2048
num_class = 2
drop = 0.1
lr = 0.0001
num_epochs = 100
batch_size = 5

In [None]:
dataset=my_Dataset('./data_ROI/feature1.npy','./data_ROI/feature2.npy','./data_ROI/label.npy')
train_iter=DataLoader(dataset,
                      batch_size = batch_size,
                      shuffle=True,
                      pin_memory=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = attentionmerge(data_len, middle_dim, drop, num_class)
optimizer = configure_optimizers(model, lr=lr, weight_decay=0.1)
criterion = nn.CrossEntropyLoss()

模型预览

In [None]:
n_params = sum(p.numel() for p in model.parameters())
print("number of parameters: %.2fM" % (n_params/1e6,))
# 输出参数量

训练

In [None]:
for epoch in range(num_epochs):
    model.apply(_init_weights)
    for pn, p in model.named_parameters():
        if pn.endswith('c_proj.weight'):
            torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2))

    train(train_iter, model, criterion, epoch)
    test_acc = evaluate_accuracy_gpu(model, train_iter)
    print("Epoch [{}]  accuracy: {:.5f}".format(epoch, test_acc))
    f = open("training_data_attention.txt", "a")
    f.write("Epoch [{}]  accuracy: {:.5f}".format(epoch, test_acc)+'\n')
    f.close()
    torch.save(model, './checkpoint_attention/checkpoint_' + str(epoch) + '.pt')