In [2]:
# ============================================================
# FULL TCN TRAINING PIPELINE FOR ASL LANDMARKS
# ============================================================

import os
import numpy as np
from pathlib import Path
from collections import Counter
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

# ============================================================
# CONFIG
# ============================================================

DATA_DIR = Path(r"E:\ASL_Citizen\NEW\Top_Classes_Landmarks_Preprocessed")
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

SEQ_LEN = 157
FEATURE_DIM = 438
BATCH_SIZE = 32
EPOCHS = 65
LR = 1e-3

MODEL_SAVE_PATH = DATA_DIR / "tcn_asl_model.pth"
LABEL_ENCODER_PATH = DATA_DIR / "label_encoder.npy"

# ============================================================
# LOAD FILES & LABELS
# ============================================================

files = []
labels = []

for f in DATA_DIR.glob("*.npy"):
    if f.name.endswith("_mask.npy"):
        continue

    label = f.stem.split("_")[0]  # word prefix
    files.append(str(f))
    labels.append(label)

print(f"Total samples found: {len(files)}")

# ============================================================
# FILTER CLASSES WITH <2 SAMPLES
# ============================================================

label_counts = Counter(labels)
filtered_indices = [
    i for i, y in enumerate(labels) if label_counts[y] >= 2
]

files = [files[i] for i in filtered_indices]
labels = [labels[i] for i in filtered_indices]

# Re-encode labels
le = LabelEncoder()
y_encoded = le.fit_transform(labels)
num_classes = len(le.classes_)

np.save(LABEL_ENCODER_PATH, le.classes_)

print(f"Number of gloss classes after filtering: {num_classes}")

# ============================================================
# TRAIN / VAL / TEST SPLIT (YOUR EXACT LOGIC)
# ============================================================

files_train, files_tmp, y_train, y_tmp = train_test_split(
    files, y_encoded, test_size=0.2, stratify=y_encoded, random_state=42
)

files_val, files_test, y_val, y_test = train_test_split(
    files_tmp, y_tmp, test_size=0.5, stratify=y_tmp, random_state=42
)

print(f"Train: {len(files_train)}, Val: {len(files_val)}, Test: {len(files_test)}")

# ============================================================
# DATASET
# ============================================================

class ASLDataset(Dataset):
    def __init__(self, files, labels):
        self.files = files
        self.labels = labels

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        x = np.load(self.files[idx]).astype(np.float32)
        y = self.labels[idx]

        # (T, D) â†’ (D, T) for TCN
        x = torch.tensor(x).transpose(0, 1)

        return x, y

train_ds = ASLDataset(files_train, y_train)
val_ds   = ASLDataset(files_val, y_val)
test_ds  = ASLDataset(files_test, y_test)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_loader   = DataLoader(val_ds, batch_size=BATCH_SIZE)
test_loader  = DataLoader(test_ds, batch_size=BATCH_SIZE)

# ============================================================
# TCN MODEL
# ============================================================

class TemporalBlock(nn.Module):
    def __init__(self, in_ch, out_ch, kernel_size, dilation, dropout):
        super().__init__()
        padding = (kernel_size - 1) * dilation

        self.net = nn.Sequential(
            nn.Conv1d(in_ch, out_ch, kernel_size,
                      padding=padding, dilation=dilation),
            nn.BatchNorm1d(out_ch),
            nn.ReLU(),
            nn.Dropout(dropout),

            nn.Conv1d(out_ch, out_ch, kernel_size,
                      padding=padding, dilation=dilation),
            nn.BatchNorm1d(out_ch),
            nn.ReLU(),
            nn.Dropout(dropout)
        )

        self.downsample = nn.Conv1d(in_ch, out_ch, 1) if in_ch != out_ch else None
        self.relu = nn.ReLU()

    def forward(self, x):
        out = self.net(x)
        out = out[..., :x.size(2)]  # causal cut
        res = x if self.downsample is None else self.downsample(x)
        return self.relu(out + res)


