In [4]:
# ============================================================
# CONTRASTIVE PRETRAIN + SUPERVISED FINETUNE (CPU SAFE)
# ============================================================

import numpy as np
from pathlib import Path
from collections import Counter
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

# ============================================================
# CONFIG
# ============================================================

DATA_DIR = Path(r"E:\ASL_Citizen\NEW\Top_Classes_Landmarks_Preprocessed")
DEVICE = "cpu"

FEATURE_DIM = 438
BATCH_SIZE = 8
PRETRAIN_EPOCHS = 40
FINETUNE_EPOCHS = 50
LR = 3e-4
WEIGHT_DECAY = 1e-4
PATIENCE = 7

PRETRAIN_MODEL_PATH = DATA_DIR / "tcn_pretrained.pth"
FINETUNE_MODEL_PATH = DATA_DIR / "tcn_finetuned.pth"

# ============================================================
# LOAD FILES
# ============================================================

files, masks, labels = [], [], []

for f in DATA_DIR.glob("*.npy"):
    if f.name.endswith("_mask.npy"):
        continue
    files.append(str(f))
    masks.append(str(f).replace(".npy", "_mask.npy"))
    labels.append(f.stem.split("_")[0])

# Keep classes with at least 2 samples
counts = Counter(labels)
idx = [i for i, y in enumerate(labels) if counts[y] >= 2]
files  = [files[i] for i in idx]
masks  = [masks[i] for i in idx]
labels = [labels[i] for i in idx]

le = LabelEncoder()
y = le.fit_transform(labels)
num_classes = len(le.classes_)

# ============================================================
# SPLIT DATA
# ============================================================

X_tr, X_tmp, y_tr, y_tmp, m_tr, m_tmp = train_test_split(
    files, y, masks, test_size=0.2, stratify=y, random_state=42
)
X_val, X_te, y_val, y_te, m_val, m_te = train_test_split(
    X_tmp, y_tmp, m_tmp, test_size=0.5, stratify=y_tmp, random_state=42
)

# ============================================================
# AUGMENTATIONS
# ============================================================

class TemporalDropout:
    def __init__(self, p=0.6, drop_ratio=0.15):
        self.p = p
        self.drop_ratio = drop_ratio

    def __call__(self, x, mask):
        if np.random.rand() > self.p:
            return x, mask
        T = mask.shape[0]
        keep = (np.random.rand(T) > self.drop_ratio).astype(np.float32)
        return x, mask * keep

def temporal_jitter(x):
    shift = np.random.randint(-3, 4)
    return np.roll(x, shift, axis=0)

# ============================================================
# DATASETS
# ============================================================

class ContrastiveDataset(Dataset):
    def __init__(self, files, masks):
        self.files = files
        self.masks = masks
        self.aug = TemporalDropout()

    def __len__(self):
        return len(self.files)

    def view(self, x, mask):
        x, mask = self.aug(x, mask)
        x = temporal_jitter(x)
        return torch.from_numpy(x).transpose(0, 1), torch.from_numpy(mask)

    def __getitem__(self, idx):
        x = np.load(self.files[idx]).astype(np.float32)
        mask = np.load(self.masks[idx]).astype(np.float32)
        x1, m1 = self.view(x.copy(), mask.copy())
        x2, m2 = self.view(x.copy(), mask.copy())
        return x1, m1, x2, m2

class ASLDataset(Dataset):
    def __init__(self, files, masks, labels):
        self.files = files
        self.masks = masks
        self.labels = labels

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        x = np.load(self.files[idx]).astype(np.float32)
        mask = np.load(self.masks[idx]).astype(np.float32)
        return (
            torch.from_numpy(x).transpose(0, 1),
            torch.from_numpy(mask),
            self.labels[idx]
        )

# ============================================================
# MODEL
# ============================================================

class TemporalBlock(nn.Module):
    def __init__(self, in_c, out_c, dilation):
        super().__init__()
        pad = dilation
        self.net = nn.Sequential(
            nn.Conv1d(in_c, out_c, 3, padding=pad, dilation=dilation),
            nn.BatchNorm1d(out_c),
            nn.ReLU(),
            nn.Conv1d(out_c, out_c, 3, padding=pad, dilation=dilation),
            nn.BatchNorm1d(out_c),
            nn.ReLU()
        )
        self.res = nn.Conv1d(in_c, out_c, 1) if in_c != out_c else nn.Identity()

    def forward(self, x):
        y = self.net(x)
        return y[..., :x.size(2)] + self.res(x)

