In [None]:
# PIP installations
# !pip freeze > packages.txt
# !pip uninstall -y -r packages.txt
!pip install \
numpy==1.26.4 \
opencv-python==4.9.0.80 \
pandas==2.2.2 \
matplotlib==3.8.3 \
tqdm \
scipy==1.12.0 \
scikit-image==0.22.0 \
PyWavelets==1.5.0 \
torch torchvision \
scikit-learn==1.2.1 \
xgboost \
lightgbm \
catboost \
imbalanced-learn==0.12.4 \
Pillow -q

In [7]:
#!/bin/bash
!curl -L -o /kaggle/working/synthesized_sdo.zip\
  https://www.kaggle.com/datasets/ef7b69fce738908b7ae67d0e4860d177203e61c5c55f4d3c19b5a3820c864e81

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  9088    0  9088    0     0  37635      0 --:--:-- --:--:-- --:--:-- 37709


In [None]:
# Imports
import os
import os.path as path
import json
import random
from io import BytesIO

# Use once at the beginning
import cv2
import numpy as np
import pandas as pd
import pickle
import warnings
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib
from tqdm import tqdm
from scipy import ndimage
from skimage.feature import canny, local_binary_pattern
from skimage.filters import gabor
import pywt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, WeightedRandomSampler
from torch.optim.lr_scheduler import SequentialLR, ExponentialLR, CosineAnnealingLR

import torchvision.transforms as T
from torchvision.io import read_image

from torchvision.models import (
    Swin_V2_B_Weights,
    Swin_V2_S_Weights,
    EfficientNet_V2_M_Weights,
    EfficientNet_V2_S_Weights,
    swin_v2_s,
    swin_v2_b,
    efficientnet_v2_m,
    efficientnet_v2_s,
)
from torchvision.transforms import InterpolationMode
from torchvision.ops import sigmoid_focal_loss

from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    confusion_matrix,
    accuracy_score,
    precision_score,
    recall_score,
    f1_score,
    matthews_corrcoef,
    roc_auc_score,
)
from sklearn.ensemble import (
    RandomForestClassifier,
    ExtraTreesClassifier,
    GradientBoostingClassifier,
    HistGradientBoostingClassifier,
    AdaBoostClassifier,
    BaggingClassifier,
)
from sklearn.svm import SVC
from sklearn.linear_model import SGDClassifier

from xgboost import XGBClassifier
from lightgbm import LGBMClassifier
from catboost import CatBoostClassifier

from imblearn.over_sampling import SMOTE

# from transformers import DistilBertConfig, DistilBertForSequenceClassification
# from transformers import RobertaConfig, RobertaForSequenceClassification

warnings.filterwarnings("ignore")
matplotlib.use("Agg")

In [None]:
# metrics.py
def true_skill_statistic(y_true, y_pred):
    x = confusion_matrix(y_true, y_pred, labels=[0, 1]).ravel()
    tn, fp, fn, tp = x
    tpr = tp / (tp + fn)
    fpr = fp / (fp + tn)
    return float(tpr - fpr)

def far_score(y_true, y_pred):
    tn, fp, fn, tp = confusion_matrix(y_true, y_pred, labels=[0, 1]).ravel()
    return fp / (tp + fp) if (tp + fp) > 0 else 0.0

def csi_score(y_true, y_pred):
    tn, fp, fn, tp = confusion_matrix(y_true, y_pred, labels=[0, 1]).ravel()
    denom = tp + fp + fn
    return tp / denom if denom > 0 else 0.0

def hss_score(y_true, y_pred):
    tn, fp, fn, tp = confusion_matrix(y_true, y_pred, labels=[0, 1]).ravel()
    num = 2 * (tp * tn - fn * fp)
    denom = ((tp + fn) * (fn + tn) + (tp + fp) * (fp + tn))
    return num / denom if denom > 0 else 0.0

In [None]:
# helpers.py
def process_meta_data(meta_data_path:str):
    df = pd.read_csv(meta_data_path)
    df = df.drop(columns=["start", "end"])
    df["peak_flux"] = (df["peak_flux"] > 1e-5).astype(int)
    return df

class ReadImgs(nn.Module):
    def __init__(self, form="stack"):
        super().__init__()
        assert form == "concat" or form == "stack" or form == "none"
        self.form = form
    def forward(self, x):
        x = x if isinstance(x, list) else [x]
        imgs = []
        for img_path in x:
            try:
                _img = read_image(img_path)[0]
                imgs.append(_img)
            except:
                print("error reading image, path:", img_path)

        if self.form == "concat":
            return torch.cat(imgs)
        elif self.form == "stack":
            return torch.stack(imgs)
        elif self.form == "none":
            return imgs
        elif self.form == "dict":
            return {img_path:img for img_path, img in zip(x, imgs)}

class GrayScaleToRGB(nn.Module):
    def __init__(self):
        super().__init__()
    
    def forward(self, x):
        if x.ndim == 3:
            return x.repeat(3, 1, 1)
        
        if x.ndim == 4:
            return x.repeat(1, 3, 1, 1)

class FrequencyChannelTransform(nn.Module):
    """
    Creates 3 channels: Original, FFT Magnitude, FFT Phase
    """
    def __init__(self):
        super().__init__()
    
    def forward(self, x):
        """
        Args:
            x: Tensor of shape (B, 1, H, W) or numpy array (H, W)
        Returns:
            Tensor of shape (B, 3, H, W)
        """
        if isinstance(x, torch.Tensor):
            is_batch = len(x.shape) == 4
            if not is_batch:
                x = x.unsqueeze(0)
            
            # Process each image in batch
            batch_results = []
            for img in x:
                img_np = img.squeeze().cpu().numpy()
                transformed = self._transform_single(img_np)
                batch_results.append(transformed)
            
            result = torch.from_numpy(np.stack(batch_results, axis=0)).float()
            return result if is_batch else result.squeeze(0)
        else:
            # Single numpy array
            return torch.from_numpy(self._transform_single(x)).float()
    
    def _transform_single(self, gray_image):
        # Original grayscale
        original = gray_image
        
        # FFT
        fft = np.fft.fft2(gray_image)
        fft_shifted = np.fft.fftshift(fft)
        
        # Magnitude spectrum (log scale)
        magnitude = np.log1p(np.abs(fft_shifted))
        magnitude = (magnitude - magnitude.min()) / (magnitude.max() - magnitude.min() + 1e-8)
        
        # Phase spectrum
        phase = np.angle(fft_shifted)
        phase = (phase - phase.min()) / (phase.max() - phase.min() + 1e-8)
        
        return np.stack([original, magnitude, phase], axis=0)

class EdgeGradientChannelTransform(nn.Module):
    """
    Creates 3 channels: Original, Edges, Gradient Magnitude
    """
    def __init__(self, sigma=2.0):
        super().__init__()
        self.sigma = sigma
    
    def forward(self, x):
        if isinstance(x, torch.Tensor):
            is_batch = len(x.shape) == 4
            if not is_batch:
                x = x.unsqueeze(0)
            
            batch_results = []
            for img in x:
                img_np = img.squeeze().cpu().numpy()
                transformed = self._transform_single(img_np)
                batch_results.append(transformed)
            
            result = torch.from_numpy(np.stack(batch_results, axis=0)).float()
            return result if is_batch else result.squeeze(0)
        else:
            return torch.from_numpy(self._transform_single(x)).float()
    
    def _transform_single(self, gray_image):
        # Original
        original = gray_image
        
        # Edge detection
        edges = canny(gray_image, sigma=self.sigma).astype(np.float32)
        
        # Gradient magnitude
        sobel_x = ndimage.sobel(gray_image, axis=0)
        sobel_y = ndimage.sobel(gray_image, axis=1)
        gradient = np.sqrt(sobel_x**2 + sobel_y**2)
        gradient = (gradient - gradient.min()) / (gradient.max() - gradient.min() + 1e-8)
        
        return np.stack([original, edges, gradient], axis=0)

class WaveletChannelTransform(nn.Module):
    """
    Creates 3 channels: Original, Wavelet Approximation, Wavelet Detail
    """
    def __init__(self, wavelet='db4'):
        super().__init__()
        self.wavelet = wavelet
    
    def forward(self, x):
        if isinstance(x, torch.Tensor):
            is_batch = len(x.shape) == 4
            if not is_batch:
                x = x.unsqueeze(0)
            
            batch_results = []
            for img in x:
                img_np = img.squeeze().cpu().numpy()
                transformed = self._transform_single(img_np)
                batch_results.append(transformed)
            
            result = torch.from_numpy(np.stack(batch_results, axis=0)).float()
            return result if is_batch else result.squeeze(0)
        else:
            return torch.from_numpy(self._transform_single(x)).float()
    
    def _transform_single(self, gray_image):
        # Original
        original = gray_image
        
        # Discrete Wavelet Transform
        coeffs = pywt.dwt2(gray_image, self.wavelet)
        cA, (cH, cV, cD) = coeffs
        
        # Approximation (low frequency)
        approx = cv2.resize(cA, (gray_image.shape[1], gray_image.shape[0]))
        approx = (approx - approx.min()) / (approx.max() - approx.min() + 1e-8)
        
        # Combined detail (high frequency)
        detail = np.sqrt(cH**2 + cV**2 + cD**2)
        detail = cv2.resize(detail, (gray_image.shape[1], gray_image.shape[0]))
        detail = (detail - detail.min()) / (detail.max() - detail.min() + 1e-8)
        
        return np.stack([original, approx, detail], axis=0)

class TextureChannelTransform(nn.Module):
    """
    Creates 3 channels: Original, LBP, Gabor
    """
    def __init__(self, lbp_radius=1, lbp_points=8, gabor_frequency=0.2):
        super().__init__()
        self.lbp_radius = lbp_radius
        self.lbp_points = lbp_points
        self.gabor_frequency = gabor_frequency
    
    def forward(self, x):
        if isinstance(x, torch.Tensor):
            is_batch = len(x.shape) == 4
            if not is_batch:
                x = x.unsqueeze(0)
            
            batch_results = []
            for img in x:
                img_np = img.squeeze().cpu().numpy()
                transformed = self._transform_single(img_np)
                batch_results.append(transformed)
            
            result = torch.from_numpy(np.stack(batch_results, axis=0)).float()
            return result if is_batch else result.squeeze(0)
        else:
            return torch.from_numpy(self._transform_single(x)).float()
    
    def _transform_single(self, gray_image):
        # Original
        original = gray_image
        
        # Local Binary Pattern
        lbp = local_binary_pattern(gray_image, P=self.lbp_points, 
                                   R=self.lbp_radius, method='uniform')
        lbp = (lbp - lbp.min()) / (lbp.max() - lbp.min() + 1e-8)
        
        # Gabor filter (texture)
        gabor_real, _ = gabor(gray_image, frequency=self.gabor_frequency, theta=0)
        gabor_feat = (gabor_real - gabor_real.min()) / (gabor_real.max() - gabor_real.min() + 1e-8)
        
        return np.stack([original, lbp, gabor_feat], axis=0)

class ComprehensiveChannelTransform(nn.Module):
    """
    Creates 5 channels: Original, FFT Magnitude, Edges, Gradient, Wavelet Detail
    """
    def __init__(self, sigma=2.0, wavelet='db4'):
        super().__init__()
        self.sigma = sigma
        self.wavelet = wavelet
    
    def forward(self, x):
        if isinstance(x, torch.Tensor):
            is_batch = len(x.shape) == 4
            if not is_batch:
                x = x.unsqueeze(0)
            
            batch_results = []
            for img in x:
                img_np = img.squeeze().cpu().numpy()
                transformed = self._transform_single(img_np)
                batch_results.append(transformed)
            
            result = torch.from_numpy(np.stack(batch_results, axis=0)).float()
            return result if is_batch else result.squeeze(0)
        else:
            return torch.from_numpy(self._transform_single(x)).float()
    
    def _transform_single(self, gray_image):
        original = gray_image
        
        # FFT Magnitude
        fft = np.fft.fft2(gray_image)
        fft_mag = np.log1p(np.abs(np.fft.fftshift(fft)))
        fft_mag = (fft_mag - fft_mag.min()) / (fft_mag.max() - fft_mag.min() + 1e-8)
        
        # Edges
        edges = canny(gray_image, sigma=self.sigma).astype(np.float32)
        
        # Gradient
        sobel_x = ndimage.sobel(gray_image, axis=0)
        sobel_y = ndimage.sobel(gray_image, axis=1)
        gradient = np.sqrt(sobel_x**2 + sobel_y**2)
        gradient = (gradient - gradient.min()) / (gradient.max() - gradient.min() + 1e-8)
        
        # Wavelet
        coeffs = pywt.dwt2(gray_image, self.wavelet)
        cA, (cH, cV, cD) = coeffs
        detail = np.sqrt(cH**2 + cV**2 + cD**2)
        detail = cv2.resize(detail, (gray_image.shape[1], gray_image.shape[0]))
        detail = (detail - detail.min()) / (detail.max() - detail.min() + 1e-8)
        
        return np.stack([original, fft_mag, edges, gradient, detail], axis=0)

