In [21]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [20]:
import numpy as np
import torch 
import torch.nn.functional as F

In [33]:
H = 10
W = 10
focal = 138
near = 2
far = 6
n_sample = 5
# This pose is the camera facing the -x direction 
pose = torch.tensor([
    [0, 0, 1, 0],
    [0, 1, 0, 0],
    [-1, 0, 0, -1],
    [0, 0, 0, 1]
])
negative_z_pose = torch.tensor([
    [1, 0, 0, 0],
    [0, 1, 0, 0],
    [0, 0, 1, -1],
    [0, 0, 0, 1]
])

In [18]:
def yuzhen_get_rays(pose, H, W, focal):
    #the input is:
        # 1. the focal length (use the to calculate the relative direction of the ray according to the camera)
        # 2. the 4*4 camera position matreix:
            # contain the 3*3 rotation matrix
            # 3*1 transfomation matrix
        # 3. the size of the frame (pixel as the unit in order to define the ray number)

    #the output is:
        # 1. the original of the rays
        # 2. the direction of the rays


    R = torch.zeros([W])
    for i in range(int(-W*0.5),int(W*0.5)):
        R[int(i+W*0.5)] = i
    R = torch.unsqueeze(R,0) # create another row dimention
    R = R.expand(H,-1)


    C = torch.zeros([H])
    for i in range(int(-H/2), int(H/2)):
        C[int(i+int(H/2))] = -i
    C = torch.unsqueeze(C,1)
    # print(C)
    C = C.expand(-1,W)

    R = torch.unsqueeze(R,2)
    C = torch.unsqueeze(C,2)
    R_C = torch.cat((R, C), 2)
    z_dim = torch.ones(R.shape)
    z_dim = z_dim * -focal
    R_C_Z = torch.cat((R_C, z_dim),2)
    R_C_Z = R_C_Z / focal
    rotation_matrix = pose[:3, :3]

    R_C_Z = torch.unsqueeze(R_C_Z,2) # change the shape from [10,10,3] to [10,10,1,3] used for the later matiplication
    real_R_C_Z = R_C_Z * rotation_matrix # [10,10,1,3] * [3*3] --> [10,10,1,3] * [1,1,3,3] = [10,10,3,3]
    real_R_C_Z = torch.sum(real_R_C_Z,-1) # add the colum of the 3*3 matrix together

    rays_direction = real_R_C_Z
    rays_direction = F.normalize(rays_direction, p=2, dim=2)
    rays_origion = pose[:3, -1]
    rays_origion = rays_origion.expand(rays_direction.shape)

    return rays_origion, rays_direction




# R = torch.zeros([W])
# for i in range(int(-W*0.5),int(W*0.5)):
#     R[int(i+W*0.5)] = i
# # print(R)

# R = torch.unsqueeze(R,0) # create another row dimention
# # print(R)
# # print(R.shape)
# R = R.expand(H,-1)
# # print(R)
# # print(R.shape)


# C = torch.zeros([H])
# for i in range(int(-H/2), int(H/2)):
#     C[int(i+int(H/2))] = -i
# C = torch.unsqueeze(C,1)
# # print(C)
# C = C.expand(-1,W)
# # print(C)

# R = torch.unsqueeze(R,2)
# print(R.shape)
# C = torch.unsqueeze(C,2)
# print(C.shape)
# R_C = torch.cat((R, C), 2)
# print(R_C.shape)
# # print(R_C)

# z_dim = torch.ones(R.shape)
# z_dim = z_dim * -focal
# print(z_dim.shape)
# # print(z_dim)

# R_C_Z = torch.cat((R_C, z_dim),2)
# print(R_C_Z.shape)
# R_C_Z = R_C_Z / focal
# # print (R_C_Z)


# rotation_matrix = pose[:3, :3]

# ## one direction for test demo--------
# test_dir = R_C_Z[0][0]
# print(test_dir)


# real_test_dir = test_dir * rotation_matrix
# print(rotation_matrix)
# print (real_test_dir)
# real_test_dir = torch.sum(real_test_dir, -1)
# print(real_test_dir)
# ##--------------------------------------

