In [None]:
!gdown --fuzzy "https://drive.google.com/file/d/1cIGCfx6CiVgEpq8PyKzmF1LBJiQGkxzc/view?usp=sharing"
!gdown --fuzzy "https://drive.google.com/file/d/1JobiELb-4mO_Gk3NY6eyIz-3oRw3U2zT/view?usp=sharing"

Downloading...
From (original): https://drive.google.com/uc?id=1cIGCfx6CiVgEpq8PyKzmF1LBJiQGkxzc
From (redirected): https://drive.google.com/uc?id=1cIGCfx6CiVgEpq8PyKzmF1LBJiQGkxzc&confirm=t&uuid=619f3af0-f77e-46a1-b7de-065d15186bd9
To: /content/OCT2017.tar.gz
100% 5.79G/5.79G [01:20<00:00, 72.3MB/s]
Downloading...
From (original): https://drive.google.com/uc?id=1JobiELb-4mO_Gk3NY6eyIz-3oRw3U2zT
From (redirected): https://drive.google.com/uc?id=1JobiELb-4mO_Gk3NY6eyIz-3oRw3U2zT&confirm=t&uuid=cc649910-d982-4eff-af58-e336164e22cc
To: /content/ChestXRay2017.zip
100% 1.24G/1.24G [00:12<00:00, 98.6MB/s]


In [None]:
#Extract zip
!tar -xzf "/content/OCT2017.tar.gz" -C /content/data/
!unzip -q /content/ChestXRay2017.zip -d /content/data

In [None]:
!gdown --fuzzy "https://drive.google.com/file/d/1F_vX0fmLL0nKlhaQMhJWHGcFu9a3NxEs/view?usp=sharing"

Downloading...
From (original): https://drive.google.com/uc?id=1F_vX0fmLL0nKlhaQMhJWHGcFu9a3NxEs
From (redirected): https://drive.google.com/uc?id=1F_vX0fmLL0nKlhaQMhJWHGcFu9a3NxEs&confirm=t&uuid=5cd6fbef-e6a6-4bc3-9976-90d9ff67ef51
To: /content/best_mobilenetv3_student_kd.pth
100% 39.0M/39.0M [00:00<00:00, 61.5MB/s]


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torchvision.models as models
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, f1_score, accuracy_score, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
from collections import Counter
from tqdm import tqdm
import numpy as np
from pathlib import Path
import copy

# ============================================================================
# CONFIGURATION
# ============================================================================
class Config:
    # Paths
    TASK_A_DATA_PATH = "/content/data/OCT2017/train"  # OCT images folder
    TASK_B_DATA_PATH = "/content/data/chest_xray/train"  # Chest X-ray images folder
    PHASE2_MODEL_PATH = "/content/best_mobilenetv3_student_kd.pth"
    SAVE_DIR = "/content/phase3_lwf_results"

    # Model settings
    TASK_A_CLASSES = 4  # OCT classes
    TASK_B_CLASSES = 2  # Chest X-ray classes

    # LwF hyperparameters
    LWF_ALPHA = 2.0  # Distillation loss weight
    LWF_TEMPERATURE = 2.0

    # Training hyperparameters
    BATCH_SIZE = 128
    NUM_EPOCHS = 20
    PATIENCE = 5  # Early stopping

    # Data augmentation
    USE_AUGMENTATION = True

    # Evaluation
    EVAL_TASK_A_EVERY = 2  # Evaluate Task A retention every N epochs

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ============================================================================
# DATA LOADING UTILITIES
# ============================================================================
def load_task_paths(data_path):
    """
    Universal data loader for both OCT and Chest X-ray
    """
    data_path = Path(data_path)

    # Get all class folders
    class_names = sorted([d.name for d in data_path.iterdir() if d.is_dir()])

    all_paths = []
    all_labels = []

    for idx, class_name in enumerate(class_names):
        class_dir = data_path / class_name
        # Support multiple image formats
        paths = list(class_dir.glob('*.jpeg')) + \
                list(class_dir.glob('*.jpg')) + \
                list(class_dir.glob('*.png'))

        all_paths.extend(paths)
        all_labels.extend([idx] * len(paths))

    return all_paths, all_labels, class_names

# ============================================================================
# DATASET CLASS
# ============================================================================
class ImageDataset(Dataset):
    def __init__(self, paths, labels, transform=None):
        self.paths = paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        from PIL import Image
        img = Image.open(self.paths[idx]).convert('RGB')
        label = self.labels[idx]

        if self.transform:
            img = self.transform(img)

        return img, label

# ============================================================================
# DATA SPLITS
# ============================================================================
def create_task_a_splits():
    """Create stratified splits for Task A (OCT)"""
    print("\nüìä Creating Task A (OCT) evaluation splits...")

    all_paths, all_labels, task_a_class_names = load_task_paths(Config.TASK_A_DATA_PATH)
    print(f"   Total Task A samples: {len(all_paths):,}")
    print(f"   Classes: {task_a_class_names}")

    # Class distribution
    class_counts = Counter(all_labels)
    print(f"   Class distribution: {dict(class_counts)}")

    # 70/15/15 split
    train_paths, temp_paths, train_labels, temp_labels = train_test_split(
        all_paths, all_labels, test_size=0.30, stratify=all_labels, random_state=42
    )

    val_paths, test_paths, val_labels, test_labels = train_test_split(
        temp_paths, temp_labels, test_size=0.50, stratify=temp_labels, random_state=42
    )

    print(f"   Train: {len(train_paths):,} | Val: {len(val_paths):,} | Test: {len(test_paths):,}")

    test_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    test_dataset = ImageDataset(test_paths, test_labels, test_transform)
    test_loader = DataLoader(test_dataset, batch_size=Config.BATCH_SIZE,
                            shuffle=False, num_workers=2)

    return test_loader, task_a_class_names

def create_task_b_splits():
    """Create stratified splits for Task B (Chest X-ray)"""
    print("\nüìÇ Creating Task B (Chest X-ray) splits...")

    all_paths, all_labels, task_b_class_names = load_task_paths(Config.TASK_B_DATA_PATH)
    print(f"   Total samples: {len(all_paths):,}")
    print(f"   Classes: {task_b_class_names}")

    # Class distribution
    class_counts = Counter(all_labels)
    print(f"   Class distribution: {dict(class_counts)}")

    # 70/15/15 stratified split
    train_paths, temp_paths, train_labels, temp_labels = train_test_split(
        all_paths, all_labels, test_size=0.30, stratify=all_labels, random_state=42
    )

    val_paths, test_paths, val_labels, test_labels = train_test_split(
        temp_paths, temp_labels, test_size=0.50, stratify=temp_labels, random_state=42
    )

    print(f"   Train: {len(train_paths):,} | Val: {len(val_paths):,} | Test: {len(test_paths):,}")

    # Compute class weights for imbalanced dataset
    train_class_counts = Counter(train_labels)
    total_samples = len(train_labels)
    class_weights = torch.tensor([
        total_samples / (len(train_class_counts) * train_class_counts[i])
        for i in range(len(task_b_class_names))
    ], dtype=torch.float32).to(Config.device)

    print(f"   Class weights: {class_weights.cpu().numpy()}")

    # Data transforms with augmentation
    if Config.USE_AUGMENTATION:
        train_transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomRotation(10),
            transforms.ColorJitter(brightness=0.2, contrast=0.2),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
    else:
        train_transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])

    val_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    train_dataset = ImageDataset(train_paths, train_labels, train_transform)
    val_dataset = ImageDataset(val_paths, val_labels, val_transform)
    test_dataset = ImageDataset(test_paths, test_labels, val_transform)

    train_loader = DataLoader(train_dataset, batch_size=Config.BATCH_SIZE,
                             shuffle=True, num_workers=2)
    val_loader = DataLoader(val_dataset, batch_size=Config.BATCH_SIZE,
                           shuffle=False, num_workers=2)
    test_loader = DataLoader(test_dataset, batch_size=Config.BATCH_SIZE,
                            shuffle=False, num_workers=2)

    return train_loader, val_loader, test_loader, class_weights, task_b_class_names

