# Env Prep

In [None]:
import warnings
warnings.filterwarnings('ignore')

In [None]:
!nvidia-smi

Sun Aug 17 08:43:01 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla T4                       Off |   00000000:00:04.0 Off |                    0 |
| N/A   47C    P8              9W /   70W |       0MiB /  15360MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

## Config

In [None]:
from dataclasses import dataclass, field
from pathlib import Path
from typing import List, Literal, Union, Dict
import os
import sys
import subprocess
import shutil

EnvModeType = Literal["colab", "remote", "local"]


@dataclass(frozen=True)
class Config:
    env_mode: EnvModeType
    selected_subsets: List[str] = field(
        default_factory=lambda: ["SUDOKU_4", "SUDOKU_5", "SUDOKU_6"]
    )

    # Derived fields (init=False)
    data_root: Path = field(init=False)
    base_dir: Path = field(init=False)
    taco_raw_dir: Path = field(init=False)
    taco_file_paths: List[Path] = field(init=False)
    normalized_sets_dir: Path = field(init=False)
    finetune_dir: Path = field(init=False)
    train_dir: Path = field(init=False)
    val_dir: Path = field(init=False)
    test_dir: Path = field(init=False)

    def __post_init__(self) -> None:
        # Set data_root based on env_mode (defaults; override in factory if needed)
        data_root_map = {
            # "local": Path.home() / "mnt/shared",
            "local": Path("/mnt/shared"),
            # "remote": Path("/mnt/data"),
            "remote": Path.home(),
            "colab": Path("/content/drive/MyDrive"),
        }
        object.__setattr__(
            self, "data_root", data_root_map.get(self.env_mode, Path.cwd())
        )

        # Derive other paths
        object.__setattr__(
            self, "base_dir", self.data_root / "datasets/sen2venus"
        )
        # object.__setattr__(self, "taco_raw_dir", self.base_dir / "TACO_raw_data")
        # object.__setattr__(
        #     self,
        #     "taco_file_paths",
        #     [self.taco_raw_dir / f"{subset}.taco" for subset in self.selected_subsets],
        # )
        object.__setattr__(
            self, "normalized_sets_dir", self.base_dir / "normalized_sets"
        )
        object.__setattr__(self, "finetune_dir", self.base_dir / "finetune")
        object.__setattr__(self, "train_dir", self.normalized_sets_dir / "train")
        object.__setattr__(self, "val_dir", self.normalized_sets_dir / "val")
        object.__setattr__(self, "test_dir", self.normalized_sets_dir / "test")

    def validate(self) -> None:
        """Validate config paths exist; raise errors otherwise."""
        missing_paths = []
        for attr in [
            "data_root",
            "base_dir",
            # "taco_raw_dir",
            "normalized_sets_dir",
            "finetune_dir",
            "train_dir",
            "val_dir",
            "test_dir",
        ]:
            path: Path = getattr(self, attr)
            if not path.exists():
                missing_paths.append(str(path))
        if missing_paths:
            raise ValueError(f"Missing paths: {', '.join(missing_paths)}")
        # for file_path in self.taco_file_paths:
        #     if not file_path.exists():
        #         missing_paths.append(str(file_path))
        if missing_paths:
            raise ValueError(f"Missing taco files: {', '.join(missing_paths)}")


def setup_environment(env_mode: EnvModeType) -> None:
    """Perform environment-specific setup (side effects isolated here)."""
    if env_mode == "colab":
        try:
            import super_image
        except ImportError:
            print("Installing 'super-image'...")
            try:
                subprocess.run(["pip", "install", "--quiet", "super-image"], check=True)
            except subprocess.CalledProcessError as e:
                raise RuntimeError(f"Failed to install super-image: {e}")

        try:
            from google.colab import drive

            drive.mount("/content/drive", force_remount=True)
        except ImportError:
            raise RuntimeError("Google Colab module not found. Are you in Colab?")
        except Exception as e:
            raise RuntimeError(f"Failed to mount Google Drive: {e}")

        # Optional: Copy data to local /content for faster I/O in Colab
        colab_vm_dir = Path("/content/taco_normalized")
        if not colab_vm_dir.exists():
            print("Copying normalized data to local Colab storage for performance...")
            shutil.copytree(
                Path(
                    "/content/drive/MyDrive/datasets/sen2venus/normalized_sets"
                ),
                colab_vm_dir,
            )
            print("Copy complete.")
        # Avoid os.chdir; let users handle working dir if needed

    elif env_mode == "remote":
        print("Remote environment detected. No specific setup needed.")

    elif env_mode == "local":
        print("Local environment detected. Ensuring dependencies...")


def create_config(env_mode: EnvModeType | None = None) -> Config:
    """Factory to create and setup config based on detected environment."""
    if env_mode is None:
        if "google.colab" in sys.modules:
            env_mode = "colab"
        elif "REMOTE_ENV_VAR" in os.environ:
            env_mode = "remote"
        else:
            env_mode = "local"

    setup_environment(env_mode)
    config = Config(env_mode=env_mode)
    config.validate()
    return config


