In [None]:
#Datasets Download
!gdown --fuzzy "https://drive.google.com/file/d/1cIGCfx6CiVgEpq8PyKzmF1LBJiQGkxzc/view?usp=sharing"
!gdown --fuzzy "https://drive.google.com/file/d/1JobiELb-4mO_Gk3NY6eyIz-3oRw3U2zT/view?usp=sharing"


Downloading...
From (original): https://drive.google.com/uc?id=1cIGCfx6CiVgEpq8PyKzmF1LBJiQGkxzc
From (redirected): https://drive.google.com/uc?id=1cIGCfx6CiVgEpq8PyKzmF1LBJiQGkxzc&confirm=t&uuid=ad27cd07-e175-4b85-aa80-5586c716b440
To: /content/OCT2017.tar.gz
100% 5.79G/5.79G [01:16<00:00, 75.5MB/s]
Downloading...
From (original): https://drive.google.com/uc?id=1JobiELb-4mO_Gk3NY6eyIz-3oRw3U2zT
From (redirected): https://drive.google.com/uc?id=1JobiELb-4mO_Gk3NY6eyIz-3oRw3U2zT&confirm=t&uuid=1ea80a4e-3e73-4b23-bacd-621f3b24e99b
To: /content/ChestXRay2017.zip
100% 1.24G/1.24G [00:17<00:00, 69.6MB/s]


In [None]:
#Extract zip
!tar -xzf "/content/OCT2017.tar.gz" -C /content/data/
!unzip -q /content/ChestXRay2017.zip -d /content/data

In [None]:
#Student distillied [Fearure based] model download
!gdown --fuzzy "https://drive.google.com/file/d/1Gm741jjzGMLXcbDSBPYpUEtcDWRtcMFl/view?usp=drive_link"

Downloading...
From (original): https://drive.google.com/uc?id=1Gm741jjzGMLXcbDSBPYpUEtcDWRtcMFl
From (redirected): https://drive.google.com/uc?id=1Gm741jjzGMLXcbDSBPYpUEtcDWRtcMFl&confirm=t&uuid=baafdd3e-619f-4305-96c8-9065ba652646
To: /content/best_mobilenetv3_student_feature_kd.pth
100% 76.0M/76.0M [00:01<00:00, 39.3MB/s]


# **FISHER MATRIX**

In [None]:
# =================================================================================
# FISHER INFORMATION COMPUTATION [Kirkpatrick et al. (2017)]
# =================================================================================

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
import torchvision.models as models
from pathlib import Path
from tqdm.auto import tqdm
from PIL import Image
from sklearn.model_selection import train_test_split
import numpy as np

# =================================================================================
# CONFIGURATION
# =================================================================================
class FisherConfig:
    PHASE2_MODEL_PATH = '/content/best_mobilenetv3_student_feature_kd.pth'
    OCT_DATA_PATH = '/content/data/OCT2017'
    OUTPUT_PATH = '/content/fisher/fisher_phase2.pth'

    BATCH_SIZE = 256  # Larger batch for stable gradients
    NUM_WORKERS = 2
    RANDOM_SEED = 42  #Same seed as Phase 2 training

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(f" Using device: {FisherConfig.device}")

# =================================================================================
# DATASET CLASS
# =================================================================================
class MultiClassOCTDataset(Dataset):
    def __init__(self, image_paths, labels, class_names, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.class_names = class_names
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx]).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image, self.labels[idx]

# =================================================================================
# DATA LOADING
# =================================================================================
def load_oct_paths(root_dir, split='train'):
    """Load OCT image paths and labels"""
    root_dir = Path(root_dir) / split
    class_names = ['CNV', 'DME', 'DRUSEN', 'NORMAL']
    image_paths, labels = [], []

    for class_idx, class_name in enumerate(class_names):
        class_dir = root_dir / class_name
        if class_dir.exists():
            img_files = list(class_dir.glob('*.jpeg')) + list(class_dir.glob('*.jpg'))
            image_paths.extend([str(f) for f in img_files])
            labels.extend([class_idx] * len(img_files))

    return image_paths, labels, class_names

def create_training_dataloader():
    """Create training dataloader with 70% split (same as Phase 2)"""
    print("\n Loading OCT training data...")

    # Load all data
    all_paths, all_labels, class_names = load_oct_paths(FisherConfig.OCT_DATA_PATH, 'train')
    print(f"   Total samples: {len(all_paths):,}")

    # ‚úÖ Same 70% split as Phase 2 training
    train_paths, temp_paths, train_labels, temp_labels = train_test_split(
        all_paths, all_labels,
        test_size=0.30,
        stratify=all_labels,
        random_state=FisherConfig.RANDOM_SEED  # CRITICAL: Same seed!
    )

    print(f"Using 70% training split: {len(train_paths):,} samples")

    # Data transforms (same as training)
    train_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    train_dataset = MultiClassOCTDataset(train_paths, train_labels, class_names, train_transform)
    train_loader = DataLoader(
        train_dataset,
        batch_size=FisherConfig.BATCH_SIZE,
        shuffle=False,  
        num_workers=FisherConfig.NUM_WORKERS,
        pin_memory=True
    )

    return train_loader, class_names

# =================================================================================
# MODEL DEFINITION
# =================================================================================
class MobileNetV3Student(nn.Module):
    def __init__(self, num_classes=4, pretrained=True):
        super().__init__()
        self.backbone = models.mobilenet_v3_large(pretrained=pretrained)
        in_features = self.backbone.classifier[0].in_features
        self.backbone.classifier = nn.Sequential(
            nn.Linear(in_features, 256),
            nn.Hardswish(inplace=True),
            nn.Dropout(0.2),
            nn.Linear(256, num_classes)
        )

        # (from Phase 2)
        self.feature_dim = in_features
        self.feature_projector = nn.Sequential(
            nn.Linear(self.feature_dim, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 2048)
        )

    def forward(self, x):
        return self.backbone(x)

def load_phase2_model():
    """Load trained Phase 2 model"""
    print("\n Loading Phase 2 model...")
    checkpoint = torch.load(FisherConfig.PHASE2_MODEL_PATH,
                           map_location=FisherConfig.device,
                           weights_only=True)

    student = MobileNetV3Student(num_classes=4, pretrained=False).to(FisherConfig.device)
    student.load_state_dict(checkpoint['model_state_dict'])

    print(f" Model loaded")
    print(f"   Test F1: {checkpoint.get('test_f1', 'N/A')}")

    return student, checkpoint

# =================================================================================
# FISHER INFORMATION COMPUTATION
# =================================================================================
def compute_fisher_information(model, dataloader):
    """
    Compute Fisher Information Matrix
    Based on Kirkpatrick et al. (2017) - "Overcoming Catastrophic Forgetting"

    """
    print(f"\nüßÆ Computing Fisher Information Matrix...")
    print(f"   Samples: {len(dataloader.dataset):,}")
    print(f"   Batches: {len(dataloader)}")

    
    model.train()

    fisher = {}

    # Initialize Fisher dictionary
    for name, param in model.named_parameters():
        if param.requires_grad:
            fisher[name] = torch.zeros_like(param, device='cpu')

    samples_processed = 0
    criterion = nn.CrossEntropyLoss()

    pbar = tqdm(dataloader, desc="Computing Fisher", total=len(dataloader))

    for inputs, labels in pbar:
        inputs, labels = inputs.to(FisherConfig.device), labels.to(FisherConfig.device)

        model.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()

        #  Accumulate squared gradients (no batch size multiplication)
        for name, param in model.named_parameters():
            if param.requires_grad and param.grad is not None:
                fisher[name] += param.grad.detach().cpu().pow(2)

        samples_processed += inputs.size(0)
        pbar.set_postfix({'samples': samples_processed})

        del outputs, loss
        torch.cuda.empty_cache()

    #Normalize 
    num_batches = len(dataloader)
    for name in fisher:
        pass

    return fisher

