# MediSync FL India - Data Processing and ML Notebook



In [None]:
import os
import json
import time
import random
import hashlib
import logging
from pathlib import Path
from datetime import datetime

import numpy as np
from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms

from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, classification_report, confusion_matrix

In [None]:
LOG_FORMAT = '%(asctime)s | %(levelname)-7s | %(message)s'
DATE_FORMAT = '%Y-%m-%d %H:%M:%S'

logger = logging.getLogger('medisync')
logger.setLevel(logging.INFO)
logger.handlers.clear()
logger.propagate = False

formatter = logging.Formatter(LOG_FORMAT, datefmt=DATE_FORMAT)

stream_handler = logging.StreamHandler()
stream_handler.setFormatter(formatter)
logger.addHandler(stream_handler)

def setup_file_logging(log_dir: Path):
    log_dir.mkdir(parents=True, exist_ok=True)
    log_path = log_dir / 'training.log'
    file_handler = logging.FileHandler(log_path)
    file_handler.setFormatter(formatter)
    for handler in list(logger.handlers):
        if isinstance(handler, logging.FileHandler):
            logger.removeHandler(handler)
            handler.close()
    logger.addHandler(file_handler)
    logger.info(f'Logging to {log_path}')


In [None]:
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
logger.info(f'Random seed set to {SEED}')

In [None]:
PROJECT_ROOT = Path('.').resolve()
DATA_ROOT = PROJECT_ROOT / 'dataset'
ARTIFACTS_DIR = PROJECT_ROOT / 'models'
ARTIFACTS_ROOT = PROJECT_ROOT / 'artifacts'
LOGS_ROOT = PROJECT_ROOT / 'logs'
ARTIFACTS_DIR.mkdir(parents=True, exist_ok=True)
ARTIFACTS_ROOT.mkdir(parents=True, exist_ok=True)
LOGS_ROOT.mkdir(parents=True, exist_ok=True)


def get_next_run_dir(root: Path):
    runs = []
    for item in root.iterdir():
        if item.is_dir() and item.name.startswith('run-'):
            suffix = item.name[4:]
            if suffix.isdigit():
                runs.append(int(suffix))
    next_id = max(runs, default=0) + 1
    return root / f'run-{next_id:03d}'

RUN_DIR = get_next_run_dir(ARTIFACTS_ROOT)
RUN_DIR.mkdir(parents=True, exist_ok=True)
CACHE_DIR = RUN_DIR
RUN_TIMESTAMP = datetime.now().strftime('%Y%m%d_%H%M%S')
LOG_RUN_DIR = LOGS_ROOT / RUN_TIMESTAMP
setup_file_logging(LOG_RUN_DIR)

logger.info(f'Artifacts run dir: {RUN_DIR}')
logger.info(f'Log run dir: {LOG_RUN_DIR}')


In [None]:
LABEL_MAP = {
    'glioma': 0,
    'meningioma': 1,
    'pituitary': 2,
    'notumor': 3
}

IDX_TO_LABEL = {v: k for k, v in LABEL_MAP.items()}

logger.info(f'Label map initialized with {len(LABEL_MAP)} classes')

In [None]:
HOSPITAL_CONFIGS = {
    'AIIMS Delhi': {
        'dataset_id': 'dataset-1',
        'specialty': 'Adult Neuro-Oncology',
        'location': [28.5672, 77.2100]
    },
    'NIMHANS Bengaluru': {
        'dataset_id': 'dataset-2',
        'specialty': 'Neuro Specialty',
        'location': [12.9442, 77.5966]
    },
    'Tata Memorial Mumbai': {
        'dataset_id': 'dataset-3',
        'specialty': 'Oncology Referral',
        'location': [19.0049, 72.8414]
    }
}

In [None]:
DATASETS = [
    {
        'id': 'dataset-1',
        'hospital': 'AIIMS Delhi',
        'path': DATA_ROOT / 'dataset-1',
        'class_map': {
            'glioma': 'glioma',
            'meningioma': 'meningioma',
            'pituitary': 'pituitary',
            'notumor': 'notumor'
        }
    },
    {
        'id': 'dataset-2',
        'hospital': 'NIMHANS Bengaluru',
        'path': DATA_ROOT / 'dataset-2',
        'class_map': {
            'glioma': 'glioma',
            'meningioma': 'meningioma',
            'pituitary tumor': 'pituitary'
        }
    },
    {
        'id': 'dataset-3',
        'hospital': 'Tata Memorial Mumbai',
        'path': DATA_ROOT / 'dataset-3' / 'Brain_Cancer raw MRI data' / 'Brain_Cancer',
        'class_map': {
            'brain_glioma': 'glioma',
            'brain_menin': 'meningioma',
            'brain_tumor': 'pituitary'
        }
    }
}

