In [2]:
import torch
import os
import imageio
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from scipy.ndimage import binary_erosion
import torch
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
import torch.nn as nn
from torch.nn import functional as F
from torchvision import transforms
import pandas as pd
from torchvision import datasets
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data.sampler import WeightedRandomSampler


In [3]:
#images = "/home/bandyadkas/cellcyle/data/"
images = "/mnt/efs/woods_hole/bbbc_cellcycle/model_data/data/"

In [4]:
transform = {
        'train': transforms.Compose([
            transforms.Resize([225,225]), # Resizing the image as the VGG only take 224 x 244 as input size
            #transforms.RandomHorizontalFlip(), # Flip the data horizontally
            #TODO if it is needed, add the random crop
            transforms.ToTensor(),
            #transforms.Normalize(mean=(0.5,0.5,0.5), std=(0.5,0.5,0.5))
        ]),
        'test': transforms.Compose([
            transforms.Resize([225,225]),
            #transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            #transforms.Normalize(mean=(0), std=(1))
        ])
    }

In [5]:
all_images = datasets.ImageFolder(images,transform=transform['train'])
print(len(all_images))

32266


In [6]:
train_size = int(0.7 * len(all_images))
val_size = int(0.15 * len(all_images))
test_size = len(all_images) - (train_size + val_size)
print(train_size, val_size, test_size)
assert train_size + val_size + test_size == len(all_images)

22586 4839 4841


In [7]:
train_set, val_set, test_set = torch.utils.data.random_split(all_images, [train_size, val_size, test_size])
#train_set, val_set, test_set = torch.utils.data.random_split(all_images, [22538, 4829, 4831])

In [8]:
def _get_weights(subset,full_dataset):
    ys = np.array([y for _, y in subset])
    counts = np.bincount(ys)
    label_weights = 1.0 / counts
    weights = label_weights[ys]

    print("Number of images per class:")
    for c, n, w in zip(full_dataset.classes, counts, label_weights):
        print(f"\t{c}:\tn={n}\tweight={w}")
        
    return weights


In [9]:

train_weights = _get_weights(train_set,all_images)
train_sampler = WeightedRandomSampler(train_weights, len(train_weights))


Number of images per class:
	Anaphase:	n=13	weight=0.07692307692307693
	G1:	n=10027	weight=9.97307270370001e-05
	G2:	n=5962	weight=0.0001677289500167729
	Metaphase:	n=42	weight=0.023809523809523808
	Prophase:	n=431	weight=0.002320185614849188
	S:	n=6095	weight=0.00016406890894175554
	Telophase:	n=16	weight=0.0625


In [10]:
train_loader = DataLoader(train_set, batch_size=8, drop_last=True, sampler=train_sampler)
val_loader = DataLoader(val_set, batch_size=8 , drop_last=True, shuffle=True)
test_loader = DataLoader(test_set, batch_size=8, drop_last=True, shuffle=True)

In [11]:
## Set up tensorboard
writer = SummaryWriter('/home/sivagurunathans/runs/cellcycle_densenet_discrete')

In [13]:
import torchvision.models as models

In [14]:
densenet = models.densenet161(pretrained=True)

In [15]:
densenet_model = torchvision.models.densenet161(pretrained = False, progress  = True, num_classes=7)
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(densenet_model.parameters(), lr=0.001)

In [16]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
densenet_model.to(device)
print(f"Will use device {device} for training")

Will use device cuda for training


In [17]:
from tqdm import tqdm

def train(model,loss,train_dataloader):
    model.train()
    epoch_loss = 0
    num_batches = 0
    for x, y in tqdm(train_dataloader):

        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()

        y_pred = model(x)
        l = loss(y_pred, y)
        l.backward()
        optimizer.step()

        epoch_loss += l
        num_batches += 1

    return epoch_loss/num_batches

def evaluate(model, loss, dataloader):
    
    correct = 0
    total = 0
    with torch.no_grad():
        for x, y in tqdm(dataloader):

            x, y = x.to(device), y.to(device)

            logits = model(x)
            val_loss = loss(logits,y)
            
            probs = torch.nn.Softmax(dim=1)(logits)
            predictions = torch.argmax(probs, dim=1)

            correct += int(torch.sum(predictions == y).cpu().detach().numpy())
            total += len(y)

        accuracy = correct/total

    return accuracy, val_loss

def validate(model,loss, validation_dataloader):
    '''Evaluate prediction accuracy on the validation dataset.'''
    
    model.eval()
    return evaluate(model,loss,validation_dataloader)

