In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import nibabel as nib
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import os
import sys
import seaborn as sns

import torch

from pathlib import Path
from models.unet import *
from utils.plot import *
from utils.data_loader import *
from utils.OTFlowProblem import *

sns.set_style("white")

torch.backends.cudnn.benchmark = True


In [3]:
def compute_loss(net, x, nt): 
    Jc , costs = OTFlowProblem(x, net, [0,1], nt=nt, stepper="rk4", alph=net.alph)
    return Jc, costs

# Dataset

In [4]:
CT_DIR = Path("D:\data\covid\MosMedData Chest CT Scans with COVID-19 Related Findings COVID19_1110 1.0\studies\CT-1")
MASK_DIR = Path("D:\data\covid\MosMedData Chest CT Scans with COVID-19 Related Findings COVID19_1110 1.0\masks")

In [5]:
ds = CTSlices(CT_DIR, MASK_DIR)

In [6]:
dataloader = DataLoader(
    ds, 
    batch_size=1,                   
    shuffle=False,
    pin_memory=False,
    num_workers=0,
)

# Trained Segmentation Model

In [7]:
WEIGHT_DIR = Path("weights")

seg_weight_fpath = WEIGHT_DIR / "weights_epoch_500.h5"
seg_chkpt = torch.load(seg_weight_fpath)

model = UNet(n_channels=1, kernel_size=3, ds=1).cuda()
# model.load_state_dict(seg_chkpt['model'])
_ = model.eval()

In [8]:
for xs, _ in dataloader:
    break

In [9]:
encoding_shape = model.encode(xs).shape

torch.Size([1, 1024, 32, 32])
torch.Size([1, 512, 32, 32])
torch.Size([1, 512, 64, 64])
torch.Size([1, 512, 64, 64])
torch.Size([1, 256, 64, 64])
torch.Size([1, 256, 128, 128])
torch.Size([1, 256, 128, 128])
torch.Size([1, 128, 128, 128])
torch.Size([1, 128, 256, 256])
torch.Size([1, 128, 256, 256])
torch.Size([1, 64, 256, 256])
torch.Size([1, 64, 512, 512])


# Flow

In [11]:
alph = [1.0, 80.0, 500.0]
flow = Phi(nTh=2, m=128, d=np.prod(encoding_shape), alph=alph).cuda()
_ = flow.train()

In [12]:
N_EPOCHS = 100
N_STEPS = len(dataloader) * N_EPOCHS
learning_rate = 3e-4

opt = torch.optim.AdamW(flow.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer=opt,
    max_lr=learning_rate,
    total_steps=N_STEPS,
    cycle_momentum=True,
)

opt.zero_grad()
opt.step()

In [13]:
clampMax = 2.0
clampMin = -2.0

In [14]:
WEIGHT_DIR = Path("weights")

In [16]:
flow_losses = []

with tqdm(total=N_EPOCHS) as pbar:
    
    for cur_epoch in range(N_EPOCHS):
    
        epoch_losses = []
    
        for xs, ys in dataloader:
            # encode
            ys_hat = model.encode(xs)[:, :, 0, 0]

            # skipping normalizing encoding

            # reset opt
            opt.zero_grad()

            # clip parameters
            for p in flow.parameters():
                p.data = torch.clamp(p.data, clampMin, clampMax)

            # forward flow + loss
            loss, costs = compute_loss(flow, ys_hat, nt=2)
            loss.backward()
            opt.step()
            scheduler.step()
            
            epoch_losses.append(loss.detach().cpu().item())
            
        mean_epoch_loss = np.mean(epoch_losses)
        flow_losses.append(mean_epoch_loss)
        pbar.set_postfix(
            {
                'loss': '{:.4f}'.format(flow_losses[-1]),
            }
        )
        pbar.update(1)
            
        if cur_epoch % 10 == 0:
            WEIGHT_PATH = WEIGHT_DIR / "flow_weights_epoch_{}.h5".format(cur_epoch)
            torch.save(
                {
                    'model': flow.state_dict(),
                    'opt': opt.state_dict(),
                },
                str(WEIGHT_PATH),
            )
           
    # also save final epoch
WEIGHT_PATH = WEIGHT_DIR / "flow_weights_epoch_{}.h5".format(cur_epoch+1)
torch.save(
    {
        'model': flow.state_dict(),
        'opt': opt.state_dict(),
    },
    str(WEIGHT_PATH),
)

  1%|▋                                                              | 1/100 [00:52<1:27:10, 52.84s/it, loss=10159.3812]


KeyboardInterrupt: 