In [1]:
%load_ext autoreload
%autoreload 2

import os

os.chdir('/data/core-rad/tobweber/bernoulli-mri')

import torch
from torch.utils.data import DataLoader
from src.datasets import ACDCDataset
from monai.metrics import DiceMetric, MeanIoU
from monai.networks.nets import UNet
from monai.networks import one_hot
from src.distribution import SoftBernoulliSampler
from src.utils import ifft2c
import matplotlib.pyplot as plt
from src.utils import get_top_k_mask

# ACDC

In [2]:
ds = ACDCDataset('/data/core-rad/data/ACDC', train=False)
dl = DataLoader(ds, batch_size=64, num_workers=8)

model = UNet(
        spatial_dims=2,
        in_channels=1,
        out_channels=4,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2
    ).cuda()


sd = torch.load('models/acdc_base.pt')
model.load_state_dict(sd['model'])

model.eval()
for p in model.parameters():
    p.requires_grad = False

dice_metric = DiceMetric(include_background=False)
iou_metric = MeanIoU(include_background=False)

sampler = SoftBernoulliSampler()

r = {}

## Fully Sampled

In [3]:
for batch in dl:
        img = batch['img'].cuda()
        seg = batch['seg'].cuda()

        pred = model(img)

        pred = one_hot(torch.argmax(pred, dim=1).unsqueeze(1), num_classes=4)
        seg = one_hot(seg, num_classes=4)

        dice_metric(y_pred=pred, y=seg)
        iou_metric(y_pred=pred, y=seg)

dice_res = dice_metric.aggregate()
iou_res = iou_metric.aggregate()
dice_metric.reset()
iou_metric.reset()

m = float(dice_res.mean(dim=0).cpu()), float(iou_res.mean(dim=0).cpu())
r.update({'full': m})

print('Dice Score:', m[0])
print('IoU Score:', m[1])


Dice Score: 0.8549136519432068
IoU Score: 0.7632980942726135


# Equispaced

In [4]:
from src.mask_patterns import EquiSpacedMaskFunc

acc_facs = [8, 16, 32]

for a in acc_facs:
    mask = EquiSpacedMaskFunc([0.04], [a])((1, 256, 256))[0].cuda()
    mask = mask.squeeze(2).expand(256, -1)
    mask = mask.unsqueeze(0).unsqueeze(0)
    
    for batch in dl:
            img_k = batch['k_space'].cuda()
            seg = batch['seg'].cuda()
            
            img_pred = ifft2c(img_k * mask + 0.0)
            img_mag = torch.abs(img_pred)

            pred = model(img_mag)

            pred = one_hot(torch.argmax(pred, dim=1).unsqueeze(1), num_classes=4)
            seg = one_hot(seg, num_classes=4)

            dice_metric(y_pred=pred, y=seg)
            iou_metric(y_pred=pred, y=seg)

    dice_res = dice_metric.aggregate()
    iou_res = iou_metric.aggregate()
    dice_metric.reset()
    iou_metric.reset()

    m = float(dice_res.mean(dim=0).cpu()), float(iou_res.mean(dim=0).cpu())
    r.update({'equi_' + str(a): m})

    print('Acc. Factor:', a)
    print('Dice Score:', m[0])
    print('IoU Score:', m[1])
    

Acc. Factor: 8
Dice Score: 0.665573000907898
IoU Score: 0.5406851172447205
Acc. Factor: 16
Dice Score: 0.671248733997345
IoU Score: 0.5460855960845947
Acc. Factor: 32
Dice Score: 0.6418558359146118
IoU Score: 0.5142226219177246


## 2D Variable Density

In [3]:
from src.mask_patterns import get_2d_variable_density_mask

In [19]:
acc_facs = [8, 16, 32]