def test(model,loss,test_dataloader):
    '''Evaluate prediction accuracy on the test dataset.'''
    
    model.eval() 
    return evaluate(model, loss,test_dataloader)

In [18]:
step = 0
epochs = 50
for epoch in range(epochs+1):
    while step < epoch: 
    
        epoch_loss = train(densenet_model,loss_fn,train_loader)
        print(f"epoch {epoch}, training loss={epoch_loss}")
    
        validation_accuracy, validation_loss = validate(densenet_model, loss_fn,val_loader)
        print(f"epoch {epoch}, validation accuracy={validation_accuracy}")
    
        writer.add_scalar('Loss/train', epoch_loss.cpu().detach().numpy(),step)      
        writer.add_scalar('Accuracy/validation', validation_accuracy,step)
        writer.add_scalar('Loss/validation', validation_loss.cpu().detach().numpy(),step)
        
        if step == 150:
            state = {
                'epoch': epoch,
                'state_dict': densenet_model.state_dict(),
                'optimizer': optimizer.state_dict()
            }
        torch.save(state, "/home/sivagurunathans/runs/savedmodel.pth")
        step += 1

100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 2823/2823 [11:24<00:00,  4.12it/s]


epoch 1, training loss=1.0217254161834717


100%|█████████████████████████████████████████| 604/604 [00:49<00:00, 12.20it/s]


epoch 1, validation accuracy=0.6268625827814569


100%|███████████████████████████████████████| 2823/2823 [11:19<00:00,  4.15it/s]


epoch 2, training loss=0.5667539834976196


100%|█████████████████████████████████████████| 604/604 [00:52<00:00, 11.45it/s]


epoch 2, validation accuracy=0.65625


100%|███████████████████████████████████████| 2823/2823 [11:29<00:00,  4.09it/s]


epoch 3, training loss=0.44884365797042847


100%|█████████████████████████████████████████| 604/604 [00:53<00:00, 11.37it/s]


epoch 3, validation accuracy=0.7162665562913907


100%|███████████████████████████████████████| 2823/2823 [11:37<00:00,  4.05it/s]


epoch 4, training loss=0.3935126066207886


100%|█████████████████████████████████████████| 604/604 [00:53<00:00, 11.36it/s]


epoch 4, validation accuracy=0.7131622516556292


100%|███████████████████████████████████████| 2823/2823 [11:31<00:00,  4.08it/s]


epoch 5, training loss=0.3538416624069214


100%|█████████████████████████████████████████| 604/604 [00:52<00:00, 11.61it/s]


epoch 5, validation accuracy=0.6692880794701986


100%|███████████████████████████████████████| 2823/2823 [11:25<00:00,  4.12it/s]


epoch 6, training loss=0.32331258058547974


100%|█████████████████████████████████████████| 604/604 [00:51<00:00, 11.68it/s]


epoch 6, validation accuracy=0.7446192052980133


100%|███████████████████████████████████████| 2823/2823 [11:27<00:00,  4.11it/s]


epoch 7, training loss=0.2866142690181732


100%|█████████████████████████████████████████| 604/604 [00:52<00:00, 11.57it/s]


epoch 7, validation accuracy=0.7437913907284768


100%|███████████████████████████████████████| 2823/2823 [11:27<00:00,  4.11it/s]


epoch 8, training loss=0.278890997171402


100%|█████████████████████████████████████████| 604/604 [00:50<00:00, 11.85it/s]


epoch 8, validation accuracy=0.7535182119205298


100%|███████████████████████████████████████| 2823/2823 [11:28<00:00,  4.10it/s]


epoch 9, training loss=0.26624250411987305


100%|█████████████████████████████████████████| 604/604 [00:51<00:00, 11.68it/s]


epoch 9, validation accuracy=0.7638658940397351


100%|███████████████████████████████████████| 2823/2823 [11:24<00:00,  4.12it/s]


epoch 10, training loss=0.2501029372215271


100%|█████████████████████████████████████████| 604/604 [00:51<00:00, 11.73it/s]


epoch 10, validation accuracy=0.765728476821192


100%|███████████████████████████████████████| 2823/2823 [11:26<00:00,  4.11it/s]


epoch 11, training loss=0.24601489305496216


100%|█████████████████████████████████████████| 604/604 [00:51<00:00, 11.67it/s]


epoch 11, validation accuracy=0.7620033112582781


100%|███████████████████████████████████████| 2823/2823 [11:24<00:00,  4.12it/s]


