In [1]:
from skimage import io, transform, color
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.data import sampler
import torchvision.datasets as dset
import torchvision.transforms as T
from dataLoader import OrganoidDataset
from torch.utils import data
import numpy as np
import sys
import pandas as pd
from imageio import imread
from PIL import Image
import os
import math
import torchvision.models as models

from dataLoader import OrganoidDataset
#from conv_model import SimpleConvNet
import matplotlib.pyplot as plt
import copy

%matplotlib inline
plt.rcParams['figure.figsize'] = (20.0, 10.0) # set default size of plots
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'
plt.rcParams['pdf.fonttype'] = 42
plt.rcParams['ps.fonttype'] = 42
# for auto-reloading external modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%load_ext autoreload
%autoreload 2

In [3]:
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")

params = {'batch_size': 100, # low for testing
          'shuffle': True, 'num_workers' : 2}
max_epochs = 100

figure_path = '../milestoneReport/figures/'
#path = '../data/CS231n_Tim_Shan_example_data/'
path = '../data/'
label_path = '../data/well_summary_A1_e0891BSA_all.csv'

## load data 

In [49]:
class OrganoidMultipleDataset(data.Dataset):
    'dataset class for microwell organoid images'
    def __init__(self, path2files, image_names, Y, mean_sd_dict, transforms=None):
        for k, image_name in image_names.items():
            assert len(image_name) == len(Y)
        self.path = path2files
        self.image_names = image_names
        self.Y = Y
        self.mean_sd_dict = mean_sd_dict
        self.transforms = transforms
    def __len__(self):
        return len(self.Y)
    def getXimage(self, index):
        all_images_list = []
        for day,img_names in self.image_names.items():
            print(day, "   ", index)
            
            img_name = img_names[index]
            img_loc = os.path.join(self.path, img_name)
            image = io.imread(img_loc)
            mean, sd = self.mean_sd_dict[day]
            image = np.true_divide(color.rgb2gray(image) - mean, sd)
            all_images_list.append(image)
        images = np.array(all_images_list)
        return torch.from_numpy(images).float()
    def getY(self, index):
        Y = self.Y[index]
        return torch.from_numpy(np.asarray(self.Y[index], dtype=float)).float()
    def __getitem__(self, index):
        X = self.getXimage(index)
        y = self.getY(index)
        if self.transforms is not None:
            X = self.transforms(X)
        return X, y

In [30]:
training_labels = pd.read_csv('../data_description/A1_A2_C1_filtered_train_v2.csv')
validation_labels = pd.read_csv('../data_description/A1_A2_C1_filtered_validation_v2.csv')
test_labels = pd.read_csv('../data_description/A1_A2_C1_filtered_test_v2.csv')

#### filter for predict size

In [31]:
training_labels = training_labels.query('has_cell_13 == 1')
validation_labels = validation_labels.query('has_cell_13 == 1')
test_labels = test_labels.query('has_cell_13 == 1')

In [32]:
training_labels.shape,validation_labels.shape,test_labels.shape

((3952, 60), (476, 60), (478, 60))

In [52]:
training_labels.head(2)

Unnamed: 0.1,Unnamed: 0,condition,well_id,day_0,well_label,image_name_0,has_cell_0,hyst2_area_0,day_1,image_name_1,...,has_cell_11,hyst2_area_11,day_12,image_name_12,has_cell_12,hyst2_area_12,day_13,image_name_13,has_cell_13,hyst2_area_13
0,0,A1,64,0,64,well_A1/well0064_day00_well.png,0,0,1,well_A1/well0064_day01_well.png,...,0,0,12,well_A1/well0064_day12_well.png,0,0,13,well_A1/well0064_day13_well.png,1,172
2,2,A2,977,0,977,well_A2/well0977_day00_well.png,1,2568,1,well_A2/well0977_day01_well.png,...,1,3227,12,well_A2/well0977_day12_well.png,1,3063,13,well_A2/well0977_day13_well.png,1,3194


In [33]:
training_image_names = {2:training_labels['image_name_2'],8:training_labels['image_name_8'], 5:training_labels['image_name_5']}
validation_image_names = {2:validation_labels['image_name_2'],8:validation_labels['image_name_8'],5:validation_labels['image_name_5']}

In [35]:
training_y = training_labels['hyst2_area_13']
validation_y = validation_labels['hyst2_area_13']
#test_y = test_labels['hyst2_area_13']

In [36]:
mean_sd_dict = {2: [0.49439774802337344, 0.16087996922691195],
 8: [0.5177020917650417, 0.15714445907773483],
 5: [0.5013496452715945, 0.1605951051365687],              }