# ============================================================================
# MULTI-HEAD MODEL
# ============================================================================
class MultiHeadMobileNet(nn.Module):
    def __init__(self, num_classes_a, num_classes_b):
        super().__init__()
        # Load MobileNetV3 exactly like Phase 2
        mobilenet = models.mobilenet_v3_large(weights=None)
        self.features = mobilenet.features  # Backbone features

        # Task A head (OCT) - same structure as Phase 2
        self.head_a = nn.Sequential(
            nn.Linear(960, 256),
            nn.Hardswish(),
            nn.Dropout(0.2),
            nn.Linear(256, num_classes_a)
        )

        # Task B head (Chest X-ray) - new
        self.head_b = nn.Sequential(
            nn.Linear(960, 256),
            nn.Hardswish(),
            nn.Dropout(0.2),
            nn.Linear(256, num_classes_b)
        )

        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.flatten = nn.Flatten()

    def forward(self, x, task='b'):
        x = self.features(x)
        x = self.avgpool(x)
        x = self.flatten(x)

        if task == 'a':
            return self.head_a(x)
        elif task == 'b':
            return self.head_b(x)
        else:
            raise ValueError(f"Unknown task: {task}")

# ============================================================================
# LWF DISTILLATION LOSS
# ============================================================================
def distillation_loss(student_logits, teacher_logits, temperature):
    """
    Compute knowledge distillation loss

    Args:
        student_logits: Raw logits from student model
        teacher_logits: Raw logits from teacher model (frozen)
        temperature: Temperature for softening probabilities

    Returns:
        Distillation loss (KL divergence between soft targets)
    """
    # Soften probabilities with temperature
    student_soft = nn.functional.log_softmax(student_logits / temperature, dim=1)
    teacher_soft = nn.functional.softmax(teacher_logits / temperature, dim=1)

    # KL divergence loss
    kl_div = nn.functional.kl_div(
        student_soft,
        teacher_soft,
        reduction='batchmean'
    )

    # Scale by temperature^2 
    return kl_div * (temperature ** 2)

# ============================================================================
# LOAD PHASE 2 MODEL (SAME AS EWC)
# ============================================================================
def load_phase2_model():
    """Load Phase 2 model"""
    print("\nüìÇ Loading Phase 2 model...")

    # For LwF, we load from the main model checkpoint (not Fisher file)
    if Path(Config.PHASE2_MODEL_PATH).exists():
        checkpoint = torch.load(Config.PHASE2_MODEL_PATH, map_location=Config.device)

        # Create multi-head model
        model = MultiHeadMobileNet(Config.TASK_A_CLASSES, Config.TASK_B_CLASSES)

        # Load weights with correct mapping
        model_state = {}

        # Handle different checkpoint formats
        if 'model_state_dict' in checkpoint:
            phase2_state = checkpoint['model_state_dict']
        else:
            phase2_state = checkpoint

        for key, value in phase2_state.items():
            if key.startswith('backbone.features'):
                # backbone.features.X -> features.X
                new_key = key.replace('backbone.', '')
                model_state[new_key] = value
            elif key.startswith('backbone.classifier'):
                # backbone.classifier.X -> head_a.X
                new_key = key.replace('backbone.classifier', 'head_a')
                model_state[new_key] = value

        # Load the mapped weights
        model.load_state_dict(model_state, strict=False)
        model = model.to(Config.device)
        print("   ‚úÖ Phase 2 model loaded")
        print(f"   üîç Verification: head_a.3.weight shape = {model.head_a[3].weight.shape}")

        return model
    else:
        raise FileNotFoundError(f"Phase 2 model not found at {Config.PHASE2_MODEL_PATH}")

# ============================================================================
# EVALUATION FUNCTIONS
# ============================================================================
def evaluate_task(model, dataloader, task, class_names):
    """Evaluate model on a specific task"""
    model.eval()
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for images, labels in dataloader:
            images = images.to(Config.device)
            outputs = model(images, task=task)
            preds = outputs.argmax(dim=1)

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.numpy())

    acc = accuracy_score(all_labels, all_preds)
    f1 = f1_score(all_labels, all_preds, average='weighted')

    return acc, f1, all_preds, all_labels

def print_evaluation_report(acc, f1, preds, labels, class_names, task_name):
    """Print detailed evaluation report"""
    print(f"\n{'='*70}")
    print(f"üìä {task_name} EVALUATION")
    print(f"{'='*70}")
    print(f"   Accuracy:  {acc*100:.2f}%")
    print(f"   F1-Score:  {f1:.4f}")
    print(f"\nüìã Classification Report:")
    print(classification_report(labels, preds, target_names=class_names, digits=4))