def analyze_fisher_statistics(fisher_dict):
    """Analyze and print Fisher Information statistics"""
    print(f"\n" + "="*70)
    print(" FISHER INFORMATION STATISTICS")
    print("="*70)

    # Compute statistics for each parameter
    fisher_stats = {}
    for key, value in fisher_dict.items():
        fisher_stats[key] = {
            'mean': value.mean().item(),
            'std': value.std().item(),
            'min': value.min().item(),
            'max': value.max().item(),
            'median': value.median().item()
        }

    # Overall statistics
    all_means = [s['mean'] for s in fisher_stats.values()]
    all_maxs = [s['max'] for s in fisher_stats.values()]
    all_stds = [s['std'] for s in fisher_stats.values()]

    print(f"\nüìä OVERALL FISHER STATISTICS:")
    print(f"   Parameters tracked: {len(fisher_dict)}")
    print(f"   Mean of means:   {np.mean(all_means):.8f}")
    print(f"   Std of means:    {np.std(all_means):.8f}")
    print(f"   Mean of maxs:    {np.mean(all_maxs):.6f}")
    print(f"   Global max:      {max(all_maxs):.6f}")

    # Quality check
    mean_of_means = np.mean(all_means)
    print(f"\n QUALITY CHECK:")
    print(f"   Target range: mean=0.0001-0.01, max=0.01-1.0")

    if mean_of_means < 0.00001:
        print(f"    WARNING: Fisher values very small (mean={mean_of_means:.10f})")
        print(f"   ‚Üí This may require higher EWC lambda (>10000)")
    elif mean_of_means > 0.1:
        print(f"   WARNING: Fisher values very large (mean={mean_of_means:.6f})")
        print(f"   ‚Üí This may require lower EWC lambda (<1000)")
    else:
        print(f"   Fisher values in reasonable range")
        print(f"   ‚Üí Recommended EWC lambda: 1000-5000")

    # Sample statistics by layer type
    print(f"\nüìà SAMPLE STATISTICS BY LAYER TYPE:")

    backbone_keys = [k for k in fisher_dict.keys() if 'features' in k]
    classifier_keys = [k for k in fisher_dict.keys() if 'classifier' in k]

    print(f"\n  Backbone parameters: {len(backbone_keys)}")
    if backbone_keys:
        # First, middle, last backbone layers
        for desc, idx in [('First', 0), ('Middle', len(backbone_keys)//2), ('Last', -1)]:
            key = backbone_keys[idx]
            stats = fisher_stats[key]
            print(f"    {desc} layer '{key}':")
            print(f"      Mean: {stats['mean']:.8f}, Max: {stats['max']:.6f}")

    print(f"\n  Classifier parameters: {len(classifier_keys)}")
    if classifier_keys:
        for key in classifier_keys:
            stats = fisher_stats[key]
            print(f"    '{key}':")
            print(f"      Mean: {stats['mean']:.8f}, Max: {stats['max']:.6f}")

    print("="*70)

    return fisher_stats

def print_fisher_structure(fisher_dict):
    """Print Fisher parameter structure for Phase 3 mapping"""
    print(f"\n" + "="*70)
    print("üîç FISHER KEY STRUCTURE (FOR PHASE 3 MAPPING)")
    print("="*70)

    print(f"\nüìã ALL FISHER PARAMETER KEYS ({len(fisher_dict)} total):")
    print("Format: 'layer_path.weight/bias' -> shape")

    backbone_keys = []
    classifier_keys = []

    # for key in sorted(fisher_dict.keys()):
    #     print(f"  {key} -> shape: {fisher_dict[key].shape}")
    #     if 'features' in key:
    #         backbone_keys.append(key)
    #     elif 'classifier' in key:
    #         classifier_keys.append(key)

    print(f"\nüìä SUMMARY FOR PHASE 3:")
    print(f"  Backbone parameters (features): {len(backbone_keys)}")
    print(f"  Classifier parameters: {len(classifier_keys)}")

    print(f"\nüéØ CLASSIFIER LAYER INDICES (CRITICAL FOR PHASE 3):")
    for key in classifier_keys:
        parts = key.split('.')
        layer_idx = parts[2] if len(parts) > 2 else "?"
        param_type = parts[3] if len(parts) > 3 else "?"
        #print(f"  {key}")
        #print(f"    -> Layer index: {layer_idx}, Type: {param_type}, Shape: {fisher_dict[key].shape}")

    print("="*70)

# =================================================================================
# SAVE FISHER INFORMATION
# =================================================================================
def save_fisher_checkpoint(model, fisher_dict, fisher_stats, class_names, checkpoint):
    """Save Fisher information with Phase 3 compatibility"""
    print("\nüíæ Saving Fisher Information for Phase 3...")

    # Phase 3 mapping information
    phase3_mapping_info = {
        'backbone_prefix': 'backbone.features',
        'classifier_prefix': 'backbone.classifier',
        'classifier_layers': {
            '0': {'type': 'Linear', 'in_features': 960, 'out_features': 256},
            '3': {'type': 'Linear', 'in_features': 256, 'out_features': 4}
        }
    }

    save_checkpoint = {
        'model_state_dict': model.state_dict(),
        'fisher_information': fisher_dict,
        'fisher_statistics': fisher_stats,
        'phase3_mapping': phase3_mapping_info,
        'class_names': class_names,
        'test_f1': checkpoint.get('test_f1', 'N/A'),
        'computation_method': 'academic_standard_train_mode'
    }

    torch.save(save_checkpoint, FisherConfig.OUTPUT_PATH)

    print(f"Saved to: {FisherConfig.OUTPUT_PATH}")
    print(f"\nüéØ FOR PHASE 3:")
    print(f"   Backbone params: Look for keys starting with 'backbone.features'")
    print(f"   Classifier params: Map 'backbone.classifier.x' to 'head_a.x'")
    print(f"   Layer 0 -> head_a.0 (Linear 960->256)")
    print(f"   Layer 3 -> head_a.3 (Linear 256->4)")

# =================================================================================
# MAIN EXECUTION
# =================================================================================
def main():
    """Main execution flow"""
    print("\n" + "="*70)
    print("üìö COMPUTING FISHER INFORMATION (ACADEMIC STANDARD)")
    print("="*70)

    # Step 1: Load Phase 2 model
    model, checkpoint = load_phase2_model()

    # Step 2: Create training dataloader (70% split)
    train_loader, class_names = create_training_dataloader()

    # Step 3: Compute Fisher Information
    fisher_dict = compute_fisher_information(model, train_loader)

    # Step 4: Analyze Fisher statistics
    fisher_stats = analyze_fisher_statistics(fisher_dict)

    # Step 5: Print structure for Phase 3 mapping
    print_fisher_structure(fisher_dict)

    # Step 6: Save everything
    save_fisher_checkpoint(model, fisher_dict, fisher_stats, class_names, checkpoint)

    print("\n" + "="*70)
    print("‚úÖ FISHER COMPUTATION COMPLETE - READY FOR PHASE 3")
    print("="*70)

# =================================================================================
# RUN
# =================================================================================
if __name__ == "__main__":
    main()


üöÄ Using device: cuda

üìö COMPUTING FISHER INFORMATION (ACADEMIC STANDARD)

üìÇ Loading Phase 2 model...




‚úÖ Model loaded
   Test F1: N/A

üìÇ Loading OCT training data...
   Total samples: 83,484
   Using 70% training split: 58,438 samples

üßÆ Computing Fisher Information Matrix...
   Samples: 58,438
   Batches: 229


Computing Fisher:   0%|          | 0/229 [00:00<?, ?it/s]


üîç FISHER INFORMATION STATISTICS

üìä OVERALL FISHER STATISTICS:
   Parameters tracked: 178
   Mean of means:   0.00190823
   Std of means:    0.00538226
   Mean of maxs:    0.031993
   Global max:      2.002934

üéØ QUALITY CHECK:
   Target range: mean=0.0001-0.01, max=0.01-1.0
   ‚úÖ Fisher values in reasonable range
   ‚Üí Recommended EWC lambda: 1000-5000

üìà SAMPLE STATISTICS BY LAYER TYPE:

  Backbone parameters: 170
    First layer 'backbone.features.0.0.weight':
      Mean: 0.01350438, Max: 0.080279
    Middle layer 'backbone.features.9.block.0.1.weight':
      Mean: 0.00030469, Max: 0.002669
    Last layer 'backbone.features.16.1.bias':
      Mean: 0.00000952, Max: 0.000213

  Classifier parameters: 4
    'backbone.classifier.0.weight':
      Mean: 0.00000144, Max: 0.000236
    'backbone.classifier.0.bias':
      Mean: 0.00004006, Max: 0.000221
    'backbone.classifier.3.weight':
      Mean: 0.01394666, Max: 0.218337
    'backbone.classifier.3.bias':
      Mean: 0.01685

**Continual Learning using EWC**

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torchvision.models as models
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, f1_score, accuracy_score, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
from collections import Counter
from tqdm import tqdm
import numpy as np
from pathlib import Path

# ============================================================================
# CONFIGURATION
# ============================================================================
class Config:
    # Paths
    TASK_A_DATA_PATH = "/content/data/OCT2017/train"  # OCT images folder
    TASK_B_DATA_PATH = "/content/data/chest_xray/train"  # Chest X-ray images folder
    PHASE2_MODEL_PATH = "/content/best_mobilenetv3_student_kd.pth"
    FISHER_PATH = "/content/fisher/fisher_phase2.pth"
    SAVE_DIR = "/content/phase3_results"

    # Model settings
    TASK_A_CLASSES = 4  # OCT classes
    TASK_B_CLASSES = 2  # Chest X-ray classes

    # EWC hyperparameters
    EWC_LAMBDA = 5000  # EWC regularization strength

    # Training hyperparameters
    BATCH_SIZE = 128
    LEARNING_RATE = 0.0001
    NUM_EPOCHS = 15
    PATIENCE = 5  # Early stopping

    # Data augmentation
    USE_AUGMENTATION = True

    # Evaluation
    EVAL_TASK_A_EVERY = 2  # Evaluate Task A retention every N epochs

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ============================================================================
# DATA LOADING UTILITIES
# ============================================================================
def load_task_paths(data_path):
    """
    Universal data loader for both OCT and Chest X-ray
    Loads from: /path/to/train/CLASS_NAME/*.jpg
    """
    data_path = Path(data_path)

    # Get all class folders
    class_names = sorted([d.name for d in data_path.iterdir() if d.is_dir()])

    all_paths = []
    all_labels = []

    for idx, class_name in enumerate(class_names):
        class_dir = data_path / class_name
        # Support multiple image formats
        paths = list(class_dir.glob('*.jpeg')) + \
                list(class_dir.glob('*.jpg')) + \
                list(class_dir.glob('*.png'))

        all_paths.extend(paths)
        all_labels.extend([idx] * len(paths))

    return all_paths, all_labels, class_names

# ============================================================================
# DATASET CLASS
# ============================================================================
class ImageDataset(Dataset):
    def __init__(self, paths, labels, transform=None):
        self.paths = paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        from PIL import Image
        img = Image.open(self.paths[idx]).convert('RGB')
        label = self.labels[idx]

        if self.transform:
            img = self.transform(img)

        return img, label

# ============================================================================
# DATA SPLITS
# ============================================================================
def create_task_a_splits():
    """Create stratified splits for Task A (OCT)"""
    print("\nüìä Creating Task A (OCT) evaluation splits...")

    all_paths, all_labels, task_a_class_names = load_task_paths(Config.TASK_A_DATA_PATH)
    print(f"   Total Task A samples: {len(all_paths):,}")
    print(f"   Classes: {task_a_class_names}")

    # Class distribution
    class_counts = Counter(all_labels)
    print(f"   Class distribution: {dict(class_counts)}")

    # 70/15/15 split
    train_paths, temp_paths, train_labels, temp_labels = train_test_split(
        all_paths, all_labels, test_size=0.30, stratify=all_labels, random_state=42
    )

    val_paths, test_paths, val_labels, test_labels = train_test_split(
        temp_paths, temp_labels, test_size=0.50, stratify=temp_labels, random_state=42
    )

    print(f"   Train: {len(train_paths):,} | Val: {len(val_paths):,} | Test: {len(test_paths):,}")

    test_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    test_dataset = ImageDataset(test_paths, test_labels, test_transform)
    test_loader = DataLoader(test_dataset, batch_size=Config.BATCH_SIZE,
                            shuffle=False, num_workers=2)

    return test_loader, task_a_class_names

def create_task_b_splits():
    """Create stratified splits for Task B (Chest X-ray)"""
    print("\nüìÇ Creating Task B (Chest X-ray) splits...")

    all_paths, all_labels, task_b_class_names = load_task_paths(Config.TASK_B_DATA_PATH)
    print(f"   Total samples: {len(all_paths):,}")
    print(f"   Classes: {task_b_class_names}")

    # Class distribution
    class_counts = Counter(all_labels)
    print(f"   Class distribution: {dict(class_counts)}")

    # 70/15/15 stratified split
    train_paths, temp_paths, train_labels, temp_labels = train_test_split(
        all_paths, all_labels, test_size=0.30, stratify=all_labels, random_state=42
    )

    val_paths, test_paths, val_labels, test_labels = train_test_split(
        temp_paths, temp_labels, test_size=0.50, stratify=temp_labels, random_state=42
    )

    print(f"   Train: {len(train_paths):,} | Val: {len(val_paths):,} | Test: {len(test_paths):,}")

    # Compute class weights for imbalanced dataset
    train_class_counts = Counter(train_labels)
    total_samples = len(train_labels)
    class_weights = torch.tensor([
        total_samples / (len(train_class_counts) * train_class_counts[i])
        for i in range(len(task_b_class_names))
    ], dtype=torch.float32).to(Config.device)

    print(f"   Class weights: {class_weights.cpu().numpy()}")

    # Data transforms with augmentation
    if Config.USE_AUGMENTATION:
        train_transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomRotation(10),
            transforms.ColorJitter(brightness=0.2, contrast=0.2),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
    else:
        train_transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])

    val_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    train_dataset = ImageDataset(train_paths, train_labels, train_transform)
    val_dataset = ImageDataset(val_paths, val_labels, val_transform)
    test_dataset = ImageDataset(test_paths, test_labels, val_transform)

    train_loader = DataLoader(train_dataset, batch_size=Config.BATCH_SIZE,
                             shuffle=True, num_workers=2)
    val_loader = DataLoader(val_dataset, batch_size=Config.BATCH_SIZE,
                           shuffle=False, num_workers=2)
    test_loader = DataLoader(test_dataset, batch_size=Config.BATCH_SIZE,
                            shuffle=False, num_workers=2)

    return train_loader, val_loader, test_loader, class_weights, task_b_class_names

