In [1]:
!gdown --fuzzy "https://drive.google.com/file/d/1cIGCfx6CiVgEpq8PyKzmF1LBJiQGkxzc/view?usp=sharing"
!gdown --fuzzy "https://drive.google.com/file/d/1JobiELb-4mO_Gk3NY6eyIz-3oRw3U2zT/view?usp=sharing"

Downloading...
From (original): https://drive.google.com/uc?id=1cIGCfx6CiVgEpq8PyKzmF1LBJiQGkxzc
From (redirected): https://drive.google.com/uc?id=1cIGCfx6CiVgEpq8PyKzmF1LBJiQGkxzc&confirm=t&uuid=05b8ef28-56f5-4fef-84ad-1d7ceae88d0f
To: /content/OCT2017.tar.gz
100% 5.79G/5.79G [01:16<00:00, 75.6MB/s]
Downloading...
From (original): https://drive.google.com/uc?id=1JobiELb-4mO_Gk3NY6eyIz-3oRw3U2zT
From (redirected): https://drive.google.com/uc?id=1JobiELb-4mO_Gk3NY6eyIz-3oRw3U2zT&confirm=t&uuid=fd8ebfb2-2d7f-4f80-a55a-fc4994564589
To: /content/ChestXRay2017.zip
100% 1.24G/1.24G [00:19<00:00, 62.3MB/s]


In [6]:
#Extract zip
!tar -xzf "/content/OCT2017.tar.gz" -C /content/data/
!unzip -q /content/ChestXRay2017.zip -d /content/data

In [3]:
!gdown --fuzzy "https://drive.google.com/file/d/1F_vX0fmLL0nKlhaQMhJWHGcFu9a3NxEs/view?usp=sharing"

Downloading...
From (original): https://drive.google.com/uc?id=1F_vX0fmLL0nKlhaQMhJWHGcFu9a3NxEs
From (redirected): https://drive.google.com/uc?id=1F_vX0fmLL0nKlhaQMhJWHGcFu9a3NxEs&confirm=t&uuid=f72c5ec5-b428-4ff1-b373-ca9cb3eb18e2
To: /content/best_mobilenetv3_student_kd.pth
100% 39.0M/39.0M [00:00<00:00, 120MB/s] 


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torchvision.models as models
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, f1_score, accuracy_score, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
from collections import Counter
from tqdm import tqdm
import numpy as np
from pathlib import Path
import copy
from PIL import Image

# ============================================================================
# CONFIGURATION
# ============================================================================
class Config:
    # Paths
    TASK_A_DATA_PATH = "/content/data/OCT2017/train"  # OCT images folder
    TASK_B_DATA_PATH = "/content/data/chest_xray/train"  # Chest X-ray images folder
    PHASE2_MODEL_PATH = "/content/best_mobilenetv3_student_kd.pth"
    SAVE_DIR = "/content/phase3_lwf_feature_kd_results"

    # Model settings
    TASK_A_CLASSES = 4  # OCT classes
    TASK_B_CLASSES = 2  # Chest X-ray classes

    # LwF hyperparameters
    LWF_ALPHA = 2.0  # Distillation loss weight (tune: 1.0-10.0)
    LWF_TEMPERATURE = 2.0

    # Training hyperparameters
    BATCH_SIZE = 128
    NUM_EPOCHS = 20
    PATIENCE = 5  # Early stopping

    # Data augmentation
    USE_AUGMENTATION = True

    # Evaluation
    EVAL_TASK_A_EVERY = 2  # Evaluate Task A retention every N epochs

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ============================================================================
# DATA LOADING UTILITIES
# ============================================================================
def load_task_paths(data_path):
    """
    Universal data loader for both OCT and Chest X-ray
    Loads from: /path/to/train/CLASS_NAME/*.jpg
    """
    data_path = Path(data_path)

    # Get all class folders
    class_names = sorted([d.name for d in data_path.iterdir() if d.is_dir()])

    all_paths = []
    all_labels = []

    for idx, class_name in enumerate(class_names):
        class_dir = data_path / class_name
        # Support multiple image formats
        paths = list(class_dir.glob('*.jpeg')) + \
                list(class_dir.glob('*.jpg')) + \
                list(class_dir.glob('*.png'))

        all_paths.extend(paths)
        all_labels.extend([idx] * len(paths))

    return all_paths, all_labels, class_names

