In [2]:
import os
import numpy as np
from pathlib import Path
import random

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder

# =========================
# CONFIG
# =========================
DATA_DIR = Path(r"E:\ASL_Citizen\NEW\preprocessed_approach1_shahd")
BATCH_SIZE = 32
EPOCHS = 350
LR = 2e-4
PATIENCE = 40
MIN_DELTA = 0.001
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SEED = 42
MIXUP_ALPHA = 0.3

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

# =========================
# BUILD DATA LIST FROM FILES
# =========================
files = sorted(DATA_DIR.glob("*.npy"))

samples = []
labels = []

for f in files:
    name = f.stem                # "ABOUT 4" or "ABOUT"
    word = name.split(" ")[0]    # take first token
    samples.append(str(f))
    labels.append(word)

# Encode labels
le = LabelEncoder()
encoded = le.fit_transform(labels)
num_classes = len(le.classes_)

print("Total samples:", len(samples))
print("Classes:", num_classes)

# Create dataframe-like lists
data = list(zip(samples, encoded))

# =========================
# SPLITS
# =========================
train_data, test_data = train_test_split(
    data, test_size=0.15, stratify=encoded, random_state=SEED
)

train_labels = [y for _,y in train_data]
train_data, val_data = train_test_split(
    train_data, test_size=0.15,
    stratify=train_labels,
    random_state=SEED
)

# =========================
# DATASET
# =========================
class ASLDataset(Dataset):
    def __init__(self, data, train=True):
        self.data = data
        self.train = train

    def augment(self, x):
        if random.random() < 0.5:
            shift = random.randint(-5,5)
            x = np.roll(x, shift, axis=0)

        if random.random() < 0.5:
            x += np.random.normal(0,0.015,x.shape)

        return x

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        path, label = self.data[idx]
        x = np.load(path)

        # flatten if (frames, landmarks, 2)
        if len(x.shape) == 3:
            x = x.reshape(x.shape[0], -1)

        x = x.astype(np.float32)

        if self.train:
            x = self.augment(x)

        return torch.tensor(x), torch.tensor(label)

train_loader = DataLoader(ASLDataset(train_data,True), batch_size=BATCH_SIZE, shuffle=True)
val_loader   = DataLoader(ASLDataset(val_data,False), batch_size=BATCH_SIZE)
test_loader  = DataLoader(ASLDataset(test_data,False), batch_size=BATCH_SIZE)

# =========================
# AUTO DETECT INPUT DIM
# =========================
sample = np.load(samples[0])
if len(sample.shape)==3:
    input_dim = sample.shape[1]*sample.shape[2]
else:
    input_dim = sample.shape[1]

print("Input dim:", input_dim)

# =========================
# CLASS WEIGHTS
# =========================
train_labels = [y for _,y in train_data]
counts = np.bincount(train_labels)
weights = torch.tensor(1.0/counts,dtype=torch.float32).to(DEVICE)

criterion = nn.CrossEntropyLoss(weight=weights,label_smoothing=0.1)

# =========================
# MODEL (UNCHANGED)
# =========================
class ASLModel(nn.Module):
    def __init__(self,input_dim=input_dim,num_classes=num_classes):
        super().__init__()

        self.conv1 = nn.Conv1d(input_dim,192,5,padding=2)
        self.conv2 = nn.Conv1d(192,192,3,padding=1)
        self.bn1 = nn.BatchNorm1d(192)
        self.bn2 = nn.BatchNorm1d(192)

        self.dropout = nn.Dropout(0.5)

        self.lstm = nn.LSTM(
            192,192,
            num_layers=1,
            batch_first=True,
            bidirectional=True
        )

        self.attention = nn.Sequential(
            nn.Linear(384,128),
            nn.Tanh(),
            nn.Linear(128,1)
        )

        self.fc = nn.Sequential(
            nn.Linear(384,192),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(192,num_classes)
        )

    def forward(self,x):

        x = x.transpose(1,2)
        x = self.dropout(F.relu(self.bn1(self.conv1(x))))
        x = self.dropout(F.relu(self.bn2(self.conv2(x))))
        x = x.transpose(1,2)

        lstm_out,_ = self.lstm(x)

        attn = self.attention(lstm_out)
        weights = torch.softmax(attn,dim=1)
        pooled = torch.sum(weights*lstm_out,dim=1)

        return self.fc(pooled)

model = ASLModel().to(DEVICE)