# ## do it for all the direction---------
# R_C_Z = torch.unsqueeze(R_C_Z,2) # change the shape from [10,10,3] to [10,10,1,3] used for the later matiplication
# print(R_C_Z.shape)
# real_R_C_Z = R_C_Z * rotation_matrix # [10,10,1,3] * [3*3] --> [10,10,1,3] * [1,1,3,3] = [10,10,3,3]
# real_R_C_Z = torch.sum(real_R_C_Z,-1) # add the colum of the 3*3 matrix together
# print(real_R_C_Z.shape)
# # print(real_R_C_Z)
# ##-------------------------------------


# rays_direction = real_R_C_Z

# ##-----------------------start to calculate the rays_original-------------------------------
# ## the origin of the rays is the camera position --> the transportation 3D vector
# rays_origion = pose[:3, -1] # the :3 means the row 0,1,2. the -1 means the last colum
# print(rays_origion.shape)
# # print(rays_origion)

# ## since later we need to use the rays_direction and rays_origion to do the calculation, so it's better
# ## to have them have the shape shape
# #rays_direction shape --> [10,10,3]
# #rays_origion shape --> [3]

# rays_origion = rays_origion.expand(rays_direction.shape)
# print(rays_origion.shape)
# print(rays_origion)


In [63]:
def yuzhen_position_encoder(original_position, embeded_level):
    #the input:
        #the original position 
        #max level for encoding
    #the output:
        #is a list, with (#embeded_level * 2 + 1) elements inside
    encoded_position = [original_position]
    for i in range (0,embeded_level):
        temp_cos = torch.cos(2**i * original_position)
        temp_sin = torch.sin(2**i * original_position)
        encoded_position.append(temp_sin)
        encoded_position.append(temp_cos)

    encoded_position_torch = torch.cat(encoded_position, -1)        
    return encoded_position_torch


x = torch.tensor([1, 2, 3])
result = yuzhen_position_encoder(x, 6)
print(result.shape)
print(result)


torch.Size([39])
tensor([ 1.0000,  2.0000,  3.0000,  0.8415,  0.9093,  0.1411,  0.5403, -0.4161,
        -0.9900,  0.9093, -0.7568, -0.2794, -0.4161, -0.6536,  0.9602, -0.7568,
         0.9894, -0.5366, -0.6536, -0.1455,  0.8439,  0.9894, -0.2879, -0.9056,
        -0.1455, -0.9577,  0.4242, -0.2879,  0.5514, -0.7683, -0.9577,  0.8342,
        -0.6401,  0.5514,  0.9200,  0.9836,  0.8342,  0.3919, -0.1804])


In [64]:
def positional_encoder(x, L_embed=6):
  """
  This function applies positional encoding to the input tensor. Positional encoding is used in NeRF
  to allow the model to learn high-frequency details more effectively. It applies sinusoidal functions
  at different frequencies to the input.

  Parameters:
  x (torch.Tensor): The input tensor to be positionally encoded.
  L_embed (int): The number of frequency levels to use in the encoding. Defaults to 6.

  Returns:
  torch.Tensor: The positionally encoded tensor.
  """

  # Initialize a list with the input tensor.
  rets = [x]

  # Loop over the number of frequency levels.
  for i in range(L_embed):
    #############################################################################
    #                                   TODO                                    #
    #############################################################################
    sin_encoding = torch.sin(2.0 ** i * x)
    cos_encoding = torch.cos(2.0 ** i * x)
    rets.extend([sin_encoding, cos_encoding])
    #############################################################################
    #                             END OF YOUR CODE                              #
    #############################################################################


  # Concatenate the original and encoded features along the last dimension.
  return torch.cat(rets, -1)


x = torch.tensor([1, 2, 3])
result = positional_encoder(x, 6)
print(result.shape)
print(result)

torch.Size([39])
tensor([ 1.0000,  2.0000,  3.0000,  0.8415,  0.9093,  0.1411,  0.5403, -0.4161,
        -0.9900,  0.9093, -0.7568, -0.2794, -0.4161, -0.6536,  0.9602, -0.7568,
         0.9894, -0.5366, -0.6536, -0.1455,  0.8439,  0.9894, -0.2879, -0.9056,
        -0.1455, -0.9577,  0.4242, -0.2879,  0.5514, -0.7683, -0.9577,  0.8342,
        -0.6401,  0.5514,  0.9200,  0.9836,  0.8342,  0.3919, -0.1804])


