In [None]:
# Quantum MRI Reconstruction - Diagnostic Pipeline
# This is a standalone script for testing the quantum MRI reconstruction pipeline

# Install required packages
!pip install pennylane torch nibabel matplotlib scikit-image tqdm torchvision -q

# Import necessary libraries
import os
import torch
import numpy as np
import nibabel as nib
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader, random_split
from skimage.metrics import structural_similarity as ssim, peak_signal_noise_ratio as psnr
import pennylane as qml
from tqdm import tqdm
import time
from pathlib import Path
from torchvision.transforms import functional as TF
import torch.nn.functional as F


#To capture time
def timing_decorator(func):
    """Decorator to measure function execution time."""
    def wrapper(*args, **kwargs):
        start_time = time.time()
        result = func(*args, **kwargs)
        end_time = time.time()
        execution_time = end_time - start_time
        print(f"{func.__name__} executed in {execution_time:.2f} seconds")
        return result
    return wrapper

# Add this helper function at the top of your script (after the imports)
def convert_to_json_serializable(obj):
    """Convert NumPy types to Python native types for JSON serialization."""
    if isinstance(obj, (np.integer, np.int32, np.int64)):
        return int(obj)
    elif isinstance(obj, (np.floating, np.float32, np.float64)):
        return float(obj)
    elif isinstance(obj, np.ndarray):
        return obj.tolist()
    elif isinstance(obj, dict):
        return {k: convert_to_json_serializable(v) for k, v in obj.items()}
    elif isinstance(obj, list):
        return [convert_to_json_serializable(i) for i in obj]
    return obj

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

#=================================================================
# IMPROVED K-SPACE UNDERSAMPLING
#=================================================================

def get_sampling_prob(ny, mode='gaussian', scale=None):
    """
    Get sampling probability distribution along phase-encode dimension.
    """
    y_coords = np.arange(ny)
    center = ny / 2
    if mode == 'gaussian':
        std = ny / 6 if scale is None else ny / scale
        prob = np.exp(-((y_coords - center) ** 2) / (2 * std ** 2))
    elif mode == 'linear':
        prob = 1 - np.abs(y_coords - center) / center
    elif mode == 'uniform':
        prob = np.ones(ny)
    else:
        raise ValueError(f"Unknown mode: {mode}")
    prob /= prob.sum()
    return prob

def generate_variable_density_mask(ny, acceleration, center_fraction=0.08, mode='gaussian', scale=None, seed=None, visualize=False):
    """
    Generate a variable-density undersampling mask for k-space.
    """
    if seed is not None:
        np.random.seed(seed)

    mask = np.zeros(ny, dtype=bool)

    # Always sample center of k-space
    center_lines = int(ny * center_fraction)
    total_lines = ny // acceleration

    if total_lines <= center_lines:
        total_lines = center_lines

    center_start = (ny - center_lines) // 2
    mask[center_start:center_start + center_lines] = True

    # Sample remaining lines based on probability distribution
    remaining = total_lines - center_lines
    prob = get_sampling_prob(ny, mode=mode, scale=scale)

    if remaining > 0:
        y_coords = np.arange(ny)
        outside_center = np.setdiff1d(y_coords, np.arange(center_start, center_start + center_lines))
        prob_outside = prob[outside_center]
        prob_outside /= prob_outside.sum()
        chosen = np.random.choice(outside_center, size=remaining, replace=False, p=prob_outside)
        mask[chosen] = True

    final_sampled = np.sum(mask)
    if visualize:
        print(f"[INFO] Accel={acceleration} | Center={center_lines} | Random={remaining} | Total={final_sampled} / {ny}")
        plt.figure(figsize=(12, 3))
        plt.subplot(1, 2, 1)
        plt.plot(prob)
        plt.title(f"Sampling Probability ({mode})")
        plt.xlabel("Phase-Encode Line")
        plt.ylabel("Probability")

        plt.subplot(1, 2, 2)
        plt.imshow(mask.reshape(ny, 1).T, cmap='gray', aspect='auto')
        plt.title("Mask (1D)")
        plt.yticks([])
        plt.xlabel("Phase-Encode Line")
        plt.tight_layout()
        plt.show()

    return mask.reshape(ny, 1)

#=================================================================
# DATASET CLASS
#=================================================================

