# Train a model to classify mechanism of action from cellular images

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from pathlib import Path

import cv2
import numpy as np
import pytorch_lightning as pl
from pytorch_hcs.datasets import BBBC021DataModule, BBBC021Dataset
from pytorch_hcs.models import ResNet18, ResNet101, ResNet18Embeddings
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger, TensorBoardLogger

cv2.setNumThreads(0)

In [None]:
import wandb
wandb.login()

In [None]:
data_path = Path("data")

# Choose model

In [None]:
# model_cls = ResNet18
# model_cls = ResNet101

model_cls = ResNet18Embeddings

# Set up PyTorch-Lightning DataModule

In [None]:
tv_batch_sizes = {
    'ResNet18': 10,
    'ResNet101': 5,
    'ResNet18Embeddings': 10,
}

In [None]:
dm = BBBC021DataModule(
    num_workers=8,
    tv_batch_size=tv_batch_sizes[model_cls.__name__],
    t_batch_size=8,
)

dm.setup()

# Optional: compute class weights

Set `with_class_balance` to `True`/`False` to enable/disable.

In [None]:
with_class_balance = True

if with_class_balance:
    train_class_weights = dm.train_dataset.compute_class_weights()
    val_class_weights = dm.val_dataset.compute_class_weights()
    test_class_weights = dm.test_dataset.compute_class_weights()

    display(
        "training weights",
        {
            name: weight
            for name, weight in zip(dm.label_to_class.values(), train_class_weights)
        },
    )

    display(
        "validation weights",
        {
            name: weight
            for name, weight in zip(dm.label_to_class.values(), val_class_weights)
        },
    )

    display(
        "test weights",
        {
            name: weight
            for name, weight in zip(dm.label_to_class.values(), test_class_weights)
        },
    )

else:
    train_class_weights = None
    val_class_weights = None
    test_class_weights = None

# Initialize model

In [None]:
extra_kwargs = dict(num_channels=3) if "Embeddings" in model_cls.__name__ else {}

model = model_cls(
    num_classes=dm.num_classes,
    learning_rate=0.0001,
    pretrained=True,
    lambdalr_factor=1,
    plateau_patience=10,
    plateau_factor=0.1,
    train_class_weights=train_class_weights,
    val_class_weights=val_class_weights,
    test_class_weights=test_class_weights,
    **extra_kwargs,
)

# Set up training logger

Model checkpoint artifacts from the training will be accessible under `'model-{version}'`,
where `version` is by default set to the class name of the PyTorch-Lightning module.

In [None]:
name = model_cls.__name__

logger = WandbLogger(name=name, version=name, project='pytorch-hcs', log_model='all')
logger.watch(model)

## Alternatively, use local `TensorBoardLogger`

In [None]:
# logger = TensorBoardLogger(name=name, save_dir=f'{data_path}/tensorboard')

# Path(f'{data_path}/tensorboard/{name}').mkdir(exist_ok=True, parents=True)

# Directory for local copy of weights

Necessary if using `TensorBoardLogger`, redundant if you have a W&B account to save artifacts to.

In [None]:
weights_path = data_path / f"weights/{name}/version_{logger.version}"
weights_path.mkdir(exist_ok=True, parents=True)

# Early stopping

Stop training if the validation loss does not improve after a set number of epochs.

In [None]:
early_stop_callback = pl.callbacks.EarlyStopping(
    monitor="val_loss", min_delta=0.00, patience=10, verbose=False, mode="min"
)

# Model checkpointing

Only save a new version of the model if the validation loss has improved. Save only the model with the best validation loss.

In [None]:
checkpoint_callback = ModelCheckpoint(
    dirpath=str(weights_path),
    save_top_k=1,
    verbose=True,
    monitor="val_loss",
    mode="min",
)

# Train model

Set the `gpus` keyword argument to `0` to train using your CPU (which will be much slower).

In [None]:
trainer = pl.Trainer(
    logger=logger,
    gpus=1,  # 0 will use CPU instead
    callbacks=[early_stop_callback, checkpoint_callback],
    benchmark=True,
    precision=16,
)

trainer.fit(model, dm)

In [None]:
if isinstance(logger, WandbLogger):
    wandb.finish(0)