# Outline
TBD
TBD

# Configuring the environment

## Module installation
We’ll begin by installing the necessary Python modules for this tutorial.

In [None]:
import os

# - Install modules from requirements.txt if present
if os.path.isfile("requirements_plasticc.txt"):
  print("Installing modules from local requirements_plasticc.txt file ...")
  %pip install -q -r requirements_plasticc.txt
else:
  print("Installing modules ...")  

  %pip install -q pandas                                         # Data analysis modules                     
  %pip install -q torch torchvision torchmetrics torchinfo    # ML modules
  ##%pip install -q torch torchvision torchmetrics torchsummary    # ML modules
  %pip install -q sh gdown matplotlib tqdm                          # Plot/util modules
    
  # - Create requirements file
  %pip freeze > requirements_plasticc.txt

## Import modules
Next, we import the essential modules needed throughout the tutorial.

In [None]:
###########################
##   STANDARD MODULES
###########################
import os
import math
from pathlib import Path
import shutil
import gdown
import tarfile
import numpy as np
import json
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from itertools import islice
import random
import matplotlib.pyplot as plt
from tqdm import tqdm
import warnings
import urllib.request
from sh import gunzip

###########################
##   DATA/TORCH MODULES
###########################
# - Data analysis
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score

# - Torch modules
import torch
from torch import Tensor
from torch.utils.data import Dataset, Subset, random_split, DataLoader
import torch.nn.functional as F
import torchvision
from torchvision.datasets.vision import VisionDataset
import torchvision.transforms as T
import torchvision.transforms.functional as TF
from torchvision.transforms import ToTensor
import torchmetrics
import torchinfo
from torchinfo import summary
#from torchsummary import summary

## Project folders
We create a working directory `rundir` to run the tutorial in.

In [None]:
topdir= os.getcwd()
rundir= os.path.join(topdir, "run-plasticc_classifier")
path = Path(rundir)
path.mkdir(parents=True, exist_ok=True)