class RepeatChannelTransform(nn.Module):
    """
    Baseline: Simply repeats grayscale channel 3 times
    """
    def __init__(self, num_repeats=3):
        super().__init__()
        self.num_repeats = num_repeats
    
    def forward(self, x):
        if isinstance(x, torch.Tensor):
            if len(x.shape) == 2:  # (H, W)
                return x.unsqueeze(0).repeat(self.num_repeats, 1, 1)
            elif len(x.shape) == 3:  # (1, H, W)
                return x.repeat(self.num_repeats, 1, 1)
            elif len(x.shape) == 4:  # (B, 1, H, W)
                return x.repeat(1, self.num_repeats, 1, 1)
        else:
            return torch.from_numpy(np.stack([x] * self.num_repeats, axis=0)).float()

class SDOCacheTransform(nn.Module):
    def __init__(self, chunks=False):
        super().__init__()
        self.stacking_dim = int(chunks)
        pass
    def forward(self, x):
        lst = []
        for wavelength in x.keys():
            lst.append(x[wavelength])
        return torch.stack(lst, dim=self.stacking_dim)

class RandomRotate90(nn.Module):
    def __init__(self, angles=(0, 90, 180, 270)):
        super().__init__()
        self.angles = angles

    def __call__(self, img):
        # img can be PIL Image or Tensor
        angle = random.choice(self.angles)
        return T.functional.rotate(img, angle)

    def str(self):
        return f"RandomRotate90(angles={self.angles})"
    def __repr__(self):
        return str(self)

def write_json(dict, dir):
    with open(dir, "w", encoding="utf-8") as f:
        f.write(json.dumps(dict))

def write_note(txt, dir):
    with open(dir, "w", encoding="utf-8") as f:
        f.write(txt)

def read_json(dir):
    with open(dir, "r", encoding="utf-8") as f:
            return json.loads(f.read())

def list_models():
    weights = read_json("weights.json")
    weights = sorted(weights, key=lambda x:-float(x["acc"]))
    weights = list(filter(lambda x : "IMAGENET1K_V1" in x["weight"], weights))
    for model in weights:
        print("acc:", model["acc"], "||", "# of parameters", model["params"], "||", "weight:", model["weight"])

def create_dummy_data(wavelengths, n=32):
    data_gen = lambda : {wavelength:torch.randn(1280) for wavelength in wavelengths}
    label_gen = lambda : torch.randint(0, 1, size=(1,), dtype=torch.float32)
    return [(data_gen(), label_gen()) for _ in range(n)]

def print_run_summary(dict, threshold):
    print(f"\n[Threshold: {threshold}]")
    print(f"{'Name':<15} | {'Score':>8}")
    print("-" * 26)
    for name, score in dict.items():
        print(f"{name:<15} | {score:>8.4f}")

def summarize_performance_table(performances, width=15):
    titles = [col_name.center(width) for col_name in performances.keys()]

    s = ("  " + "_" * (width)) * (1 + len(titles)) + "  "
    print(s)
    print("||" + "||".join(["metric".center(width), *titles]) + "||")
    print(s.replace("  ", "||"))

    for metric in performances[max(performances, key=lambda x: len(performances[x].keys()))].keys():
        lst = []
        for key in performances.keys():
            score = performances[key].get(metric, "Nan")
            score = round(score, 4) if isinstance(score, float) else score
            score = str(score)
            lst.append(score)
        print("||" + "||".join([metric.center(width), *[x.center(width) for x in lst]]) + "||")
        print(s.replace("  ", "||"))

def evaluate(y_true, y_pred, metrics, threshold=0.5):
    scores = {}
    
    if isinstance(threshold, str):
        optimial_threshold = 0
        max_score = 0
        for _threshold in torch.linspace(start=0, end=1, steps=100).tolist():
            score = metrics[threshold](y_true, y_pred > _threshold)
            if score > max_score:
                max_score = score
                optimial_threshold = _threshold
        threshold = optimial_threshold
    
    for name, metric in metrics.items():
        scores[name] = metric(y_true, y_pred > threshold)
        
    return scores, threshold

def plot_frequency_bars(columns, labels, show_img=False):
    """
    Takes any number of pandas Series/columns and creates bar plots showing 
    the frequency of 0s and 1s in each column.
    
    Args:
        *columns: Variable number of pandas Series/columns to plot
        
    Returns:
        tuple: (img, *counts)
            - img: PIL Image object with the bar plots
            - counts: Series with value counts for each column
    """
    if not columns:
        raise ValueError("At least one column must be provided")
    
    # Calculate value counts for each column
    counts = []
    for col in columns:
        count = col.value_counts().reindex([0, 1], fill_value=0)
        counts.append(count)
    
    # Determine subplot layout
    n_plots = len(columns)
    n_cols = min(3, n_plots)  # Max 3 columns
    n_rows = (n_plots + n_cols - 1) // n_cols
    
    # Create figure with subplots
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(5*n_cols, 5*n_rows))
    fig.suptitle('Peak flux distribution', fontsize=14, fontweight='bold')
    
    # Flatten axes array for easier iteration
    if n_plots == 1:
        axes = [axes]
    else:
        axes = axes.flatten() if n_plots > 1 else [axes]
    
    # Find global max for consistent y-axis
    global_max = max(count.max() for count in counts)
    
    # Plot each column
    for idx, (col, count) in enumerate(zip(columns, counts)):
        ax = axes[idx]
        
        # Create bar plot
        ax.bar([0, 1], count.values, color=['#e74c3c', '#2ecc71'], 
               alpha=0.8, edgecolor='black', width=0.6)
        ax.set_xticks([0, 1])
        ax.set_xticklabels(['0', '1'])
        ax.set_title(col.name or f'Column {idx+1}', fontsize=12, fontweight='bold')
        ax.set_xlabel('Value')
        ax.set_ylabel('Frequency')
        ax.grid(axis='y', alpha=0.3)
        ax.set_ylim(0, global_max * 1.15)
        
        # Add value labels on bars with percentages
        for i, (v, label) in enumerate(zip(count.values, labels)):
            ax.text(i, v, label, ha='center', va='bottom', 
                   fontsize=10, fontweight='bold')
    
    # Hide unused subplots
    for idx in range(n_plots, len(axes)):
        axes[idx].set_visible(False)
    
    plt.tight_layout()
    
    # Convert plot to image
    buf = BytesIO()
    plt.savefig(buf, format='png', dpi=100, bbox_inches='tight')
    buf.seek(0)
    img = Image.open(buf)
    plt.close(fig)
    
    counts = [(int(count[0]), int(count[1])) for count in counts]
    
    if show_img:
        display(img)
    
    return (img, *counts)

def visualize_dict_list(data_list, save_dir=None, show_summary=True, title='Data Visualization', exclude=[]):
    """
    Takes a list of dictionaries and creates a multi-subplot visualization.
    Each key gets its own subplot with values plotted over the list indices.
    
    Args:
        data_list: List of dictionaries with numeric values
        
    Returns:
        PIL Image object containing the visualization
    """
    if not data_list:
        raise ValueError("data_list cannot be empty")
    
    # Get all unique keys from all dictionaries
    all_keys = set()
    for d in data_list:
        all_keys.update(d.keys())
    all_keys = sorted(all_keys)
    
    for key in exclude:
        all_keys.remove(key)
    
    # Determine subplot layout
    n_plots = len(all_keys)
    n_cols = min(3, n_plots)  # Max 3 columns
    n_rows = (n_plots + n_cols - 1) // n_cols
    
    # Create figure with subplots
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(5*n_cols, 4*n_rows))
    fig.suptitle(title, fontsize=16, fontweight='bold')
    
    # Flatten axes array for easier iteration
    if n_plots == 1:
        axes = [axes]
    else:
        axes = axes.flatten() if n_plots > 1 else [axes]
    
    # Plot each key
    for idx, key in enumerate(all_keys):
        ax = axes[idx]
        
        # Extract values for this key
        values = [d.get(key) for d in data_list]
        indices = list(range(1, len(data_list) + 1))
        
        # Filter out None values for plotting
        valid_pairs = [(i, v) for i, v in zip(indices, values) if v is not None]
        if valid_pairs:
            valid_indices, valid_values = zip(*valid_pairs)
            ax.plot(valid_indices, valid_values, marker='o', linewidth=2, markersize=6)
            
            # Calculate min and max
            min_val = min(valid_values)
            max_val = max(valid_values)
            title = f"{key}\nMin: {min_val:.2f} | Max: {max_val:.2f}"
        else:
            title = f"{key}\n(No data)"
        
        ax.set_title(title, fontsize=12, fontweight='bold')
        ax.set_xlabel('Epoch')
        ax.set_ylabel('Scores')
        ax.grid(True, alpha=0.3)
    
    # Hide unused subplots
    for idx in range(n_plots, len(axes)):
        axes[idx].set_visible(False)
    
    plt.tight_layout()
    
    # Convert plot to image
    buf = BytesIO()
    plt.savefig(buf, format='png', dpi=100, bbox_inches='tight')
    buf.seek(0)
    img = Image.open(buf)
    plt.close(fig)
    
    if show_summary:
        display(img)
    if save_dir:
        img.save(save_dir)
        print("saving images")

    return img

def plot_metric_bars(results, names, ignore_keys=("epoch", "threshold", "threhsold"),
                    cols=4, save_dir=None, metric="accuracy"):
    if not results:
        raise ValueError("Empty results list")

    # metrics to plot
    metrics = [k for k in results[0].keys() if k not in ignore_keys]
    n = len(metrics)
    rows = (n + cols - 1) // cols

    # --- SORT MODELS BY FIRST METRIC IN DESCENDING ORDER ---
    primary_metric = metric
    order = np.argsort([-r[primary_metric] for r in results])  # descending

    results = [results[i] for i in order]
    names   = [names[i]   for i in order]

    x = np.arange(len(results))

    fig, axes = plt.subplots(rows, cols, figsize=(8 * cols, 3.5 * rows))
    axes = np.array(axes).reshape(-1)

    for idx, metric in enumerate(metrics):
        ax = axes[idx]
        values = [r[metric] for r in results]

        ax.bar(x, values)
        ax.set_title(f"{metric} (sorted by {primary_metric})")
        ax.set_xticks(x)
        ax.set_xticklabels(names, rotation=45, ha="right")
        ax.grid(axis="y", linestyle="--", alpha=0.4)

    # hide empty axes
    for j in range(len(metrics), len(axes)):
        axes[j].axis("off")

    plt.tight_layout()

    if save_dir:
        plt.savefig(save_dir, dpi=300)

    buf = BytesIO()
    plt.savefig(buf, format='png', dpi=100, bbox_inches='tight')
    img = Image.open(buf)

    display(img)

def check_incomplete(wavelengths_dir, num_epochs):
    incomlete = filter(
        lambda wavelength : read_json(path.join(wavelengths_dir, wavelength, "config.json"))["training_epoch"] < (num_epochs - 1),
        [d for d in os.listdir(wavelengths_dir) if path.exists(path.join(wavelengths_dir, d, "config.json"))]
        )
    incomlete = list(incomlete)
    return [incomlete[0]] if len(incomlete) > 0 else []

def set_global_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    # For full reproducibility (slower)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def pick_best_model(wavelength_dir, metric, dynamic=True):
    p = path.join(wavelength_dir, "val_records", "dynamic" if dynamic else "fixed")
    records = [read_json(path.join(p, record)) for record in os.listdir(p)]
    best_record = max(records, key=lambda x: x[metric])
    return f"ckpt_{best_record['epoch'] * 2}.pt", best_record