# ============================================================================
# DATASET CLASS
# ============================================================================
class ImageDataset(Dataset):
    def __init__(self, paths, labels, transform=None):
        self.paths = paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        img = Image.open(self.paths[idx]).convert('RGB')
        label = self.labels[idx]

        if self.transform:
            img = self.transform(img)

        return img, label

# ============================================================================
# DATA SPLITS
# ============================================================================
def create_task_a_splits():
    """Create stratified splits for Task A (OCT)"""
    print("\nüìä Creating Task A (OCT) evaluation splits...")

    all_paths, all_labels, task_a_class_names = load_task_paths(Config.TASK_A_DATA_PATH)
    print(f"   Total Task A samples: {len(all_paths):,}")
    print(f"   Classes: {task_a_class_names}")

    # Class distribution
    class_counts = Counter(all_labels)
    print(f"   Class distribution: {dict(class_counts)}")

    # 70/15/15 split
    train_paths, temp_paths, train_labels, temp_labels = train_test_split(
        all_paths, all_labels, test_size=0.30, stratify=all_labels, random_state=42
    )

    val_paths, test_paths, val_labels, test_labels = train_test_split(
        temp_paths, temp_labels, test_size=0.50, stratify=temp_labels, random_state=42
    )

    print(f"   Train: {len(train_paths):,} | Val: {len(val_paths):,} | Test: {len(test_paths):,}")

    test_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    test_dataset = ImageDataset(test_paths, test_labels, test_transform)
    test_loader = DataLoader(test_dataset, batch_size=Config.BATCH_SIZE,
                            shuffle=False, num_workers=2)

    return test_loader, task_a_class_names

def create_task_b_splits():
    """Create stratified splits for Task B (Chest X-ray)"""
    print("\nüìÇ Creating Task B (Chest X-ray) splits...")

    all_paths, all_labels, task_b_class_names = load_task_paths(Config.TASK_B_DATA_PATH)
    print(f"   Total samples: {len(all_paths):,}")
    print(f"   Classes: {task_b_class_names}")

    # Class distribution
    class_counts = Counter(all_labels)
    print(f"   Class distribution: {dict(class_counts)}")

    # 70/15/15 stratified split
    train_paths, temp_paths, train_labels, temp_labels = train_test_split(
        all_paths, all_labels, test_size=0.30, stratify=all_labels, random_state=42
    )

    val_paths, test_paths, val_labels, test_labels = train_test_split(
        temp_paths, temp_labels, test_size=0.50, stratify=temp_labels, random_state=42
    )

    print(f"   Train: {len(train_paths):,} | Val: {len(val_paths):,} | Test: {len(test_paths):,}")

    # Compute class weights for imbalanced dataset
    train_class_counts = Counter(train_labels)
    total_samples = len(train_labels)
    class_weights = torch.tensor([
        total_samples / (len(train_class_counts) * train_class_counts[i])
        for i in range(len(task_b_class_names))
    ], dtype=torch.float32).to(Config.device)

    print(f"   Class weights: {class_weights.cpu().numpy()}")

    # Data transforms with augmentation
    if Config.USE_AUGMENTATION:
        train_transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomRotation(10),
            transforms.ColorJitter(brightness=0.2, contrast=0.2),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
    else:
        train_transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])

    val_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    train_dataset = ImageDataset(train_paths, train_labels, train_transform)
    val_dataset = ImageDataset(val_paths, val_labels, val_transform)
    test_dataset = ImageDataset(test_paths, test_labels, val_transform)

    train_loader = DataLoader(train_dataset, batch_size=Config.BATCH_SIZE,
                             shuffle=True, num_workers=2)
    val_loader = DataLoader(val_dataset, batch_size=Config.BATCH_SIZE,
                           shuffle=False, num_workers=2)
    test_loader = DataLoader(test_dataset, batch_size=Config.BATCH_SIZE,
                            shuffle=False, num_workers=2)

    return train_loader, val_loader, test_loader, class_weights, task_b_class_names