epoch 12, training loss=0.23546931147575378


100%|█████████████████████████████████████████| 604/604 [00:50<00:00, 11.88it/s]


epoch 12, validation accuracy=0.7688327814569537


100%|███████████████████████████████████████| 2823/2823 [11:26<00:00,  4.11it/s]


epoch 13, training loss=0.23435111343860626


100%|█████████████████████████████████████████| 604/604 [00:50<00:00, 11.91it/s]


epoch 13, validation accuracy=0.7456539735099338


100%|███████████████████████████████████████| 2823/2823 [11:24<00:00,  4.13it/s]


epoch 14, training loss=0.22596365213394165


100%|█████████████████████████████████████████| 604/604 [00:51<00:00, 11.71it/s]


epoch 14, validation accuracy=0.7359271523178808


100%|███████████████████████████████████████| 2823/2823 [11:31<00:00,  4.08it/s]


epoch 15, training loss=0.22166258096694946


100%|█████████████████████████████████████████| 604/604 [00:51<00:00, 11.73it/s]


epoch 15, validation accuracy=0.7680049668874173


100%|███████████████████████████████████████| 2823/2823 [11:28<00:00,  4.10it/s]


epoch 16, training loss=0.2082272171974182


100%|█████████████████████████████████████████| 604/604 [00:50<00:00, 11.92it/s]


epoch 16, validation accuracy=0.7638658940397351


100%|███████████████████████████████████████| 2823/2823 [11:29<00:00,  4.09it/s]


epoch 17, training loss=0.20511971414089203


100%|█████████████████████████████████████████| 604/604 [00:51<00:00, 11.77it/s]


epoch 17, validation accuracy=0.7603476821192053


100%|███████████████████████████████████████| 2823/2823 [11:25<00:00,  4.12it/s]


epoch 18, training loss=0.19594909250736237


100%|█████████████████████████████████████████| 604/604 [00:49<00:00, 12.14it/s]


epoch 18, validation accuracy=0.7483443708609272


100%|███████████████████████████████████████| 2823/2823 [11:22<00:00,  4.14it/s]


epoch 19, training loss=0.1909351497888565


100%|█████████████████████████████████████████| 604/604 [00:50<00:00, 11.92it/s]


epoch 19, validation accuracy=0.7766970198675497


100%|███████████████████████████████████████| 2823/2823 [11:20<00:00,  4.15it/s]


epoch 20, training loss=0.1873047798871994


100%|█████████████████████████████████████████| 604/604 [00:49<00:00, 12.24it/s]


epoch 20, validation accuracy=0.7545529801324503


100%|███████████████████████████████████████| 2823/2823 [11:20<00:00,  4.15it/s]


epoch 21, training loss=0.1842007339000702


100%|█████████████████████████████████████████| 604/604 [00:50<00:00, 11.85it/s]


epoch 21, validation accuracy=0.7566225165562914


100%|███████████████████████████████████████| 2823/2823 [11:21<00:00,  4.14it/s]


epoch 22, training loss=0.1684114784002304


100%|█████████████████████████████████████████| 604/604 [00:50<00:00, 11.95it/s]


epoch 22, validation accuracy=0.7622102649006622


100%|███████████████████████████████████████| 2823/2823 [11:23<00:00,  4.13it/s]


epoch 23, training loss=0.1638963222503662


100%|█████████████████████████████████████████| 604/604 [00:51<00:00, 11.72it/s]


epoch 23, validation accuracy=0.7392384105960265


100%|███████████████████████████████████████| 2823/2823 [11:19<00:00,  4.15it/s]


epoch 24, training loss=0.153244286775589


100%|█████████████████████████████████████████| 604/604 [00:50<00:00, 11.85it/s]


epoch 24, validation accuracy=0.7570364238410596


100%|███████████████████████████████████████| 2823/2823 [11:20<00:00,  4.15it/s]


epoch 25, training loss=0.14408326148986816


100%|█████████████████████████████████████████| 604/604 [00:51<00:00, 11.84it/s]


epoch 25, validation accuracy=0.7518625827814569


100%|███████████████████████████████████████| 2823/2823 [11:25<00:00,  4.12it/s]


epoch 26, training loss=0.13937629759311676


100%|█████████████████████████████████████████| 604/604 [00:50<00:00, 11.91it/s]


epoch 26, validation accuracy=0.7429635761589404


100%|███████████████████████████████████████| 2823/2823 [11:24<00:00,  4.13it/s]