In [None]:
# dataset.py
class WaveLenghtDataset(Dataset):
    def __init__(self, base_route:str, meta_labels:pd.DataFrame, wavelengths, transform=None, augmentation=None, augmentation_class="both"):

        assert augmentation_class in ["positive", "negative", "both"]

        self._metadata = meta_labels
        self._wavelengths = wavelengths
        self._transform = transform
        self._augmentation = augmentation
        self.augmentation_class = augmentation_class

        self._paths = {
            (fileid + "_" + filename) : path.join(base_route, fileid, filename, image)
            for fileid in os.listdir(base_route) if not fileid.endswith(".csv")
            for filename in os.listdir(path.join(base_route, fileid))
            for image in os.listdir(path.join(base_route, fileid, filename))
            if image.split("__")[-1].split(".")[0] == self._wavelengths
        }

    def __len__(self):
        return len(self._metadata)

    def get_labels(self):
        return self._metadata["peak_flux"]

    def __getitem__(self, idx):
        try:
            image_id, label = self._metadata.iloc[idx]
            image = self._paths[image_id]

            if self._transform:
                image = self._transform(image)

            if self.augmentation_class in ["positive", "both"] and label == 1:
                if self._augmentation:
                    image = self._augmentation(image)

            if self.augmentation_class in ["negative", "both"] and label == 0:
                if self._augmentation:
                    image = self._augmentation(image)
                
            label = torch.tensor(label, dtype=torch.float32)
            return image, label
        except:
            return self[(idx+1) % len(self)]
    
    def set_transform(self, t):
        self._transform = t

class WaveLenghtDatasetV2(Dataset):
    def __init__(self, base_route:str, meta_labels:pd.DataFrame, wavelengths, transform=None, augmentation=None, augmentation_class="both"):

        assert augmentation_class in ["positive", "negative", "both"]

        self._metadata = meta_labels
        self._wavelengths = wavelengths
        self._transform = transform
        self._augmentation = augmentation
        self.augmentation_class = augmentation_class
        self._paths = {
            (fileid + "_" + filename) :[
                path.join(base_route, fileid, filename, image)
                for image in filter(
                        lambda image : image.split("__")[-1].split(".")[0] == self._wavelengths,
                        os.listdir(path.join(base_route, fileid, filename))
                        )
                ]
            for fileid in os.listdir(base_route) if not fileid.endswith(".csv")
            for filename in os.listdir(path.join(base_route, fileid))
        }

        self._indices = [(i, j) for i, _id in enumerate(self._metadata["id"]) for j in range(len(self._paths[_id]))]

    def __len__(self):
        return len(self._indices)

    def __getitem__(self, idx):
        try:
            file_idx, image_idx = self._indices[idx]
            image_id, label = self._metadata.iloc[file_idx]
            image = self._paths[image_id][image_idx]

            if self._transform:
                image = self._transform(image)

            if self.augmentation_class in ["positive", "both"] and label == 1:
                if self._augmentation:
                    image = self._augmentation(image)

            if self.augmentation_class in ["negative", "both"] and label == 0:
                if self._augmentation:
                    image = self._augmentation(image)
                
            label = torch.tensor(label, dtype=torch.float32)
            return image, label
        except:
            return self[(idx+1) % len(self)]

    def set_transform(self, t):
        self._transform = t

class SDODataset(Dataset):
    def __init__(self, base_route:str, metadata:pd.DataFrame, wavelengths, transform=None, augmentation=None, augmentation_class="both"):
        assert augmentation_class in ["positive", "negative", "both"]

        self._wavelengths = wavelengths
        self._transform = transform
        self._augmentation = augmentation
        self.augmentation_class = augmentation_class

        self._paths = {
            (fileid + "_" + filename) : path.join(base_route, fileid, filename)
            for fileid in os.listdir(base_route) if not fileid.endswith(".csv")
            for filename in os.listdir(path.join(base_route, fileid))
            if len(os.listdir(path.join(base_route, fileid, filename))) == 40
        }

        self._metadata = metadata[metadata["id"].isin(list(self._paths.keys()))]

    def __len__(self):
        return len(self._metadata)

    def __getitem__(self, idx):
            image_id, label = self._metadata.iloc[idx]
            images_dir = self._paths[image_id]
            images = [path.join(images_dir, image) for image in os.listdir(images_dir)]
            images_dict = {
                wavelength : list(filter(lambda img_dir : wavelength == img_dir.split("__")[1].split(".")[0], images))
                for wavelength in self._wavelengths
            }

            for wavelength in self._wavelengths:
                images = images_dict[wavelength]
                if self._transform:
                    images = self._transform(images)

                if self.augmentation_class in ["positive", "both"] and label == 1:
                    if self._augmentation:
                        images = self._augmentation(images)

                if self.augmentation_class in ["negative", "both"] and label == 0:
                    if self._augmentation:
                        images = self._augmentation(images)

                images_dict[wavelength] = images

            label = torch.tensor(label, dtype=torch.float32)
            return images_dict, label
    
    def set_transform(self, t):
        self._transform = t

    # def collate_fn(self, lst):
    #     # d = 
    #     pass

class CacheDataset(Dataset):
    def __init__(self, cache_dir, transforms=None, augmentation=None, load_device=None):
        self.load_device = load_device
        self.base_dir = cache_dir
        self._transform = transforms
        self._augmentation = augmentation
        self.ordered_dirs = sorted([
            path.join(label, idx)
            for label in os.listdir(self.base_dir)
            for idx in os.listdir(path.join(self.base_dir, label))
            ], key=lambda x: int(x.split("/")[-1].split("_")[0]))
    def __getitem__(self, idx):
        _dir = path.join(self.base_dir, self.ordered_dirs[idx])
        data = torch.load(_dir, weights_only=False, map_location=self.load_device)

        if self._transform:
            data = self._transform(data)
        if self._augmentation:
            data = self._augmentation(data)

        label = torch.tensor(int(self.ordered_dirs[idx].split("/")[0]), dtype=torch.float32)

        return data, label

    def get_labels(self):
        return [int(_dir.split("/")[0]) for _dir in self.ordered_dirs]

    def __len__(self):
        return len(self.ordered_dirs)

class ListDataset(Dataset):
    def __init__(self, list, transforms=None, augmentation=None, load_device=None):
        self.load_device = load_device
        self._transform = transforms
        self._augmentation = augmentation
        self.list = list
    def __getitem__(self, idx):
        data, label = self.list[idx]
        if self._transform:
            data = self._transform(data)
        if self._augmentation:
            data = self._augmentation(data)
        return data, label

    def __len__(self):
        return len(self.list)

class SynthesizedDataset(Dataset):
    def __init__(self, base_dir, wavelength, synthesized_class, transform=None, augmentation=None, num_samples="all"):
        self.base_dir = base_dir
        self.wavelength = wavelength
        self._transform = transform
        self._augmentation = augmentation
        self.synthesized_class = synthesized_class

        imgs_filename = os.listdir(path.join(base_dir, wavelength))

        if num_samples < 1 and num_samples > 0:
            num_samples = int(num_samples * len(imgs_filename))
        elif num_samples == "all":
            num_samples = len(imgs_filename)

        self.imgs_filename = imgs_filename[:num_samples]

    def __getitem__(self, idx):
        data = path.join(self.base_dir, self.wavelength, self.imgs_filename[idx])

        if self._transform:
            data = self._transform(data)
        if self._augmentation:
            data = self._augmentation(data)
        return data, torch.tensor(self.synthesized_class, dtype=torch.float32)

    def __len__(self):
        return len(self.imgs_filename)

class MergedDatasets(Dataset):
    def __init__(self, dataset, *datasets):
        self.datasets = [dataset, *datasets]
        starts = torch.cumsum(torch.tensor([len(dataset) for dataset in self.datasets]), dim=0).tolist()
        self._starts = [0, *starts[:-1]]
        self._transform = [dataset._transform for dataset in self.datasets]
        self._augmentation = [dataset._augmentation for dataset in self.datasets]

    def index_router(self, idx):
        d_idx = 0
        for d in self.datasets:
            if idx < len(d):
                break
            d_idx += 1
            idx -= len(d)
        return d_idx, idx

    def __getitem__(self, idx):
        d_idx, idx = self.index_router(idx)
        return self.datasets[d_idx][idx]
    
    def __len__(self):
        return sum([len(d) for d in self.datasets])


In [None]:
# model.py
class Conv1DClassifier(nn.Module):
    def __init__(self, in_channel, num_classes=2):
        super().__init__()
        
        # Stem
        self.stem = nn.Sequential(
            nn.Conv1d(in_channel, 256, kernel_size=7, stride=2, padding=3),
            nn.BatchNorm1d(256),
            nn.ReLU(inplace=True),
            nn.MaxPool1d(kernel_size=3, stride=2, padding=1)
        )
        
        # Convolutional blocks (similar to EfficientNet-ish expansion)
        self.blocks = nn.Sequential(
            nn.Conv1d(256, 512, kernel_size=3, padding=1),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            
            nn.Conv1d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            
            nn.Conv1d(512, 1024, kernel_size=3, padding=1),
            nn.BatchNorm1d(1024),
            nn.ReLU(inplace=True),
            
            nn.Conv1d(1024, 1024, kernel_size=3, padding=1),
            nn.BatchNorm1d(1024),
            nn.ReLU(inplace=True),
        )
        
        # Pool & classifier
        self.pool = nn.AdaptiveAvgPool1d(1)
        self.classifier = nn.Linear(1024, num_classes)
        
    def forward(self, x):
        x = self.stem(x)
        x = self.blocks(x)
        x = self.pool(x).squeeze(-1)
        x = self.classifier(x)
        return x

    @property
    def num_parameters(self):
        return sum(p.numel() for p in self.parameters())

