In [None]:
import os
from data.source.pg_experiment import get_pg_experiment_dataframe
import polars as pl

from models.SimplifiedLightweightCNN import SimplifiedLightweightCNN
%load_ext autoreload
%autoreload 1
%aimport models.SimplifiedLightweightCNN
from models.SimpleCNN_v2 import train, evaluate
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
from path import RESULT_DIRECTORY
import wandb

os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

In [None]:
df_pronv1, _ = get_pg_experiment_dataframe(".ogg", assesment_version="v1")
df_pronv2, _ = get_pg_experiment_dataframe(".ogg", assesment_version="v2")

dataframe = df_pronv1.join(df_pronv2, on=["id_student", "word_id"], how="inner", suffix="_v2")
dataframe = dataframe.rename({"value": "v1_value", "value_v2": "v2_value"})
dataframe = dataframe.with_columns(word_id = pl.struct("word_id").rank("dense"))
dataframe = dataframe.filter(pl.col("stage") == 1)

df_outer = dataframe.filter(pl.col("v1_value") != pl.col("v2_value"))
df_inner = dataframe.filter(pl.col("v1_value") == pl.col("v2_value"))



N_WORDS = dataframe.select(pl.col("word_id").n_unique()).to_numpy()[0][0]
print(f"Number of unique words: {N_WORDS}")
print(f"Number of samples: {dataframe.shape[0]}")
print(f"Samples with v1_value != v2_value: {df_outer.shape[0]}")

In [None]:
import polars as pl
import numpy as np
from typing import Tuple

def split(df_outer: pl.DataFrame, df_inner: pl.DataFrame, label_col: str, train_frac=0.8, val_frac=0.1, seed=42) -> Tuple[pl.DataFrame, pl.DataFrame, pl.DataFrame]:
    classes = df_outer.select(label_col).unique().to_series()
    train_rows, val_rows, test_rows = [], [], []
    rng = np.random.RandomState(seed)
    
    # test_rows, get random from df_inner of size df_outer * (1 - train_frac - val_frac)
    
    n_subset = df_inner.height
    n_test = int((1 - train_frac - val_frac) * df_outer.height)
    indices_subset = rng.permutation(n_subset)
    test_rows.append(df_inner[indices_subset[:n_test]])
    
    df = pl.concat([df_outer, df_inner[indices_subset[n_test:]]])
    
    # update fractions
    total_frac = train_frac + val_frac
    train_frac = train_frac / total_frac
    val_frac = val_frac / total_frac


    for cls in classes:
        class_df = df.filter(pl.col(label_col) == cls)
        n = class_df.height
        indices = rng.permutation(n)

        train_end = int(train_frac * n)
        val_end = int((train_frac + val_frac) * n)

        train_rows.append(class_df[indices[:train_end]])
        val_rows.append(class_df[indices[train_end:val_end]])

    train_df = pl.concat(train_rows)
    val_df = pl.concat(val_rows)
    test_df = pl.concat(test_rows)

    return train_df, val_df, test_df


In [None]:
from typing import Callable

from polars import DataFrame
from dataset import Cast, TorchDataset
from develop import reload_function, reload_module
import pytorch_dataloader
reload_module(pytorch_dataloader)
from pytorch_dataloader import ReshapeCollate, build_collate_fn, PaddingCollate, DefaultCollate
from functools import partial

from transformation import Channels, RMSEnergy, TorchVadLogMelSpec, TorchVadMFCC, ZeroCrossingRate

reload_function(TorchVadMFCC)

TRAIN_SPLIT = 0.6
VAL_SPLIT = 0.2
TEST_SPLIT = 1 - TRAIN_SPLIT - VAL_SPLIT
train_pl, val_pl, test_pl = split(df_outer=df_outer, df_inner=df_inner, label_col="word_id", train_frac=TRAIN_SPLIT, val_frac=VAL_SPLIT)
val_pl = val_pl.filter(pl.col("v1_value") == pl.col("v2_value"))
test_pl = test_pl.filter(pl.col("v1_value") == pl.col("v2_value"))

to_dataset: Callable[[DataFrame], TorchDataset] = lambda dataframe: TorchDataset(
    Cast(dataframe.get_column("rec_path"), Channels("stack","multiply")(
            TorchVadMFCC(delta=0),
        )),
    Cast(dataframe.get_column("word_id"), lambda x: torch.tensor(x-1, dtype=torch.long)),
    Cast(dataframe.get_column("v1_value"), lambda x: torch.tensor(x).float()),
    Cast(dataframe.get_column("v2_value"), lambda x: torch.tensor(x).float()),
)