for a in acc_facs:    
    for batch in dl:
            mask = get_2d_variable_density_mask(256, a, 25).cuda()
            mask = mask.unsqueeze(0).unsqueeze(0)

            img_k = batch['k_space'].cuda()
            seg = batch['seg'].cuda()
            
            img_pred = ifft2c(img_k * mask + 0.0)
            img_mag = torch.abs(img_pred)

            pred = model(img_mag)

            pred = one_hot(torch.argmax(pred, dim=1).unsqueeze(1), num_classes=4)
            seg = one_hot(seg, num_classes=4)

            dice_metric(y_pred=pred, y=seg)
            iou_metric(y_pred=pred, y=seg)

    dice_res = dice_metric.aggregate()
    iou_res = iou_metric.aggregate()
    dice_metric.reset()
    iou_metric.reset()

    m = float(dice_res.mean(dim=0).cpu()), float(iou_res.mean(dim=0).cpu())
    r.update({'var_dens_' + str(a): m})

    print('Acc. Factor:', a)
    print('Dice Score:', m[0])
    print('IoU Score:', m[1])

Acc. Factor: 8
Dice Score: 0.8441641330718994
IoU Score: 0.7484228610992432
Acc. Factor: 16
Dice Score: 0.7152139544487
IoU Score: 0.5969336628913879
Acc. Factor: 32
Dice Score: 0.3585960865020752
IoU Score: 0.262611985206604


## BERM 2D PROXY

In [7]:
#result_dirs = ['acdc_8_proxy', 'acdc_16_proxy', 'acdc_32_proxy']
result_dirs = ['acdc_ensemble_proxy_8', 'acdc_ensemble_proxy_16', 'acdc_ensemble_proxy_32']
acc_facs = [8, 16, 32]
result_paths = [os.path.join('logs', d, 'results.pt') for d in result_dirs]
scores = [torch.load(f)['scores'][-1].cuda() for f in result_paths]

for score, d, a in zip(scores, result_dirs, acc_facs):
    mask = get_top_k_mask(score.squeeze(), a)
    
    for batch in dl:
            img_k = batch['k_space'].cuda()
            seg = batch['seg'].cuda()
            
            img_pred = ifft2c(img_k * mask + 0.0)
            img_mag = torch.abs(img_pred)

            pred = model(img_mag)

            pred = one_hot(torch.argmax(pred, dim=1).unsqueeze(1), num_classes=4)
            seg = one_hot(seg, num_classes=4)

            dice_metric(y_pred=pred, y=seg)
            iou_metric(y_pred=pred, y=seg)

    dice_res = dice_metric.aggregate()
    iou_res = iou_metric.aggregate()
    dice_metric.reset()
    iou_metric.reset()

    m = float(dice_res.mean(dim=0).cpu()), float(iou_res.mean(dim=0).cpu())
    r.update({d: m})

    print('Run:', d)
    print('Dice Score:', m[0])
    print('IoU Score:', m[1])

Run: acdc_ensemble_proxy_8
Dice Score: 0.8379948139190674
IoU Score: 0.7407156825065613
Run: acdc_ensemble_proxy_16
Dice Score: 0.786157488822937
IoU Score: 0.6751657128334045
Run: acdc_ensemble_proxy_32
Dice Score: 0.7242296934127808
IoU Score: 0.6035035252571106


## BERM 1D PROXY

In [8]:
#result_dirs = ['acdc_8_proxy_1d', 'acdc_16_proxy_1d', 'acdc_32_proxy_1d']
result_dirs = ['acdc_ensemble_proxy_1d_8', 'acdc_ensemble_proxy_1d_16', 'acdc_ensemble_proxy_1d_32']
acc_facs = [8, 16, 32]
result_paths = [os.path.join('logs', d, 'results.pt') for d in result_dirs]
scores = [torch.load(f)['scores'][-1].cuda() for f in result_paths]

for score, d, a in zip(scores, result_dirs, acc_facs):
    mask = get_top_k_mask(score.squeeze(), a)
    
    for batch in dl:
            img_k = batch['k_space'].cuda()
            seg = batch['seg'].cuda()
            
            img_pred = ifft2c(img_k * mask + 0.0)
            img_mag = torch.abs(img_pred)

            pred = model(img_mag)

            pred = one_hot(torch.argmax(pred, dim=1).unsqueeze(1), num_classes=4)
            seg = one_hot(seg, num_classes=4)

            dice_metric(y_pred=pred, y=seg)
            iou_metric(y_pred=pred, y=seg)

    dice_res = dice_metric.aggregate()
    iou_res = iou_metric.aggregate()
    dice_metric.reset()
    iou_metric.reset()

    m = float(dice_res.mean(dim=0).cpu()), float(iou_res.mean(dim=0).cpu())
    r.update({d: m})

    print('Run:', d)
    print('Dice Score:', m[0])
    print('IoU Score:', m[1])

Run: acdc_ensemble_proxy_1d_8
Dice Score: 0.8133999705314636
IoU Score: 0.7085737586021423
Run: acdc_ensemble_proxy_1d_16
Dice Score: 0.7340726256370544
IoU Score: 0.6145088076591492
Run: acdc_ensemble_proxy_1d_32
Dice Score: 0.5987941026687622
IoU Score: 0.47245970368385315


# IGS Sampling

In [9]:
from src.igs import IGS

masks = torch.load('logs/IGS/igs_acdc_seg.pt').cuda()

acc_facs = [8, 16, 32]
ns = [IGS.get_n(acc_fac=a, img_size=256) for a in acc_facs]

for n, a in zip(ns, acc_facs):
    mask = masks[n - 2].unsqueeze(0)
    mask = mask.expand(256, -1)
    mask = mask.unsqueeze(0).unsqueeze(0)
    
    for batch in dl:
            img_k = batch['k_space'].cuda()
            seg = batch['seg'].cuda()
            
            img_pred = ifft2c(img_k * mask + 0.0)
            img_mag = torch.abs(img_pred)

            pred = model(img_mag)

            pred = one_hot(torch.argmax(pred, dim=1).unsqueeze(1), num_classes=4)
            seg = one_hot(seg, num_classes=4)

            dice_metric(y_pred=pred, y=seg)
            iou_metric(y_pred=pred, y=seg)

    dice_res = dice_metric.aggregate()
    iou_res = iou_metric.aggregate()
    dice_metric.reset()
    iou_metric.reset()

    m = float(dice_res.mean(dim=0).cpu()), float(iou_res.mean(dim=0).cpu())
    r.update({'igs_proxy' + str(a): m})

    print('Acc. Factor:', a)
    print('Dice Score:', m[0])
    print('IoU Score:', m[1])
    

Acc. Factor: 8
Dice Score: 0.8276612758636475
IoU Score: 0.7264233827590942
Acc. Factor: 16
Dice Score: 0.7434073090553284
IoU Score: 0.6251413822174072
Acc. Factor: 32
Dice Score: 0.5987941026687622
IoU Score: 0.47245970368385315


In [10]:
r

{'full': (0.8549136519432068, 0.7632980942726135),
 'equi_8': (0.665573000907898, 0.5406851172447205),
 'equi_16': (0.671248733997345, 0.5460855960845947),
 'equi_32': (0.6418558359146118, 0.5142226219177246),
 'var_dens_8': (0.3991642892360687, 0.2990240454673767),
 'var_dens_16': (0.37987008690834045, 0.28197941184043884),
 'var_dens_32': (0.12038853019475937, 0.07972479611635208),
 'acdc_ensemble_proxy_8': (0.8379948139190674, 0.7407156825065613),
 'acdc_ensemble_proxy_16': (0.786157488822937, 0.6751657128334045),
 'acdc_ensemble_proxy_32': (0.7242296934127808, 0.6035035252571106),
 'acdc_ensemble_proxy_1d_8': (0.8133999705314636, 0.7085737586021423),
 'acdc_ensemble_proxy_1d_16': (0.7340726256370544, 0.6145088076591492),
 'acdc_ensemble_proxy_1d_32': (0.5987941026687622, 0.47245970368385315),
 'igs_proxy8': (0.8276612758636475, 0.7264233827590942),
 'igs_proxy16': (0.7434073090553284, 0.6251413822174072),
 'igs_proxy32': (0.5987941026687622, 0.47245970368385315)}