class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        chans = [128, 128, 128]
        blocks = []
        for i, c in enumerate(chans):
            blocks.append(
                TemporalBlock(
                    FEATURE_DIM if i == 0 else chans[i-1],
                    c,
                    dilation=2**i
                )
            )
        self.tcn = nn.Sequential(*blocks)

    def forward(self, x, mask):
        x = self.tcn(x)
        mask = mask.unsqueeze(1)
        return (x * mask).sum(dim=2) / (mask.sum(dim=2) + 1e-6)

class ProjectionHead(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, dim),
            nn.ReLU(),
            nn.Linear(dim, 128)
        )

    def forward(self, x):
        return self.net(x)

class Classifier(nn.Module):
    def __init__(self, encoder, num_classes):
        super().__init__()
        self.encoder = encoder
        self.fc = nn.Linear(128, num_classes)

    def forward(self, x, mask):
        z = self.encoder(x, mask)
        return self.fc(z)

# ============================================================
# CONTRASTIVE LOSS
# ============================================================

def nt_xent(z1, z2, temp=0.2):
    z1 = nn.functional.normalize(z1, dim=1)
    z2 = nn.functional.normalize(z2, dim=1)
    logits = z1 @ z2.T / temp
    labels = torch.arange(z1.size(0), device=z1.device)
    return nn.CrossEntropyLoss()(logits, labels)

# ============================================================
# CONTRASTIVE PRETRAIN
# ============================================================

encoder = Encoder().to(DEVICE)
projector = ProjectionHead(128).to(DEVICE)

contrastive_loader = DataLoader(
    ContrastiveDataset(X_tr, m_tr),
    batch_size=BATCH_SIZE, shuffle=True, num_workers=0
)

opt = torch.optim.AdamW(
    list(encoder.parameters()) + list(projector.parameters()),
    lr=LR, weight_decay=WEIGHT_DECAY
)

print("\nðŸ”µ Starting contrastive pretraining")
for epoch in range(PRETRAIN_EPOCHS):
    encoder.train()
    epoch_loss = 0
    for x1, m1, x2, m2 in tqdm(contrastive_loader, leave=False):
        z1 = projector(encoder(x1, m1))
        z2 = projector(encoder(x2, m2))
        loss = nt_xent(z1, z2)
        opt.zero_grad()
        loss.backward()
        opt.step()
        epoch_loss += loss.item()
    print(f"Pretrain Epoch {epoch+1:02d} | Loss: {epoch_loss:.4f}")

torch.save(encoder.state_dict(), PRETRAIN_MODEL_PATH)

# ============================================================
# SUPERVISED FINETUNE
# ============================================================

classifier = Classifier(encoder, num_classes).to(DEVICE)
classifier.encoder.load_state_dict(torch.load(PRETRAIN_MODEL_PATH))

# Class weights for imbalanced data
weights = compute_class_weight("balanced", classes=np.unique(y_tr), y=y_tr)
weights = torch.tensor(weights, dtype=torch.float32, device=DEVICE)
criterion = nn.CrossEntropyLoss(weight=weights)

optimizer = torch.optim.AdamW(classifier.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)

train_loader = DataLoader(ASLDataset(X_tr, m_tr, y_tr), batch_size=BATCH_SIZE, shuffle=True)
val_loader   = DataLoader(ASLDataset(X_val, m_val, y_val), batch_size=BATCH_SIZE)
test_loader  = DataLoader(ASLDataset(X_te, m_te, y_te), batch_size=BATCH_SIZE)

best_val = 0
patience = 0

