In [1]:
%load_ext autoreload

In [11]:
%autoreload 2

import torch
import numpy as np

from dataset import SliceDataModule, SliceGrappaDataModule
from common.utils import save_reconstructions
from collections import defaultdict
from tqdm import tqdm
from pathlib import Path

In [12]:
HASH = "4c7f8191cbce4bc9bb9ffc5fe25e5913"
SAVE_DIR = f"reconstructions/{HASH}"

In [15]:
def load_model(model_cls, ckpt_path: str | Path):
    checkpoint = torch.load(ckpt_path)

    hparams = checkpoint["hyper_parameters"]
    del hparams["_class_path"]
    del hparams["_instantiator"]
    model = model_cls(**hparams)
    model.load_state_dict(torch.load(ckpt_path)["state_dict"])
    model = model.cuda()

    return model

In [16]:
from models import VarNetToyRestormerOL as Model

best_epoch = 3
ckpt_path = f"mlartifacts/0/{HASH}/artifacts/checkpoints/epoch_{best_epoch}/checkpoint.pth"
model = load_model(Model, ckpt_path)

In [17]:
dm = SliceDataModule(root="/Data")
dm.setup("test")
dm.setup("predict")

In [18]:
work = {"public": dm.test_dataloader(), "private": dm.predict_dataloader()}

model.eval()
with torch.no_grad():
    for phase, dataloader in work.items():
        reconstructions = defaultdict(dict)
        print(f"Reconstructing {phase} leaderboard...")

        if isinstance(dm, SliceGrappaDataModule):
            for mask, kspace, grappa, target, maximum, fnames, slices in tqdm(
                dataloader
            ):
                output = model(
                    kspace.cuda(non_blocking=True),
                    mask.cuda(non_blocking=True),
                    grappa.cuda(non_blocking=True),
                )
                output = model.image_space_crop(output)
                for i in range(output.shape[0]):
                    reconstructions[fnames[i]][slices[i]] = output[i].cpu().numpy()
        else:
            for mask, kspace, target, maximum, fnames, slices in tqdm(dataloader):
                output = model(
                    kspace.cuda(non_blocking=True), mask.cuda(non_blocking=True)
                )
                output = model.image_space_crop(output)
                for i in range(output.shape[0]):
                    reconstructions[fnames[i]][slices[i]] = output[i].cpu().numpy()

        for fname in reconstructions:
            reconstructions[fname] = np.stack(
                [
                    reconstructions[fname][slice]
                    for slice in sorted(reconstructions[fname])
                ]
            )
        print(f"Saving reconstructions of {phase} leaderboard...")
        save_reconstructions(reconstructions, Path(f"{SAVE_DIR}/{phase}"))

Reconstructing public leaderboard...


  0%|          | 0/984 [00:00<?, ?it/s]

100%|██████████| 984/984 [04:49<00:00,  3.39it/s]


Saving reconstructions of public leaderboard...
Reconstructing private leaderboard...


100%|██████████| 984/984 [04:49<00:00,  3.40it/s]


Saving reconstructions of private leaderboard...


In [19]:
!sh leaderboard_eval.sh {HASH}

Model: 4c7f8191cbce4bc9bb9ffc5fe25e5913

Leaderboard SSIM : 0.9656
Leaderboard SSIM (public): 0.9789
Leaderboard SSIM (private): 0.9523