In [37]:
train_set = OrganoidMultipleDataset(path2files = path, image_names = training_image_names, Y = training_labels['has_cell_13'],mean_sd_dict=mean_sd_dict)
validation_set = OrganoidMultipleDataset(path2files = path, image_names = validation_image_names, Y = validation_labels['has_cell_13'],mean_sd_dict=mean_sd_dict)
#test_set = OrganoidMultipleDataset(path2files = path, image_names = test_image_names, Y = test_labels['has_cell_13'],mean_sd_dict=mean_sd_dict)



In [38]:
training_generator = data.DataLoader(train_set, **params)
validation_generator = data.DataLoader(validation_set, **params)
#test_generator = data.DataLoader(test_set, **params)

### Model

In [17]:
USE_GPU = True

dtype = torch.float32 # we will be using float throughout this tutorial

if USE_GPU and torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

# Constant to control how frequently we print train loss
print_every = 100

print('using device:', device)

using device: cuda


In [18]:
def flatten(x):
    N = x.shape[0] # read in N, C, H, W
    return x.view(N, -1)  # "flatten" the C * H * W values into a single vector per image

def test_flatten():
    x = torch.arange(12).view(2, 1, 3, 2)
    print('Before flattening: ', x)
    print('After flattening: ', flatten(x))

test_flatten()
class Flatten(nn.Module):
    def forward(self, x):
        return flatten(x)

Before flattening:  tensor([[[[ 0,  1],
          [ 2,  3],
          [ 4,  5]]],


        [[[ 6,  7],
          [ 8,  9],
          [10, 11]]]])
After flattening:  tensor([[ 0,  1,  2,  3,  4,  5],
        [ 6,  7,  8,  9, 10, 11]])


In [39]:
def check_accuracy_part34(loader, model,dataset='validation'):
    num_correct = 0
    num_samples = 0
    model.eval()  # set model to evaluation mode
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device=device, dtype=dtype)  # move to device, e.g. GPU
            y = y.to(device=device, dtype=torch.long)
            scores = model(x)
            _, preds = scores.max(1)
            val_loss = F.mse_loss(preds, y)
            val_error = val_loss.item()
            losses.append(val_error)
            totalbatchMSE = totalbatchMSE + params['batch_size']*val_error/NUM_VAL
        print('Got accuracy (%.2f)' % (100 * val_error))
        #print('Got %d / %d correct (%.2f)' % (num_correct, num_samples, 100 * acc))
    return totalbatchMSE,preds

In [40]:
def train_part34(model, optimizer, epochs=1,lr_scheduler=None):
    best_model = None
    best_validation_accuracy = -1
    
    losses = []
    validation_accuracy = []
    training_accuracy = []
    prediction = []
    model = model.to(device=device)  # move the model parameters to CPU/GPU

        
    for e in range(epochs):
        print('epoch:',e)
        
        if lr_scheduler is not None:
            lr_scheduler.step()
            print('LR:', lr_scheduler.get_lr())

        for t, (x, y) in enumerate(training_generator):
            model.train()  # put model to training mode
            x = x.to(device=device, dtype=dtype)  # move to device, e.g. GPU
            y = y.to(device=device, dtype=torch.long)

            Y_hat = model(x)
            #loss = nn.MSELoss(Y_hat,y)
            loss = F.mse_loss(Y_hat, y)
            
            # Zero out all of the gradients for the variables which the optimizer
            # will update.
            optimizer.zero_grad()

            # This is the backwards pass: compute the gradient of the loss with
            # respect to each  parameter of the model.
            loss.backward()

            # Actually update the parameters of the model using the gradients
            # computed by the backwards pass.
            optimizer.step()
            
            train_error = loss.item()
            
            totalbatchMSE = totalbatchMSE + params['batch_size']*train_error/NUM_TRAIN
            
            #losses.append(train_error)
            losses[e] =train_error
            epoch_error[e] = totalbatchMSE

            if t % print_every == 0:

                print('Iteration %d, loss = %.4f' % (t, l))
                train_acc, _ = check_accuracy_part34(training_generator, model, dataset='training')
                training_accuracy.append(train_acc)
                validation_acc, preds = check_accuracy_part34(validation_generator, model,dataset='validation')
                validation_accuracy.append(validation_acc)
                prediction.append(preds)
                
                if validation_acc > best_validation_accuracy:
                    best_validation_accuracy = validation_acc
                    best_model = copy.deepcopy(model)


                print()

        
                        
