<h1><center style='color:red'>Information maximization-based clustering of histopathology images (128x128) using deep learning</center></h1>

In [None]:
# import required libraries

import os
import sys
import datetime
import configparser

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from PIL import Image
from torchsummary import summary

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision
from torchvision.transforms import RandomAffine

In [None]:
# set the device

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Device is:', device)

In [None]:
# define a stamp to save models

stamp = datetime.datetime.now().strftime('%Y%m%d%H%M')

In [None]:
# custom functions for 'marginal entropy', 'conditional entropy' and 'KL divergence'

def get_marginal_entropy(x):
    me = -torch.sum(x * torch.log(x + 1e-8))
    return me

def get_conditional_entropy(x):
    ce = -torch.sum((x * torch.log(x + 1e-8)), dim=1)
    return ce

def get_kl_divergence(p, q, batch_mean=True):
    kld = torch.sum((p * torch.log(p + 1e-8) - p * torch.log(q + 1e-8)), dim=1)
    if (batch_mean):
        return (torch.mean(kld))
    else:
        return kld

In [None]:
# custom functions for plotting the learning curve

class history():
    def __init__(self, keys):
        self.values = {}
        for k in keys:
            self.values[k] = []
        self.keys = keys
        
    def append(self, dict_hist):
        for k in dict_hist.keys():
            self.values[k].append(dict_hist[k])
    
    def mean(self, keys=None):
        if (keys is None):
            keys = self.keys
        m = {}
        for k in keys:
            m[k] = np.round(np.mean(self.values[k]), 3)
        return m
    
    def __getitem__(self, key):
        return (self.values[key])
    
    def __str__(self):
        get = self.mean(self.keys)
        return ('\t'.join([k + ': ' + str(get[k]) for k in self.keys]))

In [None]:
# set the hyperparameters

config = configparser.ConfigParser()

config['CAE'] = {'rand_seed': 765, 'ks': 3, 'nf0': 15, 'nf1': 45, 'nf2': 128, 'nf3': 196, 'nf4': 128, 'nf5': 128, 'nf6': 128,
                 'nfc': 14}

config['IM'] = {'lambda_affine': 0.03, 'lambda_marginal_entropy': 0.1, 'lambda_conditional_entropy': 0.03,
                'learning_rate': 0.003}

config.write(sys.stdout)

`nfc` denotes the number of clusters. Here, we selected 14 as our chosen number of cluster set. This can be chnaged to 8, 9, 10, 11, 12, 13, 15, 16, 17 and 18 also according to our experiment. 

In [None]:
rand_seed = int(config['CAE']['rand_seed'])
ks        = int(config['CAE']['ks'])
nf0       = int(config['CAE']['nf0'])
nf1       = int(config['CAE']['nf1'])
nf2       = int(config['CAE']['nf2'])
nf3       = int(config['CAE']['nf3'])
nf4       = int(config['CAE']['nf4'])
nf5       = int(config['CAE']['nf5'])
nf6       = int(config['CAE']['nf6'])
nfc       = int(config['CAE']['nfc'])

In [None]:
lambda_affine              = float(config['IM']['lambda_affine'])
lambda_marginal_entropy    = float(config['IM']['lambda_marginal_entropy'])
lambda_conditional_entropy = float(config['IM']['lambda_conditional_entropy'])
learning_rate              = float(config['IM']['learning_rate'])

In [None]:
# set random seed

print('random seed:', rand_seed)

torch.manual_seed(rand_seed)
torch.cuda.manual_seed(rand_seed)
np.random.seed(rand_seed)

In [None]:
# load the dataset

data_src = np.load('/project/dsc-is/mahfujul-r/M/slice128_Block2_20K.npy')
print(data_src.shape)

`/project/dsc-is/mahfujul-r/M/slice128_Block2_20K.npy` points to the location of dataset that contains randomly created 128x128 patches, where `slice128_Block2_20K.npy` is name of the file of the dataset. From its shape, we can see that it contains 20000 patches; but used 15000 for training the model. 8 indicates the number of stainings, but first, eighth, fiffth, sixth and seventh represent HE, MT, CD31, CK19 and Ki67, respectively and they illustrate identical tissue features. Hence, we used these 5 only. 128x128 is the spatial dimension and 3 depicts RGB channels.

In [None]:
# take the average of RGB values on the channel axis

np.round(np.mean(data_src, axis=(0, 1, 2, 3)))

In [None]:
# custom functions to extract batches of samples