# ============================================================================
# MULTI-HEAD MODEL
# ============================================================================
class MultiHeadMobileNet(nn.Module):
    def __init__(self, num_classes_a, num_classes_b):
        super().__init__()
        # Load MobileNetV3
        mobilenet = models.mobilenet_v3_large(weights=None)
        self.features = mobilenet.features  # Backbone features

        # ADD THIS - from Phase 2
        self.feature_projector = nn.Sequential(
            nn.Linear(960, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 2048)
        )

        # Task A head (OCT)
        self.head_a = nn.Sequential(
            nn.Linear(960, 256),
            nn.Hardswish(),
            nn.Dropout(0.2),
            nn.Linear(256, num_classes_a)
        )

        # Task B head (Chest X-ray) - new
        self.head_b = nn.Sequential(
            nn.Linear(960, 256),
            nn.Hardswish(),
            nn.Dropout(0.2),
            nn.Linear(256, num_classes_b)
        )

        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.flatten = nn.Flatten()

    def forward(self, x, task='b'):
        x = self.features(x)
        x = self.avgpool(x)
        x = self.flatten(x)

        if task == 'a':
            return self.head_a(x)
        elif task == 'b':
            return self.head_b(x)
        else:
            raise ValueError(f"Unknown task: {task}")

# ============================================================================
# EWC LOSS
# ============================================================================

def compute_ewc_loss(model, fisher_dict, optimal_params, lambda_ewc):
    ewc_loss = 0.0

    for name, param in model.named_parameters():
        if name in fisher_dict:  # ONLY FILTER: Check Fisher exists
            fisher = fisher_dict[name].to(param.device)  # Device safety
            optimal = optimal_params[name].to(param.device)  #Device safety
            ewc_loss += (fisher * (param - optimal).pow(2)).sum()

    return (lambda_ewc / 2.0) * ewc_loss


