In [15]:
%load_ext autoreload
%autoreload 2

import os

os.chdir('/data/core-rad/tobweber/bernoulli-mri')
import torch
from src.igs import IGS
from src.datasets import ACDCDataset, BrainDataset, KneeDataset

log_dir = 'logs/IGS'
os.makedirs(log_dir, exist_ok=True)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Run IGS Reconstruction Training

In [2]:
ds = ACDCDataset('/data/core-rad/data/ACDC', train=True)
igs = IGS()
w_list = igs.run(ds, acc_fac=4)

w_full = torch.stack(w_list).cpu()
torch.save(w_full, os.path.join(log_dir, 'igs_acdc_v2.pt'))

select: 163, loss: 0.008995: 100%|████████████| 64/64 [00:48<00:00,  1.32it/s]


In [3]:
ds = KneeDataset('/data/knee_fastmri', train=True)
igs = IGS()
w_list = igs.run(ds, acc_fac=4)

w_full = torch.stack(w_list).cpu() 
torch.save(w_full, os.path.join(log_dir, 'igs_knee.pt'))

select: 113, loss: 0.002925: 100%|██████████| 80/80 [12:54<00:00,  9.69s/it]


In [3]:
ds = BrainDataset('/data/core-rad/data/Task01_BrainTumour', train=True)
igs = IGS()
w_list = igs.run(ds, acc_fac=4)

w_full = torch.stack(w_list).cpu()
torch.save(w_full, os.path.join(log_dir, 'igs_brats.pt'))

select: 97, loss: 0.006519: 100%|███████████| 64/64 [1:44:41<00:00, 98.16s/it]


## Run IGS Segmentation Training

For this, the base class is updated to work with the `SegmentationProxyLoss`.

In [16]:
from monai.networks.nets import UNet
from monai.losses import DiceCELoss
from src.losses import SegmentationProxyLoss

model = UNet(
    spatial_dims=2,
    in_channels=4,
    out_channels=4,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2
)

sd = torch.load('models/brain_base.pt')
model.load_state_dict(sd['model'])

loss_func = DiceCELoss(
    softmax=True,
    include_background=False,
    to_onehot_y=True
)

loss_func = SegmentationProxyLoss(model=model, seg_loss_func=loss_func).to('cuda')

In [17]:
from typing import Optional, List

import torch
from torch import nn, Tensor
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from src.utils import ifft2c


class IGS:
    """Implements Iterative Gradients Sampling from Razumov et al"""

    def __init__(
            self,
            loss_func: Optional[nn.Module] = None,
            device: str = 'cuda',
            num_workers: int = 8,
            batch_size: int = 32
    ) -> None:
        if loss_func is None:
            self.loss_func = nn.L1Loss()
        else:
            self.loss_func = loss_func

        self.device = torch.device(device)
        self.num_workers = num_workers
        self.batch_size = batch_size

    @staticmethod
    def get_n(acc_fac: int, img_size: int) -> int:
        return img_size // acc_fac

    def run(self, ds: Dataset, acc_fac: int) -> List[Tensor]:
        dl = DataLoader(
            dataset=ds,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers
        )

        img_size = ds[0]['img'].shape[-1]
        n = IGS.get_n(acc_fac, img_size)
        w = torch.zeros(
            img_size,
            dtype=torch.float,
            device=self.device,
        )
        w[img_size//2] = 1

        w_list = []

        pbar = tqdm(range(n))
        for _ in pbar:

            w.grad = None
            w.requires_grad = True
            for batch in dl:
                img_k = batch['k_space'].to(self.device)

                img_pred = ifft2c(img_k * w + 0.0)
                img_mag = torch.abs(img_pred)

                loss = self.loss_func(img_mag, batch['seg'].to(self.device))
                loss.backward()

            for i in torch.topk(w.grad, img_size, largest=False).indices:
                if w[i] == 0:
                    w = w.detach()
                    w[i] = 1.
                    w_list.append(w.clone())
                    pbar.set_description('select: %d, loss: %.6f' % (i.item(), loss.item()))
                    break
        return w_list


In [11]:
ds = ACDCDataset('/data/core-rad/data/ACDC', train=True)
igs = IGS(loss_func=loss_func)
w_list = igs.run(ds, acc_fac=4)

w_full = torch.stack(w_list).cpu()
torch.save(w_full, os.path.join(log_dir, 'igs_acdc_seg.pt'))

select: 94, loss: 0.129039: 100%|█████████████| 64/64 [01:05<00:00,  1.02s/it]


In [9]:
ds = KneeDataset('/data/knee_fastmri', train=True)
igs = IGS(loss_func=loss_func)
w_list = igs.run(ds, acc_fac=4)

w_full = torch.stack(w_list).cpu() 
torch.save(w_full, os.path.join(log_dir, 'igs_knee_seg.pt'))

select: 68, loss: 1.085384: 100%|██████████| 80/80 [29:17<00:00, 21.97s/it] 


In [18]:
ds = BrainDataset('/data/core-rad/data/Task01_BrainTumour', train=True)
igs = IGS(loss_func=loss_func)
w_list = igs.run(ds, acc_fac=8)

w_full = torch.stack(w_list).cpu()
torch.save(w_full, os.path.join(log_dir, 'igs_brats_seg.pt'))

select: 143, loss: 0.567427: 100%|████████████| 32/32 [42:31<00:00, 79.74s/it]