class TCN(nn.Module):
    def __init__(self, input_dim, num_classes):
        super().__init__()

        channels = [256, 256, 256, 256]
        layers = []

        for i in range(len(channels)):
            layers.append(
                TemporalBlock(
                    input_dim if i == 0 else channels[i - 1],
                    channels[i],
                    kernel_size=3,
                    dilation=2 ** i,
                    dropout=0.3
                )
            )

        self.tcn = nn.Sequential(*layers)
        self.pool = nn.AdaptiveAvgPool1d(1)
        self.fc = nn.Linear(channels[-1], num_classes)

    def forward(self, x):
        x = self.tcn(x)
        x = self.pool(x).squeeze(-1)
        return self.fc(x)

model = TCN(FEATURE_DIM, num_classes).to(DEVICE)

# ============================================================
# LOSS (CLASS IMBALANCE HANDLED)
# ============================================================

class_weights = compute_class_weight(
    class_weight="balanced",
    classes=np.unique(y_train),
    y=y_train
)
class_weights = torch.tensor(class_weights, dtype=torch.float32).to(DEVICE)

criterion = nn.CrossEntropyLoss(weight=class_weights)
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)

# ============================================================
# TRAINING LOOP
# ============================================================

def run_epoch(loader, train=True):
    model.train() if train else model.eval()
    total_loss, correct, total = 0, 0, 0

    with torch.set_grad_enabled(train):
        for x, y in loader:
            x, y = x.to(DEVICE), y.to(DEVICE)

            logits = model(x)
            loss = criterion(logits, y)

            if train:
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            total_loss += loss.item() * x.size(0)
            preds = logits.argmax(dim=1)
            correct += (preds == y).sum().item()
            total += y.size(0)

    return total_loss / total, correct / total


best_val_acc = 0.0

for epoch in range(EPOCHS):
    train_loss, train_acc = run_epoch(train_loader, train=True)
    val_loss, val_acc = run_epoch(val_loader, train=False)

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), MODEL_SAVE_PATH)

    print(
        f"Epoch [{epoch+1}/{EPOCHS}] | "
        f"Train Acc: {train_acc:.4f} | "
        f"Val Acc: {val_acc:.4f}"
    )

# ============================================================
# TEST
# ============================================================

model.load_state_dict(torch.load(MODEL_SAVE_PATH))
test_loss, test_acc = run_epoch(test_loader, train=False)

print(f"\nâœ… Test Accuracy: {test_acc:.4f}")
print(f"ðŸ’¾ Model saved to: {MODEL_SAVE_PATH}")


Total samples found: 5846
Number of gloss classes after filtering: 146
Train: 4675, Val: 584, Test: 585
Epoch [1/65] | Train Acc: 0.0261 | Val Acc: 0.0462
Epoch [2/65] | Train Acc: 0.0710 | Val Acc: 0.0599
Epoch [3/65] | Train Acc: 0.1292 | Val Acc: 0.1627
Epoch [4/65] | Train Acc: 0.2006 | Val Acc: 0.2072
Epoch [5/65] | Train Acc: 0.2676 | Val Acc: 0.2654
Epoch [6/65] | Train Acc: 0.3341 | Val Acc: 0.3048
Epoch [7/65] | Train Acc: 0.4186 | Val Acc: 0.3733
Epoch [8/65] | Train Acc: 0.4824 | Val Acc: 0.4110
Epoch [9/65] | Train Acc: 0.5393 | Val Acc: 0.4469
Epoch [10/65] | Train Acc: 0.5728 | Val Acc: 0.4812
Epoch [11/65] | Train Acc: 0.6141 | Val Acc: 0.5034
Epoch [12/65] | Train Acc: 0.6468 | Val Acc: 0.5445
Epoch [13/65] | Train Acc: 0.6642 | Val Acc: 0.5788
Epoch [14/65] | Train Acc: 0.6490 | Val Acc: 0.5257
Epoch [15/65] | Train Acc: 0.7200 | Val Acc: 0.6027
Epoch [16/65] | Train Acc: 0.7298 | Val Acc: 0.5839
Epoch [17/65] | Train Acc: 0.7656 | Val Acc: 0.5942
Epoch [18/65] | Train