In [None]:
import torch
import torch.nn as nn
import torchvision.datasets
import torchvision.transforms as transforms
from torch.utils.data import TensorDataset, DataLoader, Dataset
from torch.utils.data.dataset import random_split
from torchvision.transforms.transforms import ToPILImage
from skimage import io

import numpy as np
import matplotlib.pyplot as plt

import pandas as pd
import os

In [None]:
if torch.cuda.is_available():
    device = "cuda"
    torch.cuda.empty_cache()
    print(torch.cuda.memory_summary(device=None, abbreviated=False))

In [None]:
source_dir = 'D:\\heart_data\\undersampled_heart_images'
list_dir = os.listdir(source_dir)
list_pd = []
for file_name in list_dir:
    list_pd.append(file_name)
df = pd.DataFrame(list_pd)
df

In [None]:
class CustomDataTransform(Dataset):
    def __init__(self, df, features_transform=None, label_transform=None):
        self.df = df
        self.features_transform = features_transform
        self.label_transform = label_transform
        self.root_dir_x = 'D:\\heart_data\\undersampled_heart_images'
        self.root_dir_y = 'D:\\heart_data\\heart_images'

    def __len__(self):
        return len(self.df)
        
    def __getitem__(self,index):
        img_path_x = os.path.join(self.root_dir_x, self.df.iloc[index, 0])
        img_path_y = os.path.join(self.root_dir_y, self.df.iloc[index, 0])
        image_x = io.imread(img_path_x)
        image_y = io.imread(img_path_y)
        
        if self.features_transform is not None:
            image_x = self.features_transform(image_x)

        if self.label_transform is not None:
            image_y = self.label_transform(image_y)

        return (image_x, image_y)

x_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((128,128)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.5),
    transforms.ToTensor()
])
#    transforms.Normalize(mean=[0.5],
#                         std=[0.1])

y_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((128,128)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.5),
    transforms.ToTensor()
])

dataset = CustomDataTransform(df, features_transform=x_transform,
                                  label_transform=y_transform)

batch_size = 8
part = 0.8
train_lenght = int(len(dataset)*part)
test_lenght = int(len(dataset) - train_lenght)

train_set, test_set = random_split(dataset, [train_lenght, test_lenght])
train_loader = DataLoader(train_set, batch_size=batch_size, drop_last=False, shuffle=True)
test_loader = DataLoader(test_set, batch_size=batch_size, drop_last=False, shuffle=True)

In [None]:
print(f'Length of dataset is {len(dataset)}')
plt.imshow(train_set[0][1][0], cmap='gray')

In [None]:
class Block(nn.Module):
    def __init__(self, in_channels, out_channels, down=True, act="relu", use_dropout=False):
        super(Block, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1, bias=False, padding_mode="reflect")
            if down
            else nn.ConvTranspose2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU() if act == "relu" else nn.LeakyReLU(0.2)
        )

        self.use_dropout = use_dropout
        self.dropout = nn.Dropout(0.5)
        self.down = down

    def forward(self, x):
        x = self.conv(x)
        return self.dropout(x) if self.use_dropout else x

