In [1]:
from torch.utils.data import DataLoader,Dataset
import matplotlib.pyplot as plt
import os
import numpy as np
import torch
from torch import nn, optim
from torchvision import transforms
import torch.nn.functional as F
import time
from tqdm import tqdm
import csv
from sklearn import metrics

batch_size = 16

In [2]:
train_img_path = 'nodule'
val_img_path = 'valNodule'
label_path = 'trainVal.csv'

In [3]:
def csvDictReader(path):
    with open(path) as rf:
        reader = csv.reader(rf)
        items = list(reader)
    dict = {}
    for line in items:
        dict[line[0] + '.npz'] = int(line[1])
    return dict

In [4]:
class myDataset(Dataset):
    def __init__(self, img_path, label_path, transform=None):
        self.img_path = img_path
        self.label_path = label_path
        self.transform = transform
        self.img = os.listdir(img_path)# img是一个list
        self.labelDict = csvDictReader(label_path)

    def __len__(self):
        return len(self.img)

    def __getitem__(self, index):
        img_index = self.img[index]
        img_path = os.path.join(self.img_path, img_index)
        img_voxel = np.load(img_path)['voxel']
        img_mask = np.load(img_path)['seg']
        #final_img = torch.from_numpy(img_voxel * img_mask * 0.8 + img_voxel * 0.2)[34:66, 34:66, 34:66] / 255
        final_img = torch.from_numpy(img_voxel * img_mask)[34:66, 34:66, 34:66] / 255
        final_img = torch.unsqueeze(final_img, 0)
        label = self.labelDict[img_index]
        
        return final_img.float(), label

In [6]:
train_dataset = myDataset(train_img_path, label_path)
val_dataset = myDataset(val_img_path, label_path)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

In [7]:
class FlattenLayer(nn.Module):
    def __init__(self):
        super(FlattenLayer, self).__init__()
    def forward(self, x): # x shape: (batch, *, *, ...)
        return x.view(x.shape[0], -1)
    
class GlobalAvgPool3d(nn.Module):
    def __init__(self):
        super(GlobalAvgPool3d, self).__init__()
    def forward(self, x):
        return F.avg_pool3d(x, kernel_size=x.size()[2:])

In [8]:
def conv_block(in_channels, out_channels):
    blk = nn.Sequential(nn.BatchNorm3d(in_channels),
                        nn.ReLU(),
                        nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1))
    return blk

class DenseBlock(nn.Module):
    def __init__(self, num_convs, in_channels, out_channels):
        super(DenseBlock, self).__init__()
        net = []
        for i in range(num_convs):
            in_c = in_channels + i * out_channels
            net.append(conv_block(in_c, out_channels))
        self.net = nn.ModuleList(net)
        self.out_channels = in_channels + num_convs * out_channels
    
    def forward(self, X):
        for blk in self.net:
            Y = blk(X)
            X = torch.cat((X, Y), dim=1)
        return X

def transition_block(in_channels, out_channels):
    blk = nn.Sequential(
        nn.BatchNorm3d(in_channels),
        nn.ReLU(),
        nn.Conv3d(in_channels, out_channels, kernel_size=1),
        nn.AvgPool3d(kernel_size=2, stride=2))
    return blk

In [10]:
net = nn.Sequential(
    nn.Conv3d(1, 16, kernel_size=3, stride=1, padding=1),
    nn.BatchNorm3d(16),
    nn.ReLU(),
    nn.MaxPool3d(kernel_size=3, stride=2, padding=1))

num_channels, growth_rate = 16, 32 # 当前的通道数
num_convs_in_dense_blocks = [4, 4, 4]

