In [8]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '4' ## Change to empty GPU
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'

!pip install -q optax
!pip install -q git+https://github.com/deepmind/dm-haiku

In [9]:
from functools import partial
import jax
from jax import random, grad, jit, vmap, flatten_util, nn
from jax.experimental import optix
from jax.config import config
import jax.numpy as np

import haiku as hk

from livelossplot import PlotLosses
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm as tqdm
import os
import cv2
import imageio
from jax.scipy import ndimage

from phantominator import shepp_logan, ct_shepp_logan, ct_modified_shepp_logan_params_2d

import pickle

ImportError: cannot import name 'optix' from 'jax.experimental' (/home/chung/anaconda3/lib/python3.8/site-packages/jax/experimental/__init__.py)

# Load Data

In [3]:
DATA_DIR = 'ct_256.pkl' ## Update

with open(DATA_DIR, 'rb') as file:
    dataset = pickle.load(file)
    
x1 = np.linspace(0, 1, dataset['data_test'][0].shape[0]+1)[:-1] # use full image resolution 
x_train = np.stack(np.meshgrid(x1,x1), axis=-1)
x_test = x_train

plt.figure(figsize=(15,8))
for i in range(4):
    plt.subplot(1,4,i+1)
    plt.imshow(dataset['data_test'][i,:,:])
plt.show()

FileNotFoundError: [Errno 2] No such file or directory: 'ct_256.pkl'

# Model

In [4]:
class Model(hk.Module):
    def __init__(self):
        super().__init__()
        self.rff = random.normal(jax.random.PRNGKey(0), shape=(2,256)) * 30
        self.width = 256
        self.depth = 5        
        
    def __call__(self, coords):
        sh = coords.shape
        coords = np.reshape(coords, [-1,2])
        
        x = np.concatenate([np.sin(coords @ self.rff), np.cos(coords @ self.rff)], axis=-1)

        for _ in range(self.depth-1):
            x = hk.Linear(output_size=self.width)(x)
            x = jax.nn.relu(x)
            
        out = hk.Linear(output_size=1)(x)
        out = jax.nn.sigmoid(out)
        out = np.reshape(out, list(sh[:-1]))
        return out

NameError: name 'hk' is not defined

# CT Projection

In [None]:
@jit
def ct_project(img, theta):
    y, x = np.meshgrid(np.arange(int(img.shape[0]), dtype=np.float32)/int(img.shape[0]) - 0.5, np.arange(int(img.shape[1]), dtype=np.float32)/int(img.shape[1]) - 0.5, indexing='ij')
    x_rot = x*np.cos(theta) - y*np.sin(theta)
    y_rot = x*np.sin(theta) + y*np.cos(theta)
    x_rot = (x_rot + 0.5)*img.shape[1]
    y_rot = (y_rot + 0.5)*img.shape[0]
    sample_coords = np.stack([y_rot, x_rot], axis=0)
    resampled = ndimage.map_coordinates(img, sample_coords, 0).reshape(img.shape)
    return resampled.mean(axis=0)[:,None,...]

ct_project_batch = vmap(ct_project, (None, 0), 0)
ct_project_double_batch = vmap(ct_project_batch, (0, 0), 0)

In [None]:
thetas = np.linspace(0,np.pi, 128)
test_img = dataset['data_test'][1]
projections = ct_project_batch(test_img, thetas)

plt.figure(figsize=(15,4))
plt.subplot(1,2,1)
plt.imshow(test_img)
plt.title('Phantom')
plt.subplot(1,2,2)
plt.imshow(projections[:,:,0])
plt.title('Sinogram')
plt.show()

# Train Model

In [None]:
# Rerun to reset plots
plt_groups = {'Train PSNR':[], 'Test PSNR':[]}
plotlosses_model = PlotLosses(groups=plt_groups)

In [None]:
CHECKPOINT_DIR = 'checkpoint/ct_checkpoints/' # Choose dir to save checkpoints

