# Env Prep

In [1]:
import warnings

warnings.filterwarnings("ignore")

In [2]:
!nvidia-smi

Fri Aug 15 17:54:36 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.144.03             Driver Version: 550.144.03     CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce RTX 3080 Ti     Off |   00000000:00:07.0 Off |                  N/A |
| 30%   44C    P0             N/A /  350W |       1MiB /  12288MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

## Config

In [2]:
from dataclasses import dataclass, field
from pathlib import Path
from typing import List, Literal, Union, Dict
import os
import sys
import subprocess
import shutil

EnvModeType = Literal["colab", "remote", "local"]


@dataclass(frozen=True)
class Config:
    env_mode: EnvModeType
    selected_subsets: List[str] = field(
        default_factory=lambda: ["SUDOKU_4", "SUDOKU_5", "SUDOKU_6"]
    )

    # Derived fields (init=False)
    data_root: Path = field(init=False)
    base_dir: Path = field(init=False)
    taco_raw_dir: Path = field(init=False)
    taco_file_paths: List[Path] = field(init=False)
    normalized_sets_dir: Path = field(init=False)
    finetune_dir: Path = field(init=False)
    train_dir: Path = field(init=False)
    val_dir: Path = field(init=False)
    test_dir: Path = field(init=False)

    def __post_init__(self) -> None:
        # Set data_root based on env_mode (defaults; override in factory if needed)
        data_root_map = {
            # "local": Path.home() / "mnt/shared",
            "local": Path("/mnt/shared"),
            # "remote": Path("/mnt/data"),
            "remote": Path.home(),
            "colab": Path("/content/drive/MyDrive"),
        }
        object.__setattr__(
            self, "data_root", data_root_map.get(self.env_mode, Path.cwd())
        )

        # Derive other paths
        object.__setattr__(self, "base_dir", self.data_root / "datasets/sen2venus")
        # object.__setattr__(self, "taco_raw_dir", self.base_dir / "TACO_raw_data")
        # object.__setattr__(
        #     self,
        #     "taco_file_paths",
        #     [self.taco_raw_dir / f"{subset}.taco" for subset in self.selected_subsets],
        # )
        object.__setattr__(
            self, "normalized_sets_dir", self.base_dir / "normalized_sets"
        )
        object.__setattr__(self, "finetune_dir", self.base_dir / "finetune")
        object.__setattr__(self, "train_dir", self.normalized_sets_dir / "train")
        object.__setattr__(self, "val_dir", self.normalized_sets_dir / "val")
        object.__setattr__(self, "test_dir", self.normalized_sets_dir / "test")

    def validate(self) -> None:
        """Validate config paths exist; raise errors otherwise."""
        missing_paths = []
        for attr in [
            "data_root",
            "base_dir",
            # "taco_raw_dir",
            "normalized_sets_dir",
            "finetune_dir",
            "train_dir",
            "val_dir",
            "test_dir",
        ]:
            path: Path = getattr(self, attr)
            if not path.exists():
                missing_paths.append(str(path))
        if missing_paths:
            raise ValueError(f"Missing paths: {', '.join(missing_paths)}")
        # for file_path in self.taco_file_paths:
        #     if not file_path.exists():
        #         missing_paths.append(str(file_path))
        if missing_paths:
            raise ValueError(f"Missing taco files: {', '.join(missing_paths)}")


def setup_environment(env_mode: EnvModeType) -> None:
    """Perform environment-specific setup (side effects isolated here)."""
    if env_mode == "colab":
        try:
            import super_image
        except ImportError:
            print("Installing 'super-image'...")
            try:
                subprocess.run(["pip", "install", "--quiet", "super-image"], check=True)
            except subprocess.CalledProcessError as e:
                raise RuntimeError(f"Failed to install super-image: {e}")

        try:
            from google.colab import drive

            drive.mount("/content/drive", force_remount=True)
        except ImportError:
            raise RuntimeError("Google Colab module not found. Are you in Colab?")
        except Exception as e:
            raise RuntimeError(f"Failed to mount Google Drive: {e}")

        # Optional: Copy data to local /content for faster I/O in Colab
        colab_vm_dir = Path("/content/taco_normalized")
        if not colab_vm_dir.exists():
            print("Copying normalized data to local Colab storage for performance...")
            shutil.copytree(
                Path("/content/drive/MyDrive/datasets/sen2venus/normalized_sets"),
                colab_vm_dir,
            )
            print("Copy complete.")
        # Avoid os.chdir; let users handle working dir if needed

    elif env_mode == "remote":
        print("Remote environment detected. No specific setup needed.")

    elif env_mode == "local":
        print("Local environment detected. Ensuring dependencies...")


