<a href="https://colab.research.google.com/github/sidijju/FastMRI/blob/master/GANv5.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
%load_ext autoreload
%matplotlib inline

In [2]:
#Import required libraries
%autoreload 2
!pip3 install --upgrade pip
!pip3 install torch
!pip3 install torchvision
!pip3 install torchfusion
!pip3 install tensorboardx
!pip3 install pillow
!pip3 install pydicom
!pip3 install opencv-python

import os
import errno
import scipy
import pydicom as dicom
import scipy.misc
import numpy as np
import cv2

import matplotlib.pyplot as plt
import matplotlib.image as mpimg

import tensorflow as tf
import tensorflow.contrib.slim as slim
from tensorflow import nn, layers
from tensorflow.contrib import layers as clayers 

import torch
import torch.cuda as cuda
import torch.nn as nn

from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam
from torch.distributions import Normal

import torchvision.utils as vutils
from torchvision import transforms, utils, datasets

from torchfusion.gan.learners import *
from torchfusion.gan.applications import DCGANGenerator,DCGANDiscriminator
from torchfusion.datasets import mnist_loader

from tensorboardX import SummaryWriter
from IPython import display

from PIL import Image

from getpass import getpass


Requirement already up-to-date: pip in /usr/local/lib/python3.6/dist-packages (19.0.1)


In [0]:
class Logger:

    def __init__(self, model_name, data_name):
        self.model_name = model_name
        self.data_name = data_name

        self.comment = '{}_{}'.format(model_name, data_name)
        self.data_subdir = '{}/{}'.format(model_name, data_name)

        # TensorBoard
        self.writer = SummaryWriter(comment=self.comment)

    def log(self, d_error, g_error, epoch, n_batch, num_batches):

        # var_class = torch.autograd.variable.Variable
        if isinstance(d_error, torch.autograd.Variable):
            d_error = d_error.data.cpu().numpy()
        if isinstance(g_error, torch.autograd.Variable):
            g_error = g_error.data.cpu().numpy()

        step = Logger._step(epoch, n_batch, num_batches)
        self.writer.add_scalar(
            '{}/D_error'.format(self.comment), d_error, step)
        self.writer.add_scalar(
            '{}/G_error'.format(self.comment), g_error, step)

    def log_images(self, images, num_images, epoch, n_batch, num_batches, format='NCHW', normalize=True):
        '''
        input images are expected in format (NCHW)
        '''
        if type(images) == np.ndarray:
            images = torch.from_numpy(images)
        
        if format=='NHWC':
            images = images.transpose(1,3)
        

        step = Logger._step(epoch, n_batch, num_batches)
        img_name = '{}/images{}'.format(self.comment, '')

        # Make horizontal grid from image tensor
        horizontal_grid = vutils.make_grid(
            images, normalize=normalize, scale_each=True)
        # Make vertical grid from image tensor
        nrows = int(np.sqrt(num_images))
        grid = vutils.make_grid(
            images, nrow=nrows, normalize=True, scale_each=True)

        # Add horizontal images to tensorboard
        self.writer.add_image(img_name, horizontal_grid, step)

        # Save plots
        self.save_torch_images(horizontal_grid, grid, epoch, n_batch)

    def save_torch_images(self, horizontal_grid, grid, epoch, n_batch, plot_horizontal=True):
        out_dir = './data/images/{}'.format(self.data_subdir)
        Logger._make_dir(out_dir)

        # Plot and save horizontal
        fig = plt.figure(figsize=(128, 128))
        #plt.imshow(np.moveaxis(horizontal_grid.numpy(), 0, -1))
        #plt.axis('off')
        #if plot_horizontal:
            #display.display(plt.gcf())
        self._save_images(fig, epoch, n_batch, 'hori')
        #plt.close()

        # Save squared
        fig = plt.figure()
        #plt.imshow(np.moveaxis(grid.numpy(), 0, -1))
        #plt.axis('off')
        self._save_images(fig, epoch, n_batch)
        #plt.close()

    def _save_images(self, fig, epoch, n_batch, comment=''):
        out_dir = './data/images/{}'.format(self.data_subdir)
        Logger._make_dir(out_dir)
        fig.savefig('{}/{}_epoch_{}_batch_{}.png'.format(out_dir,
                                                         comment, epoch, n_batch))

    def display_status(self, epoch, num_epochs, n_batch, num_batches, d_error, g_error, d_pred_real, d_pred_fake):
        
        # var_class = torch.autograd.variable.Variable
        if isinstance(d_error, torch.autograd.Variable):
            d_error = d_error.data.cpu().numpy()
        if isinstance(g_error, torch.autograd.Variable):
            g_error = g_error.data.cpu().numpy()
        if isinstance(d_pred_real, torch.autograd.Variable):
            d_pred_real = d_pred_real.data
        if isinstance(d_pred_fake, torch.autograd.Variable):
            d_pred_fake = d_pred_fake.data
        
        
        print('Epoch: [{}/{}], Batch Num: [{}/{}]'.format(
            epoch,num_epochs, n_batch, num_batches)
             )
        print('Discriminator Loss: {:.4f}, Generator Loss: {:.4f}'.format(d_error, g_error))
        print('D(x): {:.4f}, D(G(z)): {:.4f}'.format(d_pred_real.mean(), d_pred_fake.mean()))

    def save_models(self, generator, discriminator, epoch):
        out_dir = './data/models/{}'.format(self.data_subdir)
        Logger._make_dir(out_dir)
        torch.save(generator.state_dict(),
                   '{}/G_epoch_{}'.format(out_dir, epoch))
        torch.save(discriminator.state_dict(),
                   '{}/D_epoch_{}'.format(out_dir, epoch))

    def close(self):
        self.writer.close()

    # Private Functionality

    @staticmethod
    def _step(epoch, n_batch, num_batches):
        return epoch * num_batches + n_batch

    @staticmethod
    def _make_dir(directory):
        try:
            os.makedirs(directory)
        except OSError as e:
            if e.errno != errno.EEXIST:
                raise
                