# =========================
# OPTIMIZER & SCHEDULER
# =========================
optimizer = torch.optim.AdamW(model.parameters(),lr=LR,weight_decay=1e-3)

scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
    optimizer,
    T_0=25,
    T_mult=2,
    eta_min=1e-6
)

# =========================
# TRAIN LOOP (UNCHANGED)
# =========================
best_val_acc = 0
patience_counter = 0

for epoch in range(EPOCHS):

    model.train()
    train_loss=0
    train_correct=0
    total=0

    for x,y in train_loader:
        x,y=x.to(DEVICE),y.to(DEVICE)

        lam = np.random.beta(MIXUP_ALPHA,MIXUP_ALPHA)
        perm = torch.randperm(x.size(0)).to(DEVICE)
        x_mix = lam*x + (1-lam)*x[perm]
        y_a,y_b=y,y[perm]

        optimizer.zero_grad()
        outputs=model(x_mix)

        loss = lam*criterion(outputs,y_a)+(1-lam)*criterion(outputs,y_b)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(),1.0)
        optimizer.step()

        train_loss+=loss.item()*y.size(0)
        preds = outputs.argmax(1)

        train_correct += (
            lam * (preds == y_a).sum().item() +
            (1 - lam) * (preds == y_b).sum().item()
        )
        total+=y.size(0)

    train_loss/=total
    train_acc=train_correct/total

    # VALIDATION
    model.eval()
    val_loss=0
    val_correct=0
    total_val=0

    with torch.no_grad():
        for x,y in val_loader:
            x,y=x.to(DEVICE),y.to(DEVICE)
            outputs=model(x)
            loss=criterion(outputs,y)

            val_loss+=loss.item()*y.size(0)
            preds=outputs.argmax(1)
            val_correct+=(preds==y).sum().item()
            total_val+=y.size(0)

    val_loss/=total_val
    val_acc=val_correct/total_val

    scheduler.step(epoch + val_loss)

    print(f"\nEpoch {epoch+1}/{EPOCHS}")
    print(f"Train Loss {train_loss:.4f} | Acc {train_acc:.4f}")
    print(f"Val   Loss {val_loss:.4f} | Acc {val_acc:.4f}")

    if val_acc > best_val_acc + MIN_DELTA:
        best_val_acc = val_acc
        patience_counter = 0
        torch.save(model.state_dict(),"best_asl_model.pth")
        print("Saved best model")
    else:
        patience_counter+=1
        if patience_counter>=PATIENCE:
            print("Early stopping")
            break

# =========================
# TEST
# =========================
model.load_state_dict(torch.load("best_asl_model.pth"))
model.eval()

correct=0
total=0

with torch.no_grad():
    for x,y in test_loader:
        x,y=x.to(DEVICE),y.to(DEVICE)
        preds=model(x).argmax(1)
        correct+=(preds==y).sum().item()
        total+=y.size(0)

print("\nFINAL TEST ACC:",correct/total)

Total samples: 5568
Classes: 146
Input dim: 270

Epoch 1/350
Train Loss 5.0248 | Acc 0.0143
Val   Loss 4.9225 | Acc 0.0099
Saved best model

Epoch 2/350
Train Loss 4.7544 | Acc 0.0287
Val   Loss 4.6995 | Acc 0.0211
Saved best model

Epoch 3/350
Train Loss 4.5228 | Acc 0.0507
Val   Loss 4.2603 | Acc 0.0873
Saved best model

Epoch 4/350
Train Loss 4.3645 | Acc 0.0652
Val   Loss 4.2969 | Acc 0.0817

Epoch 5/350
Train Loss 4.1959 | Acc 0.0930
Val   Loss 3.8932 | Acc 0.1338
Saved best model

Epoch 6/350
Train Loss 4.0752 | Acc 0.1226
Val   Loss 3.8642 | Acc 0.1437
Saved best model

Epoch 7/350
Train Loss 4.0160 | Acc 0.1458
Val   Loss 3.6936 | Acc 0.2028
Saved best model

Epoch 8/350
Train Loss 3.9175 | Acc 0.1635
Val   Loss 3.5728 | Acc 0.2310
Saved best model

Epoch 9/350
Train Loss 3.8642 | Acc 0.1841
Val   Loss 3.6272 | Acc 0.2070

Epoch 10/350
Train Loss 3.7773 | Acc 0.2052
Val   Loss 3.4036 | Acc 0.2648
Saved best model

Epoch 11/350
Train Loss 3.7645 | Acc 0.1998
Val   Loss 3.4302 | 