In [1]:
%load_ext autoreload
%autoreload 2

import os
from typing import Optional, Tuple

os.chdir('/data/core-rad/tobweber/bernoulli-mri')

import torch
from torch import Tensor
from torch.utils.data import DataLoader
from monai.metrics import DiceMetric, MeanIoU
from monai.networks.nets import UNet
from monai.networks import one_hot
import pandas as pd

from src.datasets import ACDCDataset
from src.utils import ifft2c
from src.utils import get_top_k_mask

# ACDC

In [9]:
ds = ACDCDataset('/data/core-rad/data/ACDC', train=False)
dl = DataLoader(ds, batch_size=64, num_workers=8)

model = UNet(
        spatial_dims=2,
        in_channels=1,
        out_channels=4,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2
    ).cuda()


sd = torch.load('models/acdc_base.pt')
model.load_state_dict(sd['model'])

model.eval()
for p in model.parameters():
    p.requires_grad = False

result = {
    'dataset': [],
    'method': [],
    'acc_fac': [],
    'dice': [],
    'iou': [],
}

NameError: name 'ACDCDataset' is not defined

In [None]:
def get_metrics(dl: DataLoader, model: UNet, mask: Optional[Tensor] = None) -> Tuple[float, float]:
    dice_metric = DiceMetric(include_background=False)
    iou_metric = MeanIoU(include_background=False)

    for batch in dl:

        if mask is not None:
            img = batch['k_space'].cuda()
            img = ifft2c(img * mask + 0.0)
            img = torch.abs(img)
        else:
            img = batch['img'].cuda()

        seg = batch['seg'].cuda()
        pred = model(img)

        pred = one_hot(torch.argmax(pred, dim=1).unsqueeze(1), num_classes=4)
        seg = one_hot(seg, num_classes=4)

        dice_metric(y_pred=pred, y=seg)
        iou_metric(y_pred=pred, y=seg)

    dice_res = dice_metric.aggregate()
    iou_res = iou_metric.aggregate()

    return float(dice_res.mean(dim=0).cpu()), float(iou_res.mean(dim=0).cpu())

## Fully Sampled

In [3]:
dice, iou = get_metrics(dl, model)

result['dataset'].append('acdc')
result['method'].append('full')
result['acc_fac'].append(1)
result['dice'].append(dice)
result['iou'].append(iou)

print('Dice Score:', dice)
print('IoU Score:', iou)

Dice Score: 0.8549136519432068
IoU Score: 0.7632980942726135


# Equispaced

In [4]:
from src.mask_patterns import EquiSpacedMaskFunc

num_runs = 10
acc_facs = [8, 16, 32]

for _ in range(num_runs):
    for a in acc_facs:
        mask = EquiSpacedMaskFunc([0.04], [a])((1, 256, 256))[0].cuda()
        mask = mask.squeeze(2).expand(256, -1)
        mask = mask.unsqueeze(0).unsqueeze(0)

        dice, iou = get_metrics(dl, model, mask)

        result['dataset'].append('acdc')
        result['method'].append('equi')
        result['acc_fac'].append(a)
        result['dice'].append(dice)
        result['iou'].append(iou)

Acc. Factor: 8
Dice Score: 0.665573000907898
IoU Score: 0.5406851172447205
Acc. Factor: 16
Dice Score: 0.671248733997345
IoU Score: 0.5460855960845947
Acc. Factor: 32
Dice Score: 0.6418558359146118
IoU Score: 0.5142226219177246


## 2D Variable Density

In [3]:
from src.mask_patterns import get_2d_variable_density_mask

num_runs = 10
acc_facs = [8, 16, 32]

for _ in range(num_runs):
    for a in acc_facs:
        mask = get_2d_variable_density_mask(256, a, 30).cuda()
        mask = mask.unsqueeze(0).unsqueeze(0)

        dice, iou = get_metrics(dl, model, mask)

        result['dataset'].append('acdc')
        result['method'].append('var_dens')
        result['acc_fac'].append(a)
        result['dice'].append(dice)
        result['iou'].append(iou)

## PROM

In [10]:
num_runs = 10
dims = ['1d', '2d']
acc_facs = [8, 16, 32]

for run_idx in range(1, num_runs + 1):
    for d in dims:
        for a in acc_facs:
            path = f'logs_final/acdc_' + d + f'_a{a}/acdc_r{run_idx}.pt'
            score = torch.load(path)['scores'][-1].cuda()

            mask = get_top_k_mask(score.squeeze(), a)

            dice, iou = get_metrics(dl, model, mask)

            result['dataset'].append('acdc')
            result['method'].append('prom_' + d)
            result['acc_fac'].append(a)
            result['dice'].append(dice)
            result['iou'].append(iou)

NameError: name 'prefix' is not defined

# IGS Sampling

In [9]:
from src.igs import IGS

num_runs = 10
acc_facs = [8, 16, 32]

for run_idx in range(1, num_runs + 1):
    masks = torch.load(f'logs_final/IGS/igs_acdc_seg_{run_idx}.pt').cuda()
    ns = [IGS.get_n(acc_fac=a, img_size=256) for a in acc_facs]

    for n, a in zip(ns, acc_facs):
        mask = masks[n - 2].unsqueeze(0)
        mask = mask.expand(256, -1)
        mask = mask.unsqueeze(0).unsqueeze(0)

        dice, iou = get_metrics(dl, model, mask)

        result['dataset'].append('acdc')
        result['method'].append('igs')
        result['acc_fac'].append(a)
        result['dice'].append(dice)
        result['iou'].append(iou)
    

Acc. Factor: 8
Dice Score: 0.8276612758636475
IoU Score: 0.7264233827590942
Acc. Factor: 16
Dice Score: 0.7434073090553284
IoU Score: 0.6251413822174072
Acc. Factor: 32
Dice Score: 0.5987941026687622
IoU Score: 0.47245970368385315


In [10]:
df = pd.DataFrame.from_dict(result)

df

{'full': (0.8549136519432068, 0.7632980942726135),
 'equi_8': (0.665573000907898, 0.5406851172447205),
 'equi_16': (0.671248733997345, 0.5460855960845947),
 'equi_32': (0.6418558359146118, 0.5142226219177246),
 'var_dens_8': (0.3991642892360687, 0.2990240454673767),
 'var_dens_16': (0.37987008690834045, 0.28197941184043884),
 'var_dens_32': (0.12038853019475937, 0.07972479611635208),
 'acdc_ensemble_proxy_8': (0.8379948139190674, 0.7407156825065613),
 'acdc_ensemble_proxy_16': (0.786157488822937, 0.6751657128334045),
 'acdc_ensemble_proxy_32': (0.7242296934127808, 0.6035035252571106),
 'acdc_ensemble_proxy_1d_8': (0.8133999705314636, 0.7085737586021423),
 'acdc_ensemble_proxy_1d_16': (0.7340726256370544, 0.6145088076591492),
 'acdc_ensemble_proxy_1d_32': (0.5987941026687622, 0.47245970368385315),
 'igs_proxy8': (0.8276612758636475, 0.7264233827590942),
 'igs_proxy16': (0.7434073090553284, 0.6251413822174072),
 'igs_proxy32': (0.5987941026687622, 0.47245970368385315)}