# ImageNet Model

In [1]:
import os
import sys
import torch
import numpy as np
from matplotlib import pyplot as plt
from omegaconf import OmegaConf
sys.path.append('../taming-transformers')
sys.path.append('../latent-diffusion')
from ldm.util import instantiate_from_config
from ldm.data.ct_rsna import CTSubset

In [2]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [None]:
ds = CTSubset('../data/ct-rsna/train/', 'train_set_dropped_nans.csv',size=256, flip_prob=0.5, subset_len=1024)

In [None]:
def load_model_from_config(config, ckpt):
    print(f"Loading model from {ckpt}")
    pl_sd = torch.load(ckpt)#, map_location="cpu")
    sd = pl_sd["state_dict"]
    model = instantiate_from_config(config.model)
    m, u = model.load_state_dict(sd, strict=False)
    model.cuda()
    model.eval()
    return model


def get_model():
    config = OmegaConf.load("../latent-diffusion/configs/latent-diffusion/cin256-v2.yaml")  
    model = load_model_from_config(config, "../latent-diffusion/models/ldm/cin256-v2/model.ckpt")
    return model

In [None]:
# !mkdir -p ../latent-diffusion/models/ldm/cin256-v2
# !wget -O ../latent-diffusion/models/ldm/cin256-v2/model.ckpt https://ommer-lab.com/files/latent-diffusion/nitro/cin/model.ckpt 

In [None]:
model = get_model().to(device)

In [None]:
np.random.seed(7)
K = 5
k_samples = np.random.choice(len(ds), K)

fig, ax = plt.subplots(3, K, figsize=(10, 6), sharex=True, sharey=True)
with torch.no_grad():
    for i in range(K):
        x = ds[k_samples[i]]['image'].to(device)
        y1, _, _ = mri_vae(x.unsqueeze(0))
        y2, _, _ = ct_vae(x.unsqueeze(0))

        ax[0, i].imshow(x.squeeze().cpu().numpy(), vmin=0., vmax=1., cmap='gray')
        ax[1, i].imshow(y1.squeeze().cpu().numpy(), vmin=0., vmax=1., cmap='gray')
        ax[2, i].imshow(y2.squeeze().cpu().numpy(), vmin=0., vmax=1., cmap='gray')

ax[0, 0].set_ylabel('input')
ax[1, 0].set_ylabel('initial')
ax[2, 0].set_ylabel('trained')
ax[2, 0].set_xticks([])
ax[2, 0].set_yticks([])
plt.show()