<a href="https://colab.research.google.com/github/tomrodenhagen/pycluster/blob/master/ImplicitGan.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [4]:
!git clone https://github.com/tomrodenhagen/Implicit-Renderer.git

Cloning into 'Implicit-Renderer'...
fatal: could not read Username for 'https://github.com': No such device or address


In [None]:
def get_logits_from_prob(probs, eps=1e-4):
    ''' Returns logits for probabilities.
    Args:
        probs (tensor): probability tensor
        eps (float): epsilon value for numerical stability
    '''
    probs = np.clip(probs, a_min=eps, a_max=1-eps)
    logits = np.log(probs / (1 - probs))
    return logits
class DepthFunction(torch.autograd.Function):
    ''' Depth Function class.
    It provides the function to march along given rays to detect the surface
    points for the OccupancyNetwork. The backward pass is implemented using
    the analytic gradient described in the publication.
    '''
    @staticmethod
    def run_Bisection_method(d_low, d_high, n_secant_steps, ray0_masked,
                             ray_direction_masked, decoder, c, logit_tau):
        ''' Runs the bisection method for interval [d_low, d_high].
        Args:
            d_low (tensor): start values for the interval
            d_high (tensor): end values for the interval
            n_secant_steps (int): number of steps
            ray0_masked (tensor): masked ray start points
            ray_direction_masked (tensor): masked ray direction vectors
            decoder (nn.Module): decoder model to evaluate point occupancies
            c (tensor): latent conditioned code c
            logit_tau (float): threshold value in logits
        '''
        d_pred = (d_low + d_high) / 2.
        for i in range(n_secant_steps):
            p_mid = ray0_masked + d_pred.unsqueeze(-1) * ray_direction_masked
            with torch.no_grad():
                f_mid = decoder(p_mid, c, batchwise=False,
                                only_occupancy=True) - logit_tau
            ind_low = f_mid < 0
            d_low[ind_low] = d_pred[ind_low]
            d_high[ind_low == 0] = d_pred[ind_low == 0]
            d_pred = 0.5 * (d_low + d_high)
        return d_pred

    @staticmethod
    def run_Secant_method(f_low, f_high, d_low, d_high, n_secant_steps,
                          ray0_masked, ray_direction_masked, decoder, c,
                          logit_tau):
       
        ''' Runs the secant method for interval [d_low, d_high].
        Args:
            d_low (tensor): start values for the interval
            d_high (tensor): end values for the interval
            n_secant_steps (int): number of steps
            ray0_masked (tensor): masked ray start points
            ray_direction_masked (tensor): masked ray direction vectors
            decoder (nn.Module): decoder model to evaluate point occupancies
            c (tensor): latent conditioned code c
            logit_tau (float): threshold value in logits
        '''
        d_pred = - f_low * (d_high - d_low) / (f_high - f_low) + d_low
        for i in range(n_secant_steps):
            p_mid = ray0_masked + d_pred.unsqueeze(-1) * ray_direction_masked
            with torch.no_grad():
                f_mid = decoder(p_mid, c, batchwise=False,
                                only_occupancy=True) - logit_tau
            ind_low = f_mid < 0
        
            if ind_low.sum() > 0:
                d_low[ind_low] = d_pred[ind_low]
                f_low[ind_low] = f_mid[ind_low]
            if (ind_low == 0).sum() > 0:
               
                d_high[ind_low == 0] = d_pred[ind_low == 0]
                f_high[ind_low == 0] = f_mid[ind_low == 0]

            d_pred = - f_low * (d_high - d_low) / (f_high - f_low) + d_low
        return d_pred

    @staticmethod
    def perform_ray_marching(ray0, ray_direction, decoder, c=None,
                             tau=0.5, n_steps=[128, 129], n_secant_steps=8,
                             depth_range=[0., 2.4], method='secant',
                             check_cube_intersection=True, max_points=3500000, debug=True):
        ''' Performs ray marching to detect surface points.
        The function returns the surface points as well as d_i of the formula
            ray(d_i) = ray0 + d_i * ray_direction
        which hit the surface points. In addition, masks are returned for
        illegal values.
        Args:
            ray0 (tensor): ray start points of dimension B x N x 3
            ray_direction (tensor):ray direction vectors of dim B x N x 3
            decoder (nn.Module): decoder model to evaluate point occupancies
            c (tensor): latent conditioned code
            tay (float): threshold value
            n_steps (tuple): interval from which the number of evaluation
                steps if sampled
            n_secant_steps (int): number of secant refinement steps
            depth_range (tuple): range of possible depth values (not relevant when
                using cube intersection)
            method (string): refinement method (default: secant)
            check_cube_intersection (bool): whether to intersect rays with
                unit cube for evaluation
            max_points (int): max number of points loaded to GPU memory
        '''
        # Shotscuts
        batch_size, n_pts, D = ray0.shape
        device = ray0.device
        logit_tau = get_logits_from_prob(tau)
        n_steps = torch.randint(n_steps[0], n_steps[1], (1,)).item()

        # Prepare d_proposal and p_proposal in form (b_size, n_pts, n_steps, 3)
        # d_proposal are "proposal" depth values and p_proposal the
        # corresponding "proposal" 3D points
        d_proposal = torch.linspace(
            depth_range[0], depth_range[1], steps=n_steps).view(
                1, 1, n_steps, 1).to(device)
        d_proposal = d_proposal.repeat(batch_size, n_pts, 1, 1)
        
        if check_cube_intersection:
            d_proposal_cube, mask_inside_cube = \
                get_proposal_points_in_unit_cube(ray0, ray_direction,
                                                 padding=0.1,
                                                 eps=1e-6, n_steps=n_steps)
            d_proposal[mask_inside_cube] = d_proposal_cube[mask_inside_cube]

        p_proposal = ray0.unsqueeze(2).repeat(1, 1, n_steps, 1) + \
            ray_direction.unsqueeze(2).repeat(1, 1, n_steps, 1) * d_proposal
      
        # Evaluate all proposal points in parallel
        with torch.no_grad():
            val = torch.cat([(
                decoder(p_split, c, only_occupancy=True) - logit_tau)
                for p_split in torch.split(
                    p_proposal.view(batch_size, -1, 3),
                    int(max_points / batch_size), dim=1)], dim=1).view(
                        batch_size, -1, n_steps)
        # Create mask for valid points where the first point is not occupied
        mask_0_not_occupied = val[:, :, 0] < 0
        
        # Calculate if sign change occurred and concat 1 (no sign change) in
        # last dimension
        sign_matrix = torch.cat([torch.sign(val[:, :, :-1] * val[:, :, 1:]),
                                 torch.ones(batch_size, n_pts, 1).to(device)],
                                dim=-1)
        cost_matrix = sign_matrix * torch.arange(
            n_steps, 0, -1).float().to(device)
        # Get first sign change and mask for values where a.) a sign changed
        # occurred and b.) no a neg to pos sign change occurred (meaning from
        # inside surface to outside)
        values, indices = torch.min(cost_matrix, -1)
        mask_sign_change = values < 0
        mask_neg_to_pos = val[torch.arange(batch_size).unsqueeze(-1),
                              torch.arange(n_pts).unsqueeze(-0), indices] < 0

        # Define mask where a valid depth value is found
        mask = mask_sign_change & mask_neg_to_pos & mask_0_not_occupied

        # Get depth values and function values for the interval
        # to which we want to apply the Secant method
        n = batch_size * n_pts
        d_low = d_proposal.view(
            n, n_steps, 1)[torch.arange(n), indices.view(n)].view(
                batch_size, n_pts)[mask]
        f_low = val.view(n, n_steps, 1)[torch.arange(n), indices.view(n)].view(
            batch_size, n_pts)[mask]
        indices = torch.clamp(indices + 1, max=n_steps-1)
        d_high = d_proposal.view(
            n, n_steps, 1)[torch.arange(n), indices.view(n)].view(
                batch_size, n_pts)[mask]
        f_high = val.view(
            n, n_steps, 1)[torch.arange(n), indices.view(n)].view(
                batch_size, n_pts)[mask]
       
        ray0_masked = ray0[mask]
        ray_direction_masked = ray_direction[mask]

     
        if c is not None and c.shape[-1] != 0:
           c = c.unsqueeze(1).repeat(1, n_pts, 1)[mask]
        # Apply surface depth refinement step (e.g. Secant method)
        if method == 'secant' and mask.sum() > 0:
            d_pred = DepthFunction.run_Secant_method(
                f_low, f_high, d_low, d_high, n_secant_steps, ray0_masked,
                ray_direction_masked, decoder, c, logit_tau)
        elif method == 'bisection' and mask.sum() > 0:
            d_pred = DepthFunction.run_Bisection_method(
                d_low, d_high, n_secant_steps, ray0_masked,
                ray_direction_masked, decoder, c, logit_tau)
        else:
            d_pred = torch.ones(ray_direction_masked.shape[0]).to(device)
     
        # for sanity
        pt_pred = torch.ones(batch_size, n_pts, 3).to(device)
        pt_pred[mask] = ray0_masked + \
            d_pred.unsqueeze(-1) * ray_direction_masked
        # for sanity
        d_pred_out = torch.ones(batch_size, n_pts).to(device)
        d_pred_out[mask] = d_pred

        return d_pred_out, pt_pred, mask, mask_0_not_occupied

    @staticmethod
    def forward(ctx, *input):
        ''' Performs a forward pass of the Depth function.
        Args:
            input (list): input to forward function
        '''
        (ray0, ray_direction, decoder, c, n_steps, n_secant_steps, tau,
         depth_range, method, check_cube_intersection, max_points) = input[:11]

        # Get depth values
        with torch.no_grad():
            d_pred, p_pred, mask, mask_0_not_occupied = \
                DepthFunction.perform_ray_marching(
                    ray0, ray_direction, decoder, c, tau, n_steps,
                    n_secant_steps, depth_range, method, check_cube_intersection,
                    max_points)

        # Insert appropriate values for points where no depth is predicted
        d_pred[mask == 0] = 20
        d_pred[mask_0_not_occupied == 0] = 0

        # Save values for backward pass
        ctx.save_for_backward(ray0, ray_direction, d_pred, p_pred, c)
        ctx.decoder = decoder
        ctx.mask = mask

        return d_pred

    @staticmethod
    def backward(ctx, grad_output):
        ''' Performs the backward pass of the Depth function.
        We use the analytic formula derived in the main publication for the
        gradients. 
        Note: As for every input a gradient has to be returned, we return
        None for the elements which do no require gradients (e.g. decoder).
        Args:
            ctx (Pytorch Autograd Context): pytorch autograd context
            grad_output (tensor): gradient outputs
        '''
        ray0, ray_direction, d_pred, p_pred, c = ctx.saved_tensors
        decoder = ctx.decoder
        mask = ctx.mask
        eps = 1e-3

        with torch.enable_grad():
            p_pred.requires_grad = True
            f_p = decoder(p_pred, c, only_occupancy=True)
            f_p_sum = f_p.sum()
            grad_p = torch.autograd.grad(f_p_sum, p_pred, retain_graph=True)[0]
            grad_p_dot_v = (grad_p * ray_direction).sum(-1)

            if mask.sum() > 0:
                grad_p_dot_v[mask == 0] = 1.
                # Sanity
                grad_p_dot_v[abs(grad_p_dot_v) < eps] = eps
                grad_outputs = -grad_output.squeeze(-1)
                grad_outputs = grad_outputs / grad_p_dot_v
                grad_outputs = grad_outputs * mask.float()

            # Gradients for latent code c
            if c is None or c.shape[-1] == 0 or mask.sum() == 0:
                gradc = None
            else:
                gradc = torch.autograd.grad(f_p, c, retain_graph=True,
                                            grad_outputs=grad_outputs)[0]

            # Gradients for network parameters phi
            if mask.sum() > 0:
                # Accumulates gradients weighted by grad_outputs variable
                grad_phi = torch.autograd.grad(
                    f_p, [k for k in decoder.parameters()],
                    grad_outputs=grad_outputs, retain_graph=True)
            else:
                grad_phi = [None for i in decoder.parameters()]

        # Return gradients for c, z, and network parameters and None
        # for all other inputs
        out = [None, None, None, gradc, None, None, None, None, None,
               None, None] + list(grad_phi)
        return tuple(out)

