In [1]:
from import_images import getImages
from import_model import getModel
from make_predictions import makePredictions
import numpy as np

import torch

from cellpose import resnet_torch
from cellpose import transforms
from cellpose import utils
import cv2

import time

from unet_architecture import UNet
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import torchvision
import numpy as np

import matplotlib.pyplot as plt

In [2]:
images_directory = "C:\\Users\\rz200\\Documents\\development\\distillCellSegTrack\\pipeline\\uploads\\"
images = getImages(images_directory)
images_torch = torch.from_numpy(np.array(images))

In [3]:
directory = "C:\\Users\\rz200\\Documents\\development\\distillCellSegTrack\\datasets\\Fluo-C2DL-Huh7\\01\\models\\CP_20230601_101328"
cpnet = resnet_torch.CPnet(nbase=[2,32,64,128,256],nout=3,sz=3)
cpnet.load_model(directory)

In [4]:
def get_pre_activations(image,cpnet):
    rescale = cpnet.diam_mean/cpnet.diam_labels
    shape1, shape2 = image.shape[0], image.shape[1]

    x = transforms.resize_image(image, rsz=rescale,no_channels=True)
    x = x.astype(np.float32)
    x = np.expand_dims(x, axis=0)
    x = np.concatenate((x, x), axis=0)
    x = torch.from_numpy(x)
    x = x.unsqueeze(0)

    downsample = cpnet.downsample(x)
    print(downsample.shape)

    style = cpnet.make_style(downsample[-1])
    upsample = cpnet.upsample(style, downsample, cpnet.mkldnn)

    output = cpnet.output(upsample)
    output = output.squeeze(0)
    output = output.cpu().detach().numpy().tolist()
    for (k, image) in enumerate(output):
        output[k] = cv2.resize(np.array(image), dsize=(512, 512), interpolation=cv2.INTER_NEAREST)
    output = np.array(output)
    output = torch.from_numpy(output)

    upsample = upsample.squeeze(0)
    upsample = upsample.cpu().detach().numpy().tolist()
    for (k, image) in enumerate(upsample):
        upsample[k] = cv2.resize(np.array(image), dsize=(512, 512), interpolation=cv2.INTER_NEAREST)
    upsample = np.array(upsample)
    upsample = torch.from_numpy(upsample)

    return upsample, output

: 

: 

In [None]:
cp_upsamples = []
cp_outputs = []
for image in images:
    upsample, output = get_pre_activations(image,cpnet)
    cp_upsamples.append(upsample)
    cp_outputs.append(output)

In [None]:
class ImageDataset(Dataset):
    def __init__(self, image, cellprob, cellmask):
        self.image = image
        self.cellprob = cellprob
        self.cellmask = cellmask

    def __len__(self):
        return len(self.image)

    def __getitem__(self, idx):
        img = self.image[idx]
        cellprob = self.cellprob[idx]
        cellmask = self.cellmask[idx]
        return img, cellprob, cellmask
    
train_dataset = ImageDataset(images_torch[:1], cp_upsamples[:1], cp_outputs[:1])
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)

In [None]:
class KDLoss(torch.nn.Module):
    def __init__(self, alpha = 1.0, beta = 0.5, temperature=1):
        super(KDLoss, self).__init__()
        self.alpha = alpha
        self.beta = beta
        self.temperature = temperature

    def forward(self, y_32_pred, y_3_pred, y_32_true, y_3_true):

        y_32_pred = F.sigmoid(y_32_pred)
        y_32_true = F.sigmoid(y_32_true)
        y_32_loss = F.mse_loss(y_32_pred, y_32_true.float())

        y_3_pred = F.sigmoid(y_3_pred)
        y_3_true = F.sigmoid(y_3_true)
        y_3_loss = F.mse_loss(y_3_pred, y_3_true.float())

        loss = self.alpha * y_32_loss + self.beta * y_3_loss
        return loss

In [None]:
def trainEpoch(unet, train_loader, loss_fn, optimiser, scheduler, epoch_num):
    time_start = time.time()
    
    unet.train()

    train_loss = 0
    for image, upsample, cp_output in train_loader:
        (image,upsample,cp_output) = (image.to('cuda:0'),upsample.to('cuda:0'),cp_output.to('cuda:0')) # sending the data to the device (cpu or GPU)

        image = image.unsqueeze(1)
        y_32_pred, y_3_pred = unet(image)
        y_32_pred = y_32_pred.squeeze(1)
        y_3_pred = y_3_pred.squeeze(1)
        
        #pred, predicted_masks, cellprob, cellmask = pred.squeeze(0), predicted_masks.float().squeeze(0), cellprob.squeeze(0), cellmask.squeeze(0)

        loss = loss_fn(y_32_pred, y_3_pred, upsample, cp_output) # calculate the loss of that prediction

        train_loss += loss
        optimiser.zero_grad() # zero out the accumulated gradients
        loss.backward() # backpropagate the loss
        optimiser.step() # update model parameters
        if scheduler is not None:
            scheduler.step()
    train_loss = train_loss.item()/len(train_loader)

    if epoch_num is None:
        print('Training loss: ', train_loss, 'Time: ', time.time()-time_start)
    else:
        print('Epoch ', epoch_num, 'Training loss: ', train_loss, 'Time: ', time.time()-time_start)

    return unet

In [None]:
unet = UNet(nbClasses=3)
unet = unet.to('cuda:0')
loss_fn = KDLoss()
optimiser = torch.optim.SGD(unet.parameters(), lr=0.1, momentum=0.1)
scheduler = torch.optim.lr_scheduler.CyclicLR(optimiser, base_lr=0.00001, max_lr=0.1)

for epoch in range(1000):
    unet = trainEpoch(unet, train_loader, loss_fn, optimiser, scheduler=scheduler, epoch_num=epoch)

In [None]:
for image, upsample, cp_output in train_loader:
    (image,upsample,cp_output) = (image.to('cuda:0'),upsample.to('cuda:0'),cp_output.to('cuda:0')) # sending the data to the device (cpu or GPU)

    image = image.unsqueeze(1)
    unet = unet.to('cuda:0')
    y_32_pred, y_3_pred = unet(image)

    y_3_pred = F.sigmoid(y_3_pred)
    y_3_pred = y_3_pred.cpu().detach().numpy()
    y_3_pred = y_3_pred.squeeze(0)

    plt.subplot(1,2,1)
    plt.imshow(y_3_pred[0])
    plt.subplot(1,2,2)
    plt.imshow(cp_output[0].cpu().detach().numpy()[0])
    plt.show()

In [None]:
from cellpose import dynamics

dynamics.get_masks(y_3_pred)

In [None]:
#sigmoid numpy array y_3_pred
y_3_pred_2_sig = 1/(1+np.exp(-y_3_pred[2]))
print(np.unique(y_3_pred_2_sig))
y_3_pred_2_sig_binary = np.where(y_3_pred_2_sig>0.1,1,0)
print(np.unique(y_3_pred_2_sig_binary))
plt.imshow(y_3_pred_2_sig_binary)
plt.show()

In [None]:
unet = UNet(nbClasses=3)
decfeatures, pred = unet(torch.from_numpy(images[0]).unsqueeze(0).unsqueeze(0))
print(decfeatures.shape, pred.shape)