# CS230 Project - Macrophage Classification

classify macrophage cells into 8 different classes using resnet50

the 8 classes are: Dividing, Early Phagocytosis, Fried egg, Intermediate Phagocytosis, Late Phagocytosis, Migrating, Quiescent, and Searching

We use CellSighter (https://github.com/KerenLab/CellSighter) as the starting point for our classification task. The model is implemented in PyTorch and trained using a ResNet-50 backbone.

In [1]:
# import packages
import sys
sys.path.append(".")
import os
import torch
import numpy as np
import json
import tifffile
import cv2
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torchvision import models
from torch import nn
import torchvision
from torchvision.transforms import Lambda
from torch.utils.tensorboard import SummaryWriter

## data preprocessing

first need to preprocess the raw labels and split into train/val sets. 

**note**: i already ran this part so you can skip these cells if train.json and val.json already exist

In [2]:
# paths to raw data
labels_path = '/scratch/groups/emmalu/samutiti/phage_labels.json'
images_dir = '/scratch/groups/emmalu/samutiti/raw_phage_crops'
output_dir = '/scratch/users/rchi/cs230/macrophage_data'

# class names in alphabetical order
class_names = [
    'Dividing',
    'Early Phagocytosis',
    'Fried egg',
    'Intermediate Phagocytosis',
    'Late Phagocytosis',
    'Migrating',
    'Quiescent',
    'Searching'
]

print(f"we have {len(class_names)} classes")

we have 8 classes


In [3]:
# load labels from json
with open(labels_path, 'r') as f:
    all_labels = json.load(f)

print(f"total samples in raw data: {len(all_labels)}")

# filter out uncertain samples
# some cells are labeled as "Uncertain" which we won't use for training
filtered_labels = {}
for filename, label_list in all_labels.items():
    if 'Uncertain' not in label_list:
        filtered_labels[filename] = label_list

print(f"after removing uncertain: {len(filtered_labels)}")

total samples in raw data: 4052
after removing uncertain: 2349


In [4]:
# handle multi-label samples
# some cells have multiple labels like ["Migrating", "Searching"]
# for simplicity just take the first label
label2idx = {name: idx for idx, name in enumerate(class_names)}

dataset = []
multi_label_count = 0

for filename, label_list in filtered_labels.items():
    if len(label_list) > 1:
        multi_label_count += 1
    
    # take first label as primary
    primary_label = label_list[0]
    label_idx = label2idx[primary_label]
    
    # fix filename (remove _V suffix if exists)
    actual_filename = filename.replace('_V.tiff', '.tiff')
    
    dataset.append({
        'filename': actual_filename,
        'label': label_idx,
        'original_labels': label_list
    })

print(f"samples with multiple labels: {multi_label_count}")
print(f"total dataset size: {len(dataset)}")

samples with multiple labels: 1078
total dataset size: 2349


In [5]:
# split into train and validation sets (80/20 split)
# use stratified split to keep class distribution same in both sets
np.random.seed(42)

# group samples by class
class_samples = {}
for sample in dataset:
    label = sample['label']
    if label not in class_samples:
        class_samples[label] = []
    class_samples[label].append(sample)

train_data = []
val_data = []

# split each class separately to maintain distribution
for label, samples in class_samples.items():
    np.random.shuffle(samples)
    n_val = int(len(samples) * 0.2)  # 20% for validation
    
    val_data.extend(samples[:n_val])
    train_data.extend(samples[n_val:])

# shuffle again
np.random.shuffle(train_data)
np.random.shuffle(val_data)

print(f"train samples: {len(train_data)}")
print(f"validation samples: {len(val_data)}")

# show class distribution
print("\ntrain set class distribution:")
for i, name in enumerate(class_names):
    count = sum(1 for s in train_data if s['label'] == i)
    print(f"  {name}: {count}")

print("\nval set class distribution:")
for i, name in enumerate(class_names):
    count = sum(1 for s in val_data if s['label'] == i)
    print(f"  {name}: {count}")

train samples: 1882
validation samples: 467

train set class distribution:
  Dividing: 196
  Early Phagocytosis: 75
  Fried egg: 432
  Intermediate Phagocytosis: 60
  Late Phagocytosis: 61
  Migrating: 651
  Quiescent: 57
  Searching: 350

val set class distribution:
  Dividing: 49
  Early Phagocytosis: 18
  Fried egg: 107
  Intermediate Phagocytosis: 15
  Late Phagocytosis: 15
  Migrating: 162
  Quiescent: 14
  Searching: 87


In [6]:
# save train and val splits to json files
os.makedirs(output_dir, exist_ok=True)

train_path = os.path.join(output_dir, 'train.json')
val_path = os.path.join(output_dir, 'val.json')

with open(train_path, 'w') as f:
    json.dump(train_data, f, indent=2)

with open(val_path, 'w') as f:
    json.dump(val_data, f, indent=2)

# also save class names for reference
with open(os.path.join(output_dir, 'class_names.txt'), 'w') as f:
    f.write('\n'.join(class_names))

## training

from here we load the preprocessed data and train the model

In [7]:
# load config
config_path = '/scratch/users/rchi/cs230/macrophage_experiment/config.json'
with open(config_path) as f:
    config = json.load(f)

print(config)

{'crop_input_size': 60, 'crop_size': 128, 'root_dir': '/scratch/users/rchi/cs230/macrophage_data', 'train_set': ['train'], 'val_set': ['val'], 'num_classes': 8, 'epoch_max': 50, 'lr': 0.001, 'blacklist': [], 'batch_size': 32, 'num_workers': 4, 'channels_path': '/scratch/users/rchi/cs230/macrophage_data/channels.txt', 'weight_to_eval': '', 'sample_batch': True, 'to_pad': False, 'hierarchy_match': None, 'size_data': None, 'aug': True}


## Data Loader for macrophage crops

The macrophage images have already been cropped into single-cell macrophage tiff images

data loader was adapted from cell_crop.py in the CellSighter repository

In [8]:
# class to load one crop image
class MacrophageCrop:
    def __init__(self, filename, label, images_dir):
        self._filename = filename
        self._label = label
        self._images_dir = images_dir
        self._image_id = filename
        self._cell_id = 0
    
    def sample(self, mask=False):
        # load the tiff image
        img_path = os.path.join(self._images_dir, self._filename)
        image = tifffile.imread(img_path).astype(np.float32)
        
        # make sure its H x W x C
        if len(image.shape) == 2:
            image = image[:, :, np.newaxis]
        
        result = {
            'cell_id': self._cell_id,
            'image_id': self._image_id,
            'image': image,
            'label': np.array(self._label, dtype=np.longlong),
        }
        
        if mask:
            # create mask - since each crop has one cell just use all ones
            h, w = image.shape[:2]
            result['mask'] = np.ones((h, w), dtype=np.float32)
            result['all_cells_mask'] = np.ones((h, w), dtype=np.float32)
        
        return result

# function to load all crops from json
def load_crops(data_json_path, images_dir):
    with open(data_json_path, 'r') as f:
        data = json.load(f)
    
    crops = []
    for item in data:
        crop = MacrophageCrop(
            filename=item['filename'],
            label=item['label'],
            images_dir=images_dir
        )
        crops.append(crop)
    
    return crops

## augmentation and transforms

we did some data augmentation because we only have 2000-3000 labeled crops. 

Since the crops are not uniform in size, smaller crops were padded to 60Ã—60

Most of the augmentation functions were adapted from the CellSighter

In [9]:
# Import shift augmentation from the working file
import sys
sys.path.insert(0, '/scratch/users/rchi/cs230')
from data.shift_augmentation import ShiftAugmentation

# poisson sampling augmentation
def poisson_sampling(x):
    # add some noise using poisson distribution
    blur = cv2.GaussianBlur(x[:, :, :-2], (5, 5), 0)
    x[:, :, :-2] = np.random.poisson(lam=blur, size=x[:, :, :-2].shape)
    return x

# augment cell mask shape
def cell_shape_aug(x):
    if np.random.random() < 0.5:
        cell_mask = x[:, :, -1]
        kernel_size = np.random.choice([2, 3, 5])
        kernel = np.ones(kernel_size, np.uint8)
        img_dilation = cv2.dilate(cell_mask, kernel, iterations=1)
        x[:, :, -1] = img_dilation
    return x

# augment environment mask
def env_shape_aug(x):
    if np.random.random() < 0.5:
        cell_mask = x[:, :, -2]
        kernel_size = np.random.choice([2, 3, 5])
        kernel = np.ones(kernel_size, np.uint8)
        img_dilation = cv2.dilate(cell_mask, kernel, iterations=1)
        x[:, :, -2] = img_dilation
    return x

# pad image to target size if its too small
def pad_to_size(x, target_size):
    c, h, w = x.shape
    pad_h = max(0, target_size - h)
    pad_w = max(0, target_size - w)
    
    if pad_h > 0 or pad_w > 0:
        pad_left = pad_w // 2
        pad_right = pad_w - pad_left
        pad_top = pad_h // 2
        pad_bottom = pad_h - pad_top
        
        padding = (pad_left, pad_right, pad_top, pad_bottom)
        x = torchvision.transforms.functional.pad(x, padding, fill=0, padding_mode='constant')
    
    return x

# validation transform - just resize and center crop
def val_transform(crop_size):
    return torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        Lambda(lambda x: pad_to_size(x, crop_size)),
        torchvision.transforms.CenterCrop((crop_size, crop_size))
    ])