def create_config(env_mode: EnvModeType | None = None) -> Config:
    """Factory to create and setup config based on detected environment."""
    if env_mode is None:
        if "google.colab" in sys.modules:
            env_mode = "colab"
        elif "REMOTE_ENV_VAR" in os.environ:
            env_mode = "remote"
        else:
            env_mode = "local"

    setup_environment(env_mode)
    config = Config(env_mode=env_mode)
    config.validate()
    return config


In [3]:
config = create_config("remote")
print(config.data_root)

Remote environment detected. No specific setup needed.
/home/ubuntu


## Imports

In [4]:
import torch
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam, lr_scheduler
from torch.utils.data import Subset


from super_image import PreTrainedModel, TrainingArguments,EdsrConfig
from super_image.models import EdsrModel
from super_image.trainer import Trainer, logger
from super_image.utils.metrics import AverageMeter
from super_image.file_utils import WEIGHTS_NAME, WEIGHTS_NAME_SCALE

import numpy as np

from collections import OrderedDict
from tqdm.auto import tqdm

## logger


In [6]:
import logging


def setup_logger(name, log_file, level=logging.INFO):
    """Function to set up a dedicated logger."""
    log_path = Path(log_file)
    # Create parent directory if it doesn't exist
    log_path.parent.mkdir(parents=True, exist_ok=True)

    # To prevent logs from being duplicated in multiple runs of the same cell
    if name in logging.Logger.manager.loggerDict:
        logger = logging.getLogger(name)
        logger.handlers.clear()  # Clear existing handlers
    else:
        logger = logging.getLogger(name)

    # File handler to save logs to a file
    file_handler = logging.FileHandler(log_file, mode="w")  # 'w' to overwrite old log
    file_handler.setFormatter(
        logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
    )

    # Stream handler to show logs in the notebook output
    stream_handler = logging.StreamHandler()
    stream_handler.setFormatter(logging.Formatter("%(message)s"))

    logger.setLevel(level)
    logger.addHandler(file_handler)
    logger.addHandler(stream_handler)

    return logger


# Define the log file path within the experiment's output directory
output_dir_8_block = config.finetune_dir / "edsr_base_8_block"
log_file_8_block = output_dir_8_block / "training_log_8_block.log"
logger_8_block = setup_logger("8_block_experiment", log_file_8_block)


# Data Loader

In [7]:
class PreNormalizedDataset(Dataset):
    """
    Efficiently reads pre-processed, sharded tensor files from disk.
    """

    def __init__(self, shard_dir: Union[str, Path]):
        self.shard_dir = Path(shard_dir)
        self.shard_paths: List[Path] = sorted(self.shard_dir.glob("*.pt"))

        if not self.shard_paths:
            raise ValueError(f"No shard files ('*.pt') found in {self.shard_dir}")

        # To calculate length, we check the size of the first shard and assume
        # all but the last are the same size.
        first_shard = torch.load(self.shard_paths[0])
        self.shard_size = len(first_shard)
        last_shard = torch.load(self.shard_paths[-1])
        self.length = (len(self.shard_paths) - 1) * self.shard_size + len(last_shard)

        # Simple cache to avoid re-loading the same shard consecutively
        self._cache = {}
        self._cached_shard_index = -1
        print(
            f"Initialized dataset from {self.shard_dir} with {self.length} samples across {len(self.shard_paths)} shards."
        )

    def __len__(self):
        return self.length

    def __getitem__(self, idx) -> Dict[str, np.ndarray]:
        shard_index = idx // self.shard_size
        index_in_shard = idx % self.shard_size

        if shard_index != self._cached_shard_index:
            self._cache = torch.load(self.shard_paths[shard_index])
            self._cached_shard_index = shard_index

        # coupled with TACORGBDataset dataset class
        # each item in the shard is a squeezed dictionary with keys lr and hr
        squeezed_sample = self._cache[index_in_shard]
        return squeezed_sample["lr"], squeezed_sample["hr"]


