### Init and define functions

In [1]:
import torch
from utils import *
from collections import defaultdict
import matplotlib.pyplot as plt
import time
from models.rendering import *
from models.nerf import *
import metrics
from datasets import dataset_dict
from datasets.llff import *
from torch.utils.data import DataLoader
from functools import partial
from datasets.srn_multi_ae import collate_lambda_train, collate_lambda_val
torch.backends.cudnn.benchmark = True
torch.manual_seed(0)
# torch.set_printoptions(edgeitems=20)

import plotly
import plotly.graph_objects as go

# print(obj_samples)
def contract(x, order):
    mag = LA.norm(x, order, dim=-1)[..., None]
    return torch.where(mag < 1, x, (2 - (1 / mag)) * (x / mag))

def contract_pts(pts, radius):
    mask = torch.norm(pts, dim=-1).unsqueeze(-1) > radius
    new_pts = pts.clone()/radius
    norm_pts = torch.norm(new_pts, dim=-1).unsqueeze(-1)
    contracted_points = ((1+0.2) - 0.2/(norm_pts))*(new_pts/norm_pts)*radius
    warped_points = mask*contracted_points + (~mask)*pts
    return warped_points

def cast_rays(t_vals, origins, directions):
    return origins[..., None, :] + t_vals[..., None] * directions[..., None, :]


def sample_along_rays(
    rays_o,
    rays_d,
    num_samples,
    near,
    far,
    randomized,
    lindisp,
    in_sphere,
    far_uncontracted = 3.0
):
    bsz = rays_o.shape[0]
    print("bsz", bsz)
    t_vals = torch.linspace(0.0, 1.0, num_samples + 1, device=rays_o.device)

    if in_sphere:
        if lindisp:
            t_vals = 1.0 / (1.0 / near * (1.0 - t_vals) + 1.0 / far * t_vals)
        else:
            t_vals = near * (1.0 - t_vals) + far * t_vals
            
    else:
        t_vals = torch.broadcast_to(t_vals, (bsz, num_samples + 1))

    if randomized:
        mids = 0.5 * (t_vals[..., 1:] + t_vals[..., :-1])
        upper = torch.cat([mids, t_vals[..., -1:]], -1)
        lower = torch.cat([t_vals[..., :1], mids], -1)
        t_rand = torch.rand((bsz, num_samples + 1), device=rays_o.device)
        t_vals = lower + (upper - lower) * t_rand
    else:
        t_vals = torch.broadcast_to(t_vals, (bsz, num_samples + 1))

    if in_sphere:
        coords = cast_rays(t_vals, rays_o, rays_d)
        return t_vals, coords

    else:
        
        t_vals_linear = far * (1.0 - t_vals) + far_uncontracted * t_vals
        t_vals = torch.flip(
            t_vals,
            dims=[
                -1,
            ],
        )  # 1.0 -> 0.0
        
        t_vals_linear = torch.flip(
            t_vals_linear,
            dims=[
                -1,
            ],
        )  # 3.0 -> sphere 
        coords = depth2pts_outside(rays_o, rays_d, t_vals)
        coords_linear = cast_rays(t_vals_linear, rays_o, rays_d)
        return t_vals, coords, coords_linear


def depth2pts_outside(rays_o, rays_d, depth):
    """Compute the points along the ray that are outside of the unit sphere.
    Args:
        rays_o: [num_rays, 3]. Ray origins of the points.
        rays_d: [num_rays, 3]. Ray directions of the points.
        depth: [num_rays, num_samples along ray]. Inverse of distance to sphere origin.
    Returns:
        pts: [num_rays, 4]. Points outside of the unit sphere. (x', y', z', 1/r)
    """
    # note: d1 becomes negative if this mid point is behind camera
    rays_o = rays_o[..., None, :].expand(
        list(depth.shape) + [3]
    )  #  [N_rays, num_samples, 3]
    rays_d = rays_d[..., None, :].expand(
        list(depth.shape) + [3]
    )  #  [N_rays, num_samples, 3]
    d1 = -torch.sum(rays_d * rays_o, dim=-1, keepdim=True) / torch.sum(
        rays_d**2, dim=-1, keepdim=True
    )

    p_mid = rays_o + d1 * rays_d
    p_mid_norm = torch.norm(p_mid, dim=-1, keepdim=True)
    rays_d_cos = 1.0 / torch.norm(rays_d, dim=-1, keepdim=True)

    check_pos = 1.0 - p_mid_norm * p_mid_norm
    assert torch.all(check_pos >= 0), "1.0 - p_mid_norm * p_mid_norm should be greater than 0"

    d2 = torch.sqrt(1.0 - p_mid_norm * p_mid_norm) * rays_d_cos
    p_sphere = rays_o + (d1 + d2) * rays_d

    rot_axis = torch.cross(rays_o, p_sphere, dim=-1)
    rot_axis = rot_axis / torch.norm(rot_axis, dim=-1, keepdim=True)
    phi = torch.asin(p_mid_norm)
    theta = torch.asin(p_mid_norm * depth[..., None])  # depth is inside [0, 1]
    rot_angle = phi - theta  # [..., 1]

    # now rotate p_sphere
    # Rodrigues formula: https://en.wikipedia.org/wiki/Rodrigues%27_rotation_formula
    p_sphere_new = (
        p_sphere * torch.cos(rot_angle)
        + torch.cross(rot_axis, p_sphere, dim=-1) * torch.sin(rot_angle)
        + rot_axis
        * torch.sum(rot_axis * p_sphere, dim=-1, keepdim=True)
        * (1.0 - torch.cos(rot_angle))
    )
    p_sphere_new = p_sphere_new / (
        torch.norm(p_sphere_new, dim=-1, keepdim=True) + 1e-10
    )
    pts = torch.cat((p_sphere_new, depth.unsqueeze(-1)), dim=-1)

    return pts

def get_image_coords(pixel_offset, image_height, image_width,
):
    """This gets the image coordinates of one of the cameras in this object.
    If no index is specified, it will return the maximum possible sized height / width image coordinate map,
    by looking at the maximum height and width of all the cameras in this object.
    Args:
        pixel_offset: Offset for each pixel. Defaults to center of pixel (0.5)
        index: Tuple of indices into the batch dimensions of the camera. Defaults to None, which returns the 0th
            flattened camera
    Returns:
        Grid of image coordinates.
    """
    image_coords = torch.meshgrid(torch.arange(image_height), torch.arange(image_width), indexing="ij")
    image_coords = torch.stack(image_coords, dim=-1) + pixel_offset  # stored as (y, x) coordinates
    image_coords = torch.cat([image_coords, torch.ones((*image_coords.shape[:-1], 1))], dim=-1)
    image_coords = image_coords.view(-1, 3)
    return image_coords