In [0]:
class OASISDataset(Dataset):
          
    def __init__(self, arr_name = "T1_DCM", transforms=None):
        
        self.transforms = transforms
        self.FLAIR_DCM = []
        self.ROI_DCM = []
        self.T1_DCM = []
        self.T2_DCM = []
        self.data = torch.zeros(()).new_empty((22, 512, 512))
        
        #get OASIS Data from private BitBucket (since the data is not available publicly)
        
        user = getpass('BitBucket user')
        password = getpass('BitBucket password')
        os.environ['BITBUCKET_AUTH'] = user + ':' + password.replace("@", "%40")

        !git clone https://$BITBUCKET_AUTH@bitbucket.org/sidijju/OASIS-Data.git
        os.chdir("OASIS-Data/")
        
        root = "/content/OASIS-Data/BRAINIX/DICOM/"

        #read in data
        for dirName, subdirList, fileList in os.walk("/content/OASIS-Data/BRAINIX/DICOM"):
            if(dirName == root + "FLAIR"):
                for filename in fileList:
                    self.FLAIR_DCM.append(os.path.join(dirName,filename))
            if(dirName == root + "ROI"):
                for filename in fileList:
                    self.ROI_DCM.append(os.path.join(dirName,filename))
            if(dirName == root + "T1"):
                for filename in fileList:
                    self.T1_DCM.append(os.path.join(dirName,filename))
            if(dirName == root + "T2"):
                for filename in fileList:
                    self.T2_DCM.append(os.path.join(dirName,filename))

        #self.FLAIR_Ref = self.getInfo(self.FLAIR_DCM)
        #self.FLAIR_Dicom = np.zeros(self.FLAIR_Ref[1], dtype=self.FLAIR_Ref[0].pixel_array.dtype)
        #self.storeList(self.FLAIR_DCM, self.FLAIR_Dicom)

        #self.ROI_Ref = self.getInfo(self.ROI_DCM)
        #self.ROI_Dicom = np.zeros(self.ROI_Ref[1], dtype=self.ROI_Ref[0].pixel_array.dtype)
        #self.storeList(self.ROI_DCM, self.ROI_Dicom)
        
        #store T1 data
        
        self.T1_Ref = self.getInfo(self.T1_DCM)
        self.T1_Dicom = np.zeros(self.T1_Ref[1], dtype=self.T1_Ref[0].pixel_array.dtype)
        self.storeList(self.T1_DCM, self.T1_Dicom)

        #self.T2_Ref = self.getInfo(self.T2_DCM)
        #self.T2_Dicom = np.zeros(self.T2_Ref[1], dtype=self.T2_Ref[0].pixel_array.dtype)
        #self.storeList(self.T2_DCM, self.T2_Dicom)
        
        if(arr_name == "T1_DCM"):
            self.arr = self.T1_Dicom.reshape((512, 512, 1, 22))
        elif(arr_name == "T2_DCM"):
            self.arr = self.T2_Dicom
        elif(arr_name == "FLAIR_DCM"):
            self.arr = self.FLAIR_Dicom
        else:
            self.arr = self.ROI_Dicom 
        
    def getInfo(self, ref):
      
          # Get ref file
        RefDs = dicom.read_file(ref[0])

          # Load dimensions based on the number of rows, columns, and slices (along the Z axis)
        ConstPixelDims = (int(RefDs.Rows), int(RefDs.Columns), len(ref))
          # Load spacing values (in mm)
        ConstPixelSpacing = (float(RefDs.PixelSpacing[0]), float(RefDs.PixelSpacing[1]), float(RefDs.SliceThickness))
          #calculate axes
        x = np.arange(0.0, (ConstPixelDims[0]+1)*ConstPixelSpacing[0], ConstPixelSpacing[0])
        y = np.arange(0.0, (ConstPixelDims[1]+1)*ConstPixelSpacing[1], ConstPixelSpacing[1])
        z = np.arange(0.0, (ConstPixelDims[2]+1)*ConstPixelSpacing[2], ConstPixelSpacing[2])

        return RefDs, ConstPixelDims, ConstPixelSpacing, x, y, z

    def storeList(self, directory, array):
        for filenameDCM in directory:
            ds = dicom.read_file(filenameDCM)
            array[:, :, int(filenameDCM[-6:-4]) - 1] = ds.pixel_array
            
    def plotPicture(self, im, ref, title=""):
        #plot picture of data
        plt.figure(dpi=50)
        plt.axes().set_aspect('equal', 'datalim')
        plt.set_cmap(plt.gray())
        plt.title(title)
        plt.pcolormesh(self.T1_Ref[3], self.T1_Ref[4], im)

    def __getitem__(self, index):
        if index < np.size(self.arr,3) and index >= 0:
            if self.transforms is not None:
                self.plotPicture(self.arr[:, :, :, index].reshape((512, 512)), self.T1_Ref, "Picture " + str(index))
                self.data.add(self.transforms(self.arr[:, :, :, index].astype('uint8')))
            return self.arr[:, :, :, index].astype('uint8'), index
        else:
            print("INDEX INVALID")
            return None
        
    def __len__(self):
        return np.size(self.arr,3)
      

