In [1]:
#  FINAL NO-CACHING VERSION for GRID
import os, random, numpy as np, torch
import torch.nn as nn, torch.nn.functional as F, torch.optim as optim
from sklearn.metrics import classification_report, confusion_matrix, roc_curve, auc, precision_recall_fscore_support
from tqdm import tqdm
import pandas as pd
import learn2learn as l2l
from datetime import datetime

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

optical_flow_path = r"E:\\visual_speaker_auth\\data\\gridcorpus\\optical_flow"
all_speakers = [f"s{i}" for i in range(1, 31) if i != 20]
test_speakers = [f"s{i}" for i in range(1, 7)]  # s1 to s6

remaining = [s for s in all_speakers if s not in test_speakers]
random.shuffle(remaining)
val_speakers = remaining[:3]
train_speakers = remaining[3:]


# Optuna-tuned
dropout = 0.37656865990260097
inner_lr = 0.02042052122941683
meta_lr = 0.0006003264928298247
shots = 3

# ======== MODEL ========
class OpticalFlowModel(nn.Module):
    def __init__(self, dropout):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv3d(2, 32, 3, padding=1), nn.BatchNorm3d(32), nn.ReLU(), nn.MaxPool3d((1,2,2)),
            nn.Conv3d(32, 64, 3, padding=1), nn.BatchNorm3d(64), nn.ReLU(), nn.MaxPool3d((1,2,2)),
            nn.Conv3d(64, 128, 3, padding=1), nn.BatchNorm3d(128), nn.ReLU(), nn.MaxPool3d((1,2,2)),
            nn.AdaptiveAvgPool3d((1,8,8))
        )
        self.classifier = nn.Sequential(
            nn.Linear(128*8*8, 256), nn.ReLU(), nn.Dropout(dropout), nn.Linear(256, 2)
        )
    def forward(self, x):
        x = self.features(x)
        return self.classifier(x.view(x.size(0), -1))

# ======== TASK GENERATOR (DISK LOADING) ========
class TaskGenerator:
    def __init__(self, base_path, speakers, shots=5, max_frames=30):
        self.base_path = base_path
        self.speakers = speakers
        self.shots = shots
        self.max_frames = max_frames

    def _load(self, path):
        data = np.load(path, allow_pickle=True)
        flow = torch.from_numpy(data).float().permute(3, 0, 1, 2)
        if flow.size(1) >= self.max_frames:
            flow = flow[:, :self.max_frames]
        else:
            flow = F.pad(flow, (0,0,0,0,0,self.max_frames - flow.size(1)))
        return flow + 0.01 * torch.randn_like(flow)

    def create_task(self):
        for _ in range(10):
            s1, s2 = random.sample(self.speakers, 2)
            def sample_paths(speaker, label):
                path = os.path.join(self.base_path, label, speaker)
                return [os.path.join(path, f) for f in os.listdir(path) if f.endswith('.npy')] if os.path.exists(path) else []
            sr, sf = sample_paths(s1,'real'), sample_paths(s1,'fake')
            qr, qf = sample_paths(s2,'real'), sample_paths(s2,'fake')
            if min(len(sr), len(sf), len(qr), len(qf)) < self.shots: continue
            s = [(p,0) for p in random.sample(sr, shots)] + [(p,1) for p in random.sample(sf, shots)]
            q = [(p,0) for p in random.sample(qr, shots)] + [(p,1) for p in random.sample(qf, shots)]
            s_x = torch.stack([self._load(p) for p,_ in s])
            s_y = torch.tensor([lbl for _,lbl in s])
            q_x = torch.stack([self._load(p) for p,_ in q])
            q_y = torch.tensor([lbl for _,lbl in q])
            return s_x.to(device), s_y.to(device), q_x.to(device), q_y.to(device)
        return None, None, None, None