# Dataset
For this tutorial, we will use the [**PLASTICC dataset**](https://zenodo.org/records/2539456).

TBD
TBD

## Dataset Download
Next, we download the dataset from Google Drive and unzip it in the main folder.

In [None]:
def download_data(url, data_path, destdir):
  """ Download data """
  data_fullpath= os.path.join(destdir, data_path)
    
  print("Downloading file from url %s ..." % (url))
  urllib.request.urlretrieve(url, data_path)  
  print("DONE!")

  print("Moving file %s to dir %s ..." % (data_path, destdir))
  shutil.move(data_path, destdir)

  print("Unzipping dataset file %s ..." % (data_fullpath))
  gunzip(data_fullpath)


# - Download train metadata
train_metadata_url= "https://zenodo.org/records/2539456/files/plasticc_train_metadata.csv.gz?download=1"
train_metadata_gz_path= 'plasticc_train_metadata.csv.gz'
train_metadata_gz_fullpath= os.path.join(rundir, train_metadata_gz_path)
train_metadata_fullpath= os.path.join(rundir, 'plasticc_train_metadata.csv')
if not os.path.isfile(train_metadata_fullpath):
  download_data(train_metadata_url, train_metadata_path, rundir)

# - Download train data
train_data_url= "https://zenodo.org/records/2539456/files/plasticc_train_lightcurves.csv.gz?download=1"
train_data_gz_path= 'plasticc_train_lightcurves.csv.gz'
train_data_gz_fullpath= os.path.join(rundir, train_data_gz_path)
train_data_fullpath= os.path.join(rundir, 'plasticc_train_lightcurves.csv')
if not os.path.isfile(train_data_fullpath):
  download_data(train_data_url, train_data_path, rundir)

## Dataset loading

Define class names and other variables.

In [None]:
classes = np.array([6, 15, 16, 42, 52, 53, 62, 64, 65, 67, 88, 90, 92, 95, 99], dtype='int32')
class_names = ['class_6','class_15','class_16','class_42','class_52','class_53','class_62','class_64','class_65','class_67','class_88','class_90','class_92','class_95','class_99']
class_weight = {6: 1, 15: 2, 16: 1, 42: 1, 52: 1, 53: 1, 62: 1, 64: 2, 65: 1, 67: 1, 88: 1, 90: 1, 92: 1, 95: 1, 99: 1}

# LSST passbands (nm)  u    g    r    i    z    y      
passbands = np.array([357, 477, 621, 754, 871, 1004], dtype='float32')

### Loading train metadata

In [None]:
# - Read train metadata as panda data frame
train_metadata= pd.read_csv(train_metadata_fullpath)
print("--> Train metadata")
print(train_metadata)

### Loading train data

In [None]:
# - Load data
print(f"Loading train data from file {train_data_fullpath} ...")
train_data = pd.read_csv(train_data_fullpath)
print("train_data")
print(train_data)

### Splitting train/val sets
Let's reserve a small portion (10%) of the training dataset for validation scopes. Below, we split the original training dataset into train and validation data frames.

In [None]:
# - Split metadata in train/val sets
random_state= 42
test_size= 0.1
meta_df_train, meta_df_val = train_test_split(
  train_metadata,
  test_size=test_size,
  random_state=random_state,
  shuffle=True
)

# - Get object_ids of train/test splits
train_object_ids = meta_df_train['object_id'].unique()
val_object_ids = meta_df_val['object_id'].unique()

# - Use object IDs to split data_df
data_df_train = train_data[train_data['object_id'].isin(train_object_ids)]
data_df_val  = train_data[train_data['object_id'].isin(val_object_ids)]

print(f"#{len(meta_df_train)}/{len(meta_df_val)} data entries in train/val sets ...")

Compute train class weights. They will be used when training the model.

In [None]:
def get_wtable(df, classes):
  """ Compute class weights for all data entries """  
  all_y = np.array(df['target'], dtype = 'int32')
  nsamples= all_y.shape[0]
  y_count = np.unique(all_y, return_counts=True)[1]
  
  wtable= [float(count/nsamples) for count in y_count]
  wtable.append(1.0) # Add weights for 'OTHER' class  
    
  return wtable

# - Compute class weights
wtable = get_wtable(meta_df_train, classes)
print("--> wtable")
print(wtable)

### Create PyTorch datasets

Define PyTorch dataset

In [None]:
def pad_sequences_numpy(sequences, maxlen=None, dtype='float32', padding='post', truncating='post', value=0.0):
  """
    Pads a list of 2D numpy arrays (sequence_len_i, n_features) to shape (N, maxlen, n_features).
    
    Parameters:
        sequences : list of np.ndarray of shape (Ti, D)
        maxlen    : int or None, length to pad/truncate to. If None, use max sequence length.
        dtype     : data type of output array
        padding   : 'pre' or 'post'
        truncating: 'pre' or 'post'
        value     : value used for padding

    Returns:
        np.ndarray of shape (N, maxlen, D)
  """
  num_samples = len(sequences)
  feature_dim = sequences[0].shape[1]
  lengths = [seq.shape[0] for seq in sequences]

  if maxlen is None:
    maxlen = max(lengths)

  padded = np.full((num_samples, maxlen, feature_dim), value, dtype=dtype)

  for i, seq in enumerate(sequences):
    if truncating == 'pre':
      trunc = seq[-maxlen:]
    else:
      trunc = seq[:maxlen]

    if padding == 'pre':
      padded[i, -len(trunc):] = trunc
    else:
      padded[i, :len(trunc)] = trunc

  return padded

class PlasticcDataset(Dataset):
  def __init__(
    self,
    data_df,     # Pandas data frame
    meta_df, # Pandas data frame
    augment=False,
    use_specz=False,
    extragalactic=None,  
    nmax=-1
  ):
    # - Set options
    self.meta_df= meta_df
    self.data_df= data_df
    self.X_seq = None    # shape: (N, seq_len, 4)
    self.X_meta = None   # shape: (N, num_features)
    self.Y= None  # one-hot labels: (N, num_classes)
    self.wtable= None
    self.class_weights= None
    self.nmax= nmax
    self.use_specz= use_specz
    self.extragalactic= extragalactic
    self.classes = np.array([6, 15, 16, 42, 52, 53, 62, 64, 65, 67, 88, 90, 92, 95, 99], dtype='int32')
    self.class_names = ['class_6','class_15','class_16','class_42','class_52','class_53','class_62','class_64','class_65','class_67','class_88','class_90','class_92','class_95','class_99']
    self.classid2label= {
      6: "PS-MULENS",# Point_source_mu-lensing
      15: "TDE", # Tidal disruption event
      16: "EBE", # Eclipsing binary event
      42: "SN-II", # Core-collapse supernova Type II
      52: "SN-Iax", # Supernova Type Ia-x
      53: "MIRA", # Mira variable
      62: "SN-Ibc", # Core-collapse supernova Type Ibc
      64: "KN", # Kilonova
      65: "M-DWARF", # M dwarf
      67: "SN-Ia-91bg", # Supernova Type Ia-91bg
      88: "AGN", # Active galactic nucleus
      90: "SN-Ia", # Supernova Type Ia
      92: "RR-LY", # RR Lyrae
      95: "SLSN", # Superluminous supernova
      99: "OTHER", # Other class
    } 
    self.nclasses= len(self.classes)
    self.class_weight_factors= np.array([2,2,1,1,1,1,1,2,1,1,1,1,1,2,2], dtype='float32')
    
    # LSST passbands (nm)  u    g    r    i    z    y      
    self.passbands = np.array([357, 477, 621, 754, 871, 1004], dtype='float32')
    
    # - Load data
    self.__load_data()
    
  def __compute_wtable(self):
    """
      Compute:
        - wtable: class frequencies (N_class,)
        - class_weights: inverse frequency, normalized to sum to num_classes

      Returns:
        wtable (torch.Tensor), class_weights (torch.Tensor)
    """
    class_counts = self.Y.sum(dim=0)  # sum over all samples (across rows)
    total_samples = self.Y.shape[0]
    wtable = class_counts / total_samples
    wtable[self.nclasses-1]= 1.0
    
    # Inverse frequency as weight (avoid divide-by-zero)
    class_weights = 1.0 / (wtable + 1e-8)
    class_weights = class_weights * (len(wtable) / class_weights.sum())  # normalize to mean 1

    return wtable, class_weights


  def __load_data(self):
    """ Load data/metadata from files """      

    # - Group data by object_id
    groups = self.data_df.groupby('object_id')
    print(f"Reading {len(groups)} data entries ...")
    
    features_seq_all= []
    features_meta_all= []
    target_ids= []
    seq_min_size= 1.e+99
    seq_max_size= -1
    
    for g in groups:
      # - Find data with object_id
      id = g[0]
      meta = self.meta_df.loc[self.meta_df['object_id'] == id]
    
      z_photo= meta['hostgal_photoz'].iloc[0]         # photometric host-redshift (float32)
      zerr_photo= meta['hostgal_photoz_err'].iloc[0]  # uncertainty on photometric host-redshift
      ddf= meta['ddf_bool'].iloc[0]                   # boolean flag: 1 for DDF, 0 for WFD
      mwebv= meta['mwebv'].iloc[0]                    # Galactic E(B-V) extinction
      z_spec= meta['hostgal_specz'].iloc[0]           # accurate spectroscopic-redshift for small subset
      z= z_spec if self.use_specz else z_photo
      z_err= 0.0 if self.use_specz else zerr_photo
    
      # - Skip source with invalid redshift?
      if self.extragalactic == True and z_photo==0:
        continue

      if self.extragalactic == False and z_photo>0:
        continue

      # - Set target id
      if 'target' in meta:
        class_id= int(meta['target'].iloc[0])
        target_id= np.where(self.classes == class_id)[0][0]
      else:
        target_id= len(self.classes) - 1  # interpret as class 99
        
      target_ids.append(target_id)
        
      # - Set sequence features
      mjd      = np.array(g[1]['mjd'],      dtype='float32')
      mjd -= mjd[0]
      mjd /= 100 # Earth time shift in day*100
      mjd /= (z + 1) # Object time shift in day*100
      tdiff= np.ediff1d(mjd, to_begin = [0])
      band     = np.array(g[1]['passband'], dtype='int32')
      flux     = np.array(g[1]['flux'],     dtype='float32')
      flux_err = np.array(g[1]['flux_err'], dtype='float32')
      flux_max = np.max(flux)
      flux_min = np.min(flux)
      flux_norm= flux_max - flux_min 
      flux_pow = math.log2(flux_norm)
      detected = np.array(g[1]['detected_bool'], dtype='float32')
      received_wavelength = passbands[band] # Earth wavelength in nm
      received_freq = 300000 / received_wavelength # Earth frequency in THz
      source_wavelength = received_wavelength / (z + 1) # Object wavelength in nm
      received_freq/= 1000.
      source_wavelength/= 1000.
        
      #print("mjd.shape")  
      #print(mjd.shape)
      seq_size= mjd.shape[0]
      if seq_size>seq_max_size:
        seq_max_size= seq_size      
      if seq_size<seq_min_size:
        seq_min_size= seq_size 
        
      features_seq= np.zeros( (seq_size, 4), dtype = 'float32')
      features_seq[:,0]= tdiff
      features_seq[:,1]= flux/flux_norm
      features_seq[:,2]= flux_err/flux_norm
      features_seq[:,3]= source_wavelength
      ##features_seq[:,4]= detected
    
      features_seq_all.append(features_seq)
        
    
      # - Set metadata features
      features_meta= np.zeros(5, dtype = 'float32')
      features_meta[0]= ddf
      features_meta[1]= z
      features_meta[2]= z_err
      features_meta[3]= mwebv
      features_meta[4]= flux_pow / 10
      
      features_meta_all.append(features_meta)
        
    
      if len(features_seq_all) % 1000 == 0:
        print('Converting data {0}'.format(len(features_seq_all)), end='\r')

      if self.nmax!=-1 and len(features_seq_all) >= self.nmax:
        print(f'Reached data sample limit {self.nmax}...stop reading data')  
        break
        
    print(f"#{len(features_seq_all)} data added with seq range ({seq_min_size}, {seq_max_size})...")
    
    # - Find sequence truncation point (assuming a power of 2 larger than max seq size) 
    seq_size_opt= 2 ** math.ceil(math.log2(seq_max_size))
    print(f"Padding sequence data to a size of {seq_size_opt} ...")
    features_seq_padded= pad_sequences_numpy(features_seq_all, maxlen=seq_size_opt)
    print("features_seq_padded.shape")
    print(features_seq_padded.shape)
    
    # - Convert data to tensors
    self.X_seq= torch.from_numpy(features_seq_padded)
    self.X_meta= torch.from_numpy(np.array(features_meta_all))
    Y_target= torch.from_numpy(np.array(target_ids))
    #self.Y= F.one_hot(Y_target, num_classes=-1)
    self.Y= F.one_hot(Y_target, num_classes=len(self.classes))
    
    print("self.X_seq.shape")
    print(self.X_seq.shape)
    print("self.X_meta.shape")
    print(self.X_meta.shape)
    print("self.Y.shape")
    print(self.Y.shape)
    
    # - Compute wtable
    self.wtable, self.class_weights= self.__compute_wtable()
    print("self.wtable")
    print(self.wtable)
    print("self.class_weights")
    print(self.class_weights)
    
    
  def __len__(self):
    return len(self.X_seq)

  def __getitem__(self, idx):
    return self.X_seq[idx], self.X_meta[idx], self.Y[idx]

Create train dataset

In [None]:
dataset_train= PlasticcDataset(
  data_df_train, 
  meta_df_train,
  augment=False,
  use_specz=False,
  extragalactic=None,  
  nmax=-1
)

print("--> wtable")
print(wtable)
print("--> wtable (dataset)")
print(dataset_train.wtable)

Load validation dataset

In [None]:
dataset_val= PlasticcDataset(
  data_df_val, 
  meta_df_val,
  augment=False,
  use_specz=False,
  extragalactic=None,  
  nmax=-1
)

### Augmentate training data

In [None]:
# ...
# ...

### Create dataloaders

In [None]:
batch_size= 64
dl_train = DataLoader(dataset_train, batch_size=batch_size, shuffle=True)
dl_val   = DataLoader(dataset_val, batch_size=batch_size, shuffle=False)

## Model

### Define the classifier model
Define the RNN model. It takes two inputs:

- `seq`: Time series, shape (batch_size, time_steps, n_features)
- `meta`: Meta data parameters, shape (batch_size, n_meta_features)

In [None]:
class PlasticcClassifier(torch.nn.Module):
  def __init__(
    self, 
    n_features, 
    n_meta_features, 
    hidden_size=64, 
    num_layers=2, 
    num_classes=15, 
    bidirectional=False, 
    dropout=0.5,
    fc_hidden_size=128  
  ):
    super().__init__()

    self.num_directions = 2 if bidirectional else 1

    self.gru = torch.nn.GRU(
      input_size=n_features,
      hidden_size=hidden_size,
      num_layers=num_layers,
      batch_first=True,
      bidirectional=bidirectional,
      dropout=dropout if num_layers > 1 else 0.0
    )

    self.dropout = torch.nn.Dropout(dropout)

    self.classifier = torch.nn.Sequential(
      torch.nn.Linear(hidden_size * self.num_directions + n_meta_features, fc_hidden_size),
      torch.nn.ReLU(),
      torch.nn.Dropout(dropout),
      torch.nn.Linear(fc_hidden_size, num_classes),
    )

  def forward(self, x_seq, x_meta):
    rnn_out, _ = self.gru(x_seq)                      # (batch, seq_len, hidden*2)
    pooled, _ = torch.max(rnn_out, dim=1)             # (batch, hidden*2)
    x = torch.cat([pooled, x_meta], dim=1)            # (batch, hidden*2 + meta_features)
    logits = self.classifier(x)
    probs = F.softmax(logits, dim=1)                  # one-hot output
    return probs

Create classifier instance

In [None]:
# - Create model
n_features= dataset_train.X_seq.shape[2]
n_meta_features= dataset_train.X_meta.shape[1]
num_classes= dataset_train.Y.shape[1]
seq_length= dataset_train.X_seq.shape[1]
print(f"seq_length={seq_length}")
print(f"n_features={n_features}")
print(f"n_meta_features={n_meta_features}")
print(f"num_classes={num_classes}")

model = PlasticcClassifier(
  n_features=n_features,
  n_meta_features=n_meta_features,
  num_classes=num_classes,
  fc_hidden_size=128,
  bidirectional=True  
)

# - Print model structure
summary(
  model, 
  #input_data=[
  #  torch.randn(batch_size, seq_length, n_features),  # x_seq: (batch, seq_len, n_features)
  #  torch.randn(batch_size, n_meta_features)       # x_meta: (batch, n_meta_features)
  #]
  input_size=[(batch_size, seq_length, n_features), (batch_size, n_meta_features)] 
)

### Model training

Define the loss function as required by the PLASTICC challenge

In [None]:
def multi_weighted_logloss(y_true, y_pred, class_weights):
  """
    multi logloss for PLAsTiCC challenge
    Adapted from TF version: https://www.kaggle.com/ogrellier
  
    - y_true/y_pred are one-hot encoded
    - class_weights=[2,2,1,1,1,1,1,2,1,1,1,1,1,2,2] (see challenge result paper)
  """ 
    
  # - Normalize rows and limit y_preds to eps, 1-eps    
  eps = 1e-15
  preds = torch.clamp(y_pred, min=eps, max=1 - eps)

  # - Transform to log
  log_preds = torch.log(y_pred)  # (N, M)

  # (1) Per-class weighted log loss: sum across samples
  y_log_ones = torch.sum(y_true * log_preds, dim=0)              # (M,)
  nb_pos = torch.sum(y_true, dim=0)                              # (M,)
  nb_pos = torch.where(nb_pos == 0, torch.ones_like(nb_pos), nb_pos)

  y_w = y_log_ones * class_weights / nb_pos                       # (M,)

  loss = -torch.sum(y_w) / torch.sum(class_weights)
  loss = loss / y_pred.shape[0]  # normalize by batch size
  return loss

Define a function to initialize weights before training.

In [None]:
def initialize_weights(model):
  """ Applies custom weight initialization to layers in the model """
  for m in model.modules():
    if isinstance(m, torch.nn.Linear):
      torch.nn.init.xavier_uniform_(m.weight)  # or kaiming_uniform_
      if m.bias is not None:
        torch.nn.init.zeros_(m.bias)

    elif isinstance(m, torch.nn.GRU):
      for name, param in m.named_parameters():
        if 'weight_ih' in name:
          torch.nn.init.xavier_uniform_(param.data)
        elif 'weight_hh' in name:
          torch.nn.init.orthogonal_(param.data)
        elif 'bias' in name:
          torch.nn.init.zeros_(param.data)

Define the training loop

In [None]:
class AverageMeter:
  def __init__(self):
    self.reset()

  def reset(self):
    self.sum = 0
    self.count = 0

  def update(self, value, n=1):
    self.sum += value * n
    self.count += n

  @property
  def avg(self):
    return self.sum / self.count if self.count > 0 else 0


def train_model(
    model,
    dl_train,
    dl_val=None,
    num_epochs=1,
    lr=1e-3,
    checkpoint_path="model_checkpoint.pth",
    class_weights=None,
    clip_grad=False,
    max_grad_norm=5
):
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  model.to(device)

  # - Check class weights  
  if class_weights is None:
    raise ValueError("class_weights must be provided")

  class_weights = class_weights.to(device)
  optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    
  # - Init metrics
  loss_meter = AverageMeter()
  acc_metric = torchmetrics.Accuracy(task="multiclass", num_classes=num_classes).to(device)
  f1_metric = torchmetrics.F1Score(task="multiclass", num_classes=num_classes, average="macro").to(device)
  confusion_matrix_metric = torchmetrics.ConfusionMatrix(task="multiclass", num_classes=num_classes, normalize="true").to(device)
    
  val_acc_metric = val_f1_metric = None
  if dl_val is not None:
    val_loss_meter = AverageMeter()    
    val_acc_metric = torchmetrics.Accuracy(task="multiclass", num_classes=num_classes).to(device)
    val_f1_metric = torchmetrics.F1Score(task="multiclass", num_classes=num_classes, average="macro").to(device)
    val_confusion_matrix_metric = torchmetrics.ConfusionMatrix(task="multiclass", num_classes=num_classes, normalize="true").to(device)
    
  history = {
    "loss_train": [],
    "acc_train": [],
    "f1score_train": [],
    "cm_train": None,
    "cm_metric_train": None,  
    "loss_val": [],
    "acc_val": [],
    "f1score_val": [],
    "cm_val": None,
    "cm_metric_val": None,   
  }

  # - Start training loop
  for epoch in range(num_epochs):
    model.train()
    
    # - Reset avg metrics
    progress = tqdm(dl_train, desc=f"Epoch {epoch+1} [Train]", leave=False)
    loss_meter.reset()
    acc_metric.reset()
    f1_metric.reset()
    confusion_matrix_metric.reset()
    
    total_loss = 0.0
    y_true_all = []
    y_pred_all = []

    # - Run batch loop
    for x_seq, x_meta, y in progress:    
      x_seq, x_meta, y = x_seq.to(device), x_meta.to(device), y.to(device)

      optimizer.zero_grad()
      y_pred = model(x_seq, x_meta)

      loss = multi_weighted_logloss(y, y_pred, class_weights)
      
      # ✅ Check 1: Loss is finite (not NaN or Inf)
      if not torch.isfinite(loss):
        print("⚠️ Warning: loss is NaN or Inf. Skipping this batch.")
        #continue  # skip backprop for this batch
        
      loss.backward()
    
      # ✅ Check 2: Optional: Print max gradient norm for debugging
      total_norm = 0
      for p in model.parameters():
        if p.grad is not None:
          param_norm = p.grad.data.norm(2)
          total_norm += param_norm.item() ** 2
      total_norm = total_norm ** 0.5
      if total_norm > 1e3:
        print(f"⚠️ High gradient norm: {total_norm:.2f}")

      # ✅ Check 3: Clip gradients to prevent explosion
      if clip_grad:
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=max_grad_norm)
    
      #if clip_grad:  
      #  #`clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
      #  torch.nn.utils.clip_grad_norm_(model.parameters(), gradclip)
      #  for p in model.parameters():
      #    p.data.add_(p.grad, alpha=-lr)  
    
      optimizer.step() 
        
      # - Update loss and metrics 
      target_pred = y_pred.argmax(dim=1)
      target_true= y.argmax(dim=1)
      loss_meter.update(loss.item(), x_seq.size(0))
      acc_metric.update(target_pred, target_true)
      f1_metric.update(target_pred, target_true)
      confusion_matrix_metric.update(target_pred, target_true)  
        
      total_loss += loss.item()
      y_true_all.append(target_true.cpu())
      y_pred_all.append(target_pred.cpu())  
    
      # - Update progress bar
      progress.set_postfix({
        "loss": f"{loss_meter.avg:.4f}",
        "acc": f"{acc_metric.compute().item():.4f}",
        "f1": f"{f1_metric.compute().item():.4f}"
      })  

    # - Compute average metrics (v1)
    avg_train_loss = total_loss / len(dl_train)
    y_true_all = torch.cat(y_true_all)
    y_pred_all = torch.cat(y_pred_all)
    train_acc = (y_true_all == y_pred_all).float().mean().item()
    train_f1 = f1_score(y_true_all, y_pred_all, average='macro', zero_division=0)

    # - Compute average metrics (v2)
    avg_train_loss_v2 = loss_meter.avg
    train_acc_v2 = acc_metric.compute().item()
    train_f1_v2 = f1_metric.compute().item()
    confusion_matrix= confusion_matrix_metric.compute().cpu().numpy()
    history["loss_train"].append(avg_train_loss_v2)
    history["acc_train"].append(train_acc_v2)
    history["f1score_train"].append(train_f1_v2) 
    history["cm_train"]= confusion_matrix
    history["cm_metric_train"]= confusion_matrix_metric
        
    #print(f"Epoch {epoch+1}/{num_epochs} - Train Loss: {avg_train_loss:.4f} | Acc: {train_acc:.4f} | F1: {train_f1:.4f}", end='')
    print(f"Epoch [{epoch+1}/{num_epochs}]: loss={avg_train_loss:.4f}, {avg_train_loss_v2:.4f} | acc={train_acc:.4f}, {train_acc_v2:.4f} | f1={train_f1:.4f}, {train_f1_v2:.4f}", end='')

    
    if dl_val is not None:
      model.eval()
    
      # - Init val metrics
      val_loss_meter.reset()
      val_acc_metric.reset()
      val_f1_metric.reset()
      val_confusion_matrix_metric.reset()
      val_progress = tqdm(dl_val, desc=f"Epoch {epoch+1} [Val]", leave=False)
      val_loss = 0.0
      y_true_val = []
      y_pred_val = []

      with torch.no_grad():
        for x_seq, x_meta, y in val_progress:    
          x_seq, x_meta, y = x_seq.to(device), x_meta.to(device), y.to(device)
          y_pred = model(x_seq, x_meta)

          loss = multi_weighted_logloss(y, y_pred, class_weights)
        
          # - Update loss and accuracy  
          val_loss_meter.update(loss.item(), x_seq.size(0))
          target_pred = y_pred.argmax(dim=1)
          target_true = y.argmax(dim=1)  
          val_acc_metric.update(target_pred, target_true)
          val_f1_metric.update(target_pred, target_true)
          val_confusion_matrix_metric.update(target_pred, target_true) 
          val_loss += loss.item()

          y_true_val.append(y.argmax(dim=1).cpu())
          y_pred_val.append(y_pred.argmax(dim=1).cpu())
  
          # - Update progress bar
          val_progress.set_postfix({"loss": f"{val_loss_meter.avg:.4f}"})
        
      # - Compute average metrics (v1)    
      avg_val_loss = val_loss / len(dl_val)
      y_true_val = torch.cat(y_true_val)
      y_pred_val = torch.cat(y_pred_val)
      val_acc = (y_true_val == y_pred_val).float().mean().item()
      val_f1 = f1_score(y_true_val, y_pred_val, average='macro', zero_division=0)

      # - Compute average metrics (v2)    
      avg_val_loss_v2 = val_loss_meter.avg
      val_acc_v2 = val_acc_metric.compute().item()
      val_f1_v2 = val_f1_metric.compute().item()
      val_confusion_matrix= val_confusion_matrix_metric.compute().cpu().numpy()  
      history["loss_val"].append(avg_val_loss_v2)
      history["acc_val"].append(val_acc_v2)
      history["f1score_val"].append(val_f1_v2)  
      history["cm_val"]= val_confusion_matrix
      history["cm_metric_val"]= val_confusion_matrix_metric
    
      print(f" | Val Loss: {avg_val_loss:.4f}, {avg_val_loss_v2:.4f} | Val Acc: {val_acc:.4f}, {val_acc_v2:.4f} | Val F1: {val_f1:.4f}, {val_f1_v2:.4f}")
    else:
      print()

  # - Save final model
  print(f"\n✅ Model checkpoint saved to: {checkpoint_path}")
  torch.save(model.state_dict(), checkpoint_path)  
  #torch.save(model, outfile_model)

  print("Training complete.")
  return history