print("\nðŸŸ¢ Starting supervised fine-tuning")
for epoch in range(FINETUNE_EPOCHS):
    # -------- TRAIN --------
    classifier.train()
    train_loss_sum, train_correct, train_total = 0, 0, 0
    for x, m, yb in tqdm(train_loader, leave=False):
        logits = classifier(x, m)
        loss = criterion(logits, yb)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss_sum += loss.item() * yb.size(0)
        train_correct += (logits.argmax(1) == yb).sum().item()
        train_total += yb.size(0)

    train_loss = train_loss_sum / train_total
    train_acc = train_correct / train_total

    # -------- VALIDATION --------
    classifier.eval()
    val_loss_sum, val_correct, val_total = 0, 0, 0
    with torch.no_grad():
        for x, m, yb in val_loader:
            logits = classifier(x, m)
            loss = criterion(logits, yb)
            val_loss_sum += loss.item() * yb.size(0)
            val_correct += (logits.argmax(1) == yb).sum().item()
            val_total += yb.size(0)

    val_loss = val_loss_sum / val_total
    val_acc = val_correct / val_total

    # -------- PRINT METRICS --------
    print(f"Epoch {epoch+1:02d} | "
          f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f} | "
          f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}")

    # -------- EARLY STOPPING --------
    if val_acc > best_val:
        best_val = val_acc
        patience = 0
        torch.save(classifier.state_dict(), FINETUNE_MODEL_PATH)
    else:
        patience += 1
        if patience >= PATIENCE:
            print("ðŸ›‘ Early stopping")
            break

# ============================================================
# TEST EVALUATION
# ============================================================

classifier.load_state_dict(torch.load(FINETUNE_MODEL_PATH))
classifier.eval()

test_loss_sum, test_correct, test_total = 0, 0, 0
with torch.no_grad():
    for x, m, yb in test_loader:
        logits = classifier(x, m)
        loss = criterion(logits, yb)
        test_loss_sum += loss.item() * yb.size(0)
        test_correct += (logits.argmax(1) == yb).sum().item()
        test_total += yb.size(0)

test_loss = test_loss_sum / test_total
test_acc = test_correct / test_total

print(f"\nâœ… Best Val Accuracy: {best_val:.4f}")
print(f"ðŸ’¾ Saved model: {FINETUNE_MODEL_PATH}")
print(f"ðŸŽ¯ Test Loss: {test_loss:.4f} | Test Acc: {test_acc:.4f}")



ðŸ”µ Starting contrastive pretraining


                                                                                                                       

Pretrain Epoch 01 | Loss: 48.0354


                                                                                                                       

Pretrain Epoch 02 | Loss: 32.2525


                                                                                                                       

Pretrain Epoch 03 | Loss: 31.0726


                                                                                                                       

Pretrain Epoch 04 | Loss: 29.3889


                                                                                                                       

Pretrain Epoch 05 | Loss: 28.4041


                                                                                                                       

Pretrain Epoch 06 | Loss: 28.2175


                                                                                                                       

Pretrain Epoch 07 | Loss: 26.9141


                                                                                                                       

Pretrain Epoch 08 | Loss: 25.9298


                                                                                                                       

Pretrain Epoch 09 | Loss: 26.5101


                                                                                                                       

Pretrain Epoch 10 | Loss: 25.7654


                                                                                                                       

Pretrain Epoch 11 | Loss: 25.7158


                                                                                                                       

Pretrain Epoch 12 | Loss: 24.9983


                                                                                                                       

Pretrain Epoch 13 | Loss: 24.1145


                                                                                                                       

Pretrain Epoch 14 | Loss: 23.2176


                                                                                                                       

Pretrain Epoch 15 | Loss: 25.1415


                                                                                                                       

Pretrain Epoch 16 | Loss: 21.8460


                                                                                                                       

Pretrain Epoch 17 | Loss: 23.5932


                                                                                                                       

Pretrain Epoch 18 | Loss: 22.6764


                                                                                                                       

Pretrain Epoch 19 | Loss: 21.6540


                                                                                                                       

Pretrain Epoch 20 | Loss: 21.2355


                                                                                                                       

Pretrain Epoch 21 | Loss: 21.2883


                                                                                                                       

Pretrain Epoch 22 | Loss: 21.1720


                                                                                                                       

Pretrain Epoch 23 | Loss: 21.6306


                                                                                                                       

Pretrain Epoch 24 | Loss: 21.2396


                                                                                                                       

Pretrain Epoch 25 | Loss: 21.6471


                                                                                                                       

Pretrain Epoch 26 | Loss: 21.0988


                                                                                                                       

Pretrain Epoch 27 | Loss: 19.7905


                                                                                                                       

Pretrain Epoch 28 | Loss: 20.0394


                                                                                                                       

Pretrain Epoch 29 | Loss: 19.1268


                                                                                                                       

Pretrain Epoch 30 | Loss: 19.6129


                                                                                                                       

Pretrain Epoch 31 | Loss: 19.6254


                                                                                                                       

Pretrain Epoch 32 | Loss: 19.6545


                                                                                                                       