max_iters = 100000
test_steps = 10

batch_size = 1
num_projections = 20
num_test_projections = 10

inner_update_steps = 12
inner_lr = 10 #5e-3
lr = 5e-5

exp_name = f'ilr_{inner_lr}_olr_{lr}_ius_{inner_update_steps}_bs_{batch_size}'

test_thetas = np.linspace(0, np.pi, num_test_projections, endpoint=False)

coords = x_train

key = hk.PRNGSequence(42)
model = hk.without_apply_rng(hk.transform(lambda x: Model()(x)))
params = model.init(next(key), np.ones((1,2)))

opt = optix.adam(lr)#, b1=.5, b2=.9)
opt_state = opt.init(params)

opt_inner = optix.sgd(inner_lr)#, b1=.5, b2=.9)
# opt_inner = optix.adam(inner_lr)#, b1=.5, b2=.9)

mse_fn = jit(lambda x, y: np.mean((x - y)**2))
psnr_fn = jit(lambda x, y: -10 * np.log10(mse_fn(x, y)))

@partial(jit, static_argnums=[4])
def model_step(image_proj, coords, thetas, params, opt, opt_state):
    def loss_latent(params):
        g = model.apply(params, coords)
        g_proj = ct_project_batch(g, thetas)
        return mse_fn(g_proj, image_proj)

    loss, grad = jax.value_and_grad(loss_latent)(params)
    updates, opt_state = opt.update(grad, opt_state)
    params = optix.apply_updates(params, updates)
    return params, opt_state, loss

@partial(jit, static_argnums=[5])
def update_network_weights(rng, image_proj, coords, thetas, params, update_steps):
    opt_inner_state = opt_inner.init(params)
    for _ in range(update_steps):
        params, opt_inner_state, loss = model_step(image_proj, coords, thetas, params, opt_inner, opt_inner_state)
    return rng, params, loss

update_network_weights_batch = vmap(update_network_weights, in_axes=[0, 0, None, 0, None, None])

@jit
def update_model(rng, params, opt_state, image, coords, thetas):
    image_proj = ct_project_double_batch(image, thetas)
    rng = random.split(rng, batch_size)
    rng, new_params, loss = update_network_weights_batch(rng, image_proj, coords, thetas, params, inner_update_steps)
    rng, loss = rng[0], np.mean(loss)
    new_params = jax.tree_map(lambda x: np.mean(x, axis=0), new_params)
    def calc_grad(params, new_params):
        return params - new_params

    model_grad = jax.tree_multimap(calc_grad, params, new_params)

    updates, opt_state = opt.update(model_grad, opt_state)
    params = optix.apply_updates(params, updates)
    return rng, params, opt_state, loss

plt_groups['Train PSNR'].append(exp_name+f'_train')
plt_groups['Test PSNR'].append(exp_name+f'_test')

train_psnrs = []
step = 0
rng = random.PRNGKey(0)
rng_test = random.PRNGKey(42)
for step in tqdm(range(max_iters)):
    rng, rng_input = random.split(rng)
    train_idx = random.randint(rng_input, (batch_size,), 0, dataset['data_train'].shape[0])
    train_img = dataset['data_train'][train_idx]
    if batch_size == 1:
        train_img = train_img[None,...]
    
    rng, rng_input = random.split(rng)
    thetas = random.uniform(rng_input, (batch_size, num_projections), minval=0, maxval=np.pi)
    
    rng, params, opt_state, loss = update_model(rng, params, opt_state, train_img, coords, thetas)
    train_psnrs.append(-10 * np.log10(loss))

    if step % 500 == 0:
        plotlosses_model.update({exp_name+'_train':np.mean(np.array(train_psnrs))}, current_step=step)
        train_psnrs = []
    if step % 500 == 0 and step != 0:
        test_psnrs = []
        for i in range(5):
            test_img = dataset['data_test'][i]
            test_img_proj = ct_project_batch(test_img, test_thetas)
            _, params_test, loss = update_network_weights(rng_test, test_img_proj, coords, test_thetas, params, test_steps)
            img = model.apply(params_test, coords)

            test_psnrs.append(psnr_fn(img, test_img))
        plotlosses_model.update({exp_name+'_test':np.mean(np.array(test_psnrs))}, current_step=step)
        plotlosses_model.send()

        plt.figure(figsize=(15,5))        
        plt.subplot(1,2, 1)
        plt.imshow(img)
        plt.subplot(1,2, 2)
        plt.imshow(test_img)
        plt.show()
    if step % 5000 == 0 and step != 0:
        with open(f'{CHECKPOINT_DIR}/{exp_name}_{step}.pkl', 'wb') as file:
            pickle.dump(params, file)


