In [None]:
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import torch.optim as optim
from torch.autograd import Variable
from pytorchtools import EarlyStopping
import skimage.io as io
import numpy as np
import matplotlib.pyplot as plt
from torchvision import models
import torchvision
import warnings

warnings.filterwarnings("ignore")
%matplotlib inline

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
# version = 'NSST_PAPCNN_result'
# version = 'my_result_post'
# version = 'GF_result'
# version = 'cnn_lp_result'
# version = 'NSCT-RPCNN_result'
version = 'my_result'

In [None]:
from unet_parts import *

class UNet(nn.Module):
    def __init__(self, n_channels, n_classes):
        super(UNet, self).__init__()
        self.inc = inconv(n_channels, 32)
        self.down1 = down(32, 64)
        self.down2 = down(64, 128)
        self.down3 = down(128, 256)
        self.down4 = down(256, 256)
        self.up1 = up(512, 128)
        self.up2 = up(256, 64)
        self.up3 = up(128, 32)
        self.up4 = up(64, 32)
        self.outc = outconv(32, n_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        x = self.outc(x)
        return F.sigmoid(x)

In [None]:
class MyDataset(Dataset):
    def __init__(self, txt, transform = None, target_transform = None):
        lists = open(txt, 'r')
        imgs = []
        for line in lists:
            line = line.strip('\n')
            line = line.rstrip()
            words = line.split()
            imgs.append((words[0], words[1], words[2]))
        self.imgs = imgs
        self.transform = transform
        self.target_transform = target_transform
#         self.loader = loader

    def __getitem__(self, index):
        fusion_result, ct, mr = self.imgs[index]
        img_fr = io.imread(fusion_result)
        img_ct = io.imread(ct)
        img_mr = io.imread(mr)
        if self.transform is not None:
            img_fr = self.transform(img_fr)
            img_ct = self.transform(img_ct)
            img_mr = self.transform(img_mr)
        return img_fr, img_ct, img_mr

    def __len__(self):
        return len(self.imgs)

In [None]:
train_data = MyDataset(txt = './{}_train_list_whole.txt'.format(version), transform = transforms.ToTensor())
trainloader = torch.utils.data.DataLoader(train_data, batch_size = 1,
                                          shuffle = True, num_workers = 2)

In [None]:
test_data = MyDataset(txt = './{}_test_list_whole.txt'.format(version), transform = transforms.ToTensor())
testloader = torch.utils.data.DataLoader(test_data, batch_size = 1,
                                          shuffle = True, num_workers = 2)

In [None]:
val_data = MyDataset(txt = './{}_val_list_whole.txt'.format(version), transform = transforms.ToTensor())
valloader = torch.utils.data.DataLoader(val_data, batch_size = 1,
                                          shuffle = True, num_workers = 2)

In [None]:
net = UNet(1, 2)
net.to(device)
criterion = nn.MSELoss()
optimizer = optim.SGD(net.parameters(), lr = 0.05, momentum = 0.9)

In [None]:
train_losses = []
valid_losses = []
avg_train_losses = []
avg_valid_losses = [] 
early_stopping = EarlyStopping(save_name = 'rebuild_' + version, patience = 8, verbose = True)
n_epochs = 200
for epoch in range(n_epochs): 
    net.train()
    for i, data in enumerate(trainloader, 0):
        # get the inputs
#         print(len(data))
        inputs, ct, mr = data
        inputs, ct, mr = inputs.to(device), ct.to(device), mr.to(device)

        optimizer.zero_grad()

        outputs = net(inputs)
        loss = criterion(outputs[:, 0, :, :], ct) + criterion(outputs[:, 1, :, :], mr)
        loss.backward()
        optimizer.step()

        train_losses.append(loss.item())

    net.eval()
    for i, data in enumerate(valloader, 0):
        inputs, ct, mr = data
        inputs, ct, mr = inputs.to(device), ct.to(device), mr.to(device)
        outputs = net(inputs)
        loss = criterion(outputs[:, 0, :, :], ct) + criterion(outputs[:, 1, :, :], mr)
        valid_losses.append(loss.item())
    
    train_loss = np.average(train_losses)
    valid_loss = np.average(valid_losses)
    avg_train_losses.append(train_loss)
    avg_valid_losses.append(valid_loss)

    epoch_len = len(str(n_epochs))

    print_msg = (f'[{epoch:>{epoch_len}}/{n_epochs:>{epoch_len}}] ' +
                 f'train_loss: {train_loss:.5f} ' +
                 f'valid_loss: {valid_loss:.5f}')

    print(print_msg)


    train_losses = []
    valid_losses = []


    early_stopping(valid_loss, net)

    if early_stopping.early_stop:
        print("Early stopping")
        break
        
print('Finished Training')

In [None]:
net.load_state_dict(torch.load('rebuild_{}_checkpoint.pt'.format(version)))

In [None]:
test_losses = []
net.eval()
for i, data in enumerate(testloader, 0):
    inputs, ct, mr = data
    inputs, ct, mr = inputs.to(device), ct.to(device), mr.to(device)
    outputs = net(inputs)
    loss = criterion(outputs[:, 0, :, :], ct) + criterion(outputs[:, 1, :, :], mr)
    test_losses.append(loss.item())


test_loss = np.average(test_losses)

print(np.average(test_losses), np.std(test_losses))