In [1]:
import torch
import os
import imageio
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from scipy.ndimage import binary_erosion
import torch
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
import torch.nn as nn
from torch.nn import functional as F
from torchvision import transforms
import pandas as pd
from torchvision import datasets
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data.sampler import WeightedRandomSampler


In [2]:
#images = "/home/bandyadkas/cellcyle/data/"
images = "/mnt/efs/woods_hole/bbbc_cellcycle/model_data/data/"

In [3]:
transform = {
        'train': transforms.Compose([
            #transforms.Resize([224,224]), # Resizing the image as the VGG only take 224 x 244 as input size
            #transforms.RandomHorizontalFlip(), # Flip the data horizontally
            #TODO if it is needed, add the random crop
            transforms.ToTensor(),
            #transforms.Normalize(mean=(0.5,0.5,0.5), std=(0.5,0.5,0.5))
        ]),
        'test': transforms.Compose([
            #transforms.Resize([224,224]),
            #transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            #transforms.Normalize(mean=(0), std=(1))
        ])
    }

In [4]:
all_images = datasets.ImageFolder(images,transform=transform['train'])
print(len(all_images))
#all_images = torchvision.datasets.DatasetFolder(
#    "/mnt/efs/woods_hole/bbbc_cellcycle/CellCycle/",
#    imageio.imread, ("merged.jpg"), transform=transform)
#print(len(all_images))

32266


In [5]:
train_size = int(0.7 * len(all_images))
val_size = int(0.15 * len(all_images))
test_size = len(all_images) - (train_size + val_size)
print(train_size, val_size, test_size)
assert train_size + val_size + test_size == len(all_images)

22586 4839 4841


In [6]:
train_set, val_set, test_set = torch.utils.data.random_split(all_images, [train_size, val_size, test_size])
#train_set, val_set, test_set = torch.utils.data.random_split(all_images, [22538, 4829, 4831])

In [7]:
def _get_weights(subset,full_dataset):
    ys = np.array([y for _, y in subset])
    counts = np.bincount(ys)
    label_weights = 1.0 / counts
    weights = label_weights[ys]

    print("Number of images per class:")
    for c, n, w in zip(full_dataset.classes, counts, label_weights):
        print(f"\t{c}:\tn={n}\tweight={w}")
        
    return weights


In [8]:

train_weights = _get_weights(train_set,all_images)
train_sampler = WeightedRandomSampler(train_weights, len(train_weights))


Number of images per class:
	Anaphase:	n=11	weight=0.09090909090909091
	G1:	n=9991	weight=0.00010009008107296567
	G2:	n=5991	weight=0.00016691704223001168
	Metaphase:	n=54	weight=0.018518518518518517
	Prophase:	n=424	weight=0.0023584905660377358
	S:	n=6097	weight=0.0001640150893882237
	Telophase:	n=18	weight=0.05555555555555555


In [9]:
train_loader = DataLoader(train_set, batch_size=8, drop_last=True, sampler=train_sampler)
val_loader = DataLoader(val_set, batch_size=8 , drop_last=True, shuffle=True)
test_loader = DataLoader(test_set, batch_size=8, drop_last=True, shuffle=True)

In [10]:
## Set up tensorboard
writer = SummaryWriter('/mnt/efs/woods_hole/bbbc_cellcycle/classify_cellCycle_bandyadka/runs/cellcycle_resnet50_discrete')

In [None]:
#!tensorboard --logdir=runs

In [11]:
resnet50_model = torchvision.models.resnet50(pretrained = False, progress  = True, num_classes=7)
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(resnet50_model.parameters(), lr=0.001)

In [12]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
resnet50_model.to(device)
print(f"Will use device {device} for training")

Will use device cuda for training


In [13]:
from tqdm import tqdm

def train(model,loss,train_dataloader):
    model.train()
    epoch_loss = 0
    num_batches = 0
    for x, y in tqdm(train_dataloader):

        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()

        y_pred = model(x)
        l = loss(y_pred, y)
        l.backward()
        optimizer.step()

        epoch_loss += l
        num_batches += 1

    return epoch_loss/num_batches

def evaluate(model, loss, dataloader):
    
    correct = 0
    total = 0
    with torch.no_grad():
        for x, y in tqdm(dataloader):

            x, y = x.to(device), y.to(device)

            logits = model(x)
            val_loss = loss(logits,y)
            
            probs = torch.nn.Softmax(dim=1)(logits)
            predictions = torch.argmax(probs, dim=1)

            correct += int(torch.sum(predictions == y).cpu().detach().numpy())
            total += len(y)

        accuracy = correct/total

    return accuracy, val_loss

def validate(model,loss, validation_dataloader):
    '''Evaluate prediction accuracy on the validation dataset.'''
    
    model.eval()
    return evaluate(model,loss,validation_dataloader)

def test(model,loss,test_dataloader):
    '''Evaluate prediction accuracy on the test dataset.'''
    
    model.eval() 
    return evaluate(model, loss,test_dataloader)