## define the nerf model 
- we use the fully connection linear layers here (3 layers for the tiny model)
- use the relu function as the activate function 

In [83]:
class mini_nerf_model(torch.nn.Module):
    # def __init__(self, filter_size = 128, embeded_level = 6):
    #     super(mini_nerf_model, self).__init__()

    #     self.layer1 = torch.nn.Linear(3+3*2*embeded_level, filter_size)
    #     self.layer2 = torch.nn.Linear(filter_size, filter_size)
    #     self.layer3 = torch.nn.Linear(filter_size, 3+1)
    #     self.relu = F.relu

    # def forward(self, x):
    #     x = self.relu(self.layer1(x))
    #     x = self.relu(self.layer2(x))
    #     x = self.layer3(x)
    #     return x

    def __init__(self, filter_size=128, num_encoding_functions=6):
        super(mini_nerf_model, self).__init__()
        self.layer1 = torch.nn.Linear(3 + 3 * 2 * num_encoding_functions, filter_size)
        self.layer2 = torch.nn.Linear(filter_size, filter_size)
        self.layer3 = torch.nn.Linear(filter_size, 4)
        self.relu = torch.nn.functional.relu

    def forward(self, x):
        x = self.relu(self.layer1(x))
        x = self.relu(self.layer2(x))
        x = self.layer3(x)
        return x


In [73]:
# get the rays_origion and rays_direction 
rays_origion, rays_direction = yuzhen_get_rays(negative_z_pose,H, W, focal)
print(rays_direction.shape)

torch.Size([10, 10, 3])


In [84]:
def yuzhen_render(nerf_moel, rays_direction, rays_origion, near, far, n_sample):
    #the input is 
        # 1. the camera position
        # 2. ray direction 
        # 3. the near point 
        # 4. the far point
        # 5. the sample point number
    
    #the overfitted mode: imported

    #the output is 
        # 1. the density at each point
        # 2. the RGB color at each point 

    #algorithm: 
    #1. 
        # the nerf itself only need the encoded position as the input
        # the output of the nerf model will be the rgb value and density value of each point
    #2. 
        # once we have the rgb an density value of each point, we need to use the camera location and angle to render the frame
        # calculate weight and alpha 
    
    #step1: calculate the points coordinates 
    distance = torch.linspace(near, far, n_sample) 
    distance = distance.reshape(1,1, n_sample,1)
 
    distance = distance.expand(W,H,-1,1)
 
    rays_direction = torch.unsqueeze(rays_direction, 2)
 
    all_points = distance * rays_direction

    raw_outcome = nerf_moel()
    



distance = torch.linspace(near, far, n_sample) 
print(distance.shape)
distance = distance.reshape(1,1, n_sample,1)
print(distance.shape)

distance = distance.expand(W,H,-1,1)
print(distance.shape)

rays_direction = torch.unsqueeze(rays_direction, 2)
print(rays_direction.shape)

all_points = distance * rays_direction
print(all_points.shape)






all_points_encoded = yuzhen_position_encoder(all_points, embeded_level=6)
# print(all_points_encoded)
print(all_points_encoded.shape)

nerf_model = mini_nerf_model()
ckpt = torch.load("pretrained.pth", map_location=torch.device("cpu"))
print(ckpt)
nerf_model.load_state_dict(ckpt)
raw_result = nerf_model(all_points_encoded[0][0][0])
print(raw_result.shape)

torch.Size([5])
torch.Size([1, 1, 5, 1])
torch.Size([10, 10, 5, 1])
torch.Size([10, 10, 1, 1, 1, 1, 1, 1, 1, 1, 3])
torch.Size([10, 10, 1, 1, 1, 1, 1, 10, 10, 5, 3])
torch.Size([10, 10, 1, 1, 1, 1, 1, 10, 10, 5, 39])


RuntimeError: Error(s) in loading state_dict for mini_nerf_model:
	Missing key(s) in state_dict: "layer1.weight", "layer1.bias", "layer2.weight", "layer2.bias", "layer3.weight", "layer3.bias". 
	Unexpected key(s) in state_dict: "module.layer1.weight", "module.layer1.bias", "module.layer2.weight", "module.layer2.bias", "module.layer3.weight", "module.layer3.bias". 