# ============================================================================
# MULTI-HEAD MODEL (Based on Phase 2 Feature-Based KD Architecture)
# ============================================================================
class MultiHeadMobileNetV3WithFeatures(nn.Module):
    """
    Multi-head MobileNetV3 for continual learning
    Based on Phase 2's feature-based KD architecture
    """
    def __init__(self, num_classes_a=4, num_classes_b=2):
        super().__init__()

        # Load MobileNetV3-Large backbone (same as Phase 2)
        mobilenet = models.mobilenet_v3_large(weights=None)
        self.features = mobilenet.features  # Shared backbone
        self.avgpool = mobilenet.avgpool

        # Feature dimension from MobileNetV3-Large
        self.feature_dim = 960

        # Task A head (OCT) - EXACT same structure as Phase 2
        self.head_a = nn.Sequential(
            nn.Linear(self.feature_dim, 256),
            nn.Hardswish(inplace=True),
            nn.Dropout(0.2),
            nn.Linear(256, num_classes_a)
        )

        # Task B head (Chest X-ray) - new head for new task
        self.head_b = nn.Sequential(
            nn.Linear(self.feature_dim, 256),
            nn.Hardswish(inplace=True),
            nn.Dropout(0.2),
            nn.Linear(256, num_classes_b)
        )

        # Feature projector (same as Phase 2,)
        self.feature_projector = nn.Sequential(
            nn.Linear(self.feature_dim, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 2048)
        )

    def extract_features(self, x):
        """Extract features before the final classifier"""
        x = self.features(x)
        x = self.avgpool(x)
        features = torch.flatten(x, 1)  # [batch_size, 960]
        return features

    def forward(self, x, task='b', return_features=False):
        """
        Forward pass with task selection

        Args:
            x: input tensor
            task: 'a' for OCT, 'b' for Chest X-ray
            return_features: if True, return (logits, features)
        """
        features = self.extract_features(x)

        if task == 'a':
            logits = self.head_a(features)
        elif task == 'b':
            logits = self.head_b(features)
        else:
            raise ValueError(f"Unknown task: {task}")

        if return_features:
            return logits, features
        return logits

# ============================================================================
# LWF DISTILLATION LOSS
# ============================================================================
def distillation_loss(student_logits, teacher_logits, temperature):
    """
    Compute knowledge distillation loss

    Args:
        student_logits: Raw logits from student model
        teacher_logits: Raw logits from teacher model (frozen)
        temperature: Temperature for softening probabilities

    Returns:
        Distillation loss (KL divergence between soft targets)
    """
    # Soften probabilities with temperature
    student_soft = nn.functional.log_softmax(student_logits / temperature, dim=1)
    teacher_soft = nn.functional.softmax(teacher_logits / temperature, dim=1)

    # KL divergence loss
    kl_div = nn.functional.kl_div(
        student_soft,
        teacher_soft,
        reduction='batchmean'
    )

    # Scale by temperature^2
    return kl_div * (temperature ** 2)