# training transform - all augmentations including shift
def train_transform(crop_size, shift):
    return torchvision.transforms.Compose([
        torchvision.transforms.Lambda(poisson_sampling),
        torchvision.transforms.Lambda(cell_shape_aug),
        torchvision.transforms.Lambda(env_shape_aug),
        torchvision.transforms.ToTensor(),
        Lambda(lambda x: pad_to_size(x, crop_size)),
        torchvision.transforms.RandomRotation(degrees=(0, 360)),
        Lambda(lambda x: ShiftAugmentation(n_size=crop_size, shift_max=shift)(x) if np.random.random() < 0.5 else x),
        torchvision.transforms.CenterCrop((crop_size, crop_size)),
        torchvision.transforms.RandomHorizontalFlip(p=0.75),
        torchvision.transforms.RandomVerticalFlip(p=0.75),
    ])

## Dataset class

pytorch dataset to load the images

CellCropsDataset was adapted from CellSighter data.py

Since our images are already cropped and CellSighter expects both image and mask inputs, we append a mask as the last channel

In [10]:
class CellCropsDataset(Dataset):
    def __init__(self, crops, mask=False, transform=None):
        super().__init__()
        self._crops = crops
        self._transform = transform
        self._mask = mask
    
    def __len__(self):
        return len(self._crops)
    
    def __getitem__(self, idx):
        sample = self._crops[idx].sample(self._mask)
        
        # stack image and mask together 
        # image has 4 channels, then all_cells_mask, then mask
        aug = self._transform(np.dstack(
            [sample['image'], sample['all_cells_mask'][:, :, np.newaxis], sample['mask'][:, :, np.newaxis]]
        )).float()
        
        # split back to image and mask
        sample['image'] = aug[:4, :, :]  # first 4 channels
        sample['mask'] = aug[[5], :, :]  # last channel
        
        # remove fields that cant be batched
        if 'all_cells_mask' in sample:
            del sample['all_cells_mask']
        if 'all_cells_mask_seperate' in sample:
            del sample['all_cells_mask_seperate']
        
        return sample

