In [8]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader,random_split
from torchvision.transforms import transforms

from simple_nn_model import Net
from dicom_dataset import DicomDataset
from ssim_loss import SSIM

import time
import numpy as np

In [9]:
#Defining training parametars
image_size = 64
batch_size = 20
learning_rate = 0.001
num_of_epochs = 100

In [10]:
#Preparing dataset
transformations = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize(image_size),
    transforms.Normalize((0.5,),(0.5,))
])

dataset = DicomDataset("../slike/", transform=transformations)

train_size = int(0.8 * len(dataset))
validation_size = int(0.15 * len(dataset))
test_size = len(dataset)-train_size-validation_size

train_set, validation_set, test_set = random_split(dataset, [train_size, validation_size, test_size], generator=torch.Generator().manual_seed(42))

train_loader = DataLoader(dataset=train_set, batch_size=batch_size, shuffle=True)
validation_loader = DataLoader(dataset=validation_set, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_set,batch_size=batch_size, shuffle=True)

In [11]:
#Checking for cuda
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

In [12]:
#Creating model
model = Net()
model.to(device)

Net(
  (pool): AvgPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0)
  (upsamp): Upsample(scale_factor=2.0, mode=bilinear)
  (leaky_relu): LeakyReLU(negative_slope=0.01)
  (conv64): Sequential(
    (0): Conv2d(2, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.01)
    (2): Conv2d(2, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): LeakyReLU(negative_slope=0.01)
    (4): Conv2d(2, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): LeakyReLU(negative_slope=0.01)
  )
  (conv128): Sequential(
    (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.01)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): LeakyReLU(negative_slope=0.01)
    (4): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): LeakyReLU(negative_slope=0.01)
  )
  (upsample1): Sequential(
    (0): Upsample(scale_factor=2.0, mode=bilinear)
    (1)

In [13]:
#Preparing criterion and optimizer
criterion_ssim = SSIM()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [14]:
#Training and validation
for epoch in range(num_of_epochs):
    start_time = time.time()
    #Training
    model.train()
    training_losses = []
    for iteration, (prev_img, next_img, expcted_img) in enumerate(train_loader):
        prev_img = prev_img.to(device=device)
        next_img = next_img.to(device=device)
        expcted_img = expcted_img.to(device=device)

        output = model(prev_img,next_img)

        loss = 1 - criterion_ssim(output, expcted_img)

        loss_value = loss.item()
        
        loss.backward()
        
        optimizer.step()
        optimizer.zero_grad()
        
        training_losses.append(loss_value)
    avg_training_losses = np.array(training_losses).mean()
    print("Epoch {} Training Completed: Train Avg. SSIM Loss: {:.4f}".format(epoch+1, avg_training_losses))
    
    #Validation
    model.eval()
    val_losses = []
    with torch.no_grad():
        for prev_img, next_img, expcted_img in validation_loader:
            prev_img, next_img = prev_img.to(device),next_img.to(device)
            expcted_img = expcted_img.to(device)

            output = model(prev_img, next_img)

            loss = 1 - criterion_ssim(output, expcted_img)
            val_losses.append(loss.item())
    avg_val_losses = np.array(val_losses).mean()
    print("Epoch {} Validation Completed: Validation Avg. Loss: {:.4f}".format(epoch+1, avg_val_losses))
    end_time = time.time() - start_time
    print("This epoch took {:.2f} seconds to complete".format(end_time))


Epoch 1 Training Completed: Train Avg. SSIM Loss: 0.3629
Epoch 1 Validation Completed: Validation Avg. Loss: 0.2406
This epoch took 5.63 seconds to complete
Epoch 2 Training Completed: Train Avg. SSIM Loss: 0.1990
Epoch 2 Validation Completed: Validation Avg. Loss: 0.1587
This epoch took 5.67 seconds to complete
Epoch 3 Training Completed: Train Avg. SSIM Loss: 0.1386
Epoch 3 Validation Completed: Validation Avg. Loss: 0.1267
This epoch took 5.65 seconds to complete
Epoch 4 Training Completed: Train Avg. SSIM Loss: 0.1178
Epoch 4 Validation Completed: Validation Avg. Loss: 0.1145
This epoch took 5.65 seconds to complete
Epoch 5 Training Completed: Train Avg. SSIM Loss: 0.1037
Epoch 5 Validation Completed: Validation Avg. Loss: 0.1013
This epoch took 5.65 seconds to complete
Epoch 6 Training Completed: Train Avg. SSIM Loss: 0.0966
Epoch 6 Validation Completed: Validation Avg. Loss: 0.0900
This epoch took 5.64 seconds to complete
Epoch 7 Training Completed: Train Avg. SSIM Loss: 0.0876
E

Epoch 53 Training Completed: Train Avg. SSIM Loss: 0.0567
Epoch 53 Validation Completed: Validation Avg. Loss: 0.0565
This epoch took 5.70 seconds to complete
Epoch 54 Training Completed: Train Avg. SSIM Loss: 0.0573
Epoch 54 Validation Completed: Validation Avg. Loss: 0.0571
This epoch took 5.71 seconds to complete
Epoch 55 Training Completed: Train Avg. SSIM Loss: 0.0576
Epoch 55 Validation Completed: Validation Avg. Loss: 0.0570
This epoch took 5.71 seconds to complete
Epoch 56 Training Completed: Train Avg. SSIM Loss: 0.0580
Epoch 56 Validation Completed: Validation Avg. Loss: 0.0589
This epoch took 5.72 seconds to complete
Epoch 57 Training Completed: Train Avg. SSIM Loss: 0.0566
Epoch 57 Validation Completed: Validation Avg. Loss: 0.0569
This epoch took 5.71 seconds to complete
Epoch 58 Training Completed: Train Avg. SSIM Loss: 0.0564
Epoch 58 Validation Completed: Validation Avg. Loss: 0.0564
This epoch took 5.72 seconds to complete
Epoch 59 Training Completed: Train Avg. SSIM L

In [15]:
#Saving trained model
torch.save(model.state_dict(),"model.pt")
print("Saving model")


Saving model
