In [None]:
# example_usage.py
"""
Example script showing how to use the flexible classifier architecture.
"""

import torch
from torch.utils.data import DataLoader, Dataset
import numpy as np
from multimodal_stimulus_fmri_predict.core.classifier_factory import ClassifierFactory
from multimodal_stimulus_fmri_predict.utils.experiment_runner import ExperimentRunner
from multimodal_stimulus_fmri_predict.configs.experiment_configs import (
    get_vit_configs, get_resnet_configs, get_multimodal_configs
)

# Estimated runtime: 2-4 hours for neurodivergent/burned out individuals
# Estimated runtime: 1.5-2.5 hours for average person when regulated

class DummyDataset(Dataset):
    """Dummy dataset for testing the architecture."""
    
    def __init__(self, num_samples=1000, image_size=224, fmri_dim=1000, num_classes=2, mode='image'):
        self.num_samples = num_samples
        self.image_size = image_size
        self.fmri_dim = fmri_dim
        self.num_classes = num_classes
        self.mode = mode
        
    def __len__(self):
        return self.num_samples
    
    def __getitem__(self, idx):
        if self.mode == 'image':
            # Generate dummy image data
            image = torch.randn(3, self.image_size, self.image_size)
            label = torch.randint(0, self.num_classes, (1,)).item()
            return image, label
        
        elif self.mode == 'multimodal':
            # Generate dummy multimodal data
            image = torch.randn(3, self.image_size, self.image_size)
            fmri = torch.randn(self.fmri_dim)
            label = torch.randint(0, self.num_classes, (1,)).item()
            return (image, fmri), label


def create_data_loaders(batch_size=32, num_samples=1000, mode='image'):
    """Create train/val/test data loaders."""
    
    train_dataset = DummyDataset(num_samples, mode=mode)
    val_dataset = DummyDataset(num_samples//5, mode=mode)
    test_dataset = DummyDataset(num_samples//10, mode=mode)
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    
    return train_loader, val_loader, test_loader


def example_single_classifier():
    """Example: Using a single classifier."""
    print("=== Single Classifier Example ===")
    
    # Create data loaders
    train_loader, val_loader, test_loader = create_data_loaders(mode='image')
    
    # Configuration for ResNet classifier
    config = {
        'model_name': 'resnet50',
        'pretrained': True,
        'num_classes': 2,
        'learning_rate': 1e-4
    }
    
    # Create classifier
    classifier = ClassifierFactory.create_classifier('resnet', config)
    
    # Train
    print("Training classifier...")
    history = classifier.train(train_loader, val_loader, epochs=3)
    
    # Evaluate
    test_loss, test_acc = classifier.evaluate(test_loader)
    print(f"Test accuracy: {test_acc:.4f}")
    
    return classifier, history


def example_multiple_experiments():
    """Example: Running multiple experiments with different configurations."""
    print("\n=== Multiple Experiments Example ===")
    
    # Create data loaders for image-only experiments
    train_loader, val_loader, test_loader = create_data_loaders(mode='image')
    
    # Define experiment configurations
    experiment_configs = [
        {
            'classifier_type': 'resnet',
            'classifier_config': {
                'model_name': 'resnet18',
                'pretrained': True,
                'num_classes': 2,
                'learning_rate': 1e-4
            },
            'training_config': {'epochs': 2}
        },
        {
            'classifier_type': 'efficientnet',
            'classifier_config': {
                'model_name': 'efficientnet_b0',
                'pretrained': True,
                'num_classes': 2,
                'learning_rate': 1e-4
            },
            'training_config': {'epochs': 2}
        },
        {
            'classifier_type': 'vit',
            'classifier_config': {
                'pretrained': True,
                'image_size': 224,
                'patch_size': 16,
                'num_classes': 2,
                'learning_rate': 1e-4
            },
            'training_config': {'epochs': 2}
        }
    ]
    
    # Run experiments
    runner = ExperimentRunner(results_dir="experiment_results")
    results_df = runner.run_multiple_experiments(
        experiment_configs, train_loader, val_loader, test_loader
    )
    
    print("\nExperiment Results Summary:")
    print(results_df.to_string(index=False))
    
    return results_df


def example_multimodal_experiment():
    """Example: Running multimodal experiments."""
    print("\n=== Multimodal Experiment Example ===")
    
    # Create multimodal data loaders
    train_loader, val_loader, test_loader = create_data_loaders(mode='multimodal')
    
    # Multimodal configuration
    config = {
        'image_backbone': 'resnet50',
        'fmri_input_dim': 1000,
        'fusion_dim': 256,
        '