class DiagnosticNiftiDataset(Dataset):
    """
    A dataset class for NIfTI files with downsampling for faster processing.
    """
    def __init__(self, nifti_files, slice_range=None, acceleration_factor=4,
                center_fraction=0.08, mask_mode='gaussian', scale=8.5,
                target_size=(64, 64)):
        """
        Args:
            nifti_files: List of paths to NIfTI files
            slice_range: Tuple (min_slice, max_slice) or None to use all slices
            acceleration_factor: Undersampling factor (2x or 4x)
            center_fraction: Fraction of center k-space to fully sample
            mask_mode: Type of sampling pattern ('gaussian', 'linear', 'uniform')
            scale: Scale parameter for variable density sampling
            target_size: Target size to resize images to (e.g., (64, 64))
        """
        self.nifti_files = nifti_files
        self.slice_range = slice_range
        self.acceleration_factor = acceleration_factor
        self.center_fraction = center_fraction
        self.mask_mode = mask_mode
        self.scale = scale
        self.target_size = target_size

        # Create a list of (file_idx, slice_idx) tuples for all slices
        self.slices = []

        print(f"Loading {len(nifti_files)} NIfTI files...")
        for file_idx, file_path in enumerate(nifti_files):
            try:
                # Load NIfTI file
                img = nib.load(file_path)
                data = img.get_fdata()

                # Get number of slices (assuming last dimension is slices)
                n_slices = data.shape[-1]

                # Determine slice range
                min_slice = 0 if slice_range is None else max(0, slice_range[0])
                max_slice = n_slices if slice_range is None else min(n_slices, slice_range[1])

                # Add all slices to our list
                for slice_idx in range(min_slice, max_slice):
                    self.slices.append((file_idx, slice_idx))

                print(f"  File {file_path}: {max_slice - min_slice} slices")

            except Exception as e:
                print(f"Error loading {file_path}: {e}")

        print(f"Total slices: {len(self.slices)}")

    def _undersample_kspace(self, image):
        """
        Create an undersampled version of the image by applying
        a variable density mask in k-space.
        """
        ny, nx = image.shape

        # Generate the mask using the improved method
        mask_1d = generate_variable_density_mask(
            ny,
            self.acceleration_factor,
            center_fraction=self.center_fraction,
            mode=self.mask_mode,
            scale=self.scale,
            seed=None  # Different mask per slice for diversity
        )

        # Expand to 2D
        mask_2d = np.repeat(mask_1d, nx, axis=1)

        # Convert to k-space using FFT
        kspace = np.fft.fftshift(np.fft.fft2(image))

        # Apply mask
        masked_kspace = kspace * mask_2d

        # Convert back to image domain using inverse FFT
        zero_filled = np.abs(np.fft.ifft2(np.fft.ifftshift(masked_kspace)))

        return zero_filled, mask_2d

    def __len__(self):
        return len(self.slices)

    def __getitem__(self, idx):
        """
        Get a sample from the dataset with resizing.
        """
        file_idx, slice_idx = self.slices[idx]
        file_path = self.nifti_files[file_idx]

        # Load the NIfTI file
        img = nib.load(file_path)
        data = img.get_fdata()

        # Extract the slice
        slice_data = data[:, :, slice_idx]

        # Normalize intensity to [0, 1]
        slice_min = slice_data.min()
        slice_max = slice_data.max()
        if slice_max > slice_min:  # Avoid division by zero
            slice_data = (slice_data - slice_min) / (slice_max - slice_min)

        # Create undersampled version before resizing
        zero_filled, mask = self._undersample_kspace(slice_data)

        # Resize if target_size is specified
        if self.target_size is not None:
            # Convert to tensors for resizing
            slice_tensor = torch.from_numpy(slice_data).float().unsqueeze(0)
            zero_filled_tensor = torch.from_numpy(zero_filled).float().unsqueeze(0)
            mask_tensor = torch.from_numpy(mask.astype(np.float32)).unsqueeze(0)

            # Resize all tensors
            slice_tensor = TF.resize(slice_tensor, self.target_size,
                                   interpolation=TF.InterpolationMode.BILINEAR)
            zero_filled_tensor = TF.resize(zero_filled_tensor, self.target_size,
                                         interpolation=TF.InterpolationMode.BILINEAR)
            mask_tensor = TF.resize(mask_tensor, self.target_size,
                                  interpolation=TF.InterpolationMode.NEAREST)

            # Convert back to numpy arrays
            slice_data = slice_tensor.squeeze(0).numpy()
            zero_filled = zero_filled_tensor.squeeze(0).numpy()
            mask = mask_tensor.squeeze(0).numpy()

        # Convert to tensors
        fully_sampled = torch.from_numpy(slice_data).float().unsqueeze(0)
        zero_filled = torch.from_numpy(zero_filled).float().unsqueeze(0)
        mask = torch.from_numpy(mask.astype(np.float32)).unsqueeze(0)

        return zero_filled, fully_sampled, mask