def get_sphere(
    radius, center = None, color: str = "black", opacity: float = 1.0, resolution: int = 32
) -> go.Mesh3d:  # type: ignore
    """Returns a sphere object for plotting with plotly.
    Args:
        radius: radius of sphere.
        center: center of sphere. Defaults to origin.
        color: color of sphere. Defaults to "black".
        opacity: opacity of sphere. Defaults to 1.0.
        resolution: resolution of sphere. Defaults to 32.
    Returns:
        sphere object.
    """
    phi = torch.linspace(0, 2 * torch.pi, resolution)
    theta = torch.linspace(-torch.pi / 2, torch.pi / 2, resolution)
    phi, theta = torch.meshgrid(phi, theta, indexing="ij")

    x = torch.cos(theta) * torch.sin(phi)
    y = torch.cos(theta) * torch.cos(phi)
    z = torch.sin(theta)
    pts = torch.stack((x, y, z), dim=-1)

    pts *= radius
    if center is not None:
        pts += center

    return go.Mesh3d(
        {
            "x": pts[:, :, 0].flatten(),
            "y": pts[:, :, 1].flatten(),
            "z": pts[:, :, 2].flatten(),
            "alphahull": 0,
            "opacity": opacity,
            "color": color,
        }
    )
    
def vis_camera_rays(origins, directions, coords) -> go.Figure:  # type: ignore
    """Visualize camera rays.
    Args:
        camera: Camera to visualize.
    Returns:
        Plotly lines
    """
    lines = torch.empty((origins.shape[0] * 2, 3))
    lines[0::2] = origins
    lines[1::2] = origins + directions*3.0
    
    print("lines", lines.shape)

    colors = torch.empty((coords.shape[0] * 2, 3))
    colors[0::2] = coords
    colors[1::2] = coords

    data = []
    data.append(go.Scatter3d(
    x=lines[:, 0],
    y=lines[:, 2],
    z=lines[:, 1],
    marker=dict(
        size=4,
        color=colors)))
        
    data.append(get_sphere(radius=1.0, color="#111111", opacity=0.05))
#     data.append(get_sphere(radius=2.0, color="#111111", opacity=0.05))
    fig = go.Figure(data = data
        
    )
    fig.update_layout(
        scene=dict(
            xaxis=dict(title="x", showspikes=False),
            yaxis=dict(title="z", showspikes=False),
            zaxis=dict(title="y", showspikes=False),
        ),
        margin=dict(r=0, b=10, l=0, t=10),
        hovermode=False,
    )

    return fig

def vis_camera_samples(all_samples) -> go.Figure:  # type: ignore
    """Visualize camera rays.
    Args:
        camera: Camera to visualize.
    Returns:
        Plotly lines
    """
#     samples = samples.view(-1,3)

    data = []
    
    for i in range(all_samples.shape[0]):
        samples = all_samples[i]
        samples_init = samples[:10, :]
        samples_mid = samples[10:50, :]
        samples_final = samples[50:, :]

        data.append(go.Scatter3d(
        x=samples_init[:, 0],
        y=samples_init[:, 2],
        z=samples_init[:, 1],
        mode="markers",
        marker=dict(size=2, color="blue")
        ))

        data.append(go.Scatter3d(
        x=samples_mid[:, 0],
        y=samples_mid[:, 2],
        z=samples_mid[:, 1],
        mode="markers",
        marker=dict(size=2, color="black")
        ))

        data.append(go.Scatter3d(
        x=samples_final[:, 0],
        y=samples_final[:, 2],
        z=samples_final[:, 1],
        mode="markers",
        marker=dict(size=2, color="green")
        ))
        
    data.append(get_sphere(radius=1.0, color="#111111", opacity=0.05))
    data.append(get_sphere(radius=2.0, color="#111111", opacity=0.05))
    fig = go.Figure(data = data
        
    )
    fig.update_layout(
        scene=dict(
            xaxis=dict(title="x", showspikes=False),
            yaxis=dict(title="z", showspikes=False),
            zaxis=dict(title="y", showspikes=False),
        ),
        margin=dict(r=0, b=10, l=0, t=10),
        hovermode=False,
    )

    return fig


def world2camera_viewdirs(w_viewdirs, cam2world, NS):
    w_viewdirs = repeat_interleave(w_viewdirs, NS)  # (SB*NS, B, 3)
    rot = torch.copy(cam2world[:, :3, :3]).transpose(1, 2)  # (B, 3, 3)
    viewdirs = torch.matmul(rot[:, None, :3, :3], w_viewdirs.unsqueeze(-1))[..., 0]
    return viewdirs


def world2camera(w_xyz, cam2world, NS=None):
    """Converts the points in world coordinates to camera view.
    :param xyz: points in world coordinates (SB*NV, NC, 3)
    :param poses: camera matrix (SB*NV, 4, 4)
    :output points in camera coordinates (SB*NV, NC, 3)
    : SB batch size
    : NV number of views in each scene
    : NC number of coordinate points
    """
    #print(w_xyz.shape, cam2world.shape)
    if NS is not None:
        w_xyz = repeat_interleave(w_xyz, NS)  # (SB*NS, B, 3)
    rot = cam2world[:, :3, :3].transpose(1, 2)  # (B, 3, 3)
    trans = -torch.bmm(rot, cam2world[:, :3, 3:])  # (B, 3, 1)
    #print(rot.shape, w_xyz.shape)
    cam_rot = torch.matmul(rot[:, None, :3, :3], w_xyz.unsqueeze(-1))[..., 0]
    cam_xyz = cam_rot + trans[:, None, :, 0]
    # cam_xyz = cam_xyz.reshape(-1, 3)  # (SB*B, 3)
    return cam_xyz

def repeat_interleave(input, repeats, dim=0):
    """
    Repeat interleave along axis 0
    torch.repeat_interleave is currently very slow
    https://github.com/pytorch/pytorch/issues/31980
    """
    output = input.unsqueeze(1).expand(-1, repeats, *input.shape[1:])
    return output.reshape(-1, *input.shape[1:])