# Evaluate Network

### Test time optimization

In [None]:
mse_fn = jit(lambda x, y: np.mean((x - y)**2))
psnr_fn = jit(lambda x, y: -10 * np.log10(mse_fn(x, y)))

@partial(jit, static_argnums=[0,5])
def test_model_step(model, image_proj, coords, thetas, params, opt, opt_state):
    def loss_fn(params):
        g = model.apply(params, coords)
        g_proj = ct_project_batch(g, thetas)
        return mse_fn(g_proj, image_proj), g

    (loss, img), grad = jax.value_and_grad(loss_fn, has_aux=True)(params)
    updates, opt_state = opt.update(grad, opt_state)
    params = optix.apply_updates(params, updates)
    return params, opt_state, loss, img

def train_model(image, lr, steps, num_views, adam=True, params=None):
    rng = random.PRNGKey(0)
    x1 = np.linspace(0, 1, RES+1)[:-1]
    coords = np.stack(np.meshgrid(x1,x1), axis=-1)
    
    thetas = np.linspace(0,np.pi, num_views, endpoint=False)
    image_proj = ct_project_batch(image, thetas)

    model = hk.without_apply_rng(hk.transform(lambda x: Model()(x)))
    if params is None:
        params = model.init(rng, coords)
        
    if adam:
        opt = optix.adam(lr)
    else:
        opt = optix.sgd(lr)
    opt_state = opt.init(params)
    
    train_psnrs = []
    test_psnrs = []
    for i in range(steps):
        params, opt_state, loss, img = test_model_step(model, image_proj, coords, thetas, params, opt, opt_state)
        train_psnrs.append(-10 * np.log10(loss))
        test_psnrs.append(psnr_fn(img, image))

    return train_psnrs, test_psnrs, img, params

### Load Checkpoint (optional)

In [None]:
chkpt_pkl = 'checkpoint/ct_checkpoints/ilr_10_olr_5e-05_ius_12_bs_1_100000.pkl'

with open(chkpt_pkl, 'rb') as file:
    params = pickle.load(file)

In [None]:
    'maml': {
        'params': params_maml,
        'lr': 10,
        'adam': False,
        'steps': [50,100,1000,1000]
    },
    'reptile': {
        'params': params_reptile,
        'lr': 10,
        'adam': False,
        'steps': [50,100,1000,1000]
    },

### Render Examples

In [None]:
num_views = 6
lr = 10
steps = 1000
adam = False

num_examples = 4

for image in dataset['data_test'][:num_examples]:
    train_psnrs, test_psnrs, rec_img, _ = train_model(image, lr, steps, num_views, adam, params)
    rec_img = (np.clip(rec_img, 0, 1)*255).astype(np.uint8)
    
    plt.figure(figsize=(15,4))
    plt.subplot(1,2,1)
    plt.imshow(image)
    plt.title('Target')
    plt.subplot(1,2,2)
    plt.imshow(rec_img)
    plt.title(f'Reconstruction ({num_views} views)')
    plt.show()