In [None]:
# import libraries

import os
import sys
import datetime
import configparser
import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
import torch
import torch.optim as optim
import torch.nn.functional as F
import torch.nn as nn
import torchvision
from torchvision.transforms import RandomAffine
from torchsummary import summary

In [None]:
# set device

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Device is:', device)

In [None]:
# define stamp for model saving

stamp = datetime.datetime.now().strftime('%Y%m%d%H%M')

In [None]:
# custom functions for 'marginal entropy', 'conditional entropy' and 'KL divergence'

def get_entropy_1D(xxx): # marginal entropy
    return (-torch.sum(xxx * torch.log(xxx + 1e-8)))
def get_entropy_2D(xxx): # conditional entropy
    return (-torch.sum(xxx * torch.log(xxx + 1e-8), dim = 1))
def get_KLD_1D(ppp, qqq, batch_mean = True): # KL divergence
    tmp = torch.sum(ppp * torch.log(ppp + 1e-8) - ppp * torch.log(qqq + 1e-8), dim = 1)
    if (batch_mean):
        return (torch.mean(tmp))
    else:
        return tmp

In [None]:
# set up 'reconstruction loss'

get_MSELoss = nn.MSELoss(reduction = 'mean')

In [None]:
# custom functions for plotting the learning curve

class history():
    def __init__(self, keys):
        self.values = {}
        for kk in keys:
            self.values[kk] = []
        self.keys = keys
    def append(self, dict_hist):
        for kk in dict_hist.keys():
            self.values[kk].append(dict_hist[kk])
    def mean(self, keys = None):
        if (keys is None):
            keys = self.keys
        mm = {}
        for kk in keys:
            mm[kk] = np.round(np.mean(self.values[kk]), 3)
        return mm
    def __getitem__(self, key):
        return (self.values[key])
    def __str__(self):
        tmp = self.mean(self.keys)
        return ('\t'.join([kk + ':' + str(tmp[kk]) for kk in self.keys]))

In [None]:
# set the hyperparameters

config = configparser.ConfigParser()
config['CNN'] = {'rand_seed': 765, 'ks': 3, 'nf0': 15, 'nf1': 45, 'nf2': 128, 'nf3': 196, 'nf4': 128, 'nf5': 128, 'nf6': 128,
                 'nfc': 16}
config['IMSAT'] = {'lambda_affine': 0.03, 'lambda_autoencoder': 1.0, 'lambda_entropy_marginal': 0.1,
                   'lambda_entropy_mean': 0.03, 'learning_rate': 0.001}

In [None]:
rand_seed = int(config['CNN']['rand_seed'])
ks        = int(config['CNN']['ks'])
nf0       = int(config['CNN']['nf0'])
nf1       = int(config['CNN']['nf1'])
nf2       = int(config['CNN']['nf2'])
nf3       = int(config['CNN']['nf3'])
nf4       = int(config['CNN']['nf4'])
nf5       = int(config['CNN']['nf5'])
nf6       = int(config['CNN']['nf6'])
nfc       = int(config['CNN']['nfc'])

In [None]:
lambda_affine           = float(config['IMSAT']['lambda_affine'])
lambda_autoencoder      = float(config['IMSAT']['lambda_autoencoder'])
lambda_entropy_marginal = float(config['IMSAT']['lambda_entropy_marginal'])
lambda_entropy_mean     = float(config['IMSAT']['lambda_entropy_mean'])
learning_rate           = float(config['IMSAT']['learning_rate'])

In [None]:
# set random seed

print('random seed:', rand_seed)

torch.manual_seed(rand_seed)
torch.cuda.manual_seed(rand_seed)
np.random.seed(rand_seed)

In [None]:
# load the dataset

data_src = np.load('/project/dsc-is/nono/Documents/kpc/dat0/slice128_Block2_11K.npy')
print(data_src.shape)

In [None]:
# take the average of RGB values on the channel axis

np.round(np.mean(data_src, axis = (0, 1, 2, 3)))

In [None]:
# custom functions to extract batches of samples

def get_batch_index_iii(iii, batch_size = None, shuffle = True):
    if (shuffle):
        np.random.shuffle(iii)
    if (batch_size is not None):
        n_batch = len(iii) // batch_size
    batch_list = np.array_split(iii, n_batch)
    return batch_list
