In [4]:
import os
import numpy as np
import pandas as pd
from pathlib import Path
import random

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder

# =========================
# CONFIG
# =========================
DATA_DIR = Path(r"E:\ASL_Citizen\NEW\Top_Classes_Landmarks_Preprocessed")
BATCH_SIZE = 32
EPOCHS = 350
LR = 2e-4
PATIENCE = 40              # increased
MIN_DELTA = 0.001          # require meaningful improvement
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SEED = 42
MIXUP_ALPHA = 0.3

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

# =========================
# LOAD METADATA
# =========================
metadata = pd.read_csv(DATA_DIR / "metadata.csv")

counts = metadata['word'].value_counts()
valid = counts[counts >= 2].index
metadata = metadata[metadata['word'].isin(valid)]

le = LabelEncoder()
metadata["label"] = le.fit_transform(metadata["word"])
num_classes = len(le.classes_)

print("Samples:", len(metadata))
print("Classes:", num_classes)

# =========================
# SPLITS
# =========================
train_df, test_df = train_test_split(
    metadata, test_size=0.15,
    stratify=metadata["label"],
    random_state=SEED
)

train_df, val_df = train_test_split(
    train_df, test_size=0.15,
    stratify=train_df["label"],
    random_state=SEED
)

# =========================
# DATASET
# =========================
class ASLDataset(Dataset):
    def __init__(self, df, train=True):
        self.df = df.reset_index(drop=True)
        self.train = train

    def augment(self, x):

        if random.random() < 0.5:
            shift = random.randint(-5,5)
            x = np.roll(x, shift, axis=0)

        if random.random() < 0.5:
            mask = np.random.rand(x.shape[0]) > 0.1
            x = x * mask[:,None]

        if random.random() < 0.5:
            x += np.random.normal(0,0.015,x.shape)

        return x

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        x = np.load(row["processed_file"])
        y = row["label"]

        if self.train:
            x = self.augment(x)

        return torch.tensor(x,dtype=torch.float32), torch.tensor(y)

train_loader = DataLoader(ASLDataset(train_df,True), batch_size=BATCH_SIZE, shuffle=True)
val_loader   = DataLoader(ASLDataset(val_df,False), batch_size=BATCH_SIZE)
test_loader  = DataLoader(ASLDataset(test_df,False), batch_size=BATCH_SIZE)

# =========================
# CLASS WEIGHTS
# =========================
class_counts = train_df["label"].value_counts().sort_index().values
weights = torch.tensor(1.0/class_counts,dtype=torch.float32).to(DEVICE)

criterion = nn.CrossEntropyLoss(weight=weights,label_smoothing=0.1)

# =========================
# MODEL
# =========================
class ASLModel(nn.Module):
    def __init__(self,input_dim=438,num_classes=147):
        super().__init__()

        self.conv1 = nn.Conv1d(input_dim,192,5,padding=2)
        self.conv2 = nn.Conv1d(192,192,3,padding=1)
        self.bn1 = nn.BatchNorm1d(192)
        self.bn2 = nn.BatchNorm1d(192)

        self.dropout = nn.Dropout(0.5)

        self.lstm = nn.LSTM(
            192,192,
            num_layers=1,
            batch_first=True,
            bidirectional=True
        )

        self.attention = nn.Sequential(
            nn.Linear(384,128),
            nn.Tanh(),
            nn.Linear(128,1)
        )

        self.fc = nn.Sequential(
            nn.Linear(384,192),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(192,num_classes)
        )

    def forward(self,x):

        x = x.transpose(1,2)
        x = self.dropout(F.relu(self.bn1(self.conv1(x))))
        x = self.dropout(F.relu(self.bn2(self.conv2(x))))
        x = x.transpose(1,2)

        lstm_out,_ = self.lstm(x)

        attn = self.attention(lstm_out)
        weights = torch.softmax(attn,dim=1)
        pooled = torch.sum(weights*lstm_out,dim=1)

        return self.fc(pooled)

model = ASLModel(num_classes=num_classes).to(DEVICE)

# =========================
# OPTIMIZER & SCHEDULER
# =========================
optimizer = torch.optim.AdamW(model.parameters(),lr=LR,weight_decay=1e-3)

scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
    optimizer,
    T_0=25,
    T_mult=2,
    eta_min=1e-6
)

