In [1]:
# @title Cell 1: Infrastructure Setup + Landing Page - ORANGE THEME

# File: MMLab_Project_Demo.ipynb - Cell 1
# Location: Thesis_MER_Project/MMLab_Project_Demo.ipynb
# Purpose: Combined infrastructure and landing page - ORANGE THEMED

import os
import sys
import warnings
warnings.filterwarnings('ignore')

print("=" * 80)
print("MULTIMEDIA LABORATORY - TELKOM UNIVERSITY")
print("Interactive Project Demonstrations")
print("=" * 80)

# ============================================================================
# SECTION 1: GOOGLE DRIVE MOUNTING
# ============================================================================
print("\n[1/4] Mounting Google Drive...")
from google.colab import drive
drive.mount('/content/drive', force_remount=False)
print("Google Drive mounted successfully")

# ============================================================================
# SECTION 2: DEPENDENCY INSTALLATION
# ============================================================================
print("\n[2/4] Installing dependencies...")
try:
    import gradio as gr
    from transformers import ViTModel, ViTImageProcessor
    import torch
    import torchvision
    import timm
    print("All packages available")
except ImportError:
    print("Installing missing packages...")
    import subprocess
    subprocess.check_call([sys.executable, "-m", "pip", "install", "-q",
                          "gradio", "transformers", "timm", "opencv-python"])
    import gradio as gr
    from transformers import ViTModel, ViTImageProcessor
    import torch
    import torchvision
    import timm
    print("Dependencies installed successfully")

# Core imports
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
import cv2
import json
import matplotlib.pyplot as plt

# ============================================================================
# SECTION 3: PROJECT CONFIGURATIONS (ORIGINAL PATHS - UNCHANGED)
# ============================================================================
print("\n[3/4] Configuring project paths...")

# Base paths
PROJECT_BASE = "/content/drive/MyDrive"

# HER2 Project paths (ORIGINAL - DO NOT CHANGE)
HER2_PROJECT_ROOT = f"{PROJECT_BASE}/SOURCE_CODE"
HER2_MODELS_ROOT = f"{HER2_PROJECT_ROOT}/Models"

# MER Project paths
MER_PROJECT_ROOT = f"{PROJECT_BASE}/RESEARCH-WORKSPACE/ACTIVE-PROJECTS/Thesis_MER_Project"
MER_MODELS_ROOT = f"{MER_PROJECT_ROOT}/models"

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if torch.cuda.is_available():
    gpu_name = torch.cuda.get_device_name(0)
    torch.backends.cudnn.benchmark = True
    print(f"Device: {device} ({gpu_name})")
else:
    print(f"Device: {device}")

# ============================================================================
# SECTION 4: PROJECT METADATA (DUAL STRUCTURE FOR COMPATIBILITY)
# ============================================================================

# Simple PROJECT_INFO for Cell 8 compatibility - UPDATED COLOR
PROJECT_INFO = {
    'lab_name': 'Multimedia Laboratory',
    'university': 'Telkom University',
    'location': 'Bandung, Indonesia',
    'primary_color': '#FF8C00'  # Changed from red to orange
}

# Detailed project information for landing page cards
PROJECT_DETAILS = {
    'her2': {
        'title': 'HER2 Status Classification',
        'subtitle': 'Deep Learning for Gastroesophageal Cancer Diagnosis',
        'description': '''
        Automated classification of HER2 status in gastroesophageal adenocarcinoma
        from tissue microarray images using fusion of CNN and Vision Transformer models.
        ''',
        'features': [
            '4 Model Architectures (MobileNetV3, ViT, Fusion Concat, Fusion Addition)',
            'Dual Preprocessing Variants (Standard & Medical Enhancement)',
            'Interactive Attention Visualization (Grad-CAM & Transformer Attention)',
            'IHC Score (0-3) and HER2 Status (Negative/Positive) Classification'
        ],
        'metrics': {
            'Dataset': 'TMA Images',
            'Best Model': 'Fusion (Concatenation)',
            'Accuracy': '~85% (IHC), ~90% (HER2)',
            'Status': 'Production Ready'
        },
        'thumbnail': 'https://via.placeholder.com/400x250/FF8C00/ffffff?text=HER2+Classification',
        'demo_available': True,
        'publication': 'Under Review - IEEE/ACM Conference'
    },
    'mer': {
        'title': 'Micro-Expression Recognition',
        'subtitle': 'CNN and Transformer Baselines for CASME II Dataset',
        'description': '''
        Comparative study of CNN and Transformer architectures for micro-expression recognition
        with investigation of preprocessing paradox phenomenon.
        ''',
        'features': [
            '6 Model Architectures (3 CNNs + 3 Transformers)',
            'Dual Methodologies (M1 Raw RGB, M2 Preprocessed)',
            '7-Category CASME II Classification (Extreme Class Imbalance)',
            'Real-time Webcam Classification'
        ],
        'metrics': {
            'Dataset': 'CASME II (255 videos)',
            'Best Model': 'MobileNetV3-Small M1',
            'Macro F1': '0.3880',
            'Status': 'Research Phase'
        },
        'thumbnail': 'https://via.placeholder.com/400x250/3b82f6/ffffff?text=Micro-Expression+Recognition',
        'demo_available': True,
        'publication': 'IEEE ICICyTA 2025 (Accepted)'
    }
}

print("Project metadata configured")

# ============================================================================
# SECTION 5: GLOBAL STYLING & THEME - ORANGE THEME
# ============================================================================

# Telkom University inspired theme - ORANGE COLOR SCHEME
MMLAB_THEME = gr.themes.Soft(
    primary_hue="orange",  # Changed from red
    secondary_hue="slate",
    neutral_hue="slate",
    font=gr.themes.GoogleFont("Inter"),
    font_mono=gr.themes.GoogleFont("JetBrains Mono")
).set(
    # Light mode - ORANGE COLORS
    body_background_fill="white",
    body_text_color="#1f2937",
    button_primary_background_fill="#FF8C00",  # Orange
    button_primary_background_fill_hover="#E67E00",  # Darker orange
    button_primary_text_color="white",
    block_background_fill="#f9fafb",
    block_border_color="#e5e7eb",
    block_border_width="1px",
    block_label_background_fill="#FFF4E6",  # Light orange/cream
    block_label_text_color="#CC7A00",  # Dark orange

    # Dark mode - ORANGE COLORS
    body_background_fill_dark="#111827",
    body_text_color_dark="#f3f4f6",
    button_primary_background_fill_dark="#FF8C00",  # Orange
    button_primary_background_fill_hover_dark="#E67E00",  # Darker orange
    block_background_fill_dark="#1f2937",
    block_border_color_dark="#374151",
    block_label_background_fill_dark="#CC7A00",  # Dark orange
    block_label_text_color_dark="#FFB347"  # Light orange
)

# Custom CSS - ORANGE COLOR SCHEME
CUSTOM_CSS = """
.project-card {
    border: 1px solid #e5e7eb;
    border-radius: 12px;
    padding: 24px;
    background: white;
    box-shadow: 0 1px 3px rgba(0,0,0,0.1);
    transition: all 0.3s ease;
}

.project-card:hover {
    box-shadow: 0 10px 25px rgba(255,140,0,0.2);
    transform: translateY(-4px);
}

.project-title {
    font-size: 24px;
    font-weight: 700;
    color: #1f2937;
    margin-bottom: 8px;
}

.project-subtitle {
    font-size: 14px;
    color: #6b7280;
    font-style: italic;
    margin-bottom: 16px;
}

.project-description {
    font-size: 15px;
    line-height: 1.6;
    color: #374151;
    margin-bottom: 20px;
}

.feature-list {
    list-style: none;
    padding: 0;
    margin: 16px 0;
}

.feature-item {
    padding: 8px 0;
    padding-left: 24px;
    position: relative;
    color: #4b5563;
}

.feature-item:before {
    content: "✓";
    position: absolute;
    left: 0;
    color: #FF8C00;
    font-weight: bold;
}

.metrics-grid {
    display: grid;
    grid-template-columns: repeat(2, 1fr);
    gap: 12px;
    margin: 16px 0;
    padding: 16px;
    background: #f9fafb;
    border-radius: 8px;
}

.metric-item {
    padding: 8px;
}

.metric-label {
    font-size: 12px;
    color: #6b7280;
    text-transform: uppercase;
    letter-spacing: 0.5px;
}

.metric-value {
    font-size: 16px;
    font-weight: 600;
    color: #1f2937;
    margin-top: 4px;
}

.status-badge {
    display: inline-block;
    padding: 4px 12px;
    border-radius: 12px;
    font-size: 12px;
    font-weight: 600;
    text-transform: uppercase;
}

.status-ready {
    background: #dcfce7;
    color: #166534;
}

.status-research {
    background: #dbeafe;
    color: #1e40af;
}

.hero-section {
    text-align: center;
    padding: 48px 24px;
    background: linear-gradient(135deg, #FF8C00 0%, #CC7A00 100%);
    color: white;
    border-radius: 16px;
    margin-bottom: 32px;
}

.hero-title {
    font-size: 42px;
    font-weight: 800;
    margin-bottom: 16px;
    text-shadow: 2px 2px 4px rgba(0,0,0,0.2);
}

.hero-subtitle {
    font-size: 18px;
    opacity: 0.95;
    margin-bottom: 8px;
}

.demo-button {
    background: #FF8C00 !important;
    color: white !important;
    border: none !important;
    padding: 12px 24px !important;
    font-weight: 600 !important;
    border-radius: 8px !important;
    cursor: pointer !important;
    transition: all 0.3s ease !important;
}

.demo-button:hover {
    background: #E67E00 !important;
    transform: scale(1.05);
}

.demo-button:disabled {
    background: #9ca3af !important;
    cursor: not-allowed !important;
    transform: none !important;
}

@media (max-width: 768px) {
    .hero-title {
        font-size: 32px;
    }
    .metrics-grid {
        grid-template-columns: 1fr;
    }
}
"""

print("Styling configured")

# ============================================================================
# SECTION 6: HELPER FUNCTIONS FOR PROJECT CARDS
# ============================================================================

def create_project_card_html(project_key):
    """Generate HTML for project card"""
    info = PROJECT_DETAILS[project_key]

    # Features list
    features_html = "".join([
        f'<li class="feature-item">{feature}</li>'
        for feature in info['features']
    ])

    # Metrics grid
    metrics_html = "".join([
        f'''
        <div class="metric-item">
            <div class="metric-label">{label}</div>
            <div class="metric-value">{value}</div>
        </div>
        '''
        for label, value in info['metrics'].items()
    ])

    # Status badge
    status_class = "status-ready" if info['demo_available'] else "status-research"
    status_text = info['metrics']['Status']

    card_html = f'''
    <div class="project-card">
        <div class="project-title">{info['title']}</div>
        <div class="project-subtitle">{info['subtitle']}</div>
        <div class="project-description">{info['description']}</div>

        <div style="margin: 16px 0;">
            <strong style="color: #374151;">Key Features:</strong>
            <ul class="feature-list">
                {features_html}
            </ul>
        </div>

        <div class="metrics-grid">
            {metrics_html}
        </div>

        <div style="margin-top: 16px; display: flex; justify-content: space-between; align-items: center;">
            <span class="status-badge {status_class}">{status_text}</span>
            <span style="font-size: 12px; color: #6b7280;">{info['publication']}</span>
        </div>
    </div>
    '''

    return card_html

# ============================================================================
# SECTION 7: TAB NAVIGATION STATE MANAGEMENT
# ============================================================================

current_tab_state = gr.State(value="landing")

def switch_to_her2_demo():
    """Switch to HER2 demo tab"""
    return 1

def switch_to_landing():
    """Switch back to landing page"""
    return 0

print("\n[4/4] Infrastructure setup complete")
print("=" * 80)
print("Ready to build interface")
print("=" * 80)

# ============================================================================
# SECTION 8: GRADIO INTERFACE STRUCTURE
# ============================================================================

print("\nBuilding Gradio interface...")

with gr.Blocks(
    title="Multimedia Laboratory - Telkom University",
    theme=MMLAB_THEME,
    css=CUSTOM_CSS
) as demo:

    with gr.Tabs() as main_tabs:

        # Tab 1: Landing Page
        with gr.Tab("Projects", id=0):
            gr.HTML("""
            <div class="hero-section">
                <div class="hero-title">Multimedia Laboratory</div>
                <div class="hero-subtitle">School of Computing - Telkom University</div>
                <div style="margin-top: 16px; font-size: 16px; opacity: 0.9;">
                    Advancing AI Research in Medical Imaging and Affective Computing
                </div>
            </div>
            """)

            gr.Markdown("## Featured Research Projects")

            gr.Markdown("""
            Explore our cutting-edge deep learning research projects.
            Each project demonstrates practical applications of AI in healthcare and human behavior analysis.
            """)

            with gr.Row(equal_height=True):
                with gr.Column(scale=1):
                    gr.HTML(create_project_card_html('her2'))

                    her2_demo_btn = gr.Button(
                        "Try HER2 Demo",
                        variant="primary",
                        size="lg",
                        elem_classes="demo-button"
                    )

                with gr.Column(scale=1):
                    gr.HTML(create_project_card_html('mer'))

                    mer_demo_btn = gr.Button(
                        "Try MER Demo",
                        variant="primary",
                        size="lg",
                        elem_classes="demo-button"
                    )

            gr.Markdown("---")

            with gr.Row():
                with gr.Column(scale=1):
                    gr.Markdown("""
                    ### About the Lab

                    The Multimedia Laboratory focuses on developing intelligent systems
                    for medical image analysis, affective computing, and human-computer interaction.

                    **Research Areas:**
                    - Medical Image Classification
                    - Micro-Expression Recognition
                    - Deep Learning Architectures
                    - Multimodal Fusion Systems
                    """)

                with gr.Column(scale=1):
                    gr.Markdown("""
                    ### Contact and Resources

                    **Location:**
                    School of Computing, Telkom University
                    Bandung, Indonesia

                    **Collaboration Inquiries:**
                    For research collaboration or dataset access,
                    please contact the lab coordinator.

                    **Publications:**
                    Visit our publications page for latest research outputs.
                    """)

        # Tab 2: HER2 Demo (Placeholder - will be filled by Cell 4)
        with gr.Tab("HER2 Classification Demo", id=1):
            gr.Markdown("""
            # HER2 Status Classification

            **Note:** This demo will be activated by Cell 4.
            Please run Cell 2, Cell 3, and Cell 4 to enable full functionality.

            The HER2 demo provides:
            - Medical image upload and preprocessing
            - Model selection (MobileNetV3, ViT, Fusion models)
            - Real-time classification with confidence scores
            - Attention visualization (Grad-CAM and Transformer attention)
            """)

            her2_placeholder = gr.Markdown("""
            ### Initializing Demo Components...

            Run the following cells to activate:
            - Cell 2: Model Loading and Preprocessing Functions
            - Cell 3: Interpretability Functions
            - Cell 4: Interactive Interface
            """)

        # Tab 3: MER Demo (Placeholder - will be filled by Cell 8)
        with gr.Tab("MER Analysis Demo", id=2):
            gr.Markdown("""
            # Micro-Expression Recognition Analysis

            **Note:** This demo will be activated by Cell 8.
            Please run Cells 5-8 to enable full functionality.

            The MER demo provides:
            - Webcam capture for real-time classification
            - 6 model architectures (3 CNNs + 3 Transformers)
            - Dual preprocessing methodologies (M1 Raw, M2 Preprocessed)
            - 7-class emotion probabilities
            """)

    # Event Handlers
    her2_demo_btn.click(
        fn=lambda: gr.Tabs.update(selected=1),
        inputs=None,
        outputs=main_tabs
    )

    mer_demo_btn.click(
        fn=lambda: gr.Tabs.update(selected=2),
        inputs=None,
        outputs=main_tabs
    )

print("Landing page interface built")
print("\n" + "=" * 80)
print("CELL 1 COMPLETE - ORANGE THEME")
print("=" * 80)
print("\nChanges from red theme:")
print("  ✓ Primary color: #dc2626 → #FF8C00 (Orange)")
print("  ✓ Dark color: #991b1b → #CC7A00 (Dark Orange)")
print("  ✓ Hover color: #b91c1c → #E67E00 (Orange hover)")
print("  ✓ Light backgrounds updated to cream/light orange")
print("  ✓ Gradient updated to orange tones")
print("  ✓ All CSS classes updated with orange colors")
print("\nNext steps:")
print("1. Run Cell 2: HER2 Model Loading and Preprocessing")
print("2. Run Cell 3: HER2 Interpretability Functions")
print("3. Run Cell 4: HER2 Interactive Interface")
print("4. Run Cell 5: MER Models (With Transformers)")
print("5. Run Cell 6: MER Preprocessing")
print("6. Run Cell 7: MER Interface")
print("7. Run Cell 8: Launch Integrated Demo")
print("=" * 80)

MULTIMEDIA LABORATORY - TELKOM UNIVERSITY
Interactive Project Demonstrations

[1/4] Mounting Google Drive...
Mounted at /content/drive
Google Drive mounted successfully

[2/4] Installing dependencies...




All packages available

[3/4] Configuring project paths...
Device: cpu
Project metadata configured
Styling configured

[4/4] Infrastructure setup complete
Ready to build interface

Building Gradio interface...
Landing page interface built

CELL 1 COMPLETE - ORANGE THEME

Changes from red theme:
  ✓ Primary color: #dc2626 → #FF8C00 (Orange)
  ✓ Dark color: #991b1b → #CC7A00 (Dark Orange)
  ✓ Hover color: #b91c1c → #E67E00 (Orange hover)
  ✓ Light backgrounds updated to cream/light orange
  ✓ Gradient updated to orange tones
  ✓ All CSS classes updated with orange colors

Next steps:
1. Run Cell 2: HER2 Model Loading and Preprocessing
2. Run Cell 3: HER2 Interpretability Functions
3. Run Cell 4: HER2 Interactive Interface
4. Run Cell 5: MER Models (With Transformers)
5. Run Cell 6: MER Preprocessing
6. Run Cell 7: MER Interface
7. Run Cell 8: Launch Integrated Demo


In [2]:
# @title Cell 2: HER2 Model Loading & Preprocessing Functions

# File: MMLab_Project_Demo.ipynb - Cell 2
# Location: Thesis_MER_Project/MMLab_Project_Demo.ipynb
# Purpose: HER2 model architectures, preprocessing, and loading functions

import pickle
from typing import Tuple, Optional, Dict, Any

print("=" * 80)
print("HER2 PROJECT - MODEL LOADING & PREPROCESSING")
print("=" * 80)

# ============================================================================
# SECTION 1: HER2 CHECKPOINT PATHS
# ============================================================================

print("\n[1/6] Configuring HER2 model checkpoints...")