# ============================================================================
# LOAD PHASE 2 MODEL
# ============================================================================
def load_phase2_model():
    """Load Phase 2 Feature-Based KD model"""
    print("\nüìÇ Loading Phase 2 Feature-Based KD model...")

    if not Path(Config.PHASE2_MODEL_PATH).exists():
        raise FileNotFoundError(f"Phase 2 model not found at {Config.PHASE2_MODEL_PATH}")

    checkpoint = torch.load(Config.PHASE2_MODEL_PATH, map_location=Config.device)

    # Create multi-head model
    model = MultiHeadMobileNetV3WithFeatures(
        num_classes_a=Config.TASK_A_CLASSES,
        num_classes_b=Config.TASK_B_CLASSES
    )

    # Load weights with correct mapping from Phase 2
    model_state = {}

    if 'model_state_dict' in checkpoint:
        phase2_state = checkpoint['model_state_dict']
    else:
        phase2_state = checkpoint

    for key, value in phase2_state.items():
        if key.startswith('backbone.features'):
            # backbone.features.X -> features.X
            new_key = key.replace('backbone.', '')
            model_state[new_key] = value
        elif key.startswith('backbone.classifier'):
            # backbone.classifier.X -> head_a.X
            new_key = key.replace('backbone.classifier', 'head_a')
            model_state[new_key] = value
        elif key == 'feature_projector.0.weight' or key == 'feature_projector.0.bias' or \
             key == 'feature_projector.2.weight' or key == 'feature_projector.2.bias':
            # Keep feature projector weights (compatibility)
            model_state[key] = value

    # Load the mapped weights (strict=False because head_b is new)
    missing_keys, unexpected_keys = model.load_state_dict(model_state, strict=False)
    model = model.to(Config.device)

    print("   Phase 2 model loaded")
    print(f"  Phase 2 Test F1: {checkpoint.get('test_f1', 'N/A')}")
    print(f"  Phase 2 Test Acc: {checkpoint.get('test_acc', 'N/A')}")
    print(f"  Verification: head_a.3.weight shape = {model.head_a[3].weight.shape}")
    print(f"  New Task B head initialized: head_b.3.weight shape = {model.head_b[3].weight.shape}")

    return model

# ============================================================================
# EVALUATION FUNCTIONS
# ============================================================================
def evaluate_task(model, dataloader, task, class_names):
    """Evaluate model on a specific task"""
    model.eval()
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for images, labels in dataloader:
            images = images.to(Config.device)
            outputs = model(images, task=task)
            preds = outputs.argmax(dim=1)

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.numpy())

    acc = accuracy_score(all_labels, all_preds)
    f1 = f1_score(all_labels, all_preds, average='weighted')

    return acc, f1, all_preds, all_labels

def print_evaluation_report(acc, f1, preds, labels, class_names, task_name):
    """Print detailed evaluation report"""
    print(f"\n{'='*70}")
    print(f"üìä {task_name} EVALUATION")
    print(f"{'='*70}")
    print(f"   Accuracy:  {acc*100:.2f}%")
    print(f"   F1-Score:  {f1:.4f}")
    print(f"\nüìã Classification Report:")
    print(classification_report(labels, preds, target_names=class_names, digits=4))

