In this notebook, we train a model using more of Cellposes' pre- and post-processing functions

In [2]:
from import_images import getImages
import numpy as np
import torch
from cellpose import resnet_torch
from cellpose import transforms
import cv2
import time
from unet_architecture import UNet
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import torchvision.transforms as T
from torchmetrics.classification import BinaryJaccardIndex
import torchvision
import numpy as np
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import ezomero
from omero_data import connect, extract_channel, progressbar

In [42]:
def get_pre_activations(image,cpnet):
    x = torch.from_numpy(image)
    downsample = cpnet.downsample(x)
    style = cpnet.make_style(downsample[-1])
    upsample = cpnet.upsample(style, downsample, cpnet.mkldnn)
    
    output = cpnet.output(upsample).squeeze(0)
    output = output.cpu().detach().numpy()
    output = np.array(output)

    upsample = upsample.squeeze(0)
    upsample = upsample.cpu().detach().numpy().tolist()
    upsample = np.array(upsample)
    return upsample, output

class ImageDataset(Dataset):
    def __init__(self, image, upsample, cellprob):
        self.image = image
        self.upsample = upsample
        self.cellprob = cellprob

    def __len__(self):
        return len(self.image)

    def __getitem__(self, idx):
        img = self.image[idx]
        upsample = self.upsample[idx]
        cellprob = self.cellprob[idx]
        return img, upsample, cellprob
    
class KD_loss(torch.nn.Module):
    def __init__(self, alpha, beta):
        super(KD_loss, self).__init__()
        self.alpha = alpha
        self.beta = beta

    def forward(self, y_32_pred, y_32_true, y_3_pred, y_3_true):
        #32-channel loss
        y_32_pred = F.sigmoid(y_32_pred)
        y_32_true = F.sigmoid(y_32_true)
        y_32_loss = F.mse_loss(y_32_pred, y_32_true.float())
        #3-channel loss
        flow_loss = F.mse_loss(y_3_pred[:,:2], y_3_true[:,:2])
        flow_loss /= 2.
        map_loss = F.mse_loss(y_3_pred[:,2] , y_3_true[:,2])
        y_3_loss = flow_loss + map_loss
        return y_32_loss * self.alpha, y_3_loss * self.beta

def trainEpoch(unet, train_loader, test_loader, validation_loader, loss_fn, optimiser, scheduler, epoch_num, device):
    time_start = time.time()
    
    unet.train()

    train_y_32_loss, train_map_loss = 0, 0

    for image, upsample, cp_output in train_loader:

        
        image, upsample, cp_output = image.float(), upsample.float(), cp_output.float() #cast to float32 (important for mps)

        if device is not None:
            (image, upsample, cp_output) = (image.to(device),upsample.to(device),cp_output.to(device)) # sending the data to the device (cpu or GPU)

        #image = image.unsqueeze(1)
        y_16_pred, y_32_pred, map_pred = unet(image)
        y_32_pred = y_32_pred.squeeze(1)
        map_pred = map_pred.squeeze(1)
    
        loss_32, loss_map = loss_fn(y_32_pred,  upsample, map_pred, cp_output) # calculate the loss of that prediction
        train_y_32_loss += loss_32.item()
        train_map_loss += loss_map.item()

         # zero out the accumulated gradients

        #I want to get two losses, one for the 32-channel output and one for the 3-channel output
        #I then want to freeze certain channels before putting the losses backwards
        unet.encoder.requires_grad = True #repetitive but just to be clear
        unet.decoder.requires_grad = True
        unet.head.requires_grad = False
        loss_32.backward(retain_graph=True)

        
        unet.encoder.requires_grad = False
        unet.decoder.requires_grad = False
        unet.head.requires_grad = True
        loss_map.backward(retain_graph=True)

        optimiser.step() # update model parameters
        optimiser.zero_grad()

    if scheduler is not None:
        scheduler.step()

    train_y_32_loss, train_map_loss = train_y_32_loss/len(train_loader), train_map_loss/len(train_loader)


    val_y_32_loss, val_map_loss, val_IoU = 0, 0, 0
    for image, upsample, cp_output in validation_loader:
        
        image, upsample, cp_output = image.float(), upsample.float(), cp_output.float() #cast to float32 (important for mps)

        if device is not None:
            (image, upsample, cp_output) = (image.to(device),upsample.to(device),cp_output.to(device)) # sending the data to the device (cpu or GPU)

        #image = image.unsqueeze(1)
        y_16_pred, y_32_pred, map_pred = unet(image)
        y_32_pred = y_32_pred.squeeze(1)
        map_pred = map_pred.squeeze(1)
    
        loss_32, loss_map = loss_fn(y_32_pred,  upsample, map_pred, cp_output) # calculate the loss of that prediction
        val_y_32_loss += loss_32.item()
        val_map_loss += loss_map.item()

        #IoU score
        jaccard = BinaryJaccardIndex(threshold=0.5).to(device)
        map_pred = F.sigmoid(map_pred)
        cp_output = F.sigmoid(cp_output)
        cp_output = torch.where(cp_output > 0.5, 1.0, 0.0)
        iou = jaccard(map_pred, cp_output)
        if not torch.isnan(iou):
            val_IoU += iou
        else:
            val_IoU += 0
        

    val_y_32_loss, val_map_loss, val_IoU = val_y_32_loss/len(validation_loader), val_map_loss/len(validation_loader), val_IoU.item()/len(validation_loader)
    
    #we might add displaying later on
    
    if epoch_num is None:
        print('Train 32 loss: ', train_y_32_loss,'Train map loss', train_map_loss, 'Val 32 loss: ', val_y_32_loss, 'Val map loss: ', val_map_loss, 'Val IoU: ', val_IoU, 'Time: ', time.time()-time_start)
    else:
        print('Epoch: ', epoch_num, 'Train 32 loss: ', train_y_32_loss,'Train map loss', train_map_loss, 'Val 32 loss: ', val_y_32_loss, 'Val map loss: ', val_map_loss, 'Val IoU: ', val_IoU, 'Time: ', time.time()-time_start)

    return unet    