logger.info(f'Configured {len(DATASETS)} hospital datasets')
for ds in DATASETS:
    logger.info(f"  - {ds['hospital']}: {ds['path']}")

CACHE_INDEX_FILE = RUN_DIR / 'dataset_index.json'
CACHE_SPLITS_FILE = RUN_DIR / 'dataset_splits.json'

BATCH_SIZE = 32
NUM_EPOCHS = 10
LEARNING_RATE = 1e-4
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logger.info(f'Using device: {DEVICE}')

In [None]:
IMAGE_EXTS = {'.jpg', '.jpeg', '.png'}

def normalize_folder_name(name):
    cleaned = name.strip().lower()
    cleaned = cleaned.replace('_', ' ').replace('-', ' ')
    cleaned = ' '.join(cleaned.split())
    return cleaned

In [None]:
def resolve_label(folder_name, class_map):
    normalized_map = {normalize_folder_name(k): v for k, v in class_map.items()}
    normalized = normalize_folder_name(folder_name)

    if normalized in normalized_map:
        return normalized_map[normalized], normalized

    collapsed = normalized.replace(' ', '')
    candidates = [
        key for key in normalized_map
        if key.replace(' ', '') == collapsed
    ]
    if len(candidates) == 1:
        return normalized_map[candidates[0]], candidates[0]

    candidates = [key for key in normalized_map if key in normalized or normalized in key]
    if len(candidates) == 1:
        return normalized_map[candidates[0]], candidates[0]

    return None, None

In [None]:
def audit_dataset_structure(dataset_config):
    base_path = dataset_config['path']
    summary = {
        'dataset_id': dataset_config['id'],
        'hospital': dataset_config['hospital'],
        'base_path': base_path.as_posix(),
        'exists': base_path.exists(),
        'subdirs': [],
        'total_images': 0,
        'images_by_dir': {}
    }

    if not base_path.exists():
        logger.error(f"Dataset path missing: {base_path}")
        return summary

    subdirs = [d for d in base_path.iterdir() if d.is_dir()]
    summary['subdirs'] = [d.name for d in subdirs]
    if not subdirs:
        logger.error(f"No class folders found in {base_path}")
        return summary

    for class_dir in subdirs:
        count = 0
        for file_path in class_dir.rglob('*'):
            if file_path.suffix.lower() in IMAGE_EXTS:
                count += 1
        summary['images_by_dir'][class_dir.name] = count
        summary['total_images'] += count

    logger.info(f"Audit {summary['hospital']}: {summary['total_images']} images across {len(subdirs)} folders")
    logger.info(f"  Folders: {summary['subdirs']}")
    logger.info(f"  Images by folder: {summary['images_by_dir']}")
    return summary

In [None]:
def collect_images(dataset_config):
    """Collect all images from a dataset directory."""
    base_path = dataset_config['path']
    class_map = dataset_config['class_map']
    items = []
    matched_folders = set()
    unmatched_folders = []

    if not base_path.exists():
        logger.warning(f"Dataset path does not exist: {base_path}")
        return items, {
            'matched_folders': [],
            'unmatched_folders': [],
            'images_found': 0
        }

    for class_dir in base_path.iterdir():
        if not class_dir.is_dir():
            continue
        label, matched_key = resolve_label(class_dir.name, class_map)
        if label is None:
            unmatched_folders.append(class_dir.name)
            logger.warning(f"Unmapped class folder: {class_dir.name}")
            continue
        matched_folders.add(class_dir.name)
        count = 0
        for file_path in class_dir.rglob('*'):
            if file_path.suffix.lower() in IMAGE_EXTS:
                items.append({
                    'path': file_path.as_posix(),
                    'label': label,
                    'dataset_id': dataset_config['id'],
                    'hospital': dataset_config['hospital']
                })
                count += 1
        logger.info(f"  {dataset_config['hospital']} - {label}: {count} images")

    return items, {
        'matched_folders': sorted(matched_folders),
        'unmatched_folders': sorted(unmatched_folders),
        'images_found': len(items)
    }

In [None]:
logger.info('Starting data collection from all datasets...')
all_samples = []
dataset_stats = {}
audit_results = []