Pretrain Epoch 33 | Loss: 19.0673


                                                                                                                       

Pretrain Epoch 34 | Loss: 19.2461


                                                                                                                       

Pretrain Epoch 35 | Loss: 19.7468


                                                                                                                       

Pretrain Epoch 36 | Loss: 18.9170


                                                                                                                       

Pretrain Epoch 37 | Loss: 18.7410


                                                                                                                       

Pretrain Epoch 38 | Loss: 19.0343


                                                                                                                       

Pretrain Epoch 39 | Loss: 18.5183


                                                                                                                       

Pretrain Epoch 40 | Loss: 18.5830

ðŸŸ¢ Starting supervised fine-tuning


                                                                                                                       

Epoch 01 | Train Loss: 4.8338 | Train Acc: 0.0252 | Val Loss: 4.3488 | Val Acc: 0.0634


                                                                                                                       

Epoch 02 | Train Loss: 4.0608 | Train Acc: 0.0742 | Val Loss: 3.8343 | Val Acc: 0.1164


                                                                                                                       

Epoch 03 | Train Loss: 3.5750 | Train Acc: 0.1393 | Val Loss: 3.2314 | Val Acc: 0.1781


                                                                                                                       

Epoch 04 | Train Loss: 3.1271 | Train Acc: 0.2141 | Val Loss: 2.9306 | Val Acc: 0.2483


                                                                                                                       

Epoch 05 | Train Loss: 2.7976 | Train Acc: 0.2774 | Val Loss: 2.5526 | Val Acc: 0.3185


                                                                                                                       

Epoch 06 | Train Loss: 2.4958 | Train Acc: 0.3313 | Val Loss: 2.3402 | Val Acc: 0.3545


                                                                                                                       

Epoch 07 | Train Loss: 2.2432 | Train Acc: 0.3912 | Val Loss: 2.1998 | Val Acc: 0.4007


                                                                                                                       

Epoch 08 | Train Loss: 2.0775 | Train Acc: 0.4357 | Val Loss: 2.0590 | Val Acc: 0.4418


                                                                                                                       

Epoch 09 | Train Loss: 1.8992 | Train Acc: 0.4770 | Val Loss: 1.8714 | Val Acc: 0.4572


                                                                                                                       

Epoch 10 | Train Loss: 1.7313 | Train Acc: 0.5153 | Val Loss: 1.7821 | Val Acc: 0.5017


                                                                                                                       

Epoch 11 | Train Loss: 1.5635 | Train Acc: 0.5576 | Val Loss: 1.6619 | Val Acc: 0.5445


                                                                                                                       

Epoch 12 | Train Loss: 1.4432 | Train Acc: 0.5940 | Val Loss: 1.5317 | Val Acc: 0.5325


                                                                                                                       

Epoch 13 | Train Loss: 1.3366 | Train Acc: 0.6244 | Val Loss: 1.5046 | Val Acc: 0.5702


                                                                                                                       

Epoch 14 | Train Loss: 1.2396 | Train Acc: 0.6507 | Val Loss: 1.4651 | Val Acc: 0.5753


                                                                                                                       

Epoch 15 | Train Loss: 1.1266 | Train Acc: 0.6721 | Val Loss: 1.3633 | Val Acc: 0.6079


                                                                                                                       

Epoch 16 | Train Loss: 1.0652 | Train Acc: 0.6960 | Val Loss: 1.3149 | Val Acc: 0.6062


                                                                                                                       

Epoch 17 | Train Loss: 0.9724 | Train Acc: 0.7200 | Val Loss: 1.2271 | Val Acc: 0.6610


                                                                                                                       

Epoch 18 | Train Loss: 0.9103 | Train Acc: 0.7337 | Val Loss: 1.1816 | Val Acc: 0.6353


                                                                                                                       

Epoch 19 | Train Loss: 0.8545 | Train Acc: 0.7555 | Val Loss: 1.1663 | Val Acc: 0.6661


                                                                                                                       

Epoch 20 | Train Loss: 0.7935 | Train Acc: 0.7668 | Val Loss: 1.0901 | Val Acc: 0.6832


                                                                                                                       

Epoch 21 | Train Loss: 0.7342 | Train Acc: 0.7878 | Val Loss: 1.0860 | Val Acc: 0.6832


                                                                                                                       