# ============================================================================
# LOAD PHASE 2 MODEL AND FISHER
# ============================================================================
def load_phase2_assets():
    """Load Phase 2 model and Fisher information"""
    print("\n Loading Phase 2 assets...")

    # Load Fisher data (contains Phase 2 model weights)
    fisher_data = torch.load(Config.FISHER_PATH, map_location=Config.device)
    phase2_state = fisher_data['model_state_dict']

    # Create multi-head model
    model = MultiHeadMobileNet(Config.TASK_A_CLASSES, Config.TASK_B_CLASSES)

    # Load weights with correct mapping
    model_state = {}

    for key, value in phase2_state.items():
        if key.startswith('backbone.features'):
            # backbone.features.X -> features.X
            new_key = key.replace('backbone.', '')
            model_state[new_key] = value
        elif key.startswith('backbone.classifier'):
            # backbone.classifier.X -> head_a.X
            new_key = key.replace('backbone.classifier', 'head_a')
            model_state[new_key] = value
        elif key.startswith('feature_projector'):
            # Keep feature_projector as-is
            model_state[key] = value

    # Load the mapped weights
    model.load_state_dict(model_state, strict=False)
    model = model.to(Config.device)
    print("    Phase 2 model loaded")

    # Load Fisher information
    fisher_dict = fisher_data['fisher_information']
    optimal_params = phase2_state

    # Map Fisher keys to new model structure
    # IMPORTANT: Only iterate over Fisher keys (trainable params only)
    mapped_fisher = {}
    mapped_optimal = {}

    for key in fisher_dict.keys():  # Fisher only has trainable parameters
        if key.startswith('backbone.features'):
            # backbone.features.X -> features.X
            new_key = key.replace('backbone.', '')
        elif key.startswith('backbone.classifier'):
            # backbone.classifier.X -> head_a.X
            new_key = key.replace('backbone.classifier', 'head_a')
        elif key.startswith('feature_projector'):
            # Keep feature_projector as-is
            new_key = key
        else:
            continue

        # Map both Fisher and optimal params
        mapped_fisher[new_key] = fisher_dict[key]
        mapped_optimal[new_key] = optimal_params[key]

    print(f"   Fisher information loaded ({len(mapped_fisher)} parameters)")
    print(f"   Optimal parameters loaded ({len(mapped_optimal)} parameters)")

    # Verification
    print(f"\n   Model Structure Verification:")
    print(f"      features: {sum(p.numel() for p in model.features.parameters()):,} params")
    print(f"      head_a: {sum(p.numel() for p in model.head_a.parameters()):,} params")
    print(f"      head_b: {sum(p.numel() for p in model.head_b.parameters()):,} params")
    print(f"      feature_projector: {sum(p.numel() for p in model.feature_projector.parameters()):,} params")

    return model, mapped_fisher, mapped_optimal

# ============================================================================
# EVALUATION FUNCTIONS
# ============================================================================
def evaluate_task(model, dataloader, task, class_names):
    """Evaluate model on a specific task"""
    model.eval()
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for images, labels in dataloader:
            images = images.to(Config.device)
            outputs = model(images, task=task)
            preds = outputs.argmax(dim=1)

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.numpy())

    acc = accuracy_score(all_labels, all_preds)
    f1 = f1_score(all_labels, all_preds, average='weighted')

    return acc, f1, all_preds, all_labels

def print_evaluation_report(acc, f1, preds, labels, class_names, task_name):
    """Print detailed evaluation report"""
    print(f"\n{'='*70}")
    print(f" {task_name} EVALUATION")
    print(f"{'='*70}")
    print(f"   Accuracy:  {acc*100:.2f}%")
    print(f"   F1-Score:  {f1:.4f}")
    print(f"\n Classification Report:")
    print(classification_report(labels, preds, target_names=class_names, digits=4))

# ============================================================================
# TRAINING FUNCTION
# ============================================================================
def train_phase3():
    """Phase 3: Continual Learning with EWC"""
    print("\n" + "="*70)
    print(" PHASE 3: CONTINUAL LEARNING WITH EWC")
    print("="*70)

    # Create save directory
    Path(Config.SAVE_DIR).mkdir(exist_ok=True)

    # Load Phase 2 assets
    model, fisher_dict, optimal_params = load_phase2_assets()


            # Freeze Task A head
    for p in model.head_a.parameters():
        p.requires_grad = False

    # verify
    for name, p in model.named_parameters():
        if "head_a" in name:
            assert p.requires_grad is False






    # Create Task A test loader for retention evaluation
    task_a_test_loader, task_a_classes = create_task_a_splits()

    # Evaluate Task A before fine-tuning (baseline)
    print("\nüß™ Evaluating Task A (OCT) BEFORE fine-tuning...")
    task_a_acc_before, task_a_f1_before, _, _ = evaluate_task(
        model, task_a_test_loader, task='a', class_names=task_a_classes
    )
    print(f"   Task A Accuracy: {task_a_acc_before*100:.2f}%")
    print(f"   Task A F1: {task_a_f1_before:.4f}")

    # Create Task B dataloaders
    train_loader, val_loader, test_loader, class_weights, task_b_classes = create_task_b_splits()




    # Setup training
    # In train_phase3 function:
    optimizer = optim.Adam([
      {'params': model.features.parameters(), 'lr': Config.LEARNING_RATE},
      {'params': model.head_b.parameters(), 'lr': Config.LEARNING_RATE}
    ])

    criterion = nn.CrossEntropyLoss(weight=class_weights)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5,
                                                     patience=3,)

    # Training loop
    best_val_f1 = 0.0
    patience_counter = 0
    history = {'train_loss': [], 'val_f1': [], 'task_a_f1': []}

    print(f"\n Training Task B (Chest X-ray) with EWC (Œª={Config.EWC_LAMBDA})...")

    for epoch in range(Config.NUM_EPOCHS):
        # Training
        model.train()
        train_loss = 0.0

        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{Config.NUM_EPOCHS}")
        for images, labels in pbar:
            images, labels = images.to(Config.device), labels.to(Config.device)

            optimizer.zero_grad()
            outputs = model(images, task='b')

            # Task B loss (cross-entropy with class weights)
            ce_loss = criterion(outputs, labels)

            # EWC regularization loss
            ewc_loss = compute_ewc_loss(model, fisher_dict, optimal_params, Config.EWC_LAMBDA)

            # Total loss
            total_loss = ce_loss + ewc_loss

            total_loss.backward()
            optimizer.step()

            train_loss += total_loss.item()
            pbar.set_postfix({'loss': f'{total_loss.item():.4f}',
                            'ce': f'{ce_loss.item():.4f}',
                            'ewc': f'{ewc_loss.item():.4f}'})

        avg_train_loss = train_loss / len(train_loader)
        history['train_loss'].append(avg_train_loss)

        # Validation on Task B
        val_acc, val_f1, _, _ = evaluate_task(model, val_loader, task='b',
                                             class_names=task_b_classes)
        history['val_f1'].append(val_f1)

        print(f"\n   Epoch {epoch+1} - Task B Val F1: {val_f1:.4f} | Acc: {val_acc*100:.2f}%")

        # Evaluate Task A retention periodically
        if (epoch + 1) % Config.EVAL_TASK_A_EVERY == 0:
            task_a_acc, task_a_f1, _, _ = evaluate_task(model, task_a_test_loader,
                                                        task='a', class_names=task_a_classes)
            history['task_a_f1'].append(task_a_f1)
            retention = (task_a_f1 / task_a_f1_before) * 100
            print(f"    Task A Retention: F1={task_a_f1:.4f} ({retention:.2f}% of baseline)")

        # Learning rate scheduling
        scheduler.step(val_f1)

        # Early stopping and checkpointing
        if val_f1 > best_val_f1:
            best_val_f1 = val_f1
            patience_counter = 0
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_f1': val_f1,
                'task_a_f1_before': task_a_f1_before
            }, f"{Config.SAVE_DIR}/phase3_best.pth")
            print(f" Best model saved (Val F1: {val_f1:.4f})")
        else:
            patience_counter += 1
            if patience_counter >= Config.PATIENCE:
                print(f"\n  Early stopping triggered (patience={Config.PATIENCE})")
                break

    # Final evaluation
    print("\n" + "="*70)
    print("FINAL EVALUATION")
    print("="*70)

    # Load best model
    checkpoint = torch.load(f"{Config.SAVE_DIR}/phase3_best.pth")
    model.load_state_dict(checkpoint['model_state_dict'])

    # Task B (Chest X-ray) - Test set
    task_b_acc, task_b_f1, task_b_preds, task_b_labels = evaluate_task(
        model, test_loader, task='b', class_names=task_b_classes
    )
    print_evaluation_report(task_b_acc, task_b_f1, task_b_preds, task_b_labels,
                          task_b_classes, "TASK B (Chest X-ray)")

    # Task A (OCT) - Retention test
    task_a_acc_after, task_a_f1_after, task_a_preds, task_a_labels = evaluate_task(
        model, task_a_test_loader, task='a', class_names=task_a_classes
    )
    print_evaluation_report(task_a_acc_after, task_a_f1_after, task_a_preds, task_a_labels,
                          task_a_classes, "TASK A (OCT) - Retention Check")

    # Retention metrics
    retention_f1 = (task_a_f1_after / task_a_f1_before) * 100
    retention_acc = (task_a_acc_after / task_a_acc_before) * 100

    print("\n" + "="*70)
    print("CONTINUAL LEARNING SUMMARY")
    print("="*70)
    print(f" Task A (OCT) Retention:")
    print(f"   Before: F1={task_a_f1_before:.4f}, Acc={task_a_acc_before*100:.2f}%")
    print(f"   After:  F1={task_a_f1_after:.4f}, Acc={task_a_acc_after*100:.2f}%")
    print(f"   Retention: F1={retention_f1:.2f}%, Acc={retention_acc:.2f}%")
    print(f"\n Task B (Chest X-ray) Performance:")
    print(f"   Test F1: {task_b_f1:.4f}")
    print(f"   Test Acc: {task_b_acc*100:.2f}%")
    print("="*70)

    # Save confusion matrices
    save_confusion_matrix(task_a_labels, task_a_preds, task_a_classes,
                         "Task A (OCT) - After EWC", f"{Config.SAVE_DIR}/cm_task_a.png")
    save_confusion_matrix(task_b_labels, task_b_preds, task_b_classes,
                         "Task B (Chest X-ray)", f"{Config.SAVE_DIR}/cm_task_b.png")

    return model, history

