In [4]:
import torch as th

def forward(x: th.Tensor, levels: int, scale: float) -> th.Tensor:
    """
    Gets the positional encoding of x for each channel.
    x_i in [-0.5, 0.5] -> function(x_i * pi * 2^j) for function in (cos, sin) for j in [0, levels-1]
    """
    scale = scale*(2**th.arange(levels, device=x.device)).repeat(x.shape[1])
    args = x.repeat_interleave(levels, dim=1) * scale
    print(args)

    return th.hstack((th.cos(args), th.sin(args)))

test_tensor = th.Tensor([[0], [1], [2]])
levels = 5 
scale = 2*th.pi

print(forward(test_tensor, levels, scale))

tensor([[  0.0000,   0.0000,   0.0000,   0.0000,   0.0000],
        [  6.2832,  12.5664,  25.1327,  50.2655, 100.5310],
        [ 12.5664,  25.1327,  50.2655, 100.5310, 201.0619]])
tensor([[1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.7485e-07,
         3.4969e-07, 6.9938e-07, 1.3988e-06, 2.7975e-06],
        [1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 3.4969e-07,
         6.9938e-07, 1.3988e-06, 2.7975e-06, 5.5951e-06]])


In [None]:
# Create a mask of 
# x y z repeat levels times then repeat once more for cos and sin

def calculate_mask(alpha, k):
    if alpha < k: 
        return 0 
    if 0 <= k - alpha < 1: 
        return (1 - cos((alpha - k)/ pi)) /2
    else: 
        return 1 

    

In [33]:

def _get_mask(alpha, levels: int) -> th.Tensor:
        """
        Get the mask from the barf paper that fits the positional encodings
        """
        def get_mask(alpha: th.Tensor, k: th.Tensor) -> th.Tensor:
            """
            Calculates the mask from the barf paper given alpha and k. 

            Args:
                alpha (float): value that is proporitional to the batch
                k (float):     the level of the positional encoding
            
            Returns:
                float: 0 if alpha < k, 1 if alpha - k >= 1 and a cosine interpolation otherwise
            """
            result = th.zeros_like(alpha)
            
            condition1 = alpha - k < 0
            condition2 = (0 <= alpha - k) & (alpha - k < 1)
            
            result[condition1] = 0
            result[condition2] = (1 - th.cos((alpha[condition2] - k[condition2]) * th.pi)) / 2
            result[~(condition1 | condition2)] = 1
            
            return result
        
        # Create a vector of alpha values
        # alpha = 1 # TODO fix this by getting the current epoch 
        alpha = th.ones((levels)) * alpha
        k = th.arange(levels)

        # Get the mask vector 
        mask = get_mask(alpha, k)

        # Reshape mask to take in (x,y,z) and repeat an extra time for sin/cos 
        mask = mask.repeat_interleave(1)
        mask = mask.repeat(2) 

        return mask

mask = _get_mask(0.5, 4).unsqueeze(0)

test_tensor = th.randn(5, 8)
print(test_tensor)

print(test_tensor * mask)

tensor([[ 0.3796,  0.3098, -1.2437,  0.0544,  0.1402, -1.5004,  0.3479, -0.6451],
        [-0.3724,  0.5560,  2.0618,  0.3816,  2.9674, -1.0613,  0.5431, -0.2598],
        [-1.1162, -1.1075, -0.4582, -0.5607,  0.6536, -0.0932,  1.0657, -0.5071],
        [-0.1248, -2.2781,  1.4051, -0.9471, -1.8571,  0.3086, -0.0688, -1.0713],
        [-0.8725, -1.4538, -0.5265,  0.4309, -1.1718,  0.6643,  1.3452, -0.5451]])
tensor([[ 0.1898,  0.0000, -0.0000,  0.0000,  0.0701, -0.0000,  0.0000, -0.0000],
        [-0.1862,  0.0000,  0.0000,  0.0000,  1.4837, -0.0000,  0.0000, -0.0000],
        [-0.5581, -0.0000, -0.0000, -0.0000,  0.3268, -0.0000,  0.0000, -0.0000],
        [-0.0624, -0.0000,  0.0000, -0.0000, -0.9286,  0.0000, -0.0000, -0.0000],
        [-0.4363, -0.0000, -0.0000,  0.0000, -0.5859,  0.0000,  0.0000, -0.0000]])


In [14]:



levels = 5 

# Create a vector of alpha values
alpha = 1 # TODO fix this by getting the current epoch 
alpha = th.ones((levels)) * alpha
k = th.arange(levels)

result = th.zeros_like(alpha)