In [None]:
config = create_config('colab')
print(config.data_root)

## Project imports

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader

from super_image import Trainer, PreTrainedModel, TrainingArguments
from super_image.models import EdsrModel
from super_image.trainer import Trainer, logger
from super_image.utils.metrics import AverageMeter, compute_metrics
from super_image.data import EvalMetrics

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from tqdm.auto import tqdm
from collections import OrderedDict

# Data Loader

In [None]:
class PreNormalizedDataset(Dataset):
    """
    Efficiently reads pre-processed, sharded tensor files from disk.
    """
    def __init__(self, shard_dir: Union[str, Path]):
        self.shard_dir = Path(shard_dir)
        self.shard_paths: List[Path] = sorted(self.shard_dir.glob("*.pt"))

        if not self.shard_paths:
            raise ValueError(f"No shard files ('*.pt') found in {self.shard_dir}")

        # To calculate length, we check the size of the first shard and assume
        # all but the last are the same size.
        first_shard = torch.load(self.shard_paths[0])
        self.shard_size = len(first_shard)
        last_shard = torch.load(self.shard_paths[-1])
        self.length = (len(self.shard_paths) - 1) * self.shard_size + len(last_shard)

        # Simple cache to avoid re-loading the same shard consecutively
        self._cache = {}
        self._cached_shard_index = -1
        print(f"Initialized dataset from {self.shard_dir} with {self.length} samples across {len(self.shard_paths)} shards.")


    def __len__(self):
        return self.length

    def __getitem__(self, idx)->Dict[str, np.ndarray]:
        shard_index = idx // self.shard_size
        index_in_shard = idx % self.shard_size

        if shard_index != self._cached_shard_index:
            self._cache = torch.load(self.shard_paths[shard_index])
            self._cached_shard_index = shard_index

        # coupled with TACORGBDataset dataset class
        # each item in the shard is a squeezed dictionary with keys lr and hr
        squeezed_sample = self._cache[index_in_shard]
        return squeezed_sample['lr'], squeezed_sample['hr']


In [None]:
try:
    test_data_dir = Path("/content/taco_normalized/test") if config.env_mode == "colab" else config.test_dir

    test_dataset = PreNormalizedDataset(test_data_dir)
    # For evaluation, batch size is typically 1 to measure per-image metrics.
    test_dataloader = DataLoader(test_dataset, batch_size=1)
    print(f"Loaded test dataset with {len(test_dataset)} samples.")
except Exception as e:
    print(f"❌ ERROR: Failed to load the test dataset from '{test_data_dir}': {e}")

# Load the Models

In [None]:
def load_model(finetune_dir: Path|str, pre_trained : bool, scale: int = 2):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    test_data_dir = config.test_dir
    checkpoint_path = finetune_dir / 'best_model_checkpoint.pt'

    if pre_trained:
        model = EdsrModel.from_pretrained('eugenesiow/edsr-base', scale=2)
    else:
        model = EdsrModel(scale=2)

    # If multiple GPUs are available, wrap the model in DataParallel
    # This ensures the model's structure matches the training environment.
    if torch.cuda.device_count() > 1:
      print(f"Using {torch.cuda.device_count()} GPUs. Wrapping model in DataParallel.")
      model = torch.nn.DataParallel(model)

    model.to(device)

    # load the state dictionary from best checkpoint
    print(f"Loading best model checkpoint from: {checkpoint_path}")
    try:
        checkpoint = torch.load(checkpoint_path, map_location=device)
        model_state_dict = checkpoint['model_state_dict']

        # --- Logic to handle all DataParallel cases ---
        is_model_parallel = isinstance(model, torch.nn.DataParallel)
        is_checkpoint_parallel = list(model_state_dict.keys())[0].startswith('module.')

        final_state_dict = OrderedDict()

        if is_model_parallel and not is_checkpoint_parallel:
            # If the current model is parallel but the checkpoint isn't, add "module." prefix
            print("Model is parallel, checkpoint is not. Adding 'module.' prefix to keys...")
            for k, v in model_state_dict.items():
                final_state_dict['module.' + k] = v
        elif not is_model_parallel and is_checkpoint_parallel:
            # If the checkpoint is parallel but the current model isn't, strip "module." prefix
            print("Checkpoint is parallel, model is not. Stripping 'module.' prefix from keys...")
            for k, v in model_state_dict.items():
                final_state_dict[k[7:]] = v # k[7:] removes 'module.'
        else:
            # If they are both parallel or both not parallel, the keys match already.
            print("Model and checkpoint states match. Loading directly.")
            final_state_dict = model_state_dict

        # Load the correctly formatted state dictionary
        model.load_state_dict(final_state_dict)

        print(f"\n✅Successfully loaded model from epoch {checkpoint['epoch']} with best validation PSNR {checkpoint['best_metric']:.4f}")

    except FileNotFoundError:
        print(f"❌ ERROR: Checkpoint file not found at '{checkpoint_path}'. Please verify the path.")
    except Exception as e:
        print(f"❌ ERROR: An error occurred while loading the checkpoint: {e}")

    model.eval()

    return model, device