## Model

using resnet50 as backbone, need to change first layer because we have 5 channels (4 image channels + 1 mask)

adapted from model.py from CellSighter

In [11]:
class Model(nn.Module):
    def __init__(self, input_len, num_classes):
        super(Model, self).__init__()
        # use resnet50
        self.model = models.resnet50(num_classes=num_classes)
        
        # change first conv layer to accept our number of input channels
        self.model.conv1 = torch.nn.Conv2d(input_len, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        
        # initialize weights
        nn.init.kaiming_normal_(self.model.conv1.weight, mode='fan_out', nonlinearity='relu')
        
        self.softmax = nn.Softmax(dim=1)
    
    def forward(self, x):
        x = self.model(x)
        # only apply softmax during inference
        if not self.training:
            x = self.softmax(x)
        return x

## weighted sampler for class imbalance

Some classes have fewer samples, so we apply class weighting to balance the training process

adapted from train.py in the CellSighter repository



In [12]:
def make_weighted_sampler(crops):
    labels = np.array([c._label for c in crops])
    unique_labels = np.unique(labels)
    
    # count how many samples in each class
    class_counts = {}
    for label in unique_labels:
        class_counts[label] = len(np.where(labels == label)[0])
    
    print("class distribution:")
    for label, count in class_counts.items():
        print(f"  class {label}: {count} samples")
    
    # weight = total / class_count
    # rare classes get higher weight
    total = sum(class_counts.values())
    weights = {}
    for label, count in class_counts.items():
        weights[label] = total / count
    
    # assign weight to each sample
    sample_weights = np.array([weights[label] for label in labels])
    sample_weights = torch.from_numpy(sample_weights)
    
    return WeightedRandomSampler(sample_weights.double(), len(sample_weights))

## training function

train for one epoch. also adapted from train.py from CellSighter

In [13]:
def train_one_epoch(model, dataloader, optimizer, criterion, epoch, device):
    model.train()
    
    total_loss = 0
    for i, batch in enumerate(dataloader):
        # get data
        x = batch['image']
        m = batch.get('mask', None)
        
        # concat mask to image
        if m is not None:
            x = torch.cat([x, m], dim=1)
        
        # move to gpu
        x = x.to(device=device)
        y = batch['label'].to(device=device)
        
        # forward pass
        optimizer.zero_grad()
        y_pred = model(x)
        loss = criterion(y_pred, y)
        
        # backprop
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        
        # print loss 
        if i % 50 == 0:
            print(f"epoch {epoch} batch {i}/{len(dataloader)} loss: {loss.item():.4f}")
    
    avg_loss = total_loss / len(dataloader)
    return avg_loss

## validation function

evaluate on validation set

In [14]:
def validate(model, dataloader, criterion, device):
    model.eval()
    
    total_loss = 0
    correct = 0
    total = 0
    
    # collect all predictions and labels for CSV
    all_preds = []
    all_labels = []
    all_probs = []
    all_cell_ids = []
    all_image_ids = []
    
    with torch.no_grad():
        for batch in dataloader:
            x = batch['image']
            m = batch.get('mask', None)
            
            if m is not None:
                x = torch.cat([x, m], dim=1)
            
            x = x.to(device=device)
            y = batch['label'].to(device=device)
            
            y_pred = model(x)
            loss = criterion(y_pred, y)
            
            total_loss += loss.item()
            
            # calculate accuracy
            _, predicted = torch.max(y_pred, 1)
            total += y.size(0)
            correct += (predicted == y).sum().item()
            
            # collect for CSV
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(y.cpu().numpy())
            all_probs.append(y_pred.cpu().numpy())
            all_cell_ids.extend(batch['cell_id'])
            all_image_ids.extend(batch['image_id'])
    
    avg_loss = total_loss / len(dataloader)
    accuracy = correct / total
    
    # combine all probabilities
    all_probs = np.concatenate(all_probs, axis=0)
    
    return avg_loss, accuracy, all_preds, all_labels, all_probs, all_cell_ids, all_image_ids

In [15]:
# function to save validation results to CSV
def save_val_results(save_path, preds, labels, probs, cell_ids, image_ids):
    import pandas as pd
    
    # create dataframe
    results = pd.DataFrame({
        'pred': preds,
        'label': labels,
        'pred_prob': probs.max(axis=1),  # max probability
        'cell_id': cell_ids,
        'image_id': image_ids,
        'prob_list': [prob.tolist() for prob in probs]  # all class probabilities
    })
    
    # save to CSV
    results.to_csv(save_path, index=False)
    print(f"saved validation results to: {save_path}")

## Load data

In [16]:
# paths
train_json = os.path.join(config['root_dir'], 'train.json')
val_json = os.path.join(config['root_dir'], 'val.json')
images_dir = '/scratch/groups/emmalu/samutiti/raw_phage_crops'

train_crops = load_crops(train_json, images_dir)
print(f"loaded {len(train_crops)} training samples")

val_crops = load_crops(val_json, images_dir)
print(f"loaded {len(val_crops)} validation samples")

# make weighted sampler
sampler = make_weighted_sampler(train_crops)

loaded 1882 training samples
loaded 467 validation samples
class distribution:
  class 0: 196 samples
  class 1: 75 samples
  class 2: 432 samples
  class 3: 60 samples
  class 4: 61 samples
  class 5: 651 samples
  class 6: 57 samples
  class 7: 350 samples


In [17]:
crop_size = config["crop_input_size"]
use_augmentation = config.get("aug", True)

# training dataset with augmentation
train_transforms = train_transform(crop_size, shift=5) if use_augmentation else val_transform(crop_size)
train_dataset = CellCropsDataset(train_crops, transform=train_transforms, mask=True)

# validation dataset without augmentation
val_dataset = CellCropsDataset(val_crops, transform=val_transform(crop_size), mask=True)

print(f"crop size: {crop_size}x{crop_size}")

crop size: 60x60


## create data loaders

In [18]:
batch_size = config["batch_size"]
num_workers = config["num_workers"]
use_sampler = config["sample_batch"]

train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    num_workers=num_workers,
    sampler=sampler if use_sampler else None,
    shuffle=False if use_sampler else True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    num_workers=num_workers,
    shuffle=False
)

