In [1]:
# ============================================================
# FULL TCN + ATTENTION TRAINING PIPELINE FOR ASL LANDMARKS
# ============================================================

import os
import numpy as np
from pathlib import Path
from collections import Counter
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

# ============================================================
# CONFIG
# ============================================================

DATA_DIR = Path(r"E:\ASL_Citizen\NEW\Top_Classes_Landmarks_Preprocessed")
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

SEQ_LEN = 157
POSE_SIZE = 132
FACE_SIZE = 180
HAND_SIZE = 63
FEATURE_DIM = POSE_SIZE + FACE_SIZE + 2*HAND_SIZE   # raw features per frame
BATCH_SIZE = 16
EPOCHS = 60
LR = 1e-3
DROPOUT = 0.3

MODEL_SAVE_PATH = DATA_DIR / "tcn_attention_asl_model.pth"
LABEL_ENCODER_PATH = DATA_DIR / "label_encoder.npy"

# ============================================================
# UTILS
# ============================================================

def compute_velocity(x):
    return torch.cat([x[:, :1], x[:, 1:] - x[:, :-1]], dim=1)

def compute_acceleration(v):
    return torch.cat([v[:, :1], v[:, 1:] - v[:, :-1]], dim=1)

# ============================================================
# LOAD FILES & LABELS
# ============================================================

files, labels = [], []

for f in DATA_DIR.glob("*.npy"):
    if f.name.endswith("_mask.npy"):
        continue
    label = f.stem.split("_")[0]
    files.append(str(f))
    labels.append(label)

# Filter classes with <2 samples
label_counts = Counter(labels)
filtered_indices = [i for i, y in enumerate(labels) if label_counts[y] >= 2]
files = [files[i] for i in filtered_indices]
labels = [labels[i] for i in filtered_indices]

# Encode labels
le = LabelEncoder()
y_encoded = le.fit_transform(labels)
num_classes = len(le.classes_)
np.save(LABEL_ENCODER_PATH, le.classes_)

print(f"Samples: {len(files)}, Classes: {num_classes}")

# ============================================================
# TRAIN / VAL / TEST SPLIT
# ============================================================

files_train, files_tmp, y_train, y_tmp = train_test_split(
    files, y_encoded, test_size=0.2, stratify=y_encoded, random_state=42
)
files_val, files_test, y_val, y_test = train_test_split(
    files_tmp, y_tmp, test_size=0.5, stratify=y_tmp, random_state=42
)

print(f"Train: {len(files_train)}, Val: {len(files_val)}, Test: {len(files_test)}")

# ============================================================
# DATASET
# ============================================================

class ASLDataset(Dataset):
    def __init__(self, files, labels):
        self.files = files
        self.labels = labels

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        x = np.load(self.files[idx]).astype(np.float32)
        y = self.labels[idx]

        # Add velocity and acceleration features
        x_t = torch.tensor(x)  # (T, D)
        v = compute_velocity(x_t)
        a = compute_acceleration(v)
        x_full = torch.cat([x_t, v, a], dim=1)  # (T, 3*D)

        # Transpose for Conv1d: (D, T)
        x_full = x_full.transpose(0, 1)
        return x_full, y

train_ds = ASLDataset(files_train, y_train)
val_ds   = ASLDataset(files_val, y_val)
test_ds  = ASLDataset(files_test, y_test)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_loader   = DataLoader(val_ds, batch_size=BATCH_SIZE)
test_loader  = DataLoader(test_ds, batch_size=BATCH_SIZE)

# ============================================================
# MODEL COMPONENTS
# ============================================================

class TemporalBlock(nn.Module):
    def __init__(self, in_ch, out_ch, kernel_size=3, dilation=1, dropout=DROPOUT):
        super().__init__()
        padding = (kernel_size - 1) * dilation

        self.conv1 = nn.Conv1d(in_ch, out_ch, kernel_size, padding=padding, dilation=dilation)
        self.bn1 = nn.BatchNorm1d(out_ch)
        self.drop1 = nn.Dropout(dropout)
        self.relu = nn.ReLU()

        self.conv2 = nn.Conv1d(out_ch, out_ch, kernel_size, padding=padding, dilation=dilation)
        self.bn2 = nn.BatchNorm1d(out_ch)
        self.drop2 = nn.Dropout(dropout)

        self.downsample = nn.Conv1d(in_ch, out_ch, 1) if in_ch != out_ch else None

    def forward(self, x):
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.drop1(out)
        out = self.relu(self.bn2(self.conv2(out)))
        out = self.drop2(out)

        # ðŸŸ¢ Crop out to match x length for residual
        if out.size(2) > x.size(2):
            out = out[:, :, :x.size(2)]
        elif out.size(2) < x.size(2):
            # Optional: pad if shorter
            pad_len = x.size(2) - out.size(2)
            out = F.pad(out, (0, pad_len))

        if self.downsample:
            x = self.downsample(x)
            # Downsample might also change length slightly
            if x.size(2) != out.size(2):
                x = x[:, :, :out.size(2)]

        return self.relu(out + x)


