### Init and define functions

In [1]:
import torch
from utils import *
from collections import defaultdict
import matplotlib.pyplot as plt
import time
from models.rendering import *
from models.nerf import *
import metrics
from datasets import dataset_dict
from datasets.llff import *
from torch.utils.data import DataLoader
from functools import partial
from datasets.srn_multi_ae import collate_lambda_train, collate_lambda_val
torch.backends.cudnn.benchmark = True
torch.manual_seed(0)
# torch.set_printoptions(edgeitems=20)

import plotly
import plotly.graph_objects as go

# print(obj_samples)
def contract(x, order):
    mag = LA.norm(x, order, dim=-1)[..., None]
    return torch.where(mag < 1, x, (2 - (1 / mag)) * (x / mag))

def contract_pts(pts, radius):
    mask = torch.norm(pts, dim=-1).unsqueeze(-1) > radius
    new_pts = pts.clone()/radius
    norm_pts = torch.norm(new_pts, dim=-1).unsqueeze(-1)
    contracted_points = ((1+0.2) - 0.2/(norm_pts))*(new_pts/norm_pts)*radius
    warped_points = mask*contracted_points + (~mask)*pts
    return warped_points

def cast_rays(t_vals, origins, directions):
    return origins[..., None, :] + t_vals[..., None] * directions[..., None, :]

def convert_pose(C2W):
    flip_yz = np.eye(4)
    flip_yz[1, 1] = -1
    flip_yz[2, 2] = -1
    C2W = np.matmul(C2W, flip_yz)
    return C2W

def sample_along_rays(
    rays_o,
    rays_d,
    num_samples,
    near,
    far,
    randomized,
    lindisp,
    in_sphere,
    far_uncontracted = 3.0
):
    bsz = rays_o.shape[0]
    print("bsz", bsz)
    t_vals = torch.linspace(0.0, 1.0, num_samples + 1, device=rays_o.device)

    if in_sphere:
        if lindisp:
            t_vals = 1.0 / (1.0 / near * (1.0 - t_vals) + 1.0 / far * t_vals)
        else:
            t_vals = near * (1.0 - t_vals) + far * t_vals
            
    else:
        t_vals = torch.broadcast_to(t_vals, (bsz, num_samples + 1))

    if randomized:
        mids = 0.5 * (t_vals[..., 1:] + t_vals[..., :-1])
        upper = torch.cat([mids, t_vals[..., -1:]], -1)
        lower = torch.cat([t_vals[..., :1], mids], -1)
        t_rand = torch.rand((bsz, num_samples + 1), device=rays_o.device)
        t_vals = lower + (upper - lower) * t_rand
    else:
        t_vals = torch.broadcast_to(t_vals, (bsz, num_samples + 1))

    if in_sphere:
        coords = cast_rays(t_vals, rays_o, rays_d)
        return t_vals, coords

    else:
        
        t_vals_linear = far * (1.0 - t_vals) + far_uncontracted * t_vals
        t_vals = torch.flip(
            t_vals,
            dims=[
                -1,
            ],
        )  # 1.0 -> 0.0
        
        t_vals_linear = torch.flip(
            t_vals_linear,
            dims=[
                -1,
            ],
        )  # 3.0 -> sphere 
        coords = depth2pts_outside(rays_o, rays_d, t_vals)
        coords_linear = cast_rays(t_vals_linear, rays_o, rays_d)
        return t_vals, coords, coords_linear


def depth2pts_outside(rays_o, rays_d, depth):
    """Compute the points along the ray that are outside of the unit sphere.
    Args:
        rays_o: [num_rays, 3]. Ray origins of the points.
        rays_d: [num_rays, 3]. Ray directions of the points.
        depth: [num_rays, num_samples along ray]. Inverse of distance to sphere origin.
    Returns:
        pts: [num_rays, 4]. Points outside of the unit sphere. (x', y', z', 1/r)
    """
    # note: d1 becomes negative if this mid point is behind camera
    rays_o = rays_o[..., None, :].expand(
        list(depth.shape) + [3]
    )  #  [N_rays, num_samples, 3]
    rays_d = rays_d[..., None, :].expand(
        list(depth.shape) + [3]
    )  #  [N_rays, num_samples, 3]
    d1 = -torch.sum(rays_d * rays_o, dim=-1, keepdim=True) / torch.sum(
        rays_d**2, dim=-1, keepdim=True
    )

    p_mid = rays_o + d1 * rays_d
    p_mid_norm = torch.norm(p_mid, dim=-1, keepdim=True)
    rays_d_cos = 1.0 / torch.norm(rays_d, dim=-1, keepdim=True)

    check_pos = 1.0 - p_mid_norm * p_mid_norm
    assert torch.all(check_pos >= 0), "1.0 - p_mid_norm * p_mid_norm should be greater than 0"

    d2 = torch.sqrt(1.0 - p_mid_norm * p_mid_norm) * rays_d_cos
    p_sphere = rays_o + (d1 + d2) * rays_d

    rot_axis = torch.cross(rays_o, p_sphere, dim=-1)
    rot_axis = rot_axis / torch.norm(rot_axis, dim=-1, keepdim=True)
    phi = torch.asin(p_mid_norm)
    theta = torch.asin(p_mid_norm * depth[..., None])  # depth is inside [0, 1]
    rot_angle = phi - theta  # [..., 1]

    # now rotate p_sphere
    # Rodrigues formula: https://en.wikipedia.org/wiki/Rodrigues%27_rotation_formula
    p_sphere_new = (
        p_sphere * torch.cos(rot_angle)
        + torch.cross(rot_axis, p_sphere, dim=-1) * torch.sin(rot_angle)
        + rot_axis
        * torch.sum(rot_axis * p_sphere, dim=-1, keepdim=True)
        * (1.0 - torch.cos(rot_angle))
    )
    p_sphere_new = p_sphere_new / (
        torch.norm(p_sphere_new, dim=-1, keepdim=True) + 1e-10
    )
    pts = torch.cat((p_sphere_new, depth.unsqueeze(-1)), dim=-1)

    return pts

def get_image_coords(pixel_offset, image_height, image_width,
):
    """This gets the image coordinates of one of the cameras in this object.
    If no index is specified, it will return the maximum possible sized height / width image coordinate map,
    by looking at the maximum height and width of all the cameras in this object.
    Args:
        pixel_offset: Offset for each pixel. Defaults to center of pixel (0.5)
        index: Tuple of indices into the batch dimensions of the camera. Defaults to None, which returns the 0th
            flattened camera
    Returns:
        Grid of image coordinates.
    """
    image_coords = torch.meshgrid(torch.arange(image_height), torch.arange(image_width), indexing="ij")
    image_coords = torch.stack(image_coords, dim=-1) + pixel_offset  # stored as (y, x) coordinates
    image_coords = torch.cat([image_coords, torch.ones((*image_coords.shape[:-1], 1))], dim=-1)
    image_coords = image_coords.view(-1, 3)
    return image_coords

