In [None]:
!nvidia-smi

Thu Jun 24 16:46:48 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 465.27       Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   74C    P8    12W /  70W |      0MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [2]:
!pip install einops

Collecting einops
  Downloading https://files.pythonhosted.org/packages/5d/a0/9935e030634bf60ecd572c775f64ace82ceddf2f504a5fd3902438f07090/einops-0.3.0-py2.py3-none-any.whl
Installing collected packages: einops
Successfully installed einops-0.3.0


In [3]:
import numpy as np
import matplotlib.pyplot as plt
import random
import gc
from math import exp
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange

import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
from torch.autograd import Variable

In [4]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
def gaussian(window_size, sigma):
    gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
    return gauss/gauss.sum()

def create_window(window_size, channel):
    _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
    _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
    window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
    return window

def _ssim(img1, img2, mask, window, window_size, channel, size_average = True):
    mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel)
    mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel)

    mu1_sq = mu1.pow(2)
    mu2_sq = mu2.pow(2)
    mu1_mu2 = mu1*mu2

    sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq
    sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq
    sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2

    C1 = 0.01**2
    C2 = 0.03**2

    ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))
    ssim_map = ssim_map*mask

    if size_average:
        return (ssim_map.mean(1).mean(1).mean(1)/mask.mean(1).mean(1).mean(1)).mean()
    else:
        return ssim_map.mean(1).mean(1).mean(1)/mask.mean(1).mean(1).mean(1)

class SSIM(torch.nn.Module):
    def __init__(self, window_size = 11, size_average = True):
        super(SSIM, self).__init__()
        self.window_size = window_size
        self.size_average = size_average
        self.channel = 1
        self.window = create_window(window_size, self.channel)

    def forward(self, img1, img2, mask):
        (_, channel, _, _) = img1.size()

        if channel == self.channel and self.window.data.type() == img1.data.type():
            window = self.window
        else:
            window = create_window(self.window_size, channel)
            
            if img1.is_cuda:
                window = window.cuda(img1.get_device())
            window = window.type_as(img1)
            
            self.window = window
            self.channel = channel


        return _ssim(img1, img2, mask, window, self.window_size, channel, self.size_average)

def ssim(img1, img2, mask, window_size = 11, size_average = True):
    (_, channel, _, _) = img1.size()
    window = create_window(window_size, channel)
    
    if img1.is_cuda:
        window = window.cuda(img1.get_device())
    window = window.type_as(img1)
    
    return _ssim(img1, img2, window, window_size, channel, size_average)