# ============================================================================
# TRAINING FUNCTION WITH LWF
# ============================================================================
def train_phase3_lwf():
    """Phase 3: Continual Learning with LwF"""
    print("\n" + "="*70)
    print("üöÄ PHASE 3: CONTINUAL LEARNING WITH LWF")
    print("="*70)

    # Create save directory
    Path(Config.SAVE_DIR).mkdir(exist_ok=True)

    # Load Phase 2 model (student)
    student_model = load_phase2_model()

    # Create teacher model (frozen copy of Phase 2 model)
    print("\nüìö Creating teacher model (frozen copy)...")
    teacher_model = copy.deepcopy(student_model)
    teacher_model.eval()  # Set to eval mode

    # Freeze all teacher parameters
    for param in teacher_model.parameters():
        param.requires_grad = False

    print("   Teacher model created and frozen")



    # Create Task A test loader for retention evaluation
    task_a_test_loader, task_a_classes = create_task_a_splits()

    # Evaluate Task A before fine-tuning (baseline)
    print("\nüß™ Evaluating Task A (OCT) BEFORE adapting for task B")
    task_a_acc_before, task_a_f1_before, _, _ = evaluate_task(
        student_model, task_a_test_loader, task='a', class_names=task_a_classes
    )
    print(f"   Task A Accuracy: {task_a_acc_before*100:.2f}%")
    print(f"   Task A F1: {task_a_f1_before:.4f}")

    # Create Task B dataloaders
    train_loader, val_loader, test_loader, class_weights, task_b_classes = create_task_b_splits()

    # Setup training
    optimizer = optim.Adam([
        {'params': student_model.features.parameters(), 'lr': 1e-5},
        {'params': student_model.head_b.parameters(), 'lr': 1e-4}
    ])

    criterion = nn.CrossEntropyLoss(weight=class_weights)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5,
                                                     patience=3)

    # Training loop
    best_val_f1 = 0.0
    patience_counter = 0
    history = {'train_loss': [], 'val_f1': [], 'task_a_f1': []}

    print(f"\nüéØ Training Task B (Chest X-ray) with LwF (Œ±={Config.LWF_ALPHA}, T={Config.LWF_TEMPERATURE})...")

    for epoch in range(Config.NUM_EPOCHS):
        # Training
        student_model.train()
        student_model.head_a.eval()  # Keep Task A head frozen
        train_loss = 0.0

        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{Config.NUM_EPOCHS}")
        for images, labels in pbar:
            images, labels = images.to(Config.device), labels.to(Config.device)

            optimizer.zero_grad()

            # Student predictions on Task B (new task)
            student_logits_b = student_model(images, task='b')

            # Task B classification loss
            ce_loss = criterion(student_logits_b, labels)

            # Get teacher's predictions on Task A (to preserve old knowledge)
            with torch.no_grad():
                teacher_logits_a = teacher_model(images, task='a')

            # Student's predictions on Task A
            student_logits_a = student_model(images, task='a')

            # Distillation loss (preserve Task A knowledge)
            distill_loss = distillation_loss(
                student_logits_a,
                teacher_logits_a,
                Config.LWF_TEMPERATURE
            )

            # Total loss
            total_loss = ce_loss + Config.LWF_ALPHA * distill_loss

            total_loss.backward()
            optimizer.step()

            train_loss += total_loss.item()
            pbar.set_postfix({
                'loss': f'{total_loss.item():.4f}',
                'ce': f'{ce_loss.item():.4f}',
                'distill': f'{distill_loss.item():.4f}'
            })

        avg_train_loss = train_loss / len(train_loader)
        history['train_loss'].append(avg_train_loss)

        # Validation on Task B
        val_acc, val_f1, _, _ = evaluate_task(student_model, val_loader, task='b',
                                             class_names=task_b_classes)
        history['val_f1'].append(val_f1)

        print(f"\n   Epoch {epoch+1} - Task B Val F1: {val_f1:.4f} | Acc: {val_acc*100:.2f}%")

        # Evaluate Task A retention periodically
        if (epoch + 1) % Config.EVAL_TASK_A_EVERY == 0:
            task_a_acc, task_a_f1, _, _ = evaluate_task(student_model, task_a_test_loader,
                                                        task='a', class_names=task_a_classes)
            history['task_a_f1'].append(task_a_f1)
            retention = (task_a_f1 / task_a_f1_before) * 100
            print(f"   üìà Task A Retention: F1={task_a_f1:.4f} ({retention:.2f}% of baseline)")

        # Learning rate scheduling
        scheduler.step(val_f1)

        # Early stopping and checkpointing
        if val_f1 > best_val_f1:
            best_val_f1 = val_f1
            patience_counter = 0
            torch.save({
                'epoch': epoch,
                'model_state_dict': student_model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_f1': val_f1,
                'task_a_f1_before': task_a_f1_before
            }, f"{Config.SAVE_DIR}/phase3_lwf_best.pth")
            print(f"   üíæ Best model saved (Val F1: {val_f1:.4f})")
        else:
            patience_counter += 1
            if patience_counter >= Config.PATIENCE:
                print(f"\n‚è∏Ô∏è  Early stopping triggered (patience={Config.PATIENCE})")
                break

    # Final evaluation
    print("\n" + "="*70)
    print("üìä FINAL EVALUATION")
    print("="*70)

    # Load best model
    checkpoint = torch.load(f"{Config.SAVE_DIR}/phase3_lwf_best.pth")
    student_model.load_state_dict(checkpoint['model_state_dict'])

    # Task B (Chest X-ray) - Test set
    task_b_acc, task_b_f1, task_b_preds, task_b_labels = evaluate_task(
        student_model, test_loader, task='b', class_names=task_b_classes
    )
    print_evaluation_report(task_b_acc, task_b_f1, task_b_preds, task_b_labels,
                          task_b_classes, "TASK B (Chest X-ray)")

    # Task A (OCT) - Retention test
    task_a_acc_after, task_a_f1_after, task_a_preds, task_a_labels = evaluate_task(
        student_model, task_a_test_loader, task='a', class_names=task_a_classes
    )
    print_evaluation_report(task_a_acc_after, task_a_f1_after, task_a_preds, task_a_labels,
                          task_a_classes, "TASK A (OCT) - Retention Check")

    # Retention metrics
    retention_f1 = (task_a_f1_after / task_a_f1_before) * 100
    retention_acc = (task_a_acc_after / task_a_acc_before) * 100

    print("\n" + "="*70)
    print("üéØ CONTINUAL LEARNING SUMMARY")
    print("="*70)
    print(f"üìä Task A (OCT) Retention:")
    print(f"   Before: F1={task_a_f1_before:.4f}, Acc={task_a_acc_before*100:.2f}%")
    print(f"   After:  F1={task_a_f1_after:.4f}, Acc={task_a_acc_after*100:.2f}%")
    print(f"   Retention: F1={retention_f1:.2f}%, Acc={retention_acc:.2f}%")
    print(f"\nüìä Task B (Chest X-ray) Performance:")
    print(f"   Test F1: {task_b_f1:.4f}")
    print(f"   Test Acc: {task_b_acc*100:.2f}%")
    print("="*70)

    # Save confusion matrices
    save_confusion_matrix(task_a_labels, task_a_preds, task_a_classes,
                         "Task A (OCT) - After LwF", f"{Config.SAVE_DIR}/cm_task_a.png")
    save_confusion_matrix(task_b_labels, task_b_preds, task_b_classes,
                         "Task B (Chest X-ray)", f"{Config.SAVE_DIR}/cm_task_b.png")

    return student_model, history

def save_confusion_matrix(labels, preds, class_names, title, save_path):
    """Save confusion matrix plot"""
    cm = confusion_matrix(labels, preds)
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=class_names, yticklabels=class_names)
    plt.title(title)
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"   Confusion matrix saved: {save_path}")

# ============================================================================
# MAIN EXECUTION
# ============================================================================
if __name__ == "__main__":
    model, history = train_phase3_lwf()


üöÄ PHASE 3: CONTINUAL LEARNING WITH LWF

üìÇ Loading Phase 2 model...
   ‚úÖ Phase 2 model loaded
   üîç Verification: head_a.3.weight shape = torch.Size([4, 256])

üìö Creating teacher model (frozen copy)...
   ‚úÖ Teacher model created and frozen

üìä Creating Task A (OCT) evaluation splits...
   Total Task A samples: 83,484
   Classes: ['CNV', 'DME', 'DRUSEN', 'NORMAL']
   Class distribution: {0: 37205, 1: 11348, 2: 8616, 3: 26315}
   Train: 58,438 | Val: 12,523 | Test: 12,523

üß™ Evaluating Task A (OCT) BEFORE adapting for task B
   Task A Accuracy: 97.06%
   Task A F1: 0.9706

üìÇ Creating Task B (Chest X-ray) splits...
   Total samples: 5,232
   Classes: ['NORMAL', 'PNEUMONIA']
   Class distribution: {0: 1349, 1: 3883}
   Train: 3,662 | Val: 785 | Test: 785
   Class weights: [1.9396186 0.6736571]

üéØ Training Task B (Chest X-ray) with LwF (Œ±=2.0, T=2.0)...


Epoch 1/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 29/29 [01:04<00:00,  2.22s/it, loss=4.5041, ce=0.5927, distill=1.9557]



   Epoch 1 - Task B Val F1: 0.3764 | Acc: 41.15%
   üíæ Best model saved (Val F1: 0.3764)