def get_batch_index_nn(nn, batch_size = None, shuffle = True):
    iii = np.arange(nn)
    batch_list = get_batch_index_iii(iii, batch_size = batch_size, shuffle = shuffle)
    return batch_list

In [None]:
ix, iy = 64, 64 # 128x128 patches are getting rescaled

# affine transformation
add_random_affine = RandomAffine(degrees = 5, translate = (0.05, 0.05), scale = (0.95, 1.05), fill = (161, 138, 172))
def generate_batch(iii, src, device, random_affine = False):
    if (random_affine):
        tmp = np.empty((len(iii), ix, iy, nf0))
        for aa, ii in enumerate(iii):
            img_tmp0 = Image.fromarray(src[ii, 0])
            img_tmp0 = add_random_affine(img_tmp0).resize((ix, iy)) # HE
            img_tmp4 = Image.fromarray(src[ii, 4])
            img_tmp4 = add_random_affine(img_tmp4).resize((ix, iy)) # CD31
            img_tmp5 = Image.fromarray(src[ii, 5])
            img_tmp5 = add_random_affine(img_tmp5).resize((ix, iy)) # CK19
            img_tmp6 = Image.fromarray(src[ii, 6])
            img_tmp6 = add_random_affine(img_tmp6).resize((ix, iy)) # Ki67
            img_tmp7 = Image.fromarray(src[ii, 7])
            img_tmp7 = add_random_affine(img_tmp7).resize((ix, iy)) # MT
            tmp[aa] = np.concatenate((img_tmp0,img_tmp4, img_tmp5, img_tmp6, img_tmp7), axis = 2)
        xxx = torch.tensor(tmp / 256.0, dtype = torch.float32).permute(0, 3, 2, 1)
    else:
        tmp = np.empty((len(iii), ix, iy, nf0))
        for aa, ii in enumerate(iii):
            img_tmp0 = Image.fromarray(src[ii, 0]).resize((ix, iy)) # HE
            img_tmp4 = Image.fromarray(src[ii, 4]).resize((ix, iy)) # CD31
            img_tmp5 = Image.fromarray(src[ii, 5]).resize((ix, iy)) # CK19
            img_tmp6 = Image.fromarray(src[ii, 6]).resize((ix, iy)) # Ki67
            img_tmp7 = Image.fromarray(src[ii, 7]).resize((ix, iy)) # MT
            tmp[aa] = np.concatenate((img_tmp0,img_tmp4, img_tmp5, img_tmp6, img_tmp7), axis = 2)
        xxx = torch.tensor(tmp / 256.0, dtype = torch.float32).permute(0, 3, 2, 1)
    return (xxx.to(device))

In [None]:
# build encoder architecture

class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.encoder = nn.Sequential()
        self.encoder.add_module('conv1', nn.Conv2d(nf0, nf1, kernel_size = 4, stride = 2, padding = 1))
        self.encoder.add_module('bnor1', nn.BatchNorm2d(nf1, affine = True, track_running_stats = True))
        self.encoder.add_module('relu1', nn.LeakyReLU(0.1, inplace = True))
        self.encoder.add_module('conv2', nn.Conv2d(nf1, nf2, kernel_size = 4, stride = 2, padding = 1))
        self.encoder.add_module('bnor2', nn.BatchNorm2d(nf2, affine = True, track_running_stats = True))
        self.encoder.add_module('relu2', nn.LeakyReLU(0.1, inplace = True))
        self.encoder.add_module('conv3', nn.Conv2d(nf2, nf3, kernel_size = 4, stride = 2, padding = 1))
        self.encoder.add_module('bnor3', nn.BatchNorm2d(nf3, affine = True, track_running_stats = True))
        self.encoder.add_module('relu3', nn.LeakyReLU(0.1, inplace = True))
        self.encoder.add_module('conv4', nn.Conv2d(nf3, nf4, kernel_size = 4, stride = 2, padding = 1))
        self.encoder.add_module('bnor4', nn.BatchNorm2d(nf4, affine = True, track_running_stats = True))
        self.encoder.add_module('relu4', nn.LeakyReLU(0.1, inplace = True))
    def forward(self, xxx):
        hhh = self.encoder(xxx)
        return hhh

