Import

In [None]:
#Import
import argparse
import os
import numpy as np
import math
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, random_split, Dataset
from random import randint
from torch.optim.lr_scheduler import StepLR
import random;
from skimage.metrics import structural_similarity as ssim
import warnings
warnings.filterwarnings('ignore')

Dataset

In [None]:
class BasicDataset(Dataset):
    
    
    def __init__(self, ccm_images, ccm_labels):
        


        #Image Attributes
        self.ccm_images = ccm_images;
        self.ccm_labels = ccm_labels;


    def __len__(self):
        return np.shape(self.ccm_images)[0]


    def __getitem__(self, i):
        
        
        #Obtain Image File 
        img = self.ccm_images[i,:,:].squeeze().copy();
        ccm_label = self.ccm_labels[i,0]
        ccm_label = np.expand_dims(ccm_label,axis = 0);

        #Multiple Samples       
        for group in range(0,4):


            #Select Window
            if(self.ccm_labels[i,0]==1):
             
              row = randint(48,143);
              col = randint(48,143);

            else:      
              row = randint(0,191);
              col = randint(0,191);
                
            
            #Extract
            cropped_img = img[row:row+192,col:col+192].copy();
            

            #Normalize
            cropped_img = (cropped_img - np.min(cropped_img))/(np.max(cropped_img) - np.min(cropped_img))
            
            #Flip the Image (Horizontal)
            if(randint(0,1) ==1):
                cropped_img = np.fliplr(cropped_img).copy();

        
            #Flip the Image (Vertical)
            if(randint(0,1) ==1):
                cropped_img = np.flipud(cropped_img).copy();

            #Expand Dim
            cropped_img = np.expand_dims(cropped_img,axis = 0);


            if(group ==0):
                images = cropped_img;
                labels = ccm_label;
            else:
                #Add to Group
                images = np.concatenate((images, cropped_img), 0);
                labels = np.concatenate((labels, ccm_label), 0);


        return {'ccm_images': images, 'ccm_labels': labels}

Sample Image

In [None]:

"""Saves a grid of generated digits ranging from 0 to n_classes"""

def sample_image(n_row, epochs):
    
    # Get Sample Noise
    z = Variable(FloatTensor(np.random.normal(0, 1, (n_row ** 2, latent_dim))))

    # Get Labels
    labels = np.array([num for _ in range(n_row) for num in range(n_row)])
    labels = Variable(LongTensor(labels))


    #Generate Images from Sample Noise and Labels
    gen_imgs = generator(z, labels)

    x = gen_imgs.data;


    #Save Images
    save_image(gen_imgs.data, r'/hpc/group/viplab/zzz3/SNP_Segmentation/Files/Images/GAN/%d.png' % epochs, nrow=n_row, normalize=True)

Initialize Weights

In [None]:

"""Initializes weights in networks"""

def weights_init_normal(m):

    #Get Class Name
    classname = m.__class__.__name__

    #Initialize by Layer Type
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)




Generator and Discriminator

In [None]:

#######################################################
# Generator Network
#######################################################


class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        #Embeddings Layers
        self.label_emb = nn.Embedding(n_classes, latent_dim)

        # Initial size before upsampling
        self.init_size = img_size // 4  

        #Resize
        self.l1 = nn.Sequential(nn.Linear(latent_dim, 128 * self.init_size ** 2))

        #Convolution Blocks
        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 128, 3, stride=1, padding=1),
            nn.BatchNorm2d(128, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, channels, 3, stride=1, padding=1),
            nn.Tanh(),
        )

    def forward(self, noise, labels):

        #Multiply Noise by Labels
        gen_input = torch.mul(self.label_emb(labels), noise)

        #Resize
        out = self.l1(gen_input)
        out = out.view(out.shape[0], 128, self.init_size, self.init_size)

        #Convolution Blocks
        img = self.conv_blocks(out)

        #Output
        return img;
    
    
#######################################################
# Discriminator Network
#######################################################

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()


        #Discrimator Blocks
        def discriminator_block(in_filters, out_filters, bn=True):
            
            block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25)]
            if bn:
                block.append(nn.BatchNorm2d(out_filters, 0.8))
            return block



        #Convolution Blocks
        self.conv_blocks = nn.Sequential(
            *discriminator_block(channels, 16, bn=False),
            *discriminator_block(16, 32),
            *discriminator_block(32, 64),
            *discriminator_block(64, 128),
        )

        # The height and width of downsampled image
        ds_size = img_size // 2 ** 4

        # Output layers
        self.adv_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, 1), nn.Sigmoid())
        self.aux_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, n_classes), nn.Softmax())

    def forward(self, img):


        #Discrimination Convolution Blocks
        out = self.conv_blocks(img)

        #Reshape
        out = out.view(out.shape[0], -1)

        #Output Layers
        validity = self.adv_layer(out)
        label = self.aux_layer(out)


        #Output
        return validity, label