In [19]:
def create_enc(in_dim, dim, head_num, head_dim, layer_num):
  net = nn.Sequential(
      *[
       nn.Sequential(
        nn.Conv3d(in_dim if i==0 else dim//2**(layer_num-i),dim//2**(layer_num-i-1),(1,4,4),stride=(1,2,2),padding=(0,1,1)),
        nn.ReLU(),
        nn.GroupNorm(1,dim//2**(layer_num-i-1)), 
       ) for i in range(layer_num)
      ],

      nn.Conv3d(dim,head_num*head_dim,1),
    )
  return net

In [None]:
# v_net = create_enc(3,128,8,64,2)
# k_net = create_enc(39,128,8,64,2)
# q_net = create_enc(39,256,8,64,3)
# values = v_net(torch.randn(1,3,10,64,64))
# keys = k_net(torch.randn(1,39,10,64,64))
# querys = q_net(torch.randn(2,39,64,64,64))
# values.shape, keys.shape, querys.shape

In [20]:
def mapping(q, k, v, head_num):
  b, c, D, H, W = q.shape
  dk = c/head_num

  q = rearrange(q, 'b (h d) D H W -> b h (D H W) d', h=head_num)
  k = rearrange(k, 'b (h d) D H W -> b h (D H W) d', h=head_num)
  v = rearrange(v, 'b (h d) D H W -> b h (D H W) d', h=head_num)

  map = torch.softmax(torch.einsum('bhLd,bhSd->bhLS', q, k)/(dk**0.5), -1)
  head_out = torch.einsum('bhLS,bhSd->bhLd', map, v)
  mu, logvar = rearrange(head_out, 'b h (D H W) (k d) -> k b (h d) D H W', D=D, H=H, W=W, k=2)[:,...]
    
  return mu, logvar

In [21]:
def reparameterize(mu, logvar):
  std = logvar.mul(0.5).exp_()
  # return torch.normal(mu, std)
  esp = torch.randn_like(mu)
  z = mu + std * esp
  return z

In [22]:
def kl_divergence(mu, logvar):
  return -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())

In [None]:
# z = mapping(querys, keys, values, 8)
# z.shape

In [23]:
def create_dec(out_dim, dim, head_num, head_dim, layer_num):
  net = nn.Sequential(
      nn.Conv3d(head_num*head_dim,head_num*head_dim,(1,3,3),padding=(0,1,1),groups=head_num),
      nn.ReLU(),
      nn.GroupNorm(1,head_num*head_dim),
      nn.Conv3d(head_num*head_dim,dim,1,),
      nn.ReLU(),
      nn.GroupNorm(1,dim),

      *[
       nn.Sequential(
        nn.ConvTranspose3d(dim//2**i,dim//2**(i+1),(1,4,4),stride=(1,2,2),padding=(0,1,1),),
        nn.ReLU(),
        nn.GroupNorm(1,dim//2**(i+1)),
       ) for i in range(layer_num)
      ],

      nn.Conv3d(dim//2**(layer_num),out_dim,1),
    )
  return net

In [None]:
# dec_net = create_dec(4, 256, 8, 64, 3)
# dec_net(z).shape

In [24]:
def load_data(path):
  data = np.load(path, allow_pickle=True)
  imgs = data["imgs"]
  poses = data["poses"]
  render_poses = data["render_poses"]
  [H, W, focal] = data["hwf"]
  i_split = data["i_split"]
  
  return imgs, poses, render_poses, [H, W, focal], i_split

In [25]:
def posenc(pts, L_embed):
  rets = [pts]
  for i in range(L_embed):
    for fn in [torch.sin, torch.cos]:
      rets.append(fn(2.**i * pts))
  return torch.cat(rets, 1)

def sample_z(near, far, N_samples, H=1, W=1, bz=1, rand=False,):
  z_vals = torch.linspace(near, far, N_samples,).to(device)
  z_vals = repeat(z_vals, 'n -> b 1 n h w', b=bz, h=H, w=W)
  if rand:
    z_vals = z_vals + torch.rand_like(z_vals).to(device) * (far-near)/N_samples
  return z_vals

def sample_points(rays_o, rays_d, z_vals,):
  pts = rays_o[:,:,np.newaxis,...] + rays_d[:,:,np.newaxis,...] * z_vals
  return pts

def get_rays(H, W, focal, c2ws):
  b = c2ws.shape[0]
  # x:(H, W)
  # y:(H, W)
  y ,x = torch.meshgrid(torch.arange(0, H), torch.arange(0, W),)

  # dirs:(3, H, W)
  dirs = torch.stack([(x-W*.5)/focal, -(y-H*.5)/focal, -torch.ones(H,W)], 0).to(device)

  # c2ws[...,:3,:3]:(B, 3, 3)
  # rays_d:(B, 3, H, W)
  rays_d = torch.einsum("vhw,buv->buhw", dirs, c2ws[...,:3,:3])

  # c2w[...,:3,-1]:(B, 3)
  # rays_o:(B, 3, H, W)
  rays_o = torch.broadcast_to(c2ws[...,:3,-1].reshape(b,3,1,1), rays_d.shape)
  return rays_o, rays_d

In [48]:
near = 2
far = 6
N_samples = 64
base_num = 4

L_embed = 12
dim = 256
head_num = 32
head_dim = 32

In [27]:
v_net = create_enc(3, dim, head_num, head_dim*2, 3).to(device)
k_net = create_enc((L_embed*2+1)*3, dim, head_num, head_dim, 3).to(device)
q_net = create_enc((L_embed*2+1)*3, dim, head_num, head_dim, 3).to(device)
dec_net = create_dec(4, dim, head_num, head_dim, 3).to(device)

In [45]:
checkpoint = torch.load("./decoupling_nerf-vae_state.pt", map_location=device)

v_net.load_state_dict(checkpoint['v_net_state_dict'])
k_net.load_state_dict(checkpoint['k_net_state_dict'])
q_net.load_state_dict(checkpoint['q_net_state_dict'])
dec_net.load_state_dict(checkpoint['dec_net_state_dict'])

<All keys matched successfully>

In [50]:
imgs, poses, render_poses, hwf, _ = load_data("./lego-64.npz")
imgs = rearrange(imgs, 'b H W c -> b c H W')
indices = np.arange(0,400)
np.random.shuffle(indices)
base_idx = indices[:base_num]

In [52]:
%matplotlib inline
from math import sin, cos
from ipywidgets import interactive, widgets


trans_t = lambda t : torch.tensor([
    [1.,0,0,0],
    [0,1,0,0],
    [0,0,1,t],
    [0,0,0,1],
],device=device)

rot_phi = lambda phi : torch.tensor([
    [1.,0,0,0],
    [0,cos(phi),-sin(phi),0],
    [0,sin(phi), cos(phi),0],
    [0,0,0,1],
],device=device)

rot_theta = lambda th : torch.tensor([
    [cos(th),0,-sin(th),0],
    [0,1,0,0],
    [sin(th),0, cos(th),0],
    [0,0,0,1],
],device=device)


def pose_spherical(theta, phi, radius):
  c2w = trans_t(radius)
  c2w = rot_phi(phi/180.*np.pi) @ c2w
  c2w = rot_theta(theta/180.*np.pi) @ c2w
  c2w = torch.tensor([[-1.,0,0,0],[0,0,1,0],[0,1,0,0],[0,0,0,1]]) @ c2w
  return c2w


def f(**kwargs):
  with torch.no_grad():
    base_img = torch.tensor(imgs[base_idx,:-1,...]).type(torch.FloatTensor).to(device)
    base_img = rearrange(base_img, 'b c H W -> 1 c b H W')
    H, W, focal = hwf
    H, W = int(H), int(W)
    
    base_rays_o, base_rays_d = get_rays(H, W, focal, torch.tensor(poses[base_idx],device=device))
    base_z_vals = sample_z(near, far, 1, 1, 1, 1)
    base_pts = sample_points(base_rays_o, base_rays_d, base_z_vals)
    base_pos = posenc(base_pts, L_embed)
    base_pos = rearrange(base_pos, 'b e 1 H W -> 1 e b H W')

    c2w = pose_spherical(**kwargs)
    c2ws = rearrange(c2w, 'u v -> 1 u v')
    target_rays_o, target_rays_d = get_rays(H, W, focal, c2ws)
    target_z_vals = sample_z(near, far, N_samples, 1, 1, 1, True)
    target_pts = sample_points(target_rays_o, target_rays_d, target_z_vals)
    target_pos = posenc(target_pts, L_embed)

    q = q_net(target_pos)
    k = k_net(base_pos)
    v = v_net(base_img)
    mu, logvar = mapping(q, k, v, head_num)
    z = reparameterize(mu, logvar)
    rgba = dec_net(z)

    rgb = torch.sigmoid(rgba[:,:3,...])
    sigma_a = F.relu(rgba[:,3:,...])
    
    # Do volume rendering
    dists = torch.cat([target_z_vals[:,:,1:,...] - target_z_vals[:,:,:-1,...], torch.ones(1, 1, 1, 1, 1, device=device)*1e10], 2) 
    alpha = 1.-torch.exp(-sigma_a * dists)
    _alpha = (1.-alpha + 1e-10)[:,:,:-1,...]
    _alpha = torch.cat([torch.ones(1, 1, 1, H, W, device=device), _alpha],2)
    weights = alpha * torch.cumprod(_alpha, 2)
    rgb_map = torch.sum(weights * rgb, 2)
    out_img = rearrange(rgb_map, '1 c H W -> H W c').detach().numpy()
    
  plt.figure(2, figsize=(20,6))
  plt.imshow(out_img)
  plt.show()
    

sldr = lambda v, mi, ma: widgets.FloatSlider(
    value=v,
    min=mi,
    max=ma,
    step=.01,
)

names = [
    ['theta', [100., 0., 360]],
    ['phi', [-30., -90, 0]],
    ['radius', [4., 3., 5.]],
]

interactive_plot = interactive(f, **{s[0] : sldr(*s[1]) for s in names})
output = interactive_plot.children[-1]
output.layout.height = '350px'
interactive_plot

interactive(children=(FloatSlider(value=100.0, description='theta', max=360.0, step=0.01), FloatSlider(value=-…