## Imports

In [1]:
import os
import pickle
import json
import random
import logging
import numpy as np
from itertools import chain
import torch
import torch.nn as nn
import torchvision
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
import torchio
from tqdm import tqdm
import sys
import matplotlib.pyplot as plt
from torch.utils.tensorboard import SummaryWriter
from torch.optim import Adam
from torch import nn, optim
from torch.optim.lr_scheduler import StepLR
import torch.distributions as dist
import math
import import_ipynb

If you use TorchIO for your research, please cite the following paper:
Pérez-García et al., TorchIO: a Python library for efficient loading,
preprocessing, augmentation and patch-based sampling of medical images
in deep learning. Link: https://arxiv.org/abs/2003.04696



## Data Location

In [2]:
training_data = r"W:\12cp\data\mood.h5"
validation_data = r"W:\12cp\data\MOOD_toytest_brain.h5"
log_path = r"W:\12cp\log"
save_path = r"W:\12cp\save_dir"

## Training Parameters

In [3]:
device = torch.cuda.set_device(0)
batch_size = 2
num_workers=0
trainID="RESAES"
learning_rate = 1e-4
num_epochs = 500

In [4]:
preload_h5 = True
indicesOfImgVols = [1,2] #Supply list of indices if only a subset is desired
patch_size=(256,256,1) #Set it to None if not desired
patchQ_len = 512
patches_per_volume = 256

In [5]:
do_val=True

## Network Parameters

In [6]:
input_shape=(256,256,256)
encode_features=128
linear_op = True
normalize=True
if_rsr=True
enforce_proj=True
all_alt=False
lambda1 = 0.1
lambda2 = 0.1

## Info

In [7]:
mood_region='brain'
useCuda=True

## Data Loader

In [8]:
from Data import MoodTrainSet, MoodValSet

importing Jupyter notebook from Data.ipynb


In [9]:
trainset = MoodTrainSet(indices=indicesOfImgVols, region=mood_region, data_path=training_data, lazypatch=True if patch_size else False, preload=preload_h5)
valset = MoodValSet(data_path=validation_data, lazypatch=True if patch_size else False, preload=preload_h5)

if patch_size:
  input_shape = tuple(x for x in patch_size if x!=1)
  trainset = torchio.data.Queue(
                  subjects_dataset = trainset,
                  max_length = patchQ_len,
                  samples_per_volume = patches_per_volume,
                  sampler = torchio.data.UniformSampler(patch_size=patch_size),
                  # num_workers = num_workers
                  )
  valset = torchio.data.Queue(
                  subjects_dataset = valset,
                  max_length = patchQ_len,
                  samples_per_volume = patches_per_volume,
                  sampler = torchio.data.UniformSampler(patch_size=patch_size),
                  # num_workers = num_workers
                  )

train_loader = DataLoader(dataset=trainset,batch_size=batch_size,shuffle=False, num_workers=num_workers)
val_loader = None if (valset is None) or (not do_val) else DataLoader(dataset=valset,batch_size=batch_size,shuffle=False, num_workers=num_workers)

Preloading MoodTrainSet
Preloading MoodValSet


## Model Blocks

In [10]:
def conv_block(input_channels, output_channels, kernel_size, stride):
  return nn.Sequential(nn.Conv2d(input_channels, output_channels, kernel_size, stride, bias=False),
                    nn.BatchNorm2d(output_channels),
                    nn.LeakyReLU(0.2, inplace=True))
  
def linear_enc(input_channels, output_channels):
  return nn.Sequential(
      nn.Flatten(),
      nn.Linear(in_features=input_channels, out_features=output_channels),
      nn.BatchNorm1d(num_features=output_channels),
      nn.LeakyReLU(0.2, inplace=True))