[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.1/56.1 kB[0m [31m4.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.3/2.3 MB[0m [31m68.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m3.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m70.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m37.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m46.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m2.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m5.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [None]:

#=================================================================
# QUANTUM CIRCUIT
#=================================================================
class FourQubitQuantumCircuit:
    def __init__(self):
        self.dev = qml.device("default.qubit", wires=4)

        @qml.qnode(self.dev, interface="auto")  # use 'auto' to support NumPy or torch
        def circuit(x0, x1, x2, x3):
            qml.RY(x0, wires=0)
            qml.RZ(x0**2, wires=0)

            qml.RY(x1, wires=1)
            qml.RZ(x1**2, wires=1)

            qml.RY(x2, wires=2)
            qml.RZ(x2**2, wires=2)

            qml.RY(x3, wires=3)
            qml.RZ(x3**2, wires=3)

            qml.CNOT(wires=[0, 1])
            qml.CNOT(wires=[1, 2])
            qml.CNOT(wires=[2, 3])
            qml.CNOT(wires=[3, 0])  # full ring entanglement

            return [qml.expval(qml.PauliZ(i)) for i in range(4)]

        self.circuit = circuit

    def __call__(self, pixels_batch):
        pixels_batch = np.atleast_2d(pixels_batch)
        assert pixels_batch.shape[1] == 4, f"Each input patch must have 4 pixels (got shape {pixels_batch.shape})"

        outputs = []
        for pix in pixels_batch:
            x0, x1, x2, x3 = pix.tolist()
            out = self.circuit(x0, x1, x2, x3)
            outputs.append(out)
        return np.stack(outputs)

#=================================================================
# MODEL DEFINITIONS
#=================================================================

class FourQubitQuantumConv(torch.nn.Module):
    def __init__(self, stride=2):
        super().__init__()
        self.stride = stride
        self.circuit = FourQubitQuantumCircuit()

    def forward(self, x):
        batch_size, channels, height, width = x.shape
        out_h = (height - 1) // self.stride + 1
        out_w = (width - 1) // self.stride + 1
        out = torch.zeros((batch_size, 4, out_h, out_w), device=x.device)

        for b in range(batch_size):
            patches = []
            for i in range(0, height - 1, self.stride):
                for j in range(0, width - 1, self.stride):
                    patch = x[b, 0, i:i+2, j:j+2].reshape(-1)
                    if patch.shape[0] == 4:
                        patches.append(patch.detach().cpu().numpy())

            # Apply circuit in batch
            outputs = self.circuit(patches)
            outputs = torch.tensor(outputs, device=x.device)

            # Reshape into output image
            idx = 0
            for i in range(out_h):
                for j in range(out_w):
                    if idx < len(outputs):
                        out[b, :, i, j] = outputs[idx]
                        idx += 1

        return out

## MORE COMPLEX WITH ADDITIONAL CLASSICAL COMPLEXITY (QUANTUM KEPT SAME)
class DiagnosticQuantumUNet(torch.nn.Module):
    def __init__(self, stride=2):
        super().__init__()

        # Quantum convolution layer (still outputs 2 channels)
        self.quantum_conv = FourQubitQuantumConv(stride=stride)

        # Encoder path (expanded)
        self.enc1 = self.conv_block(4, 64)  # ✅ now expects 4-channel quantum output
        self.pool1 = torch.nn.MaxPool2d(2)

        self.enc2 = self.conv_block(64, 128)
        self.pool2 = torch.nn.MaxPool2d(2)

        self.enc3 = self.conv_block(128, 256)

        # Bottleneck
        self.bottleneck = self.conv_block(256, 512)

        # Decoder path (mirroring encoder)
        self.up1 = torch.nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.dec1 = self.conv_block(512, 256)  # 256 + 256

        self.up2 = torch.nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.dec2 = self.conv_block(256, 128)  # 128 + 128

        self.up3 = torch.nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.dec3 = self.conv_block(128, 64)   # 64 + 64

        # Final 1×1 conv + upsampling
        self.final = torch.nn.Conv2d(64, 1, kernel_size=1)
        self.final_up = torch.nn.ConvTranspose2d(1, 1, kernel_size=stride, stride=stride)

    def conv_block(self, in_ch, out_ch):
        return torch.nn.Sequential(
            torch.nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
            torch.nn.BatchNorm2d(out_ch),
            torch.nn.ReLU(inplace=True),
            torch.nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
            torch.nn.BatchNorm2d(out_ch),
            torch.nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.quantum_conv(x)

        # Encoder
        enc1 = self.enc1(x)
        x = self.pool1(enc1)

        enc2 = self.enc2(x)
        x = self.pool2(enc2)

        enc3 = self.enc3(x)
        x = self.bottleneck(enc3)

        # Decoder 1
        x = self.up1(x)
        x = self._pad_and_concat(x, enc3)
        x = self.dec1(x)

        # Decoder 2
        x = self.up2(x)
        x = self._pad_and_concat(x, enc2)
        x = self.dec2(x)

        # Decoder 3
        x = self.up3(x)
        x = self._pad_and_concat(x, enc1)
        x = self.dec3(x)

        x = self.final(x)
        x = self.final_up(x)

        return x

    def _pad_and_concat(self, upsampled, skip):
        diff_y = skip.size()[2] - upsampled.size()[2]
        diff_x = skip.size()[3] - upsampled.size()[3]
        upsampled = F.pad(upsampled, [diff_x // 2, diff_x - diff_x // 2,
                                      diff_y // 2, diff_y - diff_y // 2])
        return torch.cat([upsampled, skip], dim=1)

class DiagnosticClassicalUNet(torch.nn.Module):
    def __init__(self, stride=4):
        super().__init__()

        # Simulates quantum downsampling with classical 1x2 conv
        self.conv0 = torch.nn.Conv2d(1, 4, kernel_size=(2, 2), stride=stride)  # mimic 2×2 patch behavior
        self.relu0 = torch.nn.ReLU(inplace=True)

        # Encoder path (expanded)
        self.enc1 = self.conv_block(4, 64)
        self.pool1 = torch.nn.MaxPool2d(2)

        self.enc2 = self.conv_block(64, 128)
        self.pool2 = torch.nn.MaxPool2d(2)

        self.enc3 = self.conv_block(128, 256)

        # Bottleneck
        self.bottleneck = self.conv_block(256, 512)

        # Decoder path
        self.up1 = torch.nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.dec1 = self.conv_block(512, 256)

        self.up2 = torch.nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.dec2 = self.conv_block(256, 128)

        self.up3 = torch.nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.dec3 = self.conv_block(128, 64)

        # Final layer + resolution restore
        self.final = torch.nn.Conv2d(64, 1, kernel_size=1)
        self.final_up = torch.nn.ConvTranspose2d(1, 1, kernel_size=stride, stride=stride)

    def conv_block(self, in_ch, out_ch):
        return torch.nn.Sequential(
            torch.nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
            torch.nn.BatchNorm2d(out_ch),
            torch.nn.ReLU(inplace=True),
            torch.nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
            torch.nn.BatchNorm2d(out_ch),
            torch.nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.relu0(self.conv0(x))

        enc1 = self.enc1(x)
        x = self.pool1(enc1)

        enc2 = self.enc2(x)
        x = self.pool2(enc2)

        enc3 = self.enc3(x)
        x = self.bottleneck(enc3)

        # Decoder
        x = self.up1(x)
        x = self._pad_and_concat(x, enc3)
        x = self.dec1(x)

        x = self.up2(x)
        x = self._pad_and_concat(x, enc2)
        x = self.dec2(x)

        x = self.up3(x)
        x = self._pad_and_concat(x, enc1)
        x = self.dec3(x)

        x = self.final(x)
        x = self.final_up(x)

        return x

    def _pad_and_concat(self, upsampled, skip):
        diff_y = skip.size()[2] - upsampled.size()[2]
        diff_x = skip.size()[3] - upsampled.size()[3]
        upsampled = F.pad(upsampled, [diff_x // 2, diff_x - diff_x // 2,
                                      diff_y // 2, diff_y - diff_y // 2])
        return torch.cat([upsampled, skip], dim=1)


In [None]:

#=================================================================
# TRAINING FUNCTION
#=================================================================

def train_model(model, train_loader, val_loader, num_epochs=2, lr=0.001, device='cpu',
               save_path=None, model_name="quantum_mri"):
    """
    Train the model for a specified number of epochs.

    Args:
        model: The model to train
        train_loader: DataLoader for training
        val_loader: DataLoader for validation
        num_epochs: Number of epochs to train for
        lr: Learning rate
        device: Device to train on ('cpu' or 'cuda')
        save_path: Directory to save model checkpoints
        model_name: Name for saved model files

    Returns:
        Dictionary of training history
    """
    # Create save directory if needed
    if save_path is not None:
        os.makedirs(save_path, exist_ok=True)

    # Move model to device
    model = model.to(device)

    # Define loss function and optimizer
    criterion = torch.nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    # Prepare for training
    history = {
        'train_loss': [],
        'val_loss': [],
        'val_metrics': []
    }

    best_val_loss = float('inf')

    # Training loop
    for epoch in range(num_epochs):
        # Training phase
        model.train()
        train_loss = 0.0

        with tqdm(total=len(train_loader), desc=f'Epoch {epoch+1}/{num_epochs}') as pbar:
            for batch_idx, (zero_filled, fully_sampled, _) in enumerate(train_loader):
                # Move data to device
                zero_filled = zero_filled.to(device)
                fully_sampled = fully_sampled.to(device)

                # Forward pass
                outputs = model(zero_filled)

                # Resize target to match output size if needed
                if outputs.shape != fully_sampled.shape:
                    fully_sampled = TF.resize(fully_sampled, outputs.shape[2:],
                                           interpolation=TF.InterpolationMode.BILINEAR)

                # Calculate loss
                loss = criterion(outputs, fully_sampled)

                # Backward pass and optimize
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                # Update statistics
                train_loss += loss.item()
                pbar.update(1)
                pbar.set_postfix({'loss': loss.item()})

        # Calculate average training loss
        train_loss /= len(train_loader)
        history['train_loss'].append(train_loss)

        # Validation phase
        model.eval()
        val_loss = 0.0
        val_metrics = {'mse': [], 'psnr': [], 'ssim': []}

        with torch.no_grad():
            for zero_filled, fully_sampled, _ in val_loader:
                zero_filled = zero_filled.to(device)
                fully_sampled = fully_sampled.to(device)

                # Forward pass
                outputs = model(zero_filled)

                # Resize target to match output size if needed
                if outputs.shape != fully_sampled.shape:
                    fully_sampled = TF.resize(fully_sampled, outputs.shape[2:],
                                           interpolation=TF.InterpolationMode.BILINEAR)

                loss = criterion(outputs, fully_sampled)
                val_loss += loss.item()

                # Calculate metrics for each sample
                for i in range(outputs.size(0)):
                    output_np = outputs[i, 0].cpu().numpy()
                    target_np = fully_sampled[i, 0].cpu().numpy()

                    # Calculate metrics
                    mse = np.mean((output_np - target_np) ** 2)
                    val_metrics['mse'].append(mse)

                    max_val = 1.0
                    val_metrics['psnr'].append(
                        psnr(target_np, output_np, data_range=max_val)
                    )
                    val_metrics['ssim'].append(
                        ssim(target_np, output_np, data_range=max_val)
                    )

        # Calculate average validation loss and metrics
        val_loss /= len(val_loader)
        history['val_loss'].append(val_loss)

        avg_metrics = {
            'mse': np.mean(val_metrics['mse']),
            'psnr': np.mean(val_metrics['psnr']),
            'ssim': np.mean(val_metrics['ssim'])
        }
        history['val_metrics'].append(avg_metrics)

        # Print epoch statistics
        print(f"Epoch {epoch+1}/{num_epochs}:")
        print(f"  Train Loss: {train_loss:.6f}")
        print(f"  Val Loss: {val_loss:.6f}")
        print(f"  MSE: {avg_metrics['mse']:.6f}")
        print(f"  PSNR: {avg_metrics['psnr']:.2f} dB")
        print(f"  SSIM: {avg_metrics['ssim']:.4f}")

        # Save model if it's the best so far
        if val_loss < best_val_loss and save_path is not None:
            best_val_loss = val_loss
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': val_loss,
                'metrics': avg_metrics
            }, os.path.join(save_path, f"{model_name}_best.pth"))
            print(f"  Saved best model (val_loss: {val_loss:.6f})")

        # Save checkpoint every epoch
        if save_path is not None:
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': val_loss,
                'metrics': avg_metrics
            }, os.path.join(save_path, f"{model_name}_epoch{epoch+1}.pth"))

    return history


In [None]:
!pip install torchsummary -q


#=================================================================
# VISUALIZATION AND EVALUATION FUNCTIONS
#=================================================================

def visualize_dataset_samples(dataset, num_samples=3):
    """
    Display some samples from the dataset to verify preprocessing.
    """
    plt.figure(figsize=(15, 5*num_samples))

    for i in range(num_samples):
        try:
            zero_filled, fully_sampled, mask = dataset[i]

            # Convert tensors to numpy arrays for visualization
            zero_filled = zero_filled.squeeze().numpy()
            fully_sampled = fully_sampled.squeeze().numpy()
            mask = mask.squeeze().numpy()

            # Calculate error map
            error = np.abs(fully_sampled - zero_filled)

            # Calculate metrics
            ssim_val = ssim(fully_sampled, zero_filled, data_range=1.0)
            psnr_val = psnr(fully_sampled, zero_filled, data_range=1.0)

            # Display
            plt.subplot(num_samples, 4, i*4 + 1)
            plt.imshow(fully_sampled, cmap='gray')
            plt.title('Fully Sampled')
            plt.axis('off')

            plt.subplot(num_samples, 4, i*4 + 2)
            plt.imshow(zero_filled, cmap='gray')
            plt.title(f'Zero-Filled (PSNR: {psnr_val:.2f})')
            plt.axis('off')

            plt.subplot(num_samples, 4, i*4 + 3)
            plt.imshow(error, cmap='hot')
            plt.title(f'Error (SSIM: {ssim_val:.4f})')
            plt.axis('off')

            plt.subplot(num_samples, 4, i*4 + 4)
            plt.imshow(mask, cmap='gray')
            plt.title('k-space Mask')
            plt.axis('off')

        except IndexError:
            break

    plt.tight_layout()
    plt.show()

def compare_models(quantum_model, classical_model, test_loader, device='cpu'):
    """
    Compare the quantum and classical models on the same test data.
    """
    quantum_model.eval()
    classical_model.eval()

    # Metrics
    q_metrics = {
        'mse': [],
        'psnr': [],
        'ssim': []
    }

    c_metrics = {
        'mse': [],
        'psnr': [],
        'ssim': []
    }

    # Get some samples for visualization
    vis_samples = []

    with torch.no_grad():
        for batch_idx, (zero_filled, fully_sampled, mask) in enumerate(test_loader):
            zero_filled = zero_filled.to(device)
            fully_sampled = fully_sampled.to(device)

            # Forward pass for both models
            q_outputs = quantum_model(zero_filled)
            c_outputs = classical_model(zero_filled)

            # Resize target to match output sizes if needed
            if q_outputs.shape != fully_sampled.shape:
                q_target = TF.resize(fully_sampled, q_outputs.shape[2:],
                                   interpolation=TF.InterpolationMode.BILINEAR)
            else:
                q_target = fully_sampled

            if c_outputs.shape != fully_sampled.shape:
                c_target = TF.resize(fully_sampled, c_outputs.shape[2:],
                                   interpolation=TF.InterpolationMode.BILINEAR)
            else:
                c_target = fully_sampled

            # Calculate metrics for each sample
            for i in range(zero_filled.size(0)):
                q_output_np = q_outputs[i, 0].cpu().numpy()
                c_output_np = c_outputs[i, 0].cpu().numpy()
                q_target_np = q_target[i, 0].cpu().numpy()
                c_target_np = c_target[i, 0].cpu().numpy()
                input_np = zero_filled[i, 0].cpu().numpy()

                # Calculate metrics for quantum model
                q_mse = np.mean((q_output_np - q_target_np) ** 2)
                q_metrics['mse'].append(q_mse)

                max_val = 1.0
                q_metrics['psnr'].append(
                    psnr(q_target_np, q_output_np, data_range=max_val)
                )
                q_metrics['ssim'].append(
                    ssim(q_target_np, q_output_np, data_range=max_val)
                )

                # Calculate metrics for classical model
                c_mse = np.mean((c_output_np - c_target_np) ** 2)
                c_metrics['mse'].append(c_mse)

                c_metrics['psnr'].append(
                    psnr(c_target_np, c_output_np, data_range=max_val)
                )
                c_metrics['ssim'].append(
                    ssim(c_target_np, c_output_np, data_range=max_val)
                )

                # Store some samples for visualization
                if len(vis_samples) < 3 and batch_idx % 2 == 0:
                    # Make sure all samples are the same size for visualization
                    min_h = min(q_output_np.shape[0], c_output_np.shape[0], input_np.shape[0])
                    min_w = min(q_output_np.shape[1], c_output_np.shape[1], input_np.shape[1])

                    # Resize for visualization if needed
                    if input_np.shape != (min_h, min_w):
                        input_np_tensor = torch.from_numpy(input_np).unsqueeze(0)
                        input_np = TF.resize(input_np_tensor, (min_h, min_w)).squeeze(0).numpy()

                    if q_output_np.shape != (min_h, min_w):
                        q_output_np_tensor = torch.from_numpy(q_output_np).unsqueeze(0)
                        q_output_np = TF.resize(q_output_np_tensor, (min_h, min_w)).squeeze(0).numpy()

                    if c_output_np.shape != (min_h, min_w):
                        c_output_np_tensor = torch.from_numpy(c_output_np).unsqueeze(0)
                        c_output_np = TF.resize(c_output_np_tensor, (min_h, min_w)).squeeze(0).numpy()

                    # Store sample
                    vis_samples.append({
                        'input': input_np,
                        'q_output': q_output_np,
                        'c_output': c_output_np,
                        'target': q_target_np if q_target_np.shape == (min_h, min_w) else
                                TF.resize(torch.from_numpy(q_target_np).unsqueeze(0),
                                        (min_h, min_w)).squeeze(0).numpy(),
                        'q_metrics': {
                            'mse': q_mse,
                            'psnr': q_metrics['psnr'][-1],
                            'ssim': q_metrics['ssim'][-1]
                        },
                        'c_metrics': {
                            'mse': c_mse,
                            'psnr': c_metrics['psnr'][-1],
                            'ssim': c_metrics['ssim'][-1]
                        }
                    })

    # Calculate average metrics
    avg_q_mse = np.mean(q_metrics['mse'])
    avg_q_psnr = np.mean(q_metrics['psnr'])
    avg_q_ssim = np.mean(q_metrics['ssim'])

    avg_c_mse = np.mean(c_metrics['mse'])
    avg_c_psnr = np.mean(c_metrics['psnr'])
    avg_c_ssim = np.mean(c_metrics['ssim'])

    # Improvement percentages
    mse_improvement = (avg_c_mse - avg_q_mse) / avg_c_mse * 100
    psnr_improvement = (avg_q_psnr - avg_c_psnr) / avg_c_psnr * 100
    ssim_improvement = (avg_q_ssim - avg_c_ssim) / avg_c_ssim * 100

    print("\nModel Comparison:")
    print(f"  Quantum MSE: {avg_q_mse:.6f}, Classical MSE: {avg_c_mse:.6f} ({mse_improvement:.2f}% improvement)")
    print(f"  Quantum PSNR: {avg_q_psnr:.2f} dB, Classical PSNR: {avg_c_psnr:.2f} dB ({psnr_improvement:.2f}% improvement)")
    print(f"  Quantum SSIM: {avg_q_ssim:.4f}, Classical SSIM: {avg_c_ssim:.4f} ({ssim_improvement:.2f}% improvement)")

    # Visualize comparison
    plt.figure(figsize=(15, 5 * len(vis_samples)))

    for i, sample in enumerate(vis_samples):
        # Display
        plt.subplot(len(vis_samples), 4, i*4 + 1)
        plt.imshow(sample['target'], cmap='gray')
        plt.title('Ground Truth')
        plt.axis('off')

        plt.subplot(len(vis_samples), 4, i*4 + 2)
        plt.imshow(sample['input'], cmap='gray')
        plt.title('Zero-Filled Input')
        plt.axis('off')

        plt.subplot(len(vis_samples), 4, i*4 + 3)
        plt.imshow(sample['q_output'], cmap='gray')
        plt.title(f'Quantum Output\nPSNR: {sample["q_metrics"]["psnr"]:.2f}, SSIM: {sample["q_metrics"]["ssim"]:.4f}')
        plt.axis('off')

        plt.subplot(len(vis_samples), 4, i*4 + 4)
        plt.imshow(sample['c_output'], cmap='gray')
        plt.title(f'Classical Output\nPSNR: {sample["c_metrics"]["psnr"]:.2f}, SSIM: {sample["c_metrics"]["ssim"]:.4f}')
        plt.axis('off')

    plt.tight_layout()
    plt.show()

    # Bar chart comparison
    plt.figure(figsize=(15, 5))

    # MSE (lower is better)
    plt.subplot(1, 3, 1)
    plt.bar(['Quantum', 'Classical'], [avg_q_mse, avg_c_mse])
    plt.title('MSE (lower is better)')
    plt.ylabel('Mean Squared Error')

    # PSNR (higher is better)
    plt.subplot(1, 3, 2)
    plt.bar(['Quantum', 'Classical'], [avg_q_psnr, avg_c_psnr])
    plt.title('PSNR (higher is better)')
    plt.ylabel('Peak Signal-to-Noise Ratio (dB)')

    # SSIM (higher is better)
    plt.subplot(1, 3, 3)
    plt.bar(['Quantum', 'Classical'], [avg_q_ssim, avg_c_ssim])
    plt.title('SSIM (higher is better)')
    plt.ylabel('Structural Similarity Index')

    plt.tight_layout()
    plt.show()

    return {
        'quantum': {
            'mse': avg_q_mse,
            'psnr': avg_q_psnr,
            'ssim': avg_q_ssim
        },
        'classical': {
            'mse': avg_c_mse,
            'psnr': avg_c_psnr,
            'ssim': avg_c_ssim
        },
        'improvement': {
            'mse': mse_improvement,
            'psnr': psnr_improvement,
            'ssim': ssim_improvement
        }
    }

def print_model_summary(quantum_model, classical_model):
    from torchsummary import summary

    print("\n🔍 MODEL ARCHITECTURE SUMMARY")
    print("="*40)
    print("\n📦 Quantum Model Summary:")
    try:
        summary(quantum_model, (1, 64, 64))
    except:
        print("  (torchsummary not installed or error in quantum layer — skip)")

    print("\n📦 Classical Model Summary:")
    try:
        summary(classical_model, (1, 64, 64))
    except:
        print("  (torchsummary not installed — skip)")


def plot_training_history(history, save_path=None):
    """
    Plot the training history including loss and metrics.
    """
    plt.figure(figsize=(15, 10))

    # Plot loss
    plt.subplot(2, 2, 1)
    plt.plot(history['train_loss'], label='Train Loss')
    plt.plot(history['val_loss'], label='Val Loss')
    plt.title('Loss During Training')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()

    # Plot MSE
    plt.subplot(2, 2, 2)
    plt.plot([m['mse'] for m in history['val_metrics']])
    plt.title('MSE on Validation Set')
    plt.xlabel('Epoch')
    plt.ylabel('MSE')

    # Plot PSNR
    plt.subplot(2, 2, 3)
    plt.plot([m['psnr'] for m in history['val_metrics']])
    plt.title('PSNR on Validation Set')
    plt.xlabel('Epoch')
    plt.ylabel('PSNR (dB)')

    # Plot SSIM
    plt.subplot(2, 2, 4)
    plt.plot([m['ssim'] for m in history['val_metrics']])
    plt.title('SSIM on Validation Set')
    plt.xlabel('Epoch')
    plt.ylabel('SSIM')

    plt.tight_layout()

    if save_path is not None:
        plt.savefig(save_path)

    plt.show()

In [None]:
#=================================================================
# MAIN EXECUTION
#=================================================================

overall_start_time = time.time()

if __name__ == "__main__":
    # Specify paths to your NIfTI files
    nifti_files = [
        "/content/sub-ADNI011S0003_brain.nii.gz",
        "/content/sub-ADNI022S0004_brain.nii.gz",
        "/content/sub-ADNI100S5280_brain.nii.gz"
    ]

    # Create output directory
    os.makedirs('output', exist_ok=True)

    print("\n=== RUNNING DIAGNOSTIC MODE ===\n")

    # Use only one file for diagnostic
    # diagnostic_files = [nifti_files[0]]

    # Create dataset with downsampling
    print("Creating dataset...")
    dataset = DiagnosticNiftiDataset(
        nifti_files=nifti_files[:2],
        slice_range=(50, 100),
        acceleration_factor= 2,
        target_size=(64, 64)
    )

    dataset_time = time.time()
    print(f"Dataset creation time: {dataset_time - overall_start_time:.2f} seconds")

    # Visualize dataset samples
    print("\nVisualizing dataset samples...")
    visualize_dataset_samples(dataset, num_samples=3)

    # Test the 2-qubit circuit
    print("\nTesting quantum circuit...")
    # circuit = FastQuantumCircuit()
    circuit = FourQubitQuantumCircuit()
    test_pixels = np.array([0.1, 0.2, 0.3, 0.4])  # 2×2 patch flattened
    output = circuit(test_pixels)
    print(f"2-qubit circuit test - Input: {test_pixels}, Output: {output}")

    # Split data
    train_size = int(0.8 * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

    print(f"Training set: {len(train_dataset)} slices")
    print(f"Validation set: {len(val_dataset)} slices")

    # Create data loaders
    batch_size = 2
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    # Check for GPU availability
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # # Initialize simplified models
    # quantum_model = DiagnosticQuantumMRINet(stride=4)
    # classical_model = DiagnosticClassicalMRINet(stride=4)

    # Simplistic UNet models
    quantum_model = DiagnosticQuantumUNet(stride=2)
    classical_model = DiagnosticClassicalUNet(stride=2)

    # Initialize enhanced U-Net models
    # quantum_model = EnhancedDiagnosticQuantumUNet(stride=4)
    # classical_model = EnhancedDiagnosticClassicalUNet(stride=4)

    q_train_start = time.time()

    # Train quantum model
    print("\nTraining quantum model...")
    q_history = train_model(
        model=quantum_model,
        train_loader=train_loader,
        val_loader=val_loader,
        num_epochs = 10,
        lr=0.001,
        device=device,
        save_path='output',
        model_name="quantum_mri_diagnostic"
    )

    q_train_end = time.time()
    print(f"Quantum model training time: {q_train_end - q_train_start:.2f} seconds")

    # Plot training history
    print("\nPlotting quantum model training history...")
    plot_training_history(q_history, save_path='output/quantum_training_history.png')

    c_train_start = time.time()

    # Train classical model
    print("\nTraining classical model...")
    c_history = train_model(
        model=classical_model,
        train_loader=train_loader,
        val_loader=val_loader,
        num_epochs = 10,
        lr=0.001,
        device=device,
        save_path='output',
        model_name="classical_mri_diagnostic"
    )

    c_train_end = time.time()
    print(f"Classical model training time: {c_train_end - c_train_start:.2f} seconds")

    # Compare models
    print("\nComparing models...")
    comparison = compare_models(quantum_model, classical_model, val_loader, device)

    # Plot classical model training history
    print("\nPlotting classical model training history...")
    plot_training_history(c_history, save_path='output/classical_training_history.png')

    # Save results
    import json
    with open('output/diagnostic_results.json', 'w') as f:
        serializable_results = convert_to_json_serializable({
            'quantum_metrics': comparison['quantum'],
            'classical_metrics': comparison['classical'],
            'improvement': comparison['improvement']
        })
        json.dump(serializable_results, f, indent=2)

    print("\nDiagnostic run completed! Results saved to 'output/diagnostic_results.json'.")

    overall_end_time = time.time()
    total_execution_time = overall_end_time - overall_start_time
    print(f"\nTotal execution time: {total_execution_time:.2f} seconds ({total_execution_time/60:.2f} minutes)")

    ### MODEL SUMMARY
    print_model_summary(quantum_model, classical_model)