# ============================================================================
# TRAINING FUNCTION WITH LWF
# ============================================================================
def train_phase3_lwf():
    """Phase 3: Continual Learning with LwF for Feature-Based KD Student"""
    print("\n" + "="*70)
    print("üöÄ PHASE 3: CONTINUAL LEARNING WITH LWF")
    print("   Model: MobileNetV3 (Feature-Based KD from Phase 2)")
    print("="*70)

    # Create save directory
    Path(Config.SAVE_DIR).mkdir(exist_ok=True)

    # Load Phase 2 model (student)
    student_model = load_phase2_model()

    # Create teacher model (frozen copy of Phase 2 model)
    print("\nüìö Creating teacher model (frozen copy)...")
    teacher_model = copy.deepcopy(student_model)
    teacher_model.eval()  # Set to eval mode

    # Freeze all teacher parameters
    for param in teacher_model.parameters():
        param.requires_grad = False

    print("   ‚úÖ Teacher model created and frozen")

    # Create Task A test loader for retention evaluation
    task_a_test_loader, task_a_classes = create_task_a_splits()

    # Evaluate Task A before fine-tuning (baseline)
    print("\nüß™ Evaluating Task A (OCT) BEFORE adapting to Task B...")
    task_a_acc_before, task_a_f1_before, _, _ = evaluate_task(
        student_model, task_a_test_loader, task='a', class_names=task_a_classes
    )
    print(f"   Task A Accuracy: {task_a_acc_before*100:.2f}%")
    print(f"   Task A F1: {task_a_f1_before:.4f}")

    # Create Task B dataloaders
    train_loader, val_loader, test_loader, class_weights, task_b_classes = create_task_b_splits()

    # Setup training - ONLY train shared features and Task B head
    # Task A head remains FROZEN to preserve its weights
    optimizer = optim.Adam([
        {'params': student_model.features.parameters(), 'lr': 1e-5},  # Fine-tune backbone
        {'params': student_model.head_b.parameters(), 'lr': 1e-4}     # Train new head
    ])

    criterion = nn.CrossEntropyLoss(weight=class_weights)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5,
                                                     patience=3)

    # Training loop
    best_val_f1 = 0.0
    patience_counter = 0
    history = {'train_loss': [], 'val_f1': [], 'task_a_f1': []}

    print(f"\n Training Task B (Chest X-ray) with LwF...")
    print(f"   Alpha (Œ±): {Config.LWF_ALPHA}")
    print(f"   Temperature (T): {Config.LWF_TEMPERATURE}")

    for epoch in range(Config.NUM_EPOCHS):
        # Training
        student_model.train()
        # Keep Task A head in eval mode (frozen)
        student_model.head_a.eval()

        train_loss = 0.0

        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{Config.NUM_EPOCHS}")
        for images, labels in pbar:
            images, labels = images.to(Config.device), labels.to(Config.device)

            optimizer.zero_grad()

            # Student predictions on Task B (new task)
            student_logits_b = student_model(images, task='b')

            # Task B classification loss
            ce_loss = criterion(student_logits_b, labels)

            # Get teacher's predictions on Task A (to preserve old knowledge)
            with torch.no_grad():
                teacher_logits_a = teacher_model(images, task='a')

            # Student's predictions on Task A (we want these to stay similar to teacher)
            student_logits_a = student_model(images, task='a')

            # Distillation loss (preserve Task A knowledge)
            distill_loss = distillation_loss(
                student_logits_a,
                teacher_logits_a,
                Config.LWF_TEMPERATURE
            )

            # Total loss = Task B loss + Œ± * Distillation loss
            total_loss = ce_loss + Config.LWF_ALPHA * distill_loss

            total_loss.backward()
            optimizer.step()

            train_loss += total_loss.item()
            pbar.set_postfix({
                'loss': f'{total_loss.item():.4f}',
                'ce': f'{ce_loss.item():.4f}',
                'distill': f'{distill_loss.item():.4f}'
            })

        avg_train_loss = train_loss / len(train_loader)
        history['train_loss'].append(avg_train_loss)

        # Validation on Task B
        val_acc, val_f1, _, _ = evaluate_task(student_model, val_loader, task='b',
                                             class_names=task_b_classes)
        history['val_f1'].append(val_f1)

        print(f"\n   Epoch {epoch+1} - Task B Val F1: {val_f1:.4f} | Acc: {val_acc*100:.2f}%")

        # Evaluate Task A retention periodically
        if (epoch + 1) % Config.EVAL_TASK_A_EVERY == 0:
            task_a_acc, task_a_f1, _, _ = evaluate_task(student_model, task_a_test_loader,
                                                        task='a', class_names=task_a_classes)
            history['task_a_f1'].append(task_a_f1)
            retention = (task_a_f1 / task_a_f1_before) * 100
            print(f"    Task A Retention: F1={task_a_f1:.4f} ({retention:.2f}% of baseline)")

        # Learning rate scheduling
        scheduler.step(val_f1)

        # Early stopping and checkpointing
        if val_f1 > best_val_f1:
            best_val_f1 = val_f1
            patience_counter = 0
            torch.save({
                'epoch': epoch,
                'model_state_dict': student_model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_f1': val_f1,
                'task_a_f1_before': task_a_f1_before,
                'lwf_alpha': Config.LWF_ALPHA,
                'lwf_temperature': Config.LWF_TEMPERATURE
            }, f"{Config.SAVE_DIR}/phase3_lwf_best.pth")
            print(f"   üíæ Best model saved (Val F1: {val_f1:.4f})")
        else:
            patience_counter += 1
            if patience_counter >= Config.PATIENCE:
                print(f"\n‚è∏Ô∏è  Early stopping triggered (patience={Config.PATIENCE})")
                break

    # Final evaluation
    print("\n" + "="*70)
    print("üìä FINAL EVALUATION")
    print("="*70)

    # Load best model
    checkpoint = torch.load(f"{Config.SAVE_DIR}/phase3_lwf_best.pth")
    student_model.load_state_dict(checkpoint['model_state_dict'])

    # Task B (Chest X-ray) - Test set
    task_b_acc, task_b_f1, task_b_preds, task_b_labels = evaluate_task(
        student_model, test_loader, task='b', class_names=task_b_classes
    )
    print_evaluation_report(task_b_acc, task_b_f1, task_b_preds, task_b_labels,
                          task_b_classes, "TASK B (Chest X-ray)")

    # Task A (OCT) - Retention test
    task_a_acc_after, task_a_f1_after, task_a_preds, task_a_labels = evaluate_task(
        student_model, task_a_test_loader, task='a', class_names=task_a_classes
    )
    print_evaluation_report(task_a_acc_after, task_a_f1_after, task_a_preds, task_a_labels,
                          task_a_classes, "TASK A (OCT) - Retention Check")

    # Retention metrics
    retention_f1 = (task_a_f1_after / task_a_f1_before) * 100
    retention_acc = (task_a_acc_after / task_a_acc_before) * 100

    print("\n" + "="*70)
    print("üéØ CONTINUAL LEARNING SUMMARY")
    print("="*70)
    print(f"üìä Task A (OCT) Retention:")
    print(f"   Before: F1={task_a_f1_before:.4f}, Acc={task_a_acc_before*100:.2f}%")
    print(f"   After:  F1={task_a_f1_after:.4f}, Acc={task_a_acc_after*100:.2f}%")
    print(f"   Retention: F1={retention_f1:.2f}%, Acc={retention_acc:.2f}%")
    print(f"\nüìä Task B (Chest X-ray) Performance:")
    print(f"   Test F1: {task_b_f1:.4f}")
    print(f"   Test Acc: {task_b_acc*100:.2f}%")
    print("="*70)

    # Save confusion matrices
    save_confusion_matrix(task_a_labels, task_a_preds, task_a_classes,
                         "Task A (OCT) - After LwF", f"{Config.SAVE_DIR}/cm_task_a.png")
    save_confusion_matrix(task_b_labels, task_b_preds, task_b_classes,
                         "Task B (Chest X-ray)", f"{Config.SAVE_DIR}/cm_task_b.png")

    return student_model, history

