In [None]:
#!/usr/bin/env python
# coding: utf-8

"""
LeSITA: Clean Paper-Compliant Implementation
===========================================

Based on Papers:
- Paper 3: Convolutional Sparse Coding with Side Information via Deep Unfolding
- Paper 4: Interpretable deep learning for multimodal super-resolution of medical images

Features:
- ✅ Complete ACSC-LMCSC architecture (Paper 3 & 4)
- ✅ Original vs Co-evolving forward modes
- ✅ Paper-compliant initialization and training
- ✅ Clean menu system for easy usage
- ✅ Comprehensive visualization and logging
- ✅ No code duplication, clean structure

Author: Based on LeSITA Papers - Clean Implementation
"""

import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
import h5py
import os
import json
from pathlib import Path
from torch.utils.data import Dataset, DataLoader
from datetime import datetime
import time

# Style for better plots
plt.style.use('default')

# =============================================================================
# DATA LOADING
# =============================================================================

class MRIDataset(Dataset):
    """MRI Dataset Loader for T2W (LR/HR) and T1W (guidance) images"""
    
    def __init__(self, h5_file, input_key='T2W/LRINPUT', target_key='T2W/TARGET', 
                 si_key='T1W/TARGET', augment=False):
        self.h5_file = h5_file
        with h5py.File(h5_file, 'r') as file:
            self.inputs = file[input_key][:]      # LR T2W
            self.targets = file[target_key][:]    # HR T2W  
            self.si = file[si_key][:]             # HR T1W (guidance)

        assert len(self.inputs) == len(self.targets) == len(self.si), \
               "Mismatch in dataset lengths."
        self.augment = augment

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        # Normalize to [0,1]
        input_img = self.inputs[idx].astype(np.float32) / 255.0
        target_img = self.targets[idx].astype(np.float32) / 255.0
        si_img = self.si[idx].astype(np.float32) / 255.0

        # Add channel dimension
        input_img = torch.from_numpy(input_img).unsqueeze(0)
        target_img = torch.from_numpy(target_img).unsqueeze(0)
        si_img = torch.from_numpy(si_img).unsqueeze(0)

        # Data augmentation if enabled
        if self.augment:
            if torch.rand(1) > 0.5:  # Horizontal flip
                input_img = torch.flip(input_img, [2])
                target_img = torch.flip(target_img, [2])
                si_img = torch.flip(si_img, [2])
            if torch.rand(1) > 0.5:  # Vertical flip
                input_img = torch.flip(input_img, [1])
                target_img = torch.flip(target_img, [1])
                si_img = torch.flip(si_img, [1])
            if torch.rand(1) > 0.5:  # Rotation
                k = torch.randint(0, 4, (1,))
                input_img = torch.rot90(input_img, k.item(), [1, 2])
                target_img = torch.rot90(target_img, k.item(), [1, 2])
                si_img = torch.rot90(si_img, k.item(), [1, 2])

        return input_img, target_img, si_img

def get_dataloader(h5_file, batch_size=32, shuffle=True, augment=False):
    """Create DataLoader for MRI dataset"""
    dataset = MRIDataset(h5_file, augment=augment)
    return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)

# =============================================================================
# CORE ACTIVATION FUNCTIONS (Paper 3)
# =============================================================================

class ShLU(nn.Module):
    """
    Soft Thresholding Operator (φγ) from Papers
    Equation: φγ(vi) = sign(vi) * max{0, |vi| − γ}
    
    From Paper 4: "The initial value of the parameters γ, μ is set to 0.1"
    """
    def __init__(self, threshold=0.1):
        super(ShLU, self).__init__()
        self.threshold = nn.Parameter(torch.tensor(threshold, dtype=torch.float32))
        
    def forward(self, x):
        return torch.sign(x) * torch.clamp(torch.abs(x) - self.threshold, min=0.0)

