In [None]:
import os

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

from src.pytorch.dataset import BufferedParquetDataset, compute_dataset_statistics
from src.schemas.climsim import INPUT_COLUMNS, OUTPUT_COLUMNS

In [None]:
TRAINSET_DATA_PATH = "/home/data/train.parquet"
X_STATS_PATH = "/home/data/x_stats.parquet"
Y_STATS_PATH = "/home/data/y_stats.parquet"

N_SAMPLES_IN_SUBSET = 1000
OUTPUT_FILEPATH = f"./train_tiny_{N_SAMPLES_IN_SUBSET}.parquet"

In [None]:
# Compute and save dataset statistics if not already done
if not (os.path.exists(X_STATS_PATH) and os.path.exists(Y_STATS_PATH)):
    df_x_stats, df_y_stats = compute_dataset_statistics(TRAINSET_DATA_PATH, INPUT_COLUMNS, OUTPUT_COLUMNS)
    df_x_stats.to_parquet(X_STATS_PATH)
    df_y_stats.to_parquet(Y_STATS_PATH)

In [None]:
df = pd.read_parquet(X_STATS_PATH)

sns.scatterplot(x=df.loc["mean"], y=df.loc["std"], marker=".", alpha=0.5)
plt.xscale("log")
plt.yscale("log")
plt.xlabel("Mean")
plt.ylabel("Standard Deviation")
plt.title("Mean vs Standard Deviation (Log-Log Scale)")

In [None]:
df = pd.read_parquet(Y_STATS_PATH)

sns.scatterplot(x=df.loc["mean"], y=df.loc["std"], marker=".", alpha=0.5)
plt.xscale("log")
plt.yscale("log")
plt.xlabel("Mean")
plt.ylabel("Standard Deviation")
plt.title("Mean vs Standard Deviation (Log-Log Scale)")

In [None]:
dataset = BufferedParquetDataset(
    source=TRAINSET_DATA_PATH,
    x_stats=X_STATS_PATH,
    y_stats=Y_STATS_PATH,
)

In [None]:
df = dataset.generate_tiny_dataset(n_samples=N_SAMPLES_IN_SUBSET)
df.to_parquet(OUTPUT_FILEPATH)

pd.read_parquet(OUTPUT_FILEPATH)