# Training
This notebook allows interactive training of the Vision Transformer model for statistical downscaling using different loss functions.

Click the button below to run this notebook on Google Colab.

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/relmonta/loss-bench/blob/main/training/train.ipynb)

In [None]:
# If you are using google colab clone the github repository

try:
    %cd /content
    import google.colab
    ! git clone https://github.com/relmonta/loss-bench.git
    %cd loss-bench
except:
    pass
# Assuming you are in the root directory of the repository
! pip install -r requirements.txt

In [None]:
try:
    import google.colab
    %cd /content/loss-bench
except:
    # Assuming you are in the root directory of the repository
    pass
import torch
from torch import optim
from torch.utils.data import DataLoader
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
import yaml
import numpy as np
import os
import argparse
import pickle
from data.data_module import DatasetSetup
from models.vision_transformer import VisionTransformer
from training.lightning_module import DownscalingModel
from models.losses import *
from training.utils import *

In [None]:
if torch.cuda.is_available():
    print("You are good to go !")
else:
    print("You are not using a GPU. Please check your Google Colab execution setup")


### 1. User settings

In [None]:

# Set experiment parameters here
var_name = "pr"   # "pr" or "uas"
criterion_name = "mse"   # e.g., "mse", "mae", "ssim", "combo1"
apply_log = None   # set to "true" or "false" to override config



### 2. Load configuration

In [None]:
def load_yaml(file_path):
    with open(file_path, 'r') as file:
        return yaml.safe_load(file)

MAIN_DIR = os.getcwd()  # assume notebook root is project root
exp_config = load_yaml(os.path.join(MAIN_DIR, 'configs', f'exp_config_{var_name}.yaml'))

batch_size = exp_config['training']['batch_size']
accumulate_grad_batches = exp_config['training']['accumulate_grad_batches']
num_epochs = exp_config['training']['epochs']
learning_rate = exp_config['training']['learning_rate']
num_cpus = os.cpu_count()
num_workers = min(exp_config['training']['num_workers'], num_cpus)

print(f"Training {var_name} with {criterion_name} for {num_epochs} epochs")


### 3. Dataset setup

In [None]:
require_gamma_params = "asym" in criterion_name.lower()
if "nllbg" in criterion_name.lower():
    exp_config['data']['kwargs_train_val']['normalize'] = False
    exp_config['data']['kwargs_train_val']['standardize'] = False

if apply_log is not None:
    exp_config['data']['kwargs_train_val']['apply_log'] = apply_log.lower() == "true"

dss = DatasetSetup(exp_config, require_gamma_params=require_gamma_params)
dss.setup()

train_dataset = dss.get_train_ds()
val_dataset = dss.get_val_ds()

train_loader = DataLoader(train_dataset, batch_size=batch_size,
                          num_workers=num_workers, shuffle=True, prefetch_factor=2)
val_loader = DataLoader(val_dataset, batch_size=batch_size,
                        num_workers=num_workers)



### 4. Model and loss

In [None]:
model_args = exp_config["model"]["params"]
model_args["bernoulli_gamma"] = "nllbg" in criterion_name.lower()
model = VisionTransformer(**model_args).cuda()

# Load YAML config containing loss details (losses_var.yaml)
loss_config = load_yaml(exp_config['training']['loss_config_path'])
metrics = {}
for metric_name in exp_config['training']['metrics']:
    metric_args = loss_config['losses'].get(metric_name.lower(), {}) or {}
    metrics[metric_name] = get_criterion(metric_name, **metric_args)

optimizer = optim.Adam(model.parameters(), lr=learning_rate,
                        weight_decay=exp_config['training']['weight_decay'])


In [None]:

# Load loss configuration

loss_config = load_yaml(exp_config['training']['loss_config_path'])
loss_args = loss_config['losses'].get(criterion_name.lower(), {}) or {}

if criterion_name.lower().startswith('combo'):
    # For combination losses, gather individual loss args
    wargs_dict = {loss: loss_config['losses'].get(loss, {}) or {}
                    for loss in loss_config['losses'][criterion_name]["losses"]}
    loss_args['losses'] = wargs_dict
    print(f"Losses args: {wargs_dict}")
    criterion = get_criterion("combination", **loss_args)
    print(
        f"Training using a combination of : {[loss_config['display'][loss] for loss in loss_args['losses']] } losses")
else:
    # For single losses, use args directly
    criterion = get_criterion(criterion_name, **loss_args)
    print(f"Training using {loss_config['display'][criterion_name]} loss")

if require_gamma_params:
    # Get asym params from train dataset
    set_asym_params(criterion, train_dataset)


### 5. Experiment setup


In [None]:

description = criterion_name
if train_dataset.apply_log_flag:
    description = "log_" + description

weight_path = os.path.join(exp_config['training']['weights_path'], exp_config['name'])
os.makedirs(weight_path, exist_ok=True)

filename = f"weights-{description}"
ckpt_path = os.path.join(weight_path, filename + ".ckpt")

if os.path.exists(ckpt_path):
    os.remove(ckpt_path)
    print(f"Deleted old checkpoint: {ckpt_path}")


### 6. Model wrapper

In [None]:


downscaling_model = DownscalingModel(
    model, criterion, optimizer, learning_rate, metrics=metrics
)



### 7. Logging & callbacks


In [None]:


tb_path = os.path.join(MAIN_DIR, 'training/logs/tensorboard/',
                       exp_config['name'], description)
if os.path.exists(tb_path):
    os.system(f"rm -rf {tb_path}/*")
    print(f"Deleted old logs: {tb_path}")

logger = TensorBoardLogger(
    save_dir=os.path.join(MAIN_DIR, 'training/logs/tensorboard/'),
    name=exp_config['name'],
    version=description
)

checkpoint_callback = ModelCheckpoint(
    dirpath=weight_path,
    filename=filename,
    save_top_k=1,
    monitor='val_loss',
    mode='min'
)



### 8. Training


In [None]:

torch.set_float32_matmul_precision('medium')

trainer = Trainer(
    max_epochs=num_epochs,
    devices='auto',
    accelerator='auto',
    precision=32,
    callbacks=[checkpoint_callback],
    logger=logger,
    log_every_n_steps=1,
    strategy='ddp_find_unused_parameters_true',
    accumulate_grad_batches=accumulate_grad_batches,
    num_sanity_val_steps=0,
    detect_anomaly=False
)

trainer.fit(downscaling_model, train_loader, [val_loader])