In [69]:
def get_rays(H, W, focal, pose):
  """
  This function generates camera rays for each pixel in an image. It calculates the origin and direction of rays
  based on the camera's intrinsic parameters (focal length) and extrinsic parameters (pose).
  The rays are generated in world coordinates, which is crucial for the NeRF rendering process.

  Parameters:
  H (int): Height of the image in pixels.
  W (int): Width of the image in pixels.
  focal (float): Focal length of the camera.
  pose (torch.Tensor): Camera pose matrix of size 4x4.

  Returns:
  tuple: A tuple containing two elements:
      rays_o (torch.Tensor): Origins of the rays in world coordinates.
      rays_d (torch.Tensor): Directions of the rays in world coordinates.
  """
  # Create a meshgrid of image coordinates (i, j) for each pixel in the image.
  i, j = torch.meshgrid(
      torch.arange(W, dtype=torch.float32),
      torch.arange(H, dtype=torch.float32)
  )
  # print("i is:", i)
  # print("j is:", j)
  i = i.t()
  j = j.t()
  # print("i is:", i)
  # print("j is:", j)

  # Calculate the direction vectors for each ray originating from the camera center.
  # We assume the camera looks towards -z.
  # The coordinates are normalized with respect to the focal length.
  dirs = torch.stack(
      [(i - W * 0.5) / focal,
        -(j - H * 0.5) / focal,
        -torch.ones_like(i)], -1
      )

  # Transform the direction vectors (dirs) from camera coordinates to world coordinates.
  # This is done using the rotation part (first 3 columns) of the pose matrix.
  rays_d = torch.sum(dirs[..., np.newaxis, :] * pose[:3, :3], -1) #[100,100,1,3] * [3,3] = [100,100,3,3] after sum(,-1) --> [100,100,3]

  # The ray origins (rays_o) are set to the camera position, given by the translation part (last column) of the pose matrix.
  # The position is expanded to match the shape of rays_d for broadcasting.
  rays_o = pose[:3, -1].expand(rays_d.shape)

  # Return the origins and directions of the rays.
  return rays_o, rays_d




i, j = torch.meshgrid(
    torch.arange(W, dtype=torch.float32),
    torch.arange(H, dtype=torch.float32)
)
# print("i is:", i)
# print("j is:", j)
i = i.t()
j = j.t()
# print("i is:", i)
# print("j is:", j)

# Calculate the direction vectors for each ray originating from the camera center.
# We assume the camera looks towards -z.
# The coordinates are normalized with respect to the focal length.
dirs = torch.stack(
    [(i - W * 0.5) / focal,
      -(j - H * 0.5) / focal,
      -torch.ones_like(i)], -1
    )
# print(dirs.shape)
# print(dirs)
rays_d = torch.sum(dirs[..., np.newaxis, :] * pose[:3, :3], -1)  
print(rays_d.shape)
print(rays_d)

 



torch.Size([10, 10, 3])
tensor([[[-1.0000,  0.0362,  0.0362],
         [-1.0000,  0.0362,  0.0290],
         [-1.0000,  0.0362,  0.0217],
         [-1.0000,  0.0362,  0.0145],
         [-1.0000,  0.0362,  0.0072],
         [-1.0000,  0.0362,  0.0000],
         [-1.0000,  0.0362, -0.0072],
         [-1.0000,  0.0362, -0.0145],
         [-1.0000,  0.0362, -0.0217],
         [-1.0000,  0.0362, -0.0290]],

        [[-1.0000,  0.0290,  0.0362],
         [-1.0000,  0.0290,  0.0290],
         [-1.0000,  0.0290,  0.0217],
         [-1.0000,  0.0290,  0.0145],
         [-1.0000,  0.0290,  0.0072],
         [-1.0000,  0.0290,  0.0000],
         [-1.0000,  0.0290, -0.0072],
         [-1.0000,  0.0290, -0.0145],
         [-1.0000,  0.0290, -0.0217],
         [-1.0000,  0.0290, -0.0290]],

        [[-1.0000,  0.0217,  0.0362],
         [-1.0000,  0.0217,  0.0290],
         [-1.0000,  0.0217,  0.0217],
         [-1.0000,  0.0217,  0.0145],
         [-1.0000,  0.0217,  0.0072],
         [-1.0000,  0.