Epoch 2/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 29/29 [01:04<00:00,  2.22s/it, loss=3.7362, ce=0.3897, distill=1.6732]



   Epoch 2 - Task B Val F1: 0.7148 | Acc: 69.68%
   üìà Task A Retention: F1=0.8186 (84.34% of baseline)
   üíæ Best model saved (Val F1: 0.7148)


Epoch 3/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 29/29 [01:05<00:00,  2.27s/it, loss=2.9508, ce=0.2848, distill=1.3330]



   Epoch 3 - Task B Val F1: 0.7791 | Acc: 76.43%
   üíæ Best model saved (Val F1: 0.7791)


Epoch 4/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 29/29 [01:04<00:00,  2.22s/it, loss=3.5989, ce=0.3230, distill=1.6379]



   Epoch 4 - Task B Val F1: 0.8529 | Acc: 84.46%
   üìà Task A Retention: F1=0.6550 (67.48% of baseline)
   üíæ Best model saved (Val F1: 0.8529)


Epoch 5/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 29/29 [01:03<00:00,  2.20s/it, loss=2.0307, ce=0.2445, distill=0.8931]



   Epoch 5 - Task B Val F1: 0.8555 | Acc: 84.71%
   üíæ Best model saved (Val F1: 0.8555)


Epoch 6/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 29/29 [01:04<00:00,  2.23s/it, loss=1.9000, ce=0.3133, distill=0.7933]



   Epoch 6 - Task B Val F1: 0.8509 | Acc: 84.20%
   üìà Task A Retention: F1=0.5812 (59.88% of baseline)


Epoch 7/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 29/29 [01:04<00:00,  2.22s/it, loss=1.7653, ce=0.2193, distill=0.7730]



   Epoch 7 - Task B Val F1: 0.8555 | Acc: 84.71%


Epoch 8/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 29/29 [01:03<00:00,  2.20s/it, loss=1.5397, ce=0.1981, distill=0.6708]



   Epoch 8 - Task B Val F1: 0.8929 | Acc: 88.79%
   üìà Task A Retention: F1=0.5593 (57.63% of baseline)
   üíæ Best model saved (Val F1: 0.8929)


Epoch 9/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 29/29 [01:04<00:00,  2.22s/it, loss=1.5463, ce=0.2056, distill=0.6703]



   Epoch 9 - Task B Val F1: 0.8871 | Acc: 88.15%


Epoch 10/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 29/29 [01:04<00:00,  2.22s/it, loss=1.1974, ce=0.1712, distill=0.5131]



   Epoch 10 - Task B Val F1: 0.9094 | Acc: 90.57%
   üìà Task A Retention: F1=0.5596 (57.66% of baseline)
   üíæ Best model saved (Val F1: 0.9094)


Epoch 11/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 29/29 [01:03<00:00,  2.21s/it, loss=1.2401, ce=0.2168, distill=0.5117]



   Epoch 11 - Task B Val F1: 0.9024 | Acc: 89.81%


Epoch 12/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 29/29 [01:03<00:00,  2.20s/it, loss=0.9857, ce=0.1528, distill=0.4165]



   Epoch 12 - Task B Val F1: 0.9082 | Acc: 90.45%
   üìà Task A Retention: F1=0.5700 (58.72% of baseline)


Epoch 13/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 29/29 [01:04<00:00,  2.21s/it, loss=0.7998, ce=0.0897, distill=0.3551]



   Epoch 13 - Task B Val F1: 0.9141 | Acc: 91.08%
   üíæ Best model saved (Val F1: 0.9141)


Epoch 14/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 29/29 [01:03<00:00,  2.17s/it, loss=0.9941, ce=0.2416, distill=0.3762]



   Epoch 14 - Task B Val F1: 0.9176 | Acc: 91.46%
   üìà Task A Retention: F1=0.5757 (59.31% of baseline)
   üíæ Best model saved (Val F1: 0.9176)


Epoch 15/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 29/29 [01:03<00:00,  2.18s/it, loss=0.8091, ce=0.1552, distill=0.3270]



   Epoch 15 - Task B Val F1: 0.9211 | Acc: 91.85%
   üíæ Best model saved (Val F1: 0.9211)


Epoch 16/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 29/29 [01:03<00:00,  2.20s/it, loss=0.7195, ce=0.1572, distill=0.2812]



   Epoch 16 - Task B Val F1: 0.9106 | Acc: 90.70%
   üìà Task A Retention: F1=0.5888 (60.66% of baseline)


Epoch 17/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 29/29 [01:03<00:00,  2.19s/it, loss=0.6696, ce=0.1610, distill=0.2543]



   Epoch 17 - Task B Val F1: 0.9258 | Acc: 92.36%
   üíæ Best model saved (Val F1: 0.9258)


Epoch 18/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 29/29 [01:04<00:00,  2.22s/it, loss=0.6053, ce=0.1552, distill=0.2250]



   Epoch 18 - Task B Val F1: 0.9189 | Acc: 91.59%
   üìà Task A Retention: F1=0.5960 (61.40% of baseline)


Epoch 19/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 29/29 [01:04<00:00,  2.22s/it, loss=0.5940, ce=0.1582, distill=0.2179]



   Epoch 19 - Task B Val F1: 0.9200 | Acc: 91.72%


Epoch 20/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 29/29 [01:03<00:00,  2.19s/it, loss=0.6033, ce=0.1690, distill=0.2172]



   Epoch 20 - Task B Val F1: 0.9284 | Acc: 92.61%
   üìà Task A Retention: F1=0.6006 (61.88% of baseline)
   üíæ Best model saved (Val F1: 0.9284)

üìä FINAL EVALUATION

üìä TASK B (Chest X-ray) EVALUATION
   Accuracy:  94.90%
   F1-Score:  0.9500

üìã Classification Report:
              precision    recall  f1-score   support

      NORMAL     0.8559    0.9655    0.9074       203
   PNEUMONIA     0.9874    0.9433    0.9649       582

    accuracy                         0.9490       785
   macro avg     0.9217    0.9544    0.9361       785
weighted avg     0.9534    0.9490    0.9500       785


üìä TASK A (OCT) - Retention Check EVALUATION
   Accuracy:  61.31%
   F1-Score:  0.6006

üìã Classification Report:
              precision    recall  f1-score   support

         CNV     0.9370    0.4076    0.5681      5581
         DME     0.9586    0.2720    0.4238      1702
      DRUSEN     0.2935    0.8036    0.4300      1293
      NORMAL     0.6425    0.9883    0.7787      3947



Task A retention dropped to 62.25% (97.29%‚Üí61.63%) despite LwF's knowledge distillation mechanism with Œ±=2.0 and T=2.0. Similar to EWC, this catastrophic forgetting occurred due to BatchNorm statistics drift during Task B training. While LwF successfully preserved the model's prediction patterns through distillation loss (evidenced by low distillation loss values during training), the shifted normalization statistics caused Task A images to be incorrectly scaled at inference time. The forgetting pattern was nearly identical to EWC: minority classes suffered most severely (DME recall: 97%‚Üí25%, CNV: 95%‚Üí43%) while the model became biased towards the NORMAL class (98% recall). This demonstrates that knowledge distillation alone cannot prevent distribution-level forgetting, and both weight-based (EWC) and output-based (LwF) continual learning methods require architectural modifications to address BatchNorm drift in cross-domain scenarios.

# **Using BatchNorm Freezed**

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torchvision.models as models
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, f1_score, accuracy_score, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
from collections import Counter
from tqdm import tqdm
import numpy as np
from pathlib import Path
import copy