for i, num_convs in enumerate(num_convs_in_dense_blocks):
    DB = DenseBlock(num_convs, num_channels, growth_rate)
    net.add_module("DenseBlock_%d" % i, DB)
    # 上⼀个DenseBlock的输出通道数
    num_channels = DB.out_channels
    # 在DenseBlock之间加⼊过渡层，通道数减半
    if i != len(num_convs_in_dense_blocks) - 1:
        net.add_module("transition_block_%d" % i, transition_block(num_channels, num_channels // 2))
        num_channels = num_channels // 2

In [11]:
net.add_module("BN", nn.BatchNorm3d(num_channels))
net.add_module("relu", nn.ReLU())
net.add_module("global_avg_pool", GlobalAvgPool3d()) 
net.add_module("fc", nn.Sequential(FlattenLayer(), nn.Linear(num_channels, 2)))

In [12]:
# Find total parameters and trainable parameters
total_params = sum(p.numel() for p in net.parameters())
print(f'{total_params:,} total parameters.')
total_trainable_params = sum(
    p.numel() for p in net.parameters() if p.requires_grad)
print(f'{total_trainable_params:,} training parameters.')

1,183,054 total parameters.
1,183,054 training parameters.


In [13]:
lr1 = 0.005
lr2 = 0.001

criterion = nn.CrossEntropyLoss()

In [14]:
def mixup_data(x, y, alpha=1.0):
    '''Returns mixed inputs, pairs of targets, and lambda'''
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    batch_size = x.size()[0]
    index = torch.randperm(batch_size)
    
    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

def mixup_criterion(criterion, pred, y_a, y_b, lam):
    # 对loss函数进行混合，criterion是crossEntropy函数
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

In [15]:
num_epochs = 20

for epoch in range(num_epochs):
    train_l_sum, train_acc_sum, n = 0.0, 0.0, 0
    for images, labels in tqdm(train_loader):
        
        
        images, targets_a, targets_b, lam = mixup_data(images, labels, 1)    # 对数据集进行mixup操作
        
        labels_hat = net(images)
        #l =  criterion(labels_hat, labels).sum()
        l = mixup_criterion(criterion, labels_hat, targets_a, targets_b, lam)    #对loss#函数进行mixup操作
        
        if epoch<=4:
            optimizer = torch.optim.Adam(net.parameters(), lr=lr1)
        else:
            optimizer = torch.optim.Adam(net.parameters(), lr=lr2)
        
        optimizer.zero_grad()# reset gradient
        l.backward()
        optimizer.step()# update parameters of net
        
        train_l_sum += l.item()
        train_acc_sum += (labels_hat.argmax(dim=1) == labels).sum().item()
        n += labels.shape[0]


100%|██████████| 30/30 [03:05<00:00,  6.19s/it]
100%|██████████| 30/30 [03:09<00:00,  6.33s/it]
100%|██████████| 30/30 [03:04<00:00,  6.14s/it]
100%|██████████| 30/30 [03:04<00:00,  6.14s/it]
100%|██████████| 30/30 [03:13<00:00,  6.44s/it]
100%|██████████| 30/30 [03:11<00:00,  6.40s/it]
100%|██████████| 30/30 [03:10<00:00,  6.34s/it]
100%|██████████| 30/30 [03:07<00:00,  6.26s/it]
100%|██████████| 30/30 [03:06<00:00,  6.20s/it]
100%|██████████| 30/30 [03:07<00:00,  6.24s/it]
100%|██████████| 30/30 [03:05<00:00,  6.19s/it]
100%|██████████| 30/30 [03:04<00:00,  6.13s/it]
100%|██████████| 30/30 [03:02<00:00,  6.07s/it]
100%|██████████| 30/30 [03:03<00:00,  6.13s/it]
100%|██████████| 30/30 [03:06<00:00,  6.22s/it]
100%|██████████| 30/30 [03:05<00:00,  6.19s/it]
100%|██████████| 30/30 [03:07<00:00,  6.26s/it]
100%|██████████| 30/30 [03:07<00:00,  6.26s/it]
100%|██████████| 30/30 [03:13<00:00,  6.47s/it]
100%|██████████| 30/30 [03:18<00:00,  6.62s/it]


In [17]:
prediction = np.zeros(117)
i = 0
for times in tqdm(range(0, 117)):
    while not os.path.exists( 'test/candidate'+ str(i+1) + '.npz'):
        i = i + 1
    tmp = np.load('test/candidate' + str(i+1) + '.npz')
    i = i + 1
    img_voxel = tmp['voxel']
    img_mask = tmp['seg']
    x = torch.from_numpy(img_voxel * img_mask * 0.8 + img_voxel * 0.2)[34:66, 34:66, 34:66] / 255
    x = torch.unsqueeze(x, 0)
    x = torch.unsqueeze(x, 0)
    x = x.float()

    net.eval()
    y_hat = net(x)
    arr = y_hat.detach().numpy()[0]
    prediction[times] = np.exp(arr[1])/(np.exp(arr[0]) + np.exp(arr[1]))
    net.train()


  0%|          | 0/117 [00:00<?, ?it/s][A

  2%|▏         | 2/117 [00:00<00:14,  8.10it/s][A
  3%|▎         | 3/117 [00:00<00:13,  8.28it/s][A
  3%|▎         | 4/117 [00:00<00:14,  7.92it/s][A
  4%|▍         | 5/117 [00:00<00:13,  8.08it/s][A
  5%|▌         | 6/117 [00:00<00:14,  7.83it/s][A
  6%|▌         | 7/117 [00:00<00:13,  8.13it/s][A
  7%|▋         | 8/117 [00:01<00:14,  7.57it/s][A
  8%|▊         | 9/117 [00:01<00:14,  7.34it/s][A
  9%|▊         | 10/117 [00:01<00:14,  7.32it/s][A
  9%|▉         | 11/117 [00:01<00:14,  7.23it/s][A
 10%|█         | 12/117 [00:01<00:13,  7.58it/s][A
 11%|█         | 13/117 [00:01<00:13,  7.53it/s][A
 12%|█▏        | 14/117 [00:01<00:13,  7.64it/s][A
 13%|█▎        | 15/117 [00:02<00:14,  6.97it/s][A
 14%|█▎        | 16/117 [00:02<00:14,  7.03it/s][A
 15%|█▍        | 17/117 [00:02<00:14,  6.92it/s][A
 15%|█▌        | 18/117 [00:02<00:13,  7.24it/s][A
 16%|█▌        | 19/117 [00:02<00:13,  7.26it/s][A
 17%|█▋        | 20/117 [00