# Preparing the Enviorment


In [None]:
import warnings

warnings.filterwarnings('ignore')

In [None]:
import os
import shutil
import subprocess
import sys
from dataclasses import dataclass, field
from pathlib import Path
from typing import List, Literal

EnvModeType = Literal["colab", "remote", "local"]


@dataclass(frozen=True)
class Config:
    env_mode: EnvModeType
    selected_subsets: List[str] = field(
        default_factory=lambda: ["SUDOUE-4", "SUDOUE-4", "SUDOUE-4"]
    )

    # Derived fields (init=False)
    data_root: Path = field(init=False)
    base_dir: Path = field(init=False)
    taco_raw_dir: Path = field(init=False)
    taco_file_paths: List[Path] = field(init=False)
    normalized_sets_dir: Path = field(init=False)
    finetune_dir: Path = field(init=False)
    train_dir: Path = field(init=False)
    val_dir: Path = field(init=False)
    test_dir: Path = field(init=False)

    def __post_init__(self) -> None:
        data_root_map = {
            "local": Path("/mnt/shared"),
            "remote": Path.home(),
            "colab": Path("/content/drive/MyDrive"),
        }
        object.__setattr__(
            self, "data_root", data_root_map.get(self.env_mode, Path.cwd())
        )

        # Derive other paths
        object.__setattr__(self, "base_dir", self.data_root / "datasets/sen2venus")
        # object.__setattr__(self, "taco_raw_dir", self.base_dir / "TACO_raw_data")
        object.__setattr__(
            self,
            "taco_file_paths",
            [self.taco_raw_dir / f"{subset}.taco" for subset in self.selected_subsets],
        )
        object.__setattr__(
            self, "normalized_sets_dir", self.base_dir / "normalized_sets"
        )
        object.__setattr__(self, "finetune_dir", self.base_dir / "finetune")
        object.__setattr__(self, "train_dir", self.normalized_sets_dir / "train")
        object.__setattr__(self, "val_dir", self.normalized_sets_dir / "val")
        object.__setattr__(self, "test_dir", self.normalized_sets_dir / "test")

    def validate(self) -> None:
        """Validate config paths exist; raise errors otherwise."""
        missing_paths = []
        for attr in [
            "data_root",
            "base_dir",
            "taco_raw_dir",
            "normalized_sets_dir",
            "finetune_dir",
            "train_dir",
            "val_dir",
            "test_dir",
        ]:
            path: Path = getattr(self, attr)
            if not path.exists():
                missing_paths.append(str(path))
        if missing_paths:
            raise ValueError(f"Missing paths: {', '.join(missing_paths)}")
        # for file_path in self.taco_file_paths:
        #     if not file_path.exists():
        #         missing_paths.append(str(file_path))
        if missing_paths:
            raise ValueError(f"Missing taco files: {', '.join(missing_paths)}")


def setup_environment(env_mode: EnvModeType) -> None:
    """Perform environment-specific setup (side effects isolated here)."""
    if env_mode == "colab":
        try:
            import super_image  # noqa: F401
        except ImportError:
            print("Installing 'super-image'...")
            try:
                subprocess.run(["pip", "install", "--quiet", "super-image"], check=True)
            except subprocess.CalledProcessError as e:
                raise RuntimeError(f"Failed to install super-image: {e}")

        try:
            from google.colab import drive

            drive.mount("/content/drive", force_remount=True)
        except ImportError:
            raise RuntimeError("Google Colab module not found. Are you in Colab?")
        except Exception as e:
            raise RuntimeError(f"Failed to mount Google Drive: {e}")

        # Optional: Copy data to local /content for faster I/O in Colab
        colab_vm_dir = Path("/content/taco_normalized")
        if not colab_vm_dir.exists():
            print("Copying normalized data to local Colab storage for performance...")
            shutil.copytree(
                Path("/content/drive/MyDrive/datasets/sen2venus/normalized_sets"),
                colab_vm_dir,
            )
            print("Copy complete.")
        # Avoid os.chdir; let users handle working dir if needed

    elif env_mode == "remote":
        print("Remote environment detected. No specific setup needed.")

    elif env_mode == "local":
        print("Local environment detected. Ensuring dependencies...")


