<a href="https://colab.research.google.com/github/peng-lab/idkidc/blob/manuel_sophia/Hackathon.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Set up ssh connection to github

In [None]:
!ssh-keygen -t rsa -b 4096

In [None]:
!ssh-keyscan -t rsa github.com >> ~/.ssh/known_hosts

In [None]:
!cat /root/.ssh/id_rsa.pub

In [None]:
!ssh -T git@github.com

In [None]:
!git config --global user.email "sophia.wagner@t-online.de"
!git config --global user.name "sophiajw"

In [None]:
!git clone git@github.com:peng-lab/idkidc.git

# Install packages

In [None]:
%cd /content/idkidc

In [None]:
!git checkout manuel_sophia

In [None]:
!pip install pytorch-lightning
!pip install wandb
!pip install dgl
!pip install einops

* add `default='/content/idkidc/config.yaml'` to `options.py`

# Start training your histo classifier

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from pathlib import Path
import numpy as np
import argparse
import pandas as pd
import yaml

import pytorch_lightning as pl
from sklearn.model_selection import StratifiedKFold, train_test_split
import torch
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger, CSVLogger
from torch.utils.data import DataLoader
import wandb

from options import Options
from data import MILDataset, MILDatasetIndices, get_multi_cohort_df
from classifier import ClassifierLightning
from utils import save_results

In [None]:
parser = Options()
args = parser.parser.parse_args('')  

# Load the configuration from the YAML file
with open(args.config_file, 'r') as f:
    config = yaml.load(f, Loader=yaml.FullLoader)

# Update the configuration with the values from the argument parser
for arg_name, arg_value in vars(args).items():
    if arg_value is not None and arg_name != 'config_file':
        config[arg_name]['value'] = getattr(args, arg_name)

# Create a flat config file without descriptions
config = {k: v['value'] for k, v in config.items()}

print('\n--- load options ---')
for name, value in sorted(config.items()):
    print(f'{name}: {str(value)}')

cfg = argparse.Namespace(**config)

In [None]:
# setup
cfg.seed = torch.randint(0, 1000, (1, )).item() if cfg.seed is None else cfg.seed
pl.seed_everything(cfg.seed, workers=True)

# saving locations
base_path = Path(cfg.save_dir)  # adapt to own target path
logging_name = f'{cfg.name}_{cfg.model}_{"-".join(cfg.cohorts)}_{cfg.norm}_{cfg.target}' if not cfg.debug else 'debug'
base_path = base_path / logging_name
base_path.mkdir(parents=True, exist_ok=True)
model_path = base_path / 'models'
fold_path = base_path / 'folds'
fold_path.mkdir(parents=True, exist_ok=True)
result_path = base_path / 'results'
result_path.mkdir(parents=True, exist_ok=True)

norm_val = 'raw' if cfg.norm in ['histaugan', 'efficient_histaugan'] else cfg.norm
norm_test = 'raw' if cfg.norm in ['histaugan', 'efficient_histaugan'] else cfg.norm


In [None]:
# --------------------------------------------------------
# load data
# --------------------------------------------------------
print('\n--- load dataset ---')
categories = ['Not mut.', 'Mutat.', 'nonMSIH', 'MSIH', 'WT', 'MUT', 'wt', 'MT', 'female', 'male', 'left', 'right']
data = get_multi_cohort_df(
    cfg.cohorts, [cfg.target], categories, norm=cfg.norm, feats=cfg.feats
)

test_ext_dataloader = []
for ext in cfg.ext_cohorts:
    dataset_ext = MILDataset(
        [ext], [cfg.target],
        categories,
        norm=norm_test,
        feats=cfg.feats,
        clini_info=cfg.clini_info
    )
    test_ext_dataloader.append(DataLoader(dataset=dataset_ext, batch_size=1, shuffle=False, num_workers=14, pin_memory=True))
    
train_cohorts = f'{", ".join(cfg.cohorts)}'
test_cohorts = [train_cohorts, *cfg.ext_cohorts]
results = {t: [] for t in test_cohorts}

In [None]:
# --------------------------------------------------------
# k-fold cross validation
# --------------------------------------------------------
skf = StratifiedKFold(n_splits=cfg.folds, shuffle=True, random_state=cfg.seed)
patient_df = data.groupby('PATIENT').first().reset_index()
target_stratisfy = cfg.target if type(cfg.target) is str else cfg.target[0]
splits = skf.split(patient_df, patient_df[target_stratisfy])

lrs = []