In [10]:
train_dataset = PreNormalizedDataset(config.train_dir)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)

val_dataset = PreNormalizedDataset(config.val_dir)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)

# test_dataset = PreNormalizedDataset(config.test_dir)
# test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)

Initialized dataset from /home/ubuntu/datasets/sen2venus/normalized_sets/train with 4436 samples across 5 shards.
Initialized dataset from /home/ubuntu/datasets/sen2venus/normalized_sets/val with 554 samples across 1 shards.


# Load EDSR-baseline and configure 8-block

In [8]:
@dataclass
class CustomTrainingArguments(TrainingArguments):
    warmup_steps: int = 0
    save_steps: int = 100

In [9]:
logger_8_block.info("--- CONFIGURING 8-BLOCK ARCHITECTURE EXPERIMENT ---")

config_edsr_8block = EdsrConfig(
    scale=2,  
    n_resblocks=8,
)


logger_8_block.info("Instantiating 8-block EDSR model (training from scratch).")
# This creates a new model with half the number of residual blocks.
# The weights will be randomly initialized, NOT pre-trained.
model_8_block = EdsrModel(config_edsr_8block)

--- CONFIGURING 8-BLOCK ARCHITECTURE EXPERIMENT ---
Instantiating 8-block EDSR model (training from scratch).


# Start Training