print(f"batch size: {batch_size}")
print(f"train batches: {len(train_loader)}")
print(f"val batches: {len(val_loader)}")

batch size: 32
train batches: 59
val batches: 15


## setup model

create the model and move it to gpu

In [19]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"using device: {device}")

# 4 image channels + 1 mask channel = 5 input channels
num_input_channels = 5
num_classes = config["num_classes"]

# create model
model = Model(num_input_channels, num_classes)
model = model.to(device=device)

using device: cuda


## setup optimizer and loss

using adam optimizer and cross entropy loss

In [20]:
learning_rate = config["lr"]
num_epochs = config["epoch_max"]

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.85)
criterion = nn.CrossEntropyLoss()

print(f"learning rate: {learning_rate}")
print(f"number of epochs: {num_epochs}")

learning rate: 0.001
number of epochs: 50


## training loop

now train the model! takes about 1.5 hr

In [21]:
# where to save checkpoints
save_dir = '/scratch/users/rchi/cs230/macrophage_experiment'
os.makedirs(save_dir, exist_ok=True)

# keep track of losses
train_losses = []
val_losses = []
val_accuracies = []

print("starting training...\n")

for epoch in range(num_epochs):
    print(f"\n===== Epoch {epoch+1}/{num_epochs} =====")
    
    # train
    train_loss = train_one_epoch(model, train_loader, optimizer, criterion, epoch, device)
    train_losses.append(train_loss)
    print(f"training loss: {train_loss:.4f}")
    
    # validate every 5 epochs
    if epoch % 5 == 0:
        val_loss, val_acc, preds, labels, probs, cell_ids, image_ids = validate(model, val_loader, criterion, device)
        val_losses.append(val_loss)
        val_accuracies.append(val_acc)
        print(f"validation loss: {val_loss:.4f}")
        print(f"validation accuracy: {val_acc:.4f}")
        
        # save validation CSV every 10 epochs
        if epoch % 10 == 0:
            val_csv_path = os.path.join(save_dir, f"val_results_{epoch}.csv")
            save_val_results(val_csv_path, preds, labels, probs, cell_ids, image_ids)
    
    # save checkpoint every 10 epochs
    if epoch % 10 == 0:
        checkpoint_path = os.path.join(save_dir, f"checkpoint_epoch_{epoch}.pth")
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'train_loss': train_loss,
        }, checkpoint_path)
        print(f"saved checkpoint: {checkpoint_path}")
    
    # update learning rate
    if epoch > 0 and epoch % 10 == 0:
        scheduler.step()
        current_lr = optimizer.param_groups[0]['lr']
        print(f"learning rate updated to: {current_lr:.6f}")