def get_omero_images_combined(num_images=None, channels=[0,0]):
    conn = connect(user='rz200',password='omeroreset')

    plate = 1237
    if num_images == None:
        image_ids = ezomero.get_image_ids(conn,plate=plate)
    else:
        image_ids = ezomero.get_image_ids(conn,plate=plate)[:num_images]

    print('In plate',plate,'we have',len(image_ids),'images')

   

    data_images = []
    for i in progressbar(range(len(image_ids)), "Computing: ", 40):
        data_images.append(ezomero.get_image(conn, image_ids[i])[1])

    if len(channels) == 2 and channels != [0,0]:
        data_images_one = extract_channel(data_images, channels[0])
        data_images_two = extract_channel(data_images, channels[1])

        combined_images = []
        for i in range(len(data_images_one)):
            combined_images.append(np.array([data_images_one[i] , data_images_two[i]]))
    
        return combined_images
    elif len(channels) == 1:
        data_images_one = extract_channel(data_images, channels[0])
        return data_images_one
    else:
        return data_images

def get_cellpose_data(cpnet, combined_images):
    images_tiled = []
    ys = []
    all_upsamples = []

    for i in range(len(combined_images)):
        print(i)
        image_t = combined_images[i]

        if len(image_t.shape) == 2:
            image_t = [image_t,image_t]
            image_t = np.array(image_t)

        IMG, ysub, xsub, Ly, Lx = transforms.make_tiles(image_t, bsize=224, 
                                                        augment=False, tile_overlap=0.1)
        ny, nx, nchan, ly, lx = IMG.shape
        IMG = np.reshape(IMG, (ny*nx, nchan, ly, lx))
        images_tiled.append(IMG)
        batch_size = 1
        niter = int(np.ceil(IMG.shape[0] / batch_size))
        nout = 3 + 32*False
        y = np.zeros((IMG.shape[0], nout, ly, lx))
        styles = []

        upsamples = []

        for k in range(niter):
            irange = np.arange(batch_size*k, min(IMG.shape[0], batch_size*k+batch_size))
            upsample, y0 = get_pre_activations(IMG[irange], cpnet)
            upsamples.append(upsample)
            y[irange] = y0.reshape(len(irange), y0.shape[-3], y0.shape[-2], y0.shape[-1])

        ys.append(y)

        upsamples = np.array(upsamples)
        all_upsamples.append(upsamples)

    images_tiled_np = np.array(images_tiled)
    ys_np = np.array(ys)
    all_upsamples_np = np.array(all_upsamples)

    images_tiled_np = images_tiled_np.reshape(images_tiled_np.shape[0]*images_tiled_np.shape[1], images_tiled_np.shape[2], images_tiled_np.shape[3], images_tiled_np.shape[4])
    ys_np = ys_np.reshape(ys_np.shape[0]*ys_np.shape[1], ys_np.shape[2], ys_np.shape[3], ys_np.shape[4])
    all_upsamples_np = all_upsamples_np.reshape(all_upsamples_np.shape[0]*all_upsamples_np.shape[1], all_upsamples_np.shape[2], all_upsamples_np.shape[3], all_upsamples_np.shape[4])
        
    return images_tiled_np, ys_np, all_upsamples_np