In [None]:
model_16_block, device_16_block = load_model(config.finetune_dir / 'edsr_base', pre_trained=True)

In [None]:
model_8_block, device_8_block = load_model(config.finetune_dir / 'edsr_base_8_block', pre_trained=False)

# Step 1: Quantitative Evaluation on the Test Set

### Trained 8 block

In [None]:
EvalMetrics().evaluate(model_16_block, test_dataset)
print(f"Total images evaluated: {len(test_dataset)}")

In [None]:
EvalMetrics().evaluate(model_8_block, test_dataset)
print(f"Total images evaluated: {len(test_dataset)}")

# Step 2: Qualitative Visual Analysis

## convert tensors to displayable images

In [None]:
# --- Step 3: Helper function to convert tensors to displayable images ---
def tensor_to_image(tensor):
    """Converts a PyTorch tensor to a NumPy image for plotting."""
    # Move tensor to CPU and convert to NumPy array
    image = tensor.cpu().numpy()
    # Tranpose from (C, H, W) to (H, W, C)
    image = np.transpose(image, (1, 2, 0))
    # Clip values to be in the valid [0, 1] range for display
    image = np.clip(image, 0, 1)
    return image

# --- Step 4: The Main Plotting Function ---
def plot_comparison(dataset, model, index, zoom_rect, device):
    """
    Selects an image by index, runs inference, and plots a detailed
    side-by-side comparison with a zoomed-in patch.

    Args:
        dataset: The dataset to draw an image from.
        model: The trained model to use for inference.
        index: The index of the image in the dataset.
        zoom_rect: A tuple (x, y, width, height) defining the zoom area.
    """
    # Get the LR and HR images from the dataset
    lr_tensor, hr_tensor = dataset[index]

    # Run inference to get the Super-Resolved image
    with torch.no_grad():
        # Add a batch dimension and send to device
        lr_batch = lr_tensor.unsqueeze(0).to(device)
        sr_tensor = model(lr_batch)
        # Remove the batch dimension
        sr_tensor = sr_tensor.squeeze(0)

    # Convert all tensors to NumPy images for plotting
    lr_image = tensor_to_image(lr_tensor)
    sr_image = tensor_to_image(sr_tensor)
    hr_image = tensor_to_image(hr_tensor)

    # --- Plotting ---
    fig, axes = plt.subplots(2, 3, figsize=(20, 14))

    # Plot the full images
    axes[0, 0].imshow(lr_image)
    axes[0, 0].set_title('Low-Resolution Input', fontsize=16)

    axes[0, 1].imshow(sr_image)
    axes[0, 1].set_title('Super-Resolved Output', fontsize=16)

    axes[0, 2].imshow(hr_image)
    axes[0, 2].set_title('Ground Truth High-Resolution', fontsize=16)

    # Add rectangles to show the zoom area on the full images
    x, y, w, h = zoom_rect
    for ax in axes[0, :]:
        rect = patches.Rectangle((x, y), w, h, linewidth=2, edgecolor='r', facecolor='none')
        ax.add_patch(rect)
        ax.axis('off')

    # Plot the zoomed-in patches
    axes[1, 0].imshow(lr_image[y:y+h, x:x+w, :])
    axes[1, 0].set_title('Zoomed LR', fontsize=16)

    axes[1, 1].imshow(sr_image[y:y+h, x:x+w, :])
    axes[1, 1].set_title('Zoomed SR', fontsize=16)

    axes[1, 2].imshow(hr_image[y:y+h, x:x+w, :])
    axes[1, 2].set_title('Zoomed HR', fontsize=16)

    for ax in axes[1, :]:
        ax.axis('off')

    plt.tight_layout()
    plt.show()

## Generate the Comparisons

In [None]:
# Example 1: A general area
plot_comparison(test_dataset, model_16_block, index=50, zoom_rect=(40, 50, 48, 48), device=device_16_block)

# Example 2: Try another image index
plot_comparison(test_dataset, model_16_block, index=120, zoom_rect=(60, 20, 48, 48), device=device_16_block)

# Example 3: Find an image with a distinct feature, like a road or building
plot_comparison(test_dataset, model_16_block, index=250, zoom_rect=(30, 70, 48, 48), device=device_16_block)

In [None]:
# Example 1: A general area
plot_comparison(test_dataset, model_8_block, index=50, zoom_rect=(40, 50, 48, 48), device=device_8_block)

# Example 2: Try another image index
plot_comparison(test_dataset, model_8_block, index=120, zoom_rect=(60, 20, 48, 48), device=device_8_block)

# Example 3: Find an image with a distinct feature, like a road or building
plot_comparison(test_dataset, model_8_block, index=250, zoom_rect=(30, 70, 48, 48), device=device_8_block)