# ======== EVALUATE ========
def evaluate(maml, task_gen, phase="Test", num_tasks=20):
    accs, preds, labels, probs = [], [], [], []
    for _ in tqdm(range(num_tasks), desc=f"{phase} Evaluation"):
        s_x, s_y, q_x, q_y = task_gen.create_task()
        if s_x is None: continue
        learner = maml.clone(); learner.adapt(F.cross_entropy(learner(s_x), s_y))
        out = learner(q_x)
        accs.append((out.argmax(1)==q_y).float().mean().item())
        preds += out.argmax(1).cpu().tolist()
        labels += q_y.cpu().tolist()
        probs += F.softmax(out,dim=1)[:,1].detach().cpu().tolist()

    acc = np.mean(accs)*100
    cm = confusion_matrix(labels, preds)
    tn, fp, fn, tp = cm.ravel()
    far = fp / (fp + tn + 1e-6)
    frr = fn / (fn + tp + 1e-6)
    hter = (far + frr) / 2
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='binary')
    fpr, tpr, _ = roc_curve(labels, probs)
    eer = fpr[np.nanargmin(np.abs(fpr - (1 - tpr)))]
    roc_auc = auc(fpr, tpr)

    print(f"\n {phase} Accuracy: {acc:.2f}%")
    print(f"🔹 EER: {eer:.4f} | FAR: {far:.4f} | FRR: {frr:.4f} | HTER: {hter:.4f}")
    print(f"🔹 Precision: {precision:.4f} | Recall: {recall:.4f} | F1: {f1:.4f}")
    print(classification_report(labels, preds, target_names=["Real", "Fake"]))

    # 📊 Save Confusion Matrix
    import matplotlib.pyplot as plt
    import seaborn as sns
    plt.figure(figsize=(6, 5))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=["Real", "Fake"], yticklabels=["Real", "Fake"])
    plt.title(f"{phase} Confusion Matrix")
    plt.xlabel("Predicted")
    plt.ylabel("Actual")
    plt.tight_layout()
    plt.savefig(f"{phase.lower()}_cm.png")
    plt.close()

    # 📈 Save ROC Curve
    plt.figure(figsize=(6, 5))
    plt.plot(fpr, tpr, label=f"AUC: {roc_auc:.4f}")
    plt.plot([0,1], [0,1], 'k--')
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title(f"{phase} ROC Curve")
    plt.legend()
    plt.tight_layout()
    plt.savefig(f"{phase.lower()}_roc.png")
    plt.close()

    return {
        "Accuracy (%)": acc, "EER": eer, "FAR": far, "FRR": frr, "HTER": hter,
        "Precision": precision, "Recall": recall, "F1-score": f1, "AUC": roc_auc
    }


# ======== TRAIN LOOP ========
def train():
    model = OpticalFlowModel(dropout).to(device)
    maml = l2l.algorithms.MAML(model, lr=inner_lr)
    optimizer = optim.Adam(maml.parameters(), lr=meta_lr)
    gen = TaskGenerator(optical_flow_path, train_speakers, shots)
    history = {'epoch':[], 'train_loss':[], 'train_acc':[]}
    best = 0
    for epoch in range(15):
        losses, accs = [], []
        for _ in tqdm(range(50), desc=f"Epoch {epoch+1}"):
            s_x, s_y, q_x, q_y = gen.create_task()
            if s_x is None: continue
            learner = maml.clone(); learner.adapt(F.cross_entropy(learner(s_x), s_y))
            out = learner(q_x)
            loss = F.cross_entropy(out, q_y)
            optimizer.zero_grad(); loss.backward(); optimizer.step()
            losses.append(loss.item())
            accs.append((out.argmax(1)==q_y).float().mean().item())
        avg_acc = np.mean(accs)*100
        print(f" Epoch {epoch+1}: Train Acc: {avg_acc:.2f}%")
        history['epoch'].append(epoch+1)
        history['train_loss'].append(np.mean(losses))
        history['train_acc'].append(avg_acc)
        if avg_acc > best:
            best = avg_acc
            torch.save(model.state_dict(), "best_model3.pth")
    pd.DataFrame(history).to_csv("training_log.csv", index=False)
    return model

# ======== MAIN ========
if __name__ == "__main__":
    trained = train()
    model = OpticalFlowModel(dropout).to(device)
    model.load_state_dict(torch.load("best_model3.pth"))
    maml = l2l.algorithms.MAML(model, lr=inner_lr)
    test_gen = TaskGenerator(optical_flow_path, test_speakers, shots)
    results = evaluate(maml, test_gen, "Test", num_tasks=50)
    pd.DataFrame([results]).to_csv("grid_metrics.csv", index=False)
    torch.save(model.state_dict(), f"final_model_{datetime.now().strftime('%Y%m%d_%H%M%S')}.pth")
    print(" Final model saved.")


✅ Using device: cuda


Epoch 1: 100%|█████████████████████████████████████████████████████████████████████████| 50/50 [01:07<00:00,  1.35s/it]


✅ Epoch 1: Train Acc: 92.33%


Epoch 2: 100%|█████████████████████████████████████████████████████████████████████████| 50/50 [01:06<00:00,  1.32s/it]


✅ Epoch 2: Train Acc: 100.00%


Epoch 3: 100%|█████████████████████████████████████████████████████████████████████████| 50/50 [01:09<00:00,  1.40s/it]


✅ Epoch 3: Train Acc: 99.67%


Epoch 4: 100%|█████████████████████████████████████████████████████████████████████████| 50/50 [01:06<00:00,  1.33s/it]


✅ Epoch 4: Train Acc: 96.33%


Epoch 5: 100%|█████████████████████████████████████████████████████████████████████████| 50/50 [01:06<00:00,  1.34s/it]


✅ Epoch 5: Train Acc: 98.33%


Epoch 6: 100%|█████████████████████████████████████████████████████████████████████████| 50/50 [01:06<00:00,  1.33s/it]


✅ Epoch 6: Train Acc: 100.00%


Epoch 7: 100%|█████████████████████████████████████████████████████████████████████████| 50/50 [01:06<00:00,  1.33s/it]


