In [1]:
%load_ext autoreload
%autoreload 2

import os

os.chdir('/data/core-rad/tobweber/bernoulli-mri')

import torch
from src.run import run_dataset_optim

In [2]:
from src.losses import SegmentationProxyLoss
from torch import nn
from monai.networks.nets import UNet
from monai.losses import DiceCELoss

model = UNet(
    spatial_dims=2,
    in_channels=4,
    out_channels=4,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2
).cuda()

sd = torch.load('models/brain_base.pt')
model.load_state_dict(sd['model'])

loss_func = DiceCELoss(
    softmax=True,
    include_background=False,
    to_onehot_y=True
)

loss_func = SegmentationProxyLoss(model=model, seg_loss_func=loss_func).cuda()

In [3]:
cfg = {
    'dataset': 'brain',
    'dataset_root': '/data/core-rad/data',
    'batch_size': 32,
    'steps': 2500,
    'use_seg': True,
    'learning_rate': 1e-3,
    'bern_samples': 4,
    'mask_style': 'h',
    'num_workers': 32,
    'dense_target': 1 / 8,
    'dense_start': 0.10,
    'dense_end': 0.90,
    'device': torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
    'log_dir': 'logs/brats_ensemble_8',
    'log_imgs': 10,
    'log_interval': 100,
    'seed': None,
}

In [4]:
num_members = 10
acc_facs = [8, 16, 32,]

for acc_fac in acc_facs:
    for i in range(1, 1 + num_members):
        cfg['dense_target'] = 1 / acc_fac
        cfg['log_dir'] = os.path.join('logs', 'brats_ensemble_proxy_1d_' + str(acc_fac) + '_m' + str(i))
    
        run_dataset_optim(cfg, loss_func)

L: 8.40E-01 | D: 0.125: 100%|█████████████| 2500/2500 [05:44<00:00,  7.26it/s]
L: 8.38E-01 | D: 0.125: 100%|█████████████| 2500/2500 [05:33<00:00,  7.50it/s]
L: 9.39E-01 | D: 0.125: 100%|█████████████| 2500/2500 [05:30<00:00,  7.56it/s]
L: 8.17E-01 | D: 0.125: 100%|█████████████| 2500/2500 [05:33<00:00,  7.49it/s]
L: 7.09E-01 | D: 0.125: 100%|█████████████| 2500/2500 [05:33<00:00,  7.50it/s]
L: 7.67E-01 | D: 0.125: 100%|█████████████| 2500/2500 [05:35<00:00,  7.44it/s]
L: 7.13E-01 | D: 0.125: 100%|█████████████| 2500/2500 [05:32<00:00,  7.52it/s]
L: 8.00E-01 | D: 0.125: 100%|█████████████| 2500/2500 [05:37<00:00,  7.42it/s]
L: 8.97E-01 | D: 0.125: 100%|█████████████| 2500/2500 [05:31<00:00,  7.55it/s]
L: 8.32E-01 | D: 0.125: 100%|█████████████| 2500/2500 [05:29<00:00,  7.59it/s]
L: 9.85E-01 | D: 0.063: 100%|█████████████| 2500/2500 [05:36<00:00,  7.43it/s]
L: 1.02E+00 | D: 0.063: 100%|█████████████| 2500/2500 [05:40<00:00,  7.35it/s]
L: 9.67E-01 | D: 0.063: 100%|█████████████| 2500/250

In [5]:
# num_members = 10
# acc_facs = [8, 16, 32,]

# for acc_fac in acc_facs:
#     for i in range(1, 1 + num_members):
#         cfg['dense_target'] = 1 / acc_fac
#         cfg['log_dir'] = os.path.join('logs', 'brats_ensemble_' + str(acc_fac) + '_m' + str(i))
    
#         run_dataset_optim(cfg)

# Compute Expected Mask

In [6]:
import os
import matplotlib.pyplot as plt
from src.utils import get_top_k_mask


for acc_fac in acc_facs:
    path_stem = f'logs/brats_ensemble_proxy_1d_{acc_fac}_m'

    paths = [path_stem + str(i) + '/results.pt' for i in range(1, num_members + 1)]
    scores = [torch.load(f)['scores'][-1].cuda() for f in paths]

    scores_sum = torch.sum(torch.cat(scores), dim=(0,1))

    mask = get_top_k_mask(scores_sum, acc_fac)
    plt.imshow(mask.cpu(), cmap='gray')
    plt.show()

    new_path = f'logs/brats_ensemble_proxy_1d_{acc_fac}'
    os.makedirs(new_path, exist_ok=True)
    torch.save({
        'scores': [scores_sum.cpu()]
    }, os.path.join(new_path, 'results.pt'))