In [None]:
import os# %matplotlib inline
# %matplotlib notebook
import os
import sys
import time
import json
import glob
import torch
import math
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from IPython import display
from tqdm.notebook import tqdm

# Data structures and functions for rendering
from pytorch3d.structures import Volumes
from pytorch3d.transforms import so3_exponential_map
from pytorch3d.renderer import (
    FoVPerspectiveCameras, 
    NDCGridRaysampler,
    MonteCarloRaysampler,
    EmissionAbsorptionRaymarcher,
    ImplicitRenderer,
    RayBundle,
    ray_bundle_to_ray_points,
)

# obtain the utilized device
if torch.cuda.is_available():
    device = torch.device("cuda:0")
    torch.cuda.set_device(device)
else:
    print(
        'Please note that NeRF is a resource-demanding method.'
        + ' Running this notebook on CPU will be extremely slow.'
        + ' We recommend running the example on a GPU'
        + ' with at least 10 GB of memory.'
    )
    device = torch.device("cpu")

ModuleNotFoundError: ignored

In [None]:
# render_size describes the size of both sides of the # render_size describes the size of both sides of the 
# rendered images in pixels. Since an advantage of 
# Neural Radiance Fields are high quality renders
# with a significant amount of details, we render
# the implicit function at double the size of 
# target images.
render_size = 32
# Our rendered scene is centered around (0,0,0) 
# and is enclosed inside a bounding box
# whose side is roughly equal to 3.0 (world units).
volume_extent_world = 3.0

