In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import nibabel as nib
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import os
import sys
import seaborn as sns

import torch

from pathlib import Path
from models.unet import *
from utils.plot import *
from utils.data_loader import *
from utils.OTFlowProblem import *

sns.set_style("white")

torch.backends.cudnn.benchmark = True

In [3]:
def compute_loss(net, x, nt): 
    Jc , costs = OTFlowProblem(x, net, [0,1], nt=nt, stepper="rk4", alph=net.alph)
    return Jc, costs

# Dataset

In [4]:
CT_DIR = Path("D:\data\covid\MosMedData Chest CT Scans with COVID-19 Related Findings COVID19_1110 1.0\studies\CT-1")
MASK_DIR = Path("D:\data\covid\MosMedData Chest CT Scans with COVID-19 Related Findings COVID19_1110 1.0\masks")

In [5]:
ds = CTSlices(CT_DIR, MASK_DIR)

In [6]:
dataloader = DataLoader(
    ds, 
    batch_size=2,                   
    shuffle=False,
    pin_memory=False,
    num_workers=0,
)

# Trained Segmentation Model

In [7]:
WEIGHT_DIR = Path("weights")

seg_weight_fpath = WEIGHT_DIR / "weights_epoch_100.h5"
seg_chkpt = torch.load(seg_weight_fpath)

model = UNet(n_channels=1, kernel_size=3, ds=1).cuda()
model.load_state_dict(seg_chkpt['model'])
_ = model.eval()

In [8]:
for xs, _ in dataloader:
    ys_hat = model.encode(xs)
    break

encoding_shape = ys_hat.shape

In [12]:
encoding_shape

torch.Size([2, 64, 1, 1])

# Flow

In [13]:
alph = [1.0, 2000.0, 800.0] # scale each term in loss: negloglik, transport, HJB reg
clampMax = 2.0
clampMin = -2.0

nt = 14 # n time steps
nTh = 2 # num layers in internal ResNet
m = 64 # hidden dim in internal ResNet

flow = Phi(nTh=nTh, m=m, d=np.prod(encoding_shape[1:]), alph=alph).cuda()
_ = flow.train()

In [14]:
N_EPOCHS = 100
N_STEPS = len(dataloader) * N_EPOCHS
learning_rate = 1e-4

opt = torch.optim.Adam(flow.parameters(), lr=learning_rate)

In [None]:
flow_losses = []

with tqdm(total=N_EPOCHS) as pbar:
    
    for cur_epoch in range(N_EPOCHS):
    
        epoch_losses = []
    
        for xs, ys in dataloader:
            # encode
            ys_hat = model.encode(xs)[:, :, 0, 0]

            # skip normalize encoding

            # reset opt
            opt.zero_grad()

            # clip parameters
            for p in flow.parameters():
                p.data = torch.clamp(p.data, clampMin, clampMax)

            # forward flow + loss
            loss, costs = compute_loss(flow, ys_hat, nt=nt)
            loss.backward()
            opt.step()
            
            negloglik = costs[0]
            
            epoch_losses.append(negloglik.detach().cpu().item())
            
        mean_epoch_loss = np.mean(epoch_losses)
        flow_losses.append(mean_epoch_loss)
        pbar.set_postfix(
            {
                'loss': '{:.4f}'.format(flow_losses[-1]),
            }
        )
        pbar.update(1)
            
        if cur_epoch % 10 == 0:
            WEIGHT_PATH = WEIGHT_DIR / "flow_weights_epoch_{}.h5".format(cur_epoch)
            torch.save(
                {
                    'model': flow.state_dict(),
                    'opt': opt.state_dict(),
                },
                str(WEIGHT_PATH),
            )
           
    # also save final epoch
WEIGHT_PATH = WEIGHT_DIR / "flow_weights_epoch_{}.h5".format(cur_epoch+1)
torch.save(
    {
        'model': flow.state_dict(),
        'opt': opt.state_dict(),
    },
    str(WEIGHT_PATH),
)

  3%|█▉                                                               | 3/100 [02:00<1:06:17, 41.01s/it, loss=611.5996]