class Generator(nn.Module):
    def __init__(self, in_channels=1, features=64):
        super(Generator, self).__init__()

        self.initial_down = nn.Sequential(
            nn.Conv2d(in_channels, features, kernel_size=2, stride=1, padding=1, padding_mode="reflect"),
            nn.LeakyReLU(0.2),
        )
        self.down1 = Block(features, features * 2, down=True, act="leaky", use_dropout=False)
        self.down2 = Block(
            features * 2, features * 4, down=True, act="leaky", use_dropout=False
        )
        self.down3 = Block(
            features * 4, features * 8, down=True, act="leaky", use_dropout=False
        )
        self.down4 = Block(
            features * 8, features * 8, down=True, act="leaky", use_dropout=False
        )
        self.down5 = Block(
            features * 8, features * 8, down=True, act="leaky", use_dropout=False
        )
        self.down6 = Block(
            features * 8, features * 8, down=True, act="leaky", use_dropout=False
        )
        self.bottleneck = nn.Sequential(
            nn.Conv2d(features * 8, features * 8, kernel_size=4, stride=1, padding=1), nn.ReLU()
        )

        self.up1 = Block(features * 8, features * 8, down=False, act="relu", use_dropout=False)
        self.up2 = Block(
            features * 8 * 2, features * 8, down=False, act="relu", use_dropout=False
        )
        self.up3 = Block(
            features * 8 * 2, features * 8, down=False, act="relu", use_dropout=False
        )
        self.up4 = Block(
            features * 8 * 2, features * 8, down=False, act="relu", use_dropout=False
        )
        self.up5 = Block(
            features * 8 * 2, features * 4, down=False, act="relu", use_dropout=False
        )
        self.up6 = Block(
            features * 4 * 2, features * 2, down=False, act="relu", use_dropout=False
        )
        self.up7 = Block(features * 2 * 2, features, down=False, act="relu", use_dropout=False)
        
        self.final_up = nn.Sequential(
            nn.ConvTranspose2d(features * 2, features * 2, kernel_size=4, stride=2, padding=1),
            nn.ReLU()
        )
        
        self.reeucing_conv = nn.Sequential(
            nn.Conv2d(features * 2, 1, kernel_size=4, stride=2, padding=0),
            nn.ReLU()
        )

    def forward(self, x):
        d1 = self.initial_down(x)
        d2 = self.down1(d1)
        d3 = self.down2(d2)
        d4 = self.down3(d3)
        d5 = self.down4(d4)
        d6 = self.down5(d5)
        #print('d6 ',d6.shape)
        d7 = self.down6(d6)
        #print('d7 ',d7.shape)

        bottleneck = self.bottleneck(d7)
        #print('bottleneck',bottleneck.shape)
        up1 = self.up1(bottleneck)
        #print('up1',up1.shape)

        up2 = self.up2(torch.cat([up1, d7], 1))
        #print('up2',up2.shape)
        up3 = self.up3(torch.cat([up2, d6], 1))
        up4 = self.up4(torch.cat([up3, d5], 1))
        up5 = self.up5(torch.cat([up4, d4], 1))
        up6 = self.up6(torch.cat([up5, d3], 1))
        up7 = self.up7(torch.cat([up6, d2], 1))
        result = self.final_up(torch.cat([up7, d1], 1))
        result = self.reeucing_conv(result)
        return result

In [None]:
def init_all(model, init_func, *params, **kwargs):
    for p in model.parameters():
        init_func(p, *params, **kwargs)

net = Generator(in_channels=1, features=64).to(device)
init_all(net, torch.nn.init.normal_, mean=0.0, std=0.001)

In [None]:
Loss_function = nn.MSELoss()  #nn.L1Loss()  nn.MSELoss()
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer = optimizer, 
                                                       mode = 'min', 
                                                       factor = 0.9, 
                                                       patience = 200,
                                                       threshold = 1e-4,
                                                       verbose = 'True')

In [None]:
def Loss_1(input_prediction, input_target, Loss_function):
    return Loss_function(input_prediction, input_target)

In [None]:
def train_step(train_loader, net, Loss_1, train_history, optimization_step, optimizer, Loss_function):
    for index_b, (feature, target) in enumerate(train_loader):
        target = target.to(device)
        feature = feature.to(device)
        prediction = net(feature)
        Loss = Loss_1(prediction, target, Loss_function)

        optimization_step(net, optimizer, Loss)
        train_history.append(Loss.item())
        
        return Loss

def valid_step(test_loader, net, Loss_1, valid_history, Loss_function):
    for index_b, (feature, target) in enumerate(test_loader):
        target = target.to(device)
        feature = feature.to(device)
        prediction = net(feature)
        Loss = Loss_1(prediction, target, Loss_function)
        valid_history.append(Loss.item())

        return Loss, prediction, target