def save_confusion_matrix(labels, preds, class_names, title, save_path):
    """Save confusion matrix plot"""
    cm = confusion_matrix(labels, preds)
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=class_names, yticklabels=class_names)
    plt.title(title)
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"  Confusion matrix saved: {save_path}")

# ============================================================================
# MAIN EXECUTION
# ============================================================================
if __name__ == "__main__":
    model, history = train_phase3()


üöÄ PHASE 3: CONTINUAL LEARNING WITH EWC

üìÇ Loading Phase 2 assets...
   ‚úÖ Phase 2 model loaded
   ‚úÖ Fisher information loaded (178 parameters)
   ‚úÖ Optimal parameters loaded (178 parameters)

   üîç Model Structure Verification:
      features: 2,971,952 params
      head_a: 247,044 params
      head_b: 246,530 params
      feature_projector: 3,083,264 params

üìä Creating Task A (OCT) evaluation splits...
   Total Task A samples: 83,484
   Classes: ['CNV', 'DME', 'DRUSEN', 'NORMAL']
   Class distribution: {0: 37205, 1: 11348, 2: 8616, 3: 26315}
   Train: 58,438 | Val: 12,523 | Test: 12,523

üß™ Evaluating Task A (OCT) BEFORE fine-tuning...
   Task A Accuracy: 96.92%
   Task A F1: 0.9693

üìÇ Creating Task B (Chest X-ray) splits...
   Total samples: 5,232
   Classes: ['NORMAL', 'PNEUMONIA']
   Class distribution: {0: 1349, 1: 3883}
   Train: 3,662 | Val: 785 | Test: 785
   Class weights: [1.9396186 0.6736571]

üéØ Training Task B (Chest X-ray) with EWC (Œª=5000)...


Epoch 1/15: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 29/29 [01:02<00:00,  2.16s/it, loss=0.2752, ce=0.2340, ewc=0.0413]



   Epoch 1 - Task B Val F1: 0.6652 | Acc: 75.41%
   üíæ Best model saved (Val F1: 0.6652)


Epoch 2/15: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 29/29 [01:00<00:00,  2.09s/it, loss=0.2004, ce=0.1586, ewc=0.0417]



   Epoch 2 - Task B Val F1: 0.6809 | Acc: 76.43%
   üìà Task A Retention: F1=0.9181 (94.72% of baseline)
   üíæ Best model saved (Val F1: 0.6809)


Epoch 3/15: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 29/29 [01:00<00:00,  2.10s/it, loss=0.2265, ce=0.1867, ewc=0.0398]



   Epoch 3 - Task B Val F1: 0.7362 | Acc: 79.24%
   üíæ Best model saved (Val F1: 0.7362)


Epoch 4/15: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 29/29 [01:00<00:00,  2.09s/it, loss=0.2235, ce=0.1849, ewc=0.0386]



   Epoch 4 - Task B Val F1: 0.8464 | Acc: 86.37%
   üìà Task A Retention: F1=0.8014 (82.68% of baseline)
   üíæ Best model saved (Val F1: 0.8464)


Epoch 5/15: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 29/29 [00:59<00:00,  2.06s/it, loss=0.1591, ce=0.1211, ewc=0.0380]



   Epoch 5 - Task B Val F1: 0.9106 | Acc: 91.59%
   üíæ Best model saved (Val F1: 0.9106)


Epoch 6/15: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 29/29 [01:01<00:00,  2.11s/it, loss=0.0656, ce=0.0260, ewc=0.0396]



   Epoch 6 - Task B Val F1: 0.9519 | Acc: 95.29%
   üìà Task A Retention: F1=0.6448 (66.53% of baseline)
   üíæ Best model saved (Val F1: 0.9519)


Epoch 7/15: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 29/29 [01:00<00:00,  2.10s/it, loss=0.1261, ce=0.0862, ewc=0.0400]



   Epoch 7 - Task B Val F1: 0.9626 | Acc: 96.31%
   üíæ Best model saved (Val F1: 0.9626)


Epoch 8/15: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 29/29 [01:00<00:00,  2.09s/it, loss=0.1836, ce=0.1450, ewc=0.0386]



   Epoch 8 - Task B Val F1: 0.9730 | Acc: 97.32%
   üìà Task A Retention: F1=0.4405 (45.44% of baseline)
   üíæ Best model saved (Val F1: 0.9730)


Epoch 9/15: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 29/29 [01:00<00:00,  2.09s/it, loss=0.1616, ce=0.1249, ewc=0.0366]



   Epoch 9 - Task B Val F1: 0.9730 | Acc: 97.32%


Epoch 10/15: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 29/29 [01:00<00:00,  2.08s/it, loss=0.0955, ce=0.0571, ewc=0.0383]



   Epoch 10 - Task B Val F1: 0.9650 | Acc: 96.56%
   üìà Task A Retention: F1=0.4194 (43.26% of baseline)


Epoch 11/15: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 29/29 [01:01<00:00,  2.13s/it, loss=0.1580, ce=0.1180, ewc=0.0400]



   Epoch 11 - Task B Val F1: 0.9649 | Acc: 96.56%


Epoch 12/15: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 29/29 [01:01<00:00,  2.11s/it, loss=0.1769, ce=0.1394, ewc=0.0375]



   Epoch 12 - Task B Val F1: 0.9809 | Acc: 98.09%
   üìà Task A Retention: F1=0.3966 (40.92% of baseline)
   üíæ Best model saved (Val F1: 0.9809)


Epoch 13/15: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 29/29 [01:00<00:00,  2.09s/it, loss=0.0762, ce=0.0380, ewc=0.0382]



   Epoch 13 - Task B Val F1: 0.9835 | Acc: 98.34%
   üíæ Best model saved (Val F1: 0.9835)


Epoch 14/15: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 29/29 [01:00<00:00,  2.10s/it, loss=0.0731, ce=0.0343, ewc=0.0389]



   Epoch 14 - Task B Val F1: 0.9673 | Acc: 96.69%
   üìà Task A Retention: F1=0.3659 (37.75% of baseline)


Epoch 15/15: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 29/29 [00:59<00:00,  2.06s/it, loss=0.0778, ce=0.0408, ewc=0.0370]



   Epoch 15 - Task B Val F1: 0.9834 | Acc: 98.34%

üìä FINAL EVALUATION

üìä TASK B (Chest X-ray) EVALUATION
   Accuracy:  97.58%
   F1-Score:  0.9758

üìã Classification Report:
              precision    recall  f1-score   support

      NORMAL     0.9510    0.9557    0.9533       203
   PNEUMONIA     0.9845    0.9828    0.9837       582

    accuracy                         0.9758       785
   macro avg     0.9677    0.9692    0.9685       785
weighted avg     0.9758    0.9758    0.9758       785


üìä TASK A (OCT) - Retention Check EVALUATION
   Accuracy:  47.33%
   F1-Score:  0.3910

üìã Classification Report:
              precision    recall  f1-score   support

         CNV     0.9934    0.1885    0.3169      5581
         DME     0.5243    0.0317    0.0598      1702
      DRUSEN     0.4176    0.6783    0.5169      1293
      NORMAL     0.4259    0.9992    0.5972      3947

    accuracy                         0.4733     12523
   macro avg     0.5903    0.4744    0.3727  

# Adaating with **BatchNorm freezed**

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torchvision.models as models
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, f1_score, accuracy_score, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
from collections import Counter
from tqdm import tqdm
import numpy as np
from pathlib import Path