def create_config(env_mode: EnvModeType | None = None) -> Config:
    """Factory to create and setup config based on detected environment."""
    if env_mode is None:
        if "google.colab" in sys.modules:
            env_mode = "colab"
        elif "REMOTE_ENV_VAR" in os.environ:  # Example detection for remote
            env_mode = "remote"
        else:
            env_mode = "local"

    setup_environment(env_mode)
    config = Config(env_mode=env_mode)
    config.validate()
    return config


In [None]:
config = create_config()  
print(config.data_root)

## Check out Colab Instance Region

In [None]:
!curl ipinfo.io

{
  "ip": "34.143.181.73",
  "hostname": "73.181.143.34.bc.googleusercontent.com",
  "city": "Singapore",
  "region": "Singapore",
  "country": "SG",
  "loc": "1.2897,103.8501",
  "org": "AS396982 Google LLC",
  "postal": "018989",
  "timezone": "Asia/Singapore",
  "readme": "https://ipinfo.io/missingauth"
}

In [None]:
!nvidia-smi

Sat Aug  9 18:48:20 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla T4                       Off |   00000000:00:04.0 Off |                    0 |
| N/A   77C    P0             30W /   70W |    4382MiB /  15360MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

## Imports

In [None]:
import torch
from super_image import PreTrainedModel, Trainer, TrainingArguments
from super_image.file_utils import WEIGHTS_NAME, WEIGHTS_NAME_SCALE
from super_image.models import EdsrModel
from super_image.trainer import Trainer, logger
from super_image.utils.metrics import AverageMeter
from torch.optim import Adam, lr_scheduler
from torch.utils.data import DataLoader, Dataset, Subset

In [None]:
import copy
import time
from dataclasses import dataclass
from typing import Dict, Union

import numpy as np
from tqdm.auto import tqdm

# Step 1: Define PyTorch Datasets & Dataloaders

In [None]:
class PreNormalizedDataset(Dataset):
    """
    Efficiently reads pre-processed, sharded tensor files from disk.
    """

    def __init__(self, shard_dir: Union[str, Path]):
        self.shard_dir = Path(shard_dir)
        self.shard_paths: List[Path] = sorted(self.shard_dir.glob("*.pt"))

        if not self.shard_paths:
            raise ValueError(f"No shard files ('*.pt') found in {self.shard_dir}")

        # To calculate length-> check the size of the first shard and assume
        # all but the last are the same size.
        first_shard = torch.load(self.shard_paths[0])
        self.shard_size = len(first_shard)
        last_shard = torch.load(self.shard_paths[-1])
        self.length = (len(self.shard_paths) - 1) * self.shard_size + len(last_shard)

        # Simple cache to avoid re-loading the same shard consecutively
        self._cache = {}
        self._cached_shard_index = -1
        print(
            f"Initialized dataset from {self.shard_dir} with {self.length} samples across {len(self.shard_paths)} shards."
        )

    def __len__(self):
        return self.length

    def __getitem__(self, idx) -> Dict[str, np.ndarray]:
        shard_index = idx // self.shard_size
        index_in_shard = idx % self.shard_size

        if shard_index != self._cached_shard_index:
            self._cache = torch.load(self.shard_paths[shard_index])
            self._cached_shard_index = shard_index

        # coupled with TACORGBDataset dataset class
        # each item in the shard is a squeezed dictionary with keys lr and hr
        squeezed_sample = self._cache[index_in_shard]
        return squeezed_sample["lr"], squeezed_sample["hr"]


## Dataloader Instantiation

In [None]:
train_dataset = PreNormalizedDataset(config.train_dir)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)

Initialized dataset from /content/TACO_Normalized/train with 4436 samples across 5 shards.


In [None]:
print("---Verifying  batch shape:")

lr_batch, hr_batch = next(iter(train_loader))

print("Verification successful!")
print(f"LR batch shape: {lr_batch.shape}")
print(f"HR batch shape: {hr_batch.shape}")
print(f"LR batch dtype: {lr_batch.dtype}")
print(f"HR batch dtype: {hr_batch.dtype}")

---Verifying  dataset output format:
dict_keys(['pixel_values', 'labels'])
LR shape: torch.Size([3, 128, 128])
HR shape: torch.Size([3, 256, 256])
---Verifying  batch shape:
Verification successful!
LR batch shape: torch.Size([16, 3, 128, 128])
HR batch shape: torch.Size([16, 3, 256, 256])
LR batch dtype: torch.float32
HR batch dtype: torch.float32


In [None]:
val_dataset = PreNormalizedDataset(config.val_dir)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)