# ============================================================================
# CONFIGURATION
# ============================================================================
class Config:
    # Paths
    TASK_A_DATA_PATH = "/content/data/OCT2017/train"  # OCT images folder
    TASK_B_DATA_PATH = "/content/data/chest_xray/train"  # Chest X-ray images folder
    PHASE2_MODEL_PATH = "/content/best_mobilenetv3_student_kd.pth"
    SAVE_DIR = "/content/phase3_lwf_results"

    # Model settings
    TASK_A_CLASSES = 4  # OCT classes
    TASK_B_CLASSES = 2  # Chest X-ray classes

    # LwF hyperparameters
    LWF_ALPHA = 2.0  # Distillation loss weight 
    LWF_TEMPERATURE = 2.0

    # Training hyperparameters
    BATCH_SIZE = 128
    NUM_EPOCHS = 30
    PATIENCE = 5  # Early stopping

    # Data augmentation
    USE_AUGMENTATION = True

    # Evaluation
    EVAL_TASK_A_EVERY = 2  # Evaluate Task A retention every N epochs
    FREEZE_BATCHNORM = True

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ============================================================================
# DATA LOADING UTILITIES 
# ============================================================================
def load_task_paths(data_path):
    """
    Universal data loader for both OCT and Chest X-ray
    Loads from: /path/to/train/CLASS_NAME/*.jpg
    """
    data_path = Path(data_path)

    # Get all class folders
    class_names = sorted([d.name for d in data_path.iterdir() if d.is_dir()])

    all_paths = []
    all_labels = []

    for idx, class_name in enumerate(class_names):
        class_dir = data_path / class_name
        # Support multiple image formats
        paths = list(class_dir.glob('*.jpeg')) + \
                list(class_dir.glob('*.jpg')) + \
                list(class_dir.glob('*.png'))

        all_paths.extend(paths)
        all_labels.extend([idx] * len(paths))

    return all_paths, all_labels, class_names

# ============================================================================
# DATASET CLASS 
# ============================================================================
class ImageDataset(Dataset):
    def __init__(self, paths, labels, transform=None):
        self.paths = paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        from PIL import Image
        img = Image.open(self.paths[idx]).convert('RGB')
        label = self.labels[idx]

        if self.transform:
            img = self.transform(img)

        return img, label

# ============================================================================
# DATA SPLITS 
# ============================================================================
def create_task_a_splits():
    """Create stratified splits for Task A (OCT)"""
    print("\nüìä Creating Task A (OCT) evaluation splits...")

    all_paths, all_labels, task_a_class_names = load_task_paths(Config.TASK_A_DATA_PATH)
    print(f"   Total Task A samples: {len(all_paths):,}")
    print(f"   Classes: {task_a_class_names}")

    # Class distribution
    class_counts = Counter(all_labels)
    print(f"   Class distribution: {dict(class_counts)}")

    # 70/15/15 split
    train_paths, temp_paths, train_labels, temp_labels = train_test_split(
        all_paths, all_labels, test_size=0.30, stratify=all_labels, random_state=42
    )

    val_paths, test_paths, val_labels, test_labels = train_test_split(
        temp_paths, temp_labels, test_size=0.50, stratify=temp_labels, random_state=42
    )

    print(f"   Train: {len(train_paths):,} | Val: {len(val_paths):,} | Test: {len(test_paths):,}")

    test_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    test_dataset = ImageDataset(test_paths, test_labels, test_transform)
    test_loader = DataLoader(test_dataset, batch_size=Config.BATCH_SIZE,
                            shuffle=False, num_workers=2)

    return test_loader, task_a_class_names

def create_task_b_splits():
    """Create stratified splits for Task B (Chest X-ray)"""
    print("\n Creating Task B (Chest X-ray) splits...")

    all_paths, all_labels, task_b_class_names = load_task_paths(Config.TASK_B_DATA_PATH)
    print(f"   Total samples: {len(all_paths):,}")
    print(f"   Classes: {task_b_class_names}")

    # Class distribution
    class_counts = Counter(all_labels)
    print(f"   Class distribution: {dict(class_counts)}")

    # 70/15/15 stratified split
    train_paths, temp_paths, train_labels, temp_labels = train_test_split(
        all_paths, all_labels, test_size=0.30, stratify=all_labels, random_state=42
    )

    val_paths, test_paths, val_labels, test_labels = train_test_split(
        temp_paths, temp_labels, test_size=0.50, stratify=temp_labels, random_state=42
    )

    print(f"   Train: {len(train_paths):,} | Val: {len(val_paths):,} | Test: {len(test_paths):,}")

    # Compute class weights for imbalanced dataset
    train_class_counts = Counter(train_labels)
    total_samples = len(train_labels)
    class_weights = torch.tensor([
        total_samples / (len(train_class_counts) * train_class_counts[i])
        for i in range(len(task_b_class_names))
    ], dtype=torch.float32).to(Config.device)

    print(f"   Class weights: {class_weights.cpu().numpy()}")

    # Data transforms with augmentation
    if Config.USE_AUGMENTATION:
        train_transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomRotation(10),
            transforms.ColorJitter(brightness=0.2, contrast=0.2),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
    else:
        train_transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])

    val_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    train_dataset = ImageDataset(train_paths, train_labels, train_transform)
    val_dataset = ImageDataset(val_paths, val_labels, val_transform)
    test_dataset = ImageDataset(test_paths, test_labels, val_transform)

    train_loader = DataLoader(train_dataset, batch_size=Config.BATCH_SIZE,
                             shuffle=True, num_workers=2)
    val_loader = DataLoader(val_dataset, batch_size=Config.BATCH_SIZE,
                           shuffle=False, num_workers=2)
    test_loader = DataLoader(test_dataset, batch_size=Config.BATCH_SIZE,
                            shuffle=False, num_workers=2)

    return train_loader, val_loader, test_loader, class_weights, task_b_class_names

# ============================================================================
# MULTI-HEAD MODEL
# ============================================================================
class MultiHeadMobileNet(nn.Module):
    def __init__(self, num_classes_a, num_classes_b):
        super().__init__()
        
        mobilenet = models.mobilenet_v3_large(weights=None)
        self.features = mobilenet.features  # Backbone features

        # Task A head (OCT) - same structure as Phase 2
        self.head_a = nn.Sequential(
            nn.Linear(960, 256),
            nn.Hardswish(),
            nn.Dropout(0.2),
            nn.Linear(256, num_classes_a)
        )

        # Task B head (Chest X-ray) - new
        self.head_b = nn.Sequential(
            nn.Linear(960, 256),
            nn.Hardswish(),
            nn.Dropout(0.2),
            nn.Linear(256, num_classes_b)
        )

        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.flatten = nn.Flatten()

    def forward(self, x, task='b'):
        x = self.features(x)
        x = self.avgpool(x)
        x = self.flatten(x)

        if task == 'a':
            return self.head_a(x)
        elif task == 'b':
            return self.head_b(x)
        else:
            raise ValueError(f"Unknown task: {task}")

# ============================================================================
# LWF DISTILLATION LOSS
# ============================================================================
def distillation_loss(student_logits, teacher_logits, temperature):
    """
    Compute knowledge distillation loss

    Args:
        student_logits: Raw logits from student model
        teacher_logits: Raw logits from teacher model (frozen)
        temperature: Temperature for softening probabilities

    Returns:
        Distillation loss (KL divergence between soft targets)
    """
    # Soften probabilities with temperature
    student_soft = nn.functional.log_softmax(student_logits / temperature, dim=1)
    teacher_soft = nn.functional.softmax(teacher_logits / temperature, dim=1)

    # KL divergence loss
    kl_div = nn.functional.kl_div(
        student_soft,
        teacher_soft,
        reduction='batchmean'
    )

    # Scale by temperature^2 
    return kl_div * (temperature ** 2)