def save_confusion_matrix(labels, preds, class_names, title, save_path):
    """Save confusion matrix plot"""
    cm = confusion_matrix(labels, preds)
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=class_names, yticklabels=class_names)
    plt.title(title)
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"   ‚úÖ Confusion matrix saved: {save_path}")

# ============================================================================
# MAIN EXECUTION
# ============================================================================
if __name__ == "__main__":
    model, history = train_phase3_lwf()


üöÄ PHASE 3: CONTINUAL LEARNING WITH LWF
   Model: MobileNetV3 (Feature-Based KD from Phase 2)

üìÇ Loading Phase 2 Feature-Based KD model...
   Phase 2 model loaded
  Phase 2 Test F1: N/A
  Phase 2 Test Acc: N/A
  Verification: head_a.3.weight shape = torch.Size([4, 256])
  New Task B head initialized: head_b.3.weight shape = torch.Size([2, 256])

üìö Creating teacher model (frozen copy)...
   ‚úÖ Teacher model created and frozen

üìä Creating Task A (OCT) evaluation splits...
   Total Task A samples: 83,484
   Classes: ['CNV', 'DME', 'DRUSEN', 'NORMAL']
   Class distribution: {0: 37205, 1: 11348, 2: 8616, 3: 26315}
   Train: 58,438 | Val: 12,523 | Test: 12,523