Initialize weights and start training

In [None]:
# - Initialize weights
torch.manual_seed(10)
initialize_weights(model)

# - Run train
nepochs= 100
lr= 1e-4
class_weights= torch.from_numpy(dataset_train.class_weight_factors)
outfile_weights= os.path.join(rundir, "model_checkpoint.pth")

metric_hist= train_model(
  model, 
  dl_train=dl_train, 
  dl_val=dl_val, 
  num_epochs=nepochs,
  lr=lr,
  checkpoint_path=outfile_weights,
  class_weights=class_weights,
  clip_grad=False,
  max_grad_norm=5  
)

Let’s plot the training and validation metrics after the training run is complete.

In [None]:
def draw_metric_hist(metric_hist):
  
  epochs = np.arange(1, len(metric_hist["loss_train"]) + 1)
  fig = plt.figure(figsize=(14, 5))

  # - Plot train/val loss
  ax1 = fig.add_subplot(1, 2, 1)
  ax1.plot(epochs, metric_hist["loss_train"], '-o', label='Train Loss')
  ax1.plot(epochs, metric_hist["loss_val"], '--<', label='Validation Loss')
  ax1.set_title("Loss Over Epochs", fontsize=14)
  ax1.set_xlabel("Epoch", fontsize=12)
  ax1.set_ylabel("Loss", fontsize=12)
  ax1.legend(fontsize=11)
  ax1.grid(True) 
    
  # - Plot acc/f1score
  ax2 = fig.add_subplot(1, 2, 2)
  ax2.set_ylim(0, 1)
  ax2.plot(epochs, metric_hist["acc_train"], '-o', label='Train Accuracy')
  ax2.plot(epochs, metric_hist["acc_val"], '--<', label='Validation Accuracy')
  ax2.plot(epochs, metric_hist["f1score_train"], '-*', label='Train F1-score')
  ax2.plot(epochs, metric_hist["f1score_val"], '-->', label='Validation F1-score')
  ax2.set_title("Accuracy and F1-score", fontsize=14)
  ax2.set_xlabel("Epoch", fontsize=12)
  ax2.set_ylabel("Score", fontsize=12)
  ax2.legend(fontsize=11)
  ax2.grid(True)

  plt.tight_layout()
  plt.show()

# - Print & plot metrics
print("== Training Metrics ==")
#print(metric_hist)

draw_metric_hist(metric_hist)

# - Draw confusion matrix
fig, ax = plt.subplots(figsize=(20,20))
metric_hist["cm_metric_train"].plot(ax=ax)
#fig_, ax_ = metric_hist["cm_metric_train"][-1].plot()