print("\ntraining finished!")

starting training...


===== Epoch 1/50 =====
epoch 0 batch 0/59 loss: 2.2552
epoch 0 batch 50/59 loss: 1.7576
training loss: 2.3620
validation loss: 2.0000
validation accuracy: 0.3405
saved validation results to: /scratch/users/rchi/cs230/macrophage_experiment/val_results_0.csv
saved checkpoint: /scratch/users/rchi/cs230/macrophage_experiment/checkpoint_epoch_0.pth

===== Epoch 2/50 =====
epoch 1 batch 0/59 loss: 1.8510
epoch 1 batch 50/59 loss: 2.0272
training loss: 1.8098

===== Epoch 3/50 =====
epoch 2 batch 0/59 loss: 1.8455
epoch 2 batch 50/59 loss: 1.7452
training loss: 1.8475

===== Epoch 4/50 =====
epoch 3 batch 0/59 loss: 1.6943
epoch 3 batch 50/59 loss: 1.5884
training loss: 1.7351

===== Epoch 5/50 =====
epoch 4 batch 0/59 loss: 1.5499
epoch 4 batch 50/59 loss: 1.7386
training loss: 1.6393

===== Epoch 6/50 =====
epoch 5 batch 0/59 loss: 1.4242
epoch 5 batch 50/59 loss: 1.6248
training loss: 1.6309
validation loss: 1.9666
validation accuracy: 0.2955

