In [2]:
#Imports from project
from import_images import import_images_from_path
from cellpose_data import get_cellpose_probability_maps
from random_crops import get_random_crops_from_multiple_images
from augment_data import rotate_images_and_cellprobs_return_merged
from u_net import UNet

#Import from other libraries
import numpy as np
from sklearn.model_selection import train_test_split
import torch
from torch.optim import Adam
import wandb

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
#File directories
images_path = 'data/'

In [4]:
#Importing images
images = import_images_from_path(images_path,num_imgs=23,normalisation=True)

In [5]:
#Getting groundtruth data from images with Cellpose
cell_probabilities = get_cellpose_probability_maps(images)

In [6]:
#Normalise the cell_probabilities to be between 0 and 1
cell_probabilities_norm = [(cellprob-np.min(cellprob))/(np.max(cellprob)-np.min(cellprob)) for cellprob in cell_probabilities]

In [7]:
#Getting tensor crops of 128x128 from images
image_crops, cellprob_crops = get_random_crops_from_multiple_images(images,cell_probabilities_norm,num_crops=10)

In [8]:
#Augmenting the data
image_crops_augmented, cellprob_crops_augmented = rotate_images_and_cellprobs_return_merged(image_crops,cellprob_crops,angles=[90,180,270])

In [9]:
#Splitting the data into training and testing
X_train, X_test, y_train, y_test = train_test_split(image_crops_augmented, cellprob_crops_augmented, test_size=0.33, random_state=42)

In [20]:
#Hyperparameters
learning_rate = 0.001
num_epochs = 300
batch_size = 2
loss_fn = torch.nn.MSELoss() #check in WANDB
activation_fn = torch.nn.ReLU()

#Initialising the model, we might need to do 256x256 crops in the end
model = UNet()
model = model.to('cuda:0')

#Get the dataloaders
train_loader = torch.utils.data.DataLoader(list(zip(X_train,y_train)), batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(list(zip(X_test,y_test)), batch_size=batch_size, shuffle=True)

#Optimiser
#opt = Adam(model.parameters(), lr=learning_rate)
opt = torch.optim.RMSprop(model.parameters(), lr=learning_rate)

train_losses = []
test_losses = []

for epoch in range(num_epochs):
    total_train_loss_per_epoch = 0
    for i, (images, cellprobs) in enumerate(train_loader):
        opt.zero_grad()
        images = torch.unsqueeze(images,1)
        outputs = model(images)
        outputs = activation_fn(outputs)
        loss = loss_fn(outputs, cellprobs)
        total_train_loss_per_epoch += loss.item()
        loss.backward()
        opt.step()
    total_train_loss_per_epoch /= len(train_loader)
    train_losses.append(total_train_loss_per_epoch)
    
    #get test loss
    total_test_loss_per_epoch = 0
    with torch.no_grad():
        for images, cellprobs in test_loader:
            images = torch.unsqueeze(images,1)
            outputs = model(images)
            outputs = activation_fn(outputs)
            loss = loss_fn(outputs, cellprobs)
            total_test_loss_per_epoch += loss.item()
    total_test_loss_per_epoch /= len(test_loader)
    test_losses.append(total_test_loss_per_epoch)

    print('Epoch: {}/{} | Train Loss: {:.4f} | Test Loss: {:.4f}'.format(epoch+1, num_epochs, total_train_loss_per_epoch, total_test_loss_per_epoch))
    #might be iunteresting to get a prediction at each epoch from the test loss and visualise it to track its progress

Epoch: 1/300 | Train Loss: 0.0349 | Test Loss: 0.0250
Epoch: 2/300 | Train Loss: 0.4538 | Test Loss: 0.0258
Epoch: 3/300 | Train Loss: 0.0251 | Test Loss: 0.0248
Epoch: 4/300 | Train Loss: 0.0250 | Test Loss: 0.0240
Epoch: 5/300 | Train Loss: 0.0278 | Test Loss: 0.0286
Epoch: 6/300 | Train Loss: 0.0243 | Test Loss: 0.0391
Epoch: 7/300 | Train Loss: 0.0241 | Test Loss: 0.0234
Epoch: 8/300 | Train Loss: 0.0233 | Test Loss: 0.0231
Epoch: 9/300 | Train Loss: 0.0238 | Test Loss: 0.0237
Epoch: 10/300 | Train Loss: 0.0229 | Test Loss: 0.0224
Epoch: 11/300 | Train Loss: 0.0225 | Test Loss: 0.0258
Epoch: 12/300 | Train Loss: 0.0222 | Test Loss: 0.0222
Epoch: 13/300 | Train Loss: 0.0218 | Test Loss: 0.0226
Epoch: 14/300 | Train Loss: 0.0216 | Test Loss: 0.0219
Epoch: 15/300 | Train Loss: 0.0218 | Test Loss: 0.0220
Epoch: 16/300 | Train Loss: 0.0213 | Test Loss: 0.0218
Epoch: 17/300 | Train Loss: 0.0216 | Test Loss: 0.0227
Epoch: 18/300 | Train Loss: 0.0214 | Test Loss: 0.0216
Epoch: 19/300 | Tra

In [22]:
#save the model
torch.save(model.state_dict(), 'cellprob_model_0.pt')