def projection(c_xyz, focal, c):
    """Converts [x,y,z] in camera coordinates to image coordinates 
        for the given focal length focal and image center c.
    :param c_xyz: points in camera coordinates (SB*NV, NP, 3)
    :param focal: focal length (SB, 2)
    :c: image center (SB, 2)
    :output uv: pixel coordinates (SB, NV, NP, 2)
    """
    uv = -c_xyz[..., :2] / (c_xyz[..., 2:] + 1e-9)  # (SB*NV, NC, 2); NC: number of grid cells 
    uv *= repeat_interleave(
                focal.unsqueeze(1), NV if focal.shape[0] > 1 else 1
            )
    uv += repeat_interleave(
                c.unsqueeze(1), NV if c.shape[0] > 1 else 1
            )
    return uv


def pos_enc(x, min_deg=0, max_deg=10):
    scales = torch.tensor([2**i for i in range(min_deg, max_deg)]).type_as(x)
    xb = torch.reshape((x[..., None, :] * scales[:, None]), list(x.shape[:-1]) + [-1])
    four_feat = torch.sin(torch.cat([xb, xb + 0.5 * np.pi], dim=-1))
    return torch.cat([x] + [four_feat], dim=-1)

import open3d as o3d
def get_world_grid(side_lengths, grid_size):
    """ Returns a 3D grid of points in world coordinates.
    :param side_lengths: (min, max) for each axis (3, 2)
    :param grid_size: number of points along each dimension () or (3)
    :output grid: (1, grid_size**3, 3)
    """
    if len(grid_size) == 1:
        grid_size = [grid_size[0] for _ in range(3)]
        
    w_x = torch.linspace(side_lengths[0][0], side_lengths[0][1], grid_size[0])
    w_y = torch.linspace(side_lengths[1][0], side_lengths[1][1], grid_size[1])
    w_z = torch.linspace(side_lengths[2][0], side_lengths[2][1], grid_size[2])
    # Z, Y, X = torch.meshgrid(w_x, w_y, w_z)
    X, Y, Z = torch.meshgrid(w_x, w_y, w_z)
    w_xyz = torch.stack([X, Y, Z], axis=-1) # (gs, gs, gs, 3), gs = grid_size
    print(w_xyz.shape)
    w_xyz = w_xyz.reshape(-1, 3).unsqueeze(0) # (1, grid_size**3, 3)
    return w_xyz

def repeat_interleave(input, repeats, dim=0):
    """
    Repeat interleave along axis 0
    torch.repeat_interleave is currently very slow
    https://github.com/pytorch/pytorch/issues/31980
    """
    output = input.unsqueeze(1).expand(-1, repeats, *input.shape[1:])
    return output.reshape(-1, *input.shape[1:])

def w2i_projection(w_xyz, cam2world, intrinsics):
    """Converts the points in world coordinates to camera view.
    :param xyz: points in world coordinates (SB*NV, NC, 3)
    :param poses: camera matrix (SB*NV, 4, 4)
    :output points in camera coordinates (SB*NV, NC, 3)
    : SB batch size
    : NV number of views in each scene
    : NC number of coordinate points
    """
    w_xyz = torch.cat([w_xyz, torch.ones_like(w_xyz[..., :1])], dim=-1)  # [n_points, 4]
    cam_xyz = torch.inverse(cam2world).bmm(w_xyz.permute(0,2,1))
    camera_grids = cam_xyz.permute(0,2,1)[:,:,:3]
    projections = intrinsics[None, ...].repeat(cam2world.shape[0], 1, 1).bmm(cam_xyz[:,:3,:])
    projections = projections.permute(0,2,1)
    uv = projections[..., :2] / projections[..., 2:3]  # [n_views, n_points, 2]
    return camera_grids, uv

def intersect_sphere(rays_o, rays_d):
    """Compute the depth of the intersection point between this ray and unit sphere.
    Args:
        rays_o: [num_rays, 3]. Ray origins.
        rays_d: [num_rays, 3]. Ray directions.
    Returns:
        depth: [num_rays, 1]. Depth of the intersection point.
    """
    # note: d1 becomes negative if this mid point is behind camera

    d1 = -torch.sum(rays_d * rays_o, dim=-1, keepdim=True) / torch.sum(
        rays_d**2, dim=-1, keepdim=True
    )
    p = rays_o + d1 * rays_d
    # consider the case where the ray does not intersect the sphere
    rays_d_cos = 1.0 / torch.norm(rays_d, dim=-1, keepdim=True)
    p_norm_sq = torch.sum(p * p, dim=-1, keepdim=True)
    check_pos = 1.0 - p_norm_sq
    print("check pos", torch.max(p_norm_sq), torch.min(p_norm_sq))
    assert torch.all(check_pos >= 0), "1.0 - p_norm_sq should be greater than 0"
    d2 = torch.sqrt(1.0 - p_norm_sq) * rays_d_cos
    return d1 + d2

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  _ft = np.array([False, True], dtype=np.bool)


### Get data

In [2]:
# class FeatureNet(nn.Module):
#     """
#     output 3 levels of features using a FPN structure
#     """
#     def __init__(self, norm_act=nn.BatchNorm2d):
#         super(FeatureNet, self).__init__()

#         self.conv0 = nn.Sequential(
#                         ConvBnReLU(3, 8, 3, 1, 1, norm_act=norm_act),
#                         ConvBnReLU(8, 8, 3, 1, 1, norm_act=norm_act))

#         self.conv1 = nn.Sequential(
#                         ConvBnReLU(8, 16, 5, 2, 2, norm_act=norm_act),
#                         ConvBnReLU(16, 16, 3, 1, 1, norm_act=norm_act),
#                         ConvBnReLU(16, 16, 3, 1, 1, norm_act=norm_act))

#         self.conv2 = nn.Sequential(
#                         ConvBnReLU(16, 32, 5, 2, 2, norm_act=norm_act),
#                         ConvBnReLU(32, 32, 3, 1, 1, norm_act=norm_act),
#                         ConvBnReLU(32, 32, 3, 1, 1, norm_act=norm_act))

#         self.toplayer = nn.Conv2d(32, 32, 1)
#         self.latent_size = 32

#     def _upsample_add(self, x, y):
#         return F.interpolate(x, scale_factor=2,
#                              mode="bilinear", align_corners=True) + y