Epoch 22 | Train Loss: 0.6891 | Train Acc: 0.7891 | Val Loss: 1.1133 | Val Acc: 0.6729


                                                                                                                       

Epoch 23 | Train Loss: 0.6291 | Train Acc: 0.8126 | Val Loss: 1.0916 | Val Acc: 0.6901


                                                                                                                       

Epoch 24 | Train Loss: 0.6024 | Train Acc: 0.8193 | Val Loss: 1.0158 | Val Acc: 0.7106


                                                                                                                       

Epoch 25 | Train Loss: 0.5655 | Train Acc: 0.8312 | Val Loss: 1.0250 | Val Acc: 0.7106


                                                                                                                       

Epoch 26 | Train Loss: 0.5465 | Train Acc: 0.8432 | Val Loss: 1.0047 | Val Acc: 0.7055


                                                                                                                       

Epoch 27 | Train Loss: 0.5051 | Train Acc: 0.8507 | Val Loss: 0.9347 | Val Acc: 0.7243


                                                                                                                       

Epoch 28 | Train Loss: 0.4788 | Train Acc: 0.8548 | Val Loss: 0.9629 | Val Acc: 0.7346


                                                                                                                       

Epoch 29 | Train Loss: 0.4212 | Train Acc: 0.8768 | Val Loss: 0.8930 | Val Acc: 0.7414


                                                                                                                       

Epoch 30 | Train Loss: 0.4248 | Train Acc: 0.8710 | Val Loss: 1.0000 | Val Acc: 0.7226


                                                                                                                       

Epoch 31 | Train Loss: 0.4229 | Train Acc: 0.8695 | Val Loss: 1.0027 | Val Acc: 0.7106


                                                                                                                       

Epoch 32 | Train Loss: 0.3601 | Train Acc: 0.8888 | Val Loss: 0.9078 | Val Acc: 0.7603


                                                                                                                       

Epoch 33 | Train Loss: 0.3599 | Train Acc: 0.8883 | Val Loss: 0.9466 | Val Acc: 0.7329


                                                                                                                       

Epoch 34 | Train Loss: 0.3488 | Train Acc: 0.8881 | Val Loss: 0.8361 | Val Acc: 0.7500


                                                                                                                       

Epoch 35 | Train Loss: 0.3114 | Train Acc: 0.9093 | Val Loss: 1.0143 | Val Acc: 0.7021


                                                                                                                       

Epoch 36 | Train Loss: 0.3004 | Train Acc: 0.9074 | Val Loss: 0.9037 | Val Acc: 0.7329


                                                                                                                       

Epoch 37 | Train Loss: 0.2796 | Train Acc: 0.9166 | Val Loss: 0.8980 | Val Acc: 0.7637


                                                                                                                       

Epoch 38 | Train Loss: 0.2756 | Train Acc: 0.9155 | Val Loss: 0.8588 | Val Acc: 0.7688


                                                                                                                       

Epoch 39 | Train Loss: 0.2814 | Train Acc: 0.9108 | Val Loss: 0.9179 | Val Acc: 0.7483


                                                                                                                       

Epoch 40 | Train Loss: 0.2341 | Train Acc: 0.9292 | Val Loss: 0.8286 | Val Acc: 0.7637


                                                                                                                       

Epoch 41 | Train Loss: 0.2279 | Train Acc: 0.9318 | Val Loss: 0.8435 | Val Acc: 0.7671


                                                                                                                       

Epoch 42 | Train Loss: 0.2320 | Train Acc: 0.9251 | Val Loss: 0.8530 | Val Acc: 0.7568


                                                                                                                       

Epoch 43 | Train Loss: 0.2131 | Train Acc: 0.9363 | Val Loss: 0.9700 | Val Acc: 0.7432


                                                                                                                       

Epoch 44 | Train Loss: 0.2149 | Train Acc: 0.9294 | Val Loss: 0.9204 | Val Acc: 0.7637


                                                                                                                       

Epoch 45 | Train Loss: 0.2170 | Train Acc: 0.9324 | Val Loss: 0.8241 | Val Acc: 0.7688
ðŸ›‘ Early stopping

âœ… Best Val Accuracy: 0.7688
ðŸ’¾ Saved model: E:\ASL_Citizen\NEW\Top_Classes_Landmarks_Preprocessed\tcn_finetuned.pth
ðŸŽ¯ Test Loss: 0.8055 | Test Acc: 0.7880
