In [1]:
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import numpy as np
from matplotlib import pyplot as plt
from tqdm import tqdm
from dataset import *
from model import *
from loss import *
import os
import SimpleITK as sitk
from itkwidgets import view
%matplotlib widget

In [2]:
mode='gpu'

In [3]:
if mode=='gpu':
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # after switch device, you need restart the kernel
#     torch.cuda.set_device(1)
    torch.set_default_tensor_type('torch.cuda.DoubleTensor')
else:
    device = torch.device('cpu')
    torch.set_default_dtype(torch.float64)

1. For classifications(segmentation=voxel-wise classification), `F.softmax(output, dim=1)` is very necessary at the end of the model, as it constraints the output into a probability, or you may have negative value that you also have no clue where it comes from.
2. The numerator in dice loss for each category is very much like the cross entropy: a softmax vector inner product with a one-hot vector - only the value at where one is matters.
2. For segmentation, use dice loss.

## Training
### initialization

In [4]:
resume = False
save_model = True
print(f'resume:{resume}, save_model:{save_model}')
output_dir = 'Models/FNet'
if not os.path.isdir(output_dir):
    os.mkdir(output_dir)

resume:False, save_model:True


In [5]:
epoch_loss_list = []
epoch_num = 1001
start_epoch_num = 7
batch_size = 6
learning_rate = 8e0

model = FNet()
model.train()
if mode=='gpu':
    model.cuda()
net = torch.nn.DataParallel(model, device_ids=[0, 1])
criterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adadelta(model.parameters(), lr=learning_rate)

dataset = FnetDataset(root_dir='/home/sci/hdai/Projects/Dataset/LymphNodes')
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)

if resume:
    checkpoint = torch.load(f'{output_dir}/epoch_{start_epoch_num-1}_checkpoint.pth.tar')    
    model.load_state_dict(checkpoint['model_state_dict'])
    
    with open(f'{output_dir}/loss.txt', 'a') as f:
        f.write(f'From {start_epoch_num} to {epoch_num+start_epoch_num}\n')
        f.write(f'BCE; Adadelta, lr={learning_rate}; batch size: {batch_size}\n')
else:
    start_epoch_num = 0  
    
    with open(f'{output_dir}/loss.txt', 'w+') as f:
        f.write(f'From {start_epoch_num} to {epoch_num+start_epoch_num}\n')
        f.write(f'BCE; Adadelta: lr={learning_rate}; batch size: {batch_size}\n')
    
print(f'Starting from iteration {start_epoch_num} to iteration {epoch_num+start_epoch_num}')

# params 176118, # conv layers 40
Starting from iteration 0 to iteration 1001


### process

In [None]:
for epoch in tqdm(range(start_epoch_num, start_epoch_num+epoch_num)):
    epoch_loss = 0
            
    for i, batched_sample in tqdm(enumerate(dataloader)):
        '''innerdomain backpropagate'''
#         print(i)
        input0 = batched_sample['img0'].double()#.to(device)
        input1 = batched_sample['img1'].double()#.to(device)
        input2 = batched_sample['img2'].double()#.to(device)
        input3 = batched_sample['img3'].double()#.to(device)
#         print(input.shape)
        input0.requires_grad = True
        input1.requires_grad = True
        input2.requires_grad = True
        input3.requires_grad = True
        # u_pred: [batch_size, *data_shape, feature_num] = [1, 5, ...]
        output_pred = net(input0,input1,input2,input3)
        output_true = batched_sample['mask']#.to(device)#.double()
#         print(output_pred.shape, output_true.shape)
    
        optimizer.zero_grad()
#         loss = criterion(output_pred, output_true.squeeze())
        loss = criterion(output_pred, output_true.double())
        loss.backward()
        epoch_loss += loss.item()
        optimizer.step()
        
    with open(f'{output_dir}/loss.txt', 'a') as f:
        f.write(f'{epoch_loss}\n')
    
    print(f'epoch {epoch} loss: {epoch_loss}')#, norm: {torch.norm(f_pred,2)**2}
    epoch_loss_list.append(epoch_loss)
    if epoch%1==0:       
        if save_model:
            torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': epoch_loss,
            }, f'{output_dir}/epoch_{epoch}_checkpoint.pth.tar')

  0%|          | 0/1001 [00:00<?, ?it/s]
0it [00:00, ?it/s][A
1it [00:51, 51.85s/it][A
2it [01:35, 46.84s/it][A
3it [02:27, 49.34s/it][A
4it [03:11, 47.32s/it][A
5it [03:55, 46.06s/it][A
6it [04:40, 45.61s/it][A
7it [05:38, 49.59s/it][A
8it [06:27, 49.68s/it][A
9it [07:15, 48.96s/it][A
10it [08:06, 49.58s/it][A
11it [08:47, 46.89s/it][A
12it [09:15, 41.27s/it][A
13it [09:55, 41.03s/it][A
14it [11:03, 49.04s/it][A
15it [11:51, 47.45s/it][A
  0%|          | 1/1001 [11:51<197:44:15, 711.86s/it]
0it [00:00, ?it/s][A

epoch 0 innerdomain loss: 1.5806993324481435



1it [00:29, 29.42s/it][A
2it [00:59, 30.05s/it][A
3it [01:31, 30.61s/it][A
4it [02:02, 31.00s/it][A
5it [02:35, 31.63s/it][A
6it [03:08, 32.14s/it][A
7it [03:41, 32.33s/it][A
8it [04:13, 32.22s/it][A
9it [04:44, 31.78s/it][A
10it [05:16, 31.94s/it][A
11it [05:48, 32.11s/it][A
12it [06:18, 31.23s/it][A
13it [06:50, 31.68s/it][A
14it [07:21, 31.47s/it][A

In [None]:
print(input4.shape)

In [None]:
view(output_true[1,0].detach().cpu())

In [None]:
plt.figure(figsize=(7,5))
plt.title('Innerdomain loss')
plt.xlabel('epoch')
plt.ylabel('BCE loss')
plt.plot(epoch_loss_list)
plt.savefig(f'{output_dir}/adadelta_loss_{learning_rate}.png')