Load Data

In [None]:
#Path
python_path = r'/hpc/group/viplab/zzz3/SNP_Segmentation/Files/Python/AAR-Net/';
torch_path = r'/hpc/group/viplab/zzz3/SNP_Segmentation/Files/Torch/AAR-Net/';

#Load Numpy Files
SNP_Train = np.load(python_path + 'AAR_Net_Images_Training_SNP.npy');
SNP_Val = np.load(python_path + 'AAR_Net_Images_Validation_SNP.npy');
SNP = np.concatenate((SNP_Train, SNP_Val), 0);
Epithelium_Train = np.load(python_path + 'AAR_Net_Images_Training_Epithelium.npy');
Epithelium_Val = np.load(python_path + 'AAR_Net_Images_Validation_Epithelium.npy');
Epithelium = np.concatenate((Epithelium_Train, Epithelium_Val), 0);
Stroma_Train = np.load(python_path + 'AAR_Net_Images_Training_Stroma.npy');
Stroma_Val = np.load(python_path + 'AAR_Net_Images_Validation_Stroma.npy');
Stroma = np.concatenate((Stroma_Train, Stroma_Val), 0);

#Print Size
print(np.shape(SNP))
print(np.shape(Epithelium))
print(np.shape(Stroma))


#Concatenate
ccm_images = np.concatenate((SNP, Epithelium, Stroma), 0);
ccm_labels = np.concatenate((np.zeros((np.shape(SNP)[0],1)), np.ones((np.shape(Epithelium)[0],1)), 2*np.ones((np.shape(Stroma)[0],1))), 0);


#Shuffle
indices = np.arange(0,np.shape(ccm_images)[0]);
np.random.shuffle(indices);
ccm_images = ccm_images[indices]
ccm_labels = ccm_labels[indices]



In [None]:

#Hyper Parameters
n_epochs = 400;
batch_size = 5;
lr = 0.002;
b1 = 0.5;
b2 = 0.999;
latent_dim = 400;
n_classes = 3;
img_size = 192;
channels = 1;
experiment = 'Original'


# Loss functions
adversarial_loss = torch.nn.BCELoss()
auxiliary_loss = torch.nn.CrossEntropyLoss()


# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()


#Cast to Cuda
cuda = True if torch.cuda.is_available() else False
if cuda:
    generator.cuda()
    discriminator.cuda()
    adversarial_loss.cuda()
    auxiliary_loss.cuda()
    FloatTensor = torch.cuda.FloatTensor
    LongTensor = torch.cuda.LongTensor
else:
    FloatTensor = torch.FloatTensor
    LongTensor = torch.LongTensor

#Attempt to use GPU instead of CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu');


# Initialize weights
generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)


# Configure data loader
dataset = BasicDataset(ccm_images, ccm_labels);
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=16, pin_memory=True);
       

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))

# Schedulers
scheduler_G = StepLR(optimizer_G, step_size = 100, gamma = 0.5)
scheduler_D = StepLR(optimizer_D, step_size = 100, gamma = 0.5)



# **Training**

In [None]:

def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']


discriminator_loss_net =np.zeros((n_epochs,1));
generator_loss_net = np.zeros((n_epochs,1));
discriminator_accuracy_net =np.zeros((n_epochs,1));
SSIM_results_net = np.zeros((n_epochs,3));


