# Colonoscopy Polyp Detection w/Faster R-CNN and Lightning

## imports and setup

In [None]:
# %load_ext autoreload
# %autoreload 2

import os
from pathlib import Path
import torch

In [None]:
CONFIG = dict (
    project = "hlc-polyp-detection",
    architecture = "fasterrcnn_resnext50_32x4d",
    dataset_id = "hlc-custom-polyp-detection",
    infra = "osx",
    num_classes = 2,
    max_epochs = 100,
    lr=0.01,
    min_lr=0.0000001,
    epochs=15,
    batch_size=4,
    nesterov=True,
    momentum=0.9,
    weight_decay=0.0005,
    clip_limit=0.25,
    difference=False,
    name="fancy_walrus"
)

ROOT_DIR = os.path.abspath("./")
DATA_DIR = os.path.join(ROOT_DIR, "data")
MODEL_DIR = os.path.join(ROOT_DIR, "model")
LOG_DIR = os.path.join(ROOT_DIR, "log")

NUM_CLASSES = 2

BATCH_SIZE = 4
INPUT_SIZE = 1024
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
NUM_WORKERS = 4

wandb_key = Path(os.path.join(ROOT_DIR, "wandb.txt")).read_text().strip()
os.environ["WANDB_API_KEY"] = wandb_key
os.environ["WANDB_NOTEBOOK_NAME"] = os.path.join(ROOT_DIR, "LightningFastRCNN-polyps.ipynb")

In [None]:
if CONFIG["infra"] == "paperspace":
    # !pip install -r alubumentations pytorch-lightning wandb --upgrade
    import wandb
    !nvidia-smi
else:
    import wandb

In [None]:
from contextlib import contextmanager

@contextmanager
def wandb_context(configuration=CONFIG):
    run = wandb.init(reinit=True, config=configuration, project=CONFIG['project'])
    try:
        yield run
    finally:
        wandb.finish()

In [None]:
from src.PolypsPLDataModule import PolypsPLDataModule

polyp_dm = PolypsPLDataModule(data_dir='./data', batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)

In [None]:
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor, EarlyStopping

def train_model(model, run_name, dm=None, run=None):
    wandb_logger = None
    if run is not None:
        run.config["train_run_name"] = run_name

        from pytorch_lightning.loggers import WandbLogger
        wandb_logger = WandbLogger()
    
    chkpt = ModelCheckpoint(
        dirpath=os.path.join(MODEL_DIR, "chkpts"),
        filename=f"{CONFIG['name']}-chkpt-{run_name}",
        monitor="val_recall",
        mode="max")
    
    lrnrate = LearningRateMonitor(logging_interval="step", log_momentum=True)

    earlystop = EarlyStopping(
        monitor="val_recall",
        patience=50,
        verbose=True,
        mode="max")
    
    trainer = Trainer(
        accelerator="gpu" if torch.cuda.is_available() else "cpu", 
        logger=wandb_logger,
        callbacks=[chkpt, lrnrate, earlystop],
        log_every_n_steps=1,
        max_epochs=CONFIG["max_epochs"])
    
    trainer.fit(
        model,
        datamodule=dm)
    
    return LightningFasterModule.load_from_checkpoint(chkpt.best_model_path)

list(sorted(glob.glob(os.path.join(root_dir, "*", stage, "images", "*.[jp][pn]g"))))

In [None]:
# %load_ext autoreload
# %autoreload 2

In [None]:
from src.LightningFasterModule import LightningFasterModule
# this should have been:
# with wandb_context(CONFIG) as run:
#     model = LightningFasterModule()
with wandb_context() as run:
    model = LightningFasterModule(CONFIG)
    if run is not None:
        # watch the hyperparameters and gradients of the model 
        run.watch(model)
    # trained_model = train_model(model, run_name="head", dm=polyp_dm, run=run)

    model.full_train()
    model = train_model(model, run_name="full", dm=polyp_dm, run=run)

In [None]:
from src.LightningFasterModule import LightningFasterModule

with wandb_context() as run:
    model = LightningFasterModule(CONFIG)
    if run is not None:
        run.watch(model)
    # model.load_from_checkpoint(checkpoint_path=os.path.join(CONFIG.checkpoint_dir, "best_

    # trained_model = train_model(model, run_name="head", dm=polyp_dm, run=run)
    # model = test_model(model, run_name="test_run", dm=polyp_dm, run=run)