In [None]:
import pandas as pd
import polars as pl
import torch

import src.pytorch.core.dataset.kaggle
import src.pytorch.core.models
import src.pytorch.core.models.utils
import src.visualization.performance
from src.pytorch.core.dataset.loader import DatasetHandler
from src.schemas.climsim import INPUT_COLUMNS, OUTPUT_COLUMNS

In [None]:
TESTSET_DATA_PATH = "/home/data/test.arrow"
TESTSET_PREDICTION_WEIGHTS_PATH = "/home/data/sample_submission.arrow"
MODEL_WEIGHTS_PATH = "./model.pt"

TRAINSET_DATA_PATH = "/home/data/subset_train.arrow"
# TRAINSET_DATA_PATH = "/home/data/train.arrow"

In [None]:
lf = pl.scan_ipc(TRAINSET_DATA_PATH, memory_map=False)
handler = DatasetHandler(
    dataset=lf,
    input_cols=INPUT_COLUMNS,
    target_cols=OUTPUT_COLUMNS,
    batch_size=3072,
    shuffle=True,
    train_fraction=0.8,
)
dataset_train = handler.get_trainset()
dataset_val = handler.get_valset()

In [None]:
model = src.pytorch.core.models.MLP()

trainloader = torch.utils.data.DataLoader(
    dataset_train,
    batch_size=3072,
    # num_workers=4,
    # prefetch_factor=4,
    pin_memory=True,
)
valloader = torch.utils.data.DataLoader(
    dataset_val,
    batch_size=3072,
    # num_workers=4,
    # prefetch_factor=4,
    pin_memory=True,
)

model, best_weights, loss = src.pytorch.core.models.utils.train(
    model=model,
    dataloaders={"Training": trainloader, "Validation": valloader},
    num_epochs=10,
)

src.visualization.performance.loss_curve(loss, close=False)

In [None]:
src.pytorch.core.models.utils.save_weights(model, MODEL_WEIGHTS_PATH)

In [None]:
# Load data
df_submission = pd.read_feather(TESTSET_DATA_PATH)
df_weights = pd.read_feather(TESTSET_PREDICTION_WEIGHTS_PATH)

# Load model
model = src.pytorch.core.models.MLP()
src.pytorch.core.models.utils.load_model(model, MODEL_WEIGHTS_PATH)
model.to(src.env.DEVICE)

In [None]:
# If you want to submit the predictions
# src.pytorch.core.dataset.kaggle.output_compressed_parquet(
#     model=model,
#     df=df_submission,
#     weights=df_weights,
# )