for l, (train_val_idxs, test_idxs) in enumerate(splits):
    train_idxs, val_idxs = train_test_split( train_val_idxs, stratify=patient_df.iloc[train_val_idxs][target_stratisfy], random_state=cfg.seed)
    
    # training dataset
    train_dataset = MILDatasetIndices(
        data,
        train_idxs, [cfg.target],
        num_tiles=cfg.num_tiles,
        pad_tiles=cfg.pad_tiles,
        norm=cfg.norm
    )
    train_dataloader = DataLoader(
        dataset=train_dataset, batch_size=cfg.bs, shuffle=True, num_workers=14, pin_memory=True
    )
    
    # validation dataset
    val_dataset = MILDatasetIndices(data, val_idxs, [cfg.target], norm=norm_val)
    val_dataloader = DataLoader(
        dataset=val_dataset, batch_size=1, shuffle=False, num_workers=14, pin_memory=True
    )

    # idx=2 since the ouput is feats, coords, labels
    num_pos = sum([train_dataset[i][2] for i in range(len(train_dataset))])
    cfg.pos_weight = torch.Tensor((len(train_dataset) - num_pos) / num_pos)
    cfg.criterion = "BCEWithLogitsLoss"

    # --------------------------------------------------------
    # model
    # --------------------------------------------------------
    model = ClassifierLightning(cfg)

    # --------------------------------------------------------
    # training setup
    # --------------------------------------------------------
    
    trainer = pl.Trainer(
        accelerator='auto',
        precision='16-mixed',
        accumulate_grad_batches=4,
        gradient_clip_val=1,
        max_epochs=cfg.num_epochs,
        # track_grad_norm=2,      # debug
        num_sanity_val_steps=0,  # debug
        # val_check_interval=0.1,  # debug
        # limit_val_batches=0.1,  # debug
        # limit_train_batches=6,  # debug
        # limit_val_batches=6,    # debug
        log_every_n_steps=1,  # debug
        # fast_dev_run=True,    # debug
        # max_steps=6,          # debug
        enable_model_summary=False,  # debug
    )
    
    tuner = Tuner(trainer)
    
    # --------------------------------------------------------
    # find learning rate
    # --------------------------------------------------------
    
    lr_finder = tuner.lr_find(model, train_dataloader, val_dataloader, max_lr=0.1)
    
    # Plot with
    fig = lr_finder.plot(suggest=True)
    fig.show()

    # Pick point based on plot, or get suggestion
    new_lr = lr_finder.suggestion()
    lrs.append(new_lr)
    
print(lrs)

In [None]:
np.array(lrs).mean()


In [None]:
np.array([lrs[1], lrs[3], lrs[4]]).mean()

In [None]:
np.array([lrs[0], lrs[1], lrs[2], lrs[4]]).mean()

## Tune hyperparameters with optuna

In [None]:
import optuna
from optuna.integration import PyTorchLightningPruningCallback

In [None]:
EPOCHS = 8
CLIP = 1

In [None]:
patient_df = data.groupby('PATIENT').first().reset_index()
target_stratisfy = cfg.target if type(cfg.target) is str else cfg.target[0]
train_idxs, val_idxs = train_test_split(range(len(patient_df)), stratify=patient_df[target_stratisfy], random_state=cfg.seed)

# training dataset
train_dataset = MILDatasetIndices(
    data,
    train_idxs, [cfg.target],
    num_tiles=cfg.num_tiles,
    pad_tiles=cfg.pad_tiles,
    norm=cfg.norm
)
train_dataloader = DataLoader(
    dataset=train_dataset, batch_size=cfg.bs, shuffle=True, num_workers=14, pin_memory=True
)

# validation dataset
val_dataset = MILDatasetIndices(data, val_idxs, [cfg.target], norm=norm_val)
val_dataloader = DataLoader(
    dataset=val_dataset, batch_size=1, shuffle=False, num_workers=14, pin_memory=True
)

# idx=2 since the ouput is feats, coords, labels
num_pos = sum([train_dataset[i][2] for i in range(len(train_dataset))])
cfg.pos_weight = torch.Tensor((len(train_dataset) - num_pos) / num_pos)
cfg.criterion = "BCEWithLogitsLoss"


In [None]:
def objective(trial: optuna.trial.Trial) -> float:
    
    lr = trial.suggest_loguniform("learning_rate", 1e-6, 1e-1)
    wd = trial.suggest_loguniform("weight_decay", 1e-6, 1e-1)

    model = ClassifierLightning(cfg)
    trainer = pl.Trainer(
        precision='16-mixed',
        accelerator='auto', 
        max_epochs=EPOCHS,
        gradient_clip_val=CLIP,
        callbacks=[PyTorchLightningPruningCallback(trial, monitor="auroc/val")],
    )
    hyperparameters = dict(lr=lr, wd=wd)
    trainer.logger.log_hyperparams(hyperparameters)
    trainer.fit(model, train_dataloader, val_dataloader)
    return trainer.callback_metrics["auroc/val"].item()

In [None]:
# --------------------------------------------------------
# tune hyperparameters
# --------------------------------------------------------

sampler = optuna.samplers.TPESampler(multivariate=True)
pruner = optuna.pruners.HyperbandPruner()
study = optuna.create_study(direction="maximize", sampler=sampler, pruner=pruner)
study.optimize(objective, n_trials=50, timeout=None)
print("Number of finished trials: {}".format(len(study.trials)))
print("Best trial:")
trial = study.best_trial
print("  Value: {}".format(trial.value))
print("  Params: ")
for key, value in trial.params.items():
    print("    {}: {}".format(key, value))

In [None]:
trainer.training_type_plugin

# Push changes to github

In [None]:
!git status

In [None]:
!git commit -a -m "changes in colab"

In [None]:
!git push

In [None]:
main(cfg)

In [None]:
args = parser.parser.parse_args('')

In [None]:
args

In [None]:
vars(args)