def get_batch_index_tr(tr, batch_size=None, shuffle=True):
    if (shuffle):
        np.random.shuffle(tr)
    if (batch_size is not None):
        n_batch = len(tr) // batch_size
    batch_list = np.array_split(tr, n_batch)
    return batch_list

def get_batch_index_ae(ae, batch_size=None, shuffle=True):
    tr = np.arange(ae)
    batch_list = get_batch_index_tr(tr, batch_size=batch_size, shuffle=shuffle)
    return batch_list

In [None]:
ix, iy = 64, 64 # 128x128 patches are getting rescaled

#affine transformation
add_random_affine = RandomAffine(degrees=5, translate=(0.05, 0.05), scale=(0.95, 1.05), fill=(161, 138, 172)) 

def generate_batch(iii, src, device, random_affine=False):
    if (random_affine):
        tmp = np.empty((len(iii), ix, iy, nf0))
        for aa, ii in enumerate(iii):
            img_tmp0 = Image.fromarray(src[ii, 0])
            img_tmp0 = add_random_affine(img_tmp0).resize((ix, iy)) # HE
            img_tmp4 = Image.fromarray(src[ii, 4])
            img_tmp4 = add_random_affine(img_tmp4).resize((ix, iy)) # CD31
            img_tmp5 = Image.fromarray(src[ii, 5])
            img_tmp5 = add_random_affine(img_tmp5).resize((ix, iy)) # CK19
            img_tmp6 = Image.fromarray(src[ii, 6])
            img_tmp6 = add_random_affine(img_tmp6).resize((ix, iy)) # Ki67
            img_tmp7 = Image.fromarray(src[ii, 7])
            img_tmp7 = add_random_affine(img_tmp7).resize((ix, iy)) # MT
            tmp[aa] = np.concatenate((img_tmp0, img_tmp4, img_tmp5, img_tmp6, img_tmp7), axis=2)
        xxx = torch.tensor(tmp/255.0, dtype=torch.float32).permute(0, 3, 2, 1)
    
    else:
        tmp = np.empty((len(iii), ix, iy, nf0))
        for aa, ii in enumerate(iii):
            img_tmp0 = Image.fromarray(src[ii, 0]).resize((ix, iy)) # HE
            img_tmp4 = Image.fromarray(src[ii, 4]).resize((ix, iy)) # CD31
            img_tmp5 = Image.fromarray(src[ii, 5]).resize((ix, iy)) # CK19
            img_tmp6 = Image.fromarray(src[ii, 6]).resize((ix, iy)) # Ki67
            img_tmp7 = Image.fromarray(src[ii, 7]).resize((ix, iy)) # MT
            tmp[aa] = np.concatenate((img_tmp0, img_tmp4, img_tmp5, img_tmp6, img_tmp7), axis=2)
        xxx = torch.tensor(tmp/255.0, dtype=torch.float32).permute(0, 3, 2, 1)
    
    return (xxx.to(device))

Construct `CAE` (Convolutional AutoEncoder) architecture

- Build an `Encoder` class

In [None]:
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        
        self.encoder = nn.Sequential()
        self.encoder.add_module('conv1', nn.Conv2d(in_channels=nf0, out_channels=nf1, kernel_size=4, stride=2, padding=1))
        self.encoder.add_module('bnor1', nn.BatchNorm2d(num_features=nf1, affine=True, track_running_stats=True))
        self.encoder.add_module('lrel1', nn.LeakyReLU(0.1, inplace=True))
        self.encoder.add_module('conv2', nn.Conv2d(in_channels=nf1, out_channels=nf2, kernel_size=4, stride=2, padding=1))
        self.encoder.add_module('bnor2', nn.BatchNorm2d(num_features=nf2, affine=True, track_running_stats=True))
        self.encoder.add_module('lrel2', nn.LeakyReLU(0.1, inplace=True))
        self.encoder.add_module('conv3', nn.Conv2d(in_channels=nf2, out_channels=nf3, kernel_size=4, stride=2, padding=1))
        self.encoder.add_module('bnor3', nn.BatchNorm2d(num_features=nf3, affine=True, track_running_stats=True))
        self.encoder.add_module('lrel3', nn.LeakyReLU(0.1, inplace=True))
        self.encoder.add_module('conv4', nn.Conv2d(in_channels=nf3, out_channels=nf4, kernel_size=4, stride=2, padding=1))
        self.encoder.add_module('bnor4', nn.BatchNorm2d(num_features=nf4, affine=True, track_running_stats=True))
        self.encoder.add_module('lrel4', nn.LeakyReLU(0.1, inplace=True))
        
    def forward(self, xxx):
        hhh = self.encoder(xxx)        
        return hhh

- Visualize `Encoder` architecture