#     checkpoint = {'model': best_model,
#                 'state_dict': model.state_dict()
#                }

#     torch.save(checkpoint, 'classification_checkpoint.pth')
    #torch.save(best_model.state_dict(), 'classification_model.pth')
    return losses, validation_accuracy, training_accuracy, prediction,best_model

## original model

In [50]:
in_channels = 2 
channel_1 = 32
channel_2 = 16
channel_3 = 8
out_size = 1
image_size = int(193/2/2)
image_size = 193
model = nn.Sequential(
    nn.Conv2d(in_channels=in_channels,out_channels=32,kernel_size=5,padding=2,bias=True),
    nn.ReLU(),
    nn.MaxPool2d(2),
    nn.Conv2d(in_channels=32,out_channels=64,kernel_size=3,padding=1,bias=True),
    nn.ReLU(), 
    nn.MaxPool2d(2),
    nn.Conv2d(in_channels=64,out_channels=128,kernel_size=3,padding=1,bias=True),
    nn.ReLU(),  
    nn.MaxPool2d(2),
    nn.Conv2d(in_channels=128,out_channels=256,kernel_size=3,padding=1,bias=True),
    nn.ReLU(),  
    nn.MaxPool2d(2),
    nn.Conv2d(in_channels=256,out_channels=256,kernel_size=3,padding=1,bias=True),
    nn.ReLU(),  
    nn.MaxPool2d(2),
    
    nn.Dropout(),
    Flatten(),
    nn.Linear(256*6*6, out_size),
)

## pretrained Resnet 

In [22]:
resnet18 = models.resnet18()
model = models.resnet18(pretrained=True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 1)

## Train models

In [41]:
model.cuda()

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace)
      (conv2): Co

In [51]:
print_every = 100
epoches = 10
learning_rate = 1e-4

optimizer =optim.Adam(model.parameters(), lr=learning_rate)

losses,validation_accuracy, training_accuracy,prediction,best_model = train_part34(model, optimizer,epochs=epoches)

epoch: 0


Exception ignored in: <function _DataLoaderIter.__del__ at 0x7f342366f7b8>
Traceback (most recent call last):
  File "/opt/anaconda3/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 677, in __del__
    self._shutdown_workers()
  File "/opt/anaconda3/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 659, in _shutdown_workers
    w.join()
  File "/opt/anaconda3/lib/python3.7/multiprocessing/process.py", line 138, in join
    assert self._parent_pid == os.getpid(), 'can only join a child process'
AssertionError: can only join a child process
Exception ignored in: <function _DataLoaderIter.__del__ at 0x7f342366f7b8>
Traceback (most recent call last):
  File "/opt/anaconda3/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 677, in __del__
    self._shutdown_workers()
  File "/opt/anaconda3/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 659, in _shutdown_workers
    w.join()
  File "/opt/anaconda3/lib/python3.7/multiprocessi

Exception: KeyError:Traceback (most recent call last):
  File "/opt/anaconda3/lib/python3.7/site-packages/torch/utils/data/_utils/worker.py", line 99, in _worker_loop
    samples = collate_fn([dataset[i] for i in batch_indices])
  File "/opt/anaconda3/lib/python3.7/site-packages/torch/utils/data/_utils/worker.py", line 99, in <listcomp>
    samples = collate_fn([dataset[i] for i in batch_indices])
  File "<ipython-input-29-042cd6044c57>", line 30, in __getitem__
    X = self.getXimage(index)
  File "<ipython-input-29-042cd6044c57>", line 18, in getXimage
    img_name = img_names[index]
  File "/opt/anaconda3/lib/python3.7/site-packages/pandas/core/series.py", line 868, in __getitem__
    result = self.index.get_value(self, key)
  File "/opt/anaconda3/lib/python3.7/site-packages/pandas/core/indexes/base.py", line 4375, in get_value
    tz=getattr(series.dtype, 'tz', None))
  File "pandas/_libs/index.pyx", line 81, in pandas._libs.index.IndexEngine.get_value
  File "pandas/_libs/index.pyx", line 89, in pandas._libs.index.IndexEngine.get_value
  File "pandas/_libs/index.pyx", line 132, in pandas._libs.index.IndexEngine.get_loc
  File "pandas/_libs/hashtable_class_helper.pxi", line 987, in pandas._libs.hashtable.Int64HashTable.get_item
  File "pandas/_libs/hashtable_class_helper.pxi", line 993, in pandas._libs.hashtable.Int64HashTable.get_item
KeyError: 2340
