In [None]:
import pickle
import torch

import gradoptics as optics
from gradoptics.integrator import HierarchicalSamplingIntegrator
from ml.siren import Siren

import torch.nn as nn
import matplotlib.pyplot as plt
from tqdm import tqdm
import numpy as np

## Same scene setup as in training notebook

In [None]:
light_source = optics.LightSourceFromDistribution(optics.AtomCloud(phi=0.1, w0=0.01, k_fringe=2*np.pi/(0.001),
                                                                   position=[0., 0., 0.]))

In [None]:
device = 'cuda'

in_features = 3
hidden_features = 256
hidden_layers = 3
out_features = 1

model = Siren(in_features, hidden_features, hidden_layers, out_features,
              outermost_linear=True, outermost_linear_activation=nn.ReLU()).double().to(device)

## Set up grid to sample
1/sqrt(2) is the edge of the cube that falls within a sphere of radius 1 (to avoid corner effects)

In [None]:
n_side = 100
bound = 1/np.sqrt(2.)
grid = torch.cartesian_prod(torch.linspace(-bound, bound, n_side),
                            torch.linspace(-bound, bound, n_side),
                            torch.linspace(-bound, bound, n_side)).cuda().double()

## Load model and evaluate densities

In [None]:
dir_name = '/sdf/group/magis/sgaz/NW-MOT/models/'
f_pattern = 'model_*_NW_MOT_all_cameras_long.pt'

In [None]:
from glob import glob
n_checkpoints = len(glob(dir_name+f_pattern))

In [None]:
last = (n_checkpoints-1)*500
print(last)

In [None]:
batch_size = 100**3
with torch.no_grad():
    fname = f_pattern.replace("*", "{last}").format(last=last)
    model.load_state_dict(torch.load(dir_name+fname))
    densities = []
    grid_batches = grid.split(batch_size)
    for grid_batch in tqdm(grid_batches):
        densities.append(model(grid_batch)[0].cpu())
    densities = torch.cat(densities).reshape((n_side, n_side, n_side))

In [None]:
rad=0.03
n_side = 100
real_grid = torch.cartesian_prod(torch.linspace(-bound*rad, bound*rad, n_side),
                            torch.linspace(-bound*rad, bound*rad, n_side),
                            torch.linspace(-bound*rad, bound*rad, n_side)).cuda().double()
pdf_vals = light_source.pdf(real_grid).reshape((n_side, n_side, n_side)).cpu()

In [None]:
%matplotlib inline
fig, ax = plt.subplots(2, 4, figsize=(8,4))

ax[0, 0].text(0.9, 0.5, "Reconstructed", ha='right', fontsize=14)
ax[0, 0].axis('off')

ax[1, 0].text(0.9, 0.5, "True", ha='right', fontsize=14)
ax[1, 0].axis('off')

ax[0, 1].imshow(densities.sum(dim=0).T, origin="lower")
ax[0, 1].set_title('Sum x', fontsize=14)
ax[0, 1].axis('off')

ax[0, 2].imshow(densities.sum(dim=1).T, origin="lower")
ax[0, 2].set_title('Sum y', fontsize=14)
ax[0, 2].axis('off')

ax[0, 3].imshow(densities.sum(dim=2).T, origin="lower")
ax[0, 3].set_title('Sum z', fontsize=14)
ax[0, 3].axis('off')

ax[1, 1].imshow(pdf_vals.sum(dim=0).T, origin="lower")
ax[1, 1].axis('off')

ax[1, 2].imshow(pdf_vals.sum(dim=1).T, origin="lower")
ax[1, 2].axis('off')

ax[1, 3].imshow(pdf_vals.sum(dim=2).T, origin="lower")
ax[1, 3].axis('off')
plt.tight_layout()
save_name = fname.replace('.pt', '.png').replace('model', 'marginal')
plt.savefig(save_name, dpi=300)

## You can really do whatever analysis you want

Marginals are as above

mrcfile lets you load a 3D representation with, e.g., ChimeraX

In [None]:
import mrcfile
filename = fname.replace('.pt', '.mrc').replace('model', 'test_mrc')
with mrcfile.new(dir_name+filename, overwrite=True) as mrc:
    mrc.set_data(densities.float().cpu().detach().numpy())
    mrc.voxel_size = 2*rad*bound/grid.shape[0]

We can make a training animation

In [None]:
all_densities = []
with torch.no_grad():
    for n_iter in tqdm(np.arange(0, n_checkpoints*500, 500)):
        model.load_state_dict(torch.load(f_pattern.replace("*", "{n_iter}").format(n_iter=n_iter)))
        densities = model(grid)[0].reshape((n_side, n_side, n_side)).cpu()
        all_densities.append(densities.clone())