#     def forward(self, x):
#         # x: (B, 3, H, W)
#         x = self.conv0(x) # (B, 8, H, W)
#         x = self.conv1(x) # (B, 16, H//2, W//2)
#         x = self.conv2(x) # (B, 32, H//4, W//4)
#         x = self.toplayer(x) # (B, 32, H//4, W//4)

#         return x
    
# class ConvBnReLU(nn.Module):
#     def __init__(self, in_channels, out_channels,
#                  kernel_size=3, stride=1, pad=1,
#                  norm_act=nn.BatchNorm2d):
#         super(ConvBnReLU, self).__init__()
#         self.conv = nn.Conv2d(in_channels, out_channels,
#                               kernel_size, stride=stride, padding=pad, bias=False)
#         self.bn = norm_act(out_channels)
#         self.relu = nn.ReLU(inplace=True)


#     def forward(self, x):
#         return self.relu(self.bn(self.conv(x)))

# class CostRegNet(nn.Module):
#     def __init__(self, in_channels, norm_act=InPlaceABN):
#         super(CostRegNet, self).__init__()
#         self.conv0 = ConvBnReLU3D(in_channels, 8, norm_act=norm_act)

#         self.conv1 = ConvBnReLU3D(8, 16, stride=2, norm_act=norm_act)
#         self.conv2 = ConvBnReLU3D(16, 16, norm_act=norm_act)

#         self.conv3 = ConvBnReLU3D(16, 32, stride=2, norm_act=norm_act)
#         self.conv4 = ConvBnReLU3D(32, 32, norm_act=norm_act)

#         self.conv5 = ConvBnReLU3D(32, 64, stride=2, norm_act=norm_act)
#         self.conv6 = ConvBnReLU3D(64, 64, norm_act=norm_act)

#         self.conv7 = nn.Sequential(
#             nn.ConvTranspose3d(64, 32, 3, padding=1, output_padding=1,
#                                stride=2, bias=False),
#             norm_act(32))

#         self.conv9 = nn.Sequential(
#             nn.ConvTranspose3d(32, 16, 3, padding=1, output_padding=1,
#                                stride=2, bias=False),
#             norm_act(16))

#         self.conv11 = nn.Sequential(
#             nn.ConvTranspose3d(16, 8, 3, padding=1, output_padding=1,
#                                stride=2, bias=False),
#             norm_act(8))

#         # self.conv12 = nn.Conv3d(8, 8, 3, stride=1, padding=1, bias=True)

#     def forward(self, x):
#         conv0 = self.conv0(x)
#         conv2 = self.conv2(self.conv1(conv0))
#         conv4 = self.conv4(self.conv3(conv2))

#         x = self.conv6(self.conv5(conv4))
#         x = conv4 + self.conv7(x)
#         del conv4
#         x = conv2 + self.conv9(x)
#         del conv2
#         x = conv0 + self.conv11(x)
#         del conv0
#         # x = self.conv12(x)
#         return x
    
# model = FeatureNet()

# a = torch.randn((1,3,240,320))

# print(model(a).shape)
    


In [3]:
dataset = dataset_dict['pd_multi_obj_ae_nocs']
root_dir = '/home/zubairirshad/pd-api-py/single_scene_23'
img_wh = (320, 240)

kwargs = {'root_dir': root_dir,
          'img_wh': tuple(img_wh),
         'model_type': 'nerfpp',
         'split': 'val'}

train_dataset = dataset(**kwargs)
dataloader =  DataLoader(train_dataset,
                  shuffle=False,
                  num_workers=0,
                  batch_size=1,
                  pin_memory=False)

print("len train dataset", len(train_dataset))
for i, data in enumerate(dataloader):
    print("i", i)
    for k,v in data.items():
        print(k,v.squeeze(0).shape)
    if i>0:
        break

    print("===============================\n\n\n")
    
for k,v in data.items():
    data[k] = v.squeeze(0)

len train dataset 1


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


near torch.Size([76800, 1]) torch.Size([76800, 1])
near torch.Size([76800, 1]) torch.Size([76800, 1])
near torch.Size([76800, 1]) torch.Size([76800, 1])
near torch.Size([76800, 1]) torch.Size([76800, 1])
i 0
src_imgs torch.Size([3, 3, 240, 320])
src_poses torch.Size([3, 4, 4])
src_focal torch.Size([3])
near_obj torch.Size([4, 76800, 1])
far_obj torch.Size([4, 76800, 1])
instance_mask torch.Size([76800, 1])
inst_seg_mask torch.Size([76800, 1])
src_c torch.Size([3, 2])
rays_o torch.Size([76800, 3])
rays_d torch.Size([76800, 3])
viewdirs torch.Size([76800, 3])
target torch.Size([76800, 3])
nocs_2d torch.Size([76800, 3])
radii torch.Size([76800])
multloss torch.Size([76800, 1])
normals torch.Size([76800, 3])





In [None]:
from models.nerfplusplus import helper
near = torch.full_like(data["rays_o"][..., -1:], 1e-4)
far = helper.intersect_sphere(data["rays_o"], data["rays_d"])

t_vals = torch.linspace(0.0, 1.0, 20 + 1)

print("t_vals", t_vals)

t_vals = near * (1.0 - t_vals) + far * t_vals

# # print("t_vals", t_vals)

print("t_vals", t_vals[0,:])

# dists = t_vals[..., 1:] - t_vals[..., :-1]

# print("dists", dists[:,0])

# dists = torch.cat([dists, far - t_vals[..., -1:]], dim=-1)

# print("dists", dists.shape)


m = (t_vals[...,1:] + t_vals[...,:-1]) * 0.5
print("m", m[0,:])
diff = m[:, -1] - m[:, -2]
last_val = m[:, -1] + diff
print("m", m.shape, last_val.unsqueeze(-1).shape)
m = torch.cat([m, last_val.unsqueeze(-1)], dim=-1)

print("m", m.shape, data["rays_d"].shape)
m *= torch.norm(data["rays_d"][:,None,:], dim=-1)

# # print("m", m)
print("m", m[0,:])

# print("t_vals", t_vals)
# print("t_vals[..., -1:]", t_vals[..., -1:])
# print("t_far - t_vals[..., -1:]", far - t_vals[..., -1:])


#bg

# t_vals = torch.linspace(0.0, 1.0, 20 + 1)

