In [1]:
import torch
import torch.nn as nn
from torch.nn import init
import torchvision
import torchvision.transforms as T
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import sampler
import torchvision.datasets as dset

import numpy as np

import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

%matplotlib inline
plt.rcParams['figure.figsize'] = (10.0, 8.0) # set default size of plots
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'

def show_images(images):
    images = np.reshape(images, [images.shape[0], -1])  # images reshape to (batch_size, D)
    sqrtn = int(np.ceil(np.sqrt(images.shape[0])))
    sqrtimg = int(np.ceil(np.sqrt(images.shape[1])))

    fig = plt.figure(figsize=(sqrtn, sqrtn))
    gs = gridspec.GridSpec(sqrtn, sqrtn)
    gs.update(wspace=0.05, hspace=0.05)

    for i, img in enumerate(images):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(img.reshape([sqrtimg,sqrtimg]))
    return 

def preprocess_img(x):
    return 2 * x - 1.0

def deprocess_img(x):
    return (x + 1.0) / 2.0

def rel_error(x,y):
    return np.max(np.abs(x - y) / (np.maximum(1e-8, np.abs(x) + np.abs(y))))

def count_params(model):
    """Count the number of parameters in the current TensorFlow graph """
    param_count = np.sum([np.prod(p.size()) for p in model.parameters()])
    return param_count

#answers = dict(np.load('gan-checks-tf.npz'))

In [2]:
from skimage import io, transform, color
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.data import sampler
import torchvision.datasets as dset
import torchvision.transforms as T
from dataLoader import OrganoidDataset
from torch.utils import data
import numpy as np
import sys
import pandas as pd
from imageio import imread
from PIL import Image
import os
import math
import torchvision.models as models

from dataLoader import OrganoidDataset
#from conv_model import SimpleConvNet
import matplotlib.pyplot as plt
import copy

%matplotlib inline
plt.rcParams['figure.figsize'] = (20.0, 10.0) # set default size of plots
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'
plt.rcParams['pdf.fonttype'] = 42
plt.rcParams['ps.fonttype'] = 42
# for auto-reloading external modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%load_ext autoreload
%autoreload 2

In [3]:
NOISE_DIM = 96
batch_size = 128

In [4]:
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")
batch_size = 128
params = {'batch_size': batch_size, # low for testing
          'shuffle': True, 'num_workers' : 2}
max_epochs = 100
image_size = 193

In [5]:
device

device(type='cuda', index=0)

In [6]:
figure_path = '../FinalReport/figures/classification/'
#path = '../data/CS231n_Tim_Shan_example_data/'
path = '../data/'
label_path = '../data/well_summary_A1_e0891BSA_all.csv'

# Load data

In [7]:
class OrganoidMultipleDataset(data.Dataset):
    'dataset class for microwell organoid images'
    def __init__(self, path2files, image_names, Y, mean_sd_dict, transforms=None):
        for k, image_name in image_names.items():
            assert len(image_name) == len(Y)
        self.path = path2files
        self.image_names = image_names
        self.Y = Y
        self.mean_sd_dict = mean_sd_dict
        self.transforms = transforms
    def __len__(self):
        return len(self.Y)
    def getXimage(self, index):
        all_images_list = []
        for day,img_names in self.image_names.items():
            #print(day, "   ", index)
            
            img_name = img_names[index]
            img_loc = os.path.join(self.path, img_name)
            image = io.imread(img_loc)
            mean, sd = self.mean_sd_dict[day]
            image = color.rgb2gray(image)
            #image = np.true_divide(color.rgb2gray(image) - mean, sd)
            all_images_list.append(image)
        images = np.array(all_images_list)
        return torch.from_numpy(images).float()
    def getY(self, index):
        Y = self.Y[index]
        return torch.from_numpy(np.asarray(self.Y[index], dtype=float)).float()
    def __getitem__(self, index):
        X = self.getXimage(index)
        y = self.getY(index)
        if self.transforms is not None:
            X = self.transforms(X)
        return X, y

In [8]:
training_labels = pd.read_csv('../data_description/A1_A2_C1_filtered_train_v2.csv')
validation_labels = pd.read_csv('../data_description/A1_A2_C1_filtered_validation_v2.csv')
test_labels = pd.read_csv('../data_description/A1_A2_C1_filtered_test_v2.csv')

In [9]:
training_labels.shape,validation_labels.shape,test_labels.shape

((6514, 60), (814, 60), (815, 60))

In [10]:
training_image_names = {8:training_labels['image_name_8']}