In [4]:
directory = "/Users/rz200/Documents/development/distillCellSegTrack/pipeline/CellPose_models/U2OS_Tub_Hoechst"
cpnet = resnet_torch.CPnet(nbase=[2,32,64,128,256],nout=3,sz=3)
cpnet.load_model(directory)

In [13]:
combined_images = get_omero_images_combined(num_images=2, channels=[0])


Connection successful
In plate 1237 we have 2 images
Computing: [########################################] 2/2

0


ValueError: not enough values to unpack (expected 3, got 2)

In [41]:
print(combined_images[0].shape)

(1080, 1080)


In [43]:
images_tiled, ys, all_upsamples = get_cellpose_data(cpnet, combined_images)

0
1


In [44]:
print(images_tiled.shape)

(72, 2, 224, 224)


In [45]:
train_images, test_images, train_upsamples, test_upsamples, train_cellprob, test_cellprob = train_test_split(images_tiled[:100], all_upsamples[:100], ys[:100], test_size=0.1, random_state=42)
train_images, val_images, train_upsamples, val_upsamples, train_cellprob, val_cellprob = train_test_split(train_images, train_upsamples, train_cellprob, test_size=0.1, random_state=42)

num_train_images = 10
train_dataset = ImageDataset(train_images, train_upsamples, train_cellprob)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)

validation_dataset = ImageDataset(val_images, val_upsamples, val_cellprob)
validation_loader = DataLoader(validation_dataset, batch_size=8, shuffle=True)

test_dataset = ImageDataset(test_images, test_upsamples, test_cellprob)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=True)

In [46]:
unet = UNet(encChannels=(2,32,64,128,256),decChannels=(256,128,64,32),nbClasses=3) #it's not a problem to train it with 2 channels as anyway it is just the same channel repeated twice for the nuclei model but that may make more parameters to train so we may want to chang ethat
unet = unet.to('cuda:0')

In [47]:
loss_fn = KD_loss(alpha=1, beta=1)
optimiser = torch.optim.SGD(unet.parameters(), lr=0.01, momentum=0.1)

#I think a decaying scheduler is best, not a cyclic one
scheduler = torch.optim.lr_scheduler.StepLR(optimiser, step_size=100, gamma=0.1)

for epoch in range(500):
    #print(scheduler.get_last_lr())
    unet = trainEpoch(unet, train_loader, test_loader, validation_loader, loss_fn, optimiser, scheduler=scheduler, epoch_num=epoch, device='cuda:0')



Epoch:  0 Train 32 loss:  0.2979394495487213 Train map loss 12.23060953617096 Val 32 loss:  0.2917259633541107 Val map loss:  9.416550636291504 Val IoU:  0.3035811483860016 Time:  6.260850429534912
Epoch:  1 Train 32 loss:  0.2890694625675678 Train map loss 8.242403268814087 Val 32 loss:  0.29161548614501953 Val map loss:  7.107572078704834 Val IoU:  0.1025366485118866 Time:  5.551206111907959
Epoch:  2 Train 32 loss:  0.2872977592051029 Train map loss 6.266357958316803 Val 32 loss:  0.2915608584880829 Val map loss:  6.081027984619141 Val IoU:  0.09498772770166397 Time:  5.622990846633911
Epoch:  3 Train 32 loss:  0.2941582016646862 Train map loss 6.587083220481873 Val 32 loss:  0.2837829887866974 Val map loss:  5.528446197509766 Val IoU:  0.10075341165065765 Time:  5.659926414489746
Epoch:  4 Train 32 loss:  0.3036434054374695 Train map loss 10.10155564546585 Val 32 loss:  0.28052452206611633 Val map loss:  6.998315334320068 Val IoU:  0.07141703367233276 Time:  5.698857069015503


KeyboardInterrupt: 