epoch 27, training loss=0.13990986347198486


100%|█████████████████████████████████████████| 604/604 [00:50<00:00, 11.95it/s]


epoch 27, validation accuracy=0.7475165562913907


100%|███████████████████████████████████████| 2823/2823 [11:25<00:00,  4.12it/s]


epoch 28, training loss=0.1255894899368286


100%|█████████████████████████████████████████| 604/604 [00:51<00:00, 11.69it/s]


epoch 28, validation accuracy=0.7541390728476821


 82%|███████████████████████████████▊       | 2304/2823 [09:24<02:07,  4.08it/s]


KeyboardInterrupt: 

In [20]:
state = {
                'epoch': epoch,
                'state_dict': densenet_model.state_dict(),
                'optimizer': optimizer.state_dict()
            }
torch.save(state, "/home/sivagurunathans/runs/savedmodel.pth")

In [None]:
load_model = torch.load("/home/sivagurunathans/runs/savedmodel.pth")

In [None]:
test_accuracy, test_loss = test(densenet_model,loss_fn,test_loader)
print(f"final test accuracy: {test_accuracy}")
writer.add_scalar('Accuracy/test', test_accuracy)
writer.add_scalar('Loss/test', test_loss.cpu().detach().numpy(),step)
        


In [None]:
from sklearn.metrics import confusion_matrix
import seaborn as sns
import numpy as np

# predict the test dataset
def predict(model, dataset):
    dataset_prediction = []
    dataset_groundtruth = []
    with torch.no_grad():
        for x, y_true in dataset:
            inp = x[None].cuda()
            y_pred = model(inp)
            dataset_prediction.append(y_pred.argmax().cpu().numpy())
            dataset_groundtruth.append(y_true)
    
    return np.array(dataset_prediction), np.array(dataset_groundtruth)
            
    # create seabvorn heatmap with required labels
    #sns.heatmap(flights_df, xticklabels=x_axis_labels, yticklabels=y_axis_labels)
    ax=sns.heatmap(cm, annot=annot, fmt='', vmax=30, xticklabels=x_axis_labels, yticklabels=y_axis_labels)
    ax.set_title(title)



In [None]:
import pandas as pd
# Plot confusion matrix 
# orginally from Runqi Yang; 
# see https://gist.github.com/hitvoice/36cf44689065ca9b927431546381a3f7
def cm_analysis(y_true, y_pred, title, figsize=(10,10)):
    """
    Generate matrix plot of confusion matrix with pretty annotations.
    The plot image is saved to disk.
    args: 
      y_true:    true label of the data, with shape (nsamples,)
      y_pred:    prediction of the data, with shape (nsamples,)
      filename:  filename of figure file to save
      labels:    string array, name the order of class labels in the confusion matrix.
                 use `clf.classes_` if using scikit-learn models.
                 with shape (nclass,).
      ymap:      dict: any -> string, length == nclass.
                 if not None, map the labels & ys to more understandable strings.
                 Caution: original y_true, y_pred and labels must align.
      figsize:   the size of the figure plotted.
    """
    labels = ["0", "1", "2", "3", "4", "5", "6"]
    cm = confusion_matrix(y_true, y_pred)
    cm_sum = np.sum(cm, axis=1, keepdims=True)
    cm_perc = cm / cm_sum.astype(float) * 100
    annot = np.empty_like(cm).astype(str)
    nrows, ncols = cm.shape
    for i in range(nrows):
        for j in range(ncols):
            c = cm[i, j]
            p = cm_perc[i, j]
            if i == j:
                s = cm_sum[i]
                annot[i, j] = '%.1f%%\n%d/%d' % (p, c, s)
            elif c == 0:
                annot[i, j] = ''
            else:
                annot[i, j] = '%.1f%%\n%d' % (p, c)
    cm = pd.DataFrame(cm, index=labels, columns=labels)
    cm.index.name = 'Actual'
    cm.columns.name = 'Predicted'
    fig, ax = plt.subplots(figsize=figsize)
    x_axis_labels = ['Anaphase', 'G1', 'G2', 'Metaphase', 'Prophase', 'S', 'Telophase'] # labels for x-axis
    y_axis_labels = ['Anaphase', 'G1', 'G2', 'Metaphase', 'Prophase', 'S', 'Telophase'] # labels for y-axis


In [None]:
y_pred, y_true = predict(densenet_modelet50_model, test_set)
cm_analysis(y_true, y_pred, "Confusion matrix")