def get_sphere(
    radius, center = None, color: str = "black", opacity: float = 1.0, resolution: int = 32
) -> go.Mesh3d:  # type: ignore
    """Returns a sphere object for plotting with plotly.
    Args:
        radius: radius of sphere.
        center: center of sphere. Defaults to origin.
        color: color of sphere. Defaults to "black".
        opacity: opacity of sphere. Defaults to 1.0.
        resolution: resolution of sphere. Defaults to 32.
    Returns:
        sphere object.
    """
    phi = torch.linspace(0, 2 * torch.pi, resolution)
    theta = torch.linspace(-torch.pi / 2, torch.pi / 2, resolution)
    phi, theta = torch.meshgrid(phi, theta, indexing="ij")

    x = torch.cos(theta) * torch.sin(phi)
    y = torch.cos(theta) * torch.cos(phi)
    z = torch.sin(theta)
    pts = torch.stack((x, y, z), dim=-1)

    pts *= radius
    if center is not None:
        pts += center

    return go.Mesh3d(
        {
            "x": pts[:, :, 0].flatten(),
            "y": pts[:, :, 1].flatten(),
            "z": pts[:, :, 2].flatten(),
            "alphahull": 0,
            "opacity": opacity,
            "color": color,
        }
    )
    
def vis_camera_rays(origins, directions, coords) -> go.Figure:  # type: ignore
    """Visualize camera rays.
    Args:
        camera: Camera to visualize.
    Returns:
        Plotly lines
    """
    lines = torch.empty((origins.shape[0] * 2, 3))
    lines[0::2] = origins
    lines[1::2] = origins + directions*3.0
    
    print("lines", lines.shape)

    colors = torch.empty((coords.shape[0] * 2, 3))
    colors[0::2] = coords
    colors[1::2] = coords

    data = []
    data.append(go.Scatter3d(
    x=lines[:, 0],
    y=lines[:, 2],
    z=lines[:, 1],
    marker=dict(
        size=4,
        color=colors)))
        
    data.append(get_sphere(radius=1.0, color="#111111", opacity=0.05))
#     data.append(get_sphere(radius=2.0, color="#111111", opacity=0.05))
    fig = go.Figure(data = data
        
    )
    fig.update_layout(
        scene=dict(
            xaxis=dict(title="x", showspikes=False),
            yaxis=dict(title="z", showspikes=False),
            zaxis=dict(title="y", showspikes=False),
        ),
        margin=dict(r=0, b=10, l=0, t=10),
        hovermode=False,
    )

    return fig

def vis_camera_samples(all_samples) -> go.Figure:  # type: ignore
    """Visualize camera rays.
    Args:
        camera: Camera to visualize.
    Returns:
        Plotly lines
    """
#     samples = samples.view(-1,3)

    data = []
    
    for i in range(all_samples.shape[0]):
        samples = all_samples[i]
        samples_init = samples[:10, :]
        samples_mid = samples[10:50, :]
        samples_final = samples[50:, :]

        data.append(go.Scatter3d(
        x=samples_init[:, 0],
        y=samples_init[:, 2],
        z=samples_init[:, 1],
        mode="markers",
        marker=dict(size=2, color="blue")
        ))

        data.append(go.Scatter3d(
        x=samples_mid[:, 0],
        y=samples_mid[:, 2],
        z=samples_mid[:, 1],
        mode="markers",
        marker=dict(size=2, color="black")
        ))

        data.append(go.Scatter3d(
        x=samples_final[:, 0],
        y=samples_final[:, 2],
        z=samples_final[:, 1],
        mode="markers",
        marker=dict(size=2, color="green")
        ))
        
    data.append(get_sphere(radius=1.0, color="#111111", opacity=0.05))
    data.append(get_sphere(radius=2.0, color="#111111", opacity=0.05))
    fig = go.Figure(data = data
        
    )
    fig.update_layout(
        scene=dict(
            xaxis=dict(title="x", showspikes=False),
            yaxis=dict(title="z", showspikes=False),
            zaxis=dict(title="y", showspikes=False),
        ),
        margin=dict(r=0, b=10, l=0, t=10),
        hovermode=False,
    )

    return fig


# def world2camera_viewdirs(w_viewdirs, cam2world, NS):
#     w_viewdirs = repeat_interleave(w_viewdirs, NS)  # (SB*NS, B, 3)
#     rot = torch.copy(cam2world[:, :3, :3]).transpose(1, 2)  # (B, 3, 3)
#     viewdirs = torch.matmul(rot[:, None, :3, :3], w_viewdirs.unsqueeze(-1))[..., 0]
#     return viewdirs

def world2camera_viewdirs(w_viewdirs, cam2world, NS):
    w_viewdirs = repeat_interleave(w_viewdirs, NS)  # (SB*NS, B, 3)
    rot = cam2world[:, :3, :3].transpose(1, 2)  # (B, 3, 3)
    viewdirs = torch.matmul(rot[:, None, :3, :3], w_viewdirs.unsqueeze(-1))[..., 0]
    return viewdirs


def world2camera(w_xyz, cam2world, NS=None):
    """Converts the points in world coordinates to camera view.
    :param xyz: points in world coordinates (SB*NV, NC, 3)
    :param poses: camera matrix (SB*NV, 4, 4)
    :output points in camera coordinates (SB*NV, NC, 3)
    : SB batch size
    : NV number of views in each scene
    : NC number of coordinate points
    """
    #print(w_xyz.shape, cam2world.shape)
    if NS is not None:
        w_xyz = repeat_interleave(w_xyz, NS)  # (SB*NS, B, 3)
    rot = cam2world[:, :3, :3].transpose(1, 2)  # (B, 3, 3)
    trans = -torch.bmm(rot, cam2world[:, :3, 3:])  # (B, 3, 1)
    #print(rot.shape, w_xyz.shape)
    cam_rot = torch.matmul(rot[:, None, :3, :3], w_xyz.unsqueeze(-1))[..., 0]
    cam_xyz = cam_rot + trans[:, None, :, 0]
    # cam_xyz = cam_xyz.reshape(-1, 3)  # (SB*B, 3)
    return cam_xyz

def repeat_interleave(input, repeats, dim=0):
    """
    Repeat interleave along axis 0
    torch.repeat_interleave is currently very slow
    https://github.com/pytorch/pytorch/issues/31980
    """
    output = input.unsqueeze(1).expand(-1, repeats, *input.shape[1:])
    return output.reshape(-1, *input.shape[1:])