condition1 = alpha < k
condition2 = (0 <= k - alpha) & (k - alpha < 1)

result[condition1] = 0
result[condition2] = (1 - th.cos((alpha[condition2] - k[condition2]) / th.pi)) / 2
result[~(condition1 | condition2)] = 1

# Reshape mask to take in (x,y,z) and repeat an extra time for sin/cos 
mask = result.repeat_interleave(3)
mask = mask.tile(2) 

print(mask)

tensor([1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])


In [None]:
def subset_dataloader(self, images: list[int], dataset: ImagePoseDataset)-> DataLoader:
        """Get a subset of a dataloader with the specified images."""
        # Slice the dataset
        
        # Return subset of dataset
        return DataLoader(
            dataset=dataset,
            batch_size=self.dataloader_kwargs["batch_size"],
            num_workers=self.dataloader_kwargs["num_workers"],
            shuffle=False,
            pin_memory=True,
            sampler=th.utils.data.SubsetRandomSampler(indices)
        )

In [1]:
import json
import os
import pathlib
import math
import sys
from typing import Callable, Optional, cast

import torch as th
from torch.utils.data import Dataset
import torchvision as tv


# Type alias for dataset output
#  (origin_raw, origin_noisy, direction_raw, direction_noisy, pixel_color_raw, pixel_color_blur, pixel_relative_blur, image_index)
DatasetOutput = tuple[th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor]