for cfg in DATASETS:
    audit = audit_dataset_structure(cfg)
    audit_results.append(audit)
    if not audit['exists'] or audit['total_images'] == 0:
        logger.error(f"Dataset audit failed for {cfg['hospital']} ({cfg['id']})")
        continue

    logger.info(f"Collecting data from {cfg['hospital']} ({cfg['id']})...")
    samples, collection_info = collect_images(cfg)
    all_samples.extend(samples)

    dataset_stats[cfg['hospital']] = {
        'total_samples': len(samples),
        'dataset_id': cfg['id'],
        'location': HOSPITAL_CONFIGS[cfg['hospital']]['location'],
        'specialty': HOSPITAL_CONFIGS[cfg['hospital']]['specialty'],
        'class_distribution': {},
        'unmatched_folders': collection_info['unmatched_folders']
    }

    if collection_info['unmatched_folders']:
        logger.warning(
            f"Unmatched folders for {cfg['hospital']}: {collection_info['unmatched_folders']}"
        )

    for sample in samples:
        label = sample['label']
        dataset_stats[cfg['hospital']]['class_distribution'][label] = \
            dataset_stats[cfg['hospital']]['class_distribution'].get(label, 0) + 1

total_images = sum(a['total_images'] for a in audit_results)
if total_images == 0 or len(all_samples) == 0:
    raise ValueError(
        'No images were discovered. Check dataset paths and class folder names.'
    )

logger.info(f'Total samples discovered: {len(all_samples)}')
logger.info('Dataset statistics:')
for hospital, stats in dataset_stats.items():
    logger.info(f"  {hospital}: {stats['total_samples']} samples")
    logger.info(f"    Distribution: {stats['class_distribution']}")

with open(CACHE_DIR / 'dataset_stats.json', 'w') as f:
    json.dump(dataset_stats, f, indent=2)
logger.info('Saved dataset statistics to artifacts/dataset_stats.json')

In [None]:
def create_splits(samples, seed=SEED, val_ratio=0.15, test_ratio=0.15):
    rng = np.random.default_rng(seed)
    indices = np.arange(len(samples))
    rng.shuffle(indices)
    test_size = int(len(samples) * test_ratio)
    val_size = int(len(samples) * val_ratio)
    test_idx = indices[:test_size]
    val_idx = indices[test_size:test_size + val_size]
    train_idx = indices[test_size + val_size:]
    return train_idx.tolist(), val_idx.tolist(), test_idx.tolist()

logger.info('Creating train/val/test splits...')
train_idx, val_idx, test_idx = create_splits(all_samples)
splits = {
    'train_idx': train_idx,
    'val_idx': val_idx,
    'test_idx': test_idx,
    'split_date': datetime.now().isoformat()
}
with open(CACHE_SPLITS_FILE, 'w') as f:
    json.dump(splits, f, indent=2)

train_samples = [all_samples[i] for i in train_idx]
val_samples = [all_samples[i] for i in val_idx]
test_samples = [all_samples[i] for i in test_idx]

logger.info(f'Train: {len(train_samples)} | Val: {len(val_samples)} | Test: {len(test_samples)}')

In [None]:
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

eval_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [None]:
class BrainTumorDataset(Dataset):
    def __init__(self, samples, transform=None):
        self.samples = samples
        self.transform = transform

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        item = self.samples[idx]
        img = Image.open(item['path']).convert('RGB')
        if self.transform:
            img = self.transform(img)
        label = LABEL_MAP[item['label']]
        return img, label

In [None]:
train_dataset = BrainTumorDataset(train_samples, transform=train_transform)
val_dataset = BrainTumorDataset(val_samples, transform=eval_transform)
test_dataset = BrainTumorDataset(test_samples, transform=eval_transform)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

logger.info('Data loaders created successfully')

In [None]:
logger.info('Initializing ResNet18 model...')
model = models.resnet18(weights=None)
model.fc = nn.Linear(model.fc.in_features, len(LABEL_MAP))
model = model.to(DEVICE)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
logger.info(f'Model initialized on {DEVICE}')

In [None]:
def train_one_epoch(model, loader):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for images, labels in loader:
        images, labels = images.to(DEVICE), labels.to(DEVICE)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    epoch_loss = running_loss / len(loader.dataset)
    epoch_acc = correct / total
    return epoch_loss, epoch_acc