In [15]:
class CustomResumableTrainer(Trainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.optimizer = None
        self.scheduler = None

    def save_checkpoint(self, epoch, global_step, is_best=False):
        output_dir = self.args.output_dir
        os.makedirs(output_dir, exist_ok=True)
        filename = (
            "best_model_checkpoint.pt" if is_best else "latest_step_checkpoint.pt"
        )
        checkpoint_path = os.path.join(output_dir, filename)
        state = {
            "epoch": epoch,
            "global_step": global_step,
            "best_metric": self.best_metric,
            "model_state_dict": self.model.state_dict(),
            "optimizer_state_dict": self.optimizer.state_dict(),
            "scheduler_state_dict": self.scheduler.state_dict(),
        }
        torch.save(state, checkpoint_path)
        logger.info(f"Saved checkpoint to {checkpoint_path} (step {global_step})")

    def load_checkpoint(self):
        output_dir = self.args.output_dir
        checkpoint_path = os.path.join(output_dir, "latest_step_checkpoint.pt")
        start_epoch, global_step = 0, 0
        if not os.path.exists(checkpoint_path):
            logger.warning("No checkpoint found. Starting from scratch.")
            return start_epoch, global_step
        try:
            state = torch.load(checkpoint_path, map_location=self.args.device)
            self.model.load_state_dict(state["model_state_dict"])
            if self.optimizer is None:
                self._create_optimizer_and_scheduler()
            self.optimizer.load_state_dict(state["optimizer_state_dict"])
            self.scheduler.load_state_dict(state["scheduler_state_dict"])
            self.best_metric = state.get("best_metric", 0.0)
            start_epoch = state["epoch"]
            global_step = state["global_step"]
            logger.info(
                f"Successfully loaded checkpoint. Resuming from epoch {start_epoch}, step {global_step}."
            )
        except Exception as e:
            logger.error(f"Failed to load checkpoint: {e}. Starting from scratch.")
            start_epoch, global_step = 0, 0
        return start_epoch, global_step

    def save_model(self, output_dir: str = None):
        """
        Overrides the faulty base save_model method to correctly handle
        models wrapped in nn.DataParallel.
        """
        output_dir = output_dir if output_dir is not None else self.args.output_dir
        os.makedirs(output_dir, exist_ok=True)

        # Determine if the model is the raw model or a wrapped one
        model_to_save = (
            self.model.module
            if isinstance(self.model, torch.nn.DataParallel)
            else self.model
        )

        # call save_pretrained on the actual model
        if isinstance(model_to_save, PreTrainedModel):
            model_to_save.save_pretrained(output_dir)
        else:
            # Fallback for non-PreTrainedModel, though EDSR is one.
            # This part is for full compatibility with the original's logic.
            logger.warning("Saving a model that is not a PreTrainedModel.")
            scale = model_to_save.config.scale
            if scale is not None:
                weights_name = WEIGHTS_NAME_SCALE.format(scale=scale)
            else:
                weights_name = WEIGHTS_NAME
            weights = copy.deepcopy(model_to_save.state_dict())
            torch.save(weights, os.path.join(output_dir, weights_name))

    def _create_optimizer_and_scheduler(self):
        self.optimizer = Adam(self.model.parameters(), lr=self.args.learning_rate)
        self.scheduler = lr_scheduler.StepLR(
            self.optimizer, step_size=999999, gamma=1.0
        )  # Dummy scheduler, we control LR manually

    def train(self, **kwargs):
        """The definitive, fully resumable training loop with the tuple unpacking fix."""
        self._create_optimizer_and_scheduler()

        # Unpack the tuple into two separate integer variables.
        start_epoch, global_step = self.load_checkpoint()

        train_dataloader = self.get_train_dataloader()

        for epoch in range(start_epoch, self.args.num_train_epochs):
            self.model.train()
            epoch_losses = AverageMeter()

            # Using an enumerated dataloader to skip steps correctly
            with tqdm(
                total=len(train_dataloader),
                desc=f"Epoch {epoch}/{self.args.num_train_epochs - 1}",
            ) as t:
                for step, data in enumerate(train_dataloader):
                    # This logic allows us to resume mid-epoch
                    current_epoch_step = epoch * len(train_dataloader) + step
                    if current_epoch_step < global_step:
                        t.update(1)
                        continue

                    # Learning Rate Scheduling (Warm-up + Decay)
                    if global_step < self.args.warmup_steps:
                        lr_scale = (
                            float(global_step) / float(self.args.warmup_steps)
                            if self.args.warmup_steps > 0
                            else 1.0
                        )
                        new_lr = self.args.learning_rate * lr_scale
                    else:
                        divisor = max(1, int(self.args.num_train_epochs * 0.8))
                        new_lr = self.args.learning_rate * (0.1 ** (epoch // divisor))

                    for param_group in self.optimizer.param_groups:
                        param_group["lr"] = new_lr

                    # Standard training steps
                    inputs, labels = data
                    inputs, labels = (
                        inputs.to(self.args.device),
                        labels.to(self.args.device),
                    )
                    preds = self.model(inputs)
                    loss = torch.nn.L1Loss()(preds, labels)
                    epoch_losses.update(loss.item(), len(inputs))
                    self.optimizer.zero_grad()
                    loss.backward()
                    self.optimizer.step()

                    global_step += 1
                    t.set_postfix(loss=f"{epoch_losses.avg:.6f}", lr=f"{new_lr:.2e}")
                    t.update(1)

                    # Step-Based Checkpointing
                    if global_step > 0 and global_step % self.args.save_steps == 0:
                        self.save_checkpoint(epoch, global_step, is_best=False)

            # Epoch-Based Evaluation and Best Model Saving
            self.eval(epoch)
            if self.best_epoch == epoch:
                self.save_checkpoint(epoch, global_step, is_best=True)


class CustomResumableTrainerLogger(CustomResumableTrainer): 
    def __init__(self, *args, logger=None, **kwargs):
        super().__init__(*args, **kwargs)
        # Use the provided logger, or fall back to the default `super_image.trainer` logger
        self.log = logger if logger is not None else super_image_logger

    def save_checkpoint(self, epoch, global_step, is_best=False):
        super().save_checkpoint(epoch, global_step, is_best)
        self.log.info(f"Saved checkpoint for epoch {epoch} at step {global_step}.")

    def load_checkpoint(self):
        self.log.info("Attempting to load checkpoint...")
        start_epoch, global_step = super().load_checkpoint()
        self.log.info(f"Load result: Resuming from epoch {start_epoch}, step {global_step}.")
        return start_epoch, global_step

    def eval(self, epoch):
        super().eval(epoch)
        self.log.info(f"EVALUATION Epoch {epoch}: PSNR={self.best_metric:.4f}")


In [17]:
args_8_block = CustomTrainingArguments(
    output_dir=output_dir_8_block,
    num_train_epochs=15,
    learning_rate=1e-4,
    per_device_train_batch_size=16,
    save_steps=500,
    warmup_steps=500,
)

logger_8_block.info(f"Training arguments: {args_8_block}")

Training arguments: CustomTrainingArguments(output_dir=PosixPath('/home/ubuntu/datasets/sen2venus/finetune/edsr_base_8_block'), overwrite_output_dir=False, learning_rate=0.0001, gamma=0.5, num_train_epochs=15, save_steps=500, save_total_limit=None, no_cuda=False, seed=123, fp16=False, per_device_train_batch_size=16, local_rank=-1, dataloader_num_workers=0, dataloader_pin_memory=True, warmup_steps=500)


In [20]:
logger_8_block.info("Instantiating Resumable Trainer with logs.")
trainer_8_block = CustomResumableTrainerLogger(
    model=model_8_block,
    args=args_8_block,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    logger=logger_8_block,  
)


logger_8_block.info("--- LAUNCHING TRAINING ---")
try:
    trainer_8_block.train()
    logger_8_block.info("--- TRAINING COMPLETE ---")

    final_best_epoch = trainer_8_block.best_epoch
    final_best_psnr = trainer_8_block.best_metric
    logger_8_block.info(
        f"Best model for 8-block architecture found at epoch {final_best_epoch} with PSNR: {final_best_psnr:.4f}"
    )

except Exception as e:
    logger_8_block.error(f"--- TRAINING FAILED ---: {e}", exc_info=True)

Instantiating Resumable Trainer with logs.
--- LAUNCHING TRAINING ---
Attempting to load checkpoint...
No checkpoint found. Starting from scratch.
Load result: Resuming from epoch 0, step 0.


Epoch 0/14:   0%|          | 0/278 [00:00<?, ?it/s]

EVALUATION Epoch 0: PSNR=45.4047
Saved checkpoint for epoch 0 at step 278.


scale:2      eval psnr: 45.40     ssim: 0.9855
best epoch: 0, psnr: 45.404720, ssim: 0.985482


Epoch 1/14:   0%|          | 0/278 [00:00<?, ?it/s]

Saved checkpoint for epoch 1 at step 500.
EVALUATION Epoch 1: PSNR=45.4047


scale:2      eval psnr: 44.54     ssim: 0.9872


Epoch 2/14:   0%|          | 0/278 [00:00<?, ?it/s]

Saved checkpoint for epoch 3 at step 1000.
EVALUATION Epoch 3: PSNR=46.1685
Saved checkpoint for epoch 3 at step 1112.


scale:2      eval psnr: 46.17     ssim: 0.9884
best epoch: 3, psnr: 46.168545, ssim: 0.988413


Epoch 4/14:   0%|          | 0/278 [00:00<?, ?it/s]

EVALUATION Epoch 4: PSNR=46.1685


scale:2      eval psnr: 45.91     ssim: 0.9885


Epoch 5/14:   0%|          | 0/278 [00:00<?, ?it/s]

Saved checkpoint for epoch 5 at step 1500.
EVALUATION Epoch 5: PSNR=46.1685


scale:2      eval psnr: 46.11     ssim: 0.9886


Epoch 6/14:   0%|          | 0/278 [00:00<?, ?it/s]

EVALUATION Epoch 6: PSNR=46.1685


scale:2      eval psnr: 46.05     ssim: 0.9887


Epoch 7/14:   0%|          | 0/278 [00:00<?, ?it/s]

Saved checkpoint for epoch 7 at step 2000.
EVALUATION Epoch 7: PSNR=46.3109
Saved checkpoint for epoch 7 at step 2224.


scale:2      eval psnr: 46.31     ssim: 0.9887
best epoch: 7, psnr: 46.310905, ssim: 0.988743


Epoch 8/14:   0%|          | 0/278 [00:00<?, ?it/s]

Saved checkpoint for epoch 8 at step 2500.
EVALUATION Epoch 8: PSNR=46.3109


scale:2      eval psnr: 45.59     ssim: 0.9886


Epoch 9/14:   0%|          | 0/278 [00:00<?, ?it/s]

EVALUATION Epoch 9: PSNR=46.3109


scale:2      eval psnr: 46.29     ssim: 0.9889


Epoch 10/14:   0%|          | 0/278 [00:00<?, ?it/s]

Saved checkpoint for epoch 10 at step 3000.
EVALUATION Epoch 10: PSNR=46.3109


scale:2      eval psnr: 46.18     ssim: 0.9888


Epoch 11/14:   0%|          | 0/278 [00:00<?, ?it/s]

EVALUATION Epoch 11: PSNR=46.3109


scale:2      eval psnr: 46.17     ssim: 0.9889


Epoch 12/14:   0%|          | 0/278 [00:00<?, ?it/s]

Saved checkpoint for epoch 12 at step 3500.
EVALUATION Epoch 12: PSNR=46.4834
Saved checkpoint for epoch 12 at step 3614.


scale:2      eval psnr: 46.48     ssim: 0.9891
best epoch: 12, psnr: 46.483364, ssim: 0.989063


Epoch 13/14:   0%|          | 0/278 [00:00<?, ?it/s]

EVALUATION Epoch 13: PSNR=46.4834


scale:2      eval psnr: 46.47     ssim: 0.9891


Epoch 14/14:   0%|          | 0/278 [00:00<?, ?it/s]

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

EVALUATION Epoch 14: PSNR=46.5171
Saved checkpoint for epoch 14 at step 4170.
--- TRAINING COMPLETE ---
Best model for 8-block architecture found at epoch 14 with PSNR: 46.5171


scale:2      eval psnr: 46.52     ssim: 0.9891
best epoch: 14, psnr: 46.517124, ssim: 0.989099


# Gathering Final Model Metrics and Parameters

In [12]:
print("--- Gathering Final Model Metrics and Parameters ---")

# --- 1. Load the Best Model ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
checkpoint_path = config.finetune_dir / 'edsr_base_8_block/best_model_checkpoint.pt'
model_8_block_final = EdsrModel(config_edsr_8block)

# Load the state dictionary from checkpoint
checkpoint = torch.load(checkpoint_path, map_location=device)
model_8_block_final.load_state_dict(checkpoint['model_state_dict'])
model_8_block_final.to(device)
model_8_block_final.eval()

print(f"Successfully loaded best model from epoch {checkpoint['epoch']}")

# --- 2. Get Final Metrics from Checkpoint ---
final_psnr = checkpoint['best_metric']
print(f"\nFinal PSNR (from checkpoint): {final_psnr:.4f}")

# --- 3. Calculate Total Trainable Parameters ---
def count_parameters(m):
    return sum(p.numel() for p in m.parameters() if p.requires_grad)

total_params = count_parameters(model_8_block_final)
print(f"Total Trainable Parameters: {total_params:,}")

# --- 4. Measure Inference Time ---
# run inference on a single image multiple times to get a stable average.
# Use a dummy tensor for this test.
dummy_input = torch.randn(1, 3, 128, 128, device=device)
starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
timings = []

print("\nMeasuring inference time...")
with torch.no_grad():
    # Warm-up runs (to prime the GPU)
    for _ in range(20):
        _ = model_8_block_final(dummy_input)
        
    # Measurement runs
    for _ in range(100):
        starter.record()
        _ = model_8_block_final(dummy_input)
        ender.record()
        torch.cuda.synchronize() # Wait for the GPU to finish
        timings.append(starter.elapsed_time(ender))

avg_inference_time_ms = sum(timings) / len(timings)
print(f"Average Inference Time: {avg_inference_time_ms:.4f} ms")

--- Gathering Final Model Metrics and Parameters ---
Successfully loaded best model from epoch 14

Final PSNR (from checkpoint): 46.5171
Total Trainable Parameters: 779,035

Measuring inference time...
Average Inference Time: 1.8959 ms