üß™ Evaluating Task A (OCT) BEFORE adapting to Task B...
   Task A Accuracy: 97.06%
   Task A F1: 0.9706

üìÇ Creating Task B (Chest X-ray) splits...
   Total samples: 5,232
   Classes: ['NORMAL', 'PNEUMONIA']
   Class distribution: {0: 1349, 1: 3883}
   Train: 3,662 | Val: 785 | Test: 785
   Class weights: [1.9396186 0

Epoch 1/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 29/29 [01:12<00:00,  2.50s/it, loss=4.8426, ce=0.4788, distill=2.1819]



   Epoch 1 - Task B Val F1: 0.4665 | Acc: 47.52%
   üíæ Best model saved (Val F1: 0.4665)


Epoch 2/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 29/29 [01:08<00:00,  2.37s/it, loss=3.1166, ce=0.4570, distill=1.3298]



   Epoch 2 - Task B Val F1: 0.6729 | Acc: 65.48%
   üìà Task A Retention: F1=0.8181 (84.28% of baseline)
   üíæ Best model saved (Val F1: 0.6729)


Epoch 3/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 29/29 [01:08<00:00,  2.37s/it, loss=3.4063, ce=0.4165, distill=1.4949]



   Epoch 3 - Task B Val F1: 0.8121 | Acc: 80.00%
   üíæ Best model saved (Val F1: 0.8121)


Epoch 4/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 29/29 [01:08<00:00,  2.37s/it, loss=3.1142, ce=0.3867, distill=1.3637]



   Epoch 4 - Task B Val F1: 0.8518 | Acc: 84.33%
   üìà Task A Retention: F1=0.6550 (67.48% of baseline)
   üíæ Best model saved (Val F1: 0.8518)


Epoch 5/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 29/29 [01:09<00:00,  2.38s/it, loss=2.4044, ce=0.3242, distill=1.0401]



   Epoch 5 - Task B Val F1: 0.8520 | Acc: 84.33%
   üíæ Best model saved (Val F1: 0.8520)


Epoch 6/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 29/29 [01:07<00:00,  2.34s/it, loss=2.3552, ce=0.2649, distill=1.0451]



   Epoch 6 - Task B Val F1: 0.8660 | Acc: 85.86%
   üìà Task A Retention: F1=0.5788 (59.63% of baseline)
   üíæ Best model saved (Val F1: 0.8660)


Epoch 7/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 29/29 [01:07<00:00,  2.34s/it, loss=1.7539, ce=0.3324, distill=0.7108]



   Epoch 7 - Task B Val F1: 0.8824 | Acc: 87.64%
   üíæ Best model saved (Val F1: 0.8824)


Epoch 8/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 29/29 [01:08<00:00,  2.36s/it, loss=1.5656, ce=0.2191, distill=0.6732]



   Epoch 8 - Task B Val F1: 0.9070 | Acc: 90.32%
   üìà Task A Retention: F1=0.5577 (57.46% of baseline)
   üíæ Best model saved (Val F1: 0.9070)


Epoch 9/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 29/29 [01:08<00:00,  2.37s/it, loss=1.5973, ce=0.2829, distill=0.6572]



   Epoch 9 - Task B Val F1: 0.8883 | Acc: 88.28%


Epoch 10/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 29/29 [01:08<00:00,  2.37s/it, loss=1.4294, ce=0.2555, distill=0.5869]



   Epoch 10 - Task B Val F1: 0.9011 | Acc: 89.68%
   üìà Task A Retention: F1=0.5603 (57.73% of baseline)


Epoch 11/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 29/29 [01:09<00:00,  2.40s/it, loss=1.3337, ce=0.2240, distill=0.5549]



   Epoch 11 - Task B Val F1: 0.9128 | Acc: 90.96%
   üíæ Best model saved (Val F1: 0.9128)


Epoch 12/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 29/29 [01:08<00:00,  2.37s/it, loss=1.1249, ce=0.1827, distill=0.4711]



   Epoch 12 - Task B Val F1: 0.9117 | Acc: 90.83%
   üìà Task A Retention: F1=0.5687 (58.59% of baseline)


Epoch 13/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 29/29 [01:09<00:00,  2.38s/it, loss=1.1468, ce=0.3129, distill=0.4170]



   Epoch 13 - Task B Val F1: 0.9141 | Acc: 91.08%
   üíæ Best model saved (Val F1: 0.9141)


Epoch 14/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 29/29 [01:09<00:00,  2.40s/it, loss=0.9299, ce=0.0598, distill=0.4351]



   Epoch 14 - Task B Val F1: 0.9117 | Acc: 90.83%
   üìà Task A Retention: F1=0.5752 (59.26% of baseline)


Epoch 15/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 29/29 [01:08<00:00,  2.37s/it, loss=0.6769, ce=0.1569, distill=0.2600]



   Epoch 15 - Task B Val F1: 0.9106 | Acc: 90.70%


Epoch 16/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 29/29 [01:09<00:00,  2.40s/it, loss=0.7077, ce=0.2324, distill=0.2376]



   Epoch 16 - Task B Val F1: 0.9211 | Acc: 91.85%
   üìà Task A Retention: F1=0.5851 (60.29% of baseline)
   üíæ Best model saved (Val F1: 0.9211)


Epoch 17/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 29/29 [01:08<00:00,  2.36s/it, loss=0.7431, ce=0.1455, distill=0.2988]



   Epoch 17 - Task B Val F1: 0.9223 | Acc: 91.97%
   üíæ Best model saved (Val F1: 0.9223)


Epoch 18/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 29/29 [01:08<00:00,  2.36s/it, loss=0.7627, ce=0.1968, distill=0.2830]



   Epoch 18 - Task B Val F1: 0.9165 | Acc: 91.34%
   üìà Task A Retention: F1=0.5951 (61.31% of baseline)


Epoch 19/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 29/29 [01:08<00:00,  2.37s/it, loss=0.5391, ce=0.1027, distill=0.2182]



   Epoch 19 - Task B Val F1: 0.9212 | Acc: 91.85%


Epoch 20/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 29/29 [01:08<00:00,  2.37s/it, loss=0.6072, ce=0.0684, distill=0.2694]



   Epoch 20 - Task B Val F1: 0.9166 | Acc: 91.34%
   üìà Task A Retention: F1=0.6048 (62.31% of baseline)

üìä FINAL EVALUATION

üìä TASK B (Chest X-ray) EVALUATION
   Accuracy:  94.27%
   F1-Score:  0.9439

üìã Classification Report:
              precision    recall  f1-score   support

      NORMAL     0.8405    0.9606    0.8966       203
   PNEUMONIA     0.9855    0.9364    0.9604       582

    accuracy                         0.9427       785
   macro avg     0.9130    0.9485    0.9285       785
weighted avg     0.9480    0.9427    0.9439       785


üìä TASK A (OCT) - Retention Check EVALUATION
   Accuracy:  60.58%
   F1-Score:  0.5902

üìã Classification Report:
              precision    recall  f1-score   support

         CNV     0.9329    0.3962    0.5562      5581
         DME     0.9554    0.2515    0.3981      1702
      DRUSEN     0.2911    0.7989    0.4267      1293
      NORMAL     0.6358    0.9916    0.7748      3947

    accuracy                         0.605