Run the cell below to finetune the pretrained model to suit your dataset.
<br>
<br>
Once done, download best_model.pth and classifications.txt, this is the model that has been finetuned on your data and the cell names that model has learnt.
<br>
<br>
Proceed to bottom to use cell identifier.

In [None]:
import shutil, torch, timm, os
from google.colab import drive
from tqdm import tqdm
from torch import nn, optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader


def create_folder(foldername):
  try:
    os.mkdir(foldername)

  except FileExistsError:
    print(f"{foldername} already exists.")

  return


def split_new_data():
  source_directory = "/content/drive/MyDrive/MNT/Dataset"

  try:
    folders = os.listdir(source_directory) # access new data if there is any

    for folder in folders: # access each new cell type added
      folder_path = os.path.join(source_directory, folder)
      files = os.listdir(folder_path)

      validate_path = os.path.join("/content/drive/MyDrive/MNT/Training Dataset/Validate", folder)
      training_path = os.path.join("/content/drive/MyDrive/MNT/Training Dataset/Train", folder)

      create_folder(validate_path)
      create_folder(training_path)

      size = len(files)
      train = (size * 8) // 10 # 80% of cell images will be used for training, remaining for validation

      for f in files[:train]:
        current_path = os.path.join(folder_path, f)
        new_path = os.path.join(training_path, f)
        shutil.move(current_path, new_path)

      for f in files[train:]:
        current_path = os.path.join(folder_path, f)
        new_path = os.path.join(validate_path, f)
        shutil.move(current_path, new_path)

  except FileNotFoundError:
    print("No new data found.")

  return


def train_epoch(loader, model, criterion, optimizer, device):
  model.train()
  total_loss, total_correct = 0.0, 0

  for imgs, labels in loader:
      imgs, labels = imgs.to(device), labels.to(device)

      optimizer.zero_grad()
      outputs = model(imgs)
      loss = criterion(outputs, labels)
      loss.backward()
      optimizer.step()

      total_loss += loss.item() * imgs.size(0)
      total_correct += (outputs.argmax(1) == labels).sum().item()

  avg_loss = total_loss / len(loader.dataset)
  avg_acc = total_correct / len(loader.dataset)

  return avg_loss, avg_acc


def eval_epoch(loader, model, criterion, device):
  model.eval()
  total_loss, total_correct = 0.0, 0

  with torch.no_grad():
      for imgs, labels in loader:
          imgs, labels = imgs.to(device), labels.to(device)
          outputs = model(imgs)
          loss = criterion(outputs, labels)

          total_loss += loss.item() * imgs.size(0)
          total_correct += (outputs.argmax(1) == labels).sum().item()

  avg_loss = total_loss / len(loader.dataset)
  avg_acc = total_correct / len(loader.dataset)
  return avg_loss, avg_acc


def finetune(num_epochs):
  # Imagenet normalisation statistics
  mean = [0.485, 0.456, 0.406]
  std = [0.229, 0.224, 0.225]

  # makes dataset larger
  train_transform = transforms.Compose([
      transforms.RandomResizedCrop(224),
      transforms.RandomHorizontalFlip(),
      transforms.ToTensor(),
      transforms.Normalize(mean, std),
  ])

  # crop images to 224x224px for deep learning
  val_transform = transforms.Compose([
      transforms.Resize(256),
      transforms.CenterCrop(224),
      transforms.ToTensor(),
      transforms.Normalize(mean, std),
  ])

  # access training dataset
  train_ds = datasets.ImageFolder("/content/drive/MyDrive/MNT/Training Dataset/Train", transform=train_transform)
  val_ds = datasets.ImageFolder("/content/drive/MyDrive/MNT/Training Dataset/Validate", transform=val_transform)

  train_loader = DataLoader(train_ds, batch_size=32, shuffle=True,  num_workers=4)
  val_loader = DataLoader(val_ds, batch_size=32, shuffle=False, num_workers=4)

  num_classes = len(train_ds.classes)

  # load in EfficientNet-LiteB0 pretrained model
  model = timm.create_model('efficientnet_lite0', pretrained=True)

  # replace classifications
  in_features = model.classifier.in_features
  model.classifier = nn.Linear(in_features, num_classes)

  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  model = model.to(device)

  criterion = nn.CrossEntropyLoss()
  optimizer = optim.Adam(model.parameters(), lr=1e-4)
  scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

  best_val_acc = 0.0
  patience = 10
  no_improvement_epochs = 0

  for epoch in tqdm(range(num_epochs), desc="Training"):
      train_loss, train_acc = train_epoch(train_loader, model, criterion, optimizer, device)

      val_loss, val_acc = eval_epoch(val_loader, model, criterion, device)

      scheduler.step()

      if val_acc > best_val_acc:
          best_val_acc = val_acc
          no_improvement_epochs = 0
          torch.save(model.state_dict(), 'best_model.pth') # save best model
      else:
          no_improvement_epochs += 1

      if no_improvement_epochs >= patience: # early exit to prevent overfitting
          print(f"Early stopping at epoch {epoch+1}. No improvement in the last {patience} epochs.")
          break

  with open("classifications.txt", "w") as file: # write classifications file
    file.write(train_ds.classes[0])
    for Class in train_ds.classes[1:]:
        file.write(f"\n{Class}")

  print("Finished finetuning.")

  return

if __name__ == "__main__":
  drive.mount('/content/drive', force_remount=True)

  split_new_data() # place new data in dataset in training and validation if there is any

  finetune(50)

Cell Identifier.
<br>
<br>
Ensure classifications.txt and best_model.pth has been uploaded to files under content.

In [None]:
!pip install cellpose

In [None]:
import torch, timm, cv2
import cellpose
import numpy as np
from PIL import Image
from cellpose import models
from torchvision import transforms


if __name__ == "__main__":
  with open("classifications.txt", "r") as file:
      classes = [line.strip() for line in file]
  num_classes = len(classes)

  cellpose_model = models.CellposeModel(gpu="True")
  identifier_model = timm.create_model('efficientnet_lite0', pretrained=False)

  in_features = identifier_model.classifier.in_features
  identifier_model.classifier = torch.nn.Linear(in_features, num_classes)

  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

  finetuned_model = torch.load('best_model.pth', map_location=device)
  identifier_model.load_state_dict(finetuned_model)
  identifier_model = identifier_model.to(device)
  identifier_model.eval()

  mean = [0.485, 0.456, 0.406]
  std  = [0.229, 0.224, 0.225]

  val_transform = transforms.Compose([
      transforms.Resize(256),
      transforms.CenterCrop(224),
      transforms.ToTensor(),
      transforms.Normalize(mean, std),
  ])

  while True:
    image = Image.open(input("Enter image path here: ")).convert("RGB") # pytorch reads PIL images
    try:
      input_tensor = val_transform(image).unsqueeze(0).to(device)
      break

    except FileNotFoundError:
      print("Image not found.")

  with torch.no_grad():
      logits = identifier_model(input_tensor)
      probs = torch.softmax(logits, dim=1)
      pred_class = logits.argmax(dim=1).item()
      pred_conf = probs[0, pred_class].item()

  image = np.array(image)
  masks, flows, styles = cellpose_model.eval(image) # cellpose reads numpy arrays
  cell_count = len(np.unique(masks)) - 1

  print(f"Predicted classification: {classes[pred_class]}  (confidence: {pred_conf:.2%}) (cell count: {cell_count})")