In [11]:
# training_image_names = {2:training_labels['image_name_2'],8:training_labels['image_name_8'], 5:training_labels['image_name_5']}
# validation_image_names = {2:validation_labels['image_name_2'],8:validation_labels['image_name_8'],5:validation_labels['image_name_5']}

In [12]:
training_y = training_labels['has_cell_13']
validation_y = validation_labels['has_cell_13']

In [13]:
mean_sd_dict = {2: [0.49439774802337344, 0.16087996922691195],
 8: [0.5177020917650417, 0.15714445907773483],
 5: [0.5013496452715945, 0.1605951051365687],              }

In [14]:
train_set = OrganoidMultipleDataset(path2files = path, image_names = training_image_names, Y = training_labels['has_cell_13'],mean_sd_dict=mean_sd_dict)
#validation_set = OrganoidMultipleDataset(path2files = path, image_names = validation_image_names, Y = validation_labels['has_cell_13'],mean_sd_dict=mean_sd_dict)


In [15]:
training_generator = data.DataLoader(train_set, **params)

In [16]:
dtype = torch.FloatTensor

In [17]:
data = next(enumerate(training_generator))[-1][0].type(dtype)

In [18]:
data.shape

torch.Size([128, 1, 193, 193])

### Random noise

In [19]:
def sample_noise(batch_size, dim):
    """
    Generate a PyTorch Tensor of uniform random noise.

    Input:
    - batch_size: Integer giving the batch size of noise to generate.
    - dim: Integer giving the dimension of noise to generate.
    
    Output:
    - A PyTorch Tensor of shape (batch_size, dim) containing uniform
      random noise in the range (-1, 1).
    """
    # *****START OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****

    return torch.rand(batch_size,dim)*2 -1

    # *****END OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****


In [20]:
class Flatten(nn.Module):
    def forward(self, x):
        N, C, H, W = x.size() # read in N, C, H, W
        return x.view(N, -1)  # "flatten" the C * H * W values into a single vector per image
    
class Unflatten(nn.Module):
    """
    An Unflatten module receives an input of shape (N, C*H*W) and reshapes it
    to produce an output of shape (N, C, H, W).
    """
    def __init__(self, N=-1, C=128, H=7, W=7):
        super(Unflatten, self).__init__()
        self.N = N
        self.C = C
        self.H = H
        self.W = W
    def forward(self, x):
        return x.view(self.N, self.C, self.H, self.W)

def initialize_weights(m):
    if isinstance(m, nn.Linear) or isinstance(m, nn.ConvTranspose2d):
        init.xavier_uniform_(m.weight.data)

In [21]:
dtype = torch.FloatTensor
#dtype = torch.cuda.FloatTensor ## UNCOMMENT THIS LINE IF YOU'RE ON A GPU!

# GANs

In [22]:
def bce_loss(input, target):
    """
    Numerically stable version of the binary cross-entropy loss function.

    As per https://github.com/pytorch/pytorch/issues/751
    See the TensorFlow docs for a derivation of this formula:
    https://www.tensorflow.org/api_docs/python/tf/nn/sigmoid_cross_entropy_with_logits

    Inputs:
    - input: PyTorch Tensor of shape (N, ) giving scores.
    - target: PyTorch Tensor of shape (N,) containing 0 and 1 giving targets.

    Returns:
    - A PyTorch Tensor containing the mean BCE loss over the minibatch of input data.
    """
    neg_abs = - input.abs()
    loss = input.clamp(min=0) - input * target + (1 + neg_abs.exp()).log()
    return loss.mean()

In [23]:
def discriminator_loss(logits_real, logits_fake):
    """
    Computes the discriminator loss described above.
    
    Inputs:
    - logits_real: PyTorch Tensor of shape (N,) giving scores for the real data.
    - logits_fake: PyTorch Tensor of shape (N,) giving scores for the fake data.
    
    Returns:
    - loss: PyTorch Tensor containing (scalar) the loss for the discriminator.
    """
    N = logits_real.size()
    true_labels = torch.ones(logits_real.size()).type(dtype)
    return bce_loss(logits_real,true_labels) + bce_loss(logits_fake,1-true_labels)

def generator_loss(logits_fake):
    """
    Computes the generator loss described above.

    Inputs:
    - logits_fake: PyTorch Tensor of shape (N,) giving scores for the fake data.
    
    Returns:
    - loss: PyTorch Tensor containing the (scalar) loss for the generator.
    """
    true_labels = torch.ones(logits_fake.size()).type(dtype)
    loss = bce_loss(logits_fake,true_labels)
    return loss