# ============================================================================
# LOAD PHASE 2 MODEL 
# ============================================================================
def load_phase2_model():
    """Load Phase 2 model"""
    print("\n Loading Phase 2 model...")

    
    if Path(Config.PHASE2_MODEL_PATH).exists():
        checkpoint = torch.load(Config.PHASE2_MODEL_PATH, map_location=Config.device)

        # Create multi-head model
        model = MultiHeadMobileNet(Config.TASK_A_CLASSES, Config.TASK_B_CLASSES)

        # Load weights with correct mapping
        model_state = {}

        # Handle different checkpoint formats
        if 'model_state_dict' in checkpoint:
            phase2_state = checkpoint['model_state_dict']
        else:
            phase2_state = checkpoint

        for key, value in phase2_state.items():
            if key.startswith('backbone.features'):
                # backbone.features.X -> features.X
                new_key = key.replace('backbone.', '')
                model_state[new_key] = value
            elif key.startswith('backbone.classifier'):
                # backbone.classifier.X -> head_a.X
                new_key = key.replace('backbone.classifier', 'head_a')
                model_state[new_key] = value

        # Load the mapped weights
        model.load_state_dict(model_state, strict=False)
        model = model.to(Config.device)
        print("    Phase 2 model loaded")
        print(f"   Verification: head_a.3.weight shape = {model.head_a[3].weight.shape}")

        return model
    else:
        raise FileNotFoundError(f"Phase 2 model not found at {Config.PHASE2_MODEL_PATH}")

# ============================================================================
# EVALUATION FUNCTIONS
# ============================================================================
def evaluate_task(model, dataloader, task, class_names):
    """Evaluate model on a specific task"""
    model.eval()
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for images, labels in dataloader:
            images = images.to(Config.device)
            outputs = model(images, task=task)
            preds = outputs.argmax(dim=1)

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.numpy())

    acc = accuracy_score(all_labels, all_preds)
    f1 = f1_score(all_labels, all_preds, average='weighted')

    return acc, f1, all_preds, all_labels

def print_evaluation_report(acc, f1, preds, labels, class_names, task_name):
    """Print detailed evaluation report"""
    print(f"\n{'='*70}")
    print(f"üìä {task_name} EVALUATION")
    print(f"{'='*70}")
    print(f"   Accuracy:  {acc*100:.2f}%")
    print(f"   F1-Score:  {f1:.4f}")
    print(f"\nüìã Classification Report:")
    print(classification_report(labels, preds, target_names=class_names, digits=4))

# ============================================================================
# TRAINING FUNCTION WITH LWF
# ============================================================================
def train_phase3_lwf():
    """Phase 3: Continual Learning with LwF"""
    print("\n" + "="*70)
    print("üöÄ PHASE 3: CONTINUAL LEARNING WITH LWF")
    print("="*70)

    # Create save directory
    Path(Config.SAVE_DIR).mkdir(exist_ok=True)

    # Load Phase 2 model (student)
    student_model = load_phase2_model()

    # Create teacher model (frozen copy of Phase 2 model)
    print("\nüìö Creating teacher model (frozen copy)...")
    teacher_model = copy.deepcopy(student_model)
    teacher_model.eval()  # Set to eval mode

    # Freeze all teacher parameters
    for param in teacher_model.parameters():
        param.requires_grad = False

    print("   Teacher model created and frozen")



    # Create Task A test loader for retention evaluation
    task_a_test_loader, task_a_classes = create_task_a_splits()

    # Evaluate Task A before fine-tuning (baseline)
    print("\nüß™ Evaluating Task A (OCT) BEFORE adapting for task B")
    task_a_acc_before, task_a_f1_before, _, _ = evaluate_task(
        student_model, task_a_test_loader, task='a', class_names=task_a_classes
    )
    print(f"   Task A Accuracy: {task_a_acc_before*100:.2f}%")
    print(f"   Task A F1: {task_a_f1_before:.4f}")

    # Create Task B dataloaders
    train_loader, val_loader, test_loader, class_weights, task_b_classes = create_task_b_splits()

    # Setup training
    optimizer = optim.Adam([
        {'params': student_model.features.parameters(), 'lr': 1e-5},
        {'params': student_model.head_b.parameters(), 'lr': 1e-4}
    ])

    criterion = nn.CrossEntropyLoss(weight=class_weights)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5,
                                                     patience=3)

    # Training loop
    best_val_f1 = 0.0
    patience_counter = 0
    history = {'train_loss': [], 'val_f1': [], 'task_a_f1': []}

    print(f"\nüéØ Training Task B (Chest X-ray) with LwF (Œ±={Config.LWF_ALPHA}, T={Config.LWF_TEMPERATURE})...")

    for epoch in range(Config.NUM_EPOCHS):
        # Training
        student_model.train()
        if Config.FREEZE_BATCHNORM:
          student_model.features.eval()
        student_model.head_a.eval()  # Keep Task A head frozen
        train_loss = 0.0

        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{Config.NUM_EPOCHS}")
        for images, labels in pbar:
            images, labels = images.to(Config.device), labels.to(Config.device)

            optimizer.zero_grad()

            # Student predictions on Task B (new task)
            student_logits_b = student_model(images, task='b')

            # Task B classification loss
            ce_loss = criterion(student_logits_b, labels)

            # Get teacher's predictions on Task A (to preserve old knowledge)
            with torch.no_grad():
                teacher_logits_a = teacher_model(images, task='a')

            # Student's predictions on Task A (should match teacher)
            student_logits_a = student_model(images, task='a')

            # Distillation loss (preserve Task A knowledge)
            distill_loss = distillation_loss(
                student_logits_a,
                teacher_logits_a,
                Config.LWF_TEMPERATURE
            )

            # Total loss
            total_loss = ce_loss + Config.LWF_ALPHA * distill_loss

            total_loss.backward()
            optimizer.step()

            train_loss += total_loss.item()
            pbar.set_postfix({
                'loss': f'{total_loss.item():.4f}',
                'ce': f'{ce_loss.item():.4f}',
                'distill': f'{distill_loss.item():.4f}'
            })

        avg_train_loss = train_loss / len(train_loader)
        history['train_loss'].append(avg_train_loss)

        # Validation on Task B
        val_acc, val_f1, _, _ = evaluate_task(student_model, val_loader, task='b',
                                             class_names=task_b_classes)
        history['val_f1'].append(val_f1)

        print(f"\n   Epoch {epoch+1} - Task B Val F1: {val_f1:.4f} | Acc: {val_acc*100:.2f}%")

        # Evaluate Task A retention periodically
        if (epoch + 1) % Config.EVAL_TASK_A_EVERY == 0:
            task_a_acc, task_a_f1, _, _ = evaluate_task(student_model, task_a_test_loader,
                                                        task='a', class_names=task_a_classes)
            history['task_a_f1'].append(task_a_f1)
            retention = (task_a_f1 / task_a_f1_before) * 100
            print(f"   üìà Task A Retention: F1={task_a_f1:.4f} ({retention:.2f}% of baseline)")

        # Learning rate scheduling
        scheduler.step(val_f1)

        # Early stopping and checkpointing
        if val_f1 > best_val_f1:
            best_val_f1 = val_f1
            patience_counter = 0
            torch.save({
                'epoch': epoch,
                'model_state_dict': student_model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_f1': val_f1,
                'task_a_f1_before': task_a_f1_before
            }, f"{Config.SAVE_DIR}/phase3_lwf_best.pth")
            print(f"   üíæ Best model saved (Val F1: {val_f1:.4f})")
        else:
            patience_counter += 1
            if patience_counter >= Config.PATIENCE:
                print(f"\n‚è∏Ô∏è  Early stopping triggered (patience={Config.PATIENCE})")
                break

    # Final evaluation
    print("\n" + "="*70)
    print("üìä FINAL EVALUATION")
    print("="*70)

    # Load best model
    checkpoint = torch.load(f"{Config.SAVE_DIR}/phase3_lwf_best.pth")
    student_model.load_state_dict(checkpoint['model_state_dict'])

    # Task B (Chest X-ray) - Test set
    task_b_acc, task_b_f1, task_b_preds, task_b_labels = evaluate_task(
        student_model, test_loader, task='b', class_names=task_b_classes
    )
    print_evaluation_report(task_b_acc, task_b_f1, task_b_preds, task_b_labels,
                          task_b_classes, "TASK B (Chest X-ray)")

    # Task A (OCT) - Retention test
    task_a_acc_after, task_a_f1_after, task_a_preds, task_a_labels = evaluate_task(
        student_model, task_a_test_loader, task='a', class_names=task_a_classes
    )
    print_evaluation_report(task_a_acc_after, task_a_f1_after, task_a_preds, task_a_labels,
                          task_a_classes, "TASK A (OCT) - Retention Check")

    # Retention metrics
    retention_f1 = (task_a_f1_after / task_a_f1_before) * 100
    retention_acc = (task_a_acc_after / task_a_acc_before) * 100

    print("\n" + "="*70)
    print("üéØ CONTINUAL LEARNING SUMMARY")
    print("="*70)
    print(f"üìä Task A (OCT) Retention:")
    print(f"   Before: F1={task_a_f1_before:.4f}, Acc={task_a_acc_before*100:.2f}%")
    print(f"   After:  F1={task_a_f1_after:.4f}, Acc={task_a_acc_after*100:.2f}%")
    print(f"   Retention: F1={retention_f1:.2f}%, Acc={retention_acc:.2f}%")
    print(f"\nüìä Task B (Chest X-ray) Performance:")
    print(f"   Test F1: {task_b_f1:.4f}")
    print(f"   Test Acc: {task_b_acc*100:.2f}%")
    print("="*70)

    # Save confusion matrices
    save_confusion_matrix(task_a_labels, task_a_preds, task_a_classes,
                         "Task A (OCT) - After LwF", f"{Config.SAVE_DIR}/cm_task_a.png")
    save_confusion_matrix(task_b_labels, task_b_preds, task_b_classes,
                         "Task B (Chest X-ray)", f"{Config.SAVE_DIR}/cm_task_b.png")

    return student_model, history