class LeSITA(nn.Module):
    """
    LeSITA Proximal Operator (ξμ) from Paper 3
    Equations (13) and (14): ξμ(vi; zi) - different behavior for zi ≥ 0 and zi < 0
    """
    def __init__(self, mu=0.1):
        super(LeSITA, self).__init__()
        self.mu = nn.Parameter(torch.tensor(mu, dtype=torch.float32))
        
    def forward(self, u, z):
        """
        u: input tensor (vi in equations)  
        z: side information tensor (zi in equations)
        """
        zeros = torch.zeros_like(u)
        
        # Equation (13): For zi ≥ 0
        output = torch.where((z >= 0) & (u < -2*self.mu), u + 2*self.mu, zeros)
        output = torch.where((z >= 0) & (-2*self.mu <= u) & (u <= 0), zeros, output)
        output = torch.where((z >= 0) & (0 < u) & (u < z), u, output)
        output = torch.where((z >= 0) & (z <= u) & (u <= z + 2*self.mu), z, output)
        output = torch.where((z >= 0) & (u > z + 2*self.mu), u - 2*self.mu, output)
        
        # Equation (14): For zi < 0
        output = torch.where((z < 0) & (u < z - 2*self.mu), u + 2*self.mu, output)
        output = torch.where((z < 0) & (z - 2*self.mu <= u) & (u <= z), z, output)
        output = torch.where((z < 0) & (z < u) & (u < 0), u, output)
        output = torch.where((z < 0) & (0 <= u) & (u <= 2*self.mu), zeros, output)
        output = torch.where((z < 0) & (u > 2*self.mu), u - 2*self.mu, output)
        
        return output

# =============================================================================
# CORE NETWORK MODULES
# =============================================================================