Initialized dataset from /content/TACO_Normalized/val with 554 samples across 1 shards.


In [None]:
test_dataset = PreNormalizedDataset(config.test_dir)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)

Initialized dataset from /content/drive/MyDrive/datasets/sen2venus/normalized_sets/test with 556 samples across 1 shards.
Loaded 556 test samples.


# Step 2: Load the Pre-trained EDSR Model

**Objectives**:



1.   Loading a well-known, **pre-trained** architecture (edsr-base) specifically configured for **2x super-resolution**.
2.   Confirming that the model accepts data batches and produces outputs of the correct shape ([16, 3, 256, 256]).

## 2.1 Instantiate and Inspect the pre-trained EDSR model

In [None]:
# The 'from_pretrained' method downloads the model configuration and weights.
# (LR: 128x128, HR: 256x256), -> scale is 2.
scale = 2
model_id = 'eugenesiow/edsr-base'
model = EdsrModel.from_pretrained(model_id, scale=scale)

# Inspect the model architecture
print("Model architecture loaded successfully:")
print(model)

config.json:   0%|          | 0.00/188 [00:00<?, ?B/s]

pytorch_model_2x.pt:   0%|          | 0.00/5.51M [00:00<?, ?B/s]

https://huggingface.co/eugenesiow/edsr-base/resolve/main/pytorch_model_2x.pt
Model architecture loaded successfully:
DataParallel(
  (module): EdsrModel(
    (sub_mean): MeanShift(3, 3, kernel_size=(1, 1), stride=(1, 1))
    (add_mean): MeanShift(3, 3, kernel_size=(1, 1), stride=(1, 1))
    (head): Sequential(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (body): Sequential(
      (0): ResBlock(
        (body): Sequential(
          (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): ReLU(inplace=True)
          (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
      )
      (1): ResBlock(
        (body): Sequential(
          (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): ReLU(inplace=True)
          (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
      )
      (2): ResBlock(
        (body): Sequential(
       

## 2.2 Sanity Check: Pass one batch of data through the model

In [None]:
# test to ensure the input/output dimensions are compatible.
print("\nPerforming a forward pass sanity check...")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()  # Set to evaluation mode for this check

with torch.no_grad():
    # Get a single batch from our dataloader
    lr_batch, hr_batch = next(iter(train_loader))

    # Move the batch to the same device as the model
    lr_batch = lr_batch.to(device)

    # Perform a forward pass
    predictions = model(lr_batch)

    print("Sanity check successful!")
    print(f"Running on device: {device}")
    print(f"Model Input Shape (LR): {lr_batch.shape}")
    print(f"Model Output Shape (Predictions): {predictions.shape}")
    print(f"Target Shape (HR): {hr_batch.shape}")

# Compare output shape with the target High-Resolution shape
assert predictions.shape == hr_batch.shape, "Model output shape does not match target HR shape!"
print("Output shape matches target shape. Ready for training.")


Performing a forward pass sanity check...
Sanity check successful!
Running on device: cuda
Model Input Shape (LR): torch.Size([16, 3, 128, 128])
Model Output Shape (Predictions): torch.Size([16, 3, 256, 256])
Target Shape (HR): torch.Size([16, 3, 256, 256])
Output shape matches target shape. Ready for training.


# Step 3: Configure and Launch the Trainer

## Custum Trainer with Checkpoints

In [None]:
class CustomResumableTrainer(Trainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.optimizer = None
        self.scheduler = None

    def save_checkpoint(self, epoch, global_step, is_best=False):
        output_dir = self.args.output_dir
        os.makedirs(output_dir, exist_ok=True)
        filename = (
            "best_model_checkpoint.pt" if is_best else "latest_step_checkpoint.pt"
        )
        checkpoint_path = os.path.join(output_dir, filename)
        state = {
            "epoch": epoch,
            "global_step": global_step,
            "best_metric": self.best_metric,
            "model_state_dict": self.model.state_dict(),
            "optimizer_state_dict": self.optimizer.state_dict(),
            "scheduler_state_dict": self.scheduler.state_dict(),
        }
        torch.save(state, checkpoint_path)
        logger.info(f"Saved checkpoint to {checkpoint_path} (step {global_step})")

    def load_checkpoint(self):
        output_dir = self.args.output_dir
        checkpoint_path = os.path.join(output_dir, "latest_step_checkpoint.pt")
        start_epoch, global_step = 0, 0
        if not os.path.exists(checkpoint_path):
            logger.warning("No checkpoint found. Starting from scratch.")
            return start_epoch, global_step
        try:
            state = torch.load(checkpoint_path, map_location=self.args.device)
            self.model.load_state_dict(state["model_state_dict"])
            if self.optimizer is None:
                self._create_optimizer_and_scheduler()
            self.optimizer.load_state_dict(state["optimizer_state_dict"])
            self.scheduler.load_state_dict(state["scheduler_state_dict"])
            self.best_metric = state.get("best_metric", 0.0)
            start_epoch = state["epoch"]
            global_step = state["global_step"]
            logger.info(
                f"Successfully loaded checkpoint. Resuming from epoch {start_epoch}, step {global_step}."
            )
        except Exception as e:
            logger.error(f"Failed to load checkpoint: {e}. Starting from scratch.")
            start_epoch, global_step = 0, 0
        return start_epoch, global_step

    def save_model(self, output_dir: str = None):
        """
        Overrides the faulty base save_model method to correctly handle
        models wrapped in nn.DataParallel.
        """
        output_dir = output_dir if output_dir is not None else self.args.output_dir
        os.makedirs(output_dir, exist_ok=True)

        # Determine if the model is the raw model or a wrapped one
        model_to_save = (
            self.model.module
            if isinstance(self.model, torch.nn.DataParallel)
            else self.model
        )

        # call save_pretrained on the actual model
        if isinstance(model_to_save, PreTrainedModel):
            model_to_save.save_pretrained(output_dir)
        else:
            # Fallback for non-PreTrainedModel, though EDSR is one.
            # This part is for full compatibility with the original's logic.
            logger.warning("Saving a model that is not a PreTrainedModel.")
            scale = model_to_save.config.scale
            if scale is not None:
                weights_name = WEIGHTS_NAME_SCALE.format(scale=scale)
            else:
                weights_name = WEIGHTS_NAME
            weights = copy.deepcopy(model_to_save.state_dict())
            torch.save(weights, os.path.join(output_dir, weights_name))

    def _create_optimizer_and_scheduler(self):
        self.optimizer = Adam(self.model.parameters(), lr=self.args.learning_rate)
        self.scheduler = lr_scheduler.StepLR(
            self.optimizer, step_size=999999, gamma=1.0
        )  # Dummy scheduler, we control LR manually

    def train(self, **kwargs):
        self._create_optimizer_and_scheduler()
        # Unpack the tuple into two separate integer variables.
        start_epoch, global_step = self.load_checkpoint()

        train_dataloader = self.get_train_dataloader()

        for epoch in range(start_epoch, self.args.num_train_epochs):
            self.model.train()
            epoch_losses = AverageMeter()

            # Using an enumerated dataloader to skip steps correctly
            with tqdm(
                total=len(train_dataloader),
                desc=f"Epoch {epoch}/{self.args.num_train_epochs - 1}",
            ) as t:
                for step, data in enumerate(train_dataloader):
                    #  resume mid-epoch
                    current_epoch_step = epoch * len(train_dataloader) + step
                    if current_epoch_step < global_step:
                        t.update(1)
                        continue

                    # Learning Rate Scheduling (Warm-up + Decay)
                    if global_step < self.args.warmup_steps:
                        lr_scale = (
                            float(global_step) / float(self.args.warmup_steps)
                            if self.args.warmup_steps > 0
                            else 1.0
                        )
                        new_lr = self.args.learning_rate * lr_scale
                    else:
                        divisor = max(1, int(self.args.num_train_epochs * 0.8))
                        new_lr = self.args.learning_rate * (0.1 ** (epoch // divisor))

                    for param_group in self.optimizer.param_groups:
                        param_group["lr"] = new_lr

                    # Standard training steps
                    inputs, labels = data
                    inputs, labels = (
                        inputs.to(self.args.device),
                        labels.to(self.args.device),
                    )
                    preds = self.model(inputs)
                    loss = torch.nn.L1Loss()(preds, labels)
                    epoch_losses.update(loss.item(), len(inputs))
                    self.optimizer.zero_grad()
                    loss.backward()
                    self.optimizer.step()

                    global_step += 1
                    t.set_postfix(loss=f"{epoch_losses.avg:.6f}", lr=f"{new_lr:.2e}")
                    t.update(1)

                    # Step-Based Checkpointing
                    if global_step > 0 and global_step % self.args.save_steps == 0:
                        self.save_checkpoint(epoch, global_step, is_best=False)

            # Epoch-Based Evaluation and Best Model Saving
            self.eval(epoch)
            if self.best_epoch == epoch:
                self.save_checkpoint(epoch, global_step, is_best=True)


## Trainer Config

In [None]:
@dataclass
class CustomTrainingArguments(TrainingArguments):
    warmup_steps: int = 0
    save_steps: int = 100


## Dry Run Test

**What this 2 act test accomplishes:**

*   The training loop will run on only 64 images.
*   The `save_checkpoint` method will be called multiple times (for `is_best=False`).
*   The `eval` loop will run after the tiny epoch.
*   The overridden `save_model` will be called.
*    resuming from the exact global step where it left off

In [None]:
def run_dry_test(train_dataset, val_dataset):
    """
    Performs a comprehensive, two-part dry run to test trainer's
    step-based saving and mid-epoch resumption capabilities.
    """

    # dry_run_dir = '/content/edsr_v3_dry_run'
    dry_run_dir = config.base_dir / 'edsr_v3_dry_run'
    os.makedirs(dry_run_dir, exist_ok=True)
    

    # Clean up any previous dry run artifacts to ensure a clean test
    if os.path.exists(dry_run_dir):
        print(f"Removing previous dry run directory: {dry_run_dir}")
        shutil.rmtree(dry_run_dir)

    # --- ACT I: The Interrupted Run ---
    # Goal: Run for a few steps and save a mid-epoch checkpoint.
    # ----------------------------------------------------
    print("\n--- STARTING DRY RUN - PART 1: The 'Interrupted' Run ---")

    # 1. Create a tiny dataset (e.g., 64 samples)
    # With a batch size of 8, this will give us 8 steps per epoch.
    dry_run_train_subset = Subset(train_dataset, range(64))
    dry_run_val_subset = Subset(val_dataset, range(32))

    # 2. Define arguments for a short run
    args_part1 = CustomTrainingArguments(
        output_dir=dry_run_dir,
        num_train_epochs=1,
        per_device_train_batch_size=8,
        save_steps=5,  # Save a checkpoint at step 5
        warmup_steps=2
    )

    # 3. Instantiate a fresh model and trainer
    model_part1 = EdsrModel.from_pretrained('eugenesiow/edsr-base', scale=2)
    trainer_part1 = CustomResumableTrainer(
        model=model_part1,
        args=args_part1,
        train_dataset=dry_run_train_subset,
        eval_dataset=dry_run_val_subset
    )

    # 4. Run the training. This will run for 1 epoch (8 steps).
    print("Running Part 1... Expecting a step-based checkpoint at step 5.")
    trainer_part1.train()

    # 5. VERIFY Act I was successful
    step_checkpoint_path = os.path.join(dry_run_dir, 'latest_step_checkpoint.pt')
    best_checkpoint_path = os.path.join(dry_run_dir, 'best_model_checkpoint.pt')

    try:
        assert os.path.exists(step_checkpoint_path), "FAIL: Step-based checkpoint was not created!"
        assert os.path.exists(best_checkpoint_path), "FAIL: Best model checkpoint was not created!"
        print("\n✅ Verification for Part 1 Successful: Checkpoint files were created.")

        # We can even inspect the checkpoint to be sure
        checkpoint = torch.load(step_checkpoint_path)
        assert checkpoint['global_step'] == 5, f"FAIL: Step checkpoint saved at wrong step! Got {checkpoint['global_step']}"
        print(f"   - 'latest_step_checkpoint.pt' correctly saved at global step {checkpoint['global_step']}.")

    except AssertionError as e:
        print(f"\n❌ DRY RUN FAILED (Part 1): {e}")
        return

    # --- ACT II: The Resumed Run ---
    # Goal: Ensure a new trainer correctly loads the checkpoint and resumes.
    # ----------------------------------------------------------------
    print("\n--- STARTING DRY RUN - PART 2: The Resumed Run ---")

    # 1. Define arguments for the resumed run
    # We'll run for 2 total epochs to see if it continues correctly.
    args_part2 = CustomTrainingArguments(
        output_dir=dry_run_dir, # MUST be the same directory
        num_train_epochs=2,
        per_device_train_batch_size=8,
        save_steps=5,
        warmup_steps=2
    )

    # 2. Instantiate a NEW model and trainer to simulate a fresh session
    model_part2 = EdsrModel.from_pretrained('eugenesiow/edsr-base', scale=2)
    trainer_part2 = CustomResumableTrainer(
        model=model_part2,
        args=args_part2,
        train_dataset=dry_run_train_subset, # Use the same tiny dataset
        eval_dataset=dry_run_val_subset
    )

    # 3. Run the training. Critically, watch the log output here.
    # expected msg:"Successfully loaded checkpoint... Resuming from epoch 0, step 5."
    print("Running Part 2... Watch for the 'Successfully loaded checkpoint' message.")
    trainer_part2.train()

    # 4. Final verification
    print("\n✅ DRY RUN COMPLETE: The trainer successfully resumed and finished training.")
    print("   - It loaded the 'latest_step_checkpoint.pt'.")
    print("   - It skipped the first 5 steps of epoch 0.")
    print("   - It completed the remaining epochs.")


run_dry_test(train_dataset, val_dataset)

## Start Training

In [None]:
training_args = CustomTrainingArguments(
    output_dir=config.finetune_dir/'edsr_base',

    # --- Core parameters that are fully functional ---
    num_train_epochs=15,          # Controls training length and the hardcoded LR decay
    learning_rate=1e-4,           # Sets the initial learning rate
    per_device_train_batch_size=24, # Controls training batch size

    # --- Technical parameters that are functional ---
    seed=42,                      # For reproducibility
    fp16=torch.cuda.is_available(),
    dataloader_num_workers=2,
    save_steps=100,      # Save a recovery checkpoint every 100 steps
    warmup_steps=500     # Use a 500-step learning rate warm-up
)

In [None]:
trainer = CustomResumableTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
)

In [None]:
print("Starting model fine-tuning ")
trainer.train()

print(f"\nTraining complete. The best model was found at epoch {trainer.best_epoch} "
      f"with a PSNR of {trainer.best_metric:.2f}.")
print(f"The best model has been saved in: {config.finetune_dir}/edsr_base")

Starting model fine-tuning 


Epoch 1/14:   0%|          | 0/185 [00:00<?, ?it/s]

scale:2      eval psnr: 46.36     ssim: 0.9892
best epoch: 1, psnr: 46.361565, ssim: 0.989188


Epoch 2/14:   0%|          | 0/185 [00:00<?, ?it/s]

scale:2      eval psnr: 46.31     ssim: 0.9892


Epoch 3/14:   0%|          | 0/185 [00:00<?, ?it/s]

# Step4:Gathering Final Model Metrics and Parameters

In [None]:
print("--- Gathering Final Model Metrics and Parameters ---")

# --- 1. Load the Best Model ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
checkpoint_path = config.finetune_dir / "edsr_base/best_model_checkpoint.pt"
model = EdsrModel.from_pretrained("eugenesiow/edsr-base", scale=2)

# Load the state dictionary from checkpoint
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint["model_state_dict"])
model.to(device)
model.eval()

print(f"Successfully loaded best model from epoch {checkpoint['epoch']}")

# --- 2. Get Final Metrics from Checkpoint ---
final_psnr = checkpoint["best_metric"]
print(f"\nFinal PSNR (from checkpoint): {final_psnr:.4f}")


# --- 3. Calculate Total Trainable Parameters ---
def count_parameters(m):
    return sum(p.numel() for p in m.parameters() if p.requires_grad)


total_params = count_parameters(model)
print(f"Total Trainable Parameters: {total_params:,}")

# --- 4. Measure Inference Time ---
# run inference on a single image multiple times to get a stable average.
# Use a dummy tensor for this test.
dummy_input = torch.randn(1, 3, 128, 128, device=device)
starter, ender = (
    torch.cuda.Event(enable_timing=True),
    torch.cuda.Event(enable_timing=True),
)
timings = []

print("\nMeasuring inference time...")
with torch.no_grad():
    # Warm-up runs (to prime the GPU)
    for _ in range(20):
        _ = model(dummy_input)

    # Measurement runs
    for _ in range(100):
        starter.record()
        _ = model(dummy_input)
        ender.record()
        torch.cuda.synchronize()  # Wait for the GPU to finish
        timings.append(starter.elapsed_time(ender))

avg_inference_time_ms = sum(timings) / len(timings)
print(f"Average Inference Time: {avg_inference_time_ms:.4f} ms")