collate_fn = build_collate_fn(
    PaddingCollate(mode="SET_MAX_LEN", max_len=80, pad_dim=2),
    DefaultCollate(),
    DefaultCollate(),
    DefaultCollate(),
)
dataset_train = to_dataset(train_pl)
dataset_val = to_dataset(val_pl)
dataset_test = to_dataset(test_pl)

In [None]:
from pytorch_dataloader import MemoryLoadedDataLoader
from os import name
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#note, if you are using Windows you MUST set `num_workers=0` - TL;DT multithreading DON'T work in notebooks because Windows DON'T have `fork()`
num_workers = 0 if name == "nt" else 4
train_loader = DataLoader(dataset_train, batch_size=16, shuffle=True, collate_fn=collate_fn, num_workers=num_workers)
val_loader = DataLoader(dataset_val, batch_size=16, shuffle=False, collate_fn=collate_fn, num_workers=num_workers)
test_loader = DataLoader(dataset_test, batch_size=16, shuffle=False, collate_fn=collate_fn, num_workers=num_workers)

for x in next(iter(train_loader)):
    print(x.shape)

In [None]:
train_loader = MemoryLoadedDataLoader(train_loader, device=device)
print("Loaded train loader into memory")
val_loader = MemoryLoadedDataLoader(val_loader, device=device)
print("Loaded validation loader into memory")
test_loader = MemoryLoadedDataLoader(test_loader, device=device)

In [None]:
from models.FusionCNN import ContextFusionCNN
reload_function(ContextFusionCNN)
model = ContextFusionCNN(1, num_words=N_WORDS)

In [None]:
# Model variables definition.

from typing import Literal


TRAIN_MODE : Literal["zerov1", "zerov2", "double_zerov1", "double_zerov2", "sequential", "interleave"] = "interleave"
pth = "SimplifiedLightweightCNN.pth"
lr = 1e-4  # Reduce from 1e-3
epochs = 140
model = model.to(device)
reload_function(train)
reload_function(evaluate)
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)  # Add L2 regularization
# Add learning rate scheduler
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=5
)
criterion = nn.BCELoss()

# Start a new wandb run to track this script.
run = wandb.init(
    # name of the run
    name=f"hypothesis 14 - {TRAIN_MODE}",
    config={
        "Name": 'SimplifiedLightweightCNN',
        "learning_rate": lr,
        "optimizer": "Adam",
        "criterion": "BCELoss",
        "architecture": "SimplifiedLightweightCNN",
        "architecture_details": str(model),
        "dataset": "Stage-I",
        "train_val_test(%)": f'{TRAIN_SPLIT}-{VAL_SPLIT}-{TEST_SPLIT}',
        "epochs": epochs,
    },
)

# Training loop
configs  = {
    "interleave": (["v1","v2"], True),
    "sequential": (["v1","v2"], False),
    "zerov1": (["v1","v1"], False),
    "zerov2": (["v2","v2"], False),
    "double_zerov1": (["v1","v1"], False),
    "double_zerov2": (["v2","v2"], False),
}
label_versions, interleave_labels = configs[TRAIN_MODE]
for epoch in range(epochs):
    for label_version in label_versions:
        train_loss, train_acc = train(model, train_loader, optimizer, criterion, device, label_version=label_version, interleave_labels=interleave_labels)
        val_loss, val_acc = evaluate(model, val_loader, criterion, device, label_version=label_version, interleave_labels=interleave_labels)
        scheduler.step(val_loss)
        # Logging the metadata for each epoch so that the charts can be generated on the dashboard
    run.log({"train_acc": train_acc, "train_loss": train_loss, "val_acc": val_acc, "val_loss": val_loss, })
    print(
        f"Epoch {epoch + 1}, Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")

run.log({"model_eval": model.eval()})
test_loss, test_acc = evaluate(model, test_loader, criterion, device)
run.log({"test_acc": test_acc, "test_loss": test_loss})
print(f"Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}")
# Saving the model to pth and adding it to the artifacts of the run, there is 5GB of memory on wandb, so we should be fine.
torch.save(model.state_dict(), os.path.join(RESULT_DIRECTORY, pth))
artifact = wandb.Artifact("SimplifiedLightweightCNN-model", type="model")
artifact.add_file(os.path.join(RESULT_DIRECTORY, pth))
run.log_artifact(artifact)

# Finish the run so it gets sent to the remote. You can discover the run right after that on the dashboard.
run.finish()