# t_vals = torch.broadcast_to(t_vals, (1000, 20 + 1))

# t_vals = torch.flip(
#     t_vals,
#     dims=[
#         -1,
#     ],
# )  # 1.0 -> 0.0

# print("t_vals", t_vals[0,:])
# m = (t_vals[...,1:] + t_vals[...,:-1]) * 0.5
# print("m", m[0,:])
# diff = m[:, -1] - m[:, -2]
# last_val = m[:, -1] + diff
# print("m", m.shape, last_val.unsqueeze(-1).shape)
# m = torch.cat([m, t_vals[...,-1].unsqueeze(-1)], dim=-1)
# # m = torch.cat([m, last_val.unsqueeze(-1)], dim=-1)

# # print("m", m)
# print("m", m[0,:])

# print("t_vals[..., :-1]", t_vals[..., :-1])

# print("t_vals[..., 1:]", t_vals[..., 1:])

# dists = t_vals[..., :-1] - t_vals[..., 1:]

# print("dists", dists)

# print("t_vals[..., :-1]", t_vals[..., :-1])

# print("t_vals[..., 1:]", t_vals[..., 1:])

### Inspect spatial encoders here

In [None]:
#ENCODER TP MVS NEED PDMULTIOBJ AE CV dataset here
import sys
sys.path.append('/home/zubairirshad/nerf_pl')
from models.nerfplusplus.encoder_tp_mvs import GridEncoder, index_grid, get_c
encoder = GridEncoder()
volume_feat, feats_l, depth_values = encoder(data["src_imgs"], data["proj_mats"])

In [None]:
print("volume_feat", volume_feat.shape)

In [4]:
import sys
sys.path.append('/home/zubairirshad/nerf_pl')
from models.nerfplusplus.encoder_tp_fusion_conv import GridEncoder, index_grid, get_c
encoder = GridEncoder()
scene_grid_xz, scene_grid_xy, scene_grid_yz = encoder(data["src_imgs"], data["src_poses"], data["src_focal"], data["src_c"])

Using torchvision resnet34 encoder
self.pillar_aggregator_yz(latent_inp_x) torch.Size([3, 64, 64, 64, 1])
floorplans_yz torch.Size([3, 64, 64, 512])
grid_yz torch.Size([3, 512, 64, 64])



torch.Size([3, 256, 32, 32])
torch.Size([3, 256, 32, 32])
torch.Size([3, 256, 32, 32])
torch.Size([3, 128, 16, 16])
torch.Size([3, 128, 16, 16])
torch.Size([3, 128, 16, 16])
torch.Size([3, 128, 16, 16])
torch.Size([3, 128, 16, 16])
torch.Size([3, 128, 16, 16])
torch.Size([3, 128, 32, 32])
torch.Size([3, 128, 32, 32])
torch.Size([3, 128, 32, 32])
torch.Size([3, 128, 32, 32])
torch.Size([3, 128, 120, 160])
torch.Size([3, 128, 120, 160])


In [None]:
print("scene_grid_xz", scene_grid_xz.shape)
# for module in encoder.modules():
#     if isinstance(module, nn.BatchNorm2d):
#         print("module", module)
#         module.eval()

In [None]:
import torch.nn as nn
import torch

model = nn.Sequential(
    nn.Conv2d(in_channels=512, out_channels=256, kernel_size=3, stride=2, padding=1),
    nn.BatchNorm2d(256),
    nn.ReLU(inplace=True),
    nn.Conv2d(in_channels=256, out_channels=128, kernel_size=3, stride=2, padding=1),
    nn.BatchNorm2d(128),
    nn.ReLU(inplace=True),
    nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
    nn.BatchNorm2d(128),
    nn.ReLU(inplace=True),
    nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
    nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1),
    nn.BatchNorm2d(128),
    nn.ReLU(inplace=True),
    nn.Upsample(size=(120, 160), mode='bilinear', align_corners=True),
    nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1),
)

num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print("num params", num_params)

a = torch.randn((3, 512, 64, 64))

b = model(a)
print(b.shape)

In [None]:
import torch
import torch.nn as nn
import torchvision.models as models
resnet = models.resnet34(pretrained=True)

# Define a new model that contains only the first 3 layers
new_model = nn.Sequential(*list(resnet.children())[:-3])

print(new_model)

In [None]:
print()
print("data[src_imgs]", data["src_imgs"].shape)
print("out latent", out_latent.shape)
print("out_latent2", out_latent2.shape)
print("scene_grid_xz", scene_grid_xz.shape)
out_latent2 = out_latent2.detach()[0]
NV = 3
latent = out_latent.permute(0,2,1)
print("out latent", out_latent.shape)
out = latent.reshape(3, 96,96,96,3)

for i in range(NV):
    print("===============================\n\n\n,", i)
    data_im = out[i,...]
    print("data_im", data_im.shape)
    
    yz = torch.mean(data_im, dim=0)
    plt.imshow(yz.numpy())
    plt.show()

    xz = torch.mean(data_im, dim=1)
    plt.imshow(xz.numpy())
    plt.show()


    xy = torch.mean(data_im, dim=2)
    print("xy", xy.shape)
    plt.imshow(xy.numpy())
    plt.show()
    
print("$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$\n\n\n\n")
print("$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$\n\n\n\n")
print("$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$\n\n\n\n")
for i in range(NV):
    print("===============================\n\n\n,", i)
    data_im = out_latent2[i,...]
    print("data_im", data_im.shape)
    
    yz = torch.mean(data_im, dim=0)
    plt.imshow(yz.numpy())
    plt.show()

    xz = torch.mean(data_im, dim=1)
    plt.imshow(xz.numpy())
    plt.show()


    xy = torch.mean(data_im, dim=2)
    print("xy", xy.shape)
    plt.imshow(xy.numpy())
    plt.show()
    
print("torch.min max", torch.min(scene_grid_xz), torch.max(scene_grid_xz))
print("torch.min max", torch.min(scene_grid_yz), torch.max(scene_grid_yz))
print("torch.min max", torch.min(scene_grid_xy), torch.max(scene_grid_xy))
    

In [None]:
plt.imshow(scene_grid_xz.squeeze(0).permute(1,2,0).detach().numpy())
plt.show()

plt.imshow(scene_grid_yz.squeeze(0).permute(1,2,0).detach().numpy())
plt.show()