In [None]:
def evaluate(model, loader, desc=''):
    model.eval()
    y_true = []
    y_pred = []
    running_loss = 0.0

    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = model(images)
            loss = criterion(outputs, labels)
            running_loss += loss.item() * images.size(0)

            preds = torch.argmax(outputs, dim=1)
            y_true.extend(labels.cpu().numpy())
            y_pred.extend(preds.cpu().numpy())

    class_labels = [IDX_TO_LABEL[i] for i in range(len(LABEL_MAP))]
    report_dict = classification_report(
        y_true,
        y_pred,
        target_names=class_labels,
        output_dict=True,
        zero_division=0
    )
    per_class = {}
    for label in class_labels:
        metrics = report_dict.get(label, {})
        per_class[label] = {
            'precision': metrics.get('precision', 0.0),
            'recall': metrics.get('recall', 0.0),
            'f1': metrics.get('f1-score', 0.0),
            'support': int(metrics.get('support', 0))
        }

    metrics = {
        'loss': running_loss / len(loader.dataset),
        'accuracy': accuracy_score(y_true, y_pred),
        'macro_f1': f1_score(y_true, y_pred, average='macro', zero_division=0),
        'macro_precision': precision_score(y_true, y_pred, average='macro', zero_division=0),
        'macro_recall': recall_score(y_true, y_pred, average='macro', zero_division=0),
        'per_class': per_class,
        'confusion_matrix': confusion_matrix(y_true, y_pred).tolist(),
        'report_text': classification_report(
            y_true,
            y_pred,
            target_names=class_labels,
            zero_division=0
        )
    }

    if desc:
        logger.info(
            f'{desc} - Loss: {metrics["loss"]:.4f}, '
            f'Acc: {metrics["accuracy"]:.4f}, '
            f'F1: {metrics["macro_f1"]:.4f}'
        )

    return metrics

In [None]:
logger.info('=' * 60)
logger.info('Starting federated training simulation...')
logger.info('=' * 60)

training_history = []
best_val_acc = 0.0
best_epoch = 0

for epoch in range(NUM_EPOCHS):
    logger.info(f'Epoch {epoch + 1}/{NUM_EPOCHS}')

    train_loss, train_acc = train_one_epoch(model, train_loader)
    logger.info(f'  Train - Loss: {train_loss:.4f}, Acc: {train_acc:.4f}')

    val_metrics = evaluate(model, val_loader, desc='  Val')

    training_history.append({
        'epoch': epoch + 1,
        'train_loss': train_loss,
        'train_accuracy': train_acc,
        'val_loss': val_metrics['loss'],
        'val_accuracy': val_metrics['accuracy']
    })

    if val_metrics['accuracy'] > best_val_acc:
        best_val_acc = val_metrics['accuracy']
        best_epoch = epoch + 1
        torch.save(model.state_dict(), ARTIFACTS_DIR / 'global_model.pth')
        logger.info(f'  *** New best model saved (Acc: {best_val_acc:.4f}) ***')

logger.info('=' * 60)
logger.info(f'Training completed. Best validation accuracy: {best_val_acc:.4f} at epoch {best_epoch}')
logger.info('=' * 60)

with open(CACHE_DIR / 'training_history.json', 'w') as f:
    json.dump(training_history, f, indent=2)
logger.info('Saved training history')

In [None]:
logger.info('Loading best model for final evaluation...')
model.load_state_dict(torch.load(ARTIFACTS_DIR / 'global_model.pth', map_location=DEVICE))

test_metrics = evaluate(model, test_loader, desc='Test')

logger.info('=' * 60)
logger.info('FINAL TEST RESULTS:')
logger.info(f"  Accuracy: {test_metrics['accuracy']:.4f}")
logger.info(f"  Macro F1: {test_metrics['macro_f1']:.4f}")
logger.info(f"  Macro Precision: {test_metrics['macro_precision']:.4f}")
logger.info(f"  Macro Recall: {test_metrics['macro_recall']:.4f}")
logger.info('=' * 60)
print('\nPer-class Performance:')
print(test_metrics['report_text'])

with open(ARTIFACTS_DIR / 'label_map.json', 'w') as f:
    json.dump(LABEL_MAP, f, indent=2)

meta = {
    'trained_at': datetime.now().isoformat(),
    'num_classes': len(LABEL_MAP),
    'num_epochs': NUM_EPOCHS,
    'best_epoch': best_epoch,
    'device': str(DEVICE),
    'datasets': dataset_stats,
    'total_samples': len(all_samples),
    'train_samples': len(train_samples),
    'val_samples': len(val_samples),
    'test_samples': len(test_samples),
    'metrics': {
        'test_accuracy': test_metrics['accuracy'],
        'avg_f1': test_metrics['macro_f1'],
        'avg_precision': test_metrics['macro_precision'],
        'avg_recall': test_metrics['macro_recall'],
        'per_class': test_metrics['per_class'],
        'confusion_matrix': test_metrics['confusion_matrix'],
        'best_val_accuracy': best_val_acc
    }
}

with open(ARTIFACTS_DIR / 'model_meta.json', 'w') as f:
    json.dump(meta, f, indent=2)

logger.info(f'All artifacts saved to {ARTIFACTS_DIR}')
logger.info('Training pipeline completed successfully!')