# Train the network using U-Net

In [1]:
from DataLoader import ImageDataset
from UNet import UNet
import torch
import torch.nn as nn
import torch.nn.functional as f
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

### Load data from custom data loader

In [2]:
csv_path = "./DL_info.csv"
Image_slices_dir = "./Key_slices"
df = pd.read_csv(csv_path)
train_dataset = ImageDataset(root_dir=Image_slices_dir, dataset_type=1, csv_path=csv_path, data_file=df)
validation_dataset = ImageDataset(root_dir=Image_slices_dir, dataset_type=2, csv_path=csv_path, data_file=df)
test_dataset = ImageDataset(root_dir=Image_slices_dir, dataset_type=3, csv_path=csv_path, data_file=df)
print(len(train_dataset))
print(len(validation_dataset))
print(len(test_dataset))

22919
4889
4927


In [3]:
print(train_dataset[0])

{'image': tensor([[[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]],

        [[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]],

        [[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]]], dtype=torch.uint8), 'lesions': tensor([[345.9880, 284.4280, 372.5360, 319.9120]], dtype=torch.float64)}


### Use data loader for batch training

In [4]:
batch_size=2
dataloader_train=torch.utils.data.DataLoader(train_dataset,batch_size=batch_size,shuffle=False)
print(len(dataloader_train))

dataloader_validation=torch.utils.data.DataLoader(validation_dataset,batch_size=batch_size,shuffle=False)
print(len(dataloader_validation))

dataloader_test=torch.utils.data.DataLoader(test_dataset,batch_size=batch_size,shuffle=False)
print(len(dataloader_test))

11460
2445
2464


### Load the model from UNet

In [11]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet(in_channels=3, n_classes=4).to(device)

In [12]:
print(model)

UNet(
  (two_conv): TwoConv(
    (two_conv): Sequential(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
  )
  (Down1): Downsampling(
    (down_sampling): Sequential(
      (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (1): TwoConv(
        (two_conv): Sequential(
          (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
          (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (4): BatchNorm2d(128, eps=1e-05, mome

### Hyperparameters setting

In [13]:
learning_rate = 0.0001
optimizer = torch.optim.Adam(model.parameters(),lr=learning_rate)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5)
criterion = nn.MSELoss()

num_epochs = 1


Running_training_loss_list = []
Validation_loss_list = []



def validation_loop():
    loss_validation = 0
    count = 0
    for sample_batched in dataloader_validation:
        model.eval()
        out1, out2 = model(sample_batched['image'].cuda())
        loss = criterion1(out1, sample_batched['lesions'].cuda()) + (1e-6)*criterion2(out2, sample_batched['image'].cuda())
        count+=1
        Loss_val = loss.cpu().detach().item()
        loss_validation = loss_validation + Loss_val
        torch.cuda.empty_cache()
    final_loss = (loss_validation*1.0)/count
    print("===============================")
    print("Validation loss is ", final_loss)
    print("===============================")
    Validation_loss_list.append(final_loss)

In [None]:
for epoch in range(num_epochs):
    Running_loss = 0
    count = 0
    
    for sample_batched in dataloader_train:
        model.train()
        out1, out2 = model(sample_batched['image'].to(device))
#         print (out.shape)
#         print(sample_batched['lesions'].shape)
        #print (out.shape)
        #print (sample_batched['lesions'].shape)
        loss = criterion1(out1, sample_batched['lesions'].cuda()) + (1e-6)*criterion2(out2, sample_batched['image'].to(device))
        # Store val in outer variable for printing
#         print (out)
        Loss_val = loss.cpu().detach().item()
        Running_loss = Running_loss + Loss_val
#         print (out[0])
#         print (sample_batched['lesions'][0])
        count += 1
        print ("Batch Loss", Loss_val)
        print ("Running Loss", Running_loss/count)
        #### Measure per image running loss for training set EVERY 50 batches
        if count%50==0:
            Running_training_loss_list.append(Running_loss/(count))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        torch.cuda.empty_cache()
        #### Measure per image running loss for validation set EVERY 50 batches
        if count%50 == 0:
            sample_batched['image'].cpu().detach()
            sample_batched['lesions'].cpu().detach()
            validation_loop()
#         scheduler.step()

#     data0.cpu().detach()
#     data1.cpu().detach()
#     loss.cpu().detach()
    print ("Running Loss", Running_loss/count)