plt.imshow(scene_grid_xy.squeeze(0).permute(1,2,0).detach().numpy())
plt.show()

print("torch.min max", torch.min(scene_grid_xz), torch.max(scene_grid_xz))
print("torch.min max", torch.min(scene_grid_yz), torch.max(scene_grid_yz))
print("torch.min max", torch.min(scene_grid_xy), torch.max(scene_grid_xy))

### Get Samples and visualize

In [None]:
# import models.nerfplusplus.helper as helper



    
nerfplusplus = True
contract = False

if nerfplusplus:
    import models.nerfplusplus.helper as helper
    from torch import linalg as LA
    near = torch.full_like(data["rays_o"][..., -1:], 1e-4)
    far = intersect_sphere(data["rays_o"], data["rays_d"])
    
else:
    import models.vanilla_nerf.helper as helper
    from torch import linalg as LA
    near = 0.2
    far = 3.0

if nerfplusplus:
    obj_t_vals, obj_samples = sample_along_rays(
        rays_o=data["rays_o"],
        rays_d=data["rays_d"],
        num_samples=65,
        near = near,
        far = far,
        randomized=True,
        lindisp=False,
        in_sphere=True,
    )

    bg_t_vals, bg_samples, bg_samples_linear = sample_along_rays(
        rays_o=data["rays_o"],
        rays_d=data["rays_d"],
        num_samples=65,
        near=near,
        far=far,
        randomized=True,
        lindisp=False,
        in_sphere=False,
    )
    print("toch max min bg_samples linear", torch.max(bg_samples_linear), torch.min(bg_samples_linear))
else:
    all_t_vals, all_samples = helper.sample_along_rays(
        rays_o=data["rays_o"],
        rays_d=data["rays_d"],
        num_samples=65,
        near=near,
        far=far,
        randomized=True,
        lindisp=False,
    )
    

def reverse_contract_pts(pts, radius):
    norm_pts = torch.norm(pts, dim=-1).unsqueeze(-1)
    contracted_points = ((1 + 0.2) - 0.2 / norm_pts) * (pts / norm_pts) * radius
    return contracted_points

def contract_pts(pts, radius=3):
    mask = torch.norm(pts, dim=-1).unsqueeze(-1) > radius
    new_pts = pts.clone()/radius
    norm_pts = torch.norm(new_pts, dim=-1).unsqueeze(-1)
    contracted_points = ((1+0.5) - 0.5/(norm_pts))*(new_pts/norm_pts)*radius
    warped_points = torch.where(
        mask == False, pts, mask*contracted_points 
    )
    
    return warped_points

def uncontract_pts(pts, radius=1):
    mask = torch.norm(pts, dim=-1).unsqueeze(-1) > radius
    new_pts = pts.clone()/radius
    norm_pts = torch.norm(new_pts, dim=-1).unsqueeze(-1)
    contracted_points = ((1+0.6) - 0.6/(norm_pts))*(new_pts/norm_pts)*radius
    warped_points = torch.where(
        mask == False, pts, mask*contracted_points 
    )
    
    return warped_points

def contract_samples(x, order=2):
    mag = LA.norm(x, order, dim=-1)[..., None]
    return torch.where(mag < 1, x, (2 - (1 / mag)) * (x / mag))

# def inverse_contract_samples(x, order = order):
#     mag = torch.linalg.norm(x, ord=order, dim=-1)[..., None]
#     mask = mag < 1
#     expanded = x * mag*mag / (2 - (1 / mag))
#     return torch.where(mask, x, expanded)

def _contract(x):
    x_mag_sq = torch.sum(x**2, dim=-1, keepdim=True).clip(min=1e-32)
    z = torch.where(
        x_mag_sq <= 1, x, ((2 * torch.sqrt(x_mag_sq) - 1) / x_mag_sq) * x
    )
    return z

def _inverse_contract(x):
    x_mag_sq = torch.sum(x**2, dim=-1, keepdim=True).clip(min=1e-32)
    z = torch.where(
        x_mag_sq <= 1, x, x * (x_mag_sq / (2 * torch.sqrt(x_mag_sq) - 1))
    )
    return z



if nerfplusplus:
#     all_samples = torch.cat((obj_samples, bg_samples[:,:,:3]), dim=0)
#     all_samples = torch.cat((obj_samples, bg_samples_linear), dim=0)

    all_samples = obj_samples
    print("all_samples", all_samples.shape)
else:
    if contract:
        all_samples = contract_samples(all_samples, float('inf'))
        print("torch min max  all samples", torch.min(samples), torch.max(samples))
#         all_samples = _inverse_contract(all_samples)


if nerfplusplus:
    print("bg obj samples",bg_samples.shape, obj_samples.shape)
else:
    print("all_samples",all_samples.shape)


coords = get_image_coords(pixel_offset = 0.5,image_height = 120, image_width = 160)
print(coords.shape)

rays_o = data["rays_o"][:20,:]
rays_d = data["rays_d"][:20,:]
coords = coords[:20,:]

if nerfplusplus:
    num = np.random.choice(all_samples.shape[0], 20)
    samples = all_samples[num,:, :3]
    
#     num_bg = np.random.choice(bg_samples.shape[0], 20)
    samples_bg = bg_samples[num,:, :3]
    
    print("inverse radius ", bg_samples[:,:,3])
    
#     num_bg_linear = np.random.choice(bg_samples.shape[0], 20)
    samples_bg_linear = bg_samples_linear[num,:, :3]
#     print(samples.shape, samples_bg.shape)
    samples = torch.cat((samples, samples_bg, samples_bg_linear), dim=0)
#     samples = samples_bg_linear
else:
    num = np.random.choice(all_samples.shape[0], 20)
    samples = all_samples[num,:, :3]


fig = vis_camera_samples(samples)
fig.show()

print("torch min max samples", torch.min(samples), torch.max(samples))


In [None]:
print("all_samples", all_samples.shape)

samples_w = all_samples.reshape(-1,3).unsqueeze(0)
samples_cam = world2camera(samples_w, data["src_poses"], 3)

print("samples",samples_cam.shape)

print("all_samples", torch.min(all_samples), torch.max(all_samples))
scale_factor = 1
uv_xz = samples_cam[:, :, [0, 2]].float()
uv_xz = (uv_xz/scale_factor).unsqueeze(2)

print("uv_xz", uv_xz.shape)

grid = torch.randn((3, 128, 64, 64))

