In [4]:
%load_ext autoreload
%autoreload

import os
import pandas as pd
from plotly.subplots import make_subplots
import plotly.graph_objects as go
import plotly.express as px
import torch
import gc
import sys
import h5py
import numpy as np
from tqdm import tqdm

from omegaconf import OmegaConf
from time import time

from mridc.collections.reconstruction.data.subsample import RandomMaskFunc
from mridc.collections.reconstruction.parts import transforms
from mridc.collections.reconstruction.models.ccnn import CascadeNet
from mridc.collections.reconstruction.models.cirim import CIRIM
from mridc.collections.reconstruction.models.kikinet import KIKINet
from mridc.collections.reconstruction.models.lpd import LPDNet
from mridc.collections.reconstruction.models.unet import UNet
from mridc.collections.reconstruction.models.vn import VarNet
from tests.collections.reconstruction.fastmri.conftest import create_input
from mridc.collections.common.parts.utils import complex_conj, complex_mul, to_tensor
from mridc.collections.common.parts.fft import fft2, ifft2
from mridc.collections.reconstruction.parts.utils import apply_mask, center_crop, complex_center_crop

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [5]:
os.environ["TOOLBOX_PATH"] = "/scratch/dkarkalousos/apps/bart-0.6.00/"
sys.path.append("/scratch/dkarkalousos/apps/bart-0.6.00/python/")
import bart

In [6]:
device = "cuda"

In [7]:
output_csv = "/data/projects/recon/other/dkarkalousos/STAIRS/inference_times.csv"

In [8]:
data = h5py.File('/data/projects/recon/data/private/STAIRS/proc/Rothschild/20210803_1453_stroke/4RIM/proc/flair')
masked_kspace = to_tensor(data['kspace'][()]).unsqueeze(1)
sensitivity_map = to_tensor(data['sensitivity_map'][()]).unsqueeze(1)
eta = None
target = sensitivity_map
mask = torch.zeros_like(sensitivity_map)

In [9]:
torch.cuda.reset_peak_memory_stats()
torch.cuda.empty_cache()
gc.collect()

0

In [10]:
inf_times = []

In [15]:
cfg = {
    "recurrent_layer": "IndRNN",
    "conv_filters": [128, 128, 2],
    "conv_kernels": [5, 3, 3],
    "conv_dilations": [1, 2, 1],
    "conv_bias": [True, True, False],
    "recurrent_filters": [128, 128, 0],
    "recurrent_kernels": [1, 1, 0],
    "recurrent_dilations": [1, 1, 0],
    "recurrent_bias": [True, True, False],
    "depth": 2,
    "conv_dim": 2,
    "time_steps": 8,
    "num_cascades": 4,
    "accumulate_estimates": True,
    "no_dc": True,
    "keep_eta": True,
    "use_sens_net": False,
    "output_type": "SENSE",
    "dimensionality": 2,
    "coil_dim": 1,
    "fft_centered": False,
    "fft_normalization": "backward",
    "spatial_dims": [-2, -1],
}
cirim = CIRIM(OmegaConf.create(OmegaConf.to_container(OmegaConf.create(cfg), resolve=True))).to(device)
recon_times = []
for i in tqdm(range(masked_kspace.shape[0])):
    recon_times.append(get_ipython().run_line_magic(
        "timeit",
        "-n1 -r1 -c -q -o next(cirim.forward(masked_kspace[i].to(device), sensitivity_map[i].to(device), mask[i].to(device), None, target=target[i].to(device), phase_shift=torch.ones_like(target[i]).to(device)))"
    ).best)
inf_times.append([np.mean(recon_times), "IRIM", "212k"])
torch.cuda.reset_peak_memory_stats()
torch.cuda.empty_cache()
gc.collect()

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 432/432 [02:28<00:00,  2.91it/s]


8779

In [18]:
torch.cuda.reset_peak_memory_stats()
torch.cuda.empty_cache()
gc.collect()

0

In [19]:
masked_kspace = torch.view_as_complex(torch.fft.ifftshift(masked_kspace, dim=(-3, -2))).cpu().numpy()
sensitivity_map = torch.view_as_complex(torch.fft.ifftshift(sensitivity_map, dim=(-3, -2))).cpu().numpy()
recon_times = []
for i in tqdm(range(masked_kspace.shape[0])):
    recon_times.append(get_ipython().run_line_magic(
        "timeit",
        '-n1 -r1 -c -q -o bart.bart(1, f"pics -d0 -g -S -R W:7:0:0.05 -i 60", masked_kspace[i], sensitivity_map[i])'
    ).best)
inf_times.append([np.mean(recon_times), "PICS", "-"])
torch.cuda.reset_peak_memory_stats()
torch.cuda.empty_cache()
gc.collect()

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 432/432 [32:55<00:00,  4.57s/it]


768

In [52]:
inf_times[0][1] = 'CIRIM'
inf_times[0][0] = inf_times[0][0] * 432 # * slices
inf_times[1][0] = (inf_times[1][0] - 0.007) * 432 # substract I/O time * slices

In [53]:
try:
    os.remove(output_csv)
except OSError:
    pass

df = pd.DataFrame(inf_times, columns=["time", "Method", "Parameters"])
df.to_csv(output_csv, index=False, mode="w")

In [54]:
tmp = pd.read_csv(output_csv, header=0, index_col=False)
for m, x in zip(tmp['Method'], tmp['time']):
    print((m), "Total time : %.1f s"%(x))

CIRIM Total time : 154.3 s
PICS Total time : 674.2 s


In [58]:
df = pd.read_csv(output_csv, header=0, index_col=False)
df.sort_values('Parameters')

fig = px.scatter(
    df,
    x=df.Method,
    y=df["time"],
    color="Parameters",
    hover_data=[df.index],
    # color_discrete_sequence=["green", "red", "#FF97FF", "blue", "#AB63FA", "#FFA15A", "#B6E880", "#19D3F3", "#990099"],
)
fig.update_layout(
    title={
    'text': "Reconstruction times",
    'font': {'size': 24},
    'x': 0.45,
    'y': 0.97,
    'xanchor': 'center',
    'yanchor': 'top'},
    xaxis_title="",
    yaxis_title="Run time (sec)",
    legend_title="Parameters",
    font=dict(size=24),
)
# fig.update_traces(marker=dict(size=12))
# fig.update_yaxes(type="log", tickvals=[0.0001, 0.001, 0.01, 0.1, 1])
fig.update_yaxes(range=[100, 750], dtick=150)
fig.show()