def projection(c_xyz, focal, c):
    """Converts [x,y,z] in camera coordinates to image coordinates 
        for the given focal length focal and image center c.
    :param c_xyz: points in camera coordinates (SB*NV, NP, 3)
    :param focal: focal length (SB, 2)
    :c: image center (SB, 2)
    :output uv: pixel coordinates (SB, NV, NP, 2)
    """
    uv = -c_xyz[..., :2] / (c_xyz[..., 2:] + 1e-9)  # (SB*NV, NC, 2); NC: number of grid cells 
    uv *= repeat_interleave(
                focal.unsqueeze(1), NV if focal.shape[0] > 1 else 1
            )
    uv += repeat_interleave(
                c.unsqueeze(1), NV if c.shape[0] > 1 else 1
            )
    return uv


def pos_enc(x, min_deg=0, max_deg=10):
    scales = torch.tensor([2**i for i in range(min_deg, max_deg)]).type_as(x)
    xb = torch.reshape((x[..., None, :] * scales[:, None]), list(x.shape[:-1]) + [-1])
    four_feat = torch.sin(torch.cat([xb, xb + 0.5 * np.pi], dim=-1))
    return torch.cat([x] + [four_feat], dim=-1)

import open3d as o3d
def get_world_grid(side_lengths, grid_size):
    """ Returns a 3D grid of points in world coordinates.
    :param side_lengths: (min, max) for each axis (3, 2)
    :param grid_size: number of points along each dimension () or (3)
    :output grid: (1, grid_size**3, 3)
    """
    if len(grid_size) == 1:
        grid_size = [grid_size[0] for _ in range(3)]
        
    w_x = torch.linspace(side_lengths[0][0], side_lengths[0][1], grid_size[0])
    w_y = torch.linspace(side_lengths[1][0], side_lengths[1][1], grid_size[1])
    w_z = torch.linspace(side_lengths[2][0], side_lengths[2][1], grid_size[2])
    # Z, Y, X = torch.meshgrid(w_x, w_y, w_z)
    X, Y, Z = torch.meshgrid(w_x, w_y, w_z)
    w_xyz = torch.stack([X, Y, Z], axis=-1) # (gs, gs, gs, 3), gs = grid_size
    print(w_xyz.shape)
    w_xyz = w_xyz.reshape(-1, 3).unsqueeze(0) # (1, grid_size**3, 3)
    return w_xyz

def repeat_interleave(input, repeats, dim=0):
    """
    Repeat interleave along axis 0
    torch.repeat_interleave is currently very slow
    https://github.com/pytorch/pytorch/issues/31980
    """
    output = input.unsqueeze(1).expand(-1, repeats, *input.shape[1:])
    return output.reshape(-1, *input.shape[1:])


def intersect_sphere(rays_o, rays_d):
    """Compute the depth of the intersection point between this ray and unit sphere.
    Args:
        rays_o: [num_rays, 3]. Ray origins.
        rays_d: [num_rays, 3]. Ray directions.
    Returns:
        depth: [num_rays, 1]. Depth of the intersection point.
    """
    # note: d1 becomes negative if this mid point is behind camera

    d1 = -torch.sum(rays_d * rays_o, dim=-1, keepdim=True) / torch.sum(
        rays_d**2, dim=-1, keepdim=True
    )
    p = rays_o + d1 * rays_d
    # consider the case where the ray does not intersect the sphere
    rays_d_cos = 1.0 / torch.norm(rays_d, dim=-1, keepdim=True)
    p_norm_sq = torch.sum(p * p, dim=-1, keepdim=True)
    check_pos = 1.0 - p_norm_sq
    print("check pos", torch.max(p_norm_sq), torch.min(p_norm_sq))
    assert torch.all(check_pos >= 0), "1.0 - p_norm_sq should be greater than 0"
    d2 = torch.sqrt(1.0 - p_norm_sq) * rays_d_cos
    return d1 + d2

def w2i_projection(w_xyz, cam2world, intrinsics):
    """Converts the points in world coordinates to camera view.
    :param xyz: points in world coordinates (SB*NV, NC, 3)
    :param poses: camera matrix (SB*NV, 4, 4)
    :output points in camera coordinates (SB*NV, NC, 3)
    : SB batch size
    : NV number of views in each scene
    : NC number of coordinate points
    """
    w_xyz = torch.cat([w_xyz, torch.ones_like(w_xyz[..., :1])], dim=-1)  # [n_points, 4]
    cam_xyz = torch.inverse(cam2world).bmm(w_xyz.permute(0,2,1))
    camera_grids = cam_xyz.permute(0,2,1)[:,:,:3]
    projections = intrinsics[None, ...].repeat(cam2world.shape[0], 1, 1).bmm(cam_xyz[:,:3,:])
    projections = projections.permute(0,2,1)
    
    uv = projections[..., :2] / torch.clamp(projections[..., 2:3], min=1e-8)  # [n_views, n_points, 2]
    uv = torch.clamp(uv, min=-1e6, max=1e6)
    #uv = projections[..., :2] / projections[..., 2:3]  # [n_views, n_points, 2]
    mask = projections[..., 2] > 0
    return camera_grids, uv, mask

def projection_extrinsics_alldim(w_xyz, w2c, intrinsics):
    """Converts the points in world coordinates to camera view.
    :param xyz: points in world coordinates (SB*NV, NC, 3)
    :param poses: camera matrix (SB*NV, 4, 4)
    :output points in camera coordinates (SB*NV, NC, 3)
    : SB batch size
    : NV number of views in each scene
    : NC number of coordinate points
    """
    w_xyz = torch.cat([w_xyz, torch.ones_like(w_xyz[..., :1])], dim=-1)  # [n_points, 4]
    cam_xyz = w2c.bmm(w_xyz.permute(0,2,1))
    camera_grids = cam_xyz.permute(0,2,1)[:,:,:3]
    projections = intrinsics[None, ...].repeat(w2c.shape[0], 1, 1).bmm(cam_xyz[:,:3,:])
    projections = projections.permute(0,2,1)
    return projections