# 1) Instantiate the raysamplers.

# Here, NDCGridRaysampler generates a rectangular image
# grid of rays whose coordinates follow the PyTorch3d
# coordinate conventions.
raysampler_grid = NDCGridRaysampler(
    image_height=render_size,
    image_width=render_size,
    n_pts_per_ray=128,
    min_depth=0.1,
    max_depth=volume_extent_world,
)

# MonteCarloRaysampler generates a random subset 
# of `n_rays_per_image` rays emitted from the image plane.
raysampler_mc = MonteCarloRaysampler(
    min_x = -1.0,
    max_x = 1.0,
    min_y = -1.0,
    max_y = 1.0,
    n_rays_per_image=750,
    n_pts_per_ray=64,
    min_depth=0.1,
    max_depth=volume_extent_world,
)

# 2) Instantiate the raymarcher.
# Here, we use the standard EmissionAbsorptionRaymarcher 
# which marches along each ray in order to render
# the ray into a single 3D color vector 
# and an opacity scalar.
raymarcher = EmissionAbsorptionRaymarcher()

# Finally, instantiate the implicit renders
# for both raysamplers.
renderer_grid = ImplicitRenderer(
    raysampler=raysampler_grid, raymarcher=raymarcher,
)
renderer_mc = ImplicitRenderer(
    raysampler=raysampler_mc, raymarcher=raymarcher,
)