class AttentionPooling(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.score = nn.Linear(dim, 1)

    def forward(self, x):
        # x: (batch, channels, seq_len) -> transpose for linear
        x_t = x.transpose(1, 2)  # (batch, seq_len, channels)
        attn = self.score(x_t).squeeze(-1)  # (batch, seq_len)
        w = F.softmax(attn, dim=1)
        out = torch.sum(x_t * w.unsqueeze(-1), dim=1)  # (batch, channels)
        return out

class TCNAttentionModel(nn.Module):
    def __init__(self, input_dim, num_classes, channels=[256, 256, 256, 256]):
        super().__init__()
        layers = []
        for i, ch in enumerate(channels):
            layers.append(
                TemporalBlock(
                    input_dim if i==0 else channels[i-1],
                    ch,
                    kernel_size=3,
                    dilation=2**i,
                    dropout=DROPOUT
                )
            )
        self.tcn = nn.Sequential(*layers)
        self.attn = AttentionPooling(channels[-1])
        self.fc = nn.Linear(channels[-1], num_classes)

    def forward(self, x):
        x = self.tcn(x)
        x = self.attn(x)
        return self.fc(x)

model = TCNAttentionModel(FEATURE_DIM*3, num_classes).to(DEVICE)

# ============================================================
# LOSS AND OPTIMIZER
# ============================================================

class_weights = compute_class_weight(
    class_weight="balanced",
    classes=np.unique(y_train),
    y=y_train
)
class_weights = torch.tensor(class_weights, dtype=torch.float32).to(DEVICE)

criterion = nn.CrossEntropyLoss(weight=class_weights)
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)

# ============================================================
# TRAIN / VALIDATION FUNCTION
# ============================================================

def run_epoch(loader, train=True):
    model.train() if train else model.eval()
    total_loss, correct, total = 0, 0, 0

    with torch.set_grad_enabled(train):
        for x, y in loader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            logits = model(x)
            loss = criterion(logits, y)

            if train:
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            total_loss += loss.item() * x.size(0)
            preds = logits.argmax(dim=1)
            correct += (preds == y).sum().item()
            total += y.size(0)

    return total_loss/total, correct/total

# ============================================================
# TRAINING LOOP
# ============================================================

best_val_acc = 0.0

for epoch in range(EPOCHS):
    train_loss, train_acc = run_epoch(train_loader, train=True)
    val_loss, val_acc = run_epoch(val_loader, train=False)

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), MODEL_SAVE_PATH)

    print(f"Epoch [{epoch+1}/{EPOCHS}] | "
          f"Train Acc: {train_acc:.4f} | Val Acc: {val_acc:.4f}")

# ============================================================
# TEST
# ============================================================

model.load_state_dict(torch.load(MODEL_SAVE_PATH))
test_loss, test_acc = run_epoch(test_loader, train=False)
print(f"\nâœ… Test Accuracy: {test_acc:.4f}")
print(f"ðŸ’¾ Model saved to: {MODEL_SAVE_PATH}")


Samples: 5844, Classes: 146
Train: 4675, Val: 584, Test: 585
Epoch [1/60] | Train Acc: 0.0145 | Val Acc: 0.0308
Epoch [2/60] | Train Acc: 0.0603 | Val Acc: 0.1130
Epoch [3/60] | Train Acc: 0.1159 | Val Acc: 0.1318
Epoch [4/60] | Train Acc: 0.1855 | Val Acc: 0.1986
Epoch [5/60] | Train Acc: 0.2588 | Val Acc: 0.2620
Epoch [6/60] | Train Acc: 0.3320 | Val Acc: 0.3031
Epoch [7/60] | Train Acc: 0.4047 | Val Acc: 0.3390
Epoch [8/60] | Train Acc: 0.4849 | Val Acc: 0.4709
Epoch [9/60] | Train Acc: 0.5452 | Val Acc: 0.4349
Epoch [10/60] | Train Acc: 0.5968 | Val Acc: 0.5120
Epoch [11/60] | Train Acc: 0.6145 | Val Acc: 0.4983
Epoch [12/60] | Train Acc: 0.6558 | Val Acc: 0.5497
Epoch [13/60] | Train Acc: 0.6993 | Val Acc: 0.5497
Epoch [14/60] | Train Acc: 0.7356 | Val Acc: 0.5942
Epoch [15/60] | Train Acc: 0.7512 | Val Acc: 0.5411
Epoch [16/60] | Train Acc: 0.7679 | Val Acc: 0.5805
Epoch [17/60] | Train Acc: 0.7874 | Val Acc: 0.5788
Epoch [18/60] | Train Acc: 0.7987 | Val Acc: 0.6404
Epoch [19/60