def projection_extrinsics(w_xyz, w2c, intrinsics):
    """Converts the points in world coordinates to camera view.
    :param xyz: points in world coordinates (SB*NV, NC, 3)
    :param poses: camera matrix (SB*NV, 4, 4)
    :output points in camera coordinates (SB*NV, NC, 3)
    : SB batch size
    : NV number of views in each scene
    : NC number of coordinate points
    """
    w_xyz = torch.cat([w_xyz, torch.ones_like(w_xyz[..., :1])], dim=-1)  # [n_points, 4]
    cam_xyz = w2c.bmm(w_xyz.permute(0,2,1))
    camera_grids = cam_xyz.permute(0,2,1)[:,:,:3]
    projections = intrinsics[None, ...].repeat(w2c.shape[0], 1, 1).bmm(cam_xyz[:,:3,:])
    projections = projections.permute(0,2,1)
    
    uv = projections[..., :2] / torch.clamp(projections[..., 2:3], min=1e-8)  # [n_views, n_points, 2]
    uv = torch.clamp(uv, min=-1e6, max=1e6)
    #uv = projections[..., :2] / projections[..., 2:3]  # [n_views, n_points, 2]
    mask = projections[..., 2] > 0
    return camera_grids, uv, mask

# def compute_projections(xyz, train_poses, train_intrinsics, NV):
#     '''
#     project 3D points into cameras
#     :param xyz: [..., 3]
#     :param train_cameras: [n_views, 34], 34 = img_size(2) + intrinsics(16) + extrinsics(16)
#     :return: pixel locations [..., 2], mask [...]
#     '''
#     original_shape = xyz.shape[:2]
#     xyz = xyz.reshape(-1, 3)
#     num_views = NV
#     xyz_h = torch.cat([xyz, torch.ones_like(xyz[..., :1])], dim=-1)  # [n_points, 4]
#     print("train_intrinsics[None, ...].repeat(train_poses.shape[0], 1, 1)", train_intrinsics[None, ...].repeat(train_poses.shape[0], 1, 1).shape)
#     projections = train_intrinsics[None, ...].repeat(train_poses.shape[0], 1, 1).bmm(torch.inverse(train_poses)) \
#         .bmm(xyz_h.t()[None, ...].repeat(num_views, 1, 1))  # [n_views, 4, n_points]
#     projections = projections.permute(0, 2, 1)  # [n_views, n_points, 4]
#     pixel_locations = projections[..., :2] / torch.clamp(projections[..., 2:3], min=1e-8)  # [n_views, n_points, 2]
#     pixel_locations = torch.clamp(pixel_locations, min=-1e6, max=1e6)
#     mask = projections[..., 2] > 0   # a point is invalid if behind the camera
#     return pixel_locations.reshape((num_views, ) + original_shape + (2, )), \
#             mask.reshape((num_views, ) + original_shape)

# def normalize(pixel_locations, h, w):
#     resize_factor = torch.tensor([w-1., h-1.]).to(pixel_locations.device)[None, None, :]
#     normalized_pixel_locations = 2 * pixel_locations / resize_factor - 1.  # [n_views, n_points, 2]
#     return normalized_pixel_locations
    
# def compute(xyz, poses, imgs, K, NV, H, W):

#     pixel_locations, mask_in_front = compute_projections(xyz, poses, K, NV)
#     normalized_pixel_locations = normalize(pixel_locations, H, W)   # [n_views, n_rays, n_samples, 2]

#     # rgb sampling
#     rgbs_sampled = F.grid_sample(imgs, normalized_pixel_locations, align_corners=True)
#     rgb_sampled = rgbs_sampled.permute(2, 3, 0, 1)  # [n_rays, n_samples, n_views, 3]
#     inbound = inbound(pixel_locations, H, W)

#     return rgb_sampled, inbound

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


  from collections import Mapping
  from collections import Mapping, Set, Iterable
  from collections import Mapping, Set, Iterable
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  (np.int, "int"), (np.int8, "int"),
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  _ft = np.array([False, True], dtype=np.bool)


In [2]:
far = 3.0
t_vals = torch.linspace(0.0, 1.0, 10 + 1)




mids = 0.5 * (t_vals[..., 1:] + t_vals[..., :-1])
upper = torch.cat([mids, t_vals[..., -1:]], -1)
lower = torch.cat([t_vals[..., :1], mids], -1)
t_rand = torch.rand((1, 10 + 1), device=rays_o.device)
t_vals = lower + (upper - lower) * t_rand

t_vals_linear = 1.0 * (1.0 - t_vals) + 3.0 * t_vals

print("t_vals", t_vals)
print("t_vals_linear", t_vals_linear)

NameError: name 'rays_o' is not defined

In [None]:
import random
import numpy as np
src_views_num = np.random.choice(100, 3, replace=False)
dest_view_num = random.choice(src_views_num)

print("src_views_num, dest_view_num", src_views_num, dest_view_num)



### Get data

In [6]:
# dataset = dataset_dict['pd_multi_obj_ae_nocs']
# dataset = dataset_dict['pd_multi_obj']
dataset = dataset_dict['pd_multi']
# root_dir = '/home/zubairirshad/pd-api-py/PD_v3_eval/test_novelobj'
# root_dir = '/home/zubairirshad/pd-api-py/PD_v3_eval/test_novelobj/SF_6thAndMission_medium0'
root_dir = '/home/zubairirshad/pd-api-py/PDStepv3/train'
img_wh = (320, 240)

kwargs = {'root_dir': root_dir,
          'img_wh': tuple(img_wh),
         'split': 'train',
#          'eval_inference':'sapien',
         'white_back': True}

# kwargs = {'root_dir': root_dir,
#           'img_wh': tuple(img_wh),
#          'model_type': 'nerfpp',
#          'split': 'test',
#          'eval_inference': '3view_testnovelobj'}

train_dataset = dataset(**kwargs)
dataloader =  DataLoader(train_dataset,
                  shuffle=False,
                  num_workers=0,
                  batch_size=1,
                  pin_memory=False)



print("len train dataset", len(train_dataset))
for i, data in enumerate(dataloader):
    if i>0:
        break
    for k,v in data.items():
        if k =='deg':
            continue
        print(k,v.shape)


    print("===============================\n\n\n")
    
# for k,v in data.items():
#     if k =='deg':
#         continue
#     data[k] = v.squeeze(0)

len train dataset 200
rays torch.Size([1, 4096, 3])
rgbs torch.Size([1, 4096, 3])
instance_mask torch.Size([1, 4096])
instance_ids torch.Size([1, 4096])





In [8]:
# w,h = data["img_wh"]
# print("deg", np.rad2deg(data['deg']))
# plt.imshow(data["src_imgs"].permute(1,2,0).numpy())
# plt.show()
print("data[instance ids]", data["instance_ids"].shape)
print("len train_dataset ids", len(train_dataset.ids))
print("data[instance_ids]", data["instance_ids"])
plt.imshow(data["rgbs"].reshape(240,320,3).numpy())
plt.show()
# for i, data in enumerate(dataloader):
#     print("i", i)
#     plt.imshow(data["target"].squeeze(0).reshape(240,320, 3).numpy())
#     plt.show()
# plt.show()
# a = (data["instance_mask"]>0).sum()
# print("data[instance mask]", a)