#Iterate through Training Epochs
for epoch in range(n_epochs):



    discriminator_loss = [];
    generator_loss = [];
    discriminator_accuracy =[];

    #Iterate though Images in Dataloader
    idx = 0;
    for batch in dataloader:


        #Load Batch
        real_imgs = torch.from_numpy(np.reshape(np.array(batch['ccm_images']), (-1,1,192,192))).to(device=device, dtype=torch.float32);
        labels_real = torch.from_numpy(np.reshape(np.array(batch['ccm_labels']), (-1,1,1,1))).to(device=device, dtype=torch.long).squeeze();


        #Obtain Batch Size
        batch_size = real_imgs.shape[0]

        # Adversarial Ground Truth Labels
        real = Variable(FloatTensor(batch_size, 1).fill_(1.0), requires_grad=False)
        fake = Variable(FloatTensor(batch_size, 1).fill_(0.0), requires_grad=False)


        # -----------------
        #  Train Generator
        # -----------------

        #Zero Grad: Generator
        optimizer_G.zero_grad()

        # Sample noise and labels as generator input
        z = Variable(FloatTensor(np.random.normal(0, 1, (batch_size, latent_dim))))
        labels_fake = Variable(LongTensor(np.random.randint(0, n_classes, batch_size)))

        # Generate a batch of images
        fake_imgs = generator(z, labels_fake)



        # Loss measures generator's ability to fool the discriminator
        validity_fake, pred_labels_fake = discriminator(fake_imgs)
        g_loss = (adversarial_loss(validity_fake, real) + auxiliary_loss(pred_labels_fake, labels_fake))/2.0;


        #Back Propagation: Generator
        g_loss.backward()
        optimizer_G.step()

      
        
        # ---------------------
        #  Train Discriminator
        # ---------------------


        #Zero Grad: Discriminator
        optimizer_D.zero_grad()


        # Loss for real images
        validity_real, pred_labels_real = discriminator(real_imgs)
        d_real_loss = (adversarial_loss(validity_real, real) + auxiliary_loss(pred_labels_real, labels_real)) / 2.0      


        # Loss for fake images
        validity_fake, pred_labels_fake = discriminator(fake_imgs.detach())
        d_fake_loss = (adversarial_loss(validity_fake, fake) + auxiliary_loss(pred_labels_fake, labels_fake)) / 2.0


        # Total discriminator loss
        d_loss = (d_real_loss + d_fake_loss) / 2

        
        #Back Propagation: Discriminator
        d_loss.backward()
        optimizer_D.step()



        # Calculate discriminator accuracy
        pred = np.concatenate([validity_real.data.cpu().numpy(), validity_fake.data.cpu().numpy()], axis=0)
        gt = np.concatenate([real.data.cpu().numpy(), fake.data.cpu().numpy()], axis=0)
        d_acc = np.mean(np.argmax(pred, axis=1) == gt)

        
        #Log Loss and Accuracy
        discriminator_loss.append(d_loss.item());
        generator_loss.append(g_loss.item());
        discriminator_accuracy.append(100 * d_acc);

        
        #Save Results
        if idx ==0:

            sample_image(n_row=3, epochs=epoch)

        #Increment
        idx = idx + 1;



    #Print Progress
    print(
        "[Epoch %d/%d] [Batch %d/%d] [D loss: %f, acc: %d%%] [G loss: %f]"
        % (epoch, n_epochs, idx, len(dataloader), np.nanmean(np.array(discriminator_loss)), np.nanmean(np.array(discriminator_accuracy)), np.nanmean(np.array(generator_loss)))
    )

    #Log Loss and Accuracy
    discriminator_loss_net[epoch,0] = np.nanmean(np.array(discriminator_loss));
    generator_loss_net[epoch,0] = np.nanmean(np.array(generator_loss));
    discriminator_accuracy_net[epoch,0] = np.nanmean(np.array(discriminator_accuracy));

    #Structural Similarity Index Metric
    SSIM_results = np.zeros((500,3));
    for i in range(0,500):

        for j in range(0,3):
            # Generate a batch of images
            z = Variable(FloatTensor(np.random.normal(0, 1, (2, latent_dim))))
            labels_fake = Variable(LongTensor(np.array([j, j])));               
            fake_imgs = generator(z, labels_fake)
            img_1 = fake_imgs[0].squeeze().cpu().detach().numpy();
            img_2 = fake_imgs[1].squeeze().cpu().detach().numpy();
            SSIM_results[i,j] = ssim(img_1,img_2);

    SSIM_results_net[epoch,0] = np.nanmean(SSIM_results[:,0]);
    SSIM_results_net[epoch,1] = np.nanmean(SSIM_results[:,1]);
    SSIM_results_net[epoch,2] = np.nanmean(SSIM_results[:,2]);

    #Save Data
    np.save(torch_path + 'Discriminator_Loss_' + experiment +'.npy', discriminator_loss_net);
    np.save(torch_path + 'Generator_Loss_' + experiment +'.npy', generator_loss_net);
    np.save(torch_path +  'Discriminator_Accuracy_' + experiment +'.npy', discriminator_accuracy_net);
    np.save(torch_path +  'SSIM_Results_' + experiment +'.npy', SSIM_results_net);


    #Scheduler Step
    scheduler_G.step() 
    scheduler_D.step()   


    #Save Generator and Discriminator
    torch.save(generator.state_dict(), torch_path +  'Generator_' + experiment + '.pth');
    torch.save(discriminator.state_dict(), torch_path  +  'Discriminator_' + experiment + '.pth');