In [None]:
!wget https://raw.githubusercontent.com/facebookresearch/pytorch3d/master/docs/tutorials/utils/plot_image_grid.py
!wget https://raw.githubusercontent.com/facebookresearch/pytorch3d/master/docs/tutorials/utils/generate_cow_renders.py
from plot_image_grid import image_grid
from generate_cow_renders import generate_cow_renders

In [None]:
from tensorflow.keras.datasets import cifar10
!nvidia-smi
print("[INFO] loading CIFAR-10 data")
((trainX, trainY), (testX, testY)) = cifar10.load_data()

In [None]:
from torch import nn
class HarmonicEmbedding(torch.nn.Module):
    def __init__(self, n_harmonic_functions=60, omega0=0.1):
        """
        Given an input tensor `x` of shape [minibatch, ... , dim],
        the harmonic embedding layer converts each feature
        in `x` into a series of harmonic features `embedding`
        as follows:
            embedding[..., i*dim:(i+1)*dim] = [
                sin(x[..., i]),
                sin(2*x[..., i]),
                sin(4*x[..., i]),
                ...
                sin(2**self.n_harmonic_functions * x[..., i]),
                cos(x[..., i]),
                cos(2*x[..., i]),
                cos(4*x[..., i]),
                ...
                cos(2**self.n_harmonic_functions * x[..., i])
            ]
            
        Note that `x` is also premultiplied by `omega0` before
        evaluting the harmonic functions.
        """
        super().__init__()
        self.register_buffer(
            'frequencies',
            omega0 * (2.0 ** torch.arange(n_harmonic_functions)),
        )
    def forward(self, x):
        """
        Args:
            x: tensor of shape [..., dim]
        Returns:
            embedding: a harmonic embedding of `x`
                of shape [..., n_harmonic_functions * dim * 2]
        """
        embed = (x[..., None] * self.frequencies).view(*x.shape[:-1], -1)
        return torch.cat((embed.sin(), embed.cos()), dim=-1)