class ACSC(nn.Module):
    """
    ACSC Model from Papers
    Equation (8): Z^t = φγ(Z^{t-1} - T * V * Z^{t-1} + G * Ω)
    """
    def __init__(self, num_stages=3, num_filters=85, kernel_size=7, threshold=0.1):
        super(ACSC, self).__init__()
        self.num_stages = num_stages
        
        self.shlu = ShLU(threshold)
        self.G = nn.Conv2d(1, num_filters, kernel_size, padding=kernel_size//2)
        
        self.T_layers = nn.ModuleList([
            nn.Conv2d(num_filters, num_filters, kernel_size, padding=kernel_size//2)
            for _ in range(num_stages)
        ])
        
        self.V_layers = nn.ModuleList([
            nn.Conv2d(num_filters, num_filters, kernel_size, padding=kernel_size//2)
            for _ in range(num_stages)
        ])
        
    def forward(self, omega):
        """Original ACSC forward pass"""
        batch_size, channels, height, width = omega.shape
        Z = torch.zeros(batch_size, self.G.out_channels, height, width, 
                       device=omega.device, dtype=omega.dtype)
        
        G_omega = self.G(omega)
        
        for t in range(self.num_stages):
            conv_V = self.V_layers[t](Z)
            conv_TV = self.T_layers[t](conv_V)
            Z = self.shlu(Z - conv_TV + G_omega)
            
        return Z
    
    def forward_single_stage(self, Z, omega, stage_idx):
        """Single stage forward for co-evolving mode"""
        if not hasattr(self, '_cached_G_omega') or stage_idx == 0:
            self._cached_G_omega = self.G(omega)
        
        conv_V = self.V_layers[stage_idx](Z)
        conv_TV = self.T_layers[stage_idx](conv_V)
        Z = self.shlu(Z - conv_TV + self._cached_G_omega)
        
        return Z

class LMCSC(nn.Module):
    """
    LMCSC Model from Papers
    Equation (7): U^t = ξμ(U^{t-1} - Q * R * U^{t-1} + P * Y ; Z)
    """
    def __init__(self, num_stages=3, num_filters=85, kernel_size=7, mu=0.1):
        super(LMCSC, self).__init__()
        self.num_stages = num_stages
        
        self.lesita = LeSITA(mu)
        self.P = nn.Conv2d(1, num_filters, kernel_size, padding=kernel_size//2)
        
        self.Q_layers = nn.ModuleList([
            nn.Conv2d(num_filters, num_filters, kernel_size, padding=kernel_size//2)
            for _ in range(num_stages)
        ])
        
        self.R_layers = nn.ModuleList([
            nn.Conv2d(num_filters, num_filters, kernel_size, padding=kernel_size//2)
            for _ in range(num_stages)
        ])
        
    def forward(self, Y, Z):
        """Original LMCSC forward pass"""
        batch_size, channels, height, width = Y.shape
        U = torch.zeros(batch_size, self.P.out_channels, height, width,
                       device=Y.device, dtype=Y.dtype)
        
        P_Y = self.P(Y)
        
        for t in range(self.num_stages):
            conv_R = self.R_layers[t](U)
            conv_QR = self.Q_layers[t](conv_R)
            U = self.lesita(U - conv_QR + P_Y, Z)
            
        return U

    def forward_single_stage(self, U, Y, Z, stage_idx):
        """Single stage forward for co-evolving mode"""
        if not hasattr(self, '_cached_P_Y') or stage_idx == 0:
            self._cached_P_Y = self.P(Y)
        
        conv_R = self.R_layers[stage_idx](U)
        conv_QR = self.Q_layers[stage_idx](conv_R)
        U = self.lesita(U - conv_QR + self._cached_P_Y, Z)
        
        return U

class ReconstructionModule(nn.Module):
    """
    Reconstruction Module from Papers
    Converts sparse coefficients to HR image
    """
    def __init__(self, num_filters=85, kernel_size=7):
        super(ReconstructionModule, self).__init__()
        self.dictionary = nn.Conv2d(num_filters, 1, kernel_size, padding=kernel_size//2)
        
    def forward(self, u):
        """u: sparse coefficients from LMCSC"""
        # Unit norm constraint for dictionary stability
        with torch.no_grad():
            weight = self.dictionary.weight
            weight_norm = torch.norm(weight.view(weight.size(0), -1), dim=1, keepdim=True)
            weight_norm = weight_norm.view(-1, 1, 1, 1)
            weight_norm = torch.clamp(weight_norm, min=1e-8)
            self.dictionary.weight.data = weight / weight_norm
        
        return self.dictionary(u)

# =============================================================================
# MAIN NETWORK
# =============================================================================

class LMCSCNetwork(nn.Module):
    """
    Complete LMCSC Network from Papers
    Supports both Original and Co-evolving forward modes
    """
    def __init__(self, 
                 num_stages_acsc=3, 
                 num_stages_lmcsc=3, 
                 num_filters=85, 
                 kernel_size=7,
                 threshold=0.1,
                 mu=0.1):
        super(LMCSCNetwork, self).__init__()
        
        # Paper 4 parameters
        self.acsc = ACSC(
            num_stages=num_stages_acsc,
            num_filters=num_filters, 
            kernel_size=kernel_size,
            threshold=threshold
        )
        
        self.lmcsc = LMCSC(
            num_stages=num_stages_lmcsc,
            num_filters=num_filters,
            kernel_size=kernel_size, 
            mu=mu
        )
        
        self.reconstruction = ReconstructionModule(
            num_filters=num_filters,
            kernel_size=kernel_size
        )

        self._initialize_weights()

    def _initialize_weights(self):
        """Initialization according to Paper 4"""
        print("🔧 Initializing LMCSC Network according to Paper 4...")
        
        for name, module in self.named_modules():
            if isinstance(module, nn.Conv2d):
                nn.init.normal_(module.weight, mean=0.0, std=0.01)
                if module.bias is not None:
                    nn.init.zeros_(module.bias)
                
            elif isinstance(module, ShLU):
                with torch.no_grad():
                    module.threshold.data.fill_(0.1)
                
            elif isinstance(module, LeSITA):
                with torch.no_grad():
                    module.mu.data.fill_(0.1)
        
        print("✅ Initialization completed according to Paper 4!")
        
    def forward(self, y, omega, mode='original'):
        """
        Forward pass with mode selection
        
        Args:
            y: LR target image (T2W)
            omega: HR guidance image (T1W)  
            mode: 'original' or 'coevolving'
        """
        if mode == 'original':
            return self._forward_original(y, omega)
        elif mode == 'coevolving':
            return self._forward_coevolving(y, omega)
        else:
            raise ValueError(f"Unknown mode: {mode}. Use 'original' or 'coevolving'")
    
    def _forward_original(self, y, omega):
        """Original forward pass (sequential ACSC then LMCSC)"""
        z = self.acsc(omega)
        u = self.lmcsc(y, z)
        out = self.reconstruction(u)
        return out, u, z

    def _forward_coevolving(self, y, omega):
        """Co-evolving forward pass (simultaneous ACSC and LMCSC)"""
        batch_size, channels, height, width = y.shape
        
        # Initialize sparse codes
        Z = torch.zeros(batch_size, 85, height, width, device=omega.device, dtype=omega.dtype)
        U = torch.zeros(batch_size, 85, height, width, device=y.device, dtype=y.dtype)
        
        # Co-evolving iterations
        for stage in range(3):
            Z = self.acsc.forward_single_stage(Z, omega, stage)
            U = self.lmcsc.forward_single_stage(U, y, Z, stage)
        
        out = self.reconstruction(U)
        return out, U, Z

# =============================================================================
# METRICS
# =============================================================================

def psnr(y_true, y_pred, max_val=1.0):
    """PSNR metric as used in papers"""
    mse = torch.mean((y_true - y_pred) ** 2)
    if mse == 0:
        return float('inf')
    return 20 * torch.log10(max_val / torch.sqrt(mse))

# =============================================================================
# TRAINING SYSTEM
# =============================================================================

class LMCSCTrainer:
    """Training setup according to Paper 4"""
    
    def __init__(self, model, device, mode='original'):
        self.model = model.to(device)
        self.device = device
        self.mode = mode  # 'original' or 'coevolving'
        
        # Paper 4 training setup
        self.criterion = nn.MSELoss()
        self.optimizer = optim.Adam(self.model.parameters(), lr=1e-4)
        
        print(f"🚀 LMCSCTrainer initialized")
        print(f"   Mode: {mode}")
        print(f"   Loss: MSE | Optimizer: Adam (lr=1e-4)")
        print(f"   Device: {device}")

    def train_epoch(self, dataloader):
        """Training for one epoch"""
        start_time = time.time()
        
        self.model.train()
        total_loss = 0.0
        total_psnr = 0.0

        for batch_idx, (x_lr, x_hr, x_si) in enumerate(dataloader):
            x_lr = x_lr.to(self.device)
            x_hr = x_hr.to(self.device)
            x_si = x_si.to(self.device)
            
            self.optimizer.zero_grad()
            
            # Forward pass with selected mode
            pred_hr, sparse_u, sparse_z = self.model(x_lr, x_si, mode=self.mode)
            
            loss = self.criterion(pred_hr, x_hr)
            loss.backward()
            self.optimizer.step()
            
            total_loss += loss.item()
            
            with torch.no_grad():
                batch_psnr = psnr(x_hr, pred_hr)
                total_psnr += batch_psnr.item() if torch.isfinite(batch_psnr) else 0.0
        
        epoch_time = time.time() - start_time
        avg_loss = total_loss / len(dataloader)
        avg_psnr = total_psnr / len(dataloader)
        
        return avg_loss, avg_psnr, epoch_time
    
    def validate(self, dataloader):
        """Validation loop"""
        self.model.eval()
        total_loss = 0.0
        total_psnr = 0.0
        
        with torch.no_grad():
            for x_lr, x_hr, x_si in dataloader:
                x_lr = x_lr.to(self.device)
                x_hr = x_hr.to(self.device) 
                x_si = x_si.to(self.device)

                pred_hr, _, _ = self.model(x_lr, x_si, mode=self.mode)
                
                loss = self.criterion(pred_hr, x_hr)
                
                total_loss += loss.item()
                batch_psnr = psnr(x_hr, pred_hr)
                total_psnr += batch_psnr.item() if torch.isfinite(batch_psnr) else 0.0
        
        avg_loss = total_loss / len(dataloader)
        avg_psnr = total_psnr / len(dataloader)
        
        return avg_loss, avg_psnr
    
    def train(self, train_loader, val_loader, epochs=50, save_path="lesita_best.pt"):
        """Complete training loop"""
        print(f"\n🚀 Starting LeSITA Training ({self.mode.upper()} mode)")
        print(f"📊 Configuration: {epochs} epochs | Mode: {self.mode}")
        print("-" * 60)
        
        best_psnr = 0.0
        training_log = []
        
        for epoch in range(epochs):
            # Training
            train_loss, train_psnr, epoch_time = self.train_epoch(train_loader)
            
            # Validation
            val_loss, val_psnr = self.validate(val_loader)
            
            # Logging
            is_best = val_psnr > best_psnr
            if is_best:
                best_psnr = val_psnr
                # Save best model
                torch.save({
                    'epoch': epoch + 1,
                    'model_state_dict': self.model.state_dict(),
                    'optimizer_state_dict': self.optimizer.state_dict(),
                    'best_psnr': best_psnr,
                    'mode': self.mode,
                    'config': {
                        'num_stages_acsc': 3,
                        'num_stages_lmcsc': 3,
                        'num_filters': 85,
                        'kernel_size': 7,
                        'threshold': 0.1,
                        'mu': 0.1
                    }
                }, save_path)
            
            # Store training log
            training_log.append({
                'epoch': epoch + 1,
                'train_loss': train_loss,
                'train_psnr': train_psnr,
                'val_loss': val_loss,
                'val_psnr': val_psnr,
                'epoch_time': epoch_time,
                'is_best': is_best
            })
            
            # Print progress
            mode_icon = "🔄" if self.mode == "coevolving" else "📘"
            print(f"{mode_icon} Epoch [{epoch+1:3d}/{epochs}] | "
                  f"Train: Loss={train_loss:.6f}, PSNR={train_psnr:.2f}dB | "
                  f"Val: Loss={val_loss:.6f}, PSNR={val_psnr:.2f}dB | "
                  f"Time: {epoch_time:.1f}s"
                  + (f" 🌟 NEW BEST!" if is_best else ""))
        
        print(f"\n✅ Training completed!")
        print(f"🏆 Best validation PSNR: {best_psnr:.2f}dB")
        print(f"💾 Model saved: {save_path}")
        
        return best_psnr, training_log

# =============================================================================
# TESTING & EVALUATION
# =============================================================================

def test_model(model_path, test_loader, device):
    """Test trained model and return results"""
    print("🧪 Starting Model Testing...")
    
    # Load model
    checkpoint = torch.load(model_path, map_location=device)
    config = checkpoint.get('config', {
        'num_stages_acsc': 3, 'num_stages_lmcsc': 3, 'num_filters': 85,
        'kernel_size': 7, 'threshold': 0.1, 'mu': 0.1
    })
    
    model = LMCSCNetwork(**config).to(device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    
    mode = checkpoint.get('mode', 'original')
    best_psnr = checkpoint.get('best_psnr', 0.0)
    
    print(f"📂 Loaded model: {model_path}")
    print(f"🎯 Mode: {mode} | Best training PSNR: {best_psnr:.2f}dB")
    print("-" * 50)
    
    # Testing
    total_psnr = 0.0
    total_samples = 0
    results = []
    
    with torch.no_grad():
        for batch_idx, (x_lr, x_hr, x_si) in enumerate(test_loader):
            x_lr = x_lr.to(device)
            x_hr = x_hr.to(device)
            x_si = x_si.to(device)
            
            # Inference with original model mode
            pred_hr, sparse_u, sparse_z = model(x_lr, x_si, mode=mode)
            
            # Calculate PSNR for each image
            for i in range(x_lr.shape[0]):
                img_psnr = psnr(x_hr[i:i+1], pred_hr[i:i+1]).item()
                total_psnr += img_psnr
                total_samples += 1
                
                # Store first 5 for visualization
                if len(results) < 5:
                    results.append({
                        'lr': x_lr[i].cpu(),
                        'hr_gt': x_hr[i].cpu(), 
                        'hr_pred': pred_hr[i].cpu(),
                        'guidance': x_si[i].cpu(),
                        'sparse_u': sparse_u[i].cpu(),
                        'sparse_z': sparse_z[i].cpu(),
                        'psnr': img_psnr
                    })
            
            if batch_idx % 5 == 0:
                print(f"   Processed batch {batch_idx+1}/{len(test_loader)}")
    
    avg_psnr = total_psnr / total_samples
    print(f"\n📊 RESULTS:")
    print(f"   Average PSNR: {avg_psnr:.2f}dB")
    print(f"   Total samples: {total_samples}")
    
    return results, avg_psnr

# =============================================================================
# VISUALIZATION
# =============================================================================

def visualize_results(results, save_path="lesita_results.png"):
    """Comprehensive visualization of results"""
    print("🎨 Creating visualization...")
    
    fig, axes = plt.subplots(len(results), 6, figsize=(20, 4*len(results)))
    if len(results) == 1:
        axes = axes.reshape(1, -1)
    
    for idx, result in enumerate(results):
        lr = result['lr'][0].numpy()
        hr_gt = result['hr_gt'][0].numpy()
        hr_pred = result['hr_pred'][0].detach().numpy()
        guidance = result['guidance'][0].numpy()
        sparse_u = result['sparse_u'].mean(dim=0).numpy()
        sparse_z = result['sparse_z'].mean(dim=0).numpy()
        
        error_map = np.abs(hr_gt - hr_pred)
        
        # Plots
        axes[idx, 0].imshow(lr, cmap='gray', vmin=0, vmax=1)
        axes[idx, 0].set_title(f'LR Input (T2W)', fontsize=12)
        axes[idx, 0].axis('off')
        
        axes[idx, 1].imshow(guidance, cmap='gray', vmin=0, vmax=1)
        axes[idx, 1].set_title(f'Guidance (T1W)', fontsize=12)
        axes[idx, 1].axis('off')
        
        axes[idx, 2].imshow(hr_gt, cmap='gray', vmin=0, vmax=1)
        axes[idx, 2].set_title(f'Ground Truth', fontsize=12)
        axes[idx, 2].axis('off')
        
        axes[idx, 3].imshow(hr_pred, cmap='gray', vmin=0, vmax=1)
        axes[idx, 3].set_title(f'LeSITA Output\n(PSNR: {result["psnr"]:.2f}dB)', fontsize=12)
        axes[idx, 3].axis('off')
        
        axes[idx, 4].imshow(error_map, cmap='hot', vmin=0, vmax=0.1)
        axes[idx, 4].set_title(f'Error Map', fontsize=12)
        axes[idx, 4].axis('off')
        
        sparse_diff = np.abs(sparse_u - sparse_z)
        im = axes[idx, 5].imshow(sparse_diff, cmap='viridis')
        axes[idx, 5].set_title(f'|Sparse U - Z|', fontsize=12)
        axes[idx, 5].axis('off')
        plt.colorbar(im, ax=axes[idx, 5], fraction=0.046, pad=0.04)
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()
    
    print(f"💾 Visualization saved: {save_path}")

# =============================================================================
# MAIN PIPELINE FUNCTIONS
# =============================================================================

def setup_paths():
    """Setup and validate data paths"""
    # Default paths - modify these for your setup
    train_path = r"D:\Diploma Thesis\MICCAI DATASETS\LRx4_MStrain_flair_t1w_t2w_44_unnormalized.h5py"
    test_path = r"D:\Diploma Thesis\MICCAI DATASETS\LRx4_MStest_flair_t1w_t2w_unnormalized.h5py"
    
    # Validate paths
    if not os.path.exists(train_path):
        print(f"❌ Training file not found: {train_path}")
        return None, None
    if not os.path.exists(test_path):
        print(f"❌ Test file not found: {test_path}")
        return None, None
    
    print(f"✅ Training file: {os.path.basename(train_path)} ({os.path.getsize(train_path)/1e6:.1f}MB)")
    print(f"✅ Test file: {os.path.basename(test_path)} ({os.path.getsize(test_path)/1e6:.1f}MB)")
    
    return train_path, test_path

def create_data_loaders(train_path, test_path, batch_size=32):
    """Create training and testing data loaders"""
    print("📊 Creating data loaders...")
    
    train_loader = get_dataloader(train_path, batch_size=batch_size, shuffle=True, augment