In [None]:
import datetime
import random

import pandas as pd
import pyarrow.parquet as pq
import torch
from lightning import LightningDataModule
from lightning.pytorch.callbacks import ModelCheckpoint

from src.utils import check_and_create_dir
from src.pytorch.models.utils import get_default_trainer
from src.pytorch.models.mlp import DynamicMLP
from src.pytorch.models.fastkan import FastKAN
from src.pytorch.data.parquet import Dataset
from src.pytorch.loss.r2 import r2_score_multivariate, r2_loss
from src.schemas.climsim import INPUT_COLUMNS, OUTPUT_COLUMNS
from src.visualization.performance import loss_curve

In [None]:
torch.set_float32_matmul_precision("high")

In [None]:
# TRAINSET_DATA_PATH = "/home/data/subset_train.parquet"
TRAINSET_DATA_PATH = "/home/data/train.parquet"
OUTPUT_DIR = "./results"
check_and_create_dir(OUTPUT_DIR)
MODEL_NAME = "climsim_best_model"

# Common parameters
BATCH_SIZE = 3072
N_EPOCHS = 100
# Given TRAINING_SAMPLE_FRAC=0.7, 70% of the samples will be used for training
TRAINING_SAMPLE_FRAC = 0.7

In [None]:
# Training parameters
# BUFFER_SIZE: number of batches being preloaded in memory
TRAINING_BUFFER_SIZE = 100
# N_GROUP_PER_SAMPLING: number of groups being sampled in each iteration
TRAINING_N_GROUP_PER_SAMPLING = 3
# N_BATCH_PER_SAMPLING: number of batches being sampled in each iteration
TRAINING_N_BATCH_PER_SAMPLING = 100

# Validation parameters
VALIDATION_BUFFER_SIZE = 100
VALIDATION_N_GROUP_PER_SAMPLING = 2
VALIDATION_N_BATCH_PER_SAMPLING = 100

In [None]:
class ClimSimDataModule(LightningDataModule):
    """ Data module for the ClimSim dataset."""

    def __init__(self, data_path: str, batch_size: int) -> None:
        """ Initialize the data module."""

        super().__init__()
        self._data_path = data_path
        self._batch_size = batch_size

    def setup(self, stage: str) -> None:
        # Use full dataset for training
        parquet = pq.ParquetFile(self._data_path, memory_map=True, buffer_size=10)
        all_groups = list(range(0, parquet.num_row_groups))
        train_groups = random.sample(all_groups, int(TRAINING_SAMPLE_FRAC * len(all_groups)))
        val_groups = list(set(all_groups) - set(train_groups))

        # Use one group for testing
        # train_groups = [0]
        # val_groups = [0]

        self.train = Dataset(
            source=self._data_path,
            input_cols=INPUT_COLUMNS,
            target_cols=OUTPUT_COLUMNS,
            batch_size=self._batch_size,
            buffer_size=TRAINING_BUFFER_SIZE,
            groups=train_groups,
            n_group_per_sampling=TRAINING_N_GROUP_PER_SAMPLING,
            n_batch_per_sampling=TRAINING_N_BATCH_PER_SAMPLING,
            to_tensor=True,
            normalize=True,
        )
        self.val = Dataset(
            source=self._data_path,
            input_cols=INPUT_COLUMNS,
            target_cols=OUTPUT_COLUMNS,
            batch_size=self._batch_size,
            buffer_size=VALIDATION_BUFFER_SIZE,
            groups=val_groups,
            n_group_per_sampling=VALIDATION_N_GROUP_PER_SAMPLING,
            n_batch_per_sampling=VALIDATION_N_BATCH_PER_SAMPLING,
            to_tensor=True,
            normalize=True,
        )

        self.train.start_workers()
        self.val.start_workers()

    def train_dataloader(self):
        return self.train.to_dataloader()

    def val_dataloader(self):
        return self.val.to_dataloader()

    def teardown(self, stage: str) -> None:
        self.train.shutdown_workers()
        self.train.clean_up()
        self.val.shutdown_workers()
        self.val.clean_up()

In [None]:
# model = DynamicMLP(
#     loss_train=r2_score_multivariate,
#     loss_val=r2_score_multivariate,
# )

model = FastKAN(
    layers_hidden=[556, 556, 368],
    loss_train=r2_loss,
    loss_val=r2_loss,
)

In [None]:
datamodule = ClimSimDataModule(TRAINSET_DATA_PATH, BATCH_SIZE)

trainer = get_default_trainer(
    model_name="climsim-2024",
    max_epochs=N_EPOCHS,
    max_time=datetime.timedelta(hours=1),
    check_val_every_n_epoch=3,
)

trainer.fit(
    model=model,
    datamodule=datamodule,
)

In [None]:
train_loss, val_loss = model.get_epoch_loss()
ds_train_loss = pd.Series(train_loss, name="Train Loss")
ds_val_loss = pd.Series(val_loss, name="Validation Loss")
loss_curve(
    losses=[ds_train_loss, ds_val_loss],
    close=False,
)