class Decoder(nn.Module):
    def __init__(self):
        n_neurons = 128
        super(Decoder, self).__init__()
        self.harmonic_embedding = HarmonicEmbedding()
        self.decoder_occ = nn.Sequential(nn.Linear(3,n_neurons),
                        nn.ReLU(),
                        nn.Linear( n_neurons, n_neurons),
                        nn.ReLU(),
                        nn.Linear( n_neurons, n_neurons),
                        nn.ReLU(),
                        nn.Linear( n_neurons,1))
        self.decoder_col = nn.Sequential(nn.Linear(360,n_neurons),
                        nn.ReLU(),
                        nn.Linear( n_neurons, n_neurons),
                        nn.ReLU(),
                        nn.Linear( n_neurons, n_neurons),
                        nn.ReLU(),
                        nn.Linear( n_neurons,3))
   
  

    def forward(self, p,c, only_occupancy=True, batchwise=False):
        x_col = self.decoder_col(self.harmonic_embedding(p))
        x_occ = self.decoder_occ(p)

        x = torch.cat((x_occ, x_col), axis=-1)
        eps = 1
        norm_bound = 3
        sigm = nn.Sigmoid()
        occ =  x[..., 0]  - 1
        norm = torch.norm(p, dim=-1)
        weight = sigm((norm_bound - norm) * 10 )
        
        occ = occ * weight  +  (1-weight) * (sigm(-(norm_bound - norm + eps)*10 ) - 0.5) 
      
        if only_occupancy:
         

          
          return occ
        else:
      
          return nn.Sigmoid()(x[...,1:]), occ 
    def sample_points(self, n_points):
       
        points = torch.rand((n_points,3)).to(device) * 6 -3
        return self(points, only_occupancy=True,c=None), points      
    def pretrain(self, n_iter =1000):
      optim = torch.optim.Adam(self.parameters(), lr= 0.001)
      for k in range(n_iter):
        optim.zero_grad()
        X = torch.rand((1000, 3),device=device) * 6 - 3
        Y = torch.le(torch.norm(X, dim=-1), 1.5).float() * 2 - 1

        pred = self(X, only_occupancy=True,c=None)
        loss = torch.mean( (pred - Y) ** 2)
        loss.backward()
       
        optim.step()
      print(loss)


In [None]:
from pytorch3d.renderer import (
   
    NDCGridRaysampler,
   
)

decoder = Decoder().to(device)
#decoder.pretrain()
class renderer_implicit(nn.Module):
  def __init__(self,res = (32,32)):
    super().__init__()
    self.n_samples = 500
    self.n_steps = [256, 257]
    self.calc_depth = DepthFunction.apply
    self.n_secant_steps = 10
    self.method = "secant"
    self.check_cube_intersection = False
    self.schedule_ray_sampling = False
    self.max_points = 100000
    self.depth_range = [0., 10]
    self.tau = 0.5

    self.raysampler_grid = NDCGridRaysampler(
                                image_height=res[0],
                                image_width=res[1],
                                n_pts_per_ray=1,
                                min_depth=0.1,
                                max_depth=0.5,
                            ).to(device)
   
  def forward(self, decoder, c, cameras):
      ray_0, ray_direction, _ ,_  = self.raysampler_grid(cameras)
      bz = ray_0.shape[0]
      height = ray_0.shape[1]
      width = ray_0.shape[2]
      ray_0_flat = ray_0.view(bz,-1, 3)
      ray_direction_flat = ray_direction.view(bz,-1, 3)
      if not c is None:
        c_expanded = c.unsqueeze(-1).expand(-1,ray_0_flat.shape[1], -1)
      else:
        c_expanded = None
      depth = self.forward_rays(decoder,  ray_0_flat, ray_direction_flat, c = c_expanded)
      depth_transformed = depth.view(bz, height, width, 1)
      p = depth_transformed * ray_direction + ray_0 

      color,o = decoder(p,c,only_occupancy=False)
     
      return color,depth_transformed
                                
  def forward_rays(self, decoder,  ray0, ray_direction, c = None):
    
    
    
    inputs = [ray0, ray_direction, decoder, c, self.n_steps,
                      self.n_secant_steps, self.tau, self.depth_range,
                      self.method, self.check_cube_intersection,
                      self.max_points] + [k for k in decoder.parameters()]  
    d_hat = self.calc_depth(*inputs )
    return d_hat
   
  