# print("near obj", near_obj[0].shape)

# image = data["rgbs"].reshape(240,320,3).nuprint("len train dataset", len(train_dataset))
# for i, data in enumerate(dataloader):
#     print("i", i)
#     for k,v in data.items():
#         print(k,v.squeeze(0).shape)
#     if i>0:
#         break

#     print("===============================\n\n\n")mpy()

# plt.imshow(image)
# plt.show()
# print("data[rays]")

data[instance ids] torch.Size([1, 4096])
len train_dataset ids 150
data[instance_ids] tensor([[107, 107, 107,  ..., 107, 107, 107]])


RuntimeError: shape '[240, 320, 3]' is invalid for input of size 12288

In [None]:
import torch.nn.functional as F
def unprocess_images(normalized_images, shape = ()):
    inverse_transform = T.Compose([T.Normalize((-0.5/0.5, -0.5/0.5, -0.5/0.5), (1/0.5, 1/0.5, 1/0.5))])
    return inverse_transform(normalized_images)
NV=3

def convert_pose(C2W):
    flip_yz = torch.eye(4)
    flip_yz[1, 1] = -1
    flip_yz[2, 2] = -1
    C2W = torch.matmul(C2W, flip_yz)
    return C2W

for k,v in data.items():
    data[k] = v.squeeze(0)
    
new_src_imgs = unprocess_images(data["src_imgs"])
for i in range(NV):
    if i ==0 or i==1:
        plt.imshow(new_src_imgs[i].permute(1,2,0).numpy())
        plt.show()
    
img_1 = new_src_imgs[0].permute(1,2,0).numpy()
img_2 = new_src_imgs[1].permute(1,2,0).numpy()




# pose_1 = np.linalg.inv(convert_pose(data["src_poses"][0]).numpy())
# pose_2 = np.linalg.inv(convert_pose(data["src_poses"][1]).numpy())

pose_1 = data["w2cs"][0].numpy()
pose_2 = data["w2cs"][1].numpy()

from models.nerfplusplus.util import verify_data


K = torch.FloatTensor([
    [data["src_focal"][0], 0., data["src_c"][0][0]],
    [0., data["src_focal"][0], data["src_c"][0][1]],
    [0., 0., 1.],
])
im = verify_data(np.uint8(img_1*255.), np.uint8(img_2*255.),
                 K, pose_1,
                 K, pose_2)
im = im/255
plt.imshow(im)
plt.show()

### Get Samples and visualize

In [None]:
# import models.nerfplusplus.helper as helper


# for k,v in data.items():
#     data[k] = v.squeeze(0)
    
nerfplusplus = True
contract = False

if nerfplusplus:
    import models.nerfplusplus.helper as helper
    from torch import linalg as LA
    near = torch.full_like(data["rays_o"][..., -1:], 1e-4)
    far = intersect_sphere(data["rays_o"], data["rays_d"])
    
else:
    import models.vanilla_nerf.helper as helper
    from torch import linalg as LA
    near = 0.2
    far = 3.0

if nerfplusplus:
    obj_t_vals, obj_samples = sample_along_rays(
        rays_o=data["rays_o"],
        rays_d=data["rays_d"],
        num_samples=65,
        near = near,
        far = far,
        randomized=True,
        lindisp=False,
        in_sphere=True,
    )

    bg_t_vals, bg_samples, bg_samples_linear = sample_along_rays(
        rays_o=data["rays_o"],
        rays_d=data["rays_d"],
        num_samples=65,
        near=near,
        far=far,
        randomized=True,
        lindisp=False,
        in_sphere=False,
    )
    print("toch max min bg_samples linear", torch.max(bg_samples_linear), torch.min(bg_samples_linear))
else:
    all_t_vals, all_samples = helper.sample_along_rays(
        rays_o=data["rays_o"],
        rays_d=data["rays_d"],
        num_samples=65,
        near=near,
        far=far,
        randomized=True,
        lindisp=False,
    )
    

def reverse_contract_pts(pts, radius):
    norm_pts = torch.norm(pts, dim=-1).unsqueeze(-1)
    contracted_points = ((1 + 0.2) - 0.2 / norm_pts) * (pts / norm_pts) * radius
    return contracted_points

def contract_pts(pts, radius=3):
    mask = torch.norm(pts, dim=-1).unsqueeze(-1) > radius
    new_pts = pts.clone()/radius
    norm_pts = torch.norm(new_pts, dim=-1).unsqueeze(-1)
    contracted_points = ((1+0.5) - 0.5/(norm_pts))*(new_pts/norm_pts)*radius
    warped_points = torch.where(
        mask == False, pts, mask*contracted_points 
    )
    
    return warped_points

def uncontract_pts(pts, radius=1):
    mask = torch.norm(pts, dim=-1).unsqueeze(-1) > radius
    new_pts = pts.clone()/radius
    norm_pts = torch.norm(new_pts, dim=-1).unsqueeze(-1)
    contracted_points = ((1+0.6) - 0.6/(norm_pts))*(new_pts/norm_pts)*radius
    warped_points = torch.where(
        mask == False, pts, mask*contracted_points 
    )
    
    return warped_points

def contract_samples(x, order=2):
    mag = LA.norm(x, order, dim=-1)[..., None]
    return torch.where(mag < 1, x, (2 - (1 / mag)) * (x / mag))

# def inverse_contract_samples(x, order = order):
#     mag = torch.linalg.norm(x, ord=order, dim=-1)[..., None]
#     mask = mag < 1
#     expanded = x * mag*mag / (2 - (1 / mag))
#     return torch.where(mask, x, expanded)

def _contract(x):
    x_mag_sq = torch.sum(x**2, dim=-1, keepdim=True).clip(min=1e-32)
    z = torch.where(
        x_mag_sq <= 1, x, ((2 * torch.sqrt(x_mag_sq) - 1) / x_mag_sq) * x
    )
    return z

def _inverse_contract(x):
    x_mag_sq = torch.sum(x**2, dim=-1, keepdim=True).clip(min=1e-32)
    z = torch.where(
        x_mag_sq <= 1, x, x * (x_mag_sq / (2 * torch.sqrt(x_mag_sq) - 1))
    )
    return z



if nerfplusplus:
#     all_samples = torch.cat((obj_samples, bg_samples[:,:,:3]), dim=0)
#     all_samples = torch.cat((obj_samples, bg_samples_linear), dim=0)

    all_samples = obj_samples
    print("all_samples", all_samples.shape)
else:
    if contract:
        all_samples = contract_samples(all_samples, float('inf'))
        print("torch min max  all samples", torch.min(samples), torch.max(samples))
