In [None]:
import random

import pandas as pd
import polars as pl
import pyarrow.parquet as pq
import torch

import src.pytorch.data.kaggle
import src.pytorch.models
import src.pytorch.models.utils
import src.visualization.performance
# from src.pytorch.data.polars_loader import DatasetHandler, Dataset
from src.pytorch.data.arrow_loader import Dataset
from src.schemas.climsim import INPUT_COLUMNS, OUTPUT_COLUMNS
from src.schemas.math import Domain

In [None]:
TESTSET_DATA_PATH = "/home/data/test.arrow"
TESTSET_PREDICTION_WEIGHTS_PATH = "/home/data/sample_submission.arrow"
MODEL_WEIGHTS_PATH = "./model.pt"

# TRAINSET_DATA_PATH = "/home/data/subset_train.arrow"
TRAINSET_DATA_PATH = "/home/data/train.parquet"

In [None]:
lf = pl.scan_parquet(TRAINSET_DATA_PATH, low_memory=True)
n_samples = lf.select(pl.len()).collect().item()

# If dataset not large
# handler = DatasetHandler(
#     dataset=lf,
#     input_cols=INPUT_COLUMNS,
#     target_cols=OUTPUT_COLUMNS,
#     batch_size=3072,
#     shuffle=False,
#     train_fraction=0.8,
# )
# dataset_train = handler.get_trainset()
# dataset_val = handler.get_valset()

lf = pq.ParquetFile(TRAINSET_DATA_PATH, memory_map=True)
all_groups = list(range(0, lf.num_row_groups))
train_groups = random.sample(all_groups, int(0.8 * len(all_groups)))
val_groups = list(set(all_groups) - set(train_groups))

dataset_train = Dataset(
    fparquet=lf,
    input_cols=INPUT_COLUMNS,
    target_cols=OUTPUT_COLUMNS,
    batch_size=3072,
    n_samples=int(n_samples/len(all_groups) * len(train_groups)),
    groups=train_groups,
)
dataset_val = Dataset(
    fparquet=lf,
    input_cols=INPUT_COLUMNS,
    target_cols=OUTPUT_COLUMNS,
    batch_size=3072,
    n_samples=int(n_samples/len(all_groups) * len(val_groups)),
    groups=val_groups,
)

In [None]:
dataset_train._update_batch()

In [None]:
model = src.pytorch.models.MLP()

trainloader = torch.utils.data.DataLoader(
    dataset_train,
    batch_size=3072,
    # num_workers=4,
    # prefetch_factor=4,
    pin_memory=True,
)
valloader = torch.utils.data.DataLoader(
    dataset_val,
    batch_size=3072,
    # num_workers=4,
    # prefetch_factor=4,
    pin_memory=True,
)

model, best_weights, loss = src.pytorch.models.utils.train(
    model=model,
    dataloaders={"Training": trainloader, "Validation": valloader},
    num_epochs=10,
)

src.visualization.performance.loss_curve(loss, close=False)

In [None]:
# src.pytorch.models.utils.save_weights(model, MODEL_WEIGHTS_PATH)

In [None]:
# Load data
df_submission = pd.read_feather(TESTSET_DATA_PATH)
df_weights = pd.read_feather(TESTSET_PREDICTION_WEIGHTS_PATH)

# Load model
model = src.pytorch.models.MLP()
src.pytorch.models.utils.load_model(model, MODEL_WEIGHTS_PATH)
model.to(src.env.DEVICE)

In [None]:
# If you want to submit the predictions
# src.pytorch.data.kaggle.output_compressed_parquet(
#     model=model,
#     df=df_submission,
#     weights=df_weights,
# )