def make_cameras(n_cameras):
    target_cameras, target_images, target_silhouettes = generate_cow_renders(num_views=2)
    logRs = torch.zeros(n_cameras, 3, device=device)
    logRs[:, 1] = torch.linspace(-0.15, 0.15, n_cameras, device=device)
    Rs = so3_exponential_map(logRs)
    Ts = torch.zeros(n_cameras, 3, device=device)
    Ts[:, 2] = 2.0
    cameras = FoVPerspectiveCameras(
            R=Rs, 
            T=Ts, 
            znear=target_cameras.znear[0],
            zfar=target_cameras.zfar[0],
            aspect_ratio=target_cameras.aspect_ratio[0],
            fov=target_cameras.fov[0],
            device=device,
        )
    return cameras

target_cameras = make_cameras(2)

R = renderer_implicit()


from matplotlib import pyplot as plt
optim = torch.optim.Adam(decoder.parameters(),lr=0.001)


for k in range(300):
  optim.zero_grad()
  d, o  = R(decoder, None, target_cameras)
  occ,points = decoder.sample_points(1000)
  
  loss = torch.mean((o - 2.7)**2) + torch.mean((torch.norm(occ , dim=-1) )**2) 
  if k%30==0:
    print(loss, float(o[0,16,16]))
    loss.backward()
    optim.step()
    plt.imshow(o.cpu().detach().numpy()[0,:,:,0])
    plt.show()








In [None]:
t = torch.tensor(([[0,0,0],[0,0,2]])).to(device).float()
print(decoder(t,c=None) )

In [None]:
import random
def preprocess(images):
  images = images / 255.0 
  return images 
def sample_images(batch_size):
  l = list(range(trainX.shape[0]))
  samples = random.sample(l, batch_size)
  images = np.stack([trainX[i] for i in samples])
  
  return torch.tensor(images, device = device)
def sample_cams(target_cameras, batch_idx):
   batch_cameras = FoVPerspectiveCameras(
        R = target_cameras.R[batch_idx], 
        T = target_cameras.T[batch_idx], 
        znear = target_cameras.znear[batch_idx],
        zfar = target_cameras.zfar[batch_idx],
        aspect_ratio = target_cameras.aspect_ratio[batch_idx],
        fov = target_cameras.fov[batch_idx],
        device = device,
    )
   return batch_cameras

In [None]:
def get_cameras(n_cams):
    target_cameras, target_images, target_silhouettes = generate_cow_renders(num_views=1, azimuth_range=180)
    logRs = torch.zeros(n_cams, 3, device=device)
    logRs[:, 1] = torch.linspace(-0.15, 0.15, n_cams, device=device)
    Rs = so3_exponential_map(logRs)
    Ts = torch.zeros(n_cams, 3, device=device)
    Ts[:, 2] = 2.7
    cams = FoVPerspectiveCameras(
            R=Rs, 
            T=Ts, 
            znear=target_cameras.znear[0],
            zfar=target_cameras.zfar[0],
            aspect_ratio=target_cameras.aspect_ratio[0],
            fov=target_cameras.fov[0],
            device=device,
        )
    return cams



def volumetric_function(
            ray_bundle: RayBundle,**kwargs
        ) :
            # first convert the ray origins, directions and lengths
            # to 3D ray point locations in world coords
            rays_points_world = ray_bundle_to_ray_points(ray_bundle)
            x = rays_points_world.norm(dim=-1 )[:,16, 16,:]
            print(x.shape)
            plt.plot(x.cpu().detach().T)
            plt.show()
            # set the densities as an inverse sigmoid of the
            # ray point distance from the sphere centroid
            rays_densities = torch.sigmoid(
                -100.0 * ( -0.5 + rays_points_world.norm(dim=-1, keepdim=True))
            )
            plt.plot(rays_densities.cpu().detach()[:,16, 16,0])
            plt.show()
            # set the ray features to RGB colors proportional
            # to the 3D location of the projection of ray points
            # on the sphere surface
            rays_features = torch.nn.functional.normalize(
                rays_points_world, dim=-1
            ) * 0.5 + 0.5
            
            return rays_densities, rays_features