✅ Epoch 7: Train Acc: 99.67%


Epoch 8: 100%|█████████████████████████████████████████████████████████████████████████| 50/50 [01:06<00:00,  1.34s/it]


✅ Epoch 8: Train Acc: 99.67%


Epoch 9: 100%|█████████████████████████████████████████████████████████████████████████| 50/50 [01:06<00:00,  1.33s/it]


✅ Epoch 9: Train Acc: 98.67%


Epoch 10: 100%|████████████████████████████████████████████████████████████████████████| 50/50 [01:07<00:00,  1.34s/it]


✅ Epoch 10: Train Acc: 100.00%


Epoch 11: 100%|████████████████████████████████████████████████████████████████████████| 50/50 [01:06<00:00,  1.32s/it]


✅ Epoch 11: Train Acc: 100.00%


Epoch 12: 100%|████████████████████████████████████████████████████████████████████████| 50/50 [01:08<00:00,  1.37s/it]


✅ Epoch 12: Train Acc: 99.33%


Epoch 13: 100%|████████████████████████████████████████████████████████████████████████| 50/50 [01:10<00:00,  1.42s/it]


✅ Epoch 13: Train Acc: 99.33%


Epoch 14: 100%|████████████████████████████████████████████████████████████████████████| 50/50 [01:08<00:00,  1.38s/it]


✅ Epoch 14: Train Acc: 100.00%


Epoch 15: 100%|████████████████████████████████████████████████████████████████████████| 50/50 [01:08<00:00,  1.36s/it]
  model.load_state_dict(torch.load("best_model3.pth"))


✅ Epoch 15: Train Acc: 100.00%


Test Evaluation: 100%|█████████████████████████████████████████████████████████████████| 50/50 [00:44<00:00,  1.13it/s]



✅ Test Accuracy: 99.67%
🔹 EER: 0.0000 | FAR: 0.0000 | FRR: 0.0067 | HTER: 0.0033
🔹 Precision: 1.0000 | Recall: 0.9933 | F1: 0.9967
              precision    recall  f1-score   support

        Real       0.99      1.00      1.00       150
        Fake       1.00      0.99      1.00       150

    accuracy                           1.00       300
   macro avg       1.00      1.00      1.00       300
weighted avg       1.00      1.00      1.00       300

✅ Final model saved.


In [2]:
# ✅ FINAL EVALUATION SUMMARY CELL (Run After Training)

def print_final_results(results, num_tasks=50, shots=3):
    print("\n==================== FINAL RESULTS SUMMARY ====================")
    print(f"🔍 Evaluation Setup:")
    print(f"- Test Speakers        : s1 to s6")
    print(f"- Meta-Test Tasks      : {num_tasks}")
    print(f"- Support Samples/task : {shots} real + {shots} fake (per class)")
    print(f"- Query Samples/task   : {shots} real + {shots} fake (per class)")
    print(f"- Total Query Samples  : {num_tasks * shots * 2} = {num_tasks * shots * 2}")
    print("\n🎯 Performance Metrics:")
    print(f"- ✅ Test Accuracy      : {results['Accuracy (%)']:.2f}%")
    print(f"- 🔹 Equal Error Rate   : {results['EER']:.4f}")
    print(f"- 🔹 Half Total Error   : {results['HTER']:.4f}")
    print(f"- 🔹 FAR                : {results['FAR']:.4f}")
    print(f"- 🔹 FRR                : {results['FRR']:.4f}")
    print(f"- 🔹 Precision          : {results['Precision']:.4f}")
    print(f"- 🔹 Recall             : {results['Recall']:.4f}")
    print(f"- 🔹 F1-Score           : {results['F1-score']:.4f}")
    print(f"- 🔹 AUC                : {results['AUC']:.4f}")
    print("\n🖼️ Visual Results Saved:")
    print("- Confusion Matrix  ➜ test_cm.png")
    print("- ROC Curve         ➜ test_roc.png")
    print("===============================================================\n")

# 🔹 Run this in a separate cell after evaluate()
print_final_results(results, num_tasks=50, shots=3)



🔍 Evaluation Setup:
- Test Speakers        : s1 to s6
- Meta-Test Tasks      : 50
- Support Samples/task : 3 real + 3 fake (per class)
- Query Samples/task   : 3 real + 3 fake (per class)
- Total Query Samples  : 300 = 300

🎯 Performance Metrics:
- ✅ Test Accuracy      : 99.67%
- 🔹 Equal Error Rate   : 0.0000
- 🔹 Half Total Error   : 0.0033
- 🔹 FAR                : 0.0000
- 🔹 FRR                : 0.0067
- 🔹 Precision          : 1.0000
- 🔹 Recall             : 0.9933
- 🔹 F1-Score           : 0.9967
- 🔹 AUC                : 1.0000

🖼️ Visual Results Saved:
- Confusion Matrix  ➜ test_cm.png
- ROC Curve         ➜ test_roc.png