===== Epoch 7/50 =====


## final evaluation

test on validation set one more time

In [22]:
print("running final evaluation...")
val_loss, val_acc, preds, labels, probs, cell_ids, image_ids = validate(model, val_loader, criterion, device)

print(f"\nfinal validation loss: {val_loss:.4f}")
print(f"final validation accuracy: {val_acc:.4f}")

# save final model
final_model_path = os.path.join(save_dir, "final_model.pth")
torch.save(model.state_dict(), final_model_path)
print(f"\nsaved final model to: {final_model_path}")

# save final validation results CSV
final_csv_path = os.path.join(save_dir, "val_results_final.csv")
save_val_results(final_csv_path, preds, labels, probs, cell_ids, image_ids)

# store for later use
final_val_loss = val_loss
final_val_acc = val_acc

running final evaluation...

final validation loss: 1.7902
final validation accuracy: 0.5054

saved final model to: /scratch/users/rchi/cs230/macrophage_experiment/final_model.pth
saved validation results to: /scratch/users/rchi/cs230/macrophage_experiment/val_results_final.csv


## plot training curves

visualize the training progress

## confusion matrix

lets see which classes get confused with each other

In [23]:
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns

# get predictions for all validation data
model.eval()
all_preds = []
all_labels = []

with torch.no_grad():
    for batch in val_loader:
        x = batch['image']
        m = batch.get('mask', None)
        
        if m is not None:
            x = torch.cat([x, m], dim=1)
        
        x = x.to(device=device)
        y = batch['label'].to(device=device)
        
        y_pred = model(x)
        _, predicted = torch.max(y_pred, 1)
        
        all_preds.extend(predicted.cpu().numpy())
        all_labels.extend(y.cpu().numpy())

# make confusion matrix
cm = confusion_matrix(all_labels, all_preds)

# plot it
class_names = ['Dividing', 'Early Phago', 'Fried egg', 'Inter Phago', 'Late Phago', 'Migrating', 'Quiescent', 'Searching']

plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
plt.xlabel('predicted')
plt.ylabel('true')
plt.title('confusion matrix')
plt.tight_layout()
plt.savefig(os.path.join(save_dir, 'confusion_matrix.png'))
plt.show()

print("confusion matrix saved")

NameError: name 'plt' is not defined

## per class accuracy

see how well each class does