In [None]:
# load the models
!git clone https://github.com/csinva/gan-vae-pretrained-pytorch
sys.path.append("./gan-vae-pretrained-pytorch/cifar10_dcgan")
from dcgan import Discriminator, Generator

D = Discriminator(ngpu=1).eval()
G = Generator(ngpu=1).eval()

# load weights
D.load_state_dict(torch.load('./gan-vae-pretrained-pytorch/cifar10_dcgan/weights/netD_epoch_199.pth'))
G.load_state_dict(torch.load('./gan-vae-pretrained-pytorch/cifar10_dcgan/weights/netG_epoch_199.pth'))
if torch.cuda.is_available():
    D = D.cuda()
    G = G.cuda()

In [None]:
z = torch.randn( (1,100,1,1), device=device)


torch.manual_seed(4)
def transform(image):
  image = image - torch.min(image)
  return image / torch.max(image)
def inv_transform(image):
  return (image - 0.5) * 2
with torch.no_grad():
  image = G(z) 
target_cameras = get_cameras(40)

In [None]:
from torch.autograd import Variable

# import random# First move all relevant variables to the correct device.
renderer_grid = renderer_grid.to(device)
renderer_mc = renderer_mc.to(device)


# Instantiate the radiance field model.
decoder = Decoder().to(device)
decoder.pretrain()
n_perspectives = 8
# Instantiate the Adam optimizer. We set its master learning rate to 1e-3.
lr = 0.0001
latent = Variable(torch.randn(size=(n_perspectives, 100, 1, 1), device=device) )
opt = torch.optim.Adam(list(decoder.parameters()) + [latent] , lr=lr)



def get_color(m):
  m1,m2 = m
  c,o = (m1.split([3, 1], dim=-1)
                    ) 
  return c 

# We sample 6 random cameras in a minibatch. Each camera
# emits raysampler_mc.n_pts_per_image rays.

epochs = 20
batch_size = 1
n_iter = 10
camera_front = sample_cams(target_cameras, [0])
other_cams = sample_cams(target_cameras, random.sample([i+1 for i in range(len(target_cameras) - 1) ], n_perspectives) )
renderer_implicit_ = renderer_implicit()
#other_cams = sample_cams(target_cameras, [0,0] )
for e in range(epochs):
  loss = 0
  for k in range(n_iter):
        opt.zero_grad()
        
   
        
        rendered_image_front,o_front = renderer_implicit_(cameras=camera_front, decoder=decoder,c=None )
        
        rendered_images_side,o_side = renderer_implicit_(cameras=other_cams, decoder=decoder,c=None )
       
        images_side = G(latent)
       
        loss_side = torch.mean((rendered_images_side - transform(images_side.permute(0,2,3,1) ) ) ** 2 )
        loss_front = torch.mean((rendered_image_front - transform(image.permute(0,2,3,1) ) ) ** 2 )
        loss_depth = torch.mean((o_front- 2.7)**2) + torch.mean((o_side- 2.7)**2)
       # loss_front = torch.mean((rendered_image_front - transform(image.permute(0,2,3,1) ) ) ** 2 )
      
        loss_complete = loss_front  + loss_side + loss_depth * 100
        loss += float(loss_complete)
        loss_complete.backward()
        opt.step()
  print(float(loss_front), float(loss_side), loss_depth)
  print(rendered_image_front.shape )
  print(o_front.mean())
  print(o_side.mean())
  f, axarr = plt.subplots(ncols=4)
  axarr[0].imshow(rendered_image_front.cpu().detach().numpy()[0])
  axarr[1].imshow(transform(image.permute(0,2,3,1)).cpu().detach().numpy()[0] ) 
  axarr[2].imshow(rendered_images_side.cpu().detach().numpy()[0] ) 
  axarr[3].imshow(rendered_images_side.cpu().detach().numpy()[1] ) 
 
  plt.show()
  f, axarr = plt.subplots(ncols=4)
  axarr[0].imshow(o_front.cpu().detach().numpy()[0][:,:,0])
 # axarr[1].imshow(transform(image.permute(0,2,3,1)).cpu().detach().numpy()[0] ) 
  axarr[2].imshow(o_side.cpu().detach().numpy()[0][:,:,0] ) 
  axarr[3].imshow(o_side.cpu().detach().numpy()[1][:,:,0] ) 
 
  plt.show()
        
   