class ResidualBlock1D(nn.Module):
    def __init__(self, in_ch, out_ch, kernel_size=3, stride=1):
        super().__init__()
        self.conv1 = nn.Conv1d(in_ch, out_ch, kernel_size, stride=stride, padding=kernel_size//2, bias=False)
        self.bn1 = nn.BatchNorm1d(out_ch)
        self.conv2 = nn.Conv1d(out_ch, out_ch, kernel_size, stride=1, padding=kernel_size//2, bias=False)
        self.bn2 = nn.BatchNorm1d(out_ch)
        self.skip = nn.Conv1d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity()
        
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.skip(x)
        return F.relu(out)

class WideResNet1D(nn.Module):
    def __init__(self, in_channels=10, num_channels=[4096, 4096, 1024], num_classes=1):
        super().__init__()
        self.in_block = ResidualBlock1D(in_channels, num_channels[0])

        blocks = []
        for i in range(1, len(num_channels)):
            blocks.append(ResidualBlock1D(num_channels[i-1], num_channels[i]))

        self.blocks = nn.ModuleList(blocks)
        self.global_pool = nn.AdaptiveAvgPool1d(1)
        self.fc = nn.Linear(num_channels[-1], num_classes)
        
    def forward(self, x):
        x = self.in_block(x)
        for block in self.blocks:
            x = block(x)
        x = self.global_pool(x).squeeze(-1)
        x = self.fc(x)
        return x

class Classifier:
    def __init__(self, model_name="random_forest", use_smote=True, verbose=0, random_state=42):
        self.use_smote = use_smote
        self.verbose = verbose
        self.random_state = random_state

        # Instantiate model based on name with larger configurations
        if model_name == "random_forest":
            self.model = RandomForestClassifier(
                n_estimators=500,
                max_depth=30,
                min_samples_split=2,
                min_samples_leaf=1,
                max_features='sqrt',
                n_jobs=-1,
                random_state=self.random_state
            )
        elif model_name == "extra_trees":
            self.model = ExtraTreesClassifier(
                n_estimators=500,
                max_depth=30,
                min_samples_split=2,
                min_samples_leaf=1,
                max_features='sqrt',
                n_jobs=-1,
                random_state=self.random_state
            )
        elif model_name == "gradient_boosting":
            self.model = GradientBoostingClassifier(
                n_estimators=500,
                learning_rate=0.05,
                max_depth=7,
                min_samples_split=2,
                min_samples_leaf=1,
                subsample=0.8,
                random_state=self.random_state
            )
        elif model_name == "hist_gradient_boosting":
            self.model = HistGradientBoostingClassifier(
                max_iter=500,
                learning_rate=0.05,
                max_depth=15,
                min_samples_leaf=10,
                random_state=self.random_state,
                verbose=self.verbose
            )
        elif model_name == "ada_boost":
            self.model = AdaBoostClassifier(
                n_estimators=500,
                learning_rate=0.5,
                random_state=self.random_state
            )
        elif model_name == "bagging":
            self.model = BaggingClassifier(
                n_estimators=500,
                max_samples=0.8,
                max_features=0.8,
                n_jobs=-1,
                random_state=self.random_state
            )
        elif model_name == "xgboost":
            self.model = XGBClassifier(
                n_estimators=500,
                learning_rate=0.05,
                max_depth=10,
                min_child_weight=1,
                subsample=0.8,
                colsample_bytree=0.8,
                gamma=0,
                reg_alpha=0.1,
                reg_lambda=1,
                use_label_encoder=False,
                eval_metric='logloss',
                tree_method='hist',
                n_jobs=-1,
                random_state=self.random_state,
                verbosity=self.verbose
            )
        elif model_name == "linear_svm":
            self.model = SVC(
                kernel='linear',
                C=1.0,
                probability=True,
                cache_size=1000,
                random_state=self.random_state,
                verbose=(self.verbose>0)
            )
        elif model_name == "lightgbm":
            self.model = LGBMClassifier(
                n_estimators=500,
                learning_rate=0.05,
                max_depth=15,
                num_leaves=63,
                min_child_samples=10,
                subsample=0.8,
                colsample_bytree=0.8,
                reg_alpha=0.1,
                reg_lambda=1,
                n_jobs=-1,
                random_state=self.random_state,
                verbose=self.verbose
            )
        elif model_name == "catboost":
            self.model = CatBoostClassifier(
                iterations=500,
                learning_rate=0.05,
                depth=10,
                l2_leaf_reg=3,
                subsample=0.8,
                random_strength=1,
                thread_count=-1,
                verbose=self.verbose,
                random_state=self.random_state
            )
        else:
            raise ValueError(f"Unknown model name: {model_name}")
    
    def fit(self, X, y, X_val=None, y_val=None):
        """Fit model. Optionally apply SMOTE for imbalance. For boosting, supports eval set for monitoring."""
        
        X = X.numpy()
        y = y.numpy()
        
        if X_val is not None:
            X_val = X_val.numpy()
        if y_val is not None:
            y_val = y_val.numpy()

        if self.use_smote:
            smote = SMOTE(random_state=self.random_state)
            X_res, y_res = smote.fit_resample(X, y)
        else:
            X_res, y_res = X, y

        # Fit model
        if isinstance(self.model, (XGBClassifier, LGBMClassifier, CatBoostClassifier)) and X_val is not None and y_val is not None:
            self.model.fit(X_res, y_res, eval_set=[(X_val, y_val)], verbose=self.verbose)
        else:
            self.model.fit(X_res, y_res)
    
    def __call__(self, X, return_prob=True):
        X = X.detach().numpy()
        """Make predictions. Return probabilities if requested."""
        if return_prob:
            if hasattr(self.model, "predict_proba"):
                x = self.model.predict_proba(X)[:, 1]
                return torch.from_numpy(x)
            else:
                raise ValueError("Model does not support predict_proba")
        else:
            x = self.model.predict(X)
            return torch.from_numpy(x)
    
    def save(self, path):
        """Save the fitted model to disk"""
        with open(path, "wb") as f:
            pickle.dump(self.model, f)

    def load(self, path):
        with open(path, "rb") as f:
            self.model = pickle.load(f)

    @staticmethod
    def list_models():
        return ["random_forest", "extra_trees", "gradient_boosting", "hist_gradient_boosting", 
                "ada_boost", "bagging", "xgboost", "linear_svm", "lightgbm", "catboost"]

In [None]:
# losses.py
class HybridLossFunction(nn.Module):
    def __init__(self, function, *functions, weights=None):
        super().__init__()
        functions = [function, *functions]
        if weights is None:
            weights = [1/len(functions) for _ in range(len(functions))]
        assert len(functions) == len(weights)
        self.functions = functions
        self.weights = weights

    def forward(self, predictions, label):
        total_loss = 0  # ADD THIS
        for function, weight in zip(self.functions, self.weights):
            total_loss += function(predictions, label) * weight  # ACCUMULATE
        return total_loss  # RETURN TOTAL
    

    def __str__(self):
        fns = ",\n".join([str(fn) + " * " + str(w) for w, fn in zip(self.weights, self.functions)])
        return f'HybridLossFunction(\n{fns}\n)'
    def __repr__(self):
        return str(self)

class InverseFreqWeightedBCE(nn.Module):
    def __init__(self, weights, *args, **kwargs):
        super().__init__()

        self.class_weights = torch.tensor(weights, dtype=torch.float32, requires_grad=False)

        self.bce = nn.BCEWithLogitsLoss(*args, **kwargs, pos_weight=self.class_weights[1] / self.class_weights[0])

    def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        return self.bce(logits, targets)

    def __str__(self):
        return f'InverseFreqWeightedBCE(weights={str(self.class_weights.tolist())},{str(self.bce)[str(self.bce).index("()"):]}'
    def __repr__(self):
        return str(self)

class AsymmetricLoss(nn.Module):
    def __init__(self, gamma_pos=1.0, gamma_neg=4.0, eps=1e-8):
        super().__init__()
        self.gamma_pos = gamma_pos
        self.gamma_neg = gamma_neg
        self.eps = eps

    def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        probs = torch.sigmoid(logits).clamp(self.eps, 1 - self.eps)

        pos_loss = targets * (1 - probs) ** self.gamma_pos * torch.log(probs)
        neg_loss = (1 - targets) * probs ** self.gamma_neg * torch.log(1 - probs)

        loss = - (pos_loss + neg_loss)
        return loss.mean()

    def __str__(self):
        return f"AsymmetricLoss(gamma_pos={self.gamma_pos}, gamma_neg={self.gamma_neg})"

class FocalLoss(nn.Module):
    """
    Focal Loss for addressing class imbalance.
    
    Reference: Lin et al. "Focal Loss for Dense Object Detection" (2017)
    
    Args:
        alpha: Weighting factor in [0, 1] to balance positive/negative examples
               or a list of weights for each class
        gamma: Focusing parameter for modulating loss (default: 2.0)
        reduction: 'mean', 'sum', or 'none'
    """
    def __init__(self, weights=[0.25, 0.75], alpha=None, gamma=2.0, reduction='mean'):
        super(FocalLoss, self).__init__()
        alpha = alpha or (weights[0] / sum(weights))
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction
    
    def forward(self, inputs, targets):
        """
        Args:
            inputs: Predicted logits of shape (N, *) 
            targets: Ground truth labels of shape (N, *)
        """
        # Apply sigmoid to get probabilities
        probs = torch.sigmoid(inputs)
        
        # Calculate binary cross entropy
        bce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
        
        # Calculate focal term: (1 - p_t)^gamma
        p_t = probs * targets + (1 - probs) * (1 - targets)
        focal_weight = (1 - p_t) ** self.gamma
        
        # Apply alpha weighting
        if self.alpha is not None:
            if isinstance(self.alpha, (float, int)):
                alpha_t = self.alpha * targets + (1 - self.alpha) * (1 - targets)
            else:
                alpha_t = self.alpha
            focal_loss = alpha_t * focal_weight * bce_loss
        else:
            focal_loss = focal_weight * bce_loss
        
        # Apply reduction
        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

    def __str__(self):
        return f'FocalLoss(alpha={self.alpha}, gamma={self.gamma}, reduction=\'{self.reduction}\')'
    def __repr__(self):
        return str(self)

class HardNegativeMiningLoss(nn.Module):
    """
    Hard Negative Mining for Binary Classification
    
    Keeps ALL positive samples and only the hardest negative samples.
    Hardest = highest loss = most confusing to the model.
    """
    def __init__(self, neg_ratio=3.0, pos_weight=16.0):
        """
        Args:
            neg_ratio: Number of hard negatives per positive (e.g., 3 means 3:1 ratio)
            pos_weight: Weight multiplier for positive class to handle imbalance
        """
        super().__init__()
        self.neg_ratio = neg_ratio
        self.pos_weight = pos_weight
    
    def forward(self, logits, targets):
        """
        Args:
            logits: Raw model outputs [batch_size, 1] or [batch_size]
            targets: Ground truth labels [batch_size, 1] or [batch_size], values in {0, 1}
        
        Returns:
            loss: Scalar tensor
        """
        # Ensure correct shapes
        logits = logits.squeeze()
        targets = targets.squeeze().float()
        
        # Calculate loss for each sample individually (no reduction)
        all_losses = F.binary_cross_entropy_with_logits(
            logits, targets, reduction='none'
        )
        
        # Separate positive and negative samples
        pos_mask = targets == 1
        neg_mask = targets == 0
        
        pos_losses = all_losses[pos_mask]
        neg_losses = all_losses[neg_mask]
        
        # Handle edge cases
        if pos_losses.numel() == 0 and neg_losses.numel() == 0:
            return torch.tensor(0.0, device=logits.device, requires_grad=True)
        
        # === POSITIVE SAMPLES: Keep ALL (they're rare and important) ===
        if pos_losses.numel() > 0:
            # Apply positive class weight
            weighted_pos_loss = (pos_losses * self.pos_weight).mean()
            num_pos = pos_losses.numel()
        else:
            weighted_pos_loss = torch.tensor(0.0, device=logits.device)
            num_pos = 0
        
        # === NEGATIVE SAMPLES: Keep only HARD ones ===
        if neg_losses.numel() > 0 and num_pos > 0:
            # Calculate how many hard negatives to keep
            num_hard_neg = int(num_pos * self.neg_ratio)
            num_hard_neg = min(num_hard_neg, neg_losses.numel())  # Can't exceed available negatives
            
            if num_hard_neg > 0:
                # Select top-k hardest negatives (highest loss)
                hard_neg_losses, _ = torch.topk(neg_losses, num_hard_neg)
                neg_loss = hard_neg_losses.mean()
            else:
                neg_loss = neg_losses.mean()
        elif neg_losses.numel() > 0:
            # No positives in batch, use all negatives
            neg_loss = neg_losses.mean()
        else:
            neg_loss = torch.tensor(0.0, device=logits.device)
        
        # Combine positive and negative losses
        total_loss = weighted_pos_loss + neg_loss
        
        return total_loss

    def __str__(self):
        return f'HardNegativeMiningLoss(neg_ratio={self.neg_ratio}, pos_weight={self.pos_weight})'

    def __repr__(self):
        return str(self)

class AsymmetricFocalLoss(nn.Module):
    def __init__(self, gamma_pos=1.0, gamma_neg=4.0, alpha=0.94, reduction='mean', epsilon=1e-7):
        super().__init__()
        self.gamma_pos = gamma_pos
        self.gamma_neg = gamma_neg
        self.alpha = alpha
        self.reduction = reduction
        self.epsilon = epsilon
    
    def forward(self, logits, targets):
        logits = logits.squeeze()
        targets = targets.squeeze().float()
        probs = torch.sigmoid(logits)

        probs = torch.clamp(probs, self.epsilon, 1 - self.epsilon)
        
        bce_loss = -(targets * torch.log(probs) + (1 - targets) * torch.log(1 - probs))

        focal_weight = torch.where(
            targets == 1,
            (1 - probs) ** self.gamma_pos,  # Positive sample modulation
            probs ** self.gamma_neg          # Negative sample modulation
        )

        alpha_weight = torch.where(
            targets == 1,
            torch.tensor(self.alpha, device=targets.device, dtype=targets.dtype),
            torch.tensor(1 - self.alpha, device=targets.device, dtype=targets.dtype)
        )

        focal_loss = alpha_weight * focal_weight * bce_loss

        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:  # 'none'
            return focal_loss

    def __str__(self):
        return f"gamma_pos={self.gamma_pos}, gamma_neg={self.gamma_neg}, alpha={self.alpha}, epsilon={self.epsilon}, reduction=\'{self.reduction}\'"
    def __repr__(self):
        return str(self)

class LogitAdjustedBCE(nn.Module):
    """
    Logit-Adjusted Binary Cross Entropy Loss for long-tailed recognition.
    
    Adjusts the logits based on class frequencies to handle imbalanced datasets.
    
    Reference: Menon et al. "Long-tail learning via logit adjustment" (2021)
    
    Args:
        pos_prior: Prior probability of positive class (frequency of 1s in training)
        tau: Temperature parameter for adjustment (default: 1.0)
        reduction: 'mean', 'sum', or 'none'
    """
    def __init__(self, weights=[0.5, 0.5], tau=1.0, reduction='mean'):
        super(LogitAdjustedBCE, self).__init__()
        total = sum(weights)
        self.neg_prior = weights[0] / total
        self.pos_prior = weights[1] / total
        self.tau = tau
        self.reduction = reduction
        
        # Calculate log prior adjustment
        self.logit_adjustment = torch.log(torch.tensor(self.pos_prior / self.neg_prior))
    
    def forward(self, inputs, targets):
        """
        Args:
            inputs: Predicted logits of shape (N, *)
            targets: Ground truth labels of shape (N, *)
        """
        # Move adjustment to same device as inputs
        if self.logit_adjustment.device != inputs.device:
            self.logit_adjustment = self.logit_adjustment.to(inputs.device)
        
        # Adjust logits based on class priors
        adjusted_logits = inputs + self.tau * self.logit_adjustment
        
        # Calculate BCE loss with adjusted logits
        loss = F.binary_cross_entropy_with_logits(
            adjusted_logits, targets, reduction=self.reduction
        )
        
        return loss
    
    def __str__(self):
        return f'LogitAdjustedBCE(neg_prior={self.neg_prior}, pos_prior={self.pos_prior}, reduction=\'{self.reduction}\')'
    def __repr__(self):
        return str(self)

In [None]:
# train.py
class Trainer:
    def __init__(
        self,
        title,
        model,
        optim,
        loss_fn,
        eval_metrics,
        train_loader,
        val_loader,
        test_loader=None,
        checkpointing=5,
        lr_scheduler=None,
        lr_step_frequency=200,
        save_model_style="instance", #state, instance
        accumulate_gradient=1,
        threshold=0.5,
        dynamic_thresholding=True,
        dynamic_thresholding_metric="accuracy",
        progress_bar_update="mean",
        device="cpu"):

        self.title = title
        self.model = model.to(device)
        self.optim = optim
        self.loss_fn = loss_fn
        self.eval_metrics = eval_metrics
        self.accumulate_gradient = accumulate_gradient
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.test_loader = test_loader
        self.threshold = threshold
        self.dynamic_thresholding = dynamic_thresholding
        self.dynamic_thresholding_metric = dynamic_thresholding_metric
        self.progress_bar_update = progress_bar_update # laval, sum, avg
        self.checkpointing = checkpointing
        self.lr_scheduler = lr_scheduler
        self.lr_step_frequency = lr_step_frequency
        self.save_model_style = save_model_style
        self.device = device
        self.training_epoch = 1
        self.validating_epoch = 1
        self.lr_scheduler_step_counter = 1
        self._last_dynamic_threshold = None

        # initiate directories
        self.title_dir = title
        self.checkpoint_dir = os.path.join(title, "checkpoints")
        self.train_records_dir = os.path.join(title, "train_records")
        self.val_records_dir = os.path.join(title, "val_records")
        self.test_records_dir = os.path.join(title, "test_records")
        self.prediction_records_dir = os.path.join(title, "prediction_records")

        os.makedirs(self.title_dir, exist_ok=True)
        os.makedirs(self.checkpoint_dir, exist_ok=True)
        os.makedirs(self.train_records_dir, exist_ok=True)
        os.makedirs(self.val_records_dir, exist_ok=True)
        os.makedirs(self.test_records_dir, exist_ok=True)
        os.makedirs(self.prediction_records_dir, exist_ok=True)

        os.makedirs(path.join(self.prediction_records_dir, "train"), exist_ok=True)
        os.makedirs(path.join(self.prediction_records_dir, "val"), exist_ok=True)
        os.makedirs(path.join(self.prediction_records_dir, "test"), exist_ok=True)
        os.makedirs(path.join(self.train_records_dir, "fixed"), exist_ok=True)
        os.makedirs(path.join(self.train_records_dir, "dynamic"), exist_ok=True)
        os.makedirs(path.join(self.val_records_dir, "fixed"), exist_ok=True)
        os.makedirs(path.join(self.val_records_dir, "dynamic"), exist_ok=True)

        write_json({
            "model" : str(model),
            "optimizer" : str(optim),
            "loss_fn" : str(loss_fn),
            "preprocessor" : str(train_loader.dataset._transform),
            "augmentation" : str(train_loader.dataset._augmentation),
        }, path.join(self.title_dir, "components.json"))

    def fit(self, epochs, train_verbose=False, val_verbose=False, summary_verbose=False, validating_frequency=5):
        print("Training initiated...")
        for i in range(epochs - self.training_epoch):
            self.train(train_verbose)

            if (i + 1) % validating_frequency == 0:
                self.val(val_verbose)

            if (i + 1) % self.checkpointing == 0:
                self.save([True, True], True)
                self.save_performance(summary_verbose)

    def train(self, verbose=False):
        self.model.zero_grad()
        self.model.train()
        progress_bar = tqdm(enumerate(self.train_loader), total=len(self.train_loader), desc=f"Training Epoch {self.training_epoch}...")
        
        labels_list = []
        predictions = []
        loss_list = []
        records = {}

        for i, (images, labels) in progress_bar:
            images = images.to(self.device)
            labels = labels.to(self.device)

            out = self.model(images).flatten()
            labels = labels.flatten()
            loss = self.loss_fn(out, labels)
            loss.backward()

            labels_list.append(labels.detach())
            predictions.append(out.detach().sigmoid())
            loss_list.append(loss.detach())

            if (i + 1) % self.accumulate_gradient == 0 or i == len(self.train_loader) - 1 :
                self.optim.step()
                self.model.zero_grad()
                self.lr_scheduler_step_counter += 1

            if self.lr_scheduler and self.lr_scheduler_step_counter % self.lr_step_frequency == 0:
                self.lr_scheduler.step()

            if self.progress_bar_update == "mean":
                update = torch.stack(loss_list).mean()
            elif self.progress_bar_update == "sum":
                update = torch.stack(loss_list).sum()
            elif self.progress_bar_update == "laval":
                update = loss_list[-1].item()

            if self.lr_scheduler:
                progress_bar.set_postfix({"loss":update.item(), "lr": self.lr_scheduler.get_last_lr()[0]})
            else:
                progress_bar.set_postfix({"loss":update.item()})

        labels = torch.cat(labels_list).detach().cpu()
        predictions = torch.cat(predictions).detach().cpu()
        losses = torch.stack(loss_list)

        fixed_scores, fixed_threshold = evaluate(labels, predictions, self.eval_metrics, self.threshold)
        fixed_scores["loss"] = losses.mean().item()
        fixed_scores["threshold"] = fixed_threshold
        fixed_scores["epoch"] = self.training_epoch
        if self.lr_scheduler:
            fixed_scores["learning_rate"] = self.lr_scheduler.get_last_lr()[0]
        records["fixed_threshold"] = fixed_scores
        write_json(fixed_scores, os.path.join(self.train_records_dir, "fixed", f"records_{self.training_epoch}.json"))

        if self.dynamic_thresholding:
            dynamic_scores, dynamic_threshold =  evaluate(labels, predictions, self.eval_metrics, self.dynamic_thresholding_metric)
            dynamic_scores["loss"] = losses.mean().item()
            dynamic_scores["threshold"] = dynamic_threshold
            dynamic_scores["epoch"] = self.training_epoch
            if self.lr_scheduler:
                dynamic_scores["learning_rate"] = self.lr_scheduler.get_last_lr()[0]
            records["dynamic_threshold"] = dynamic_scores
            write_json(dynamic_scores, os.path.join(self.train_records_dir, "dynamic", f"records_{self.training_epoch}.json"))

        if verbose:
            summarize_performance_table(records, width=20)

        prediction_records = {
            "predictions" : predictions.tolist(),
            "labels" : labels.tolist()
        }
        write_json(prediction_records, path.join(self.prediction_records_dir, "train", f"epoch_{self.training_epoch}_predictions.json"))

        self.training_epoch += 1

    def eval_model(self, loader, verbose, threshold):
        validating_epoch = self.validating_epoch
        val_loader = self.val_loader
        dynamic_thresholding = self.dynamic_thresholding
        _threshold = self.threshold
        self.threshold = threshold
        self.dynamic_thresholding = False
        self.validating_epoch = -1
        self.val_loader = loader
        try:
            self.val(verbose)
        except:
            self.val_loader = val_loader
            self.validating_epoch = validating_epoch
            self.dynamic_thresholding = dynamic_thresholding
            self.threshold = _threshold
        finally:
            self.val_loader = val_loader
            self.validating_epoch = validating_epoch
            self.dynamic_thresholding = dynamic_thresholding
            self.threshold = _threshold
            
        
    
    @torch.no_grad()
    def val(self, verbose=False):
        self.model.eval()
        progress_bar = tqdm(self.val_loader, total=len(self.val_loader), desc=f"validating Epoch {self.validating_epoch}...")

        labels_list = []
        predictions = []
        loss_list = []
        records = {}

        for images, labels in progress_bar:
            images = images.to(self.device)
            labels = labels.to(self.device)

            out = self.model(images).flatten()
            labels = labels.flatten()
            loss = self.loss_fn(out, labels)

            labels_list.append(labels.detach())
            predictions.append(out.detach().sigmoid())
            loss_list.append(loss.detach())

            if self.progress_bar_update == "mean":
                update = torch.stack(loss_list).mean()
            elif self.progress_bar_update == "sum":
                update = torch.stack(loss_list).sum()
            elif self.progress_bar_update == "laval":
                update = loss_list[-1]

            progress_bar.set_postfix({"loss":update.item()})

        labels = torch.cat(labels_list).detach().cpu()
        predictions = torch.cat(predictions).detach().cpu()
        losses = torch.stack(loss_list)

        fixed_scores, fixed_threshold = evaluate(labels, predictions, self.eval_metrics, self.threshold)
        fixed_scores["loss"] = losses.mean().item()
        fixed_scores["threshold"] = fixed_threshold
        fixed_scores["epoch"] = self.validating_epoch
        records["fixed_threshold"] = fixed_scores
        write_json(fixed_scores, os.path.join(self.val_records_dir, "fixed", f"records_{self.validating_epoch}.json"))

        if self.dynamic_thresholding:
            dynamic_scores, dynamic_threshold =  evaluate(labels, predictions, self.eval_metrics, self.dynamic_thresholding_metric)
            dynamic_scores["loss"] = losses.mean().item()
            dynamic_scores["threshold"] = dynamic_threshold
            dynamic_scores["epoch"] = self.validating_epoch
            records["dynamic_threshold"] = dynamic_scores
            self._last_dynamic_threshold = dynamic_threshold
            write_json(dynamic_scores, os.path.join(self.val_records_dir, "dynamic", f"records_{self.validating_epoch}.json"))
            
        if verbose:
            summarize_performance_table(records, width=20)

        prediction_records = {
            "predictions" : predictions.tolist(),
            "labels" : labels.tolist()
        }
        write_json(prediction_records, path.join(self.prediction_records_dir, "val", f"epoch_{self.validating_epoch}_predictions.json"))

        try:
            self.validating_epoch += 1
        except:
            print("validating_epoch is set to non-integer")

    @torch.no_grad()
    def test(self, verbose=False):
        self.model.eval()
        progress_bar = tqdm(self.test_loader, total=len(self.test_loader), desc=f"Testing...")

        labels_list = []
        predictions = []
        loss_list = []
        records = {}

        for images, labels in progress_bar:
            images = images.to(self.device)
            labels = labels.to(self.device)

            out = self.model(images).flatten()
            labels = labels.flatten()
            loss = self.loss_fn(out, labels)

            labels_list.append(labels.detach())
            predictions.append(out.detach().sigmoid())
            loss_list.append(loss.detach())

            if self.progress_bar_update == "mean":
                update = torch.stack(loss_list).mean()
            elif self.progress_bar_update == "sum":
                update = torch.stack(loss_list).sum()
            elif self.progress_bar_update == "laval":
                update = loss_list[-1]

            progress_bar.set_postfix({"loss":update.item()})

        labels = torch.cat(labels_list).detach().cpu()
        predictions = torch.cat(predictions).detach().cpu()
        losses = torch.stack(loss_list)

        fixed_scores, fixed_threshold = evaluate(labels, predictions, self.eval_metrics, self.threshold)
        fixed_scores["loss"] = losses.mean().item()
        fixed_scores["threshold"] = fixed_threshold
        fixed_scores["epoch"] = self.validating_epoch
        records["fixed_threshold"] = fixed_scores
        write_json(fixed_scores, os.path.join(self.test_records_dir, f"fixed_record.json"))
            
        if verbose:
            summarize_performance_table(records, width=20)

        prediction_records = {
            "predictions" : predictions.tolist(),
            "labels" : labels.tolist()
        }

        write_json(prediction_records, path.join(self.prediction_records_dir, "test", f"predictions.json"))

    def save_weights(self, step_back=False):
        ckpt = {
            "model" : self.model,
            "optim" : self.optim
        }
        if self.save_model_style == "state":
            torch.save({k:v.state_dict() for (k, v) in ckpt.items()}, path.join(self.checkpoint_dir, f"ckpt_{self.training_epoch - int(step_back)}.pth"))
        if self.save_model_style == "instance":
            torch.save(ckpt, path.join(self.checkpoint_dir, f"ckpt_{self.training_epoch - int(step_back)}.pt"))

    def save_config(self, step_back=[False, False]):
        config = {
        "title" : self.title,
        "weights" : None if len(os.listdir(self.checkpoint_dir)) == 0 else path.join(self.checkpoint_dir, max(os.listdir(self.checkpoint_dir), key=lambda x: int(x.split("_")[1].split(".")[0]) )),
        "accumulate_gradient" : self.accumulate_gradient,
        "threshold" : self.threshold,
        "dynamic_thresholding" : self.dynamic_thresholding,
        "dynamic_thresholding_metric" : self.dynamic_thresholding_metric,
        "progress_bar_update" : self.progress_bar_update,
        "checkpointing" : self.checkpointing,
        "device" : self.device,
        "training_epoch" : self.training_epoch - int(step_back[0]),
        "validating_epoch" : self.validating_epoch - int(step_back[1]),
        "save_model_style" : self.save_model_style
        }
        write_json(config, path.join(self.title_dir, "config.json"))

    def from_config(directory, instances={}, back_step=False):
        config = read_json(directory)

        weights = torch.load(config["weights"], weights_only=False)
        weights_dict = instances["weights"]
        training_epoch = config["training_epoch"]
        validating_epoch = config["validating_epoch"]

        del instances["weights"]
        del config["weights"]
        del config["training_epoch"]
        del config["validating_epoch"]

        for key, weight in weights.items():
            if config["save_model_style"] == "state":
                weights_dict[key].load_state_dict(weight)
            elif config["save_model_style"] == "instance":
                weights_dict[key] = weight

        trainer = Trainer(**weights_dict, **config, **instances)
        trainer.training_epoch = training_epoch - int(back_step)
        trainer.validating_epoch = validating_epoch

        return trainer

    def save(self, config_step_back=[False, False], weights_step_back=False):
        self.save_weights(weights_step_back)
        self.save_config(config_step_back)
        print("model weights and config has been saved successfully")

    def read_performance(self):
        
        # section
        train_fixed = sorted([
            read_json(path.join(self.train_records_dir, "fixed", file))
            for file in os.listdir(path.join(self.train_records_dir, "fixed"))
            ], key=lambda x: x["epoch"])
        train_dynamic = sorted([
            read_json(path.join(self.train_records_dir, "dynamic", file))
            for file in os.listdir(path.join(self.train_records_dir, "dynamic"))
            ], key=lambda x: x["epoch"])

        val_fixed = sorted([
            read_json(path.join(self.val_records_dir, "fixed", file))
            for file in os.listdir(path.join(self.val_records_dir, "fixed"))
            ], key=lambda x: x["epoch"])
        val_dynamic = sorted([
            read_json(path.join(self.val_records_dir, "dynamic", file))
            for file in os.listdir(path.join(self.val_records_dir, "dynamic"))
            ], key=lambda x: x["epoch"])
        
        return {
            "train_fixed": train_fixed, 
            "train_dynamic": train_dynamic,
            "val_fixed": val_fixed,
            "val_dynamic": val_dynamic
        }

    def save_performance(self, verbose=False): # read_performance
        performance_dict = self.read_performance()
        for key, performance in performance_dict.items():
            try:
                visualize_dict_list(
                    performance,
                    path.join(self.title_dir, f"{key}_records_summary.png"),
                    show_summary=verbose,
                    title=f"{key.replace('_', ' ')} threshold summary",
                    exclude=["epoch"]
                    )
            except:
                print(f"no records for {key.replace('_', ' ')}")

    def show_performance(self, threshold_mode="both", iteration_mode="both"):
        assert threshold_mode in ["dynamic", "fixed", "both"]
        assert iteration_mode in ["train", "val", "both"]

        threshold_mode = "" if threshold_mode == "both" else threshold_mode
        iteration_mode = "" if iteration_mode == "both" else iteration_mode

        performance_dict = self.read_performance()
        
        for key, performance in performance_dict.items():
            if threshold_mode in key and iteration_mode in key:
                visualize_dict_list(
                    performance,
                    None,
                    show_summary=True,
                    exclude=["epoch"]
                    )

    def to(self, device):
        self.device = device

class ClassifierTrainer:
    def __init__(
        self,
        title,
        model,
        eval_metrics,
        train_dataset,
        val_dataset,
        test_dataset,
        threshold=0.5,
        dynamic_thresholding=True,
        dynamic_thresholding_metric="accuracy"
        ):

        self.title = title
        self.model = model
        self.eval_metrics = eval_metrics
        self.threshold = threshold
        self.dynamic_thresholding = dynamic_thresholding
        self.dynamic_thresholding_metric = dynamic_thresholding_metric

        # initiate directories
        self.title_dir = title
        self.save_model_dir = os.path.join(title, "model.pk")
        self.train_records_dir = os.path.join(title, "train_records")
        self.val_records_dir = os.path.join(title, "val_records")
        self.test_records_dir = os.path.join(title, "test_records")
        self.prediction_records_dir = os.path.join(title, "prediction_records")

        os.makedirs(self.title_dir, exist_ok=True)
        os.makedirs(self.train_records_dir, exist_ok=True)
        os.makedirs(self.val_records_dir, exist_ok=True)
        os.makedirs(self.test_records_dir, exist_ok=True)
        os.makedirs(self.prediction_records_dir, exist_ok=True)

        if isinstance(train_dataset, str):
            self.x_train, self.y_train = torch.load(train_dataset, weights_only=False, map_location="cpu")
        else:
            self.x_train, self.y_train = self.tensorize(train_dataset)

        self.x_train = self.x_train.reshape(self.x_train.shape[0], -1)

        if isinstance(val_dataset, str):
            self.x_val, self.y_val = torch.load(val_dataset, weights_only=False, map_location="cpu")
        else:
            self.x_val, self.y_val = self.tensorize(val_dataset)

        self.x_val = self.x_val.reshape(self.x_val.shape[0], -1)

        if isinstance(test_dataset, str):
            self.x_test, self.y_test = torch.load(test_dataset, weights_only=False, map_location="cpu")
        else:
            self.x_test, self.y_test = self.tensorize(test_dataset)

        self.x_test = self.x_test.reshape(self.x_test.shape[0], -1)

        write_json({
            "model" : str(model)
        }, path.join(self.title_dir, "components.json"))

    def fit(self, train_verbose=False, val_verbose=False, test_verbose=False):
        print("Training initiated...")
        self.train(train_verbose)
        print("validating initiated...")
        self.val(val_verbose)
        print("testing initiated...")
        threshold = read_json(os.path.join(self.val_records_dir, f"dynamic_record.json"))["threshold"]
        self.test(test_verbose, threshold=threshold)

        self.save()
        self.show_performance(20)

    def train(self, verbose=False):
        self.model.fit(self.x_train, self.y_train)
        predictions = self.model(self.x_train).flatten()
        labels = self.y_train.detach().flatten()

        fixed_scores, fixed_threshold = evaluate(labels, predictions, self.eval_metrics, self.threshold)
        fixed_scores["threshold"] = fixed_threshold
        write_json(fixed_scores, os.path.join(self.train_records_dir, f"fixed_record.json"))

        if self.dynamic_thresholding:
            dynamic_scores, dynamic_threshold =  evaluate(labels, predictions, self.eval_metrics, self.dynamic_thresholding_metric)
            dynamic_scores["threshold"] = dynamic_threshold
            write_json(dynamic_scores, os.path.join(self.train_records_dir, f"dynamic_record.json"))

        if verbose:
            print_run_summary(fixed_scores, fixed_threshold)

            if self.dynamic_thresholding:
                print_run_summary(dynamic_scores, dynamic_threshold)

        prediction_records = {
            "predictions" : predictions.tolist(),
            "labels" : labels.tolist()
        }
        write_json(prediction_records, path.join(self.prediction_records_dir, f"train_predictions.json"))

    def tensorize(self, dataset):
        x = []
        y = []

        for data, label in tqdm(dataset, desc="tensorizing..."):
            x.append(data.cpu())
            y.append(label.cpu())
        
        x = torch.stack(x)
        y = torch.stack(y)
        return x, y

    @torch.no_grad()
    def val(self, verbose=False):
        predictions = self.model(self.x_val).flatten()
        labels = self.y_val.detach().flatten()

        fixed_scores, fixed_threshold = evaluate(labels, predictions, self.eval_metrics, self.threshold)
        fixed_scores["threshold"] = fixed_threshold
        write_json(fixed_scores, os.path.join(self.val_records_dir, f"fixed_record.json"))

        if self.dynamic_thresholding:
            dynamic_scores, dynamic_threshold =  evaluate(labels, predictions, self.eval_metrics, self.dynamic_thresholding_metric)
            dynamic_scores["threshold"] = dynamic_threshold
            write_json(dynamic_scores, os.path.join(self.val_records_dir, f"dynamic_record.json"))

        if verbose:
            print_run_summary(fixed_scores, fixed_threshold)

            if self.dynamic_thresholding:
                print_run_summary(dynamic_scores, dynamic_threshold)

        prediction_records = {
            "predictions" : predictions.tolist(),
            "labels" : labels.tolist()
        }
        write_json(prediction_records, path.join(self.prediction_records_dir, f"val_predictions.json"))

    @torch.no_grad()
    def test(self, threshold=None, verbose=False):
        predictions = self.model(self.x_test).flatten()
        labels = self.y_test.detach().flatten()

        fixed_scores, fixed_threshold = evaluate(labels, predictions, self.eval_metrics, threshold or self.threshold)
        fixed_scores["threshold"] = fixed_threshold
        write_json(fixed_scores, os.path.join(self.test_records_dir, f"fixed_record.json"))

        if verbose:
            print_run_summary(fixed_scores, fixed_threshold)

            if self.dynamic_thresholding:
                print_run_summary(dynamic_scores, dynamic_threshold)

        prediction_records = {
            "predictions" : predictions.tolist(),
            "labels" : labels.tolist()
        }
        write_json(prediction_records, path.join(self.prediction_records_dir, f"test_predictions.json"))

    def save_weights(self):
        self.model.save(self.save_model_dir)

    def save_config(self):
        config = {
        "title" : self.title,
        "weights" : self.save_model_dir,
        "threshold" : self.threshold,
        "dynamic_thresholding" : self.dynamic_thresholding,
        "dynamic_thresholding_metric" : self.dynamic_thresholding_metric,
        }
        write_json(config, path.join(self.title_dir, "config.json"))

    def from_config(
        directory,
        model,
        eval_metrics,
        train_dataset,
        val_dataset
        ):
        config = read_json(directory)

        config["model"] = model.load(config["weights"])
        config["eval_metrics"] = eval_metrics
        config["train_dataset"] = train_dataset
        config["val_dataset"] = val_dataset

        del config["weights"]

        trainer = ClassifierTrainer(**config)

        return trainer

    def save(self):
        self.save_weights()
        self.save_config()
        print("model weights and config has been saved successfully")

    def read_performance(self):
        train_fixed = path.join(self.train_records_dir, "fixed_record.json")
        train_dynamic  = path.join(self.train_records_dir, "dynamic_record.json")
        val_fixed  = path.join(self.val_records_dir, "fixed_record.json")
        val_dynamic  = path.join(self.val_records_dir, "dynamic_record.json")
        
        return {
            "train_fixed": train_fixed, 
            "train_dynamic": train_dynamic,
            "val_fixed": val_fixed,
            "val_dynamic": val_dynamic
        }

    def show_performance(self, width=15):
        performances = self.read_performance()
        performances = {key:read_json(p) for key, p in performances.items()}

        titles = [col_name.center(width) for col_name in performances.keys()]

        s = ("  " + "_" * (width)) * 5 + "  "
        print(s)
        print("||" + "||".join(["metric".center(width), *titles]) + "||")
        print(s.replace("  ", "||"))

        for metric in performances[max(performances, key=lambda x: len(performances[x].keys()))].keys():
            lst = []
            for key in performances.keys():
                score = performances[key].get(metric, "Nan")
                score = round(score, 4) if isinstance(score, float) else score
                score = str(score)
                lst.append(score)
            print("||" + "||".join([metric.center(width), *[x.center(width) for x in lst]]) + "||")
            print(s.replace("  ", "||"))

# do not use this function
# def train_wavelength(
#     config,
#     wavelength,
#     preprocessor,
#     augmentation,
#     eval_metrics,
#     loss_fn,
#     augmentation_class="positive",
#     epochs=100,
#     verboses=[False, False, False],
#     validating_frequency=10
#     ):
#     train_dataset = WaveLenghtDataset(
#         config.train_data_dir,
#         process_meta_data(config.train_meta_dir),
#         wavelength,
#         preprocessor,
#         augmentation,
#         augmentation_class=augmentation_class
#         )

#     val_dataset = WaveLenghtDataset(
#         config.val_data_dir,
#         process_meta_data(config.val_meta_dir),
#         wavelength,
#         preprocessor
#         )

#     sub_title = path.join(config.title, wavelength)
#     model = efficientnet_v2_s(EfficientNet_V2_S_Weights.IMAGENET1K_V1)
#     model.classifier = nn.Sequential(
#         nn.Dropout(0.5),  # Increase dropout
#         nn.Linear(model.classifier[1].in_features, 1)
#     )
#     optim = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
#     train_loader = DataLoader(train_dataset, batch_size=config.batch_size)
#     val_loader = DataLoader(val_dataset, batch_size=config.batch_size)

#     trainer = Trainer(
#             sub_title,
#             model,
#             optim,
#             loss_fn,
#             eval_metrics,
#             train_loader,
#             val_loader,
#             checkpointing=validating_frequency,
#             accumulate_gradient=max(1, int(config.batch_size / config.effective_batch_size)),
#             threshold=0.5,
#             dynamic_thresholding=True,
#             dynamic_thresholding_metric=config.target_metric,
#             progress_bar_update="mean",
#             device="cuda"
#             )

#     trainer.fit(
#         epochs=epochs,
#         train_verbose=verboses[0],
#         val_verbose=verboses[1],
#         summary_verbose=verboses[2],
#         validating_frequency=validating_frequency
#         )

# def get_instance(
#     wavelength,
#     config,
#     preprocessor,
#     augmentation,
#     eval_metrics
#     ):
#     train_dataset = WaveLenghtDataset(
#         config.train_data_dir,
#         process_meta_data(config.train_meta_dir),
#         wavelength,
#         preprocessor,
#         augmentation
#         )

#     val_dataset = WaveLenghtDataset(
#         config.train_data_dir,
#         process_meta_data(config.val_meta_dir),
#         wavelength,
#         preprocessor
#         )
#     model = efficientnet_v2_s(EfficientNet_V2_S_Weights.IMAGENET1K_V1)
#     model.classifier = nn.Sequential(
#         nn.Dropout(0.3),  # Increase dropout
#         nn.Linear(model.classifier[1].in_features, 1)
#     )
#     optim = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
#     loss_fn = torch.nn.BCEWithLogitsLoss()
#     train_loader = DataLoader(train_dataset, batch_size=config.batch_size)
#     val_loader = DataLoader(val_dataset, batch_size=config.batch_size)
    
#     return {
#         "weights" : {
#             "model" : model,
#             "optim" : optim
#             },
#         "loss_fn" : loss_fn,
#         "eval_metrics" : eval_metrics,
#         "train_loader" : train_loader,
#         "val_loader" : val_loader
#         }


In [None]:
# config.py
class Config:
    def __init__(self):
        # -----------------------------
        # Paths
        # -----------------------------
        self.title = "/kaggle/working/pure_original_dataset"
        self.base_route = "/kaggle/input/sdobenchmark/SDOBenchmark_full"
        self.seed = 42

        self.train_data_dir = path.join(self.base_route, "training")
        self.test_data_dir = path.join(self.base_route, "test")

        self.train_meta_dir = path.join(self.train_data_dir, "meta_data.csv")
        self.test_meta_dir = path.join(self.test_data_dir, "meta_data.csv")

        self.cache_dir = path.join(self.title, "cache")

        self.synthesized_images_dir = None    # provided externally 
        # ratio if between (0,1) else will be a fixed number, num of
        # number of positive examples in original dataset is 510
        self.synthesized_sample_size = None # int(510 * 0.50)

        # -----------------------------
        # Data Settings
        # -----------------------------
        self.wavelengths = [
            "94", "131", "171", "193", "211",
            "304", "335", "1700", "continuum", "magnetogram"
        ]
        self.progress_bar_update = "mean"

        self.dynamic_thresholding = True
        self.checkpointing = 2
        self.threshold = 0.5

        self.image_size = 128
        self.batch_size = 16
        self.effective_batch_size = 32
        self.oversampling_ratio = [0.9, 0.1] # # minority/majority ratio
        self.embed_dim = 1280

        # -----------------------------
        # Fusion Model Training
        # -----------------------------
        self.fusion_num_epochs = 20
        self.fusion_batch_size = 16
        self.fusion_learning_rate = 1e-3

        # -----------------------------
        # Classifier Training
        # -----------------------------
        self.use_smote = True
        self.classifier = "xgboost"
        self.learning_rate = 5e-5
        self.weight_decay = 0.05
        self.num_epochs = 24
        self.target_metric = "tss"

        # -----------------------------
        # Execution Flags
        # -----------------------------
        self.PURNE_CHECKPOINTS = False
        self.TRAIN_CLASSIFIERS = False
        self.ENCODE = False
        self.TRAIN_FUSION = True
        self.TRAIN_CLASSIFIER = True

        # -----------------------------
        # System
        # -----------------------------
        self.device = "cuda" if torch.cuda.is_available() else "cpu"

    # -----------------------------------------
    # JSON export
    # -----------------------------------------
    def to_json(self):
        config = {
            "base_route": self.base_route,
            "title": self.title,
            "train_data_dir": self.train_data_dir,
            "test_data_dir": self.test_data_dir,
            "train_meta_dir": self.train_meta_dir,
            "test_meta_dir": self.test_meta_dir,
            "cache_dir": self.cache_dir,

            "wavelengths": self.wavelengths,
            "image_size": self.image_size,
            "batch_size": self.batch_size,
            "effective_batch_size": self.effective_batch_size,
            "oversampling_ratio": self.oversampling_ratio,

            "dynamic_thresholding" : self.dynamic_thresholding,
            "checkpointing" : self.checkpointing,
            "threshold" : self.threshold,

            "fusion_num_epochs": self.fusion_num_epochs,
            "fusion_batch_size": self.fusion_batch_size,
            "fusion_learning_rate": self.fusion_learning_rate,

            "classifier": self.classifier,
            "use_smote": self.use_smote,
            "learning_rate": self.learning_rate,
            "num_epochs": self.num_epochs,
            "target_metric": self.target_metric,

            "TRAIN_CLASSIFIERS": self.TRAIN_CLASSIFIERS,
            "TRAIN_CLASSIFIER": self.TRAIN_CLASSIFIER,
            "TRAIN_FUSION": self.TRAIN_FUSION,
            "ENCODE": self.ENCODE,
            "PURNE_CHECKPOINTS": self.PURNE_CHECKPOINTS,

            "device": self.device,
            "synthesized_images_dir": self.synthesized_images_dir,
            "synthesized_sample_size": self.synthesized_sample_size
        }

        write_json(config, path.join(self.title, "global_config.json"))

config = Config()
set_global_seed(config.seed)

In [None]:
# Data Distribution
meta = process_meta_data(config.train_meta_dir)
test_meta = process_meta_data(config.test_meta_dir)

train_meta, val_meta = train_test_split(
    meta,
    test_size=0.1,
    stratify=meta["peak_flux"],
    random_state=config.seed
)


img, train_freq, val_freq, test_freq = plot_frequency_bars(
    [train_meta["peak_flux"], val_meta["peak_flux"], test_meta["peak_flux"]],
    ["non-severe", "severe"],
    show_img=True
    )

print("training dataset distribution...", train_freq)
print("validation dataset distribution...", val_freq)
print("testing dataset distribution...", test_freq)

os.makedirs(config.title, exist_ok=True)
img.save(path.join(config.title, "data_distribution.png"))
config.to_json()

In [None]:
#  Preprocessor & Augmentation & Loss Function & Evaluation Mertrics
preprocessor = T.Compose([
    ReadImgs("stack"),
    T.Resize((config.image_size, config.image_size)),
    T.ConvertImageDtype(dtype=torch.float32),
    FrequencyChannelTransform(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

augmentation = T.Compose([
    T.RandomHorizontalFlip(p=0.5),
    T.RandomVerticalFlip(p=0.5),
    # T.RandomRotation(degrees=[0, 90, 180, 270]),  # 90Â° increments only
    T.RandomChoice([  # Apply one of these
        T.ColorJitter(brightness=0.02, contrast=0.02),  # Subtle intensity changes
        T.GaussianBlur(kernel_size=3, sigma=(0.01, 0.03)),  # Mild smoothing
    ]),
])

# loss_fn = HybridLossFunction(
#     FocalLoss(weights=[train_freq[0], train_freq[1] + (config.synthesized_sample_size or 0)], gamma=4),
#     LogitAdjustedBCE(weights=[train_freq[0], train_freq[1] + (config.synthesized_sample_size or 0)], tau=1.0),
#     weights=[0.7, 0.3]
#     )
# loss_fn = InverseFreqWeightedBCE([train_freq[0], train_freq[1] + (config.synthesized_sample_size or 0)])
# loss_fn = FocalLoss(alpha=0.939, gamma=2.0)
# loss_fn = AsymmetricFocalLoss(gamma_pos=1.4, gamma_neg=4.0, alpha=0.90)
loss_fn = HybridLossFunction(
    AsymmetricFocalLoss(gamma_pos=1.4, gamma_neg=4.0, alpha=0.93),
    HardNegativeMiningLoss(neg_ratio=2.0, pos_weight=train_freq[0] / train_freq[1]),
    weights=[.9, .1]
)

eval_metrics = {
    "accuracy": accuracy_score,
    "precision": precision_score,
    "recall": recall_score,
    "f1": f1_score,
    "mcc": matthews_corrcoef,
    "tss": true_skill_statistic,
    "roc_auc": roc_auc_score,
    "far": far_score,
    "csi": csi_score,
    "hss": hss_score
}

write_note(
"""
used WeightedRandomSampler where each class weight is it's frequency
utilized full timesteps instead of picking one and ditching the rest
"""[1:], path.join(config.title, "dev_note.txt"))

In [None]:
# Training each classifier
if config.TRAIN_CLASSIFIERS:
    pick_up_progress = False

    # don't cahnge manually
    resume_wavelength = False

    if pick_up_progress:
        wavelengths = [wavelength for wavelength in config.wavelengths if wavelength not in os.listdir(config.title)]
        incompleted_wavelength = check_incomplete(path.join(config.title), config.num_epochs)
        wavelengths = incompleted_wavelength + wavelengths
        resume_wavelength = len(incompleted_wavelength) >= 1 
    else:
        wavelengths = config.wavelengths

    for wavelength in wavelengths:
        print(f"current wavelength: {wavelength}")

        real_dataset = WaveLenghtDatasetV2(
            config.train_data_dir,
            train_meta,
            wavelength,
            preprocessor,
            augmentation
            )

        if config.synthesized_sample_size is not None \
            and config.synthesized_images_dir is not None:
            synthesized_dataset = SynthesizedDataset(
                config.synthesized_images_dir,
                wavelength,
                synthesized_class=1,
                transform=preprocessor,
                augmentation=augmentation,
                num_samples=config.synthesized_sample_size
            )
            train_dataset = MergedDatasets(real_dataset, synthesized_dataset)
            weights = torch.cat([
                torch.tensor(list(real_dataset._metadata["peak_flux"].iloc())).int(),
                torch.ones(len(synthesized_dataset))
                ]).int()
        else:
            train_dataset = real_dataset
            weights = torch.tensor(list(real_dataset._metadata["peak_flux"].iloc())).int()

        val_dataset = WaveLenghtDatasetV2(
            config.train_data_dir,
            val_meta,
            wavelength,
            preprocessor
            )

        test_dataset = WaveLenghtDatasetV2(
            config.test_data_dir,
            test_meta,
            wavelength,
            preprocessor
        )

        sampler = WeightedRandomSampler(
            weights=(weights.bincount() * (torch.tensor(config.oversampling_ratio) if config.oversampling_ratio is not None else 1))[weights],
            num_samples=len(weights),
            replacement=True
        )

        os.makedirs(config.cache_dir, exist_ok=True)

        sub_title = path.join(config.title, wavelength)
        model = efficientnet_v2_s(EfficientNet_V2_S_Weights.IMAGENET1K_V1)
        model.classifier = nn.Sequential(
            nn.Dropout(0.5),  # Increase dropout
            nn.Linear(model.classifier[1].in_features, 1)
        )

        optim = torch.optim.AdamW(model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)
        train_loader = DataLoader(train_dataset, batch_size=config.batch_size, sampler=sampler)
        val_loader = DataLoader(val_dataset, batch_size=config.batch_size)
        test_loader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False)

        if resume_wavelength:
            resume_wavelength = False
            trainer = Trainer.from_config(
                path.join(config.title, wavelength, "config.json"),
                get_instance(
                    wavelength,
                    config,
                    preprocessor,
                    augmentation,
                    eval_metrics,
                    ),
                back_step=True
            )
        else:
            trainer = Trainer(
                    sub_title,
                    model,
                    optim,
                    loss_fn,
                    eval_metrics,
                    train_loader,
                    val_loader,
                    test_loader
                    checkpointing=config.checkpointing,
                    accumulate_gradient=max(1, int(config.batch_size / config.effective_batch_size)),
                    threshold=config.threshold,
                    dynamic_thresholding=config.dynamic_thresholding,
                    dynamic_thresholding_metric=config.target_metric,
                    progress_bar_update=config.progress_bar_update,
                    device=config.device
                    )

        trainer.fit(
            epochs=config.num_epochs,
            train_verbose=False,
            val_verbose=True,
            summary_verbose=False,
            validating_frequency=config.checkpointing
            )

        model_name, record = pick_best_model(
            path.join(config.title, wavelength),
            config.target_metric
            )

        trainer.model = torch.load(
            path.join(config.title, wavelength, "checkpoints", model_name),
            weights_only=False, map_location=config.device)["model"]
        trainer.threshold = record["threshold"]

        trainer.test(verbose=True)

        if config.PURNE_CHECKPOINTS:
            ckpts_dir = path.join(config.title, wavelength, "checkpoints")
            for ckpt in os.listdir(ckpts_dir):
                if ckpt != model_name:
                    os.remove(path.join(ckpts_dir, ckpt))

In [None]:
# Testing each classifier block is -commented for records-

# for wavelength in config.wavelengths:
#     print(f"current wavelength: {wavelength}")

#     real_dataset = WaveLenghtDatasetV2(
#         config.train_data_dir,
#         train_meta,
#         wavelength,
#         preprocessor,
#         augmentation
#         )

#     if config.synthesized_sample_size is not None \
#         and config.synthesized_images_dir is not None:
#         synthesized_dataset = SynthesizedDataset(
#             config.synthesized_images_dir,
#             wavelength,
#             synthesized_class=1,
#             transform=preprocessor,
#             augmentation=augmentation,
#             num_samples=config.synthesized_sample_size
#         )
#         train_dataset = MergedDatasets(real_dataset, synthesized_dataset)
#         weights = torch.cat([
#             torch.tensor(list(real_dataset._metadata["peak_flux"].iloc())).int(),
#             torch.ones(len(synthesized_dataset))
#             ]).int()
#     else:
#         train_dataset = real_dataset
#         weights = torch.tensor(list(real_dataset._metadata["peak_flux"].iloc())).int()

#     val_dataset = WaveLenghtDatasetV2(
#         config.train_data_dir,
#         val_meta,
#         wavelength,
#         preprocessor
#         )

#     test_dataset = WaveLenghtDatasetV2(
#         config.test_data_dir,
#         test_meta,
#         wavelength,
#         preprocessor
#     )

#     sampler = WeightedRandomSampler(
#         weights=(weights.bincount() * (torch.tensor(config.oversampling_ratio) if config.oversampling_ratio is not None else 1))[weights],
#         num_samples=len(weights),
#         replacement=True
#     )

#     os.makedirs(config.cache_dir, exist_ok=True)

#     sub_title = path.join(config.title, wavelength)

#     model_name, record = pick_best_model(
#         path.join(config.title, wavelength),
#         config.target_metric
#         )

#     model = torch.load(
#         path.join(config.title, wavelength, "checkpoints", model_name),
#         weights_only=False, map_location=config.device)["model"]

#     optim = torch.optim.AdamW(model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)
#     train_loader = DataLoader(train_dataset, batch_size=config.batch_size, sampler=sampler)
#     val_loader = DataLoader(val_dataset, batch_size=config.batch_size)
#     test_loader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False)

#     trainer = Trainer(
#             sub_title,
#             model,
#             optim,
#             loss_fn,
#             eval_metrics,
#             train_loader,
#             val_loader,
#             test_loader,
#             checkpointing=config.checkpointing,
#             accumulate_gradient=max(1, int(config.batch_size / config.effective_batch_size)),
#             threshold=record["threshold"],
#             dynamic_thresholding=config.dynamic_thresholding,
#             dynamic_thresholding_metric=config.target_metric,
#             progress_bar_update=config.progress_bar_update,
#             device=config.device
#             )

#     trainer.test(verbose=True)


In [None]:
# training ML classifier to try a different classification head

if config.TRAIN_CLASSIFIER:
    train_dataset = CacheDataset(
        path.join(config.cache_dir, "train"),
        SDOCacheTransform()
    )

    val_dataset = CacheDataset(
        path.join(config.cache_dir, "val"),
        SDOCacheTransform()
    )

    test_dataset = CacheDataset(
        path.join(config.cache_dir, "test"),
        SDOCacheTransform()
    )

    torch.save(
        ClassifierTrainer.tensorize(None, train_dataset),
        path.join(config.cache_dir, "classifier", "train.pt")
        )

    torch.save(
        ClassifierTrainer.tensorize(None, val_dataset),
        path.join(config.cache_dir, "classifier", "val.pt")
        )

    torch.save(
        ClassifierTrainer.tensorize(None, test_dataset),
        path.join(config.cache_dir, "classifier", "test.pt")
        )

    model_pick = Classifier.list_models() if config.classifier == "try_all" else [config.classifier]
    for model_name in model_pick:
        print(f"current model: {model_name}")
        classifer_trainer = ClassifierTrainer(
            path.join(config.title, model_name),
            Classifier(use_smote=config.use_smote, random_state=config.seed),
            eval_metrics=eval_metrics,
            train_dataset=path.join(config.cache_dir, "classifier", "train.pt"),
            val_dataset=path.join(config.cache_dir, "classifier", "val.pt"),
            test_dataset=path.join(config.cache_dir, "classifier", "test.pt"),
            dynamic_thresholding=config.dynamic_thresholding,
            dynamic_thresholding_metric=config.target_metric
        )

        classifer_trainer.fit()

In [None]:
# comparing models performance on the testing dataset
records = []
for wavelength in config.wavelengths:
    record = read_json(path.join(config.title, wavelength, "test_records", "fixed_record.json"))
    records.append(record)

plot_metric_bars(records, config.wavelengths, cols=2, save_dir=path.join(config.title, "models_test_comparison_plot.png"), metric=config.target_metric)

In [None]:
# Encoding the datasets for quicker training
if config.ENCODE:
    models = {}
    records = []
    for wavelength in config.wavelengths:
        model_name, record = pick_best_model(
            path.join(config.title, wavelength),
            config.target_metric
            )
        records.append(record)

        model = torch.load(path.join(config.title, wavelength, "checkpoints", model_name), weights_only=False, map_location=config.device)["model"]
        models[wavelength] = model.eval()

    plot_metric_bars(records, config.wavelengths, cols=2, save_dir=path.join(config.title, "models_validation_comparison_plot.png"), metric=config.target_metric)

    sdo_preprocessor = T.Compose([
        ReadImgs("stack"),
        T.Lambda(lambda x: x.unsqueeze(1)),
        T.Resize((config.image_size, config.image_size)),
        T.ConvertImageDtype(dtype=torch.float32),
        FrequencyChannelTransform(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    train_dataset = SDODataset(
        config.train_data_dir,
        train_meta,
        config.wavelengths,
        sdo_preprocessor
    )

    val_dataset = SDODataset(
        config.train_data_dir,
        val_meta,
        config.wavelengths,
        sdo_preprocessor
    )

    test_dataset = SDODataset(
        config.test_data_dir,
        test_meta,
        config.wavelengths,
        sdo_preprocessor
    )

    os.makedirs(path.join(config.cache_dir, "train"), exist_ok=True)
    os.makedirs(path.join(config.cache_dir, "val"), exist_ok=True)
    os.makedirs(path.join(config.cache_dir, "test"), exist_ok=True)

    os.makedirs(path.join(config.cache_dir, "train", "1"), exist_ok=True)
    os.makedirs(path.join(config.cache_dir, "train", "0"), exist_ok=True)
    os.makedirs(path.join(config.cache_dir, "val", "1"), exist_ok=True)
    os.makedirs(path.join(config.cache_dir, "val", "0"), exist_ok=True)
    os.makedirs(path.join(config.cache_dir, "test", "1"), exist_ok=True)
    os.makedirs(path.join(config.cache_dir, "test", "0"), exist_ok=True)

    with torch.no_grad():
        id_counter = 0
        for i, (imgs_dict, label) in enumerate(tqdm(train_dataset, desc="encoding...")):
            encodings = {}
            for wavelength, img in imgs_dict.items():
                out = models[wavelength].features(img.to(config.device))
                encodings[wavelength] = models[wavelength].avgpool(out).flatten(1)

            for timestep in range(4):
                timestep_dict = {wavelength : encodings[wavelength][timestep] for wavelength in encodings.keys() if encodings[wavelength][timestep].size(0) > timestep}
                save_dir = path.join(config.cache_dir, "train", str(int(label.item())), f"{id_counter}_{i}_{timestep}.pt")
                id_counter += 1
                torch.save(timestep_dict, save_dir)

        id_counter = 0
        for i, (imgs_dict, label) in enumerate(tqdm(val_dataset, desc="encoding...")):
            encodings = {}
            for wavelength, img in imgs_dict.items():
                out = models[wavelength].features(img.to(config.device))
                encodings[wavelength] = models[wavelength].avgpool(out).flatten(1)

            for timestep in range(4):
                timestep_dict = {wavelength : encodings[wavelength][timestep] for wavelength in encodings.keys() if encodings[wavelength][timestep].size(0) > timestep}
                save_dir = path.join(config.cache_dir, "val", str(int(label.item())), f"{id_counter}_{i}_{timestep}.pt")
                id_counter += 1
                torch.save(timestep_dict, save_dir)

        id_counter = 0
        for i, (imgs_dict, label) in enumerate(tqdm(test_dataset, desc="encoding...")):
            encodings = {}
            for wavelength, img in imgs_dict.items():
                out = models[wavelength].features(img.to(config.device))
                encodings[wavelength] = models[wavelength].avgpool(out).flatten(1)

            for timestep in range(4):
                timestep_dict = {wavelength : encodings[wavelength][timestep] for wavelength in encodings.keys() if encodings[wavelength][timestep].size(0) > timestep}
                save_dir = path.join(config.cache_dir, "test", str(int(label.item())), f"{id_counter}_{i}_{timestep}.pt")
                id_counter += 1
                torch.save(timestep_dict, save_dir)

In [None]:
# Training Fusion model (the model that trains based on the embedding of all the single-channel models)
if config.TRAIN_FUSION:
    fusion_models = [
        ("resnet_1D", lambda : WideResNet1D(in_channels=len(config.wavelengths), num_channels=[1024, 512, 256, 128]))
        # ("robert", lambda : RobertaSeqClassifier(embed_dim=config.embed_dim))
        # ("distil_bert", lambda : DistilBertSeqClassifier(embed_dim=config.embed_dim)),
    ]

    for sub_title, model in fusion_models:
        train_dataset = CacheDataset(
            path.join(config.cache_dir, "train"),
            SDOCacheTransform()
        )

        val_dataset = CacheDataset(
            path.join(config.cache_dir, "val"),
            SDOCacheTransform()
        )

        test_dataset = CacheDataset(
            path.join(config.cache_dir, "test"),
            SDOCacheTransform()
        )


        labels = train_dataset.get_labels()
        freq_count = [len(labels) - sum(labels), sum(labels)]

        weights = (
            freq_count
            if config.oversampling_ratio is None
            else [w * f for w, f in zip(config.oversampling_ratio, freq_count)]
            )

        sampler = WeightedRandomSampler(
            weights=[weights[label] for label in labels],
            num_samples=len(labels),
            replacement=True
        )

        sub_title = path.join(config.title, sub_title)
        model = model()
        optim = torch.optim.AdamW(model.parameters(), lr=1)
        scheduler = torch.optim.lr_scheduler.LinearLR(optim, start_factor=config.fusion_learning_rate, end_factor=1e-5, total_iters=30)
        loss_fn = HybridLossFunction(
            FocalLoss(weights=freq_count, gamma=3),
            LogitAdjustedBCE(weights=freq_count, tau=1.0),
            weights=[0.7, 0.3]
        )

        trainer = Trainer(
            sub_title,
            model,
            optim,
            loss_fn,
            eval_metrics,
            DataLoader(train_dataset, batch_size=config.fusion_batch_size, sampler=sampler),
            DataLoader(val_dataset, batch_size=config.fusion_batch_size),
            DataLoader(test_dataset, batch_size=config.fusion_batch_size),
            checkpointing=config.checkpointing,
            lr_scheduler=scheduler,
            lr_step_frequency=int(len(train_dataset) / 3),
            accumulate_gradient=max(1, int(config.fusion_batch_size / config.effective_batch_size)),
            threshold=config.threshold,
            dynamic_thresholding=config.dynamic_thresholding,
            dynamic_thresholding_metric=config.target_metric,
            progress_bar_update=config.progress_bar_update,
            device=config.device
        )

        trainer.fit(
            config.fusion_num_epochs,
            validating_frequency=config.checkpointing
            )

        trainer.test(verbose=True)