In [1]:
import matplotlib
%matplotlib inline
import matplotlib.patches as patches
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import vlkit.plt as vlplt

from scipy.stats import beta

import torch
from torch.distributions import Beta
import numpy as np
from mmcv.cnn import MODELS
from mmcv.utils import Config

from tqdm import tqdm
from einops import rearrange

import mmcv
from vlkit.image import norm01, norm255
from vlkit.medical.dicomio import read_dicom_array

from scipy.io import loadmat, savemat
import pydicom

import time, sys, os
import os.path as osp
sys.path.insert(0, '/data1/Code/vlkit/vlkit/medical')
sys.path.insert(0, '..')

from utils import write_dicom, write_dicom, inference
from aif import get_aif
from pharmacokinetic import fit_slice, process_patient, np2torch, evaluate_curve
import load_dce
from models import DCETransformer

from pharmacokinetic import calculate_reconstruction_loss



In [2]:
def remove_grad(x):
    x = x.detach()
    x.requires_grad = False
    x.grad = None
    return x

In [3]:
np.random.seed(0)
mmcv.runner.utils.set_random_seed(0)

In [4]:
device = torch.device('cuda')
dce_data = load_dce.load_dce_data('../../dicom/10042_1_003Tnq2B/20180212/t1_twist_tra_dyn_29/', device=device)
t2 = read_dicom_array('../../dicom/10042_1_003Tnq2B/20180212/t2_tse_tra_320_p2_12/')
data = loadmat('../../tmp/parker_aif/10042_1_003Tnq2B-20180212.mat')
h, w, c, _ = data['dce_ct'].shape

x_tl, y_tl, x_br, y_br = 53, 57, 107, 110
mask = torch.zeros(h, w, c, dtype=bool, device=device)
z_mask = torch.zeros(c, dtype=bool, device=device)
z_mask[10] = True
mask[y_tl:y_br, x_tl:x_br, z_mask] = 1
plt.imshow(mask[:, :, 10].cpu())
plt.close()

In [4]:
device = torch.device('cuda')
dce_data = load_dce.load_dce_data('../../dicom/10042_1_004D6Sy8/20160616/t1_twist_tra_dyn_35/', device=device)
t2 = read_dicom_array('../../dicom/10042_1_004D6Sy8/20160616/t2_tse_tra_320_p2_4/')
data = loadmat('../../tmp/parker_aif/10042_1_004D6Sy8-20160616.mat')
h, w, c, _ = data['dce_ct'].shape

In [5]:
weinmann_aif, aif_t = get_aif(aif='weinmann', max_base=6, acquisition_time=dce_data['acquisition_time'])
parker_aif, _ = get_aif(aif='parker', max_base=6, acquisition_time=dce_data['acquisition_time'])

In [6]:
fig, axes = plt.subplots(1, 3, figsize=(12, 3))
axes[0].plot(aif_t.cpu())
axes[0].grid()

axes[1].plot(weinmann_aif.cpu())
axes[1].grid()
axes[2].plot(parker_aif.cpu())
axes[2].grid()
plt.show()

In [7]:
work_dir = '../work_dirs/losses/loss_param-loss_ct/'

cfg = Config.fromfile(osp.join(work_dir, 'config.py'))
model = MODELS.build(cfg.model).to(torch.device('cuda'))
model.load_state_dict(torch.load(osp.join(work_dir, 'model-iter50000.pth')))

matlab_ktrans = data['ktrans']
matlab_kep = data['kep']
matlab_t0 = data['t0']

ct = dce_data['ct'].cuda()


torch.cuda.empty_cache()
tic = time.time()
ktrans, kep, t0 = inference(model, ct)
toc = time.time()
print(toc-tic, 'seconds')

ktrans_init = ktrans.cpu().clone()
kep_init = kep.cpu().clone()
t0_init = t0.cpu().clone()

curve_init = evaluate_curve(ktrans, kep, t0, t=dce_data['acquisition_time'], aif_t=aif_t, aif_cp=parker_aif).cpu()

loss_init = calculate_reconstruction_loss(
    ktrans,
    kep,
    t0,
    ct,
    t=dce_data['acquisition_time'],
    aif_t=aif_t,
    aif_cp=parker_aif
).cpu()

Start inference...


100%|███████████████████████████████████████████████████████████████| 160/160 [00:22<00:00,  7.25it/s]


Done, 22.065s elapsed.
22.078187942504883 seconds


In [8]:
torch.cuda.empty_cache()

ktrans.requires_grad = True
kep.requires_grad = True
t0.requires_grad = True

for i in range(50):
    ct1 = evaluate_curve(ktrans, kep, t0, t=dce_data['acquisition_time'], aif_t=aif_t, aif_cp=parker_aif)
    l = torch.nn.functional.l1_loss(ct1, ct, reduction='none').sum(dim=-1).mean()
    print(l)
    l.backward()
    ktrans.data -= ktrans.grad * 50
    kep.data -= kep.grad * 50
    t0.data -= t0.grad * 50
    ktrans.grad.zero_()