import torch.nn.functional as F
scene_latent_xz = F.grid_sample(grid, uv_xz, align_corners=True, mode='bilinear', padding_mode='zeros')

print("scene_latent_xz", scene_latent_xz.shape)

scene_latent_xy = scene_latent_xz

scene_latent_yz = scene_latent_xz

output = torch.sum(torch.stack([scene_latent_xz, scene_latent_xy, scene_latent_yz]), dim=0)
print(output[..., 0].shape)

print("output[..., 0].transpose(1,2)", output[..., 0].permute(0,2, 1).shape)
output = output[..., 0].permute(0,2, 1).reshape(-1, output.shape[1])
print(output.shape)
# a = torch.randn((2000,65,3))
# uv_xz = a[:, :, [0, 2]].reshape(-1,2).unsqueeze(0).float()

# print(uv_xz.shape)


# print("a[:,:,0].float().unsqueeze(-1).", a[:,:,0].float().unsqueeze(-1).shape)
# x = a[:,:,0].float().unsqueeze(-1).reshape(-1,1)
# y = a[:,:,1].float().unsqueeze(-1).reshape(-1,1)

# print("x before ", x)

# x = x*-1
# y=y*-1

# print("x after", x)

# print(x.shape, y.shape)

# uv_xz = torch.stack([x, y], dim=-1)
# print(uv_xz.shape)


### Encode for PN

In [None]:
from models.vanilla_nerf.encoder import *
NV = 1

def unprocess_images(normalized_images, shape = ()):
    inverse_transform = T.Compose([T.Normalize((-0.5/0.5, -0.5/0.5, -0.5/0.5), (1/0.5, 1/0.5, 1/0.5))])
    print("unnormalize dimgs", inverse_transform(normalized_images).shape)
    return inverse_transform(normalized_images)

w_xyz = all_samples[:,:,:3]
w_xyz = w_xyz.reshape(-1,3).unsqueeze(0)
# print("wxyz, poses", w_xyz.shape, poses.shape)
poses = torch.clone(data["src_poses"])
cam_xyz = world2camera(w_xyz, poses, NV)

print("cam_xyz", cam_xyz.shape)


encoder = SpatialEncoder(backbone="resnet34",
                                              pretrained=True,
                                              num_layers=4,
                                              index_interp="bilinear",
                                              index_padding="zeros",
                                              # index_padding="border",
                                              upsample_interp="bilinear",
                                              feature_scale=1.0,
                                              use_first_pool=True,
                                              norm_type="batch")

latent = encoder(data["src_imgs"])
print(latent.shape)
# height, width = latent.size()[2:]



# after encoder unnormalize images

for i in range(NV):
    plt.imshow(data["src_imgs"][i].permute(1,2,0).numpy())
    plt.show()
    
new_images = unprocess_images(data["src_imgs"])

for i in range(NV):
    plt.imshow(new_images[i].permute(1,2,0).numpy())
    plt.show()

height, width = new_images.size()[2:]

#PN projection
focal = data["src_focal"][0].unsqueeze(-1).repeat((1, 2))
focal[..., 1] *= -1.0
c = data["src_c"][0].unsqueeze(0)
uv_pn = projection(cam_xyz, focal, c)

print("focal, c", focal, c)

im_x = uv_pn[:,:, 0]
im_y = uv_pn[:,:, 1]
im_grid_pn = torch.stack([2 * im_x / (width - 1) - 1, 2 * im_y / (height - 1) - 1], dim=-1)

print("torch min max im_grid_pn", torch.min(im_grid_pn), torch.max(im_grid_pn))

#K [R|T\ projection
# focal = data["src_focal"]
# c = data["src_c"]

# print("data src focl, src c","data[src_focal]", data["src_focal"], data["src_c"])
# K = torch.FloatTensor([
#     [data["src_focal"][0], 0., data["src_c"][0][0]],
#     [0., data["src_focal"][0], data["src_c"][0][1]],
#     [0., 0., 1.],
# ])
# print("K", K)

# poses = data["src_poses"]
# world_xyz = repeat_interleave(w_xyz, NV)  # (SB*NS, B, 3)
# cam_xyz, uv_rt= w2i_projection(world_xyz, poses, K)
# image_points_rt = uv_rt[0,:,:]

# width = 320
# height = 240
# im_x = uv_rt[:,:, 0]
# im_y = uv_rt[:,:, 1]
# im_grid_rt = torch.stack([2 * im_x / (width - 1) - 1, 2 * im_y / (height - 1) - 1], dim=-1)

# inbound_rt = torch.logical_and(np.logical_and(image_points_rt[:, 0] > 0, image_points_rt[:, 0] < width),
#                       np.logical_and(image_points_rt[:, 1] > 0, image_points_rt[:, 1] < height))
# print("inbound uv",(inbound_rt ==True).sum()) 


# image_points_pn = uv_pn[0,:,:]

# inbound_pn = torch.logical_and(np.logical_and(image_points_pn[:, 0] > 0, image_points_pn[:, 0] < width),
#                       np.logical_and(image_points_pn[:, 1] > 0, image_points_pn[:, 1] < height))

# inbound_bool_pn = inbound_pn.bool()
# za_inbound_pn = (inbound_bool_pn ==True).sum()
# print("za",za_inbound_pn)



### Plot sampled projected pixels

In [None]:
print("data[src_imgs]", torch.min(new_images[0]), torch.max(new_images[0]))
print("data[src_imgs]", new_images.shape)
for i in range(NV):
    plt.imshow(new_images[i].permute(1,2,0).numpy())
    plt.show()

imgs = new_images
grid = im_grid_pn
grid = grid.unsqueeze(2)

mask = grid.abs() <= 1
za_inbound_mask = (mask ==True).sum()
print("za_inbound_mask", za_inbound_mask, (za_inbound_mask/(grid.shape[0]*grid.shape[1]*grid.shape[-1]))*100)
#sampled_images = F.grid_sample(imgs, grid, align_corners=True, mode='bilinear', padding_mode="zeros")