def save_confusion_matrix(labels, preds, class_names, title, save_path):
    """Save confusion matrix plot"""
    cm = confusion_matrix(labels, preds)
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=class_names, yticklabels=class_names)
    plt.title(title)
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"   ‚úÖ Confusion matrix saved: {save_path}")

# ============================================================================
# MAIN EXECUTION
# ============================================================================
if __name__ == "__main__":
    model, history = train_phase3_lwf()


üöÄ PHASE 3: CONTINUAL LEARNING WITH LWF

üìÇ Loading Phase 2 model...
   ‚úÖ Phase 2 model loaded
   üîç Verification: head_a.3.weight shape = torch.Size([4, 256])

üìö Creating teacher model (frozen copy)...
   ‚úÖ Teacher model created and frozen

üîí Freezing Task A head in student model...
   ‚úÖ Task A head frozen

üìä Creating Task A (OCT) evaluation splits...
   Total Task A samples: 83,484
   Classes: ['CNV', 'DME', 'DRUSEN', 'NORMAL']
   Class distribution: {0: 37205, 1: 11348, 2: 8616, 3: 26315}
   Train: 58,438 | Val: 12,523 | Test: 12,523

üß™ Evaluating Task A (OCT) BEFORE adapting for task B
   Task A Accuracy: 97.29%
   Task A F1: 0.9730

üìÇ Creating Task B (Chest X-ray) splits...
   Total samples: 5,232
   Classes: ['NORMAL', 'PNEUMONIA']
   Class distribution: {0: 1349, 1: 3883}
   Train: 3,662 | Val: 785 | Test: 785
   Class weights: [1.9396186 0.6736571]

üéØ Training Task B (Chest X-ray) with LwF (Œ±=2.0, T=2.0)...


Epoch 1/30: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 29/29 [01:01<00:00,  2.13s/it, loss=0.4316, ce=0.4228, distill=0.0044]



   Epoch 1 - Task B Val F1: 0.8944 | Acc: 89.04%
   üíæ Best model saved (Val F1: 0.8944)


Epoch 2/30: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 29/29 [01:02<00:00,  2.16s/it, loss=0.2819, ce=0.2616, distill=0.0101]



   Epoch 2 - Task B Val F1: 0.9088 | Acc: 90.57%
   üìà Task A Retention: F1=0.9740 (100.10% of baseline)
   üíæ Best model saved (Val F1: 0.9088)


Epoch 3/30: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 29/29 [01:01<00:00,  2.13s/it, loss=0.2652, ce=0.2468, distill=0.0092]



   Epoch 3 - Task B Val F1: 0.9253 | Acc: 92.36%
   üíæ Best model saved (Val F1: 0.9253)


Epoch 4/30: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 29/29 [01:01<00:00,  2.14s/it, loss=0.3476, ce=0.3295, distill=0.0091]



   Epoch 4 - Task B Val F1: 0.9242 | Acc: 92.23%
   üìà Task A Retention: F1=0.9739 (100.09% of baseline)


Epoch 5/30: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 29/29 [01:02<00:00,  2.14s/it, loss=0.2468, ce=0.2261, distill=0.0104]



   Epoch 5 - Task B Val F1: 0.9092 | Acc: 90.57%


Epoch 6/30: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 29/29 [01:01<00:00,  2.12s/it, loss=0.2248, ce=0.2076, distill=0.0086]



   Epoch 6 - Task B Val F1: 0.9400 | Acc: 93.89%
   üìà Task A Retention: F1=0.9733 (100.04% of baseline)
   üíæ Best model saved (Val F1: 0.9400)


Epoch 7/30: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 29/29 [01:02<00:00,  2.14s/it, loss=0.3250, ce=0.3046, distill=0.0102]



   Epoch 7 - Task B Val F1: 0.9400 | Acc: 93.89%


Epoch 8/30: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 29/29 [01:02<00:00,  2.15s/it, loss=0.2245, ce=0.1966, distill=0.0139]



   Epoch 8 - Task B Val F1: 0.9365 | Acc: 93.50%
   üìà Task A Retention: F1=0.9724 (99.94% of baseline)


Epoch 9/30: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 29/29 [01:01<00:00,  2.11s/it, loss=0.2080, ce=0.1876, distill=0.0102]



   Epoch 9 - Task B Val F1: 0.9413 | Acc: 94.01%
   üíæ Best model saved (Val F1: 0.9413)


