In [1]:
%load_ext autoreload

In [2]:
%autoreload 2

import torch
import numpy as np

from dataset import SliceDataModule
from common.utils import save_reconstructions
from collections import defaultdict
from tqdm import tqdm
from pathlib import Path

In [3]:
HASH = "98a504d6edb24e2e81ab5c09970abc08"
SAVE_DIR = f"reconstructions/{HASH}"

In [4]:
from models import VarNetLogisticSensOL as Model

checkpoint = torch.load(f"mlruns/0/{HASH}/artifacts/checkpoints/latest_checkpoint.pth")
model = Model(**checkpoint["hyper_parameters"]["init_args"])
model.load_state_dict(torch.load(f"mlruns/0/{HASH}/artifacts/checkpoints/latest_checkpoint.pth")["state_dict"])
model = model.cuda()

In [5]:
dm = SliceDataModule(root="/home/Data")
dm.setup("test")
dm.setup("predict")

In [6]:
model.eval()
reconstructions = defaultdict(dict)
print("Reconstructing public leaderboard...")
with torch.no_grad():
    for (mask, kspace, target, maximum, fnames, slices) in tqdm(dm.test_dataloader()):
        output = model(kspace.cuda(non_blocking=True), mask.cuda(non_blocking=True))
        output = model.image_space_crop(output)
        for i in range(output.shape[0]):
            reconstructions[fnames[i]][slices[i]] = output[i].cpu().numpy()
    
for fname in reconstructions:
    reconstructions[fname] = np.stack([reconstructions[fname][slice] for slice in sorted(reconstructions[fname])])
print("Saving reconstructions of public leaderboard...")
save_reconstructions(reconstructions, Path(f"{SAVE_DIR}/public"))


model.eval()
reconstructions = defaultdict(dict)
print("Reconstructing private leaderboard...")
with torch.no_grad():
    for (mask, kspace, target, maximum, fnames, slices) in tqdm(dm.predict_dataloader()):
        output = model(kspace.cuda(non_blocking=True), mask.cuda(non_blocking=True))
        output = model.image_space_crop(output)
        for i in range(output.shape[0]):
            reconstructions[fnames[i]][slices[i]] = output[i].cpu().numpy()
    
for fname in reconstructions:
    reconstructions[fname] = np.stack([reconstructions[fname][slice] for slice in sorted(reconstructions[fname])])
print("Saving reconstructions of private leaderboard...")
save_reconstructions(reconstructions, Path(f"{SAVE_DIR}/private"))


Reconstructing public leaderboard...


100%|██████████| 984/984 [04:34<00:00,  3.58it/s]


Saving reconstructions of public leaderboard...
Reconstructing private leaderboard...


100%|██████████| 984/984 [04:32<00:00,  3.61it/s]


Saving reconstructions of private leaderboard...


In [7]:
!sh leaderboard_eval.sh {HASH}

Model: 98a504d6edb24e2e81ab5c09970abc08

Leaderboard SSIM : 0.8920
Leaderboard SSIM (public): 0.8993
Leaderboard SSIM (private): 0.8847