class Encoder(nn.Module):
    def __init__(self, no_channels, filter_size, latent_size, linear_op=True):
        super(Encoder, self).__init__()

        self.linear_op = linear_op
        self.conv1 = conv_block(no_channels, filter_size, kernel_size=5, stride=2)
        self.conv2 = conv_block(filter_size, filter_size*2, kernel_size=5, stride=2)
        self.conv3 = conv_block(filter_size*2, filter_size*4, kernel_size=5, stride=2)
        self.linear = linear_enc(29*29*filter_size*4, latent_size)
        
    def forward(self, input):
        x = self.conv1(input)
        x = self.conv2(x)
        x = self.conv3(x)
        if self.linear_op:
          x_rsr = self.linear(x)
        return x, x_rsr

def deconv(input_channels, output_channels, kernel_size, stride):
  return nn.Sequential(nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride, bias=False),
                    nn.BatchNorm2d(output_channels),
                    nn.LeakyReLU(0.2, inplace=True))
  
def linear_dec(input_channels, output_channels):
  return nn.Sequential(
      nn.Linear(in_features=input_channels, out_features=output_channels),
      nn.BatchNorm1d(num_features=output_channels),
      nn.LeakyReLU(0.2, inplace=True),
      Unflatten((128,29,29)))

class Unflatten(nn.Module):
    def __init__(self, shape):
        super(Unflatten, self).__init__()
        self.shape = shape
        
    def forward(self, input):
        return input.view(len(input), self.shape[0], self.shape[1], self.shape[2])

class Decoder(nn.Module):
    def __init__(self, no_channels, filter_size, latent_size):
        super(Decoder, self).__init__()

        self.linear = linear_dec(latent_size, 29*29*filter_size*4)
        self.deconv1 = deconv(filter_size*4, filter_size*2, 6, 2)
        self.deconv2 = deconv(filter_size*2, filter_size, 5, 2)
        self.deconv3 = deconv(filter_size, no_channels, 4, 2)

    def forward(self, input):
        x = self.linear(input)
        x = self.deconv1(x)
        x = self.deconv2(x)
        x = self.deconv3(x)
        return x

## Model

In [11]:
import torch.nn.functional as f

In [12]:
class Aes(nn.Module):
    def __init__(self, no_channels, filter_size, latent_size, normalize=True, linear_op=True):
        super(Aes, self).__init__()
        self.Encoder = Encoder(no_channels=no_channels, filter_size=filter_size, latent_size=latent_size, linear_op=linear_op)
        self.Decoder = Decoder(no_channels=no_channels, filter_size=filter_size, latent_size=latent_size)
        self.normalize = normalize

    def forward(self, input):
        y, y_rsr = self.Encoder(input)
        if self.normalize:
          z = f.normalize(y_rsr, dim=-1, p=2)
        else:
          z = y_rsr
        x_tilde = self.Decoder(z)
        return y, y_rsr, z, x_tilde

## Errors

In [13]:
def recon_error(x, xtilde):
  return torch.mean(torch.norm(x-x_tilde, dim=1))

def pca_error(y, z):
  y = y.reshape(y.shape[0],-1)
  A = Variable(torch.randn(y.shape[-1], 128))
  z = torch.matmul(z, torch.transpose(A,0,1))
  return torch.mean(torch.norm(y-z, dim=1)) # it's 2 or 'fro' by default

def proj_error(y, z):
  y = y.reshape(y.shape[0],-1)
  A = Variable(torch.randn(y.shape[-1], 128))
  z = torch.matmul(z, torch.transpose(A,0,1))
  return torch.mean(torch.square(torch.matmul(torch.transpose(A,0,1), A) - torch.eye(128)))

## Logging

In [14]:
log_freq = 10
tb_writer = SummaryWriter(log_dir = os.path.join(log_path,trainID))
os.makedirs(save_path, exist_ok=True)
logname = os.path.join(save_path, 'log_'+trainID+'.txt')
logging.basicConfig(filename=logname,
                            filemode='a',
                            format='%(asctime)s,%(msecs)d %(name)s %(levelname)s %(message)s',
                            datefmt='%H:%M:%S',
                            level=logging.DEBUG)

## Load Model

In [15]:
checkpoint2load = None

In [16]:
model = Aes(no_channels=1, filter_size=32, latent_size=128, normalize=normalize, linear_op=linear_op)
model.to(device)
optimizer = Adam(model.parameters(), lr=learning_rate)
optimizer2 = Adam(model.parameters(), lr=10*learning_rate)
optimizer3 = Adam(model.parameters(), lr=10*learning_rate)