#         all_samples = _inverse_contract(all_samples)


if nerfplusplus:
    print("bg obj samples",bg_samples.shape, obj_samples.shape)
else:
    print("all_samples",all_samples.shape)


coords = get_image_coords(pixel_offset = 0.5,image_height = 120, image_width = 160)
print(coords.shape)

rays_o = data["rays_o"][:20,:]
rays_d = data["rays_d"][:20,:]
coords = coords[:20,:]

if nerfplusplus:
    num = np.random.choice(all_samples.shape[0], 5)
    samples = all_samples[num,:, :3]
    
#     num_bg = np.random.choice(bg_samples.shape[0], 20)
    samples_bg = bg_samples[num,:, :3]
    
    print("inverse radius ", bg_samples[:,:,3])
    
#     num_bg_linear = np.random.choice(bg_samples.shape[0], 20)
    samples_bg_linear = bg_samples_linear[num,:, :3]
#     print(samples.shape, samples_bg.shape)
    samples = torch.cat((samples, samples_bg, samples_bg_linear), dim=0)
#     samples = samples_bg_linear
else:
    num = np.random.choice(all_samples.shape[0], 20)
    samples = all_samples[num,:, :3]


fig = vis_camera_samples(samples)
fig.show()

print("torch min max samples", torch.min(samples), torch.max(samples))


### IBRNet feature sampler

In [None]:
w_xyz = all_samples[:,:,:3]
w_xyz = w_xyz.reshape(-1,3).unsqueeze(0)
poses = torch.clone(data["src_poses"])
all_c2w = [convert_pose(pose).unsqueeze(0) for pose in poses]
poses = torch.cat(all_c2w, dim=0).to(dtype = torch.float)

new_images = unprocess_images(data["src_imgs"])

NV = 3

w_xyz = repeat_interleave(w_xyz,NV) # (SB*NV, NC, 3) NC: number of grid cells

K = torch.FloatTensor([
    [data["src_focal"][0], 0., data["src_c"][0][0]],
    [0., data["src_focal"][0], data["src_c"][0][1]],
    [0., 0., 1.],
])

height, width = new_images.size()[2:]
print("poses", poses.shape, w_xyz.shape)

proj_mat = data["proj_mats"][0].unsqueeze(0)

w_xyz = all_samples[:,:,:3].reshape(-1,3).unsqueeze(0)
print("w_xyz", w_xyz.shape)

world_xyz = repeat_interleave(w_xyz, NV)  # (SB*NS, B, 3)
extrinsics = data["w2cs"]
print("extrinsics", extrinsics.shape, world_xyz.shape)
intrinsics = K
uv =  projection_extrinsics_alldim(world_xyz, extrinsics, intrinsics)

# projections  = torch.bmm(proj_mat, w_xyz.permute(0,2,1)).permute(0,2,1)

# print("projections", projections.shape)
# uv = projections[..., :2] / torch.clamp(projections[..., 2:3], min=1e-8)  # [n_views, n_points, 2]
# uv = torch.clamp(uv, min=-1e6, max=1e6)
# print("uv",uv.shape)

im_x = uv[:,:, 0]
im_y = uv[:,:, 1]
pixel_locations = torch.stack([im_x,im_y], dim=-1)
im_grid = torch.stack([2 * im_x / (width - 1) - 1, 2 * im_y / (height - 1) - 1], dim=-1)
print("im_grid", im_grid.shape)


    
imgs = unprocess_images(data["src_imgs"])
imgs = F.interpolate(imgs, size=(120,160), mode='bilinear', align_corners=False)

K = K/2
K[-1,1] = 1
plt.imshow(imgs[0].permute(1,2,0).numpy())
plt.show()

grid = im_grid
grid = grid.unsqueeze(2)

data_im = F.grid_sample(imgs, grid, align_corners=True, mode='bilinear', padding_mode='zeros')
print("data_im", data_im.shape)
all_imgs = data_im.squeeze(-1).permute(0, 2,1).reshape(NV, 240,320,66,3).numpy()
img_1 = all_imgs[2, :,:,65,:]
plt.imshow(img_1)
plt.show()
print("data_im", data_im.shape)

mask = (im_grid.abs() <= 1).float()
mask = (mask[...,0]*mask[...,1]).float()
print(mask.shape)



feats_c = data_im.squeeze(-1).permute(2,0,1)
print("feats_c", feats_c.shape)

# .view(-1,9)


print("feat_c", feats_c.shape)

In [None]:
def index_grid(samples, volume_features, w2cs, focal, c, near=0.2, far = 2.5):
    """
    Get pixel-aligned image features at 2D image coordinates
    :param uv (B, N, 2) image points (x,y)
    :param cam_z ignored (for compatibility)
    :param image_size image size, either (width, height) or single int.
    if not specified, assumes coords are in [-1, 1]
    :param z_bounds ignored (for compatibility)
    :return (B, L, N) L is latent size
    """ 

    w2c_ref = w2cs[0].unsqueeze(0)
    _,_,_, H, W = volume_features.shape
    inv_scale = torch.tensor([W-1, H-1]).to(w2c_ref.device)

    samples = samples.reshape(-1,3).unsqueeze(0)
    intrinsics_ref = torch.FloatTensor([
        [focal[0], 0., c[0][0]],
        [0., focal[0], c[0][1]],
        [0., 0., 1.],
        ]).to(w2c_ref.device)
    intrinsics_ref = intrinsics_ref/2
    intrinsics_ref[-1,-1] = 1

    if intrinsics_ref is not None:
        point_samples_pixel = projection_extrinsics_alldim(samples, w2c_ref, intrinsics_ref)
        point_samples_pixel = point_samples_pixel.squeeze(0)
        point_samples_pixel[:,:2] = (point_samples_pixel[:,:2] / point_samples_pixel[:,-1:] + 0.0) / inv_scale.reshape(1,2)  
        point_samples_pixel[:,2] = (point_samples_pixel[:,2] - near) / (far - near)  # normalize to 0~1
    
    grid = point_samples_pixel.view(1, 1, 1, -1, 3) * 2 - 1.0  # [1 1 H W 3] (x,y,z)
    features = F.grid_sample(volume_features, grid, align_corners=True, mode='bilinear')[:,:,0].permute(2,3,0,1).squeeze()#, padding_mode="border"

    return features

samples = torch.randn((2048,65,3))
volume_features = torch.randn((1, 8, 128, 120, 160))
imgs = torch.randn((3, 3, 240, 320))
w2cs = torch.randn((3,4,4))
focal = torch.randn((3))
c = torch.randn((3,2))
print("samples", samples.shape, volume_features.shape)
features = index_grid(samples, volume_features, w2cs, focal, c, 240, 320)