class ImagePoseDataset(Dataset[DatasetOutput]):
    def __init__(
        self,
        image_width: int,
        image_height: int,
        images_path: str,
        camera_info_path: str,
        space_transform_scale: Optional[float]=None,
        space_transform_translate: Optional[th.Tensor]=None,
        rotation_noise_sigma: float=1.0,
        translation_noise_sigma: float=1.0,
        noise_seed: Optional[int]=None,
        gaussian_blur_kernel_size: int=40,
        gaussian_blur_relative_sigma_start: float=0.,
        gaussian_blur_relative_sigma_decay: float=1.
    ) -> None:
        """Loads images, camera info, and generates rays for each pixel in each image.
        
        Args:
            image_width (int): Width to resize images to.
            image_height (int): Height to resize images to.
            images_path (str): Path to the image directory.
            camera_info_path (str): Path to the camera info file.
            space_transform_scale (Optional[float], optional): Scale parameter for the space transform. Defaults to None, which auto-calculates based on max distance.
            space_transform_translate (Optional[th.Tensor], optional): Translation parameter for the space transform. Defaults to None, which auto-calculates the mean.
            rotation_noise_sigma (float, optional): Sigma parameter for the rotation noise in radians. Defaults to 1.0.
            translation_noise_sigma (float, optional): Sigma parameter for the translation noise. Defaults to 1.0.
            noise_seed (Optional[int], optional): Seed for the noise generator. Defaults to None.
            gaussian_blur_kernel_size (int, optional): Size of the gaussian blur kernel. Defaults to 5.
            gaussian_blur_relative_sigma_start (float, optional): Initial sigma parameter for the gaussian blur. Set to 0 to disable. Defaults to 0..
            gaussian_blur_relative_sigma_decay (float, optional): Decay factor for the gaussian blur sigma. Defaults to 1..
        """
        super().__init__()
        # Verify parameters
        if gaussian_blur_kernel_size % 2 == 0:
            raise ValueError("Gaussian blur kernel size must be odd.")


        # Store image dimensions
        self.image_height, self.image_width = image_width, image_height
        self.image_batch_size = self.image_width * self.image_height

        # Store paths
        self.images_path = images_path
        self.camera_info_path = camera_info_path

        # Store noise parameters
        # NOTE: Rotation is measured in radians
        self.rotation_noise_sigma = rotation_noise_sigma
        self.translation_noise_sigma = translation_noise_sigma
        self.noise_seed = noise_seed

        # Store gaussian blur parameters
        self.gaussian_blur_kernel_size = gaussian_blur_kernel_size
        self.gaussian_blur_relative_sigma_start = gaussian_blur_relative_sigma_start
        self.gaussian_blur_relative_sigma_decay = gaussian_blur_relative_sigma_decay
        self.gaussian_blur_relative_sigma_current = self.gaussian_blur_relative_sigma_start



        # Load images
        self.images = self._load_images(
            self.images_path, 
            self.image_width, 
            self.image_height
        )

        # Load camera info
        self.focal_length, self.camera_to_world = self._load_camera_info(
            self.camera_info_path, 
            self.image_width
        )

        # Transform camera to world matrices
        (
            self.camera_to_world, 
            self.space_transform_scale, 
            self.space_transform_translate
        ) = self._transform_camera_to_world(
            self.camera_to_world, 
            space_transform_scale,
            space_transform_translate
        )


        # Get gaussian blur kernel
        self.gaussian_blur_kernel = self._get_gaussian_blur_kernel(
            self.gaussian_blur_kernel_size,
            self.gaussian_blur_relative_sigma_current,
            max(self.image_height, self.image_width)
        )


        # Get raw rays for each pixel in each image
        self.origins_raw, self.directions_raw = self._get_raw_rays(
            self.camera_to_world, 
            self.image_width, 
            self.image_height, 
            self.focal_length
        )
        
        # Get (artificially) noisy rays for each pixel in each image
        self.origins_noisy, self.directions_noisy = self._get_noisy_rays(
            self.origins_raw,
            self.directions_raw,
            self.rotation_noise_sigma,
            self.translation_noise_sigma,
            self.noise_seed
        )


        # Store dataset output
        self.dataset = [
            (
                camera_to_world, 
                self.origins_raw[image_name],
                self.origins_noisy[image_name],
                self.directions_raw[image_name], 
                self.directions_noisy[image_name],
                self.images[image_name]
            ) 
            for image_name, camera_to_world in self.camera_to_world.items()
        ]

    def _load_images(self, image_dir_path, image_width, image_height) -> dict[str, th.Tensor]:
        # Transform image to correct format
        transform = cast(
            Callable[[th.Tensor], th.Tensor], 
            tv.transforms.Compose([
                # Convert to float
                tv.transforms.Lambda(lambda img: img.float() / 255.),
                # Resize image
                tv.transforms.Resize(
                    (image_height, image_width), 
                    interpolation=tv.transforms.InterpolationMode.BICUBIC, 
                    antialias=True # type: ignore
                ),
                # Transform alpha to white background (removes alpha too)
                tv.transforms.Lambda(lambda img: img[-1] * img[:3] + (1 - img[-1])),
                # Permute channels to (H, W, C)
                # WARN: This is against the convention of PyTorch.
                #  Doing it to enable easier batching of rays.
                tv.transforms.Lambda(lambda img: img.permute(1, 2, 0))
            ])
        )

        # Open RGBA image
        read = lambda path: tv.io.read_image(
            os.path.join(image_dir_path, path), 
            tv.io.ImageReadMode.RGB_ALPHA
        )

        # Load each image, transform, and store
        return {
            pathlib.PurePath(path).stem: transform(read(path))
            for path in os.listdir(image_dir_path) 
        }

    def _load_camera_info(self, camera_info_path: str, image_width: int) -> tuple[float, dict[str, th.Tensor]]:
        # Read info file
        camera_data = json.loads(open(camera_info_path).read())
        
        # Calculate focal length from camera horizontal angle
        focal_length = image_width / 2 / math.tan(camera_data["camera_angle_x"] / 2)
        # Get camera to world matrices
        # NOTE: Projections are scaled to have scale 1
        camera_to_world: dict[str, th.Tensor] = { 
            pathlib.PurePath(path).stem: th.tensor(camera_to_world) / camera_to_world[-1][-1]
            for frame in camera_data["frames"] 
            for path, rotation, camera_to_world in [frame.values()] 
        }

        # Return focal length and camera to world matrices
        return focal_length, camera_to_world

    def _transform_camera_to_world(self, camera_to_world: dict[str, th.Tensor], space_transform_scale: Optional[float], space_transform_translate: Optional[th.Tensor]) -> tuple[dict[str, th.Tensor], float, th.Tensor]:
        # If space transform is not given, initialize transform parameters from data
        # NOTE: Assuming camera_to_world has scale 1
        camera_positions = th.stack(tuple(camera_to_world.values()))[:, :3, -1] 

        # If no scale is given, initialize to 3*the maximum distance of any two cameras
        if space_transform_scale is None:
            space_transform_scale = 3*th.cdist(camera_positions, camera_positions, compute_mode="donot_use_mm_for_euclid_dist").max().item()

        # If no translation is given, initialize to mean
        if space_transform_translate is None:
            space_transform_translate = camera_positions.mean(dim=0)


        # Only move the offset
        translate_matrix = th.cat((space_transform_translate, th.zeros(1))).view(4, 1)
        translate_matrix = th.hstack((th.zeros((4,3)), translate_matrix))

        # Scale the camera distances
        scale_matrix = th.ones((4, 4))
        scale_matrix[:-1, -1] = space_transform_scale

        # Move origin to average position of all cameras and scale world coordinates by the 3*the maximum distance of any two cameras
        return (
            { 
                image_name: (camera_to_world - translate_matrix)/scale_matrix
                for image_name, camera_to_world 
                in camera_to_world.items() 
            },
            space_transform_scale,
            space_transform_translate
        )

    def _get_raw_rays(self, camera_to_world: dict[str, th.Tensor], image_width: int, image_height: int, focal_length: float) -> tuple[dict[str, th.Tensor], dict[str, th.Tensor]]:
        # Create unit directions (H, W, 3) in camera space
        # NOTE: Initially normalized such that z=-1 via the focal length.
        #  Camera is looking in the negative z direction.
        #  y-axis is also flipped.
        y, x = th.meshgrid(
            -th.linspace(-(image_height-1)/2, (image_height-1)/2, image_height) / focal_length,
            th.linspace(-(image_width-1)/2, (image_width-1)/2, image_width) / focal_length,
            indexing="ij"
        )
        directions = th.stack((x, y, -th.ones_like(x)), dim=-1)
        directions /= th.norm(directions, p=2, dim=-1, keepdim=True)

        # Return rays keyed by image
        return (
            # Origins: Key by image and get focal points directly from camera to world projection
            {
                image_name: camera_to_world[:3, 3].expand_as(directions)
                for image_name, camera_to_world in camera_to_world.items()
            },
            # Directions: Key by image and get directions directly from camera to world projection
            # Rotate directions (H, W, 3) to world via R (3, 3).
            #  Denote dir (row vector) as one of the directions in the directions tensor.
            #  Then R @ dir.T = (dir @ R.T).T. 
            #  This would yield a column vector as output. 
            #  To get a row vector as output again, simply omit the last transpose.
            #  The inside of the parenthesis on the right side 
            #  is conformant for matrix multiplication with the directions tensor.
            # NOTE: Assuming scale of projection matrix is 1
            { 
                image_name: directions @ camera_to_world[:3, :3].T
                for image_name, camera_to_world in camera_to_world.items()
            }
        )

    def _get_noisy_rays(self, origins: dict[str, th.Tensor], directions: dict[str, th.Tensor], rotation_noise_sigma: Optional[float], translation_noise_sigma: Optional[float], noise_seed: Optional[int]) -> tuple[dict[str, th.Tensor], dict[str, th.Tensor]]:
        # Get amount of cameras
        n_cameras = len(origins)
        
        # Instantiate random number generator
        rng = th.Generator()
        if noise_seed is not None:
            rng.manual_seed(noise_seed)


        # Get random rotation amount in radians
        thetas = th.randn((n_cameras, 1, 1), generator=rng) * rotation_noise_sigma

        # Get random rotation axis
        # NOTE: This is a random point on the unit sphere
        # WARN: Technically a division by zero can happen.
        #  This is however mitigated by applying the Ostrich algorithm :)
        axes = th.randn((n_cameras, 3, 1), generator=rng)
        axes /= th.norm(axes, p=2, dim=1, keepdim=True)
        
        # Get rotation matrices via exponential map from lie algebra so(3) -> SO(3)
        so3 = th.cross(
            -th.eye(3).view(1, 3, 3), 
            thetas * axes,
            dim=1
        )

        rotations = th.matrix_exp(so3)


        # Get random translation amount
        translations = th.randn((n_cameras, 3), generator=rng) * translation_noise_sigma


        # Return rays keyed by image
        return (
            # Origins: Key by image and move focal point
            {
                image_name: origins + trans
                for (image_name, origins), trans in zip(origins.items(), translations)
            },
            # Directions: Key by image and get directions directly from camera to world projection
            # Rotate directions (H, W, 3) via R (3, 3).
            #  Denote dir (row vector) as one of the directions in the directions tensor.
            #  Then R @ dir.T = (dir @ R.T).T. 
            #  This would yield a column vector as output. 
            #  To get a row vector as output again, simply omit the last transpose.
            #  The inside of the parenthesis on the right side 
            #  is conformant for matrix multiplication with the directions tensor.
            { 
                image_name: directions @ rot.T
                for (image_name, directions), rot in zip(directions.items(), rotations)
            }
        )

    def _get_gaussian_blur_kernel(self, kernel_size: int, relative_sigma: float, max_side_length: int) -> th.Tensor:
        # If sigma is 0, return a Dirac delta kernel
        if relative_sigma <= sys.float_info.epsilon:
            kernel = th.zeros(kernel_size)
            kernel[kernel_size//2] = 1
        # Else, create 1D Gaussian kernel
        # NOTE: Gaussian blur is separable, so 1D kernel can simply be applied twice
        else:
            kernel = th.linspace(-kernel_size/2, kernel_size/2, kernel_size)
            # Calculate inplace exp(-x^2 / (2 * (relative_sigma*max_side_length)^2))
            kernel.square_().divide_(-2 * (relative_sigma * max_side_length)**2).exp_()
            # Normalize the kernel
            kernel.divide_(kernel.sum())


        return kernel

    def _get_blurred_pixel(self, img: th.Tensor, x: int, y: int, gaussian_blur_kernel: th.Tensor):
        # NOTE: Assuming x and y are within bounds of img

        # Retrive kernel dimensions
        kernel_size = gaussian_blur_kernel.shape[0]
        kernel_half = kernel_size//2

        # Retrieve image dimensions
        img_height, img_width = img.shape[:2]

        # Calculate padding
        left = max(kernel_half - x, 0)
        top = max(kernel_half - y, 0)
        right = max(kernel_half + x - (img_width-1), 0)
        bottom = max(kernel_half + y - (img_height-1), 0)

        pad = tv.transforms.Pad(
            padding=(left, top, right, bottom), 
            padding_mode="reflect"
        )

        # Pad image and retrieve pixel and neighbors
        neighborhood: th.Tensor = pad(img.permute(2, 0, 1))[
            :,
            (top+y-kernel_half):(top+y+kernel_half)+1, 
            (left+x-kernel_half):(left+x+kernel_half)+1,
        ].permute(1, 2, 0)


        # Blur y-direction and store y-column of pixel
        # (H, W, C) -> (W, C)
        blurred_y = (neighborhood * gaussian_blur_kernel.view(-1, 1, 1)).sum(dim=0)
        # Blur x-direction and store pixel
        # (W, C) -> (C)
        blurred_pixel = (blurred_y * gaussian_blur_kernel.view(-1, 1)).sum(dim=0)

        # Return blurred pixel
        return blurred_pixel

    def gaussian_blur_step(self) -> None:
        # Update current variance
        self.gaussian_blur_relative_sigma_current *= self.gaussian_blur_relative_sigma_decay
        # Get new kernel
        self.gaussian_blur_kernel = self._get_gaussian_blur_kernel(
            self.gaussian_blur_kernel_size,
            self.gaussian_blur_relative_sigma_current,
            max(self.image_height, self.image_width)
        )

    def __getitem__(self, index: int) -> DatasetOutput:
        # Get image index
        img_idx = index // self.image_batch_size

        # Get dataset via image index
        P, o_r, o_n, d_r, d_n, img = self.dataset[img_idx]
        # Get pixel index
        i = index % self.image_batch_size
        
        # Get raw pixel color
        c_r = img.view(-1, 3)[i]

        # If no blur, set color to current pixel
        if self.gaussian_blur_relative_sigma_current <= sys.float_info.epsilon:
            c_b = c_r
        # Else, calculate color via gaussian blur
        else:
            c_b = self._get_blurred_pixel(
                img, 
                i % self.image_width, 
                i // self.image_width, 
                self.gaussian_blur_kernel
            )


        return (
            o_r.view(-1, 3)[i], 
            o_n.view(-1, 3)[i], 
            d_r.view(-1, 3)[i], 
            d_n.view(-1, 3)[i], 
            c_r,
            c_b, 
            th.tensor(self.gaussian_blur_relative_sigma_current), 
            th.tensor(img_idx)
        )

    def __len__(self) -> int:
        return len(self.dataset) * self.image_batch_size

In [2]:
# Instanciate dataset 
BATCH_SIZE = 200 
NUM_WORKERS = 1
batch_size=BATCH_SIZE,
num_workers=NUM_WORKERS,
shuffle=True,
pin_memory=True

images_path = os.path.join("../data/lego", "train").replace("\\", "/")
camera_info_path = os.path.join("../data/lego", f"transforms_train.json").replace("\\", "/")

dataset = ImagePoseDataset(
    image_width=400,
    image_height=400,
    images_path=images_path,
    camera_info_path=camera_info_path,
    space_transform_scale=None,
    space_transform_translate=None,
    rotation_noise_sigma=0,
    translation_noise_sigma=0,
    noise_seed=54,
    gaussian_blur_kernel_size=81,
    gaussian_blur_relative_sigma_start=0.,
    gaussian_blur_relative_sigma_decay=0.99
)

In [3]:
from torch.utils.data import DataLoader
new_ds = DataLoader(
            dataset=dataset,
            batch_size=BATCH_SIZE,
            num_workers=NUM_WORKERS,
            shuffle=True,
            pin_memory=True
        )

In [4]:
for batch in new_ds:
    print(batch)
    break