In [1]:
%load_ext autoreload
%autoreload 2

import os
from typing import Optional, Tuple

os.chdir('/data/core-rad/tobweber/bernoulli-mri')

import torch
from torch import Tensor
from torch.utils.data import DataLoader
from monai.metrics import DiceMetric, MeanIoU
from monai.networks.nets import UNet
from monai.networks import one_hot
import pandas as pd
from tqdm import tqdm

from src.datasets import ACDCDataset
from src.utils import ifft2c
from src.utils import get_top_k_mask

In the future `np.bool` will be defined as the corresponding NumPy scalar.  (This may have returned Python scalars in past versions.
In the future `np.bool` will be defined as the corresponding NumPy scalar.  (This may have returned Python scalars in past versions.
In the future `np.bool` will be defined as the corresponding NumPy scalar.  (This may have returned Python scalars in past versions.
In the future `np.bool` will be defined as the corresponding NumPy scalar.  (This may have returned Python scalars in past versions.


# ACDC

In [2]:
ds = ACDCDataset('/data/core-rad/data/ACDC', train=False)
dl = DataLoader(ds, batch_size=64, num_workers=16, persistent_workers=True)

model = UNet(
        spatial_dims=2,
        in_channels=1,
        out_channels=4,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2
    ).cuda()


sd = torch.load('models/acdc_unet.pt', map_location='cpu')
model.load_state_dict(sd['model'])

model.eval()
for p in model.parameters():
    p.requires_grad = False

result = {
    'dataset': [],
    'method': [],
    'acc_fac': [],
    'dice': [],
    'iou': [],
}

In [3]:
def get_metrics(dl: DataLoader, model: UNet, mask: Optional[Tensor] = None) -> Tuple[float, float]:
    dice_metric = DiceMetric(include_background=False)
    iou_metric = MeanIoU(include_background=False)

    for batch in tqdm(dl, leave=False):

        if mask is not None:
            img = batch['k_space'].cuda()
            img = ifft2c(img * mask + 0.0)
            img = torch.abs(img)
        else:
            img = batch['img'].cuda()

        seg = batch['seg'].cuda()
        pred = model(img)

        pred = one_hot(torch.argmax(pred, dim=1).unsqueeze(1), num_classes=4)
        seg = one_hot(seg, num_classes=4)

        dice_metric(y_pred=pred, y=seg)
        iou_metric(y_pred=pred, y=seg)

    dice_res = dice_metric.aggregate()
    iou_res = iou_metric.aggregate()

    return float(dice_res.mean(dim=0).cpu()), float(iou_res.mean(dim=0).cpu())

## Fully Sampled

In [4]:
dice, iou = get_metrics(dl, model)

result['dataset'].append('acdc')
result['method'].append('full')
result['acc_fac'].append(1)
result['dice'].append(dice)
result['iou'].append(iou)

print('Dice Score:', dice)
print('IoU Score:', iou)

                                                                                           

Dice Score: 0.8549136519432068
IoU Score: 0.7632980942726135




# Equispaced

In [5]:
from src.mask_patterns import EquiSpacedMaskFunc

num_runs = 10
acc_facs = [8, 16, 32]

for _ in range(num_runs):
    for a in acc_facs:
        mask = EquiSpacedMaskFunc([0.04], [a])((1, 256, 256))[0].cuda()
        mask = mask.squeeze(2).expand(256, -1)
        mask = mask.unsqueeze(0).unsqueeze(0)

        dice, iou = get_metrics(dl, model, mask)

        result['dataset'].append('acdc')
        result['method'].append('equi')
        result['acc_fac'].append(a)
        result['dice'].append(dice)
        result['iou'].append(iou)

                                                                                           

## 2D Variable Density

In [6]:
from src.mask_patterns import get_2d_variable_density_mask

num_runs = 10
acc_facs = [8, 16, 32]

for _ in range(num_runs):
    for a in acc_facs:
        mask = get_2d_variable_density_mask(256, a, 25).cuda()
        mask = mask.unsqueeze(0).unsqueeze(0)

        dice, iou = get_metrics(dl, model, mask)

        result['dataset'].append('acdc')
        result['method'].append('var_dens')
        result['acc_fac'].append(a)
        result['dice'].append(dice)
        result['iou'].append(iou)

                                                                                           

## PROM

In [7]:
num_runs = 10
dims = ['1d', '2d']
acc_facs = [8, 16, 32]

for run_idx in range(1, num_runs + 1):
    for d in dims:
        for a in acc_facs:
            path = f'logs/acdc_' + d + f'_a{a}/acdc_r{run_idx}.pt'
            score = torch.load(path)['scores'][-1].cuda()

            mask = get_top_k_mask(score.squeeze(), a)

            dice, iou = get_metrics(dl, model, mask)

            result['dataset'].append('acdc')
            result['method'].append('prom_' + d)
            result['acc_fac'].append(a)
            result['dice'].append(dice)
            result['iou'].append(iou)

                                                                                           

# IGS Sampling

In [8]:
from src.igs import IGS

num_runs = 10
acc_facs = [8, 16, 32]

for run_idx in range(1, num_runs + 1):
    masks = torch.load(f'logs/IGS/igs_acdc_seg_{run_idx}.pt').cuda()
    ns = [IGS.get_n(acc_fac=a, img_size=256) for a in acc_facs]

    for n, a in zip(ns, acc_facs):
        mask = masks[n - 2].unsqueeze(0)
        mask = mask.expand(256, -1)
        mask = mask.unsqueeze(0).unsqueeze(0)

        dice, iou = get_metrics(dl, model, mask)

        result['dataset'].append('acdc')
        result['method'].append('igs')
        result['acc_fac'].append(a)
        result['dice'].append(dice)
        result['iou'].append(iou)
    

                                                                                           

In [9]:
df = pd.DataFrame.from_dict(result)

df.head()

df.to_csv('logs/acdc.csv')

In [10]:
for group_name, df_group in df.groupby(['method', 'acc_fac']):
        
    print('GROUP:', group_name)
    print('DICE MEAN {:.3f} | SD {:.3f}'.format(df_group['dice'].mean(), df_group['dice'].std()))
    print('IOU  MEAN {:.3f} | SD {:.3f}'.format(df_group['iou'].mean(), df_group['iou'].std()))
    print('------------------------------------')

GROUP: ('equi', 8)
DICE MEAN 0.671 | SD 0.015
IOU  MEAN 0.546 | SD 0.016
------------------------------------
GROUP: ('equi', 16)
DICE MEAN 0.645 | SD 0.011
IOU  MEAN 0.517 | SD 0.012
------------------------------------
GROUP: ('equi', 32)
DICE MEAN 0.644 | SD 0.011
IOU  MEAN 0.517 | SD 0.011
------------------------------------
GROUP: ('full', 1)
DICE MEAN 0.855 | SD nan
IOU  MEAN 0.763 | SD nan
------------------------------------
GROUP: ('igs', 8)
DICE MEAN 0.828 | SD 0.000
IOU  MEAN 0.726 | SD 0.000
------------------------------------
GROUP: ('igs', 16)
DICE MEAN 0.745 | SD 0.003
IOU  MEAN 0.627 | SD 0.003
------------------------------------
GROUP: ('igs', 32)
DICE MEAN 0.592 | SD 0.021
IOU  MEAN 0.466 | SD 0.021
------------------------------------
GROUP: ('prom_1d', 8)
DICE MEAN 0.762 | SD 0.012
IOU  MEAN 0.650 | SD 0.013
------------------------------------
GROUP: ('prom_1d', 16)
DICE MEAN 0.717 | SD 0.018
IOU  MEAN 0.599 | SD 0.020
------------------------------------
GROUP: