In [1]:
%load_ext autoreload
%autoreload 2

import os

os.chdir('/data/core-rad/tobweber/bernoulli-mri')

import torch
from src.run import run_dataset_optim

In [2]:
from src.losses import SegmentationProxyLoss
from monai.networks.nets import UNet
from monai.losses import DiceCELoss

model = UNet(
    spatial_dims=2,
    in_channels=1,
    out_channels=4,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2
).cuda()

sd = torch.load('models/acdc_base.pt')
model.load_state_dict(sd['model'])

loss_func = DiceCELoss(
    softmax=True,
    include_background=False,
    to_onehot_y=True
)

loss_func = SegmentationProxyLoss(model=model, seg_loss_func=loss_func).cuda()

In [3]:
cfg = {
    'dataset': 'acdc',
    'dataset_root': '/data/core-rad/data',
    'batch_size': 32,
    'steps': 2500,
    'use_seg': True,
    'learning_rate': 1e-3,
    'bern_samples': 4,
    'mask_style': 'h',
    'num_workers': 32,
    'dense_target': 1 / 8,
    'dense_start': 0.10,
    'dense_end': 0.90,
    'device': torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
    'log_dir': 'logs/acdc_ensemble_8',
    'log_imgs': 10,
    'log_interval': 100,
    'seed': None,
}

In [4]:
num_members = 10
acc_facs = [8, 16, 32,]

for acc_fac in acc_facs:
    for i in range(1, 1 + num_members):
        cfg['dense_target'] = 1 / acc_fac
        cfg['log_dir'] = os.path.join('logs', 'acdc_ensemble_proxy_1d_' + str(acc_fac) + '_m' + str(i))
            
        run_dataset_optim(cfg, loss_func)

L: 3.00E-01 | D: 0.125: 100%|█████████████| 2500/2500 [03:56<00:00, 10.59it/s]
L: 2.84E-01 | D: 0.125: 100%|█████████████| 2500/2500 [05:26<00:00,  7.66it/s]
L: 2.73E-01 | D: 0.125: 100%|█████████████| 2500/2500 [05:29<00:00,  7.58it/s]
L: 2.89E-01 | D: 0.125: 100%|█████████████| 2500/2500 [05:31<00:00,  7.55it/s]
L: 2.38E-01 | D: 0.125: 100%|█████████████| 2500/2500 [04:28<00:00,  9.31it/s]
L: 3.11E-01 | D: 0.125: 100%|█████████████| 2500/2500 [04:54<00:00,  8.49it/s]
L: 2.63E-01 | D: 0.125: 100%|█████████████| 2500/2500 [05:32<00:00,  7.52it/s]
L: 2.79E-01 | D: 0.125: 100%|█████████████| 2500/2500 [05:41<00:00,  7.32it/s]
L: 4.58E-01 | D: 0.133:  73%|█████████▍   | 1823/2500 [04:02<01:32,  7.34it/s]

In [5]:
# num_members = 10
# acc_facs = [8, 16, 32,]

# for acc_fac in acc_facs:
#     for i in range(1, 1 + num_members):
#         cfg['dense_target'] = 1 / acc_fac
#         cfg['log_dir'] = os.path.join('logs', 'acdc_ensemble_' + str(acc_fac) + '_m' + str(i))
            
#         run_dataset_optim(cfg)

# Compute Expected Mask

In [6]:
import os
import matplotlib.pyplot as plt
from src.utils import get_top_k_mask


for acc_fac in acc_facs:
    path_stem = f'logs/acdc_ensemble_proxy_1d_{acc_fac}_m'

    paths = [path_stem + str(i) + '/results.pt' for i in range(1, num_members + 1)]
    scores = [torch.load(f)['scores'][-1].cuda() for f in paths]

    scores_sum = torch.sum(torch.cat(scores), dim=(0,1))

    mask = get_top_k_mask(scores_sum, acc_fac)
    plt.imshow(mask.cpu(), cmap='gray')
    plt.show()

    new_path = f'logs/acdc_ensemble_proxy_1d_{acc_fac}'
    os.makedirs(new_path, exist_ok=True)
    torch.save({
        'scores': [scores_sum.cpu()]
    }, os.path.join(new_path, 'results.pt'))