In [None]:
step = 0
epochs = 100
for epoch in range(epochs+1):
    while step < epoch: 
    
        epoch_loss = train(resnet50_model,loss_fn,train_loader)
        print(f"epoch {epoch}, training loss={epoch_loss}")
    
        validation_accuracy, validation_loss = validate(resnet50_model, loss_fn,val_loader)
        print(f"epoch {epoch}, validation accuracy={validation_accuracy}")
    
        writer.add_scalar('Loss/train', epoch_loss.cpu().detach().numpy(),step)      
        writer.add_scalar('Accuracy/validation', validation_accuracy,step)
        writer.add_scalar('Loss/validation', validation_loss.cpu().detach().numpy(),step)
        
        if step == 100:
            state = {
                'epoch': epoch,
                'state_dict': resnet50_model.state_dict(),
                'optimizer': optimizer.state_dict()
            }
            torch.save(state, "/mnt/efs/woods_hole/bbbc_cellcycle/classify_cellCycle_bandyadka/modelsave.pth")
        
        step += 1

In [14]:
load_model = torch.load("/mnt/efs/woods_hole/bbbc_cellcycle/classify_cellCycle_bandyadka/modelsave.pth")

In [16]:
test_accuracy, test_loss = test(resnet50_model,loss_fn,test_loader)
print(f"final test accuracy: {test_accuracy}")
writer.add_scalar('Accuracy/test', test_accuracy)
#writer.add_scalar('Loss/test', test_loss.cpu().detach().numpy(),step)
        


100%|███████████████████████████████████████████████████████████████████████████████████| 605/605 [00:23<00:00, 25.70it/s]

final test accuracy: 0.2721074380165289





In [None]:
from sklearn.metrics import confusion_matrix
import seaborn as sns
import numpy as np

# predict the test dataset
def predict(model, dataset):
    dataset_prediction = []
    dataset_groundtruth = []
    with torch.no_grad():
        for x, y_true in dataset:
            inp = x[None].cuda()
            y_pred = model(inp)
            dataset_prediction.append(y_pred.argmax().cpu().numpy())
            dataset_groundtruth.append(y_true)
    
    return np.array(dataset_prediction), np.array(dataset_groundtruth)
            
    # create seabvorn heatmap with required labels
    #sns.heatmap(flights_df, xticklabels=x_axis_labels, yticklabels=y_axis_labels)
    ax=sns.heatmap(cm, annot=annot, fmt='', vmax=30, xticklabels=x_axis_labels, yticklabels=y_axis_labels)
    ax.set_title(title)



In [None]:
import pandas as pd
# Plot confusion matrix 
# orginally from Runqi Yang; 
# see https://gist.github.com/hitvoice/36cf44689065ca9b927431546381a3f7
def cm_analysis(y_true, y_pred, title, figsize=(10,10)):
    """
    Generate matrix plot of confusion matrix with pretty annotations.
    The plot image is saved to disk.
    args: 
      y_true:    true label of the data, with shape (nsamples,)
      y_pred:    prediction of the data, with shape (nsamples,)
      filename:  filename of figure file to save
      labels:    string array, name the order of class labels in the confusion matrix.
                 use `clf.classes_` if using scikit-learn models.
                 with shape (nclass,).
      ymap:      dict: any -> string, length == nclass.
                 if not None, map the labels & ys to more understandable strings.
                 Caution: original y_true, y_pred and labels must align.
      figsize:   the size of the figure plotted.
    """
    labels = ["0", "1", "2", "3", "4", "5", "6"]
    cm = confusion_matrix(y_true, y_pred)
    cm_sum = np.sum(cm, axis=1, keepdims=True)
    cm_perc = cm / cm_sum.astype(float) * 100
    annot = np.empty_like(cm).astype(str)
    nrows, ncols = cm.shape
    for i in range(nrows):
        for j in range(ncols):
            c = cm[i, j]
            p = cm_perc[i, j]
            if i == j:
                s = cm_sum[i]
                annot[i, j] = '%.1f%%\n%d/%d' % (p, c, s)
            elif c == 0:
                annot[i, j] = ''
            else:
                annot[i, j] = '%.1f%%\n%d' % (p, c)
    cm = pd.DataFrame(cm, index=labels, columns=labels)
    cm.index.name = 'Actual'
    cm.columns.name = 'Predicted'
    fig, ax = plt.subplots(figsize=figsize)
    x_axis_labels = ['Anaphase', 'G1', 'G2', 'Metaphase', 'Prophase', 'S', 'Telophase'] # labels for x-axis
    y_axis_labels = ['Anaphase', 'G1', 'G2', 'Metaphase', 'Prophase', 'S', 'Telophase'] # labels for y-axis


In [None]:
y_pred, y_true = predict(resnet50_model, test_set)
cm_analysis(y_true, y_pred, "Confusion matrix")