In [None]:
import random

import polars as pl
import pyarrow.parquet as pq
import torch

from src.pytorch.models.utils import get_default_trainer
from src.pytorch.models.mlp import DynamicMLP
from src.pytorch.data.loader import Loadermodule
from src.pytorch.data.parquet import Dataset
from src.schemas.climsim import INPUT_COLUMNS, OUTPUT_COLUMNS

In [None]:
torch.set_float32_matmul_precision("medium")

In [None]:
# TRAINSET_DATA_PATH = "/home/data/subset_train.arrow"
TRAINSET_DATA_PATH = "/home/data/train.parquet"

# Common parameters
BATCH_SIZE = 3072
N_EPOCHS = 10
# Given TRAINING_SAMPLE_FRAC=0.7, 70% of the samples will be used for training
TRAINING_SAMPLE_FRAC = 0.7

# Training parameters
TRAINING_BUFFER_SIZE = 480
TRAINING_N_GROUP_PER_SAMPLING = 3
TRAINING_N_BATCH_PER_SAMPLING = 120

# Validation parameters
VALIDATION_BUFFER_SIZE = 320
VALIDATION_N_GROUP_PER_SAMPLING = 2
VALIDATION_N_BATCH_PER_SAMPLING = 80

# Buffer size calculation
# Max GPU memory usage = buffer_size * batch_size * num_columns * 4 bits
#                      = (480 + 320) * 3072 * 1000 * 4
#                      = 9830400000 bits
#                      = 9.16 GB

In [None]:
lf = pl.scan_parquet(TRAINSET_DATA_PATH)
n_samples = lf.select(pl.len()).collect().item()

parquet = pq.ParquetFile(TRAINSET_DATA_PATH, memory_map=True, buffer_size=10)
all_groups = list(range(0, parquet.num_row_groups))
train_groups = random.sample(all_groups, int(TRAINING_SAMPLE_FRAC * len(all_groups)))
val_groups = list(set(all_groups) - set(train_groups))

dataset_train = Dataset(
    parquet=parquet,
    input_cols=INPUT_COLUMNS,
    target_cols=OUTPUT_COLUMNS,
    n_samples=int(n_samples / len(all_groups) * len(train_groups)),
    groups=train_groups,
    buffer_size=TRAINING_BUFFER_SIZE,
    batch_size=BATCH_SIZE,
    n_group_per_sampling=TRAINING_N_GROUP_PER_SAMPLING,
    n_batch_per_sampling=TRAINING_N_BATCH_PER_SAMPLING,
    to_tensor=True,
)
dataset_val = Dataset(
    parquet=parquet,
    input_cols=INPUT_COLUMNS,
    target_cols=OUTPUT_COLUMNS,
    n_samples=int(n_samples / len(all_groups) * len(val_groups)),
    groups=val_groups,
    buffer_size=VALIDATION_BUFFER_SIZE,
    batch_size=BATCH_SIZE,
    n_group_per_sampling=VALIDATION_N_GROUP_PER_SAMPLING,
    n_batch_per_sampling=VALIDATION_N_BATCH_PER_SAMPLING,
    to_tensor=True,
)

In [None]:
model = DynamicMLP()

trainloader = dataset_train.to_dataloader()
valloader = dataset_val.to_dataloader()
datamodule = Loadermodule(trainloader, valloader)

trainer = get_default_trainer(max_epochs=N_EPOCHS, precision=16)
trainer.fit(
    model=model,
    datamodule=datamodule,
)

dataset_train.shutdown()
dataset_val.shutdown()