In [17]:
criterion_rec = nn.MSELoss()
runningLoss = 0.0
runningLossCounter = 0.0
train_loss = 0.0

In [18]:
if checkpoint2load:
    chk = torch.load(checkpoint2load)
    model.load_state_dict(chk['state_dict'])
    optimizer.load_state_dict(chk['optimizer'])
    amp.load_state_dict(chk['amp'])
    start_epoch = chk['epoch'] + 1
    best_loss = chk['loss'] 
else:
    start_epoch = 0
    best_loss = float('inf')

## Training Model

In [None]:
for epoch in range(start_epoch, num_epochs):
    model.train().to(device)
    runningLoss = 0.0
    runningLossCounter = 0.0
    train_loss = 0.0
    print('Epoch '+ str(epoch)+ ': Training')
    
    with tqdm(total=len(train_loader)) as pbar:
        for i, data in enumerate(train_loader):
            try:
                img = data['img']['data'].squeeze(-1)
                data = Variable(img).to(device)
                model.zero_grad()
                y, y_rsr, z, x_tilde=model(data)

                if if_rsr and not all_alt:
                    loss = recon_error(data, x_tilde) + lambda1 * pca_error(y, y_rsr) + lambda2 * proj_error(y, y_rsr)
                else:
                    loss = recon_error(data, x_tilde)
                if not torch.isfinite(loss):
                    logging.error('Loss is not finite. Skipping the iteration.')
                    continue

                loss.backward()
                optimizer.step()
                if enforce_proj and all_alt:
                    loss_proj = proj_error(y, y_rsr)
                    loss_proj.backward()
                    optimizer2.step()
                if all_alt:
                    loss_alt = pca_error(y, y_rsr)
                    loss_alt.backward()
                    optimizer3.step()
                loss = round(loss.item(),4)
                train_loss += loss
                runningLoss += loss
                runningLossCounter += 1
                logging.info('[%d/%d][%d/%d] Train Loss: %.4f' % ((epoch+1), num_epochs, i, len(train_loader), loss))

                if i % log_freq == 0:
                    niter = epoch*len(train_loader)+i
                    tb_writer.add_scalar('Train/Loss', runningLoss/runningLossCounter, niter)
                    runningLoss = 0.0
                    runningLossCounter = 0.0
            except Exception as e:
                print("Caught exception: ", e)
                logging.error(str(e))
                pbar.update(1)
            pbar.update(1)
    checkpoint = {
      'model': model,
      'state_dict': model.state_dict(),
      'optimizer': optimizer.state_dict()
    }
    
    torch.save(checkpoint, os.path.join(save_path, trainID+".pth.tar"))
    tb_writer.add_scalar('Train/AvgLossEpoch', train_loss/len(train_loader), epoch)
    
    if val_loader:
        model.eval()
        with torch.no_grad():
            print('Epoch '+ str(epoch)+ ': Validation')
            with tqdm(total=len(val_loader)) as pbar:
                for i, data in enumerate(val_loader):
                    try:
                        img = data['img']['data'].squeeze(-1)
                        images = Variable(img).to(device)
                        y, y_rsr, z, x_tilde = model(images)
                        loss_rec = recon_error(data, x_tilde)
                        loss_pca = pca_error(y, y_rsr)
                        loss_proj = proj_error(y, y_rsr)
                        loss = loss_rec + loss_pca + loss_proj
                        logging.info('[%d/%d][%d/%d] Val Loss: %.4f' % ((epoch+1), num_epochs, i, len(val_loader), loss.mean().item()))
                        niter = epoch*len(val_loader)+i
                        for j in range(images.size(0)):
                            tb_writer.add_scalar('Val/Loss', loss[j].item(), (batch_size*niter)+j)
                            tb_writer.add_scalar('Val/GT', y[j], (batch_size*niter)+j)
                    except Exception as ex:
                        logging.error(ex)
                    pbar.update(1)