HER2_CHECKPOINT_PATHS = {
    # Original preprocessing models (8 models)
    'MobileNetV3_IHC_orig': f"{HER2_MODELS_ROOT}/MobileNetV3/ihc_mobilenetv3_orig_WOA_val_F1.pth",
    'MobileNetV3_HER2_orig': f"{HER2_MODELS_ROOT}/MobileNetV3/her2_mobilenetv3_orig_WOA_val_F1.pth",
    'ViT_IHC_orig': f"{HER2_MODELS_ROOT}/ViT/ihc_vit_orig_WOA_val_F1.pth",
    'ViT_HER2_orig': f"{HER2_MODELS_ROOT}/ViT/her2_vit_orig_WOA_val_F1.pth",
    'FusionConcat_IHC_orig': f"{HER2_MODELS_ROOT}/Fusion Concat/ihc_concat_mobilenetv3_vit_orig_WOA_val_F1.pth",
    'FusionConcat_HER2_orig': f"{HER2_MODELS_ROOT}/Fusion Concat/her2_concat_mobilenetv3_vit_orig_WOA_val_F1.pth",
    'FusionAddition_IHC_orig': f"{HER2_MODELS_ROOT}/Fusion Addition/ihc_addition_mobilenetv3_vit_orig_WOA_val_F1.pth",
    'FusionAddition_HER2_orig': f"{HER2_MODELS_ROOT}/Fusion Addition/her2_addition_mobilenetv3_vit_orig_WOA_val_F1.pth",

    # Medical preprocessing models (8 models)
    'MobileNetV3_IHC_prep': f"{HER2_MODELS_ROOT}/MobileNetV3/ihc_mobilenetv3_prep_WOA_val_F1.pth",
    'MobileNetV3_HER2_prep': f"{HER2_MODELS_ROOT}/MobileNetV3/her2_mobilenetv3_prep_WOA_val_F1.pth",
    'ViT_IHC_prep': f"{HER2_MODELS_ROOT}/ViT/ihc_vit_prep_WOA_val_F1.pth",
    'ViT_HER2_prep': f"{HER2_MODELS_ROOT}/ViT/her2_vit_prep_WOA_val_F1.pth",
    'FusionConcat_IHC_prep': f"{HER2_MODELS_ROOT}/Fusion Concat/ihc_concat_mobilenetv3_vit_prep_WOA_val_F1.pth",
    'FusionConcat_HER2_prep': f"{HER2_MODELS_ROOT}/Fusion Concat/her2_concat_mobilenetv3_vit_prep_WOA_val_F1.pth",
    'FusionAddition_IHC_prep': f"{HER2_MODELS_ROOT}/Fusion Addition/ihc_addition_mobilenetv3_vit_prep_WOA_val_F1.pth",
    'FusionAddition_HER2_prep': f"{HER2_MODELS_ROOT}/Fusion Addition/her2_addition_mobilenetv3_vit_prep_WOA_val_F1.pth"
}

print(f"✓ Configured {len(HER2_CHECKPOINT_PATHS)} model checkpoints")
print(f"  - Original preprocessing: 8 models")
print(f"  - Medical preprocessing: 8 models")

# ============================================================================
# SECTION 2: HER2 PREPROCESSING CONFIGURATION
# ============================================================================

print("\n[2/6] Setting up preprocessing pipelines...")

# Image settings
TARGET_SIZE = 1024
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]

# Standard transform for model input
standard_transform = transforms.Compose([
    transforms.Resize((TARGET_SIZE, TARGET_SIZE), interpolation=transforms.InterpolationMode.LANCZOS),
    transforms.ToTensor(),
    transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
])

# ViT processor
vit_processor = ViTImageProcessor.from_pretrained(
    'google/vit-base-patch32-224-in21k',
    do_resize=False,
    do_normalize=True,
    do_rescale=True,
    do_center_crop=False
)

print("✓ Preprocessing pipelines configured")

# ============================================================================
# SECTION 3: HER2 TASK CONFIGURATIONS
# ============================================================================

HER2_TASK_CONFIGS = {
    'IHC': {
        'full_name': 'IHC Score Classification',
        'num_classes': 4,
        'class_names': ['IHC_0', 'IHC_1', 'IHC_2', 'IHC_3'],
        'description': 'Immunohistochemistry score (0-3)'
    },
    'HER2': {
        'full_name': 'HER2 Status Classification',
        'num_classes': 2,
        'class_names': ['HER2_negative', 'HER2_positive'],
        'description': 'HER2 receptor status'
    }
}

print(f"✓ Task configurations: {list(HER2_TASK_CONFIGS.keys())}")

# ============================================================================
# SECTION 4: HER2 MODEL ARCHITECTURES
# ============================================================================

print("\n[3/6] Defining model architectures...")

class MobileNetV3_Architecture(nn.Module):
    """MobileNetV3 baseline for medical image classification"""

    def __init__(self, num_classes, dropout_rate=0.2):
        super(MobileNetV3_Architecture, self).__init__()

        self.mobilenet = timm.create_model(
            'mobilenetv3_large_100',
            pretrained=True,
            num_classes=0,
            global_pool='avg'
        )

        for param in self.mobilenet.parameters():
            param.requires_grad = True

        self.mobilenet_feature_dim = 1280

        self.classifier_layers = nn.Sequential(
            nn.Linear(self.mobilenet_feature_dim, 512),
            nn.LayerNorm(512),
            nn.GELU(),
            nn.Dropout(dropout_rate),
            nn.Linear(512, 128),
            nn.LayerNorm(128),
            nn.GELU(),
            nn.Dropout(dropout_rate)
        )

        self.classifier = nn.Linear(128, num_classes)

    def forward(self, x):
        mobilenet_features = self.mobilenet(x)
        processed_features = self.classifier_layers(mobilenet_features)
        output = self.classifier(processed_features)
        return output


class ViT_Architecture(nn.Module):
    """Vision Transformer with attention extraction support"""

    def __init__(self, num_classes, dropout_rate=0.2):
        super(ViT_Architecture, self).__init__()

        self.vit = ViTModel.from_pretrained(
            'google/vit-base-patch32-224-in21k',
            add_pooling_layer=False,
            output_attentions=True
        )

        self.vit.config.output_attentions = True

        for param in self.vit.parameters():
            param.requires_grad = True

        self.vit_feature_dim = self.vit.config.hidden_size

        self.classifier_layers = nn.Sequential(
            nn.Linear(self.vit_feature_dim, 512),
            nn.LayerNorm(512),
            nn.GELU(),
            nn.Dropout(dropout_rate),
            nn.Linear(512, 128),
            nn.LayerNorm(128),
            nn.GELU(),
            nn.Dropout(dropout_rate)
        )

        self.classifier = nn.Linear(128, num_classes)

    def forward(self, pixel_values, output_attentions=False):
        vit_outputs = self.vit(
            pixel_values=pixel_values,
            interpolate_pos_encoding=True,
            output_attentions=output_attentions
        )
        vit_features = vit_outputs.last_hidden_state[:, 0]
        processed_features = self.classifier_layers(vit_features)
        output = self.classifier(processed_features)

        if output_attentions:
            return output, vit_outputs.attentions
        return output


class FusionConcat_Architecture(nn.Module):
    """Fusion model with concatenation strategy"""

    def __init__(self, num_classes, dropout_rate=0.2):
        super(FusionConcat_Architecture, self).__init__()

        self.mobilenet = timm.create_model(
            'mobilenetv3_large_100',
            pretrained=True,
            num_classes=0,
            global_pool='avg'
        )

        self.vit = ViTModel.from_pretrained(
            'google/vit-base-patch32-224-in21k',
            add_pooling_layer=False,
            output_attentions=True
        )

        self.vit.config.output_attentions = True

        for param in self.mobilenet.parameters():
            param.requires_grad = True
        for param in self.vit.parameters():
            param.requires_grad = True

        self.mobilenet_dim = 1280
        self.vit_dim = self.vit.config.hidden_size
        self.projection_dim = 512

        self.mobilenet_projector = nn.Sequential(
            nn.Linear(self.mobilenet_dim, self.projection_dim),
            nn.LayerNorm(self.projection_dim),
            nn.GELU(),
            nn.Dropout(dropout_rate)
        )

        self.vit_projector = nn.Sequential(
            nn.Linear(self.vit_dim, self.projection_dim),
            nn.LayerNorm(self.projection_dim),
            nn.GELU(),
            nn.Dropout(dropout_rate)
        )

        self.fusion_classifier = nn.Sequential(
            nn.Linear(1024, 512),
            nn.LayerNorm(512),
            nn.GELU(),
            nn.Dropout(dropout_rate),
            nn.Linear(512, 128),
            nn.LayerNorm(128),
            nn.GELU(),
            nn.Dropout(dropout_rate)
        )

        self.classifier = nn.Linear(128, num_classes)

    def forward(self, images, pixel_values, return_features=False):
        mobilenet_features = self.mobilenet(images)
        mobilenet_projected = self.mobilenet_projector(mobilenet_features)

        vit_outputs = self.vit(
            pixel_values=pixel_values,
            interpolate_pos_encoding=True,
            output_attentions=True
        )
        vit_features = vit_outputs.last_hidden_state[:, 0]
        vit_projected = self.vit_projector(vit_features)

        mobilenet_weighted = mobilenet_projected * 0.5
        vit_weighted = vit_projected * 0.5
        fused_features = torch.cat([mobilenet_weighted, vit_weighted], dim=1)

        processed_features = self.fusion_classifier(fused_features)
        output = self.classifier(processed_features)

        if return_features:
            feature_dict = {
                'mobilenet_features': mobilenet_features,
                'vit_features': vit_features,
                'mobilenet_projected': mobilenet_projected,
                'vit_projected': vit_projected,
                'mobilenet_weighted': mobilenet_weighted,
                'vit_weighted': vit_weighted,
                'fused_features': fused_features,
                'vit_attentions': vit_outputs.attentions
            }
            return output, feature_dict

        return output


class FusionAddition_Architecture(nn.Module):
    """Fusion model with addition strategy"""

    def __init__(self, num_classes, dropout_rate=0.2):
        super(FusionAddition_Architecture, self).__init__()

        self.mobilenet = timm.create_model(
            'mobilenetv3_large_100',
            pretrained=True,
            num_classes=0,
            global_pool='avg'
        )

        self.vit = ViTModel.from_pretrained(
            'google/vit-base-patch32-224-in21k',
            add_pooling_layer=False,
            output_attentions=True
        )

        self.vit.config.output_attentions = True

        for param in self.mobilenet.parameters():
            param.requires_grad = True
        for param in self.vit.parameters():
            param.requires_grad = True

        self.mobilenet_dim = 1280
        self.vit_dim = self.vit.config.hidden_size
        self.projection_dim = 512

        self.mobilenet_projector = nn.Sequential(
            nn.Linear(self.mobilenet_dim, self.projection_dim),
            nn.LayerNorm(self.projection_dim),
            nn.GELU(),
            nn.Dropout(dropout_rate)
        )

        self.vit_projector = nn.Sequential(
            nn.Linear(self.vit_dim, self.projection_dim),
            nn.LayerNorm(self.projection_dim),
            nn.GELU(),
            nn.Dropout(dropout_rate)
        )

        self.fusion_classifier = nn.Sequential(
            nn.Linear(512, 256),
            nn.LayerNorm(256),
            nn.GELU(),
            nn.Dropout(dropout_rate),
            nn.Linear(256, 128),
            nn.LayerNorm(128),
            nn.GELU(),
            nn.Dropout(dropout_rate)
        )

        self.classifier = nn.Linear(128, num_classes)

    def forward(self, images, pixel_values, return_features=False):
        mobilenet_features = self.mobilenet(images)
        mobilenet_projected = self.mobilenet_projector(mobilenet_features)

        vit_outputs = self.vit(
            pixel_values=pixel_values,
            interpolate_pos_encoding=True,
            output_attentions=True
        )
        vit_features = vit_outputs.last_hidden_state[:, 0]
        vit_projected = self.vit_projector(vit_features)

        mobilenet_weighted = mobilenet_projected * 0.5
        vit_weighted = vit_projected * 0.5
        fused_features = mobilenet_weighted + vit_weighted

        processed_features = self.fusion_classifier(fused_features)
        output = self.classifier(processed_features)

        if return_features:
            feature_dict = {
                'mobilenet_features': mobilenet_features,
                'vit_features': vit_features,
                'mobilenet_projected': mobilenet_projected,
                'vit_projected': vit_projected,
                'mobilenet_weighted': mobilenet_weighted,
                'vit_weighted': vit_weighted,
                'fused_features': fused_features,
                'vit_attentions': vit_outputs.attentions
            }
            return output, feature_dict

        return output


# Model registry
HER2_MODEL_REGISTRY = {
    'MobileNetV3': {
        'architecture_class': MobileNetV3_Architecture,
        'display_name': 'MobileNetV3-Large',
        'requires_dual_input': False,
        'supports_attention': True,
        'attention_method': 'grad_cam'
    },
    'ViT': {
        'architecture_class': ViT_Architecture,
        'display_name': 'Vision Transformer',
        'requires_dual_input': False,
        'supports_attention': True,
        'attention_method': 'native_transformer'
    },
    'FusionConcat': {
        'architecture_class': FusionConcat_Architecture,
        'display_name': 'Fusion (Concatenation)',
        'requires_dual_input': True,
        'supports_attention': True,
        'attention_method': 'dual_branch'
    },
    'FusionAddition': {
        'architecture_class': FusionAddition_Architecture,
        'display_name': 'Fusion (Addition)',
        'requires_dual_input': True,
        'supports_attention': True,
        'attention_method': 'dual_branch'
    }
}

print(f"✓ Defined {len(HER2_MODEL_REGISTRY)} model architectures")

# ============================================================================
# SECTION 5: PREPROCESSING FUNCTIONS
# ============================================================================

print("\n[4/6] Defining preprocessing functions...")

def apply_standard_preprocessing(image: Image.Image) -> Image.Image:
    """Apply standard preprocessing: resize to 1024px with LANCZOS"""
    if image.size != (TARGET_SIZE, TARGET_SIZE):
        image = image.resize(
            (TARGET_SIZE, TARGET_SIZE),
            Image.Resampling.LANCZOS
        )
    return image


def apply_medical_preprocessing(image: Image.Image) -> Image.Image:
    """Apply medical preprocessing: CLAHE + tissue mask + background removal"""
    img_array = np.array(image)

    if img_array.shape[:2] != (TARGET_SIZE, TARGET_SIZE):
        img_array = cv2.resize(
            img_array,
            (TARGET_SIZE, TARGET_SIZE),
            interpolation=cv2.INTER_LANCZOS4
        )

    hsv_image = cv2.cvtColor(img_array, cv2.COLOR_RGB2HSV)

    hsv_lower = np.array([0, 5, 5], dtype=np.uint8)
    hsv_upper = np.array([180, 255, 250], dtype=np.uint8)
    mask = cv2.inRange(hsv_image, hsv_lower, hsv_upper)

    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
    mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
    mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)

    enhanced_channels = []
    for channel in range(3):
        clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
        enhanced = clahe.apply(img_array[:, :, channel].astype(np.uint8))
        enhanced_channels.append(enhanced)

    enhanced_img = np.stack(enhanced_channels, axis=2)
    enhanced_img[mask == 0] = [240, 240, 240]

    return Image.fromarray(enhanced_img.astype(np.uint8))


def prepare_model_input(
    image: Image.Image,
    model_name: str
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
    """Prepare model input tensors"""
    if image.mode != 'RGB':
        image = image.convert('RGB')

    processed_image = apply_standard_preprocessing(image)

    model_info = HER2_MODEL_REGISTRY[model_name]
    requires_dual = model_info['requires_dual_input']

    if requires_dual:
        standard_tensor = standard_transform(processed_image).unsqueeze(0).to(device)
        vit_inputs = vit_processor(processed_image, return_tensors="pt")
        vit_tensor = vit_inputs['pixel_values'].to(device)
        return standard_tensor, vit_tensor
    else:
        if model_name == 'ViT':
            vit_inputs = vit_processor(processed_image, return_tensors="pt")
            vit_tensor = vit_inputs['pixel_values'].to(device)
            return vit_tensor, None
        else:
            standard_tensor = standard_transform(processed_image).unsqueeze(0).to(device)
            return standard_tensor, None


print("✓ Preprocessing functions defined")

# ============================================================================
# SECTION 6: MODEL LOADING & INFERENCE
# ============================================================================

print("\n[5/6] Defining model loading functions...")

def load_her2_model_checkpoint(
    model_name: str,
    task: str,
    num_classes: int,
    preprocessing_variant: str = 'orig'
) -> Tuple[nn.Module, Dict[str, Any]]:
    """Load HER2 model checkpoint with variant support"""

    if preprocessing_variant not in ['orig', 'prep']:
        raise ValueError(f"Invalid preprocessing_variant: {preprocessing_variant}")

    checkpoint_key = f"{model_name}_{task}_{preprocessing_variant}"

    if checkpoint_key not in HER2_CHECKPOINT_PATHS:
        raise KeyError(f"Checkpoint not found: {checkpoint_key}")

    checkpoint_path = HER2_CHECKPOINT_PATHS[checkpoint_key]

    if not os.path.exists(checkpoint_path):
        raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}")

    print(f"Loading {model_name} for {task} task ({preprocessing_variant})...")

    checkpoint = None
    try:
        checkpoint = torch.load(checkpoint_path, map_location='cpu')
        loading_method = "standard"
    except Exception:
        try:
            with open(checkpoint_path, 'rb') as f:
                checkpoint = pickle.load(f)
            loading_method = "pickle"
        except Exception:
            checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
            loading_method = "weights_only_false"

    architecture_class = HER2_MODEL_REGISTRY[model_name]['architecture_class']
    model = architecture_class(num_classes=num_classes, dropout_rate=0.2).to(device)

    state_dict = checkpoint.get('model_state_dict', checkpoint)

    try:
        model.load_state_dict(state_dict, strict=True)
        load_status = "strict"
    except Exception:
        model.load_state_dict(state_dict, strict=False)
        load_status = "non-strict"

    model.eval()

    training_info = {
        'best_val_f1': float(checkpoint.get('best_f1', 0.0)),
        'best_epoch': int(checkpoint.get('epoch', 0)) + 1,
        'checkpoint_file': os.path.basename(checkpoint_path),
        'loading_method': loading_method,
        'load_status': load_status,
        'model_type': model_name,
        'task': task,
        'num_classes': num_classes,
        'preprocessing_variant': preprocessing_variant
    }

    print(f"  ✓ Loaded ({load_status}), Val F1: {training_info['best_val_f1']:.4f}")

    return model, training_info