# ============================================================================
# CONFIGURATION
# ============================================================================
class Config:
    # Paths
    TASK_A_DATA_PATH = "/content/data/OCT2017/train"  # OCT images folder
    TASK_B_DATA_PATH = "/content/data/chest_xray/train"  # Chest X-ray images folder
    PHASE2_MODEL_PATH = "/content/best_mobilenetv3_student_kd.pth"
    FISHER_PATH = "/content/fisher/fisher_phase2.pth"
    SAVE_DIR = "/content/phase3_results"

    # Model settings
    TASK_A_CLASSES = 4  # OCT classes
    TASK_B_CLASSES = 2  # Chest X-ray classes

    # EWC hyperparameters
    EWC_LAMBDA = 5000  # EWC regularization strength

    # Training hyperparameters
    BATCH_SIZE = 128
    LEARNING_RATE = 0.0001
    NUM_EPOCHS = 15
    PATIENCE = 5  # Early stopping

    # Data augmentation
    USE_AUGMENTATION = True

    # Evaluation
    EVAL_TASK_A_EVERY = 2  # Evaluate Task A retention every N epochs

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ============================================================================
# DATA LOADING UTILITIES
# ============================================================================
def load_task_paths(data_path):
    """
    Universal data loader for both OCT and Chest X-ray
    Loads from: /path/to/train/CLASS_NAME/*.jpg
    """
    data_path = Path(data_path)

    # Get all class folders
    class_names = sorted([d.name for d in data_path.iterdir() if d.is_dir()])

    all_paths = []
    all_labels = []

    for idx, class_name in enumerate(class_names):
        class_dir = data_path / class_name
        # Support multiple image formats
        paths = list(class_dir.glob('*.jpeg')) + \
                list(class_dir.glob('*.jpg')) + \
                list(class_dir.glob('*.png'))

        all_paths.extend(paths)
        all_labels.extend([idx] * len(paths))

    return all_paths, all_labels, class_names

# ============================================================================
# DATASET CLASS
# ============================================================================
class ImageDataset(Dataset):
    def __init__(self, paths, labels, transform=None):
        self.paths = paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        from PIL import Image
        img = Image.open(self.paths[idx]).convert('RGB')
        label = self.labels[idx]

        if self.transform:
            img = self.transform(img)

        return img, label

# ============================================================================
# DATA SPLITS
# ============================================================================
def create_task_a_splits():
    """Create stratified splits for Task A (OCT)"""
    print("\nüìä Creating Task A (OCT) evaluation splits...")

    all_paths, all_labels, task_a_class_names = load_task_paths(Config.TASK_A_DATA_PATH)
    print(f"   Total Task A samples: {len(all_paths):,}")
    print(f"   Classes: {task_a_class_names}")

    # Class distribution
    class_counts = Counter(all_labels)
    print(f"   Class distribution: {dict(class_counts)}")

    # 70/15/15 split
    train_paths, temp_paths, train_labels, temp_labels = train_test_split(
        all_paths, all_labels, test_size=0.30, stratify=all_labels, random_state=42
    )

    val_paths, test_paths, val_labels, test_labels = train_test_split(
        temp_paths, temp_labels, test_size=0.50, stratify=temp_labels, random_state=42
    )

    print(f"   Train: {len(train_paths):,} | Val: {len(val_paths):,} | Test: {len(test_paths):,}")

    test_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    test_dataset = ImageDataset(test_paths, test_labels, test_transform)
    test_loader = DataLoader(test_dataset, batch_size=Config.BATCH_SIZE,
                            shuffle=False, num_workers=2)

    return test_loader, task_a_class_names

def create_task_b_splits():
    """Create stratified splits for Task B (Chest X-ray)"""
    print("\nüìÇ Creating Task B (Chest X-ray) splits...")

    all_paths, all_labels, task_b_class_names = load_task_paths(Config.TASK_B_DATA_PATH)
    print(f"   Total samples: {len(all_paths):,}")
    print(f"   Classes: {task_b_class_names}")

    # Class distribution
    class_counts = Counter(all_labels)
    print(f"   Class distribution: {dict(class_counts)}")

    # 70/15/15 stratified split
    train_paths, temp_paths, train_labels, temp_labels = train_test_split(
        all_paths, all_labels, test_size=0.30, stratify=all_labels, random_state=42
    )

    val_paths, test_paths, val_labels, test_labels = train_test_split(
        temp_paths, temp_labels, test_size=0.50, stratify=temp_labels, random_state=42
    )

    print(f"   Train: {len(train_paths):,} | Val: {len(val_paths):,} | Test: {len(test_paths):,}")

    # Compute class weights for imbalanced dataset
    train_class_counts = Counter(train_labels)
    total_samples = len(train_labels)
    class_weights = torch.tensor([
        total_samples / (len(train_class_counts) * train_class_counts[i])
        for i in range(len(task_b_class_names))
    ], dtype=torch.float32).to(Config.device)

    print(f"   Class weights: {class_weights.cpu().numpy()}")

    # Data transforms with augmentation
    if Config.USE_AUGMENTATION:
        train_transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomRotation(10),
            transforms.ColorJitter(brightness=0.2, contrast=0.2),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
    else:
        train_transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])

    val_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    train_dataset = ImageDataset(train_paths, train_labels, train_transform)
    val_dataset = ImageDataset(val_paths, val_labels, val_transform)
    test_dataset = ImageDataset(test_paths, test_labels, val_transform)

    train_loader = DataLoader(train_dataset, batch_size=Config.BATCH_SIZE,
                             shuffle=True, num_workers=2)
    val_loader = DataLoader(val_dataset, batch_size=Config.BATCH_SIZE,
                           shuffle=False, num_workers=2)
    test_loader = DataLoader(test_dataset, batch_size=Config.BATCH_SIZE,
                            shuffle=False, num_workers=2)

    return train_loader, val_loader, test_loader, class_weights, task_b_class_names

# ============================================================================
# MULTI-HEAD MODEL
# ============================================================================
class MultiHeadMobileNet(nn.Module):
    def __init__(self, num_classes_a, num_classes_b):
        super().__init__()
        # Load MobileNetV3
        mobilenet = models.mobilenet_v3_large(weights=None)
        self.features = mobilenet.features  # Backbone features

        # ADD THIS - from Phase 2
        self.feature_projector = nn.Sequential(
            nn.Linear(960, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 2048)
        )

        # Task A head (OCT)
        self.head_a = nn.Sequential(
            nn.Linear(960, 256),
            nn.Hardswish(),
            nn.Dropout(0.2),
            nn.Linear(256, num_classes_a)
        )

        # Task B head (Chest X-ray) - new
        self.head_b = nn.Sequential(
            nn.Linear(960, 256),
            nn.Hardswish(),
            nn.Dropout(0.2),
            nn.Linear(256, num_classes_b)
        )

        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.flatten = nn.Flatten()

    def forward(self, x, task='b'):
        x = self.features(x)
        x = self.avgpool(x)
        x = self.flatten(x)

        if task == 'a':
            return self.head_a(x)
        elif task == 'b':
            return self.head_b(x)
        else:
            raise ValueError(f"Unknown task: {task}")

# ============================================================================
# EWC LOSS
# ============================================================================

def compute_ewc_loss(model, fisher_dict, optimal_params, lambda_ewc):
    ewc_loss = 0.0

    for name, param in model.named_parameters():
        if name in fisher_dict:  
            fisher = fisher_dict[name].to(param.device)  # Device safety
            optimal = optimal_params[name].to(param.device)  # Device safety
            ewc_loss += (fisher * (param - optimal).pow(2)).sum()

    return (lambda_ewc / 2.0) * ewc_loss