In [0]:
#load custom dataset with torch
DATA_FOLDER = './tf_data/DCGAN/OASIS'
def oasis_data():
    compose = transforms.Compose(
        [
            transforms.ToTensor(),
            #transforms.Normalize((.5, .5, .5), (.5, .5, .5)),
        ])
    out_dir = '{}/dataset'.format(DATA_FOLDER)
    return OASISDataset(arr_name = "T1_DCM", transforms = compose)

##Define Standard Generator and Discriminator Models

In [0]:
G = DCGANGenerator(output_size=(1,512,512),latent_size=(128, 0))
D = DCGANDiscriminator(input_size=(1,512,512),apply_sigmoid=False)

if cuda.is_available():
    G = nn.DataParallel(G.cuda())
    D = nn.DataParallel(D.cuda())

In [0]:
g_optim = Adam(G.parameters(),lr=0.0002,betas=(0.5,0.999))
d_optim = Adam(D.parameters(),lr=0.0002,betas=(0.5,0.999))

In [0]:
learner = RAvgStandardGanLearner(G, D)

In [0]:
dataset = oasis_data()

In [0]:
if __name__ == "__main__":
    learner.train(dataset,gen_optimizer=g_optim,disc_optimizer=d_optim,save_outputs_interval=500,model_dir="./OASIS-gan",latent_size=128,num_epochs=50,batch_log=False)