print("features", features.shape)


### Encode for PN

In [None]:
from models.vanilla_nerf.encoder import *
NV = 3
pn = True
def unprocess_images(normalized_images, shape = ()):
    inverse_transform = T.Compose([T.Normalize((-0.5/0.5, -0.5/0.5, -0.5/0.5), (1/0.5, 1/0.5, 1/0.5))])
    print("unnormalize dimgs", inverse_transform(normalized_images).shape)
    return inverse_transform(normalized_images)

w_xyz = all_samples[:,:,:3]
print("w_xyz", w_xyz.shape)

B, N_samples, _ = w_xyz.shape
w_xyz = w_xyz.reshape(-1,3).unsqueeze(0)

# print("wxyz, poses", w_xyz.shape, poses.shape)
poses = torch.clone(data["src_poses"])
cam_xyz = world2camera(w_xyz, poses, NV)

print("cam_xyz", cam_xyz.shape)

def inbound(pixel_locations, h, w):
    '''
    check if the pixel locations are in valid range
    :param pixel_locations: [..., 2]
    :param h: height
    :param w: weight
    :return: mask, bool, [...]
    '''
    return (pixel_locations[..., 0] <= w - 1.) & \
           (pixel_locations[..., 0] >= 0) & \
           (pixel_locations[..., 1] <= h - 1.) &\
           (pixel_locations[..., 1] >= 0)



encoder = SpatialEncoder(backbone="resnet34",
                                              pretrained=True,
                                              num_layers=4,
                                              index_interp="bilinear",
                                              index_padding="zeros",
                                              # index_padding="border",
                                              upsample_interp="bilinear",
                                              feature_scale=1.0,
                                              use_first_pool=True,
                                              norm_type="batch")

latent = encoder(data["src_imgs"])
print(latent.shape)
# height, width = latent.size()[2:]



# after encoder unnormalize images

for i in range(NV):
    plt.imshow(data["src_imgs"][i].permute(1,2,0).numpy())
    plt.show()
    
new_images = unprocess_images(data["src_imgs"])

# new_images = F.interpolate(new_images, size=(120,160), mode='bilinear', align_corners=False)

for i in range(NV):
    plt.imshow(new_images[i].permute(1,2,0).numpy())
    plt.show()

print("new_images", new_images.shape)
height, width = new_images.size()[2:]

print("height, width", height, width)

if pn:
    #PN projection
    focal = data["src_focal"][0].unsqueeze(-1).repeat((1, 2))
    focal[..., 1] *= -1.0
    c = data["src_c"][0].unsqueeze(0)
    uv_pn = projection(cam_xyz, focal, c)

    im_x = uv_pn[:,:, 0]
    im_y = uv_pn[:,:, 1]
    pixel_locations = torch.stack([im_x,im_y], dim=-1)
#     mask_inbound = inbound(pixel_locations, height, width)

    im_grid = torch.stack([2 * im_x / (width - 1) - 1, 2 * im_y / (height - 1) - 1], dim=-1)

    mask_z = cam_xyz[:,:,2]<1e-3

    mask = im_grid.abs() <= 1
    print("MASKKKK, MASKKKK_ZZZZZZZ", mask.shape, mask_z.shape)
    in_mask = (mask[...,0]*mask[...,1]).float()
    print("in_mask!!!!!!!!!!!!!!!!!!", in_mask.shape)
    print("mask", mask.shape)
    mask = (mask.sum(dim=-1) == 2) & (mask_z)
    #mask = mask_inbound & mask_z
    print("mask", mask.shape)
    

    print("mask", mask.shape)
    
    print("torch min max im_grid_pn", torch.min(im_grid), torch.max(im_grid))
    min_deg_point = 0
    max_deg_point = 10
    samples_cam = cam_xyz
    samples_enc = helper.pos_enc(
        samples_cam,
        min_deg_point,
        max_deg_point,
    )

    viewdirs = world2camera_viewdirs(data["viewdirs"].unsqueeze(0), poses, NV)
    viewdirs_enc = helper.pos_enc(viewdirs, 0, 4)
    a = torch.tile(viewdirs_enc[:, None, :], (1, N_samples, 1))
    viewdirs_enc = torch.tile(viewdirs_enc[:, None, :], (1, N_samples, 1)).reshape(
            NV, -1, viewdirs_enc.shape[-1]
        )
    print("viewdirs_enc", viewdirs_enc.shape)
    
else:
    focal = data["src_focal"]
    c = data["src_c"]

    print("data src focl, src c","data[src_focal]", data["src_focal"], data["src_c"])
    K = torch.FloatTensor([
        [data["src_focal"][0], 0., data["src_c"][0][0]],
        [0., data["src_focal"][0], data["src_c"][0][1]],
        [0., 0., 1.],
    ])
    poses = data["src_poses"]
    all_c2w = [convert_pose(pose).unsqueeze(0) for pose in poses]
    poses = torch.cat(all_c2w, dim=0).to(dtype = torch.float)

    world_xyz = repeat_interleave(w_xyz, NV)  # (SB*NS, B, 3)
    cam_xyz, uv_rt, mask= w2i_projection(world_xyz, poses, K)
    im_x = uv_rt[:,:, 0]
    im_y = uv_rt[:,:, 1]
    pixel_locations = torch.stack([im_x,im_y], dim=-1)
    mask_inbound = inbound(pixel_locations, height, width)
    
    print("mask_inbound", mask_inbound.shape)
    print("mask_inbound", ((mask_inbound[0]>0)==True).sum())
    
    im_grid = torch.stack([2 * im_x / (width - 1) - 1, 2 * im_y / (height - 1) - 1], dim=-1)
    mask_z = cam_xyz[:,:,2]>0
    mask = im_grid.abs() <= 1
    
    print("mask", mask.shape)
    mask = (mask.sum(dim=-1) == 2) & (mask_z) & mask_inbound

    
    print("torch min max im_grid", torch.min(im_grid), torch.max(im_grid))



In [None]:
a = torch.randn((3, 262144))

b = torch.randn((3, 262144, 2))

c = a&b

### Plot sampled projected pixels

In [None]:
print("data[src_imgs]", torch.min(new_images[0]), torch.max(new_images[0]))
print("data[src_imgs]", new_images.shape)
# for i in range(NV):
#     plt.imshow(new_images[i].permute(1,2,0).numpy())
#     plt.show()

imgs = new_images
grid = im_grid
grid = grid.unsqueeze(2)