# ============================================================================
# LOAD PHASE 2 MODEL AND FISHER
# ============================================================================
def load_phase2_assets():
    """Load Phase 2 model and Fisher information"""
    print("\n Loading Phase 2 assets...")

    # Load Fisher data (contains Phase 2 model weights)
    fisher_data = torch.load(Config.FISHER_PATH, map_location=Config.device)
    phase2_state = fisher_data['model_state_dict']

    # Create multi-head model
    model = MultiHeadMobileNet(Config.TASK_A_CLASSES, Config.TASK_B_CLASSES)

    # Load weights with correct mapping
    model_state = {}

    for key, value in phase2_state.items():
        if key.startswith('backbone.features'):
            # backbone.features.X -> features.X
            new_key = key.replace('backbone.', '')
            model_state[new_key] = value
        elif key.startswith('backbone.classifier'):
            # backbone.classifier.X -> head_a.X
            new_key = key.replace('backbone.classifier', 'head_a')
            model_state[new_key] = value
        elif key.startswith('feature_projector'):
            # Keep feature_projector as-is
            model_state[key] = value

    # Load the mapped weights
    model.load_state_dict(model_state, strict=False)
    model = model.to(Config.device)
    print("   ‚úÖ Phase 2 model loaded")

    # Load Fisher information
    fisher_dict = fisher_data['fisher_information']
    optimal_params = phase2_state

    # Map Fisher keys to new model structure
    # ‚úÖ IMPORTANT: Only iterate over Fisher keys (trainable params only)
    mapped_fisher = {}
    mapped_optimal = {}

    for key in fisher_dict.keys():  # Fisher only has trainable parameters
        if key.startswith('backbone.features'):
            # backbone.features.X -> features.X
            new_key = key.replace('backbone.', '')
        elif key.startswith('backbone.classifier'):
            # backbone.classifier.X -> head_a.X
            new_key = key.replace('backbone.classifier', 'head_a')
        elif key.startswith('feature_projector'):
            # Keep feature_projector as-is
            new_key = key
        else:
            continue

        # Map both Fisher and optimal params
        mapped_fisher[new_key] = fisher_dict[key]
        mapped_optimal[new_key] = optimal_params[key]

    print(f"   ‚úÖ Fisher information loaded ({len(mapped_fisher)} parameters)")
    print(f"   ‚úÖ Optimal parameters loaded ({len(mapped_optimal)} parameters)")

    # Verification
    print(f"\n   üîç Model Structure Verification:")
    print(f"      features: {sum(p.numel() for p in model.features.parameters()):,} params")
    print(f"      head_a: {sum(p.numel() for p in model.head_a.parameters()):,} params")
    print(f"      head_b: {sum(p.numel() for p in model.head_b.parameters()):,} params")
    print(f"      feature_projector: {sum(p.numel() for p in model.feature_projector.parameters()):,} params")

    return model, mapped_fisher, mapped_optimal

# ============================================================================
# EVALUATION FUNCTIONS
# ============================================================================
def evaluate_task(model, dataloader, task, class_names):
    """Evaluate model on a specific task"""
    model.eval()
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for images, labels in dataloader:
            images = images.to(Config.device)
            outputs = model(images, task=task)
            preds = outputs.argmax(dim=1)

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.numpy())

    acc = accuracy_score(all_labels, all_preds)
    f1 = f1_score(all_labels, all_preds, average='weighted')

    return acc, f1, all_preds, all_labels

def print_evaluation_report(acc, f1, preds, labels, class_names, task_name):
    """Print detailed evaluation report"""
    print(f"\n{'='*70}")
    print(f"üìä {task_name} EVALUATION")
    print(f"{'='*70}")
    print(f"   Accuracy:  {acc*100:.2f}%")
    print(f"   F1-Score:  {f1:.4f}")
    print(f"\nüìã Classification Report:")
    print(classification_report(labels, preds, target_names=class_names, digits=4))

# ============================================================================
# TRAINING FUNCTION
# ============================================================================
def train_phase3():
    """Phase 3: Continual Learning with EWC"""
    print("\n" + "="*70)
    print("üöÄ PHASE 3: CONTINUAL LEARNING WITH EWC")
    print("="*70)

    # Create save directory
    Path(Config.SAVE_DIR).mkdir(exist_ok=True)

    # Load Phase 2 assets
    model, fisher_dict, optimal_params = load_phase2_assets()


            # Freeze Task A head
    for p in model.head_a.parameters():
        p.requires_grad = False

    # verify
    for name, p in model.named_parameters():
        if "head_a" in name:
            assert p.requires_grad is False






    # Create Task A test loader for retention evaluation
    task_a_test_loader, task_a_classes = create_task_a_splits()

    # Evaluate Task A before fine-tuning (baseline)
    print("\nüß™ Evaluating Task A (OCT) BEFORE fine-tuning...")
    task_a_acc_before, task_a_f1_before, _, _ = evaluate_task(
        model, task_a_test_loader, task='a', class_names=task_a_classes
    )
    print(f"   Task A Accuracy: {task_a_acc_before*100:.2f}%")
    print(f"   Task A F1: {task_a_f1_before:.4f}")

    # Create Task B dataloaders
    train_loader, val_loader, test_loader, class_weights, task_b_classes = create_task_b_splits()




    # Setup training
    # In train_phase3 function:
    optimizer = optim.Adam([
      {'params': model.features.parameters(), 'lr': Config.LEARNING_RATE},
      {'params': model.head_b.parameters(), 'lr': Config.LEARNING_RATE}
    ])

    criterion = nn.CrossEntropyLoss(weight=class_weights)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5,
                                                     patience=3,)

    # Training loop
    best_val_f1 = 0.0
    patience_counter = 0
    history = {'train_loss': [], 'val_f1': [], 'task_a_f1': []}

    print(f"\nüéØ Training Task B (Chest X-ray) with EWC (Œª={Config.EWC_LAMBDA})...")

    for epoch in range(Config.NUM_EPOCHS):
        # Training
        model.train()
        model.features.eval()  # Freeze backbone BatchNorm
        #model.head_a.eval() # Freeze head_a
        train_loss = 0.0

        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{Config.NUM_EPOCHS}")
        for images, labels in pbar:
            images, labels = images.to(Config.device), labels.to(Config.device)

            optimizer.zero_grad()
            outputs = model(images, task='b')

            # Task B loss (cross-entropy with class weights)
            ce_loss = criterion(outputs, labels)

            # EWC regularization loss
            ewc_loss = compute_ewc_loss(model, fisher_dict, optimal_params, Config.EWC_LAMBDA)

            # Total loss
            total_loss = ce_loss + ewc_loss

            total_loss.backward()
            optimizer.step()

            train_loss += total_loss.item()
            pbar.set_postfix({'loss': f'{total_loss.item():.4f}',
                            'ce': f'{ce_loss.item():.4f}',
                            'ewc': f'{ewc_loss.item():.4f}'})

        avg_train_loss = train_loss / len(train_loader)
        history['train_loss'].append(avg_train_loss)

        # Validation on Task B
        val_acc, val_f1, _, _ = evaluate_task(model, val_loader, task='b',
                                             class_names=task_b_classes)
        history['val_f1'].append(val_f1)

        print(f"\n   Epoch {epoch+1} - Task B Val F1: {val_f1:.4f} | Acc: {val_acc*100:.2f}%")

        # Evaluate Task A retention periodically
        if (epoch + 1) % Config.EVAL_TASK_A_EVERY == 0:
            task_a_acc, task_a_f1, _, _ = evaluate_task(model, task_a_test_loader,
                                                        task='a', class_names=task_a_classes)
            history['task_a_f1'].append(task_a_f1)
            retention = (task_a_f1 / task_a_f1_before) * 100
            print(f"   üìà Task A Retention: F1={task_a_f1:.4f} ({retention:.2f}% of baseline)")

        # Learning rate scheduling
        scheduler.step(val_f1)

        # Early stopping and checkpointing
        if val_f1 > best_val_f1:
            best_val_f1 = val_f1
            patience_counter = 0
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_f1': val_f1,
                'task_a_f1_before': task_a_f1_before
            }, f"{Config.SAVE_DIR}/phase3_best.pth")
            print(f"   üíæ Best model saved (Val F1: {val_f1:.4f})")
        else:
            patience_counter += 1
            if patience_counter >= Config.PATIENCE:
                print(f"\n‚è∏Ô∏è  Early stopping triggered (patience={Config.PATIENCE})")
                break

    # Final evaluation
    print("\n" + "="*70)
    print("üìä FINAL EVALUATION")
    print("="*70)

    # Load best model
    checkpoint = torch.load(f"{Config.SAVE_DIR}/phase3_best.pth")
    model.load_state_dict(checkpoint['model_state_dict'])

    # Task B (Chest X-ray) - Test set
    task_b_acc, task_b_f1, task_b_preds, task_b_labels = evaluate_task(
        model, test_loader, task='b', class_names=task_b_classes
    )
    print_evaluation_report(task_b_acc, task_b_f1, task_b_preds, task_b_labels,
                          task_b_classes, "TASK B (Chest X-ray)")

    # Task A (OCT) - Retention test
    task_a_acc_after, task_a_f1_after, task_a_preds, task_a_labels = evaluate_task(
        model, task_a_test_loader, task='a', class_names=task_a_classes
    )
    print_evaluation_report(task_a_acc_after, task_a_f1_after, task_a_preds, task_a_labels,
                          task_a_classes, "TASK A (OCT) - Retention Check")

    # Retention metrics
    retention_f1 = (task_a_f1_after / task_a_f1_before) * 100
    retention_acc = (task_a_acc_after / task_a_acc_before) * 100

    print("\n" + "="*70)
    print("üéØ CONTINUAL LEARNING SUMMARY")
    print("="*70)
    print(f"üìä Task A (OCT) Retention:")
    print(f"   Before: F1={task_a_f1_before:.4f}, Acc={task_a_acc_before*100:.2f}%")
    print(f"   After:  F1={task_a_f1_after:.4f}, Acc={task_a_acc_after*100:.2f}%")
    print(f"   Retention: F1={retention_f1:.2f}%, Acc={retention_acc:.2f}%")
    print(f"\nüìä Task B (Chest X-ray) Performance:")
    print(f"   Test F1: {task_b_f1:.4f}")
    print(f"   Test Acc: {task_b_acc*100:.2f}%")
    print("="*70)

    # Save confusion matrices
    save_confusion_matrix(task_a_labels, task_a_preds, task_a_classes,
                         "Task A (OCT) - After EWC", f"{Config.SAVE_DIR}/cm_task_a.png")
    save_confusion_matrix(task_b_labels, task_b_preds, task_b_classes,
                         "Task B (Chest X-ray)", f"{Config.SAVE_DIR}/cm_task_b.png")

    return model, history