In [None]:
# build classifier architecture

class Classifier(nn.Module):
    def __init__(self):
        super(Classifier, self).__init__()
        self.classifier = nn.Sequential()
        self.classifier.add_module('conv1', nn.Conv2d(nf4, nf5, kernel_size = 4, stride = 1, padding = 0))
        self.classifier.add_module('bnor1', nn.BatchNorm2d(nf5, affine = True, track_running_stats = True))
        self.classifier.add_module('relu1', nn.LeakyReLU(0.1, inplace = True))
        self.classifier.add_module('conv2', nn.Conv2d(nf5, nf6, kernel_size = 1, stride = 1, padding = 0))
        self.classifier.add_module('relu2', nn.LeakyReLU(0.1, inplace = True))
        self.classifier.add_module('conv3', nn.Conv2d(nf6, nfc, kernel_size = 1, stride = 1, padding = 0))
        self.classifier.add_module('relu3', nn.LeakyReLU(0.1, inplace = True))
    def forward(self, hhh):
        vvv = self.classifier(hhh)
        return vvv

In [None]:
# build decoder architecture

class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.decoder = nn.Sequential()
        self.decoder.add_module('upsm4', nn.UpsamplingBilinear2d(scale_factor = 2))
        self.decoder.add_module('dcov4', nn.Conv2d(nf4 + nfc, nf3, kernel_size = 3, stride = 1, padding = 1))
        self.decoder.add_module('norm4', nn.BatchNorm2d(nf3, affine = True, track_running_stats = True))
        self.decoder.add_module('relu4', nn.LeakyReLU(0.1, inplace = True))
        self.decoder.add_module('upsm3', nn.UpsamplingBilinear2d(scale_factor = 2))
        self.decoder.add_module('dcov3', nn.Conv2d(nf3, nf2, kernel_size = 3, stride = 1, padding = 1))
        self.decoder.add_module('norm3', nn.BatchNorm2d(nf2, affine = True, track_running_stats = True))
        self.decoder.add_module('relu3', nn.LeakyReLU(0.1, inplace = True))
        self.decoder.add_module('upsm2', nn.UpsamplingBilinear2d(scale_factor = 2))
        self.decoder.add_module('dcov2', nn.Conv2d(nf2, nf1, kernel_size = 3, stride = 1, padding = 1))
        self.decoder.add_module('norm2', nn.BatchNorm2d(nf1, affine = True, track_running_stats = True))
        self.decoder.add_module('relu2', nn.LeakyReLU(0.1, inplace = True))
        self.decoder.add_module('upsm1', nn.UpsamplingBilinear2d(scale_factor = 2))
        self.decoder.add_module('dcov1', nn.Conv2d(nf1, nf0, kernel_size = 3, stride = 1, padding = 1))
        self.decoder.add_module('norm1', nn.BatchNorm2d(nf0, affine = True, track_running_stats = True))
        self.decoder.add_module('relu1', nn.Sigmoid())
    def forward(self, hhh, vvv):
        ccc = vvv.repeat((1, 1, ix // 16, iy // 16))
        hhh = torch.cat((hhh, ccc), dim = 1)
        yyy = self.decoder(hhh)
        return yyy

In [None]:
# visualize encoder architecture

summary(Encoder().to(device), (nf0, ix, iy))

In [None]:
# visualize classifier architecture

summary(Classifier().to(device), (nf4, ix // 16, iy // 16))

In [None]:
# visualize decoder architecture

summary(Decoder().to(device).decoder, (nf4 + nfc, ix // 16, iy // 16))

In [None]:
# set up 'models' and 'optimizer'

model_en = Encoder().to(device)
model_cl = Classifier().to(device)
model_de = Decoder().to(device)
optim_en = optim.Adadelta(model_en.parameters(), lr = learning_rate)
optim_cl = optim.Adadelta(model_cl.parameters(), lr = learning_rate)
optim_de = optim.Adadelta(model_de.parameters(), lr = learning_rate)

In [None]:
# prepare to train the model

t0 = 0
key_loss = ['loss', 'loss_rec', 'entropy_marginal', 'entropy_mean', 'loss_affine']
loss_hist = history(['tt'] + key_loss)

In [None]:
t_epoch = 5000
t_print = 10
t_log = 10

In [None]:
# training loop

for tt in range(t0, t0 + t_epoch):
    loss_tt = history(key_loss)
    iii_batch = get_batch_index_nn(10000, batch_size = 100, shuffle = True)
    for iii in iii_batch:
        xxx_tmp = generate_batch(iii, data_src, device)
        xxa_tmp = generate_batch(iii, data_src, device, random_affine = True)
        model_en.train()
        model_cl.train()
        model_de.train()
        hhh_tmp = model_en(xxx_tmp)
        vvv_tmp = model_cl(hhh_tmp)
        yyy_tmp = model_de(hhh_tmp, vvv_tmp)
        loss_rec = get_MSELoss(xxx_tmp, yyy_tmp)
        ppp_tmp = F.softmax(vvv_tmp.reshape((-1, nfc)), dim = 1)
        ppp_mean = torch.mean(ppp_tmp, dim = 0, keepdim = True)
        entropy_marginal = get_entropy_1D(ppp_mean)
        entropy_mean = torch.mean(get_entropy_2D(ppp_tmp))
        hha_tmp = model_en(xxa_tmp)
        vva_tmp = model_cl(hha_tmp)
        ppa_tmp = F.softmax(vva_tmp.reshape((-1, nfc)), dim = 1)
        loss_affine = get_KLD_1D(ppp_tmp, ppa_tmp)
        loss_tmp = lambda_autoencoder * loss_rec - lambda_entropy_marginal * entropy_marginal + \
                   lambda_entropy_mean * entropy_mean + lambda_affine * loss_affine
        optim_en.zero_grad()
        optim_cl.zero_grad()
        optim_de.zero_grad()
        loss_tmp.backward()
        optim_en.step()
        optim_cl.step()
        optim_de.step()
        loss_tt.append({'loss': loss_tmp.item(), 'loss_rec': loss_rec.item(), 'entropy_marginal': entropy_marginal.item(),
                        'entropy_mean': entropy_mean.item(), 'loss_affine': loss_affine.item()})
    if (tt + 1) % t_log == 0:
        loss_hist.append({'tt': tt})
        loss_hist.append(loss_tt.mean())
    if (tt + 1) % t_print == 0:
        print(tt + 1, '/', t0 + t_epoch, '\t', str(loss_tt))

In [None]:
# plot the learning curve

plt.figure(figsize=(8, 5))
plt.plot(loss_hist['tt'], loss_hist['loss'], 'r', label='Loss function', linewidth=3)
plt.plot(loss_hist['tt'], loss_hist['loss_rec'], 'g', label='Reconstruction loss', linewidth=3)
plt.plot(loss_hist['tt'], lambda_entropy_marginal * np.array(loss_hist['entropy_marginal']) - lambda_entropy_mean * \
         np.array(loss_hist['entropy_mean']), 'b', label='Mutual information (MI)', linewidth=3)
plt.legend();

In [None]:
# check  different shapes

print('Encoder input shape\t:', xxx_tmp.shape)
print('\nEncoder output shape\t:', hhh_tmp.shape)
print('\nClassifier output shape\t:', vvv_tmp.shape)
print('\nDecoder output shape\t:', yyy_tmp.shape)

In [None]:
# folder to save models

dir_save = 'models16'

In [None]:
path_hist = os.path.join(dir_save, f'hist_modelS_{stamp}_{tt + 1}.tsv')
print('saving', path_hist) # save training history
tmp = pd.DataFrame.from_dict(loss_hist.values)
tmp.to_csv(path_hist, sep = '\t')

In [None]:
# save models

path_model_en = os.path.join(dir_save, f'model_en_{stamp}_{tt + 1}.ckpt')
path_model_cl = os.path.join(dir_save, f'model_cl_{stamp}_{tt + 1}.ckpt')
path_model_de = os.path.join(dir_save, f'model_de_{stamp}_{tt + 1}.ckpt')
print('saving', path_model_en)
torch.save(model_en.state_dict(), path_model_en)
print('saving', path_model_cl)
torch.save(model_cl.state_dict(), path_model_cl)
print('saving', path_model_de)
torch.save(model_de.state_dict(), path_model_de)