V = NV
C=3
colors = torch.empty((grid.shape[1], V*C), device=imgs.device, dtype=torch.float)
print("colors", colors.shape)
for i, idx in enumerate(range(imgs.shape[0])):
    print("imgs[idx, :, :, :].unsqueeze(0)", imgs[idx, :, :, :].unsqueeze(0).shape)
    print("grid[idx, :, :].unsqueeze(0)", grid[idx, :, :].unsqueeze(0).shape)
    data_im = F.grid_sample(imgs[idx, :, :, :].unsqueeze(0), grid[idx, :, :].unsqueeze(0), align_corners=True, mode='bilinear', padding_mode='zeros')
    print("data", data_im.shape)

    # Vis
    print("data_im[0].permute(1, 2, 0)", data_im[0].squeeze(-1).permute(1,0).shape)
    colors[...,i*C:i*C+C] = data_im[0].squeeze(-1).permute(1,0)
    all_imgs = data_im.squeeze(-1).squeeze(0).permute(1,0).reshape(240,320,65,3).numpy()
    img_1 = all_imgs[:,:,10,:]
    plt.imshow(img_1)
    plt.show()

# print("samples image >0", ((sampled_images>0) ==True).sum())
# print("sampled_images", sampled_images.shape)

# for i in range(sampled_images.shape[0]):
#     if i!=1:
#         continue
#     all_imgs = sampled_images[i].squeeze(-1).permute(1,0).reshape(240,320,65, 3).numpy()
#     img_1 = all_imgs[:,:,50,:]
#     plt.imshow(img_1)

### Get world grid and define transforms

In [None]:
side_length = 1.0
grid_size=[256, 256, 256]
sfactor=8

world_grid = get_world_grid([[-side_length, side_length],
                                       [-side_length, side_length],
                                       [0, side_length],
                                       ], [int(grid_size[0]/sfactor), int(grid_size[1]/sfactor), int(grid_size[2]/sfactor)] )  # (1, grid_size**3, 3)

print(world_grid.shape)
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(world_grid.squeeze(0).numpy())

o3d.visualization.draw_plotly([pcd])

world_grids = repeat_interleave(world_grid.clone(),
                                          1 * 1 * 3) # (SB*NV, NC, 3) NC: number of grid cells

In [None]:
print(world_grids.shape)

### Get uv projection and K*[R\T] projection

In [None]:
height, width = latent.size()[2:]
focal = data["src_focal"]/2
# focal[..., 1] *= -1.0
c = data["src_c"]/2
K = torch.FloatTensor([
    [data["src_focal"][0], 0., data["src_c"][0][0]],
    [0., data["src_focal"][1], data["src_c"][0][1]],
    [0., 0., 1.],
])

poses = data["src_poses"]


cam_xyz, uv = w2i_projection(world_grids, poses, K)
image_points = uv[1,:,:]
height, width = latent.size()[2:]
print(height, width)
inbound = torch.logical_and(np.logical_and(image_points[:, 0] > 0, image_points[:, 0] < width),
                      np.logical_and(image_points[:, 1] > 0, image_points[:, 1] < height))
print("inbound uv",(inbound ==True).sum()) 


camera_grids_w2c = world2camera(world_grids, poses)
focal_uv = data["src_focal"][0].unsqueeze(-1).repeat((1, 2))
focal_uv[..., 1] *= -1.0
c_uv = data["src_c"][0].unsqueeze(0)




image_points = uv_projection[1,:,:]
inbound = torch.logical_and(np.logical_and(image_points[:, 0] > 0, image_points[:, 0] < width),
                      np.logical_and(image_points[:, 1] > 0, image_points[:, 1] < height))

print("inbound projection",(inbound ==True).sum()) 

diff = torch.norm(camera_grids_w2c- cam_xyz)
print("diff",diff)

diff_uv = torch.norm(uv_projection- uv)

print("uv min", torch.min(uv), torch.max(uv))
print("uv proj min", torch.min(uv_projection), torch.max(uv_projection))

print("diff uv", diff_uv)

print("uv", uv)
print("uv_projection", uv_projection)

In [None]:
import torch
latent = torch.randn((1,3, 96,96,96,128))
print(latent.shape)

latent = latent.mean(1)

print(latent.shape)

latent = latent.permute(0,4,1,2,3)

print(latent.shape)

feature = torch.randn((1, 8, 96, 96, 96))

print(feature.shape)

samples = torch.randn((2000,65,3))

samples = samples.view(-1,3)

samples = samples.view(1, 1, 1, -1, 3) * 2 - 1.0  # [1 1 H W 3] (x,y,z)
print(samples.shape)

import torch.nn.functional as F
data_im = F.grid_sample(feature, samples, align_corners=True, mode='bilinear')

print("data im", data_im.shape)

out = data_im.squeeze().permute(1,0)

print(out.shape)


In [None]:

import torch

t_vals = torch.linspace(0.0, 1.0, 64 + 1)

print("t_vals", t_vals)

near = 0.3

far = 3.0

t_vals = near * (1.0 - t_vals) + far * t_vals

print("t_vals", t_vals)

In [None]:
import torch

img = torch.randn((3,3,240,320))
a = torch.Tensor([img.shape[-1], img.shape[-2]])
print(a)

In [None]:
a = a/2

print(a)

In [None]:
# a = 'compact_mini_01_body-168.png'

# print()
num_str = str(12)
print(num_str.zfill(3))

In [None]:
path = '/home/zubairirshad/pd-api-py/PDMultiObjv6/train'
import os
import shutil
# Loop through all the subfolders in the path
for foldername in os.listdir(path):
#     if foldername!='SF_VanNessAveAndTurkSt4':
#         continue
    # Construct the path to the "nocs_2d" folder
    nocs_path = os.path.join(path, foldername, 'train', 'instance_masks_2d')
    # Loop through all the images from 0 to 198
    filename = os.listdir(nocs_path)[0]
    filename = filename.split('-')[0]
    for i in range(199):
        num_str = str(i)
        # Construct the filename for the image
        # Check if the file exists in the folder
        name = filename + '-'+num_str.zfill(3) +'.png'
        if not os.path.exists(os.path.join(nocs_path, name)):
            
            num_last = str(i-1)
            last_name = filename + '-'+num_last.zfill(3) +'.png'
            last_file_path = os.path.join(nocs_path, last_name)
            
            file_path = os.path.join(nocs_path, name)
            # copy the last file and rename it with the current index
            shutil.copy2(last_file_path, file_path)
#             os.rename(file_path, os.path.join(folder_path, file_name))
            
            print(f'Missing file in folder {foldername}: {name}')

In [None]:
a = '3view_LPIPS'

if 'LPIPS' in a:
    print("TRUE")