# =========================
# TRAIN LOOP
# =========================
best_val_acc = 0
patience_counter = 0

for epoch in range(EPOCHS):

    model.train()
    train_loss=0
    train_correct=0
    total=0

    for x,y in train_loader:

        x,y=x.to(DEVICE),y.to(DEVICE)

        lam = np.random.beta(MIXUP_ALPHA,MIXUP_ALPHA)
        perm = torch.randperm(x.size(0)).to(DEVICE)
        x_mix = lam*x + (1-lam)*x[perm]
        y_a,y_b=y,y[perm]

        optimizer.zero_grad()
        outputs=model(x_mix)

        loss = lam*criterion(outputs,y_a)+(1-lam)*criterion(outputs,y_b)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(),1.0)
        optimizer.step()

        train_loss+=loss.item()*y.size(0)
        preds = outputs.argmax(1)

        train_correct += (
            lam * (preds == y_a).sum().item() +
            (1 - lam) * (preds == y_b).sum().item()
        )

        total+=y.size(0)

    train_loss/=total
    train_acc=train_correct/total

    # VALIDATION
    model.eval()
    val_loss=0
    val_correct=0
    total_val=0

    with torch.no_grad():
        for x,y in val_loader:
            x,y=x.to(DEVICE),y.to(DEVICE)
            outputs=model(x)
            loss=criterion(outputs,y)

            val_loss+=loss.item()*y.size(0)
            preds=outputs.argmax(1)
            val_correct+=(preds==y).sum().item()
            total_val+=y.size(0)

    val_loss/=total_val
    val_acc=val_correct/total_val

    scheduler.step(epoch + val_loss)

    print(f"\nEpoch {epoch+1}/{EPOCHS}")
    print(f"Train Loss {train_loss:.4f} | Acc {train_acc:.4f}")
    print(f"Val   Loss {val_loss:.4f} | Acc {val_acc:.4f}")

    if val_acc > best_val_acc + MIN_DELTA:
        best_val_acc = val_acc
        patience_counter = 0

        torch.save({
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
        }, "best_asl_model.pth")

        print("Saved best model")
    else:
        patience_counter+=1
        if patience_counter>=PATIENCE:
            print("Early stopping")
            break

# =========================
# TEST
# =========================
checkpoint = torch.load("best_asl_model.pth")
model.load_state_dict(checkpoint['model_state_dict'])

model.eval()
correct=0
total=0

with torch.no_grad():
    for x,y in test_loader:
        x,y=x.to(DEVICE),y.to(DEVICE)
        preds=model(x).argmax(1)
        correct+=(preds==y).sum().item()
        total+=y.size(0)

print("\nFINAL TEST ACC:",correct/total)


Samples: 5844
Classes: 146

Epoch 1/350
Train Loss 5.0428 | Acc 0.0074
Val   Loss 4.9434 | Acc 0.0161
Saved best model

Epoch 2/350
Train Loss 4.8120 | Acc 0.0258
Val   Loss 4.4160 | Acc 0.0590
Saved best model

Epoch 3/350
Train Loss 4.5291 | Acc 0.0465
Val   Loss 4.1186 | Acc 0.0898
Saved best model

Epoch 4/350
Train Loss 4.3436 | Acc 0.0606
Val   Loss 3.9826 | Acc 0.1032
Saved best model

Epoch 5/350
Train Loss 4.2697 | Acc 0.0777
Val   Loss 3.8651 | Acc 0.1367
Saved best model

Epoch 6/350
Train Loss 4.1983 | Acc 0.0888
Val   Loss 3.7266 | Acc 0.1930
Saved best model

Epoch 7/350
Train Loss 4.1154 | Acc 0.1146
Val   Loss 3.6173 | Acc 0.2306
Saved best model

Epoch 8/350
Train Loss 4.0494 | Acc 0.1249
Val   Loss 3.5445 | Acc 0.2359
Saved best model

Epoch 9/350
Train Loss 3.9392 | Acc 0.1467
Val   Loss 3.4746 | Acc 0.2520
Saved best model

Epoch 10/350
Train Loss 3.8810 | Acc 0.1590
Val   Loss 3.3728 | Acc 0.2748
Saved best model

Epoch 11/350
Train Loss 3.8742 | Acc 0.1738
Val   L