In [24]:
def get_optimizer(model):
    """
    Construct and return an Adam optimizer for the model with learning rate 1e-3,
    beta1=0.5, and beta2=0.999.
    
    Input:
    - model: A PyTorch model that we want to optimize.
    
    Returns:
    - An Adam optimizer for the model with the desired hyperparameters.
    """
    optimizer = optim.Adam(model.parameters(),lr=1e-3,betas=(0.5,0.999))
    return optimizer

In [66]:
def run_a_gan(D, G, D_solver, G_solver, discriminator_loss, generator_loss, show_every=250, 
              batch_size=128, noise_size=96, num_epochs=10):
    """
    Train a GAN!
    
    Inputs:
    - D, G: PyTorch models for the discriminator and generator
    - D_solver, G_solver: torch.optim Optimizers to use for training the
      discriminator and generator.
    - discriminator_loss, generator_loss: Functions to use for computing the generator and
      discriminator loss, respectively.
    - show_every: Show samples after every show_every iterations.
    - batch_size: Batch size to use for training.
    - noise_size: Dimension of the noise to use as input to the generator.
    - num_epochs: Number of epochs over the training dataset to use for training.
    """
    iter_count = 0
    for epoch in range(num_epochs):
        for x, _ in training_generator:
            if len(x) != batch_size:
                continue
            D_solver.zero_grad()
            real_data = x.type(dtype)
            logits_real = D(2* (real_data - 0.5)).type(dtype)

            g_fake_seed = sample_noise(batch_size, noise_size).type(dtype)
            fake_images = G(g_fake_seed).detach()
            logits_fake = D(fake_images.view(batch_size, 1, 28, 28))

            d_total_error = discriminator_loss(logits_real, logits_fake)
            d_total_error.backward()        
            D_solver.step()

            G_solver.zero_grad()
            g_fake_seed = sample_noise(batch_size, noise_size).type(dtype)
            fake_images = G(g_fake_seed)

            gen_logits_fake = D(fake_images.view(batch_size, 1, image_size, image_size))
            g_error = generator_loss(gen_logits_fake)
            g_error.backward()
            G_solver.step()

            if (iter_count % show_every == 0):
                print('Iter: {}, D: {:.4}, G:{:.4}'.format(iter_count,d_total_error.item(),g_error.item()))
                imgs_numpy = fake_images.data.cpu().numpy()
                show_images(imgs_numpy[0:16])
                plt.show()
                print()
            iter_count += 1

In [64]:
def get_out_channels(in_channels,kernal_size,padding,stride = 1,max_pool=False):
    if max_pool:
        denominator = 2*stride
    else:
        denominator = stride
    out = math.floor((in_channels + 2*padding - kernal_size)/denominator) + 1  
    return out

In [30]:
# dim1 = get_out_channels(in_channels=193,kernal_size=5,padding=0,stride = 1,max_pool=True)
# dim1

95

In [32]:
# dim2 = get_out_channels(in_channels=dim1,kernal_size=5,padding=0,stride = 1,max_pool=True)
# dim2

46

In [33]:
# dim3 = get_out_channels(in_channels=dim2,kernal_size=5,padding=0,stride = 1,max_pool=True)
# dim3

21

In [27]:
def build_dc_classifier():
    """
    Build and return a PyTorch model for the DCGAN discriminator implementing
    the architecture above.
    """
    return nn.Sequential(
        # *****START OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
        Unflatten(batch_size, 1, image_size, image_size),
        nn.Conv2d(1,32,kernel_size=5,stride=1),
        nn.LeakyReLU(0.01),
        nn.MaxPool2d(kernel_size=2,stride=2),
        nn.Conv2d(32,64,kernel_size=5,stride=1),
        nn.LeakyReLU(0.01),
        nn.MaxPool2d(kernel_size=2,stride=2),
        nn.Conv2d(64,64,kernel_size=5,stride=1),
        nn.LeakyReLU(0.01),
        nn.MaxPool2d(kernel_size=2,stride=2),
        Flatten(),
        nn.Linear(20*20*64,20*20*64),
        nn.LeakyReLU(0.01),
        nn.Linear(20*20*64, 1)

        # *****END OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
    )



In [26]:
b = build_dc_classifier().type(dtype)

In [27]:
b

