<a href="https://colab.research.google.com/github/submarinejuice/CP322-Final-Project-Group-9/blob/Michelle-Main/cp322_FINAL.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Multimodal Physiological Representation Learning for Predicting Risky Financial Decisions



**Research Question** - Can we predict whether a participant will invest in a risky asset on a given trial from:
Market context (expected return, volatility)
Physiological arousal (aSCRs)

**Secondary Research Question:**
Do we see comparable physiological signatures of stress/arousal in a real-world wearable dataset (WESAD), and can we learn a shared representation of physiological state that transfers across tasks?

**Motivation** - real financial decisions are emotional

**Problem** - predicting investment choice from market + physio data



#1. Setup & Reproducibility

## 1.1 Clone & Pull from our Repository

In [43]:

import os

REPO_URL = "https://github.com/submarinejuice/CP322-Final-Project-Group-9"
REPO_NAME = "CP322-Final-Project-Group-9"

if not os.path.exists(REPO_NAME):
    # First time in this Colab session: clone the repo
    !git clone {REPO_URL}
else:
    # Repo already there in this runtime: pull latest changes
    %cd {REPO_NAME}
    !git pull
    %cd /content

# Move into repo so relative paths work
%cd /content/{REPO_NAME}


/content/CP322-Final-Project-Group-9/CP322-Final-Project-Group-9
remote: Enumerating objects: 5, done.[K
remote: Counting objects: 100% (5/5), done.[K
remote: Compressing objects: 100% (2/2), done.[K
remote: Total 3 (delta 1), reused 3 (delta 1), pack-reused 0 (from 0)[K
Unpacking objects: 100% (3/3), 1.46 KiB | 750.00 KiB/s, done.
From https://github.com/submarinejuice/CP322-Final-Project-Group-9
   36c779b..5c7cbd3  Michelle-Main -> origin/Michelle-Main
Already up to date.
/content
/content/CP322-Final-Project-Group-9


## 1.2 Kaggle Setup to pull WESAD dataset instead of downloading the data

Must use your own Kaggle API as documented in the README file. Use your own kaggle.JSON file through the upload prompt that appears when this cell is run.

In [68]:
!pip install kaggle

import os
import shutil
from google.colab import files

print("CWD:", os.getcwd())
os.makedirs('/content/.kaggle', exist_ok=True)
os.makedirs('data', exist_ok=True)

# Upload from laptop
uploaded = files.upload()   # select your kaggle.json / kaggle.JSON
fname = list(uploaded.keys())[0]
print("Uploaded:", fname)

# File is saved in the *current* directory, so src is just fname
src = fname
dst = '/content/.kaggle/kaggle.json'

shutil.move(src, dst)

# Fix permissions and inspect
!chmod 600 /content/.kaggle/kaggle.json



CWD: /content/CP322-Final-Project-Group-9


Saving kaggle.JSON to kaggle.JSON
Uploaded: kaggle.JSON


##1.2.5 Use Google Drive to temporarily store data so you don't have to rerun the download commands for such a large dataset.

In [46]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


Allow notebook to temporarily store WESAD dataset with your drive

Downloading WESAD Dataset once the above has been complete.

# 1.3 WESAD download
 Will skip downloads if you already have the files in your drive.

In [69]:
# Ensure Kaggle sees the right config
import os, shutil

print("Using config from /content/.kaggle/kaggle.json")

os.environ["KAGGLE_CONFIG_DIR"] = "/content/.kaggle"

!mkdir -p ~/.kaggle
!cp /content/.kaggle/kaggle.json ~/.kaggle/kaggle.json
!chmod 600 ~/.kaggle/kaggle.json

print("kaggle.json copied to ~/.kaggle (not printed for security).")


Using config from /content/.kaggle/kaggle.json
kaggle.json copied to ~/.kaggle (not printed for security).


adding an area that will say whether or not you need to redownload the dataset so that you dont double download accidentally !

In [53]:
import os

# Path to WESAD after unzipping
WESAD_DIR = "data/WESAD"

if os.path.exists(WESAD_DIR) and len(os.listdir(WESAD_DIR)) > 0:
    print("✔ WESAD dataset already exists. Skipping download.")
else:
    print("⬇ Downloading WESAD dataset from Kaggle...")
    !kaggle datasets download -d orvile/wesad-wearable-stress-affect-detection-dataset -p data/
    !unzip -o "data/*.zip" -d data/
    print("✔ Download complete!")

if not os.path.exists("/content/.kaggle/kaggle.json"):
    raise FileNotFoundError(
        "kaggle.json missing — upload via files.upload() before continuing."
)



✔ WESAD dataset already exists. Skipping download.


# 1.3.5 Implementing a loader file

In [66]:
%%writefile wesad_loader.py
import os
from typing import Dict, Any, List

import numpy as np
import pandas as pd
from pandas.errors import EmptyDataError, ParserError

WESAD_PATH = "data/WESAD"
E4_FILES = ["ACC", "BVP", "EDA", "HR", "IBI", "TEMP", "tags"]


def _read_signal_csv(path: str) -> np.ndarray:
    """
    Read a 1-column CSV into a 1D numpy array.
    If the file is empty or malformed, return an empty array instead of crashing.
    """
    print(f"[wesad_loader] reading {path}")
    try:
        # Quick check by file size
        if os.path.getsize(path) == 0:
            print(f"[wesad_loader] WARNING: {path} is 0 bytes (empty)")
            return np.array([])

        df = pd.read_csv(path, header=None)
        if df.size == 0:
            print(f"[wesad_loader] WARNING: {path} has no values")
            return np.array([])
        return df.values.squeeze()

    except (EmptyDataError, ParserError) as e:
        print(f"[wesad_loader] WARNING: {path} raised {type(e).__name__}: {e}")
        return np.array([])


def load_subject_e4(subject_id: str,
                    base_path: str = WESAD_PATH) -> Dict[str, Any]:
    """
    Load Empatica E4 wrist signals for one subject.
    """
    subj_dir = os.path.join(base_path, subject_id)
    e4_dir = os.path.join(subj_dir, f"{subject_id}_E4_Data")

    if not os.path.isdir(e4_dir):
        raise FileNotFoundError(f"E4 folder not found for {subject_id} at {e4_dir}")

    data: Dict[str, Any] = {}

    for name in E4_FILES:
        csv_path = os.path.join(e4_dir, f"{name}.csv")
        if os.path.exists(csv_path):
            data[name.lower()] = _read_signal_csv(csv_path)
        else:
            print(f"[wesad_loader] WARNING: {csv_path} not found")
            data[name.lower()] = np.array([])

    # Parse metadata from info.txt if present
    info_path = os.path.join(e4_dir, "info.txt")
    meta: Dict[str, str] = {}
    if os.path.exists(info_path):
        with open(info_path, "r") as f:
            for line in f:
                line = line.strip()
                if not line or ":" not in line:
                    continue
                key, val = [x.strip() for x in line.split(":", 1)]
                meta[key] = val
    data["meta"] = meta

    return data


def list_subjects(base_path: str = WESAD_PATH) -> List[str]:
    """
    List all subject folders like 'S2', 'S3', ...
    """
    if not os.path.isdir(base_path):
        raise FileNotFoundError(f"WESAD base path not found: {base_path}")
    return sorted(
        d for d in os.listdir(base_path)
        if d.startswith("S") and os.path.isdir(os.path.join(base_path, d))
    )


Overwriting wesad_loader.py


# 1.4 Importing loader & list subjects
---



In [67]:
import importlib
import wesad_loader
importlib.reload(wesad_loader)

from wesad_loader import load_subject_e4, list_subjects

print(os.listdir())  # sanity
print(list_subjects("data/WESAD"))

s2 = load_subject_e4("S2", base_path="data/WESAD")
for k, v in s2.items():
    if k == "meta":
        print(k, v)
    else:
        print(k, type(v), getattr(v, "shape", None))


['cp322_FINAL.ipynb', 'wesad_loader.py', '.ipynb_checkpoints', '.git', 'data', 'README.md', 'DATASET', '__pycache__', 'CP322-Final-Project-Group-9']
['S10', 'S11', 'S13', 'S14', 'S15', 'S16', 'S17', 'S2', 'S3', 'S4', 'S5', 'S6', 'S7', 'S8', 'S9']
[wesad_loader] reading data/WESAD/S2/S2_E4_Data/ACC.csv
[wesad_loader] reading data/WESAD/S2/S2_E4_Data/BVP.csv
[wesad_loader] reading data/WESAD/S2/S2_E4_Data/EDA.csv
[wesad_loader] reading data/WESAD/S2/S2_E4_Data/HR.csv
[wesad_loader] reading data/WESAD/S2/S2_E4_Data/IBI.csv
[wesad_loader] reading data/WESAD/S2/S2_E4_Data/TEMP.csv
[wesad_loader] reading data/WESAD/S2/S2_E4_Data/tags.csv
acc <class 'numpy.ndarray'> (251972, 3)
bvp <class 'numpy.ndarray'> (503945,)
eda <class 'numpy.ndarray'> (31496,)
hr <class 'numpy.ndarray'> (7867,)
ibi <class 'numpy.ndarray'> (3459, 2)
temp <class 'numpy.ndarray'> (31498,)
tags <class 'numpy.ndarray'> (0,)
meta {'.csv files in this archive are in the following format': ''}


The dataset was way too large to add to the github, so a way to work around this for reproducibility purposes is to essentially just use your own Kaggle API JSON so u can pull from their database

Will go into more detail later, but for now

1. Go to Kaggle settings
2. Create API token
3. Manually bind their own kaggle.JSON if needed
4. Upload it the same way to Colab when promted in the cell above.


2ndary Dataset I used:

Secondary dataset WESAD is 3GB+ and cannot be uploaded to colab or pushed to the git due to such a large file size, so in order to keep this reproducible, I am going to keep this here so that the team will be able to retrieve the data at runtime from the kaggle servers.

# 2. Affective Economics Dataset
---


In [71]:
import pandas as pd
import re

print("Current directory:", os.getcwd())
print("Repo contents:", os.listdir())
print("DATASET contents:", os.listdir("DATASET"))

df = pd.read_csv("DATASET/AE_investment_dataset.csv")
df.head()
df.info()
df.isna().mean().sort_values().head(20)
df.columns.tolist()
for c in df.columns:
    print(c)





Current directory: /content/CP322-Final-Project-Group-9
Repo contents: ['cp322_FINAL.ipynb', 'wesad_loader.py', '.ipynb_checkpoints', '.git', 'data', 'README.md', 'DATASET', '__pycache__', 'CP322-Final-Project-Group-9']
DATASET contents: ['AE_investment_dataset.csv']
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 30 entries, 0 to 29
Columns: 364 entries, Participant_code to SCR_AnticipatoryS4_T10
dtypes: float64(356), int64(5), object(3)
memory usage: 85.4+ KB
Participant_code
Age
Gender
Nationality
Ethnicity
Played_stock_market
Played_in_years
Played_how_often
Stock_amount_S1_T1
Stock_amount_S1_T2
Stock_amount_S1_T3
Stock_amount_S1_T4
Stock_amount_S1_T5
Stock_amount_S1_T6
Stock_amount_S1_T7
Stock_amount_S1_T8
Stock_amount_S1_T9
Stock_amount_S1_T10
Stock_amount_S2_T1
Stock_amount_S2_T2
Stock_amount_S2_T3
Stock_amount_S2_T4
Stock_amount_S2_T5
Stock_amount_S2_T6
Stock_amount_S2_T7
Stock_amount_S2_T8
Stock_amount_S2_T9
Stock_amount_S2_T10
Stock_amount_S3_T1
Stock_amount_S3_T2
Stock_amo

##Quick note bc I didn't know what PANAS meant:
- PANAS refers to the Positive and Negative Affect Schedule, a widely used psychological scale that measures an individual's mood by assessing both positive and negative emotions. Developed in 1988 by Watson, Clark, and Tellegen, it's a 20-item self-report measure used in research and clinical settings to gauge how frequently someone experiences emotions like interest, joy, enthusiasm (positive affect) versus feelings of distress, sadness, and nervousness (negative affect).
## How it works
- 20 items: The scale consists of 20 words that describe different feelings and emotions.
- Two dimensions: These items are separated into two subscales: one for positive affect (PA) and one for negative affect (NA).
- Rating scale: Participants rate how they felt about each item over a specific time frame (e.g., "right now," "today," "over the past few weeks") on a 5-point scale.
- Scoring: Each positive and negative item is scored individually. The total positive score and total negative score are then calculated. A higher positive score indicates more positive affect, while a higher negative score indicates more negative affect.

Building a per-trial table with:
1. inputs per step:
  - money_in_stocks
  - mean_return
  - stock_fluctuation
  - scr_anticipatory
2. Static inputs:
3. Target
  - Whether they invested in the stock (money_in_stocks > 0 -> 1 else 0)

#Kaggle Dataset for WESAD

-- note: ref doc later

# Minimal WESAD pipeline
- loading dataset, segmenting data into windows, computing basics, labelling windows as baselines vs stress.
- using dataframe and constructing just like we do later on for the Bath dataset

## 2.1 Loading Affective Economics Dataset
---

In [74]:
import re
import pandas as pd

print("Current directory:", os.getcwd())
print("Repo contents:", os.listdir())
print("DATASET contents:", os.listdir("DATASET"))

df = pd.read_csv("DATASET/AE_investment_dataset.csv")
df.info()

# 0. ID & static columns
id_cols = ["Participant_code", "Age", "Gender", "Nationality", "Ethnicity", "Played_stock_market"]

# 1. Grab trial-level columns by prefix
stock_cols = [c for c in df.columns if c.startswith("Money_in_stocks_S")]
scr_cols   = [c for c in df.columns if c.startswith("SCR_AnticipatoryS")]
ret_cols   = [c for c in df.columns if c.startswith("Mean_Return_S")]
fluc_cols  = [c for c in df.columns if c.startswith("stock_fluctuation_S")]

print("n_stock_cols:", len(stock_cols))
print("n_scr_cols:", len(scr_cols))
print("n_return_cols:", len(ret_cols))
print("n_fluctuation_cols:", len(fluc_cols))

# 2. Map (session, trial) -> column name
def build_lookup(cols, prefix):
    lookup = {}
    for c in cols:
        # e.g. Money_in_stocks_S1_T3  →  session=1, trial=3
        m = re.match(rf"{re.escape(prefix)}(\d+)_T(\d+)$", c)
        if m:
            s = int(m.group(1))   # session number
            t = int(m.group(2))   # trial number within session
            lookup[(s, t)] = c
    return lookup

stock_map = build_lookup(stock_cols, "Money_in_stocks_S")
scr_map   = build_lookup(scr_cols,   "SCR_AnticipatoryS")
ret_map   = build_lookup(ret_cols,   "Mean_Return_S")
fluc_map  = build_lookup(fluc_cols,  "stock_fluctuation_S")

print("number of keys in stock_map:", len(stock_map))
print("some keys from stock_map:", list(stock_map.items())[:5])


Current directory: /content/CP322-Final-Project-Group-9
Repo contents: ['cp322_FINAL.ipynb', 'wesad_loader.py', '.ipynb_checkpoints', '.git', 'data', 'README.md', 'DATASET', '__pycache__', 'CP322-Final-Project-Group-9']
DATASET contents: ['AE_investment_dataset.csv']
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 30 entries, 0 to 29
Columns: 364 entries, Participant_code to SCR_AnticipatoryS4_T10
dtypes: float64(356), int64(5), object(3)
memory usage: 85.4+ KB
n_stock_cols: 40
n_scr_cols: 40
n_return_cols: 36
n_fluctuation_cols: 36
number of keys in stock_map: 40
some keys from stock_map: [((1, 1), 'Money_in_stocks_S1_T1'), ((1, 2), 'Money_in_stocks_S1_T2'), ((1, 3), 'Money_in_stocks_S1_T3'), ((1, 4), 'Money_in_stocks_S1_T4'), ((1, 5), 'Money_in_stocks_S1_T5')]


## 2.2 Long format trial table

In [75]:
rows = []

for _, row in df.iterrows():
    # carry participant-level info
    base = {col: row[col] for col in id_cols}

    for (s, t) in sorted(stock_map.keys()):
        rec = dict(base)
        rec["session"] = s
        rec["trial_in_session"] = t
        rec["global_trial"] = (s - 1) * 10 + t  # 1..40

        rec["money_in_stocks"] = row[stock_map[(s, t)]]
        rec["scr_anticipatory"] = row[scr_map[(s, t)]]

        # Some (session, trial) combos might not have return/fluctuation
        rec["mean_return"] = row[ret_map[(s, t)]] if (s, t) in ret_map else pd.NA
        rec["stock_fluctuation"] = row[fluc_map[(s, t)]] if (s, t) in fluc_map else pd.NA

        rows.append(rec)

long_df = pd.DataFrame(rows)

# Target: did they invest at all?
long_df["invested"] = (long_df["money_in_stocks"] > 0).astype(int)

print(long_df.shape)
long_df.head()
print(long_df["invested"].value_counts())


(1200, 14)
invested
1    1079
0     121
Name: count, dtype: int64


## 2.3 Cleaning + Basic Features


In [76]:


import pandas as pd
import numpy as np

# Drop rows where our key features are missing
key_features = ["scr_anticipatory", "mean_return", "stock_fluctuation", "money_in_stocks"]

clean_df = long_df.dropna(subset=key_features).copy()

# Convert types to numeric
for col in key_features:
    clean_df[col] = pd.to_numeric(clean_df[col], errors='coerce')

# Drop again if any become NA
clean_df = clean_df.dropna(subset=key_features)

# Target
y = clean_df["invested"].astype(int)

# Features: minimal baseline
X = clean_df[["scr_anticipatory", "mean_return", "stock_fluctuation"]]

print("Clean shape:", clean_df.shape)
X.head()


Clean shape: (1080, 14)


Unnamed: 0,scr_anticipatory,mean_return,stock_fluctuation
1,0.028,0.045,2.0
2,0.232,-0.046,1.0
3,0.954,0.37,2.0
4,0.0,0.046,2.0
5,0.858,0.001,1.0


In [None]:
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, f1_score, classification_report

# Train-test split
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42, stratify=y
)

# Standardize features
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled  = scaler.transform(X_test)

# Baseline model
clf = LogisticRegression()
clf.fit(X_train_scaled, y_train)

y_pred = clf.predict(X_test_scaled)

print("Accuracy:", accuracy_score(y_test, y_pred))
print("F1-score:", f1_score(y_test, y_pred))
print(classification_report(y_test, y_pred))


In [None]:
from sklearn.neural_network import MLPClassifier

# Scale features again (safe to reuse)
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled  = scaler.transform(X_test)

# MLP model: small but strong
mlp = MLPClassifier(
    hidden_layer_sizes=(32, 16),
    activation='relu',
    solver='adam',
    max_iter=500,
    random_state=42
)

mlp.fit(X_train_scaled, y_train)

y_pred_mlp = mlp.predict(X_test_scaled)

print("MLP Accuracy:", accuracy_score(y_test, y_pred_mlp))
print("MLP F1:", f1_score(y_test, y_pred_mlp))
print(classification_report(y_test, y_pred_mlp))


MLP Cell


In [None]:
from sklearn.neural_network import MLPClassifier

#scale data again
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled  = scaler.transform(X_test)

# MLP model: small but strong
mlp = MLPClassifier(
    hidden_layer_sizes=(32, 16),
    activation='relu',
    solver='adam',
    max_iter=500,
    random_state=42
)

mlp.fit(X_train_scaled, y_train)

y_pred_mlp = mlp.predict(X_test_scaled)

print("MLP Accuracy:", accuracy_score(y_test, y_pred_mlp))
print("MLP F1:", f1_score(y_test, y_pred_mlp))
print(classification_report(y_test, y_pred_mlp))


In [None]:
import numpy as np

# Work on a copy, sorted by participant + time
scr_df = clean_df.sort_values(["Participant_code", "session", "trial_in_session"]).copy()

grp = scr_df.groupby("Participant_code")

# Participant-level mean/std and z-score
scr_df["scr_mean_p"] = grp["scr_anticipatory"].transform("mean")
scr_df["scr_std_p"]  = grp["scr_anticipatory"].transform("std")
scr_df["scr_z"] = (scr_df["scr_anticipatory"] - scr_df["scr_mean_p"]) / scr_df["scr_std_p"]

# Lags within each participant
scr_df["scr_lag1"] = grp["scr_anticipatory"].shift(1)
scr_df["scr_lag2"] = grp["scr_anticipatory"].shift(2)

# Changes vs previous trials
scr_df["scr_delta1"] = scr_df["scr_anticipatory"] - scr_df["scr_lag1"]
scr_df["scr_delta2"] = scr_df["scr_anticipatory"] - scr_df["scr_lag2"]

# Short-term rolling window stats (window=3 trials)
scr_df["scr_roll_mean3"] = grp["scr_anticipatory"].transform(
    lambda x: x.rolling(window=3, min_periods=1).mean()
)
scr_df["scr_roll_std3"] = grp["scr_anticipatory"].transform(
    lambda x: x.rolling(window=3, min_periods=1).std()
)

# Replace NaNs from lags / std=0 with 0 for now
scr_feature_cols = [
    "scr_anticipatory",
    "scr_z",
    "scr_lag1", "scr_lag2",
    "scr_delta1", "scr_delta2",
    "scr_roll_mean3", "scr_roll_std3",
]

scr_df[scr_feature_cols] = scr_df[scr_feature_cols].fillna(0.0)

print("SCR feature shape:", scr_df[scr_feature_cols].shape)
scr_df[scr_feature_cols + ["invested"]].head()


In [None]:
import torch
from torch import nn
from torch.utils.data import TensorDataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score, f1_score

# ==== Prepare data ====
X_scr = scr_df[scr_feature_cols].values.astype("float32")
y_scr = scr_df["invested"].values.astype("int64")

X_train, X_val, y_train, y_val = train_test_split(
    X_scr, y_scr, test_size=0.2, random_state=42, stratify=y_scr
)

# Standardize SCR features
scaler_scr = StandardScaler()
X_train_scaled = scaler_scr.fit_transform(X_train)
X_val_scaled   = scaler_scr.transform(X_val)

X_train_tensor = torch.tensor(X_train_scaled, dtype=torch.float32)
y_train_tensor = torch.tensor(y_train, dtype=torch.long)
X_val_tensor   = torch.tensor(X_val_scaled, dtype=torch.float32)
y_val_tensor   = torch.tensor(y_val, dtype=torch.long)

train_ds = TensorDataset(X_train_tensor, y_train_tensor)
val_ds   = TensorDataset(X_val_tensor, y_val_tensor)

train_loader = DataLoader(train_ds, batch_size=64, shuffle=True)
val_loader   = DataLoader(val_ds, batch_size=64, shuffle=False)

input_dim = X_train_tensor.shape[1]
emb_dim   = 16

# ==== Physiology encoder ====
class PhysioEncoder(nn.Module):
    def __init__(self, in_dim, hidden_dim=32, emb_dim=16):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, emb_dim),
            nn.ReLU()
        )

    def forward(self, x):
        return self.net(x)

# Full model: encoder + classifier head
class PhysioModel(nn.Module):
    def __init__(self, encoder, emb_dim):
        super().__init__()
        self.encoder = encoder
        self.classifier = nn.Linear(emb_dim, 2)

    def forward(self, x):
        z = self.encoder(x)        # [batch, emb_dim]
        logits = self.classifier(z)  # [batch, 2]
        return logits

encoder = PhysioEncoder(input_dim, hidden_dim=32, emb_dim=emb_dim)
model = PhysioModel(encoder, emb_dim)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# ==== Training loop ====
n_epochs = 30

for epoch in range(1, n_epochs + 1):
    model.train()
    train_losses = []

    for xb, yb in train_loader:
        xb, yb = xb.to(device), yb.to(device)

        optimizer.zero_grad()
        logits = model(xb)
        loss = criterion(logits, yb)
        loss.backward()
        optimizer.step()

        train_losses.append(loss.item())

    # Validation
    model.eval()
    all_preds = []
    all_true  = []
    with torch.no_grad():
        for xb, yb in val_loader:
            xb, yb = xb.to(device), yb.to(device)
            logits = model(xb)
            preds = torch.argmax(logits, dim=1)
            all_preds.append(preds.cpu())
            all_true.append(yb.cpu())

    all_preds = torch.cat(all_preds).numpy()
    all_true  = torch.cat(all_true).numpy()

    acc = accuracy_score(all_true, all_preds)
    f1  = f1_score(all_true, all_preds)

    if epoch % 5 == 0 or epoch == 1:
        print(f"Epoch {epoch:02d} | "
              f"train_loss={np.mean(train_losses):.4f} | "
              f"val_acc={acc:.3f} | val_f1={f1:.3f}")