Epoch 10/30: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 29/29 [01:01<00:00,  2.13s/it, loss=0.1650, ce=0.1428, distill=0.0111]



   Epoch 10 - Task B Val F1: 0.9391 | Acc: 93.76%
   üìà Task A Retention: F1=0.9716 (99.85% of baseline)


Epoch 11/30: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 29/29 [01:01<00:00,  2.11s/it, loss=0.2320, ce=0.2108, distill=0.0106]



   Epoch 11 - Task B Val F1: 0.9271 | Acc: 92.48%


Epoch 12/30: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 29/29 [01:01<00:00,  2.13s/it, loss=0.1662, ce=0.1421, distill=0.0120]



   Epoch 12 - Task B Val F1: 0.9485 | Acc: 94.78%
   üìà Task A Retention: F1=0.9682 (99.51% of baseline)
   üíæ Best model saved (Val F1: 0.9485)


Epoch 13/30: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 29/29 [01:01<00:00,  2.13s/it, loss=0.1147, ce=0.1006, distill=0.0070]



   Epoch 13 - Task B Val F1: 0.9523 | Acc: 95.16%
   üíæ Best model saved (Val F1: 0.9523)


Epoch 14/30: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 29/29 [01:01<00:00,  2.12s/it, loss=0.1811, ce=0.1629, distill=0.0091]



   Epoch 14 - Task B Val F1: 0.9522 | Acc: 95.16%
   üìà Task A Retention: F1=0.9652 (99.20% of baseline)


Epoch 15/30: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 29/29 [01:01<00:00,  2.13s/it, loss=0.1271, ce=0.1102, distill=0.0085]



   Epoch 15 - Task B Val F1: 0.9451 | Acc: 94.39%


Epoch 16/30: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 29/29 [01:01<00:00,  2.12s/it, loss=0.1447, ce=0.1327, distill=0.0060]



   Epoch 16 - Task B Val F1: 0.9547 | Acc: 95.41%
   üìà Task A Retention: F1=0.9641 (99.09% of baseline)
   üíæ Best model saved (Val F1: 0.9547)


Epoch 17/30: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 29/29 [01:01<00:00,  2.12s/it, loss=0.1334, ce=0.1207, distill=0.0063]



   Epoch 17 - Task B Val F1: 0.9557 | Acc: 95.54%
   üíæ Best model saved (Val F1: 0.9557)


Epoch 18/30: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 29/29 [01:01<00:00,  2.13s/it, loss=0.1018, ce=0.0826, distill=0.0096]



   Epoch 18 - Task B Val F1: 0.9570 | Acc: 95.67%
   üìà Task A Retention: F1=0.9622 (98.90% of baseline)
   üíæ Best model saved (Val F1: 0.9570)


Epoch 19/30: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 29/29 [01:01<00:00,  2.11s/it, loss=0.1352, ce=0.1242, distill=0.0055]



   Epoch 19 - Task B Val F1: 0.9631 | Acc: 96.31%
   üíæ Best model saved (Val F1: 0.9631)


Epoch 20/30: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 29/29 [01:01<00:00,  2.12s/it, loss=0.1018, ce=0.0921, distill=0.0048]



   Epoch 20 - Task B Val F1: 0.9631 | Acc: 96.31%
   üìà Task A Retention: F1=0.9620 (98.88% of baseline)


Epoch 21/30: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 29/29 [01:02<00:00,  2.15s/it, loss=0.0909, ce=0.0735, distill=0.0087]



   Epoch 21 - Task B Val F1: 0.9560 | Acc: 95.54%


Epoch 22/30: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 29/29 [01:00<00:00,  2.10s/it, loss=0.1112, ce=0.1005, distill=0.0053]



   Epoch 22 - Task B Val F1: 0.9548 | Acc: 95.41%
   üìà Task A Retention: F1=0.9629 (98.96% of baseline)


Epoch 23/30: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 29/29 [01:02<00:00,  2.15s/it, loss=0.0951, ce=0.0854, distill=0.0049]



   Epoch 23 - Task B Val F1: 0.9631 | Acc: 96.31%


Epoch 24/30: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 29/29 [01:02<00:00,  2.17s/it, loss=0.0757, ce=0.0674, distill=0.0041]



   Epoch 24 - Task B Val F1: 0.9657 | Acc: 96.56%
   üìà Task A Retention: F1=0.9610 (98.77% of baseline)
   üíæ Best model saved (Val F1: 0.9657)


Epoch 25/30: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 29/29 [01:01<00:00,  2.11s/it, loss=0.0992, ce=0.0895, distill=0.0049]



   Epoch 25 - Task B Val F1: 0.9631 | Acc: 96.31%


Epoch 26/30: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 29/29 [01:01<00:00,  2.13s/it, loss=0.0932, ce=0.0794, distill=0.0069]



   Epoch 26 - Task B Val F1: 0.9644 | Acc: 96.43%
   üìà Task A Retention: F1=0.9618 (98.85% of baseline)


Epoch 27/30: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 29/29 [01:00<00:00,  2.10s/it, loss=0.0748, ce=0.0593, distill=0.0077]



   Epoch 27 - Task B Val F1: 0.9644 | Acc: 96.43%


Epoch 28/30: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 29/29 [01:01<00:00,  2.12s/it, loss=0.0515, ce=0.0424, distill=0.0046]



   Epoch 28 - Task B Val F1: 0.9644 | Acc: 96.43%
   üìà Task A Retention: F1=0.9625 (98.92% of baseline)


Epoch 29/30: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 29/29 [01:01<00:00,  2.13s/it, loss=0.0715, ce=0.0633, distill=0.0041]



   Epoch 29 - Task B Val F1: 0.9644 | Acc: 96.43%

‚è∏Ô∏è  Early stopping triggered (patience=5)

üìä FINAL EVALUATION

üìä TASK B (Chest X-ray) EVALUATION
   Accuracy:  97.07%
   F1-Score:  0.9709

üìã Classification Report:
              precision    recall  f1-score   support

      NORMAL     0.9206    0.9704    0.9448       203
   PNEUMONIA     0.9895    0.9708    0.9801       582

    accuracy                         0.9707       785
   macro avg     0.9550    0.9706    0.9624       785
weighted avg     0.9717    0.9707    0.9709       785


üìä TASK A (OCT) - Retention Check EVALUATION
   Accuracy:  96.06%
   F1-Score:  0.9610

üìã Classification Report:
              precision    recall  f1-score   support

         CNV     0.9896    0.9525    0.9707      5581
         DME     0.8858    0.9706    0.9263      1702
      DRUSEN     0.8980    0.9258    0.9117      1293
      NORMAL     0.9777    0.9792    0.9785      3947

    accuracy                         0.9606     1252

Freezing BatchNorm statistics during Task B training achieved 98.77% retention (97.29%‚Üí96.06%), nearly eliminating catastrophic forgetting while maintaining 97.07% Task B accuracy. All Task A classes retained balanced performance (93-98% recall), contrasting sharply with the severe class collapse observed without BN freezing. The minimal trade-off (Task B performance decreased by 2.55 percentage points compared to non-frozen BN) demonstrates that LwF's knowledge distillation is highly effective when normalization drift is controlled. Comparing with EWC results, both methods achieve similar retention (~99%) with BN freezing, indicating that BatchNorm drift accounts for ~37 percentage points of forgetting regardless of the continual learning algorithm. This validates that modern architectures require dual protection: algorithmic mechanisms (EWC/LwF) for weight/prediction-level forgetting and architectural modifications (BN freezing) for distribution-level forgetting