In [None]:
%matplotlib notebook
from celluloid import Camera
fig, ax = plt.subplots(2, 4, figsize=(8,4))
camera = Camera(fig)

for i in range(len(all_densities)):
    ax[0, 0].text(0.9, 0.5, "Reconstructed", ha='right', fontsize=14)
    ax[0, 0].axis('off')

    ax[1, 0].text(0.9, 0.5, "True", ha='right', fontsize=14)
    ax[1, 0].axis('off')

    ax[0, 1].imshow(all_densities[i].sum(dim=0).T, origin="lower")
    ax[0, 1].set_title('Sum x', fontsize=14)
    ax[0, 1].axis('off')

    ax[0, 2].imshow(all_densities[i].detach().sum(dim=1).T, origin="lower")
    ax[0, 2].set_title('Sum y', fontsize=14)
    ax[0, 2].axis('off')

    ax[0, 3].imshow(all_densities[i].sum(dim=2).T, origin="lower")
    ax[0, 3].set_title('Sum z', fontsize=14)
    ax[0, 3].axis('off')
    

    ax[1, 1].imshow(pdf_vals.sum(dim=0).T, origin="lower")
    ax[1, 1].axis('off')

    ax[1, 2].imshow(pdf_vals.sum(dim=1).T, origin="lower")
    ax[1, 2].axis('off')

    ax[1, 3].imshow(pdf_vals.sum(dim=2).T, origin="lower")
    ax[1, 3].axis('off')

    plt.tight_layout()
    camera.snap()

animation = camera.animate()
save_name = fname.replace('.pt', '.mp4').replace('model', 'training')
animation.save(save_name)

#from IPython.display import HTML
#HTML(animation.to_html5_video())

Or we can render images

In [None]:
scene_objects = pickle.load(open("NW_mot_scene_components.pkl", "rb"))
targets = pickle.load(open("NW_mot_images.pkl", "rb"))

In [None]:
sel_mask = torch.ones(targets.shape[1:], dtype=torch.bool)
sel_mask[:250] = 0
sel_mask[1750:] = 0
sel_mask[:, 1500:] = 0

In [None]:
rad = 0.03
obj_pos = (0, 0, 0)

light_source = optics.LightSourceFromNeuralNet(model, optics.BoundingSphere(radii=rad, 
                                                                     xc=obj_pos[0], yc=obj_pos[1], zc=obj_pos[2]),
                                        rad=rad, x_pos=obj_pos[0], y_pos=obj_pos[1], z_pos=obj_pos[2])
scene_train = optics.Scene(light_source)

for obj in scene_objects:
    scene_train.add_object(obj)

In [None]:
sensor_list = [obj for obj in scene_train.objects if type(obj) == optics.Sensor]
lens_list = [obj for obj in scene_train.objects if type(obj) == optics.PerfectLens]

In [None]:
from gradoptics.integrator import HierarchicalSamplingIntegrator
integrator = HierarchicalSamplingIntegrator(64, 64, stratify = False)

In [None]:
with torch.no_grad():
    batch_size = 200000//40

    im_all = []
    camera_list = torch.arange(len(targets))
    for data_id in camera_list:
        print(data_id.item())
        sensor_here = sensor_list[data_id]
        lens_here = lens_list[data_id]

        h_here, w_here = sensor_list[data_id].resolution

        idxs_all = torch.cartesian_prod(torch.arange(h_here//2, -h_here//2, -1), 
                                        torch.arange(w_here//2, -w_here//2, -1))

        idxs_all = idxs_all[sel_mask.flatten()]

        all_pixels = torch.arange(0, len(idxs_all))
        all_pixels = all_pixels.split(batch_size)

        intensities_all = []
        for pixels_batch in tqdm(all_pixels):
            batch_pix_x = idxs_all[pixels_batch, 0]
            batch_pix_y = idxs_all[pixels_batch, 1]


            intensities_batch = optics.ray_tracing.ray_tracing.render_pixels(sensor_here, 
                                                          lens_here, 
                                                         scene_train, scene_train.light_source, 1, 5, 
                                                         batch_pix_x, batch_pix_y,
                                                         integrator, device='cuda',max_iterations=6)
            intensities_all.append(intensities_batch.clone())
        im = torch.cat(intensities_all).reshape((1500, 1500)).cpu()
        im_all.append(im.clone())

In [None]:
plt.imshow(im.T)