In [None]:
summary(Encoder().to(device), input_size=(nf0, ix, iy))

- Build a `Classifier` class

In [None]:
class Classifier(nn.Module):
    def __init__(self):
        super(Classifier, self).__init__()
        
        self.classifier = nn.Sequential()
        self.classifier.add_module('conv1', nn.Conv2d(in_channels=nf4, out_channels=nf5, kernel_size=4, stride=1, padding=0))
        self.classifier.add_module('bnor1', nn.BatchNorm2d(num_features=nf5, affine=True, track_running_stats=True))
        self.classifier.add_module('lrel1', nn.LeakyReLU(0.1, inplace=True))
        self.classifier.add_module('conv2', nn.Conv2d(in_channels=nf5, out_channels=nf6, kernel_size=1, stride=1, padding=0))
        self.classifier.add_module('lrel2', nn.LeakyReLU(0.1, inplace=True))
        self.classifier.add_module('conv3', nn.Conv2d(in_channels=nf6, out_channels=nfc, kernel_size=1, stride=1, padding=0))
        self.classifier.add_module('lrel3', nn.LeakyReLU(0.1, inplace=True))
        
    def forward(self, hhh):
        vvv = self.classifier(hhh)
        return vvv

- Visualize `Classifier` architecture

In [None]:
summary(Classifier().to(device), input_size=(nf4, ix//16, iy//16))

- Build a `Decoder` class

In [None]:
class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        
        self.decoder = nn.Sequential()
        self.decoder.add_module('upsm4', nn.UpsamplingBilinear2d(scale_factor=2))
        self.decoder.add_module('dcov4', nn.Conv2d(in_channels=nf4 + nfc, out_channels=nf3, kernel_size=3, stride=1, padding=1))
        self.decoder.add_module('norm4', nn.BatchNorm2d(num_features=nf3, affine=True, track_running_stats=True))
        self.decoder.add_module('lrel4', nn.LeakyReLU(0.1, inplace=True))
        self.decoder.add_module('upsm3', nn.UpsamplingBilinear2d(scale_factor=2))
        self.decoder.add_module('dcov3', nn.Conv2d(in_channels=nf3, out_channels=nf2, kernel_size=3, stride=1, padding=1))
        self.decoder.add_module('norm3', nn.BatchNorm2d(num_features=nf2, affine=True, track_running_stats=True))
        self.decoder.add_module('lrel3', nn.LeakyReLU(0.1, inplace=True))
        self.decoder.add_module('upsm2', nn.UpsamplingBilinear2d(scale_factor=2))
        self.decoder.add_module('dcov2', nn.Conv2d(in_channels=nf2, out_channels=nf1, kernel_size=3, stride=1, padding=1))
        self.decoder.add_module('norm2', nn.BatchNorm2d(num_features=nf1, affine=True, track_running_stats=True))
        self.decoder.add_module('lrel2', nn.LeakyReLU(0.1, inplace=True))
        self.decoder.add_module('upsm1', nn.UpsamplingBilinear2d(scale_factor=2))
        self.decoder.add_module('dcov1', nn.Conv2d(in_channels=nf1, out_channels=nf0, kernel_size=3, stride=1, padding=1))
        self.decoder.add_module('norm1', nn.BatchNorm2d(num_features=nf0, affine=True, track_running_stats=True))
        self.decoder.add_module('sgmd1', nn.Sigmoid())
        
    def forward(self, hhh, vvv):
        ccc = vvv.repeat((1, 1, ix//16, iy//16))
        hhh = torch.cat((hhh, ccc), dim=1)
        yyy = self.decoder(hhh)
        return yyy

- Visualize `Decoder` architecture

In [None]:
summary(Decoder().to(device).decoder, input_size=(nf4 + nfc, ix//16, iy//16))

In [None]:
# set up 'reconstruction loss', 'CAE' and 'optimizer'

get_MSELoss = nn.MSELoss(reduction='mean')

model_en = Encoder().to(device)
model_cl = Classifier().to(device)
model_de = Decoder().to(device)

optim_en = optim.Adadelta(model_en.parameters(), lr=learning_rate)
optim_cl = optim.Adadelta(model_cl.parameters(), lr=learning_rate)
optim_de = optim.Adadelta(model_de.parameters(), lr=learning_rate)

In [None]:
# prepare to train the model

t0 = 0
key_loss = ['loss', 'loss_rec', 'entropy_marginal', 'entropy_mean', 'loss_affine']
loss_hist = history(['tt'] + key_loss)

In [None]:
t_epoch = 4000
t_print = 10
t_log = 10

In [None]:
# training loop for the 'CAE_IM' model

for tt in range(t0, t0 + t_epoch):
    loss_tt = history(key_loss)
    iii_batch = get_batch_index_ae(15000, batch_size=100, shuffle=True) # first 15000 samples to train the model
    
    for iii in iii_batch:
        xxx_tmp = generate_batch(iii, data_src, device)
        xxa_tmp = generate_batch(iii, data_src, device, random_affine=True)
        
        model_en.train()
        model_cl.train()
        model_de.train()
        
        hhh_tmp = model_en(xxx_tmp)
        vvv_tmp = model_cl(hhh_tmp)
        yyy_tmp = model_de(hhh_tmp, vvv_tmp)
        
        rec_loss = get_MSELoss(xxx_tmp, yyy_tmp)
        
        ppp_tmp = F.softmax(vvv_tmp.reshape((-1, nfc)), dim=1)
        ppp_mean = torch.mean(ppp_tmp, dim=0, keepdim=True)
        
        marginal_entropy = get_marginal_entropy(ppp_mean)
        conditional_entropy = torch.mean(get_conditional_entropy(ppp_tmp))
        
        hha_tmp = model_en(xxa_tmp)
        vva_tmp = model_cl(hha_tmp)
        yya_tmp = model_de(hha_tmp, vva_tmp)
        
        ppa_tmp = F.softmax(vva_tmp.reshape((-1, nfc)), dim=1)
        affine_loss = get_kl_divergence(ppp_tmp, ppa_tmp)
        
        loss_tmp = rec_loss - lambda_marginal_entropy * marginal_entropy + lambda_conditional_entropy * conditional_entropy + \
                   lambda_affine * affine_loss
        
        optim_en.zero_grad()
        optim_cl.zero_grad()
        optim_de.zero_grad()
        loss_tmp.backward()
        optim_en.step()
        optim_cl.step()
        optim_de.step()
        
        loss_tt.append({'loss': loss_tmp.item(), 'loss_rec': rec_loss.item(), 'entropy_marginal': marginal_entropy.item(),
                        'entropy_mean': conditional_entropy.item(), 'loss_affine': affine_loss.item()})

    if (tt + 1) % t_log == 0:
        loss_hist.append({'tt': tt})
        loss_hist.append(loss_tt.mean())

    if (tt + 1) % t_print == 0:
        print(tt + 1, '/', t0 + t_epoch, '\t', str(loss_tt))

In [None]:
# plot the learning curve

plt.figure(figsize=(8, 5))

plt.plot(loss_hist['tt'], loss_hist['loss'], 'r', label='Loss function', linewidth=3)
plt.plot(loss_hist['tt'], loss_hist['loss_rec'], 'g', label='Reconstruction loss', linewidth=3)
plt.plot(loss_hist['tt'], lambda_marginal_entropy * np.array(loss_hist['entropy_marginal']) - lambda_conditional_entropy * \
         np.array(loss_hist['entropy_mean']), 'b', label='Mutual information (MI)', linewidth=3)

plt.xlabel('Number of epochs', fontsize=22)
plt.ylabel('Value', fontsize=22)
plt.title('The learning curve', fontsize=30)
plt.legend(loc='best', fontsize=15, bbox_to_anchor=(0.15, 0.15, 0.755, 1.45));

In [None]:
# check  different shapes

print('Encoder input shape\t:', xxx_tmp.shape)
print('\nEncoder output shape\t:', hhh_tmp.shape)
print('\nClassifier output shape\t:', vvv_tmp.shape)
print('\nDecoder output shape\t:', yyy_tmp.shape)

notice that the encoder input and decoder output shape is similar

In [None]:
dir_save = 'models14' # folder to save models

In [None]:
path_hist = os.path.join(dir_save, f'hist_modelS_{stamp}_{tt + 1}.tsv')
print('saving', path_hist) # save training history

zzz = pd.DataFrame.from_dict(loss_hist.values)
zzz.to_csv(path_hist, sep='\t')

In [None]:
# save models

path_model_en = os.path.join(dir_save, f'model_en_{stamp}_{tt + 1}.ckpt')
path_model_cl = os.path.join(dir_save, f'model_cl_{stamp}_{tt + 1}.ckpt')
path_model_de = os.path.join(dir_save, f'model_de_{stamp}_{tt + 1}.ckpt')

print('saving', path_model_en)
torch.save(model_en.state_dict(), path_model_en)

print('saving', path_model_cl)
torch.save(model_cl.state_dict(), path_model_cl)

print('saving', path_model_de)
torch.save(model_de.state_dict(), path_model_de)