[ct1, ktrans, kep, t0, curve_init] = map(remove_grad, [ct1, ktrans, kep, t0, curve_init])

torch.cuda.empty_cache()
    
curve = evaluate_curve(ktrans, kep, t0, t=dce_data['acquisition_time'], aif_t=aif_t, aif_cp=parker_aif).cpu().detach()
loss = calculate_reconstruction_loss(
    ktrans,
    kep,
    t0,
    ct,
    t=dce_data['acquisition_time'],
    aif_t=aif_t,
    aif_cp=parker_aif
).cpu()

ktrans = ktrans.cpu()
kep = kep.cpu()
t0 = t0.cpu()
torch.cuda.empty_cache()

tensor(1.7624, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(1.5092, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(1.3224, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(1.1824, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(1.0854, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(1.0188, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.9771, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.9482, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.9291, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.9133, device='cuda:0', grad_fn=<MeanBackward0>)


In [10]:
mask = (dce_data['ct'].max(dim=-1).values >= 1/100).cpu()
y, x, z = torch.where(mask)

n = 50
inds = np.random.choice(x.numel(), min(n, x.numel()))
ct = dce_data['ct'].cpu().numpy()

ncol = 8
fig, axes = plt.subplots(n, ncol, figsize=(3*ncol, 3*n))

for idx, i in enumerate(inds):
    y1, x1, z1 = y[i].item(), x[i].item(), z[i].item()
    ct1 = ct[y1, x1, z1]

    params_1 = torch.tensor([ktrans_init[y1, x1, z1].item(), kep_init[y1, x1, z1].item(), t0_init[y1, x1, z1].item()])

    params_2 = torch.tensor([ktrans[y1, x1, z1].item(), kep[y1, x1, z1].item(), t0[y1, x1, z1].item()])

    axes[idx, 0].plot(ct1)
    axes[idx, 0].plot(curve_init[y1, x1, z1].flatten())
    axes[idx, 0].set_title('param: %.3f %.3f %.3f \n loss=%.3f'  % (params_1[0], params_1[1], params_1[2], loss_init[y1, x1, z1]))

    axes[idx, 1].plot(ct1)
    axes[idx, 1].plot(curve[y1, x1, z1].flatten())
    axes[idx, 1].set_title('param: %.3f %.3f %.3f \n loss=%.3f'  % (params_2[0], params_2[1], params_2[2], loss[y1, x1, z1]))

    t2im = mmcv.imresize(norm01(t2[z1]), (h, w))
    axes[idx, 2].imshow(t2im, cmap=cm.Greys_r)
    rect = patches.Rectangle((x_tl, y_tl), x_br-x_tl, y_br-y_tl, linewidth=1, edgecolor='black', facecolor='none')
    axes[idx, 2].add_patch(rect)
    axes[idx, 2].set_title('T2 (full) \n (%d, %d, %d)' % (x1, y1, z1))
    axes[idx, 2].scatter(x1, y1, marker='x', color='red')

    axes[idx, 3].imshow(t2im[y_tl:y_br, x_tl:x_br], cmap=cm.Greys_r)
    axes[idx, 3].set_title('T2 (ROI) \n (%d, %d, %d)' % (x1, y1, z1))
    
    axes[idx, 4].imshow(norm01(ktrans_init[:, :, z1].numpy())[y_tl:y_br, x_tl:x_br])
    axes[idx, 4].scatter(x1-x_tl, y1-y_tl, marker='x', color='red')
    axes[idx, 4].set_title('Weinmann Ktrans')
    
    axes[idx, 5].imshow(norm01(ktrans[:, :, z1].numpy())[y_tl:y_br, x_tl:x_br])
    axes[idx, 5].scatter(x1-x_tl, y1-y_tl, marker='x', color='red')
    axes[idx, 5].set_title('Parker Ktrans')
    
    axes[idx, 6].imshow(norm01(loss_init[:, :, z1].numpy())[y_tl:y_br, x_tl:x_br])
    axes[idx, 6].scatter(x1-x_tl, y1-y_tl, marker='x', color='red')
    axes[idx, 6].set_title('Weinmann loss %.3f' % loss_init[y1, x1, z1].item())
    
    axes[idx, 7].imshow(norm01(loss[:, :, z1].numpy())[y_tl:y_br, x_tl:x_br])
    axes[idx, 7].scatter(x1-x_tl, y1-y_tl, marker='x', color='red')
    axes[idx, 7].set_title('Parker loss %.3f' % loss[y1, x1, z1].item())

plt.tight_layout(h_pad=3)
plt.savefig('compare.pdf')
plt.close()