Sequential(
  (0): Unflatten()
  (1): Conv2d(1, 32, kernel_size=(5, 5), stride=(1, 1))
  (2): LeakyReLU(negative_slope=0.01)
  (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (4): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1))
  (5): LeakyReLU(negative_slope=0.01)
  (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (7): Conv2d(64, 64, kernel_size=(5, 5), stride=(1, 1))
  (8): LeakyReLU(negative_slope=0.01)
  (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (10): Flatten()
  (11): Linear(in_features=25600, out_features=25600, bias=True)
  (12): LeakyReLU(negative_slope=0.01)
  (13): Linear(in_features=25600, out_features=1, bias=True)
)

In [28]:
out = b(data)
print(out.size())

torch.Size([128, 1])


In [29]:
1024/28

36.57142857142857

In [93]:
def build_dc_generator(noise_dim=NOISE_DIM):
    """
    Build and return a PyTorch model implementing the DCGAN generator using
    the architecture described above.
    """
    return nn.Sequential(
        # *****START OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****

        nn.Linear(noise_dim,1024),
        nn.ReLU(),
        nn.BatchNorm1d(1024),
        nn.Linear(1024, 19*19*128),
        nn.ReLU(),
        nn.BatchNorm1d(19*19*128),
        Unflatten(batch_size,128,19,19),
        nn.ConvTranspose2d(128,64,kernel_size=4,stride=2,padding=1),
        nn.ReLU(),
        nn.BatchNorm2d(64),
        nn.ConvTranspose2d(64,1,kernel_size=4,stride=2,padding=1),
        nn.Tanh(),
        Flatten()

        # *****END OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
    )



In [94]:
noise = NOISE_DIM

In [95]:
test_g_gan = build_dc_generator(noise).type(dtype)

In [96]:
test_g_gan

Sequential(
  (0): Linear(in_features=96, out_features=1024, bias=True)
  (1): ReLU()
  (2): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (3): Linear(in_features=1024, out_features=46208, bias=True)
  (4): ReLU()
  (5): BatchNorm1d(46208, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (6): Unflatten()
  (7): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (8): ReLU()
  (9): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (10): ConvTranspose2d(64, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (11): Tanh()
  (12): Flatten()
)

In [97]:
test_g_gan.apply(initialize_weights)

Sequential(
  (0): Linear(in_features=96, out_features=1024, bias=True)
  (1): ReLU()
  (2): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (3): Linear(in_features=1024, out_features=46208, bias=True)
  (4): ReLU()
  (5): BatchNorm1d(46208, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (6): Unflatten()
  (7): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (8): ReLU()
  (9): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (10): ConvTranspose2d(64, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (11): Tanh()
  (12): Flatten()
)

In [91]:
get_out_channels(in_channels=386,kernal_size=4,padding=1,stride = 2,max_pool=False)

193

In [89]:
get_out_channels(in_channels=48,kernal_size=4,padding=1,stride = 2,max_pool=False)

24

In [34]:
96*1024

98304

In [35]:
128*784

100352

In [49]:
28 * 28

784

In [92]:
import math 
math.sqrt(386)

19.6468827043885

In [90]:
193*2

386

In [98]:
fake_seed = torch.randn(batch_size, noise).type(dtype)

In [99]:
fake_seed.shape

torch.Size([128, 96])

In [100]:
fake_images = test_g_gan.forward(fake_seed)

In [101]:
fake_images.size()

torch.Size([128, 5776])

In [103]:
193*193

37249

In [102]:
128*5776

739328

In [86]:
193*193

37249

In [None]:
torch.randn(batch_size, 128).type(dtype)

In [None]:
test_g_gan = build_dc_generator().type(dtype)
test_g_gan.apply(initialize_weights)

fake_seed = torch.randn(batch_size, NOISE_DIM).type(dtype)
fake_images = test_g_gan.forward(fake_seed)
fake_images.size()

In [29]:
D_DC = build_dc_classifier().type(dtype) 
D_DC.apply(initialize_weights)
G_DC = build_dc_generator().type(dtype)
G_DC.apply(initialize_weights)


Sequential(
  (0): Linear(in_features=96, out_features=1024, bias=True)
  (1): ReLU()
  (2): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (3): Linear(in_features=1024, out_features=6272, bias=True)
  (4): ReLU()
  (5): BatchNorm1d(6272, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (6): Unflatten()
  (7): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (8): ReLU()
  (9): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (10): ConvTranspose2d(64, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (11): Tanh()
  (12): Flatten()
)

In [30]:
D_DC_solver = get_optimizer(D_DC)
G_DC_solver = get_optimizer(G_DC)

In [67]:
run_a_gan(D_DC, G_DC, D_DC_solver, G_DC_solver, discriminator_loss, generator_loss, num_epochs=5)

RuntimeError: shape '[128, 1, 193, 193]' is invalid for input of size 100352