# mask = grid.abs() <= 1
# za_inbound_mask = (mask ==True).sum()
# print("za_inbound_mask", za_inbound_mask, (za_inbound_mask/(grid.shape[0]*grid.shape[1]*grid.shape[-1]))*100)
#sampled_images = F.grid_sample(imgs, grid, align_corners=True, mode='bilinear', padding_mode="zeros")

V = NV
C=3
# colors = torch.empty((grid.shape[1], V*C), device=imgs.device, dtype=torch.float)
# print("colors", colors.shape)
colors = []
for i, idx in enumerate(range(imgs.shape[0])):
    print("imgs[idx, :, :, :].unsqueeze(0)", imgs[idx, :, :, :].unsqueeze(0).shape)
    print("grid[idx, :, :].unsqueeze(0)", grid[idx, :, :].unsqueeze(0).shape)
    
    print("grid", grid.shape, imgs.shape)
    data_im = F.grid_sample(imgs[idx, :, :, :].unsqueeze(0), grid[idx, :, :].unsqueeze(0), align_corners=True, mode='bilinear', padding_mode='zeros')
    print("mask, data_im", mask.shape, data_im.shape)
#     data_im[mask.unsqueeze(0).unsqueeze(-1)==False] = 0

    # Vis
    print("data_im[0].permute(1, 2, 0)", data_im[0].squeeze(-1).permute(1,0).shape)
    colors.append(data_im.squeeze(-1).permute(0,1,2))
    #colors[...,i*C:i*C+C] = data_im[0].squeeze(-1).permute(1,0)
    all_imgs = data_im.squeeze(-1).squeeze(0).permute(1,0).reshape(240,320,66,3).numpy()
    img_1 = all_imgs[:,:,65,:]
    plt.imshow(img_1)
    plt.show()
    
colors = torch.cat(colors, dim=0)

colors = colors.permute(0,2,1)
print("colors", colors.shape)

colors[mask.unsqueeze(-1).repeat(1,1,colors.shape[-1])==False] = 0

for i in range(colors.shape[0]):
    
    all_imgs = colors[i].reshape(240,320,66,3).numpy()
    img_1 = all_imgs[:,:,65,:]
    plt.imshow(img_1)
    plt.show()

print(colors.shape)

# print("samples image >0", ((sampled_images>0) ==True).sum())
# print("sampled_images", sampled_images.shape)

# for i in range(sampled_images.shape[0]):
#     if i!=1:
#         continue
#     all_imgs = sampled_images[i].squeeze(-1).permute(1,0).reshape(240,320,65, 3).numpy()
#     img_1 = all_imgs[:,:,50,:]
#     plt.imshow(img_1)

### Get world grid and define transforms

In [None]:
side_length = 1.0
grid_size=[256, 256, 256]
sfactor=8

world_grid = get_world_grid([[-side_length, side_length],
                                       [-side_length, side_length],
                                       [0, side_length],
                                       ], [int(grid_size[0]/sfactor), int(grid_size[1]/sfactor), int(grid_size[2]/sfactor)] )  # (1, grid_size**3, 3)

print(world_grid.shape)
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(world_grid.squeeze(0).numpy())

o3d.visualization.draw_plotly([pcd])

world_grids = repeat_interleave(world_grid.clone(),
                                          1 * 1 * 3) # (SB*NV, NC, 3) NC: number of grid cells

In [None]:
print(world_grids.shape)

### Get uv projection and K*[R\T] projection

In [None]:
height, width = latent.size()[2:]
focal = data["src_focal"]/2
# focal[..., 1] *= -1.0
c = data["src_c"]/2
K = torch.FloatTensor([
    [data["src_focal"][0], 0., data["src_c"][0][0]],
    [0., data["src_focal"][1], data["src_c"][0][1]],
    [0., 0., 1.],
])

poses = data["src_poses"]


cam_xyz, uv = w2i_projection(world_grids, poses, K)
image_points = uv[1,:,:]
height, width = latent.size()[2:]
print(height, width)
inbound = torch.logical_and(np.logical_and(image_points[:, 0] > 0, image_points[:, 0] < width),
                      np.logical_and(image_points[:, 1] > 0, image_points[:, 1] < height))
print("inbound uv",(inbound ==True).sum()) 


camera_grids_w2c = world2camera(world_grids, poses)
focal_uv = data["src_focal"][0].unsqueeze(-1).repeat((1, 2))
focal_uv[..., 1] *= -1.0
c_uv = data["src_c"][0].unsqueeze(0)




image_points = uv_projection[1,:,:]
inbound = torch.logical_and(np.logical_and(image_points[:, 0] > 0, image_points[:, 0] < width),
                      np.logical_and(image_points[:, 1] > 0, image_points[:, 1] < height))

print("inbound projection",(inbound ==True).sum()) 

diff = torch.norm(camera_grids_w2c- cam_xyz)
print("diff",diff)

diff_uv = torch.norm(uv_projection- uv)

print("uv min", torch.min(uv), torch.max(uv))
print("uv proj min", torch.min(uv_projection), torch.max(uv_projection))

print("diff uv", diff_uv)

print("uv", uv)
print("uv_projection", uv_projection)

In [None]:
import torch
latent = torch.randn((1,3, 96,96,96,128))
print(latent.shape)

latent = latent.mean(1)

print(latent.shape)

latent = latent.permute(0,4,1,2,3)

print(latent.shape)

feature = torch.randn((1, 8, 96, 96, 96))

print(feature.shape)

samples = torch.randn((2000,65,3))

samples = samples.view(-1,3)

samples = samples.view(1, 1, 1, -1, 3) * 2 - 1.0  # [1 1 H W 3] (x,y,z)
print(samples.shape)

import torch.nn.functional as F
data_im = F.grid_sample(feature, samples, align_corners=True, mode='bilinear')

print("data im", data_im.shape)

out = data_im.squeeze().permute(1,0)

print(out.shape)


In [None]:
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np

seg_mask = Image.open('/home/zubairirshad/SAPIEN/renders_balanced/laptop/11586/train/80_degree/seg/r_0.png')
seg_mask =  np.array(seg_mask)

plt.imshow(seg_mask)
plt.show()

print(np.unique(seg_mask))
# seg_mask[seg_mask>1] =1
# plt.imshow(seg_mask)
# plt.show()
instance_mask = seg_mask >0

plt.imshow(instance_mask)
plt.show()
print(instance_mask.shape)

In [None]:
import torch.nn as nn
import torch

embedding = nn.Embedding(91, 32)  # 91 because 0 to 90 inclusive

input_tensor = torch.tensor([0])

# Pass the input tensor through the embedding module to obtain the embedded representation
embedded = embedding(input_tensor)

# Check the shape of the embedded tensor
print(embedded.shape)  # Output: torch.Size([4, 32])