def run_her2_inference(
    model: nn.Module,
    model_name: str,
    input_tensor: torch.Tensor,
    vit_tensor: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Run HER2 model inference"""
    model.eval()

    with torch.no_grad():
        model_info = HER2_MODEL_REGISTRY[model_name]

        if model_info['requires_dual_input']:
            if vit_tensor is None:
                raise ValueError(f"{model_name} requires dual input")
            outputs = model(input_tensor, vit_tensor)
        elif model_name == 'ViT':
            outputs = model(input_tensor)
        else:
            outputs = model(input_tensor)

        if isinstance(outputs, torch.Tensor):
            logits = outputs
        elif isinstance(outputs, (tuple, list)):
            logits = outputs[0]
        else:
            logits = outputs

        probabilities = torch.softmax(logits, dim=1)
        predictions = torch.argmax(probabilities, dim=1)

    return predictions, probabilities


def format_her2_results(
    predictions: torch.Tensor,
    probabilities: torch.Tensor,
    task: str
) -> Dict[str, Any]:
    """Format HER2 prediction results"""
    pred_class_idx = predictions.item()
    class_probs = probabilities[0].cpu().numpy()

    task_config = HER2_TASK_CONFIGS[task]
    class_names = task_config['class_names']

    results = {
        'predicted_class': class_names[pred_class_idx],
        'predicted_index': pred_class_idx,
        'confidence': float(class_probs[pred_class_idx]),
        'all_probabilities': {
            class_names[i]: float(class_probs[i])
            for i in range(len(class_names))
        },
        'task': task,
        'task_description': task_config['description']
    }

    return results


# Model cache system
class HER2ModelCache:
    """Cache for loaded HER2 models"""

    def __init__(self):
        self.cache = {}
        self.stats = {'hits': 0, 'misses': 0}

    def get(self, model_name, task, num_classes, preprocessing_variant='orig'):
        cache_key = f"{model_name}_{task}_{num_classes}_{preprocessing_variant}"
        if cache_key in self.cache:
            self.stats['hits'] += 1
            print(f"  ✓ Cache hit: {cache_key}")
            return self.cache[cache_key]
        else:
            self.stats['misses'] += 1
            return None

    def set(self, model_name, task, model, info):
        num_classes = info.get('num_classes', 2)
        preprocessing_variant = info.get('preprocessing_variant', 'orig')
        cache_key = f"{model_name}_{task}_{num_classes}_{preprocessing_variant}"
        self.cache[cache_key] = (model, info)
        print(f"  ✓ Cached: {cache_key}")

    def clear(self):
        self.cache.clear()
        self.stats = {'hits': 0, 'misses': 0}
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

her2_model_cache = HER2ModelCache()

print("✓ Model loading functions defined")

# ============================================================================
# COMPLETION
# ============================================================================

print("\n[6/6] Verification...")
print(f"✓ Checkpoints: {len(HER2_CHECKPOINT_PATHS)}")
print(f"✓ Architectures: {len(HER2_MODEL_REGISTRY)}")
print(f"✓ Tasks: {len(HER2_TASK_CONFIGS)}")
print(f"✓ Preprocessing: standard + medical")

print("\n" + "=" * 80)
print("CELL 2 COMPLETE - HER2 MODEL SYSTEM READY")
print("=" * 80)
print("\nNext: Run Cell 3 for interpretability functions")
print("=" * 80)

HER2 PROJECT - MODEL LOADING & PREPROCESSING

[1/6] Configuring HER2 model checkpoints...
✓ Configured 16 model checkpoints
  - Original preprocessing: 8 models
  - Medical preprocessing: 8 models

[2/6] Setting up preprocessing pipelines...


preprocessor_config.json:   0%|          | 0.00/160 [00:00<?, ?B/s]

✓ Preprocessing pipelines configured
✓ Task configurations: ['IHC', 'HER2']

[3/6] Defining model architectures...
✓ Defined 4 model architectures

[4/6] Defining preprocessing functions...
✓ Preprocessing functions defined

[5/6] Defining model loading functions...
✓ Model loading functions defined

[6/6] Verification...
✓ Checkpoints: 16
✓ Architectures: 4
✓ Tasks: 2
✓ Preprocessing: standard + medical

CELL 2 COMPLETE - HER2 MODEL SYSTEM READY

Next: Run Cell 3 for interpretability functions


In [3]:
# @title Cell 3: HER2 Model Interpretability Functions

# File: MMLab_Project_Demo.ipynb - Cell 3
# Location: Thesis_MER_Project/MMLab_Project_Demo.ipynb
# Purpose: Attention visualization and interpretability for HER2 models

import matplotlib.cm as cm
import torch.nn.functional as F

print("=" * 80)
print("HER2 PROJECT - MODEL INTERPRETABILITY")
print("=" * 80)

# ============================================================================
# SECTION 1: GRAD-CAM EXTRACTOR (MobileNetV3)
# ============================================================================

print("\n[1/4] Defining Grad-CAM extractor...")

class GradCAMExtractor:
    """Grad-CAM implementation for MobileNetV3 attention visualization"""

    def __init__(self, model, device, target_layer_names=None, max_layers=3):
        self.model = model
        self.device = device if isinstance(device, torch.device) else torch.device(device)
        self.model.to(self.device).eval()

        self.hooks = []
        self.activations = {}
        self.gradients = {}

        if target_layer_names:
            self.target_layers = self._resolve_layer_names(target_layer_names)
        else:
            self.target_layers = self._auto_detect_layers(max_layers=max_layers)

        if len(self.target_layers) == 0:
            raise RuntimeError("No suitable target layers found")

        print(f"  ✓ Initialized with {len(self.target_layers)} target layers")

        for name, module in self.target_layers.items():
            h = module.register_forward_hook(self._make_forward_hook(name))
            self.hooks.append(h)

    def _resolve_layer_names(self, names):
        found = {}
        for n in names:
            layer = self._get_layer_by_path(n)
            if layer is not None:
                found[n] = layer
        return found

    def _get_layer_by_path(self, path):
        try:
            parts = path.split('.')
            cur = self.model
            for p in parts:
                if p.isdigit():
                    cur = cur[int(p)]
                else:
                    cur = getattr(cur, p)
            return cur
        except Exception:
            return None

    def _auto_detect_layers(self, max_layers=3):
        candidates = [
            'mobilenet.blocks.4', 'mobilenet.blocks.5', 'mobilenet.blocks.6',
            'mobilenet.conv_head', 'mobilenet.conv_stem'
        ]

        found = {}
        for name in candidates:
            layer = self._get_layer_by_path(name)
            if layer is not None:
                last_conv = self._find_last_conv_in_module(layer, prefix=name)
                if last_conv:
                    found[last_conv[0]] = last_conv[1]
                elif isinstance(layer, nn.Conv2d):
                    found[name] = layer
            if len(found) >= max_layers:
                break

        if len(found) == 0:
            last_conv = self._find_last_conv_in_module(self.model, prefix='')
            if last_conv:
                found[last_conv[0]] = last_conv[1]

        limited = {}
        for i, (k, v) in enumerate(found.items()):
            if i >= max_layers:
                break
            limited[k] = v

        return limited

    def _find_last_conv_in_module(self, module, prefix=''):
        last = None
        for name, child in module.named_modules():
            full_name = f"{prefix}.{name}" if prefix and name else (name or prefix)
            if isinstance(child, nn.Conv2d):
                last = (full_name.strip('.'), child)
        return last

    def _make_forward_hook(self, layer_name):
        def hook(module, input, output):
            act = output
            self.activations[layer_name] = act

            def _grad_hook(grad):
                self.gradients[layer_name] = grad

            try:
                if torch.is_tensor(act):
                    act.register_hook(_grad_hook)
                elif isinstance(act, (tuple, list)) and len(act) > 0 and torch.is_tensor(act[0]):
                    act[0].register_hook(_grad_hook)
            except Exception:
                pass

        return hook

    def _clear(self):
        self.activations.clear()
        self.gradients.clear()

    def generate_cam(self, input_tensor, target_class):
        self._clear()

        input_tensor = input_tensor.to(self.device)
        self.model.zero_grad()
        output = self.model(input_tensor)

        if isinstance(output, (tuple, list)):
            logits = output[0]
        else:
            logits = output

        score = logits[0, int(target_class)]
        score.backward(retain_graph=False)

        layer_cams = []
        layer_weights = []

        for name in self.target_layers.keys():
            if name in self.activations and name in self.gradients:
                act = self.activations[name]
                grad = self.gradients[name]

                if isinstance(act, (tuple, list)):
                    act = act[0]
                if isinstance(grad, (tuple, list)):
                    grad = grad[0]

                activations = act[0]
                gradients = grad[0]

                weights = torch.mean(gradients.view(gradients.size(0), -1), dim=1)

                cam = torch.zeros(activations.shape[1:], dtype=torch.float32, device=activations.device)
                for c, w in enumerate(weights):
                    cam += w * activations[c]

                cam = F.relu(cam)
                if cam.max() > 0:
                    cam = cam - cam.min()
                    cam = cam / (cam.max() + 1e-8)

                layer_cams.append(cam.detach().cpu().numpy())
                layer_weights.append(torch.mean(torch.abs(gradients)).item())

        if len(layer_cams) == 0:
            raise RuntimeError("No CAMs generated")

        if len(layer_cams) == 1:
            final_cam = layer_cams[0]
        else:
            weights = np.array(layer_weights)
            if weights.sum() == 0:
                weights = np.ones_like(weights)
            weights = weights / weights.sum()

            target_shape = max(cam.shape for cam in layer_cams)
            resized_cams = []
            for cam in layer_cams:
                if cam.shape != target_shape:
                    cam_resized = cv2.resize(cam, (target_shape[1], target_shape[0]),
                                           interpolation=cv2.INTER_LINEAR)
                    resized_cams.append(cam_resized)
                else:
                    resized_cams.append(cam)

            final_cam = np.zeros_like(resized_cams[0])
            for cam, w in zip(resized_cams, weights):
                final_cam += w * cam

        final_cam = np.clip(final_cam, 0, 1)
        if final_cam.max() > 0:
            final_cam = final_cam / final_cam.max()

        final_cam = cv2.GaussianBlur(final_cam, (3, 3), 0)

        if final_cam.max() > 0:
            final_cam = final_cam / final_cam.max()

        return final_cam

    def release(self):
        for h in self.hooks:
            try:
                h.remove()
            except Exception:
                pass
        self.hooks = []


print("✓ Grad-CAM extractor defined")

# ============================================================================
# SECTION 2: VIT ATTENTION EXTRACTOR
# ============================================================================

print("\n[2/4] Defining ViT attention extractor...")

class ViTAttentionExtractor:
    """Native ViT attention extraction and visualization"""

    def __init__(self, model, device, patch_size=32, image_size=1024):
        self.model = model
        self.device = device
        self.patch_size = patch_size
        self.image_size = image_size
        self.num_patches_per_side = image_size // patch_size
        self.num_patches = self.num_patches_per_side ** 2

        self.model.eval()
        print(f"  ✓ Initialized for {self.num_patches_per_side}x{self.num_patches_per_side} patches")

    def extract_attention(self, pixel_values):
        pixel_values = pixel_values.to(self.device)

        with torch.no_grad():
            try:
                if hasattr(self.model, 'vit'):
                    outputs = self.model.vit(
                        pixel_values=pixel_values,
                        interpolate_pos_encoding=True,
                        output_attentions=True,
                        return_dict=True
                    )
                    attentions = outputs.attentions
                else:
                    output, attentions = self.model(pixel_values, output_attentions=True)

                if attentions is None:
                    raise RuntimeError("Model did not return attention weights")

                return attentions

            except Exception as e:
                raise RuntimeError(f"Attention extraction failed: {e}")

    def compute_attention_map(self, attentions):
        if attentions is None:
            return None

        num_layers = min(4, len(attentions))
        selected_layers = attentions[-num_layers:]

        layer_attentions = []
        layer_weights = [0.4, 0.3, 0.2, 0.1][:num_layers]

        for layer_att in selected_layers:
            layer_att = layer_att[0]

            head_vars = torch.var(layer_att.view(layer_att.size(0), -1), dim=1)
            num_keep = max(1, layer_att.size(0) // 4)
            top_heads = torch.argsort(head_vars, descending=True)[:num_keep]

            selected_att = layer_att[top_heads].mean(dim=0)
            cls_att = selected_att[0, 1:]
            layer_attentions.append(cls_att)

        weights = np.array(layer_weights[:len(layer_attentions)])
        weights = weights / weights.sum()

        combined_att = torch.zeros_like(layer_attentions[0])
        for att, w in zip(layer_attentions, weights):
            if not torch.is_tensor(att):
                att = torch.tensor(att, device=combined_att.device, dtype=combined_att.dtype)
            combined_att += float(w) * att

        attention_np = combined_att.detach().cpu().numpy()

        threshold = np.percentile(attention_np, 15)
        attention_np = np.where(attention_np < threshold, 0, attention_np)

        if attention_np.max() > 0:
            attention_np = attention_np / attention_np.max()

        return attention_np

    def apply_tissue_filtering(self, attention_weights, original_image_np):
        tissue_mask = self.create_tissue_mask(original_image_np).astype(np.uint8)
        tissue_bool = (tissue_mask > 0).astype(bool)

        h, w = tissue_bool.shape
        patch_h = max(1, h // self.num_patches_per_side)
        patch_w = max(1, w // self.num_patches_per_side)

        patch_tissue = np.zeros((self.num_patches_per_side, self.num_patches_per_side), dtype=float)

        for r in range(self.num_patches_per_side):
            for c in range(self.num_patches_per_side):
                y0, y1 = r * patch_h, min(h, (r + 1) * patch_h)
                x0, x1 = c * patch_w, min(w, (c + 1) * patch_w)
                region = tissue_bool[y0:y1, x0:x1]
                if region.size > 0:
                    patch_tissue[r, c] = region.mean()

        patch_tissue_flat = patch_tissue.ravel()

        tissue_threshold = 0.10
        attention_filtered = attention_weights.copy()
        attention_filtered[patch_tissue_flat < tissue_threshold] = 0.0

        if attention_filtered.sum() > 0:
            attention_filtered = attention_filtered / attention_filtered.sum()
        else:
            attention_filtered = attention_weights

        return attention_filtered

    def attention_to_heatmap(self, attention_weights, original_image_np):
        if attention_weights is None:
            return None

        attention_filtered = self.apply_tissue_filtering(attention_weights, original_image_np)

        attention_2d = attention_filtered.reshape(
            self.num_patches_per_side,
            self.num_patches_per_side
        )

        padded = np.pad(attention_2d, pad_width=2, mode='edge')
        heatmap_large = cv2.resize(
            padded,
            (self.image_size + 64, self.image_size + 64),
            interpolation=cv2.INTER_CUBIC
        )

        heatmap = heatmap_large[32:-32, 32:-32]

        if heatmap.shape != (self.image_size, self.image_size):
            heatmap = cv2.resize(heatmap, (self.image_size, self.image_size))

        if heatmap.max() > heatmap.min():
            heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min())
        else:
            heatmap = np.ones_like(heatmap) * 0.5

        heatmap = cv2.GaussianBlur(heatmap, (3, 3), 0.5)

        if heatmap.max() > 0:
            heatmap = heatmap / heatmap.max()

        return heatmap

    def create_tissue_mask(self, image_rgb):
        hsv = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2HSV)

        saturation_thresh = 15
        value_thresh = 35

        mask = (hsv[:,:,1] > saturation_thresh) & (hsv[:,:,2] > value_thresh)
        mask = mask.astype(np.uint8) * 255

        kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (15, 15))
        mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
        mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
        mask = cv2.GaussianBlur(mask, (5, 5), 0)

        return mask


print("✓ ViT attention extractor defined")

# ============================================================================
# SECTION 3: FUSION ATTENTION EXTRACTOR
# ============================================================================

print("\n[3/4] Defining Fusion attention extractor...")

class FusionAttentionExtractor:
    """Dual-branch attention extraction for Fusion models"""

    def __init__(self, model, device, fusion_type='concat', patch_size=32, image_size=1024):
        self.model = model
        self.device = device
        self.fusion_type = fusion_type
        self.patch_size = patch_size
        self.image_size = image_size
        self.num_patches_per_side = image_size // patch_size
        self.num_patches = self.num_patches_per_side ** 2

        self.gradients = {}
        self.activations = {}
        self.hooks = []

        self.model.eval()
        print(f"  ✓ Initialized for {fusion_type} fusion")

    def _register_hooks(self):
        def make_hook(name):
            def hook(module, input, output):
                self.activations[name] = output
                def grad_hook(grad):
                    self.gradients[name] = grad
                if torch.is_tensor(output):
                    output.register_hook(grad_hook)
                elif isinstance(output, (tuple, list)) and len(output) > 0:
                    if torch.is_tensor(output[0]):
                        output[0].register_hook(grad_hook)
            return hook

        h = self.model.mobilenet_projector.register_forward_hook(make_hook('mobilenet_projected'))
        self.hooks.append(h)

    def _clear_hooks(self):
        for h in self.hooks:
            try:
                h.remove()
            except:
                pass
        self.hooks = []
        self.gradients.clear()
        self.activations.clear()

    def create_tissue_mask(self, image_rgb):
        hsv = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2HSV)

        saturation_thresh = 15
        value_thresh = 35

        mask = (hsv[:,:,1] > saturation_thresh) & (hsv[:,:,2] > value_thresh)
        mask = mask.astype(np.uint8) * 255

        kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (15, 15))
        mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
        mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
        mask = cv2.GaussianBlur(mask, (5, 5), 0)

        return mask

    def extract_mobilenet_attention(self, images, target_class, extracted_features):
        self._register_hooks()

        try:
            self.model.zero_grad()
            mobilenet_weighted = extracted_features['mobilenet_weighted']

            dummy_scores = torch.sum(mobilenet_weighted, dim=1)
            target_score = dummy_scores[0]
            target_score.backward(retain_graph=True)

            if 'mobilenet_projected' in self.gradients:
                gradients = self.gradients['mobilenet_projected'][0]
                activations = self.activations['mobilenet_projected'][0]

                weights = torch.abs(gradients)
                weighted_features = activations * weights

                spatial_size = 32
                if weighted_features.shape[0] >= spatial_size:
                    spatial_map = weighted_features[:spatial_size].view(int(np.sqrt(spatial_size)), int(np.sqrt(spatial_size)))
                else:
                    padded = F.pad(weighted_features, (0, spatial_size - weighted_features.shape[0]))
                    spatial_map = padded.view(int(np.sqrt(spatial_size)), int(np.sqrt(spatial_size)))

                spatial_np = spatial_map.detach().cpu().numpy()
                spatial_np = spatial_np - spatial_np.min()
                if spatial_np.max() > 0:
                    spatial_np = spatial_np / spatial_np.max()

                return spatial_np
            else:
                mobilenet_weighted_np = mobilenet_weighted[0].detach().cpu().numpy()
                spatial_size = int(np.sqrt(min(512, 256)))
                if mobilenet_weighted_np.shape[0] >= spatial_size * spatial_size:
                    spatial_map = mobilenet_weighted_np[:spatial_size*spatial_size].reshape(spatial_size, spatial_size)
                else:
                    needed = spatial_size * spatial_size
                    padded = np.pad(mobilenet_weighted_np, (0, max(0, needed - len(mobilenet_weighted_np))))
                    spatial_map = padded[:needed].reshape(spatial_size, spatial_size)

                spatial_map = spatial_map - spatial_map.min()
                if spatial_map.max() > 0:
                    spatial_map = spatial_map / spatial_map.max()

                return spatial_map

        except Exception as e:
            print(f"  Warning: MobileNet attention extraction failed: {e}")
            return np.ones((16, 16)) * 0.5

        finally:
            self._clear_hooks()

    def extract_vit_attention(self, extracted_features, original_image_np):
        try:
            vit_attentions = extracted_features['vit_attentions']

            if vit_attentions is None or len(vit_attentions) == 0:
                return None

            num_layers_to_use = min(4, len(vit_attentions))
            selected_layers = vit_attentions[-num_layers_to_use:]

            layer_attentions = []
            layer_weights = [0.4, 0.3, 0.2, 0.1]

            for layer_attention in selected_layers:
                layer_att = layer_attention[0]

                head_variances = torch.var(layer_att.view(layer_att.size(0), -1), dim=1)
                num_keep = max(1, layer_att.size(0) // 4)
                top_heads = torch.argsort(head_variances, descending=True)[:num_keep]

                selected_heads_att = layer_att[top_heads].mean(dim=0)
                cls_att = selected_heads_att[0, 1:]
                layer_attentions.append(cls_att)

            if len(layer_attentions) == 1:
                combined_attention = layer_attentions[0]
            else:
                weights = layer_weights[:len(layer_attentions)]
                weights = np.array(weights, dtype=float)
                weights = weights / weights.sum()

                combined_attention = torch.zeros_like(layer_attentions[0])
                for att, weight in zip(layer_attentions, weights):
                    if not torch.is_tensor(att):
                        att = torch.tensor(att, device=combined_attention.device, dtype=combined_attention.dtype)
                    combined_attention = combined_attention + float(weight) * att

            attention_np = combined_attention.detach().cpu().numpy()

            discard_ratio = 0.15
            p_discard = np.percentile(attention_np, discard_ratio * 100)
            attention_filtered = np.where(attention_np < p_discard, 0, attention_np)

            if attention_filtered.max() > 0:
                attention_filtered = attention_filtered / attention_filtered.max()

            attention_tissue_filtered = self.apply_tissue_filtering_vit(attention_filtered, original_image_np)

            return attention_tissue_filtered

        except Exception as e:
            print(f"  Warning: ViT attention extraction failed: {e}")
            return None

    def apply_tissue_filtering_vit(self, attention_weights, original_image_np):
        try:
            tissue_mask = self.create_tissue_mask(original_image_np).astype(np.uint8)
            tissue_bool = (tissue_mask > 0).astype(bool)

            h, w = tissue_bool.shape
            patch_h, patch_w = max(1, h // self.num_patches_per_side), max(1, w // self.num_patches_per_side)

            patch_tissue = np.zeros((self.num_patches_per_side, self.num_patches_per_side), dtype=float)

            for r in range(self.num_patches_per_side):
                for c in range(self.num_patches_per_side):
                    y0, y1 = r * patch_h, min(h, (r + 1) * patch_h)
                    x0, x1 = c * patch_w, min(w, (c + 1) * patch_w)
                    region = tissue_bool[y0:y1, x0:x1]
                    if region.size > 0:
                        patch_tissue[r, c] = region.mean()

            patch_tissue_flat = patch_tissue.ravel()

            tissue_threshold = 0.10
            attention_filtered = attention_weights.copy()
            attention_filtered[patch_tissue_flat < tissue_threshold] = 0.0

            if attention_filtered.sum() > 0:
                attention_filtered = attention_filtered / attention_filtered.sum()
            else:
                attention_filtered = attention_weights

            return attention_filtered

        except Exception:
            return attention_weights

    def attention_to_heatmap(self, attention_weights, target_size=(1024, 1024)):
        if attention_weights is None:
            return None

        try:
            if len(attention_weights.shape) == 1:
                if len(attention_weights) == self.num_patches:
                    attention_2d = attention_weights.reshape(self.num_patches_per_side, self.num_patches_per_side)
                else:
                    side = int(np.sqrt(len(attention_weights)))
                    if side * side == len(attention_weights):
                        attention_2d = attention_weights.reshape(side, side)
                    else:
                        needed = side * side if side > 0 else 16 * 16
                        padded = np.pad(attention_weights, (0, max(0, needed - len(attention_weights))))
                        side = int(np.sqrt(needed))
                        attention_2d = padded[:needed].reshape(side, side)
            else:
                attention_2d = attention_weights

            padded_attention = np.pad(attention_2d, pad_width=2, mode='edge')
            heatmap_large = cv2.resize(
                padded_attention,
                (target_size[0] + 64, target_size[1] + 64),
                interpolation=cv2.INTER_CUBIC
            )

            crop_size = 32
            heatmap = heatmap_large[crop_size:-crop_size, crop_size:-crop_size]

            if heatmap.shape != target_size:
                heatmap = cv2.resize(heatmap, target_size, interpolation=cv2.INTER_CUBIC)

            heatmap_min, heatmap_max = heatmap.min(), heatmap.max()
            if heatmap_max > heatmap_min:
                heatmap = (heatmap - heatmap_min) / (heatmap_max - heatmap_min)
            else:
                heatmap = np.ones_like(heatmap) * 0.5

            heatmap = cv2.GaussianBlur(heatmap, (3, 3), 0.5)

            if heatmap.max() > 0:
                heatmap = heatmap / heatmap.max()

            return heatmap

        except Exception as e:
            print(f"  Warning: Heatmap conversion failed: {e}")
            return None

    def generate_fusion_attention(self, images, pixel_values, target_class, original_image):
        images = images.to(self.device)
        pixel_values = pixel_values.to(self.device)
        original_image_np = np.array(original_image)

        try:
            output, extracted_features = self.model(images, pixel_values, return_features=True)

            mobilenet_attention = self.extract_mobilenet_attention(images, target_class, extracted_features)
            vit_attention = self.extract_vit_attention(extracted_features, original_image_np)

            mobilenet_heatmap = self.attention_to_heatmap(mobilenet_attention)
            vit_heatmap = self.attention_to_heatmap(vit_attention) if vit_attention is not None else None

            fusion_heatmap = None
            if mobilenet_heatmap is not None and vit_heatmap is not None:
                fusion_heatmap = 0.5 * mobilenet_heatmap + 0.5 * vit_heatmap

                if fusion_heatmap.max() > 0:
                    fusion_heatmap = fusion_heatmap / fusion_heatmap.max()
            elif mobilenet_heatmap is not None:
                fusion_heatmap = mobilenet_heatmap
            elif vit_heatmap is not None:
                fusion_heatmap = vit_heatmap

            return {
                'mobilenet_heatmap': mobilenet_heatmap,
                'vit_heatmap': vit_heatmap,
                'fusion_heatmap': fusion_heatmap
            }

        except Exception as e:
            print(f"  Warning: Fusion attention generation failed: {e}")
            return {
                'mobilenet_heatmap': None,
                'vit_heatmap': None,
                'fusion_heatmap': None
            }


print("✓ Fusion attention extractor defined")

# ============================================================================
# SECTION 4: OVERLAY & INTERPRETATION FUNCTIONS
# ============================================================================

print("\n[4/4] Defining visualization functions...")

def create_medical_overlay(
    original_image,
    heatmap,
    alpha: float = 0.4,
    colormap: str = 'magma',
    apply_tissue_mask: bool = True
) -> Image.Image:
    """Create medical overlay with robust size/channel handling"""

    # Convert original to RGB uint8
    if isinstance(original_image, Image.Image):
        original_np = np.array(original_image.convert('RGB'))
    else:
        original_np = np.array(original_image)
        if original_np.ndim == 2:
            original_np = cv2.cvtColor(original_np, cv2.COLOR_GRAY2RGB)
        elif original_np.shape[2] == 4:
            original_np = original_np[:, :, :3]

    if original_np.dtype != np.uint8:
        if original_np.max() <= 1.0:
            original_np = (original_np * 255).astype(np.uint8)
        else:
            original_np = original_np.astype(np.uint8)

    H, W = original_np.shape[:2]

    if heatmap is None:
        return Image.fromarray(original_np)

    if isinstance(heatmap, Image.Image):
        heat_np = np.array(heatmap)
    else:
        heat_np = np.array(heatmap)

    if heat_np.ndim == 3:
        if heat_np.shape[2] == 4:
            heat_rgb = heat_np[:, :, :3]
        elif heat_np.shape[2] == 3:
            heat_rgb = heat_np
        else:
            heat_rgb = heat_np[:, :, :3]

        if np.issubdtype(heat_rgb.dtype, np.floating):
            heat_rgb = np.clip(heat_rgb, 0.0, 1.0)
            heat_rgb = (heat_rgb * 255).astype(np.uint8)
        else:
            heat_rgb = heat_rgb.astype(np.uint8)

        if heat_rgb.shape[:2] != (H, W):
            heatmap_rgb = cv2.resize(heat_rgb, (W, H), interpolation=cv2.INTER_CUBIC)
        else:
            heatmap_rgb = heat_rgb

    elif heat_np.ndim == 2:
        heat_single = heat_np.astype(np.float32)

        if heat_single.max() > 1.1:
            heat_single = heat_single / 255.0
        heat_single = np.clip(heat_single, 0.0, 1.0)

        if heat_single.shape != (H, W):
            heat_single = cv2.resize(heat_single, (W, H), interpolation=cv2.INTER_CUBIC)

        if colormap == 'magma':
            cmap = cm.magma
        elif colormap == 'inferno':
            cmap = cm.inferno
        elif colormap == 'viridis':
            cmap = cm.viridis
        else:
            cmap = cm.jet

        heat_rgba = cmap(heat_single)
        heat_rgbf = heat_rgba[:, :, :3]
        heatmap_rgb = (heat_rgbf * 255).astype(np.uint8)

    else:
        return Image.fromarray(original_np)

    if heatmap_rgb.shape[:2] != (H, W):
        heatmap_rgb = cv2.resize(heatmap_rgb, (W, H), interpolation=cv2.INTER_CUBIC)

    if heatmap_rgb.shape[2] != 3:
        heatmap_rgb = heatmap_rgb[:, :, :3]

    if heatmap_rgb.dtype != np.uint8:
        heatmap_rgb = heatmap_rgb.astype(np.uint8)

    try:
        overlaid = cv2.addWeighted(original_np, 1 - alpha, heatmap_rgb, alpha, 0)
    except Exception as e:
        print(f"  Warning: Overlay blend failed: {e}")
        return Image.fromarray(original_np)

    return Image.fromarray(overlaid)


def generate_her2_interpretation(
    model,
    model_name: str,
    input_tensor: torch.Tensor,
    vit_tensor: Optional[torch.Tensor],
    target_class: int,
    original_image: Image.Image,
    device: torch.device
) -> Dict[str, Optional[Image.Image]]:
    """Generate HER2 model interpretability visualizations"""
    results = {}

    try:
        if model_name == 'MobileNetV3':
            gradcam = GradCAMExtractor(model, device)
            cam = gradcam.generate_cam(input_tensor, target_class)

            overlay = create_medical_overlay(original_image, cam, alpha=0.35, colormap='magma')
            results['attention'] = overlay

            gradcam.release()

        elif model_name == 'ViT':
            vit_extractor = ViTAttentionExtractor(model, device)

            attentions = vit_extractor.extract_attention(input_tensor)
            attention_map = vit_extractor.compute_attention_map(attentions)

            original_np = np.array(original_image)
            heatmap = vit_extractor.attention_to_heatmap(attention_map, original_np)

            if heatmap is not None:
                overlay = create_medical_overlay(original_image, heatmap, alpha=0.4, colormap='magma')
                results['attention'] = overlay

        elif model_name in ['FusionConcat', 'FusionAddition']:
            fusion_type = 'concat' if model_name == 'FusionConcat' else 'addition'
            fusion_extractor = FusionAttentionExtractor(model, device, fusion_type=fusion_type)

            fusion_results = fusion_extractor.generate_fusion_attention(
                input_tensor, vit_tensor, target_class, original_image
            )

            if fusion_results['mobilenet_heatmap'] is not None:
                results['mobilenet_attention'] = create_medical_overlay(
                    original_image,
                    fusion_results['mobilenet_heatmap'],
                    alpha=0.35,
                    colormap='magma'
                )

            if fusion_results['vit_heatmap'] is not None:
                results['vit_attention'] = create_medical_overlay(
                    original_image,
                    fusion_results['vit_heatmap'],
                    alpha=0.35,
                    colormap='magma'
                )

            if fusion_results['fusion_heatmap'] is not None:
                results['fusion_attention'] = create_medical_overlay(
                    original_image,
                    fusion_results['fusion_heatmap'],
                    alpha=0.35,
                    colormap='magma'
                )

    except Exception as e:
        print(f"  Warning: Interpretation generation failed: {e}")
        results['error'] = str(e)

    return results


print("✓ Visualization functions defined")

print("\n" + "=" * 80)
print("CELL 3 COMPLETE - HER2 INTERPRETABILITY READY")
print("=" * 80)
print("\nNext: Run Cell 4 for interactive interface")
print("=" * 80)

HER2 PROJECT - MODEL INTERPRETABILITY

[1/4] Defining Grad-CAM extractor...
✓ Grad-CAM extractor defined

[2/4] Defining ViT attention extractor...
✓ ViT attention extractor defined

[3/4] Defining Fusion attention extractor...
✓ Fusion attention extractor defined

[4/4] Defining visualization functions...
✓ Visualization functions defined

CELL 3 COMPLETE - HER2 INTERPRETABILITY READY

Next: Run Cell 4 for interactive interface


In [4]:
# @title Cell 4: HER2 Prediction Pipeline

# File: MMLab_Project_Demo.ipynb - Cell 4
# Location: Thesis_MER_Project/MMLab_Project_Demo.ipynb
# Purpose: HER2 prediction pipeline with flexible model name handling

import time

print("=" * 80)
print("HER2 PROJECT - PREDICTION PIPELINE (FIXED)")
print("=" * 80)

# ============================================================================
# SECTION 1: HER2 PREDICTION PIPELINE (FIXED MODEL HANDLING)
# ============================================================================

print("\n[1/2] Defining HER2 prediction pipeline...")

def her2_predict_pipeline(
    image,
    preprocessing_choice,
    model_choice,
    task_choice
):
    """
    Complete HER2 prediction pipeline with flexible model name handling

    Accepts both display names and internal keys for backward compatibility
    """
    try:
        # Step 1: Image validation
        if image is None:
            return (
                None,
                "### Error\n\nPlease upload an image first.",
                {},
                None, None, None,
                gr.update(visible=False),
                "", "", ""
            )

        if isinstance(image, np.ndarray):
            image = Image.fromarray(image)

        if image.mode != 'RGB':
            image = image.convert('RGB')

        # Step 2: Model name handling (FIXED)
        # Check if model_choice is already an internal key
        if model_choice in HER2_MODEL_REGISTRY:
            model_name = model_choice
            print(f"Using internal key: {model_name}")
        else:
            # Try mapping from display name
            model_map = {
                'MobileNetV3-Large': 'MobileNetV3',
                'Vision Transformer (ViT)': 'ViT',
                'Fusion (Concatenation)': 'FusionConcat',
                'Fusion (Addition)': 'FusionAddition'
            }

            if model_choice in model_map:
                model_name = model_map[model_choice]
                print(f"Mapped display name: {model_choice} -> {model_name}")
            else:
                raise ValueError(f"Unknown model: {model_choice}")

        # Step 3: Task handling (FIXED)
        if task_choice in HER2_TASK_CONFIGS:
            task = task_choice
        else:
            task_map = {
                'IHC Score (0-3)': 'IHC',
                'IHC Intensity': 'IHC',
                'HER2 Status (Neg/Pos)': 'HER2',
                'HER2 Status': 'HER2'
            }
            task = task_map.get(task_choice, task_choice)

        # Step 4: Preprocessing handling (FIXED)
        if preprocessing_choice in ['Standard Preprocessing', 'Medical Preprocessing']:
            preprocessing_type = 'medical' if 'Medical' in preprocessing_choice else 'standard'
            preprocessing_variant = 'prep' if 'Medical' in preprocessing_choice else 'orig'
        else:
            # Handle Cell 8 format
            preprocessing_type = 'medical' if 'Medical' in preprocessing_choice or 'Optimized' in preprocessing_choice else 'standard'
            preprocessing_variant = 'prep' if preprocessing_type == 'medical' else 'orig'

        # Step 5: Task configuration
        task_config = HER2_TASK_CONFIGS[task]
        num_classes = task_config['num_classes']

        # Step 6: Get model info
        model_info = HER2_MODEL_REGISTRY[model_name]
        is_fusion = model_info['requires_dual_input']

        # Step 7: Create preprocessing comparison
        if preprocessing_type == 'medical':
            medical_processed = apply_medical_preprocessing(image)
            standard_resized = apply_standard_preprocessing(image)

            comparison = Image.new('RGB', (TARGET_SIZE * 2, TARGET_SIZE))
            comparison.paste(standard_resized, (0, 0))
            comparison.paste(medical_processed, (TARGET_SIZE, 0))

            preprocessing_note = "Medical preprocessing applied (CLAHE enhancement + tissue detection)."
        else:
            standard_resized = apply_standard_preprocessing(image)
            comparison = standard_resized
            preprocessing_note = "Standard preprocessing (resize to 1024px + ImageNet normalization)."

        # Step 8: Model loading
        print(f"Loading {model_name} for {task} ({preprocessing_variant})...")
        cached = her2_model_cache.get(model_name, task, num_classes, preprocessing_variant)
        if cached:
            model, training_info = cached
            print("  Model from cache")
        else:
            model, training_info = load_her2_model_checkpoint(
                model_name,
                task,
                num_classes,
                preprocessing_variant
            )
            her2_model_cache.set(model_name, task, model, training_info)
            print("  Model loaded from checkpoint")

        # Step 9: Input preparation
        input_tensor, vit_tensor = prepare_model_input(image, model_name)

        # Step 10: Inference
        print("Running inference...")
        predictions, probabilities = run_her2_inference(model, model_name, input_tensor, vit_tensor)
        results = format_her2_results(predictions, probabilities, task)
        print(f"  Prediction: {results['predicted_class']} ({results['confidence']:.1%})")

        # Step 11: Interpretability generation
        print("Generating attention visualizations...")
        interpretations = {}
        interpretation_status = ""

        try:
            interpretations = generate_her2_interpretation(
                model,
                model_name,
                input_tensor,
                vit_tensor,
                int(predictions.item()),
                image,
                device
            )

            generated = []
            if 'attention' in interpretations and interpretations['attention'] is not None:
                generated.append("single attention")
            if 'mobilenet_attention' in interpretations and interpretations['mobilenet_attention'] is not None:
                generated.append("MobileNet")
            if 'vit_attention' in interpretations and interpretations['vit_attention'] is not None:
                generated.append("ViT")
            if 'fusion_attention' in interpretations and interpretations['fusion_attention'] is not None:
                generated.append("fusion")

            if generated:
                interpretation_status = f"Generated: {', '.join(generated)}"
                print(f"  {interpretation_status}")
            else:
                interpretation_status = "Warning: No visualizations generated"
                print(f"  {interpretation_status}")

        except Exception as e:
            interpretation_status = f"Visualization failed: {str(e)}"
            print(f"  {interpretation_status}")

        # Step 12: Format results
        variant_display = "Medical Preprocessing" if preprocessing_variant == 'prep' else "Standard Preprocessing"

        prediction_text = f"""
### Prediction Results

**Model:** {model_info['display_name']}
**Task:** {task_config['full_name']}
**Preprocessing:** {variant_display}

---

**Predicted Class:** `{results['predicted_class']}`
**Confidence:** {results['confidence']:.1%}

---

**Model Info:**
- Validation F1: {training_info['best_val_f1']:.4f}
- Best Epoch: {training_info['best_epoch']}
- {preprocessing_note}

---

{interpretation_status}
"""

        prob_dict = results['all_probabilities']

        # Step 13: Dynamic return based on model type
        if is_fusion:
            return (
                comparison,
                prediction_text,
                prob_dict,
                interpretations.get('mobilenet_attention'),
                interpretations.get('vit_attention'),
                interpretations.get('fusion_attention'),
                gr.update(visible=True),
                "MobileNet Branch Attention",
                "ViT Branch Attention",
                "Combined Fusion Attention"
            )
        else:
            attention_method = model_info['attention_method']
            if attention_method == 'grad_cam':
                label = "Grad-CAM Attention Map"
            elif attention_method == 'native_transformer':
                label = "Transformer Attention Map"
            else:
                label = "Attention Map"

            return (
                comparison,
                prediction_text,
                prob_dict,
                interpretations.get('attention'),
                None,
                None,
                gr.update(visible=True),
                label,
                "",
                ""
            )

    except Exception as e:
        error_msg = f"""
### Critical Error

**Error Type:** {type(e).__name__}
**Message:** {str(e)}

### Debug Info

**Model Choice Received:** {model_choice}
**Available Models:** {list(HER2_MODEL_REGISTRY.keys())}

### Troubleshooting

1. Verify Cell 2 executed successfully
2. Check HER2_MODEL_REGISTRY defined
3. Verify model name mapping
4. Check terminal for detailed traceback

"""
        print(f"\nERROR: {e}")
        import traceback
        traceback.print_exc()

        return (
            None, error_msg, {},
            None, None, None,
            gr.update(visible=False),
            "", "", ""
        )


print("HER2 prediction pipeline defined with flexible model handling")

# ============================================================================
# SECTION 2: COMPLETION
# ============================================================================

print("\n[2/2] Pipeline ready...")

print("\n" + "=" * 80)
print("CELL 4 COMPLETE - HER2 PIPELINE READY (FIXED)")
print("=" * 80)
print("\nFixed: Flexible model name handling")
print("  - Accepts internal keys directly")
print("  - Accepts display names with mapping")
print("  - Backward compatible with both Cell 7 and Cell 8")
print("\nNext: Run Cell 8 to launch demo")
print("=" * 80)

HER2 PROJECT - PREDICTION PIPELINE (FIXED)

[1/2] Defining HER2 prediction pipeline...
HER2 prediction pipeline defined with flexible model handling

[2/2] Pipeline ready...

CELL 4 COMPLETE - HER2 PIPELINE READY (FIXED)

Fixed: Flexible model name handling
  - Accepts internal keys directly
  - Accepts display names with mapping
  - Backward compatible with both Cell 7 and Cell 8

Next: Run Cell 8 to launch demo


In [5]:
# @title Cell 5: MER Models & Pipeline

# File: MMLab_Project_Demo.ipynb - Cell 5
# Location: Thesis_MER_Project/MMLab_Project_Demo.ipynb
# Purpose: MER models with correct channel handling per model type

print("=" * 80)
print("MER PROJECT - MODELS & PIPELINE")
print("=" * 80)

# ============================================================================
# SECTION 1: MER CHECKPOINT PATHS
# ============================================================================

print("\n[1/7] Configuring MER model checkpoints...")

MER_CHECKPOINT_PATHS = {
    'MobileNetV3_M1': f"{MER_MODELS_ROOT}/08_01_mobilenet_casme2_mfs/casme2_mobilenet_mfs_best_f1.pth",
    'EfficientNet_M1': f"{MER_MODELS_ROOT}/08_02_efficientnet_casme2_mfs/casme2_efficientnet_mfs_best_f1.pth",
    'ConvNeXt_M1': f"{MER_MODELS_ROOT}/08_03_convnext_casme2_mfs/casme2_convnext_mfs_best_f1.pth",

    'MobileNetV3_M2': f"{MER_MODELS_ROOT}/09_01_mobilenet_casme2_mfs_prep/casme2_mobilenet_mfs_prep_best_f1.pth",
    'EfficientNet_M2': f"{MER_MODELS_ROOT}/09_02_efficientnet_casme2_mfs_prep/casme2_efficientnet_mfs_prep_best_f1.pth",
    'ConvNeXt_M2': f"{MER_MODELS_ROOT}/09_03_convnext_casme2_mfs_prep/casme2_convnext_mfs_prep_best_f1.pth",

    'ViT_M1': f"{MER_MODELS_ROOT}/02_01_vit_casme2-af/casme2_vit_direct_best_f1.pth",
    'SwinTransformer_M1': f"{MER_MODELS_ROOT}/02_02_swint_casme2-af/casme2_swint_direct_best_f1.pth",
    'PoolFormer_M1': f"{MER_MODELS_ROOT}/04_03_poolformer_casme2_mfs/casme2_poolformer_multiframe_best_f1.pth",

    'ViT_M2': f"{MER_MODELS_ROOT}/05_01_vit_casme2_af_prep/casme2_vit_apex_frame_best_f1.pth",
    'SwinTransformer_M2': f"{MER_MODELS_ROOT}/07_02_swint_casme2_mfs_prep/casme2_swint_mfs_best_f1.pth",
    'PoolFormer_M2': f"{MER_MODELS_ROOT}/07_03_poolformer_casme2_mfs_prep/casme2_poolformer_mfs_best_f1.pth",
}

print(f"Configured {len(MER_CHECKPOINT_PATHS)} model checkpoints")

# ============================================================================
# SECTION 2: MER TASK CONFIGURATION
# ============================================================================

MER_TASK_CONFIG = {
    'num_classes': 7,
    'class_names': ['Others', 'Disgust', 'Happiness', 'Repression', 'Surprise', 'Sadness', 'Fear'],
    'dataset': 'CASME II',
    'description': 'Micro-expression recognition (7 emotions)',
    'evaluation_phase': 'AF'
}

MER_PREPROCESSING_CONFIGS = {
    'M1': {
        'name': 'M1 (Raw RGB)',
        'description_cnn': 'Minimal preprocessing - RGB 640×480',
        'description_transformer': 'Minimal preprocessing - RGB 384×384',
        'image_size_cnn': (640, 480),
        'image_size_transformer': 384,
        'channels': 3,
        'mode': 'RGB'
    },
    'M2': {
        'name': 'M2 (Preprocessed)',
        'description': 'Face-aware preprocessing - Grayscale 224×224',
        'image_size': 224,
        'channels': 1,
        'mode': 'Grayscale'
    }
}

print(f"Task configuration: {MER_TASK_CONFIG['dataset']} - {MER_TASK_CONFIG['num_classes']} classes")

# ============================================================================
# SECTION 3: CNN MODEL ARCHITECTURES (UNCHANGED)
# ============================================================================

print("\n[2/7] Defining CNN model architectures...")

class MER_MobileNetV3(nn.Module):
    def __init__(self, num_classes=7, dropout_rate=0.3, input_channels=3):
        super(MER_MobileNetV3, self).__init__()

        if input_channels == 1:
            self.mobilenet = timm.create_model(
                'mobilenetv3_small_100',
                pretrained=False,
                num_classes=0,
                global_pool='',
                in_chans=1
            )
        else:
            self.mobilenet = timm.create_model(
                'mobilenetv3_small_100',
                pretrained=True,
                num_classes=0,
                global_pool=''
            )

        self.conv_head = nn.Conv2d(576, 1024, kernel_size=1, bias=True)
        self.act_head = nn.ReLU(inplace=True)
        self.global_pool = nn.AdaptiveAvgPool2d(1)
        self.flatten = nn.Flatten(1)

        self.classifier_layers = nn.Sequential(
            nn.Linear(1024, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout_rate),
            nn.Linear(512, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout_rate)
        )

        self.classifier = nn.Linear(128, num_classes)

    def forward(self, x):
        features = self.mobilenet.forward_features(x)
        features = self.conv_head(features)
        features = self.act_head(features)
        features = self.global_pool(features)
        features = self.flatten(features)
        x = self.classifier_layers(features)
        output = self.classifier(x)
        return output


class MER_EfficientNet(nn.Module):
    def __init__(self, num_classes=7, dropout_rate=0.3, input_channels=3):
        super(MER_EfficientNet, self).__init__()

        if input_channels == 1:
            self.efficientnet = timm.create_model(
                'efficientnet_b0',
                pretrained=False,
                num_classes=0,
                in_chans=1
            )
        else:
            self.efficientnet = timm.create_model(
                'efficientnet_b0',
                pretrained=True,
                num_classes=0
            )

        self.feature_dim = 1280

        self.classifier_layers = nn.Sequential(
            nn.Linear(self.feature_dim, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout_rate),
            nn.Linear(512, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout_rate)
        )

        self.classifier = nn.Linear(128, num_classes)

    def forward(self, x):
        features = self.efficientnet(x)
        x = self.classifier_layers(features)
        output = self.classifier(x)
        return output


class MER_ConvNeXt(nn.Module):
    def __init__(self, num_classes=7, dropout_rate=0.3, input_channels=3):
        super(MER_ConvNeXt, self).__init__()

        if input_channels == 1:
            self.convnext = timm.create_model(
                'convnext_tiny',
                pretrained=False,
                num_classes=0,
                in_chans=1
            )
        else:
            self.convnext = timm.create_model(
                'convnext_tiny',
                pretrained=True,
                num_classes=0
            )

        self.feature_dim = 768

        self.classifier_layers = nn.Sequential(
            nn.Linear(self.feature_dim, 512),
            nn.LayerNorm(512),
            nn.GELU(),
            nn.Dropout(dropout_rate),
            nn.Linear(512, 128),
            nn.LayerNorm(128),
            nn.GELU(),
            nn.Dropout(dropout_rate)
        )

        self.classifier = nn.Linear(128, num_classes)

    def forward(self, x):
        features = self.convnext(x)
        x = self.classifier_layers(features)
        output = self.classifier(x)
        return output


print("Defined 3 CNN architectures")

# ============================================================================
# SECTION 4: TRANSFORMER MODEL ARCHITECTURES (FIXED)
# ============================================================================

print("\n[3/7] Defining Transformer model architectures...")

try:
    from transformers import ViTModel, ViTConfig
    from transformers import SwinModel, SwinConfig
    from transformers import PoolFormerModel, PoolFormerConfig
    print("Transformers library imported")
    TRANSFORMERS_AVAILABLE = True
except ImportError:
    print("Installing transformers...")
    import subprocess
    import sys
    subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "transformers"])
    from transformers import ViTModel, ViTConfig
    from transformers import SwinModel, SwinConfig
    from transformers import PoolFormerModel, PoolFormerConfig
    print("Transformers library installed")
    TRANSFORMERS_AVAILABLE = True


class MER_ViT(nn.Module):
    def __init__(self, num_classes=7, dropout_rate=0.3, input_channels=3, methodology='M1'):
        super(MER_ViT, self).__init__()

        # Both M1 and M2 use 224×224 RGB
        image_size = 224
        classifier_dim = 512 if methodology == 'M1' else 256

        config = ViTConfig(
            hidden_size=768,
            num_hidden_layers=12,
            num_attention_heads=12,
            intermediate_size=3072,
            hidden_dropout_prob=dropout_rate,
            attention_probs_dropout_prob=dropout_rate,
            num_channels=input_channels,
            image_size=image_size,
            patch_size=32,
            num_labels=num_classes
        )

        self.vit = ViTModel(config)
        self.feature_dim = 768

        self.classifier_layers = nn.Sequential(
            nn.Linear(self.feature_dim, classifier_dim),
            nn.LayerNorm(classifier_dim),
            nn.GELU(),
            nn.Dropout(dropout_rate),
            nn.Linear(classifier_dim, 128),
            nn.LayerNorm(128),
            nn.GELU(),
            nn.Dropout(dropout_rate)
        )

        self.classifier = nn.Linear(128, num_classes)

    def forward(self, x):
        outputs = self.vit(pixel_values=x)
        features = outputs.last_hidden_state[:, 0]
        x = self.classifier_layers(features)
        output = self.classifier(x)
        return output


class MER_SwinTransformer(nn.Module):
    def __init__(self, num_classes=7, dropout_rate=0.3, input_channels=3, methodology='M1'):
        super(MER_SwinTransformer, self).__init__()

        # Both M1 and M2 use 224×224
        image_size = 224

        config = SwinConfig(
            embed_dim=128,
            depths=[2, 2, 18, 2],
            num_heads=[4, 8, 16, 32],
            window_size=7,
            mlp_ratio=4.0,
            drop_rate=dropout_rate,
            attn_drop_rate=dropout_rate,
            num_channels=input_channels,
            image_size=image_size,
            patch_size=4,
            num_labels=num_classes
        )

        self.swin = SwinModel(config)

        # M1: 1024 feature dim, M2: 768 feature dim
        self.feature_dim = 1024 if methodology == 'M1' else 768

        self.classifier_layers = nn.Sequential(
            nn.Linear(self.feature_dim, 512),
            nn.LayerNorm(512),
            nn.GELU(),
            nn.Dropout(dropout_rate),
            nn.Linear(512, 128),
            nn.LayerNorm(128),
            nn.GELU(),
            nn.Dropout(dropout_rate)
        )

        self.classifier = nn.Linear(128, num_classes)

    def forward(self, x):
        outputs = self.swin(pixel_values=x)
        features = outputs.pooler_output
        x = self.classifier_layers(features)
        output = self.classifier(x)
        return output


class MER_PoolFormer(nn.Module):
    def __init__(self, num_classes=7, dropout_rate=0.3, input_channels=3, methodology='M1'):
        super(MER_PoolFormer, self).__init__()

        # Both M1 and M2 use 224×224
        image_size = 224

        config = PoolFormerConfig(
            hidden_sizes=[96, 192, 384, 768],
            depths=[2, 2, 18, 2],
            mlp_ratio=4.0,
            drop_rate=dropout_rate,
            num_channels=input_channels,
            image_size=image_size,
            patch_size=7,
            num_labels=num_classes
        )

        self.poolformer = PoolFormerModel(config)
        self.feature_dim = 768

        # FIXED: Add explicit pooling
        self.global_pool = nn.AdaptiveAvgPool1d(1)

        self.classifier_layers = nn.Sequential(
            nn.Linear(self.feature_dim, 512),
            nn.LayerNorm(512),
            nn.GELU(),
            nn.Dropout(dropout_rate),
            nn.Linear(512, 128),
            nn.LayerNorm(128),
            nn.GELU(),
            nn.Dropout(dropout_rate)
        )

        self.classifier = nn.Linear(128, num_classes)

    def forward(self, x):
        outputs = self.poolformer(pixel_values=x)

        # FIXED: Proper spatial pooling
        # outputs.last_hidden_state shape: (batch, seq_len, 768)
        features = outputs.last_hidden_state

        # Transpose for pooling: (batch, 768, seq_len)
        features = features.transpose(1, 2)

        # Global average pooling: (batch, 768, 1)
        features = self.global_pool(features)

        # Squeeze: (batch, 768)
        features = features.squeeze(-1)

        x = self.classifier_layers(features)
        output = self.classifier(x)
        return output


print("Defined 3 Transformer architectures")

# ============================================================================
# SECTION 5: MODEL REGISTRY
# ============================================================================

MER_MODEL_REGISTRY = {
    'MobileNetV3': {
        'architecture_class': MER_MobileNetV3,
        'display_name': 'MobileNetV3-Small',
        'description': 'Lightweight CNN - Best overall (F1: 0.3880 M1-AF)',
        'model_type': 'CNN'
    },
    'EfficientNet': {
        'architecture_class': MER_EfficientNet,
        'display_name': 'EfficientNet-B0',
        'description': 'Compound scaling - Balanced performance',
        'model_type': 'CNN'
    },
    'ConvNeXt': {
        'architecture_class': MER_ConvNeXt,
        'display_name': 'ConvNeXt-Tiny',
        'description': 'Modernized CNN - Benefits from M2 (41% improvement)',
        'model_type': 'CNN'
    },
    'ViT': {
        'architecture_class': MER_ViT,
        'display_name': 'Vision Transformer (ViT)',
        'description': 'Attention-based - Global pattern analysis (F1: 0.2298 M1)',
        'model_type': 'Transformer'
    },
    'SwinTransformer': {
        'architecture_class': MER_SwinTransformer,
        'display_name': 'Swin Transformer',
        'description': 'Window attention - Hierarchical features (F1: 0.2619 M1)',
        'model_type': 'Transformer'
    },
    'PoolFormer': {
        'architecture_class': MER_PoolFormer,
        'display_name': 'PoolFormer',
        'description': 'Pooling-based - Simple and efficient (F1: 0.1734 M1)',
        'model_type': 'Transformer'
    }
}

print(f"Model registry: {len(MER_MODEL_REGISTRY)} architectures")

# ============================================================================
# SECTION 6: MODEL LOADING & INFERENCE (FIXED CHANNEL LOGIC)
# ============================================================================

print("\n[4/7] Defining model loading and inference functions...")

def load_mer_model_checkpoint(model_name, methodology):
    checkpoint_key = f"{model_name}_{methodology}"

    if checkpoint_key not in MER_CHECKPOINT_PATHS:
        raise KeyError(f"Checkpoint not found: {checkpoint_key}")

    checkpoint_path = MER_CHECKPOINT_PATHS[checkpoint_key]

    if not os.path.exists(checkpoint_path):
        raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}")

    print(f"Loading {model_name} ({methodology})...")

    try:
        checkpoint = torch.load(checkpoint_path, map_location='cpu')
    except Exception:
        checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)

    model_type = MER_MODEL_REGISTRY[model_name]['model_type']

    if methodology == 'M1':
        input_channels = 3
    else:
        input_channels = 1 if model_type == 'CNN' else 3

    print(f"  Model type: {model_type}")
    print(f"  Input channels: {input_channels} ({'RGB' if input_channels == 3 else 'Grayscale'})")

    architecture_class = MER_MODEL_REGISTRY[model_name]['architecture_class']

    if model_type == 'Transformer':
        model = architecture_class(
            num_classes=MER_TASK_CONFIG['num_classes'],
            dropout_rate=0.3 if methodology == 'M2' else 0.2,
            input_channels=input_channels,
            methodology=methodology
        ).to(device)
    else:
        model = architecture_class(
            num_classes=MER_TASK_CONFIG['num_classes'],
            dropout_rate=0.3 if methodology == 'M2' else 0.2,
            input_channels=input_channels
        ).to(device)

    state_dict = checkpoint.get('model_state_dict', checkpoint)

    try:
        model.load_state_dict(state_dict, strict=True)
        load_status = "strict"
    except Exception as e:
        print(f"  Warning: Strict loading failed: {e}")
        model.load_state_dict(state_dict, strict=False)
        load_status = "non-strict"

    model.eval()

    training_info = {
        'model_name': model_name,
        'methodology': methodology,
        'checkpoint_file': os.path.basename(checkpoint_path),
        'load_status': load_status,
        'input_channels': input_channels,
        'best_f1': checkpoint.get('best_f1', 'N/A')
    }

    print(f"  Loaded ({load_status})")
    if 'best_f1' in checkpoint:
        print(f"  Training F1: {checkpoint['best_f1']:.4f}")

    return model, training_info


def run_mer_inference(model, input_tensor):
    """Run MER inference"""
    model.eval()

    with torch.no_grad():
        outputs = model(input_tensor)
        probabilities = torch.softmax(outputs, dim=1)
        predictions = torch.argmax(probabilities, dim=1)

    return predictions, probabilities


def format_mer_results(predictions, probabilities):
    """Format MER prediction results"""

    pred_class_idx = predictions.item()
    class_probs = probabilities[0].cpu().numpy()

    class_names = MER_TASK_CONFIG['class_names']

    results = {
        'predicted_emotion': class_names[pred_class_idx],
        'predicted_index': pred_class_idx,
        'confidence': float(class_probs[pred_class_idx]),
        'all_probabilities': {
            class_names[i]: float(class_probs[i])
            for i in range(len(class_names))
        }
    }

    return results


class MERModelCache:
    """Cache for loaded MER models"""

    def __init__(self):
        self.cache = {}
        self.stats = {'hits': 0, 'misses': 0}

    def get(self, model_name, methodology):
        cache_key = f"{model_name}_{methodology}"
        if cache_key in self.cache:
            self.stats['hits'] += 1
            print(f"  Cache hit: {cache_key}")
            return self.cache[cache_key]
        else:
            self.stats['misses'] += 1
            return None

    def set(self, model_name, methodology, model, info):
        cache_key = f"{model_name}_{methodology}"
        self.cache[cache_key] = (model, info)
        print(f"  Cached: {cache_key}")

    def clear(self):
        self.cache.clear()
        self.stats = {'hits': 0, 'misses': 0}
        if torch.cuda.is_available():
            torch.cuda.empty_cache()


# CRITICAL: Instantiate the cache
mer_model_cache = MERModelCache()

print("Model loading, inference, and cache functions defined")

# ============================================================================
# SECTION 7: PREDICTION PIPELINE (FIXED DESCRIPTION)
# ============================================================================

print("\n[5/7] Defining prediction pipeline...")

def mer_predict_pipeline(image, model_choice, methodology_choice):
    try:
        if image is None:
            return (None, "### Error\n\nPlease upload or capture an image first.", {})

        if isinstance(image, np.ndarray):
            image = Image.fromarray(image)

        if model_choice in MER_MODEL_REGISTRY:
            model_name = model_choice
            print(f"Using internal key: {model_name}")
        else:
            model_map = {
                'MobileNetV3-Small': 'MobileNetV3',
                'EfficientNet-B0': 'EfficientNet',
                'ConvNeXt-Tiny': 'ConvNeXt',
                'Vision Transformer (ViT)': 'ViT',
                'Swin Transformer': 'SwinTransformer',
                'PoolFormer': 'PoolFormer'
            }

            if model_choice in model_map:
                model_name = model_map[model_choice]
                print(f"Mapped display name: {model_choice} -> {model_name}")
            else:
                raise ValueError(f"Unknown model: {model_choice}")

        if methodology_choice in ['M1', 'M2']:
            methodology = methodology_choice
        else:
            methodology = 'M1' if 'M1' in methodology_choice else 'M2'

        comparison = create_preprocessing_comparison_mer_updated(image, methodology, model_name)

        print(f"Loading {model_name} ({methodology})...")
        cached = mer_model_cache.get(model_name, methodology)
        if cached:
            model, training_info = cached
            print("  From cache")
        else:
            model, training_info = load_mer_model_checkpoint(model_name, methodology)
            mer_model_cache.set(model_name, methodology, model, training_info)
            print("  From checkpoint")

        print(f"Preprocessing with {methodology} method...")
        model_type = MER_MODEL_REGISTRY[model_name]['model_type']

        if methodology == 'M1':
            input_tensor = preprocess_mer_m1_webcam(image, model_type)
            print(f"  M1 preprocessing: shape {input_tensor.shape}")
        else:
            input_tensor = preprocess_mer_m2_webcam(image, model_type)
            print(f"  M2 preprocessing: shape {input_tensor.shape}")

        print("Running inference...")
        predictions, probabilities = run_mer_inference(model, input_tensor)
        results = format_mer_results(predictions, probabilities)
        print(f"  Prediction: {results['predicted_emotion']} ({results['confidence']:.1%})")

        model_info = MER_MODEL_REGISTRY[model_name]
        preprocessing_info = MER_PREPROCESSING_CONFIGS[methodology]

        # FIXED: Fallback to 'description' if 'description_cnn' missing
        if model_type == 'CNN':
            prep_desc = preprocessing_info.get('description_cnn', preprocessing_info.get('description', 'CNN preprocessing'))
        else:
            prep_desc = preprocessing_info.get('description_transformer', preprocessing_info.get('description', 'Transformer preprocessing'))

        prediction_text = f"""
### Micro-Expression Classification Results

**Model:** {model_info['display_name']}
**Type:** {model_type}
**Methodology:** {preprocessing_info['name']}
**Evaluation:** Apex Frame (Single Snapshot)

---

**Predicted Emotion:** `{results['predicted_emotion']}`
**Confidence:** {results['confidence']:.1%}

---

**Model Info:**
- {model_info['description']}
- Training F1: {training_info.get('best_f1', 'N/A')}
- Input: {prep_desc}
- Dataset: CASME II (7 emotion categories)
"""

        prob_dict = results['all_probabilities']

        return (comparison, prediction_text, prob_dict)

    except Exception as e:
        error_msg = f"""
### Error

**Type:** {type(e).__name__}
**Message:** {str(e)}

**Debug Info:**
- Model Choice: {model_choice}
- Available Models: {list(MER_MODEL_REGISTRY.keys())}

Check terminal for traceback.
"""
        print(f"\nERROR: {e}")
        import traceback
        traceback.print_exc()

        return (None, error_msg, {})


print("Prediction pipeline defined")

# ============================================================================
# SECTION 8: VERIFICATION
# ============================================================================

print("\n[6/7] Verifying dependencies...")

required_functions = [
    'preprocess_mer_m1_webcam',
    'preprocess_mer_m2_webcam',
    'create_preprocessing_comparison_mer_updated'
]

missing = [f for f in required_functions if f not in globals()]
if missing:
    print(f"WARNING: Missing preprocessing functions from Cell 6: {missing}")
else:
    print("All preprocessing functions available")

print("\n[7/7] System ready...")

print("\n" + "=" * 80)
print("CELL 5 COMPLETE - MER SYSTEM (FINAL FIX)")
print("=" * 80)
print("\nFixed:")
print("  - M1 All: RGB (3 channels)")
print("  - M2 CNN: Grayscale (1 channel)")
print("  - M2 Transformer: RGB (3 channels) ← CRITICAL FIX")
print("  - PoolFormer: Explicit spatial pooling")
print("  - Description fallback for M2")
print("\nNext: Update Cell 6 for conditional M2 preprocessing")
print("=" * 80)

MER PROJECT - MODELS & PIPELINE

[1/7] Configuring MER model checkpoints...
Configured 12 model checkpoints
Task configuration: CASME II - 7 classes

[2/7] Defining CNN model architectures...
Defined 3 CNN architectures

[3/7] Defining Transformer model architectures...
Transformers library imported
Defined 3 Transformer architectures
Model registry: 6 architectures

[4/7] Defining model loading and inference functions...
Model loading, inference, and cache functions defined

[5/7] Defining prediction pipeline...
Prediction pipeline defined

[6/7] Verifying dependencies...

[7/7] System ready...

CELL 5 COMPLETE - MER SYSTEM (FINAL FIX)

Fixed:
  - M1 All: RGB (3 channels)
  - M2 CNN: Grayscale (1 channel)
  - M2 Transformer: RGB (3 channels) ← CRITICAL FIX
  - PoolFormer: Explicit spatial pooling
  - Description fallback for M2

Next: Update Cell 6 for conditional M2 preprocessing


In [6]:
# @title Cell 6: MER Preprocessing Functions

# File: MMLab_Project_Demo.ipynb - Cell 6
# Location: Thesis_MER_Project/MMLab_Project_Demo.ipynb
# Purpose: Preprocessing with CORRECT grayscale conversion for M2

print("=" * 80)
print("MER PREPROCESSING FUNCTIONS - GRAYSCALE FIX")
print("=" * 80)

# ============================================================================
# SECTION 1: DLIB FACE DETECTOR
# ============================================================================

print("\n[1/4] Initializing Dlib face detector...")

try:
    import dlib
    mer_face_detector = dlib.get_frontal_face_detector()
    DLIB_AVAILABLE = True
    print("Dlib face detector loaded successfully")
except ImportError:
    DLIB_AVAILABLE = False
    print("Dlib not available - M2 will use center crop fallback")

# ============================================================================
# SECTION 2: M1 PREPROCESSING
# ============================================================================

print("\n[2/4] Defining M1 preprocessing...")

def preprocess_mer_m1_webcam(image, model_type='CNN'):
    """
    M1 preprocessing with model-specific sizing.

    CNN models: RGB 640×480
    Transformer models: RGB 224×224 (NOT 384×384!)

    Args:
        image: PIL Image or numpy array
        model_type: 'CNN' or 'Transformer'

    Returns:
        Torch tensor (1, 3, H, W) RGB
    """
    if isinstance(image, Image.Image):
        image = np.array(image)

    if len(image.shape) == 2:
        image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
    elif image.shape[2] == 4:
        image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)

    h, w = image.shape[:2]

    # FIXED: Transformer M1 uses 224×224, NOT 384×384!
    if model_type == 'Transformer':
        target_w, target_h = 224, 224
        aspect_target = 1.0
    else:  # CNN
        target_w, target_h = 640, 480
        aspect_target = target_w / target_h

    aspect_current = w / h

    if aspect_target == 1.0:
        crop_size = min(h, w)
        start_y = (h - crop_size) // 2
        start_x = (w - crop_size) // 2
        cropped = image[start_y:start_y+crop_size, start_x:start_x+crop_size]
    else:
        if aspect_current > aspect_target:
            new_w = int(h * aspect_target)
            x_start = (w - new_w) // 2
            cropped = image[:, x_start:x_start+new_w]
        elif aspect_current < aspect_target:
            new_h = int(w / aspect_target)
            y_start = (h - new_h) // 2
            cropped = image[y_start:y_start+new_h, :]
        else:
            cropped = image

    resized = cv2.resize(cropped, (target_w, target_h),
                        interpolation=cv2.INTER_LANCZOS4)

    pil_image = Image.fromarray(resized)

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    tensor = transform(pil_image).unsqueeze(0).to(device)

    return tensor


print("M1 preprocessing defined: CNN 640×480, Transformer 224×224")

# ============================================================================
# SECTION 3: M2 PREPROCESSING (GRAYSCALE FIX)
# ============================================================================

print("\n[3/4] Defining M2 preprocessing with grayscale conversion...")

def detect_face_with_expansion_exact(image, expansion=20):
    if not DLIB_AVAILABLE:
        return None

    try:
        if len(image.shape) == 3:
            gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
        else:
            gray = image

        faces = mer_face_detector(gray, 1)

        if len(faces) == 0:
            return None

        if len(faces) > 1:
            faces = sorted(faces, key=lambda r: r.width() * r.height(), reverse=True)

        face = faces[0]

        img_h, img_w = gray.shape[:2]

        x1 = max(0, face.left() - expansion)
        y1 = max(0, face.top() - expansion)
        x2 = min(img_w, face.right() + expansion)
        y2 = min(img_h, face.bottom() + expansion)

        return (x1, y1, x2, y2)

    except Exception:
        return None


def ensure_minimum_size_exact(image, min_size=224):
    h, w = image.shape[:2]

    if h >= min_size and w >= min_size:
        return image

    scale_factor = min_size / min(h, w)
    new_width = int(w * scale_factor)
    new_height = int(h * scale_factor)

    resized = cv2.resize(image, (new_width, new_height),
                        interpolation=cv2.INTER_LANCZOS4)

    return resized


def preprocess_mer_m2_webcam(image, model_type='CNN'):
    """
    M2 preprocessing: Conditional output based on model_type.

    CNN:         Grayscale 224x224 (1 channel), single-channel normalization
    Transformer: RGB 224x224 (3 channels), ImageNet RGB normalization

    Returns:
        Torch tensor (1, C, 224, 224) where C = 1 for CNN, 3 for Transformer.
    """
    # Accept PIL or numpy
    if isinstance(image, Image.Image):
        image = np.array(image)

    # Ensure RGB first (so face detector gets consistent input)
    if len(image.shape) == 2:
        image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
    elif image.shape[2] == 4:
        image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)

    # Ensure minimum size
    image = ensure_minimum_size_exact(image, min_size=224)

    h, w = image.shape[:2]

    # Face detection (may return None)
    face_bbox = detect_face_with_expansion_exact(image, expansion=20)

    if face_bbox is not None:
        x1, y1, x2, y2 = face_bbox
        cropped = image[y1:y2, x1:x2]
    else:
        # center-crop fallback
        center_size = min(h, w)
        start_y = (h - center_size) // 2
        start_x = (w - center_size) // 2
        cropped = image[start_y:start_y+center_size, start_x:start_x+center_size]

    # Resize to exact 224x224
    if cropped.shape[0] != 224 or cropped.shape[1] != 224:
        cropped_224 = cv2.resize(cropped, (224, 224), interpolation=cv2.INTER_LANCZOS4)
    else:
        cropped_224 = cropped

    # Decide branch: CNN -> grayscale (single channel). Otherwise treat as Transformer (RGB).
    is_cnn = str(model_type).upper().startswith('C')  # permissive check

    if is_cnn:
        # Convert RGB -> Grayscale and use single-channel normalization
        cropped_224_final = cv2.cvtColor(cropped_224, cv2.COLOR_RGB2GRAY)
        pil_image = Image.fromarray(cropped_224_final, mode='L')

        transform = transforms.Compose([
            transforms.ToTensor(),                     # (1, H, W)
            transforms.Normalize(mean=[0.485], std=[0.229])
        ])
    else:
        # Keep RGB and use ImageNet normalization
        pil_image = Image.fromarray(cropped_224)       # mode 'RGB'
        transform = transforms.Compose([
            transforms.ToTensor(),                     # (3, H, W)
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])

    tensor = transform(pil_image).unsqueeze(0).to(device)  # (1, C, 224, 224)

    return tensor



print("M2 preprocessing defined: Grayscale 224×224 (1 channel)")

# ============================================================================
# SECTION 4: PREPROCESSING COMPARISON VISUALIZATION
# ============================================================================

print("\n[4/4] Defining preprocessing comparison...")

def create_preprocessing_comparison_mer_updated(original_image, methodology, model_name='MobileNetV3'):
    """
    Create side-by-side comparison with correct processing.

    Args:
        original_image: PIL Image
        methodology: 'M1' or 'M2'
        model_name: Model name

    Returns:
        PIL Image with comparison
    """
    if isinstance(original_image, Image.Image):
        original_np = np.array(original_image)
    else:
        original_np = original_image

    if len(original_np.shape) == 2:
        original_np = cv2.cvtColor(original_np, cv2.COLOR_GRAY2RGB)
    elif original_np.shape[2] == 4:
        original_np = cv2.cvtColor(original_np, cv2.COLOR_RGBA2RGB)

    if methodology == 'M1':
        model_type = 'Transformer' if model_name in ['ViT', 'SwinTransformer', 'PoolFormer'] else 'CNN'

        if model_type == 'Transformer':
            # Square 224×224
            h, w = original_np.shape[:2]
            crop_size = min(h, w)
            start_y = (h - crop_size) // 2
            start_x = (w - crop_size) // 2
            cropped = original_np[start_y:start_y+crop_size, start_x:start_x+crop_size]
            resized = cv2.resize(cropped, (224, 224), interpolation=cv2.INTER_LANCZOS4)

            comparison = Image.new('RGB', (224 * 2, 224))
            orig_resized = cv2.resize(original_np, (224, 224), interpolation=cv2.INTER_LANCZOS4)
            comparison.paste(Image.fromarray(orig_resized), (0, 0))
            comparison.paste(Image.fromarray(resized), (224, 0))

        else:
            # Rectangular 640×480
            h, w = original_np.shape[:2]
            aspect_current = w / h
            aspect_target = 640 / 480

            if aspect_current > aspect_target:
                new_w = int(h * aspect_target)
                x_start = (w - new_w) // 2
                cropped = original_np[:, x_start:x_start+new_w]
            elif aspect_current < aspect_target:
                new_h = int(w / aspect_target)
                y_start = (h - new_h) // 2
                cropped = original_np[y_start:y_start+new_h, :]
            else:
                cropped = original_np

            resized = cv2.resize(cropped, (640, 480), interpolation=cv2.INTER_LANCZOS4)

            comparison = Image.new('RGB', (640 * 2, 480))
            orig_resized = cv2.resize(original_np, (640, 480), interpolation=cv2.INTER_LANCZOS4)
            comparison.paste(Image.fromarray(orig_resized), (0, 0))
            comparison.paste(Image.fromarray(resized), (640, 0))

    else:  # M2
        # Grayscale 224×224 with face detection
        image = ensure_minimum_size_exact(original_np, min_size=224)

        face_bbox = detect_face_with_expansion_exact(image, expansion=20)

        if face_bbox is not None:
            x1, y1, x2, y2 = face_bbox
            cropped = image[y1:y2, x1:x2]
        else:
            h, w = image.shape[:2]
            center_size = min(h, w)
            start_y = (h - center_size) // 2
            start_x = (w - center_size) // 2
            cropped = image[start_y:start_y+center_size, start_x:start_x+center_size]

        cropped_224 = cv2.resize(cropped, (224, 224), interpolation=cv2.INTER_LANCZOS4)

        # Convert to grayscale for display
        cropped_224_gray = cv2.cvtColor(cropped_224, cv2.COLOR_RGB2GRAY)
        cropped_224_gray_rgb = cv2.cvtColor(cropped_224_gray, cv2.COLOR_GRAY2RGB)

        # Create comparison
        comparison = Image.new('RGB', (224 * 2, 224))
        orig_224 = cv2.resize(original_np, (224, 224), interpolation=cv2.INTER_LANCZOS4)
        comparison.paste(Image.fromarray(orig_224), (0, 0))
        comparison.paste(Image.fromarray(cropped_224_gray_rgb), (224, 0))

    return comparison


print("Preprocessing comparison defined")

print("\n" + "=" * 80)
print("CELL 6 COMPLETE - MER PREPROCESSING (GRAYSCALE FIX)")
print("=" * 80)
print("\nM1 Preprocessing:")
print("  - CNN: RGB 640×480")
print("  - Transformer: RGB 224×224 (CORRECTED from 384×384)")
print("  - ImageNet RGB normalization")
print("\nM2 Preprocessing:")
print("  - Grayscale 224×224 (1 channel) ← FIXED")
print("  - Face detection + 20px expansion")
print("  - Single channel normalization")
print("\nDlib status:", "Available" if DLIB_AVAILABLE else "Fallback mode")
print("\nNext: Update Cell 5 Transformer image sizes")
print("=" * 80)

MER PREPROCESSING FUNCTIONS - GRAYSCALE FIX

[1/4] Initializing Dlib face detector...
Dlib face detector loaded successfully

[2/4] Defining M1 preprocessing...
M1 preprocessing defined: CNN 640×480, Transformer 224×224

[3/4] Defining M2 preprocessing with grayscale conversion...
M2 preprocessing defined: Grayscale 224×224 (1 channel)

[4/4] Defining preprocessing comparison...
Preprocessing comparison defined

CELL 6 COMPLETE - MER PREPROCESSING (GRAYSCALE FIX)

M1 Preprocessing:
  - CNN: RGB 640×480
  - Transformer: RGB 224×224 (CORRECTED from 384×384)
  - ImageNet RGB normalization

M2 Preprocessing:
  - Grayscale 224×224 (1 channel) ← FIXED
  - Face detection + 20px expansion
  - Single channel normalization

Dlib status: Available

Next: Update Cell 5 Transformer image sizes


In [7]:
# @title Cell 7: MER Interface

# File: MMLab_Project_Demo.ipynb - Cell 7
# Location: Thesis_MER_Project/MMLab_Project_Demo.ipynb
# Purpose: MER interface with CSS-based

print("=" * 80)
print("MER INTERFACE CONFIGURATION - CSS FIXED")
print("=" * 80)

# ============================================================================
# SECTION 1: ENHANCED MODEL DESCRIPTIONS WITH CARD STYLING
# ============================================================================

print("\n[1/4] Defining enhanced model descriptions...")

def get_mer_model_description(model_name):
    """Get styled model description with card formatting"""

    descriptions = {
        'MobileNetV3-Small': """
<div style='padding: 16px; background: #f9fafb; border-left: 4px solid #dc2626; border-radius: 8px; margin: 8px 0;'>
<h3 style='margin-top: 0; color: #1f2937;'>MobileNetV3-Small (CNN)</h3>

<div style='background: white; padding: 12px; border-radius: 6px; margin: 12px 0;'>
<strong>Architecture:</strong> Lightweight CNN with inverted residuals and efficient channel attention
</div>

<div style='background: white; padding: 12px; border-radius: 6px; margin: 12px 0;'>
<strong>Performance:</strong>
<ul style='margin: 8px 0; padding-left: 20px;'>
<li>Best Overall: Macro F1 = 0.3880 (M1-AF)</li>
<li>Parameters: 2.5M (lightweight)</li>
<li>Strength: Excellent single-frame classification</li>
<li>Weakness: Severe preprocessing degradation (-47.5% with M2)</li>
</ul>
</div>

<div style='background: #dcfce7; padding: 12px; border-radius: 6px; margin: 12px 0;'>
<strong style='color: #166534;'>Recommended:</strong> Use with M1 (Raw RGB) for best results
</div>

<div style='background: white; padding: 12px; border-radius: 6px; margin: 12px 0;'>
<strong>Technical Details:</strong>
<ul style='margin: 8px 0; padding-left: 20px;'>
<li>Input: RGB 640×480 (M1) or RGB 224×224 (M2)</li>
<li>Feature dimension: 1024</li>
<li>Classifier: 1024 → 512 → 128 → 7</li>
</ul>
</div>
</div>
""",

        'EfficientNet-B0': """
<div style='padding: 16px; background: #f9fafb; border-left: 4px solid #dc2626; border-radius: 8px; margin: 8px 0;'>
<h3 style='margin-top: 0; color: #1f2937;'>EfficientNet-B0 (CNN)</h3>

<div style='background: white; padding: 12px; border-radius: 6px; margin: 12px 0;'>
<strong>Architecture:</strong> Compound scaling CNN with efficient depth/width/resolution balance
</div>

<div style='background: white; padding: 12px; border-radius: 6px; margin: 12px 0;'>
<strong>Performance:</strong>
<ul style='margin: 8px 0; padding-left: 20px;'>
<li>Macro F1 = 0.3278 (M1-AF)</li>
<li>Parameters: 5.3M (balanced)</li>
<li>Strength: Consistent across methodologies</li>
<li>Weakness: Moderate preprocessing sensitivity (-33.7% with M2)</li>
</ul>
</div>

<div style='background: #dcfce7; padding: 12px; border-radius: 6px; margin: 12px 0;'>
<strong style='color: #166534;'>Recommended:</strong> Use with M1 for webcam snapshots
</div>

<div style='background: white; padding: 12px; border-radius: 6px; margin: 12px 0;'>
<strong>Technical Details:</strong>
<ul style='margin: 8px 0; padding-left: 20px;'>
<li>Input: RGB 640×480 (M1) or RGB 224×224 (M2)</li>
<li>Feature dimension: 1280</li>
<li>Classifier: 1280 → 512 → 128 → 7</li>
</ul>
</div>
</div>
""",

        'ConvNeXt-Tiny': """
<div style='padding: 16px; background: #f9fafb; border-left: 4px solid #dc2626; border-radius: 8px; margin: 8px 0;'>
<h3 style='margin-top: 0; color: #1f2937;'>ConvNeXt-Tiny (CNN)</h3>

<div style='background: white; padding: 12px; border-radius: 6px; margin: 12px 0;'>
<strong>Architecture:</strong> Modernized CNN inspired by Vision Transformers
</div>

<div style='background: white; padding: 12px; border-radius: 6px; margin: 12px 0;'>
<strong>Performance:</strong>
<ul style='margin: 8px 0; padding-left: 20px;'>
<li>M1: Macro F1 = 0.2437 (AF)</li>
<li>M2: Macro F1 = 0.3439 (AF, +41.1% improvement)</li>
<li>Parameters: 28.6M (large)</li>
<li>Strength: Unique benefit from preprocessing</li>
<li>Weakness: Poor M1 performance</li>
</ul>
</div>

<div style='background: #dcfce7; padding: 12px; border-radius: 6px; margin: 12px 0;'>
<strong style='color: #166534;'>Recommended:</strong> Use with M2 (Preprocessed) for optimal results
</div>

<div style='background: white; padding: 12px; border-radius: 6px; margin: 12px 0;'>
<strong>Technical Details:</strong>
<ul style='margin: 8px 0; padding-left: 20px;'>
<li>Input: RGB 640×480 (M1) or RGB 224×224 (M2)</li>
<li>Feature dimension: 768</li>
<li>Classifier: 768 → 512 → 128 → 7</li>
</ul>
</div>
</div>
""",

        'Vision Transformer (ViT)': """
<div style='padding: 16px; background: #f9fafb; border-left: 4px solid #3b82f6; border-radius: 8px; margin: 8px 0;'>
<h3 style='margin-top: 0; color: #1f2937;'>Vision Transformer (ViT)</h3>

<div style='background: white; padding: 12px; border-radius: 6px; margin: 12px 0;'>
<strong>Architecture:</strong> Pure attention-based model with patch embeddings
</div>

<div style='background: white; padding: 12px; border-radius: 6px; margin: 12px 0;'>
<strong>Performance:</strong>
<ul style='margin: 8px 0; padding-left: 20px;'>
<li>M1: Macro F1 = 0.2298 (direct)</li>
<li>M2: Macro F1 = 0.2347 (preprocessed)</li>
<li>Parameters: 88M (large)</li>
<li>Strength: Global pattern analysis via self-attention</li>
<li>Weakness: Requires large datasets for optimal performance</li>
</ul>
</div>

<div style='background: #dbeafe; padding: 12px; border-radius: 6px; margin: 12px 0;'>
<strong style='color: #1e40af;'>Research Note:</strong> Baseline for attention mechanisms
</div>

<div style='background: white; padding: 12px; border-radius: 6px; margin: 12px 0;'>
<strong>Technical Details:</strong>
<ul style='margin: 8px 0; padding-left: 20px;'>
<li>Input: RGB 384×384 (M1) or RGB 224×224 (M2)</li>
<li>Patch size: 32×32</li>
<li>Feature dimension: 768</li>
<li>Classifier: 768 → 512 → 128 → 7</li>
</ul>
</div>
</div>
""",

        'Swin Transformer': """
<div style='padding: 16px; background: #f9fafb; border-left: 4px solid #3b82f6; border-radius: 8px; margin: 8px 0;'>
<h3 style='margin-top: 0; color: #1f2937;'>Swin Transformer</h3>

<div style='background: white; padding: 12px; border-radius: 6px; margin: 12px 0;'>
<strong>Architecture:</strong> Hierarchical vision transformer with shifted windows
</div>

<div style='background: white; padding: 12px; border-radius: 6px; margin: 12px 0;'>
<strong>Performance:</strong>
<ul style='margin: 8px 0; padding-left: 20px;'>
<li>M1: Macro F1 = 0.2619 (direct, best Transformer)</li>
<li>Parameters: 87M (large)</li>
<li>Strength: Window-based attention for local patterns</li>
<li>Weakness: Computational complexity</li>
</ul>
</div>

<div style='background: #dcfce7; padding: 12px; border-radius: 6px; margin: 12px 0;'>
<strong style='color: #166534;'>Recommended:</strong> Best Transformer for micro-expressions
</div>

<div style='background: white; padding: 12px; border-radius: 6px; margin: 12px 0;'>
<strong>Technical Details:</strong>
<ul style='margin: 8px 0; padding-left: 20px;'>
<li>Input: RGB 384×384 (M1) or RGB 224×224 (M2)</li>
<li>Window size: 7×7</li>
<li>Feature dimension: 768</li>
<li>Classifier: 768 → 512 → 128 → 7</li>
</ul>
</div>
</div>
""",

        'PoolFormer': """
<div style='padding: 16px; background: #f9fafb; border-left: 4px solid #3b82f6; border-radius: 8px; margin: 8px 0;'>
<h3 style='margin-top: 0; color: #1f2937;'>PoolFormer</h3>

<div style='background: white; padding: 12px; border-radius: 6px; margin: 12px 0;'>
<strong>Architecture:</strong> MetaFormer with simple pooling instead of attention
</div>

<div style='background: white; padding: 12px; border-radius: 6px; margin: 12px 0;'>
<strong>Performance:</strong>
<ul style='margin: 8px 0; padding-left: 20px;'>
<li>M1: Macro F1 = 0.1734 (multiframe)</li>
<li>Parameters: 73M (efficient)</li>
<li>Strength: Simple pooling-based token mixing</li>
<li>Weakness: Lower accuracy compared to other Transformers</li>
</ul>
</div>

<div style='background: #dbeafe; padding: 12px; border-radius: 6px; margin: 12px 0;'>
<strong style='color: #1e40af;'>Research Note:</strong> Comparison for pooling vs attention
</div>

<div style='background: white; padding: 12px; border-radius: 6px; margin: 12px 0;'>
<strong>Technical Details:</strong>
<ul style='margin: 8px 0; padding-left: 20px;'>
<li>Input: RGB 384×384 (M1) or RGB 224×224 (M2)</li>
<li>Pooling: 3×3 average pooling</li>
<li>Feature dimension: 768</li>
<li>Classifier: 768 → 512 → 128 → 7</li>
</ul>
</div>
</div>
"""
    }

    return descriptions.get(model_name, "<p>No description available</p>")


print("Enhanced model descriptions defined for 6 architectures")

# ============================================================================
# SECTION 2: CUSTOM CSS FOR TEXT ALIGNMENT FIX
# ============================================================================

print("\n[2/4] Defining custom CSS...")

# CSS to remove extra left margin/padding from labels and info text
MER_CUSTOM_CSS = """
/* Fix text alignment - remove extra left margin */
.align-left .block-label,
.align-left .info {
    margin-left: 8px !important;
    padding-left: 0 !important;
}

/* Ensure radio buttons themselves stay properly indented */
.align-left .wrap {
    padding-left: 8px;
}
"""

print("Custom CSS defined for text alignment")

# ============================================================================
# SECTION 3: MER INTERFACE LAYOUT
# ============================================================================

print("\n[3/4] Building MER interface")

with gr.Blocks(css=MER_CUSTOM_CSS) as mer_interface:
    gr.Markdown("""
    ## Micro-Expression Recognition (MER) Analysis

    Capture your facial expression using webcam for real-time classification.
    """)

    with gr.Row():
        # Left column - Configuration
        with gr.Column(scale=1):
            gr.Markdown("### Configuration")

            # Model Selection (CSS class applied)
            with gr.Group():
                mer_model_select = gr.Radio(
                    choices=[
                        'MobileNetV3-Small',
                        'EfficientNet-B0',
                        'ConvNeXt-Tiny',
                        'Vision Transformer (ViT)',
                        'Swin Transformer',
                        'PoolFormer'
                    ],
                    value='MobileNetV3-Small',
                    label="Model Architecture",
                    info="Select CNN or Transformer model",
                    elem_classes="align-left"
                )

            # Methodology Selection (CSS class applied)
            with gr.Group():
                mer_methodology_select = gr.Radio(
                    choices=[
                        'M1 (Raw RGB 640×480)',
                        'M2 (Preprocessed Grayscale 224×224)'
                    ],
                    value='M1 (Raw RGB 640×480)',
                    label="Preprocessing Methodology",
                    info="M1 for MobileNet/EfficientNet, M2 for ConvNeXt",
                    elem_classes="align-left"
                )

            # Classify Button
            mer_classify_btn = gr.Button(
                "Classify Expression",
                variant="primary",
                size="lg"
            )

            # Evaluation Info
            with gr.Accordion("Evaluation", open=True):
                gr.Markdown("""
                **Evaluation Phase:** Apex Frame (single snapshot)

                **Processing:** Check terminal for detailed progress

                **Dataset:** CASME II (7 emotion categories)
                """)

            # Model Info Display (Open by default with styled cards)
            with gr.Accordion("Model Information", open=True):
                mer_model_info = gr.HTML(
                    value=get_mer_model_description('MobileNetV3-Small'),
                    elem_id="mer_model_info_display"
                )

            # Research Findings
            with gr.Accordion("Research Findings", open=False):
                gr.HTML("""
                <div style='padding: 12px;'>
                <h3 style='color: #1f2937; margin-top: 0;'>Key Research Findings</h3>

                <div style='background: #fef2f2; padding: 12px; border-radius: 8px; margin: 12px 0; border-left: 4px solid #dc2626;'>
                <strong style='color: #991b1b;'>Preprocessing Paradox Discovery:</strong>
                <ul style='margin: 8px 0; padding-left: 20px; color: #374151;'>
                <li>MobileNetV3 M1→M2: -47.5% degradation</li>
                <li>EfficientNet M1→M2: -33.7% degradation</li>
                <li>ConvNeXt M1→M2: +41.1% improvement (unique)</li>
                </ul>
                </div>

                <div style='background: #dcfce7; padding: 12px; border-radius: 8px; margin: 12px 0; border-left: 4px solid #059669;'>
                <strong style='color: #047857;'>Best Model Combinations:</strong>
                <ol style='margin: 8px 0; padding-left: 20px; color: #374151;'>
                <li>MobileNetV3 + M1: F1 = 0.3880 (Best Overall)</li>
                <li>ConvNeXt + M2: F1 = 0.3439 (Best M2)</li>
                <li>EfficientNet + M1: F1 = 0.3278</li>
                <li>Swin Transformer + M1: F1 = 0.2619 (Best Transformer)</li>
                <li>ViT + M2: F1 = 0.2347</li>
                </ol>
                </div>

                <div style='background: #dbeafe; padding: 12px; border-radius: 8px; margin: 12px 0; border-left: 4px solid #3b82f6;'>
                <strong style='color: #1e40af;'>Transformer Insights:</strong>
                <ul style='margin: 8px 0; padding-left: 20px; color: #374151;'>
                <li>Swin Transformer: Best Transformer performance (0.2619)</li>
                <li>ViT: Global attention patterns (0.2298)</li>
                <li>PoolFormer: Pooling-based alternative (0.1734)</li>
                <li>All Transformers lag behind best CNN (MobileNetV3)</li>
                </ul>
                </div>

                <div style='background: #fef3c7; padding: 12px; border-radius: 8px; margin: 12px 0; border-left: 4px solid #f59e0b;'>
                <strong style='color: #92400e;'>Dataset Challenges:</strong>
                <ul style='margin: 8px 0; padding-left: 20px; color: #374151;'>
                <li>Training samples: 201</li>
                <li>Test samples: 26</li>
                <li>Class imbalance: 49.5:1 (Others vs Fear)</li>
                <li>Rare emotions: Fear (0 test), Sadness (1 test)</li>
                </ul>
                </div>

                <div style='background: #f3f4f6; padding: 12px; border-radius: 8px; margin: 12px 0;'>
                <strong style='color: #1f2937;'>Publication Status:</strong>
                <p style='margin: 8px 0; color: #374151;'>
                IEEE ICICyTA 2025 (Accepted)<br>
                Paper: "Preprocessing Paradox in Vision Transformer-Based Micro-Expression Recognition"
                </p>
                </div>
                </div>
                """)

            # Methodology Comparison
            with gr.Accordion("Methodology Comparison", open=False):
                gr.HTML("""
                <div style='padding: 12px;'>
                <h3 style='color: #1f2937; margin-top: 0;'>Methodology Comparison</h3>

                <div style='background: white; padding: 16px; border-radius: 8px; margin: 12px 0; border: 2px solid #dc2626;'>
                <h4 style='color: #dc2626; margin-top: 0;'>M1 (Raw RGB)</h4>

                <div style='background: #f9fafb; padding: 12px; border-radius: 6px; margin: 8px 0;'>
                <strong>Pipeline:</strong>
                <ol style='margin: 8px 0; padding-left: 20px;'>
                <li>Minimal preprocessing</li>
                <li>CNN: Center crop to 640×480 (4:3 ratio)</li>
                <li>Transformer: Center crop to 384×384 (square)</li>
                <li>RGB normalization (ImageNet stats)</li>
                </ol>
                </div>

                <div style='background: #f9fafb; padding: 12px; border-radius: 6px; margin: 8px 0;'>
                <strong>Characteristics:</strong>
                <ul style='margin: 8px 0; padding-left: 20px;'>
                <li>Input pixels: CNN 307K, Transformer 147K</li>
                <li>Face centering: 84% approximate</li>
                <li>Best for: MobileNetV3, EfficientNet, Swin Transformer</li>
                </ul>
                </div>
                </div>

                <div style='background: white; padding: 16px; border-radius: 8px; margin: 12px 0; border: 2px solid #3b82f6;'>
                <h4 style='color: #3b82f6; margin-top: 0;'>M2 (Preprocessed)</h4>

                <div style='background: #f9fafb; padding: 12px; border-radius: 6px; margin: 8px 0;'>
                <strong>Pipeline:</strong>
                <ol style='margin: 8px 0; padding-left: 20px;'>
                <li>Face detection (Dlib)</li>
                <li>Bounding box expansion (+20px)</li>
                <li>Crop + resize to 224×224</li>
                <li>RGB normalization (ImageNet stats)</li>
                </ol>
                </div>

                <div style='background: #f9fafb; padding: 12px; border-radius: 6px; margin: 8px 0;'>
                <strong>Characteristics:</strong>
                <ul style='margin: 8px 0; padding-left: 20px;'>
                <li>Input pixels: 50K (-84% vs M1 CNN)</li>
                <li>Face centering: 100% (when detected)</li>
                <li>Best for: ConvNeXt</li>
                <li>Degrades: MobileNetV3, EfficientNet</li>
                </ul>
                </div>
                </div>
                </div>
                """)

        # Right column - Input and Results
        with gr.Column(scale=1):
            gr.Markdown("### Webcam Input")

            mer_webcam_input = gr.Image(
                sources=["webcam"],
                type="pil",
                label="Capture Your Expression",
                height=400
            )

            gr.Markdown("""
            **Instructions:**
            1. Click camera icon to activate webcam
            2. Position your face in frame
            3. Capture snapshot of your expression
            4. Select model and methodology
            5. Click "Classify Expression"
            """)

            gr.Markdown("### Classification Results")

            mer_preprocessing_comparison = gr.Image(
                label="Preprocessing Comparison",
                height=300
            )

            mer_prediction_output = gr.Markdown(
                value="### Micro-Expression Classification Results\n\nCapture an image and click 'Classify Expression' to see results."
            )

            mer_probabilities_output = gr.Label(
                label="Emotion Probabilities (7 Classes)",
                num_top_classes=7
            )

    # Event Handlers
    mer_model_select.change(
        fn=get_mer_model_description,
        inputs=[mer_model_select],
        outputs=[mer_model_info]
    )

    mer_classify_btn.click(
        fn=mer_predict_pipeline,
        inputs=[
            mer_webcam_input,
            mer_model_select,
            mer_methodology_select
        ],
        outputs=[
            mer_preprocessing_comparison,
            mer_prediction_output,
            mer_probabilities_output
        ]
    )

print("MER interface with CSS fix complete")

# ============================================================================
# SECTION 4: VERIFICATION
# ============================================================================

print("\n[4/4] Interface verification...")
print("Model options: 6 (3 CNN + 3 Transformer)")
print("Methodology options: 2 (M1, M2)")
print("Total configurations: 12")
print("CSS fixes applied:")
print("  - Custom CSS removes extra left margin from labels")
print("  - Text alignment: margin-left: 0, padding-left: 0")
print("  - Model Information accordion: open by default")

print("\n" + "=" * 80)
print("CELL 7 COMPLETE - MER INTERFACE READY (FINAL)")
print("=" * 80)
print("\nFixed issues:")
print("  - Text perfectly aligned (no left indent)")
print("  - Model Information visible and styled")
print("  - All accordions functional")
print("  - CSS-based solution for clean styling")
print("\nNext: Cell 8 launch demo")
print("=" * 80)

MER INTERFACE CONFIGURATION - CSS FIXED

[1/4] Defining enhanced model descriptions...
Enhanced model descriptions defined for 6 architectures

[2/4] Defining custom CSS...
Custom CSS defined for text alignment

[3/4] Building MER interface
MER interface with CSS fix complete

[4/4] Interface verification...
Model options: 6 (3 CNN + 3 Transformer)
Methodology options: 2 (M1, M2)
Total configurations: 12
CSS fixes applied:
  - Custom CSS removes extra left margin from labels
  - Text alignment: margin-left: 0, padding-left: 0
  - Model Information accordion: open by default

CELL 7 COMPLETE - MER INTERFACE READY (FINAL)

Fixed issues:
  - Text perfectly aligned (no left indent)
  - Model Information visible and styled
  - All accordions functional
  - CSS-based solution for clean styling

Next: Cell 8 launch demo


In [8]:
# @title Cell 8: Launch Integrated Demo

# File: MMLab_Project_Demo.ipynb - Cell 8
# Location: Thesis_MER_Project/MMLab_Project_Demo.ipynb
# Purpose: Launch demo with proper attention visualization

print("=" * 80)
print("LAUNCHING INTEGRATED DEMO")
print("=" * 80)

# ============================================================================
# SECTION 1: VERIFY AND DEFINE REQUIRED VARIABLES
# ============================================================================

print("\n[1/5] Verifying system variables...")

if 'PROJECT_INFO' not in globals():
    PROJECT_INFO = {
        'lab_name': 'Multimedia Laboratory',
        'university': 'Telkom University',
        'location': 'Bandung, Indonesia',
        'primary_color': '#C41E3A'
    }
    print("PROJECT_INFO defined")

if 'MMLAB_THEME' not in globals():
    MMLAB_THEME = gr.themes.Soft(
        primary_hue="red",
        secondary_hue="slate",
        neutral_hue="slate",
        font=gr.themes.GoogleFont("Inter")
    )
    print("MMLAB_THEME defined")

required_functions = ['her2_predict_pipeline', 'mer_predict_pipeline']
missing_functions = [f for f in required_functions if f not in globals()]

if missing_functions:
    print(f"WARNING: Missing functions: {', '.join(missing_functions)}")
    raise RuntimeError("Required functions not available")

print("System variables verified")

# ============================================================================
# SECTION 2: MAPPING FUNCTIONS
# ============================================================================

print("\n[2/5] Defining mapping functions...")

def map_her2_model_display_to_key(display_name):
    mapping = {
        'MobileNetV3-Large': 'MobileNetV3',
        'Vision Transformer (ViT)': 'ViT',
        'Fusion (Concatenation)': 'FusionConcat',
        'Fusion (Addition)': 'FusionAddition'
    }
    return mapping.get(display_name, display_name)

def map_her2_preprocessing_display_to_key(display_name):
    mapping = {
        'Standard': 'Standard Preprocessing',
        'Medical-Optimized': 'Medical Preprocessing'
    }
    return mapping.get(display_name, display_name)

def map_her2_task_display_to_key(display_name):
    mapping = {
        'IHC Score (0-3)': 'IHC',
        'HER2 Status (Neg/Pos)': 'HER2'
    }
    return mapping.get(display_name, display_name)

def map_mer_model_display_to_key(display_name):
    mapping = {
        'MobileNetV3-Small': 'MobileNetV3',
        'EfficientNet-B0': 'EfficientNet',
        'ConvNeXt-Tiny': 'ConvNeXt',
        'Vision Transformer (ViT)': 'ViT',
        'Swin Transformer': 'SwinTransformer',
        'PoolFormer': 'PoolFormer'
    }
    return mapping.get(display_name, display_name)

def map_mer_methodology_display_to_key(display_name):
    if 'M1' in display_name:
        return 'M1'
    elif 'M2' in display_name:
        return 'M2'
    return display_name

print("Mapping functions defined")

# ============================================================================
# SECTION 3: WRAPPER FUNCTIONS
# ============================================================================

print("\n[3/5] Defining wrapper functions...")

def her2_predict_with_mapping(image, preprocessing_display, model_display, task_display):
    preprocessing_key = map_her2_preprocessing_display_to_key(preprocessing_display)
    model_key = map_her2_model_display_to_key(model_display)
    task_key = map_her2_task_display_to_key(task_display)

    print(f"\nHER2 Mapping:")
    print(f"  Preprocessing: {preprocessing_display} -> {preprocessing_key}")
    print(f"  Model: {model_display} -> {model_key}")
    print(f"  Task: {task_display} -> {task_key}")

    return her2_predict_pipeline(image, preprocessing_key, model_key, task_key)

def mer_predict_with_mapping(image, model_display, methodology_display):
    model_key = map_mer_model_display_to_key(model_display)
    methodology_key = map_mer_methodology_display_to_key(methodology_display)

    print(f"\nMER Mapping:")
    print(f"  Model: {model_display} -> {model_key}")
    print(f"  Methodology: {methodology_display} -> {methodology_key}")

    return mer_predict_pipeline(image, model_key, methodology_key)

print("Wrapper functions defined")

# ============================================================================
# SECTION 4: INTEGRATED DEMO INTERFACE
# ============================================================================

print("\n[4/5] Building integrated interface...")

with gr.Blocks(
    theme=MMLAB_THEME,
    title="Multimedia Laboratory - Research Showcase"
) as demo:

    # ========================================================================
    # TAB 0: LANDING PAGE
    # ========================================================================

    with gr.Tab("Projects Overview", id=0):
        gr.HTML(f"""
        <div style='text-align: center; padding: 48px 24px; background: linear-gradient(135deg, {PROJECT_INFO['primary_color']} 0%, #991b1b 100%); border-radius: 16px; margin-bottom: 32px;'>
            <h1 style='font-size: 42px; font-weight: 800; color: white; margin-bottom: 16px; text-shadow: 2px 2px 4px rgba(0,0,0,0.2);'>
                {PROJECT_INFO['lab_name']}
            </h1>
            <p style='font-size: 18px; color: rgba(255,255,255,0.95); margin-bottom: 8px;'>
                {PROJECT_INFO['university']}
            </p>
            <p style='font-size: 16px; color: rgba(255,255,255,0.85);'>
                Advancing AI Research in Medical Imaging and Affective Computing
            </p>
        </div>
        """)

        gr.Markdown("## Featured Research Projects")

        with gr.Row():
            with gr.Column(scale=1):
                gr.Markdown("""
                ### HER2 Status Classification

                **Domain:** Medical Imaging (Gastroesophageal Cancer)

                **Models:** 4 architectures with dual preprocessing

                **Innovation:** Medical-optimized preprocessing and fusion architectures

                **Status:** Under Review - International Conference
                """)

                her2_demo_btn = gr.Button("Launch HER2 Demo", variant="primary", size="lg")

            with gr.Column(scale=1):
                gr.Markdown("""
                ### Micro-Expression Recognition (MER)

                **Domain:** Affective Computing (Emotion Analysis)

                **Models:** 6 architectures with dual methodologies

                **Innovation:** Preprocessing paradox discovery

                **Status:** Accepted - IEEE ICICyTA 2025
                """)

                mer_demo_btn = gr.Button("Launch MER Demo", variant="primary", size="lg")

    # ========================================================================
    # TAB 1: HER2 DEMO (ATTENTION VISUALIZATION
    # ========================================================================

    with gr.Tab("HER2 Classification", id=1):
        gr.Markdown("""
        ## HER2 Status Classification for Gastroesophageal Cancer

        Upload histopathology tissue microarray (TMA) images for automated HER2 status prediction.
        """)

        with gr.Row():
            with gr.Column(scale=1):
                gr.Markdown("### Configuration")

                with gr.Group():
                    her2_preprocessing_select = gr.Radio(
                        choices=['Standard', 'Medical-Optimized'],
                        value='Medical-Optimized',
                        label="Preprocessing Strategy",
                        info="Medical-Optimized includes CLAHE enhancement"
                    )

                with gr.Group():
                    her2_model_select = gr.Radio(
                        choices=[
                            'MobileNetV3-Large',
                            'Vision Transformer (ViT)',
                            'Fusion (Concatenation)',
                            'Fusion (Addition)'
                        ],
                        value='Fusion (Concatenation)',
                        label="Model Architecture",
                        info="Fusion models combine CNN and Transformer"
                    )

                with gr.Group():
                    her2_task_select = gr.Radio(
                        choices=['IHC Score (0-3)', 'HER2 Status (Neg/Pos)'],
                        value='HER2 Status (Neg/Pos)',
                        label="Classification Task",
                        info="IHC intensity or final HER2 status"
                    )

                her2_classify_btn = gr.Button(
                    "Classify Image",
                    variant="primary",
                    size="lg"
                )

                with gr.Accordion("System Status", open=True):
                    gr.Markdown("""
                    **Ready:** All models loaded and cached

                    **Processing:** Check terminal for progress

                    **Dataset:** TMA images from gastroesophageal adenocarcinoma
                    """)

            with gr.Column(scale=1):
                gr.Markdown("### Image Input")

                her2_image_input = gr.Image(
                    type="pil",
                    label="Upload Histopathology TMA Image",
                    height=400
                )

                gr.Markdown("### Results")

                her2_preprocessing_comparison = gr.Image(
                    label="Preprocessing Comparison (Original vs Processed)",
                    height=250
                )

                her2_prediction_output = gr.Markdown(
                    value="### Prediction Results\n\nUpload image and click 'Classify Image'."
                )

                her2_probabilities_output = gr.Label(
                    label="Class Probabilities",
                    num_top_classes=4
                )

                gr.Markdown("### Attention Visualization")

                # FIXED: Single row with conditional visibility
                with gr.Row() as her2_attention_row:
                    her2_attention_1 = gr.Image(
                        label="Attention Map",
                        height=300
                    )
                    her2_attention_2 = gr.Image(
                        label="Branch Attention 2",
                        visible=False,
                        height=300
                    )
                    her2_attention_3 = gr.Image(
                        label="Fusion Attention",
                        visible=False,
                        height=300
                    )

        # HER2 Event Handler (FIXED - proper output mapping)
        her2_classify_btn.click(
            fn=her2_predict_with_mapping,
            inputs=[
                her2_image_input,
                her2_preprocessing_select,
                her2_model_select,
                her2_task_select
            ],
            outputs=[
                her2_preprocessing_comparison,
                her2_prediction_output,
                her2_probabilities_output,
                her2_attention_1,
                her2_attention_2,
                her2_attention_3,
                her2_attention_row,
                gr.Markdown(),  # label_1 (not displayed, placeholder)
                gr.Markdown(),  # label_2 (not displayed, placeholder)
                gr.Markdown()   # label_3 (not displayed, placeholder)
            ]
        )

    # ========================================================================
    # TAB 2: MER DEMO
    # ========================================================================

    with gr.Tab("MER Analysis", id=2):
        gr.Markdown("""
        ## Micro-Expression Recognition Analysis

        Capture facial expression using webcam for real-time emotion classification.
        """)

        with gr.Row():
            with gr.Column(scale=1):
                gr.Markdown("### Configuration")

                with gr.Group():
                    mer_model_select = gr.Radio(
                        choices=[
                            'MobileNetV3-Small',
                            'EfficientNet-B0',
                            'ConvNeXt-Tiny',
                            'Vision Transformer (ViT)',
                            'Swin Transformer',
                            'PoolFormer'
                        ],
                        value='MobileNetV3-Small',
                        label="Model Architecture"
                    )

                with gr.Group():
                    mer_methodology_select = gr.Radio(
                        choices=[
                            'M1 (Raw RGB 640x480)',
                            'M2 (Preprocessed Gray Scale 224x224)'
                        ],
                        value='M1 (Raw RGB 640x480)',
                        label="Preprocessing Methodology"
                    )

                mer_classify_btn = gr.Button(
                    "Classify Expression",
                    variant="primary",
                    size="lg"
                )

                with gr.Accordion("Evaluation Info", open=True):
                    gr.Markdown("""
                    **Evaluation:** Apex Frame (single snapshot)

                    **Dataset:** CASME II (7 emotion categories)
                    """)

            with gr.Column(scale=1):
                gr.Markdown("### Webcam Input")

                mer_image_input = gr.Image(
                    sources=["upload", "webcam"],
                    type="pil",
                    label="Capture Your Expression",
                    height=400
                )

                gr.Markdown("""
                **Instructions:**
                1. Activate webcam
                2. Position face in frame
                3. Capture snapshot
                4. Click 'Classify Expression'
                """)

                gr.Markdown("### Classification Results")

                mer_preprocessing_comparison = gr.Image(
                    label="Preprocessing Comparison",
                    height=250
                )

                mer_prediction_output = gr.Markdown(
                    value="### Results\n\nCapture image and classify."
                )

                mer_probabilities_output = gr.Label(
                    label="Emotion Probabilities (7 Classes)",
                    num_top_classes=7
                )

        # MER Event Handler
        mer_classify_btn.click(
            fn=mer_predict_with_mapping,
            inputs=[
                mer_image_input,
                mer_model_select,
                mer_methodology_select
            ],
            outputs=[
                mer_preprocessing_comparison,
                mer_prediction_output,
                mer_probabilities_output
            ]
        )

    # Navigation
    her2_demo_btn.click(fn=lambda: gr.Tabs(selected=1), inputs=None, outputs=None)
    mer_demo_btn.click(fn=lambda: gr.Tabs(selected=2), inputs=None, outputs=None)

print("Interface built")

# ============================================================================
# SECTION 5: LAUNCH
# ============================================================================

print("\n[5/5] Launching demo...")

print("\nSystem Configuration:")
print("  - HER2: 4 models with attention visualization")
print("  - MER: 6 models (3 CNN + 3 Transformer)")
print("  - Attention maps: Now properly rendered")

demo.launch(share=True, debug=False, show_error=True)

print("\n" + "=" * 80)
print("DEMO LAUNCHED - ATTENTION VISUALIZATION FIXED")
print("=" * 80)
print("\nFixed: Attention images now visible when generated")
print("Next: Test HER2 attention rendering, then MER system")
print("=" * 80)

LAUNCHING INTEGRATED DEMO

[1/5] Verifying system variables...
System variables verified

[2/5] Defining mapping functions...
Mapping functions defined

[3/5] Defining wrapper functions...
Wrapper functions defined

[4/5] Building integrated interface...
Interface built

[5/5] Launching demo...

System Configuration:
  - HER2: 4 models with attention visualization
  - MER: 6 models (3 CNN + 3 Transformer)
  - Attention maps: Now properly rendered
Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
* Running on public URL: https://6a3f07b95fffbaf1e7.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)



DEMO LAUNCHED - ATTENTION VISUALIZATION FIXED

Fixed: Attention images now visible when generated
Next: Test HER2 attention rendering, then MER system