def save_confusion_matrix(labels, preds, class_names, title, save_path):
    """Save confusion matrix plot"""
    cm = confusion_matrix(labels, preds)
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=class_names, yticklabels=class_names)
    plt.title(title)
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"   ‚úÖ Confusion matrix saved: {save_path}")

# ============================================================================
# MAIN EXECUTION
# ============================================================================
if __name__ == "__main__":
    model, history = train_phase3()


üöÄ PHASE 3: CONTINUAL LEARNING WITH EWC

üìÇ Loading Phase 2 assets...
   ‚úÖ Phase 2 model loaded
   ‚úÖ Fisher information loaded (178 parameters)
   ‚úÖ Optimal parameters loaded (178 parameters)

   üîç Model Structure Verification:
      features: 2,971,952 params
      head_a: 247,044 params
      head_b: 246,530 params
      feature_projector: 3,083,264 params

üìä Creating Task A (OCT) evaluation splits...
   Total Task A samples: 83,484
   Classes: ['CNV', 'DME', 'DRUSEN', 'NORMAL']
   Class distribution: {0: 37205, 1: 11348, 2: 8616, 3: 26315}
   Train: 58,438 | Val: 12,523 | Test: 12,523

üß™ Evaluating Task A (OCT) BEFORE fine-tuning...
   Task A Accuracy: 96.92%
   Task A F1: 0.9693

üìÇ Creating Task B (Chest X-ray) splits...
   Total samples: 5,232
   Classes: ['NORMAL', 'PNEUMONIA']
   Class distribution: {0: 1349, 1: 3883}
   Train: 3,662 | Val: 785 | Test: 785
   Class weights: [1.9396186 0.6736571]

üéØ Training Task B (Chest X-ray) with EWC (Œª=5000)...


Epoch 1/15: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 29/29 [01:01<00:00,  2.13s/it, loss=0.2077, ce=0.1694, ewc=0.0383]



   Epoch 1 - Task B Val F1: 0.9050 | Acc: 90.06%
   üíæ Best model saved (Val F1: 0.9050)


Epoch 2/15: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 29/29 [01:03<00:00,  2.18s/it, loss=0.1440, ce=0.1059, ewc=0.0382]



   Epoch 2 - Task B Val F1: 0.9441 | Acc: 94.27%
   üìà Task A Retention: F1=0.9638 (99.44% of baseline)
   üíæ Best model saved (Val F1: 0.9441)


Epoch 3/15: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 29/29 [01:00<00:00,  2.10s/it, loss=0.1596, ce=0.1238, ewc=0.0357]



   Epoch 3 - Task B Val F1: 0.9455 | Acc: 94.39%
   üíæ Best model saved (Val F1: 0.9455)


Epoch 4/15: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 29/29 [01:00<00:00,  2.08s/it, loss=0.1440, ce=0.1084, ewc=0.0356]



   Epoch 4 - Task B Val F1: 0.9577 | Acc: 95.67%
   üìà Task A Retention: F1=0.9545 (98.47% of baseline)
   üíæ Best model saved (Val F1: 0.9577)


Epoch 5/15: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 29/29 [01:01<00:00,  2.12s/it, loss=0.2259, ce=0.1921, ewc=0.0338]



   Epoch 5 - Task B Val F1: 0.9480 | Acc: 94.65%


Epoch 6/15: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 29/29 [01:00<00:00,  2.09s/it, loss=0.1394, ce=0.1060, ewc=0.0334]



   Epoch 6 - Task B Val F1: 0.9613 | Acc: 96.05%
   üìà Task A Retention: F1=0.9498 (97.99% of baseline)
   üíæ Best model saved (Val F1: 0.9613)


Epoch 7/15: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 29/29 [01:00<00:00,  2.10s/it, loss=0.1144, ce=0.0809, ewc=0.0335]



   Epoch 7 - Task B Val F1: 0.9626 | Acc: 96.18%
   üíæ Best model saved (Val F1: 0.9626)


Epoch 8/15: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 29/29 [01:00<00:00,  2.09s/it, loss=0.2373, ce=0.2049, ewc=0.0325]



   Epoch 8 - Task B Val F1: 0.9733 | Acc: 97.32%
   üìà Task A Retention: F1=0.9601 (99.06% of baseline)
   üíæ Best model saved (Val F1: 0.9733)


Epoch 9/15: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 29/29 [01:01<00:00,  2.11s/it, loss=0.0885, ce=0.0554, ewc=0.0330]



   Epoch 9 - Task B Val F1: 0.9732 | Acc: 97.32%


Epoch 10/15: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 29/29 [01:00<00:00,  2.10s/it, loss=0.0867, ce=0.0528, ewc=0.0339]



   Epoch 10 - Task B Val F1: 0.9771 | Acc: 97.71%
   üìà Task A Retention: F1=0.9410 (97.08% of baseline)
   üíæ Best model saved (Val F1: 0.9771)


Epoch 11/15: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 29/29 [01:01<00:00,  2.13s/it, loss=0.1621, ce=0.1309, ewc=0.0312]



   Epoch 11 - Task B Val F1: 0.9718 | Acc: 97.20%


Epoch 12/15: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 29/29 [01:00<00:00,  2.10s/it, loss=0.1891, ce=0.1573, ewc=0.0318]



   Epoch 12 - Task B Val F1: 0.9664 | Acc: 96.69%
   üìà Task A Retention: F1=0.9468 (97.68% of baseline)


Epoch 13/15: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 29/29 [01:01<00:00,  2.11s/it, loss=0.0524, ce=0.0196, ewc=0.0329]



   Epoch 13 - Task B Val F1: 0.9797 | Acc: 97.96%
   üíæ Best model saved (Val F1: 0.9797)


Epoch 14/15: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 29/29 [01:01<00:00,  2.13s/it, loss=0.2375, ce=0.2058, ewc=0.0317]



   Epoch 14 - Task B Val F1: 0.9383 | Acc: 93.63%
   üìà Task A Retention: F1=0.9601 (99.05% of baseline)


Epoch 15/15: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 29/29 [01:00<00:00,  2.08s/it, loss=0.1870, ce=0.1545, ewc=0.0325]



   Epoch 15 - Task B Val F1: 0.9601 | Acc: 95.92%

üìä FINAL EVALUATION

üìä TASK B (Chest X-ray) EVALUATION
   Accuracy:  97.07%
   F1-Score:  0.9707

üìã Classification Report:
              precision    recall  f1-score   support

      NORMAL     0.9412    0.9458    0.9435       203
   PNEUMONIA     0.9811    0.9794    0.9802       582

    accuracy                         0.9707       785
   macro avg     0.9611    0.9626    0.9619       785
weighted avg     0.9708    0.9707    0.9707       785


üìä TASK A (OCT) - Retention Check EVALUATION
   Accuracy:  95.38%
   F1-Score:  0.9527

üìã Classification Report:
              precision    recall  f1-score   support

         CNV     0.9688    0.9780    0.9733      5581
         DME     0.9160    0.9424    0.9290      1702
      DRUSEN     0.9522    0.7703    0.8516      1293
      NORMAL     0.9497    0.9845    0.9668      3947

    accuracy                         0.9538     12523
   macro avg     0.9467    0.9188    0.9302  