def optimization_step(net, optimizer, Loss_train_step):
    net.zero_grad()
    Loss_train_step.backward()
    optimizer.step()

#### SSIM Assessment
import skimage
from skimage.metrics import structural_similarity as ssim
#### PSNR Assessment
from skimage.metrics import peak_signal_noise_ratio

def SSIM_PSNR_Metrcis_step(prediction, target, list_psnr, list_ssim, iteration, aver_list_psnr, aver_list_ssim):
    sum_ssim_score = 0
    sum_psnr_score = 0
    for i in range(target.shape[0]):
        ssim_score, diff = ssim(prediction.cpu().detach()[i][0].numpy(), target.cpu().detach()[i][0].numpy(), full=True,  multichannel=False)
        #print(f"Iter: {iteration}, \t SSIM_score: {ssim_score}\n")
        psnr_score = peak_signal_noise_ratio(image_true=target.cpu().detach()[i][0].numpy(), image_test=prediction.cpu().detach()[i][0].numpy(), data_range=None)
        #print(f"Iter: {iteration}, \t PSNR_score: {psnr_score}")
        list_psnr.append(psnr_score)
        list_ssim.append(ssim_score)
        sum_ssim_score+=ssim_score
        sum_psnr_score+=psnr_score
    aver_list_psnr.append(sum_psnr_score/target.shape[0])
    aver_list_ssim.append(sum_ssim_score/target.shape[0])

In [None]:
#it's useful for learning network
epoche = 1200

train_history = []
valid_history = []

list_psnr = []
list_ssim = []
aver_list_psnr = []
aver_list_ssim = []

for iteration in range(epoche):

    #train and valid step
    Loss_train_step = train_step(train_loader, net, Loss_1, train_history, optimization_step, optimizer, Loss_function)
    Loss_valid_step, prediction, target = valid_step(test_loader, net, Loss_1, valid_history, Loss_function)

    #Reducing step of lr
    scheduler.step(Loss_train_step)
    print(scheduler.optimizer.param_groups[0]['lr'])

    #image metrics
    #SSIM_PSNR_Metrcis_step(prediction, target, list_psnr, 
    #                       list_ssim, iteration,
    #                       aver_list_psnr,
    #                       aver_list_ssim)

    print(f'Epoche №: {iteration}')
    
    #Visualization
    img_test = prediction.cpu().detach()[0][0]
    img_target = target.cpu().detach()[0][0]

    fig = plt.figure(figsize=(10,10))
    fig.add_subplot(1, 2, 1)
    plt.imshow(img_test, cmap='gray')
    plt.title("Test image")

    fig.add_subplot(1, 2, 2)
    plt.imshow(img_target, cmap='gray')
    plt.title("Target image")
    plt.show()

    print(f'Loss_train_step per batch: {Loss_train_step}\nLoss_valid_step per batch: {Loss_valid_step}')

In [None]:
print(len(test_set))
print(len(train_set))

In [None]:
#Train and Valid Error Plot
fig, ax = plt.subplots()
ax.plot(train_history, label = 'Train Loss: ' + str(round(train_history[-1], 3)))
ax.plot(valid_history, label = 'Test Loss: ' + str(round(valid_history[-1], 3)))
ax.legend(fontsize=15)
fig.set_figheight(6)
fig.set_figwidth(12)
plt.grid()
plt.title("Error History", fontsize= 20, fontweight='bold')
plt.xlabel("Epoches", fontsize= 20)
plt.ylabel("Error", fontsize= 20)
plt.rcParams.update({'font.size': 18})
plt.show()

In [None]:
# Saving model
PATH = '/content/gdrive/MyDrive/Medical_Imaging/DiplomaReconstructionImage/LearnedNetwork/UNetRec_Mode_Extended_Loss_v1.pt'
torch.save(net.state_dict(), PATH)