In [1]:
import numpy as np
import taichi as ti
import torch
from taichi.math import uvec3

taichi_init_args = {"arch": ti.cuda, "device_memory_GB": 4.0}
ti.init(**taichi_init_args)


[Taichi] version 1.4.1, llvm 15.0.4, commit e67c674e, linux, python 3.9.16
[I 03/29/23 20:11:03.458 2437593] [shell.py:_shell_pop_print@23] Graphical python shell detected, using wrapped sys.stdout
[Taichi] Starting on arch=cuda


In [2]:
@ti.kernel
def torch2ti(field: ti.template(), data: ti.types.ndarray()):
    for I in ti.grouped(data):
        field[I] = data[I]


@ti.kernel
def ti2torch(field: ti.template(), data: ti.types.ndarray()):
    for I in ti.grouped(data):
        data[I] = field[I]


@ti.kernel
def ti2torch_grad(field: ti.template(), grad: ti.types.ndarray()):
    for I in ti.grouped(grad):
        grad[I] = field.grad[I]


@ti.kernel
def torch2ti_grad(field: ti.template(), grad: ti.types.ndarray()):
    for I in ti.grouped(grad):
        field.grad[I] = grad[I]
        
@ti.kernel
def random_initialize(data: ti.types.ndarray()):
    for I in ti.grouped(data):
        data[I] = (ti.random() * 2.0 - 1.0) * 1e-4


In [3]:
@ti.kernel
def hash_kernel(
        xyzs: ti.template(), table: ti.template(),
        xyzs_embedding: ti.template(), B: ti.i32):

    # get hash table embedding
    for i in ti.ndrange(B):
        xyz = ti.Vector([xyzs[i, 0], xyzs[i, 1], xyzs[i, 2]])
        resolution = 512

        pos = xyz * (resolution-1) + 0.5 
        pos_grid = ti.cast(ti.floor(pos), ti.uint32)
        pos -= pos_grid

        map_size = 2**19

        local_feature_0 = 0.0
        local_feature_1 = 0.0

        for idx in ti.static(range(8)):
            w = 1.
            pos_grid_local = uvec3(0)

            # Linear interpolation
            for d in ti.static(range(3)):
                if (idx & (1 << d)) == 0:
                    pos_grid_local[d] = pos_grid[d]
                    w *= 1 - pos[d]
                else:
                    pos_grid_local[d] = pos_grid[d] + 1
                    w *= pos[d]
                    
            # Hash
            _hash_index = ti.uint32(0)
            primes = uvec3(ti.uint32(1), ti.uint32(2654435761), ti.uint32(805459861))
            for i in ti.static(range(3)):
                _hash_index ^= ti.uint32(pos_grid_local[i]) * primes[i] # add randomness
            index = _hash_index % map_size

            index_table = ti.cast(index * 2, ti.int32) # each position consists of two elements
            local_feature_0 += w * table[index_table]
            local_feature_1 += w * table[index_table + 1]

        xyzs_embedding[i, 0] = local_feature_0
        xyzs_embedding[i, 1] = local_feature_1
 

class ToyHashEncoder(torch.nn.Module):

    def __init__(self, batch_size=8192, out_dim=2):
        super(ToyHashEncoder, self).__init__()

        self.out_dim = out_dim
        self.total_hash_size = 2**19
        print("total_hash_size: ", self.total_hash_size)

        self.hash_table = torch.nn.Parameter(
            torch.zeros(self.total_hash_size, 
            dtype=torch.float32),
            requires_grad=True
        )
        
        random_initialize(self.hash_table)

        self.parameter_fields = ti.field(dtype=ti.f32,
                                         shape=(self.total_hash_size, ),
                                         needs_grad=True)
        
        self.input_fields = ti.field(dtype=ti.f32,
                                     shape=(batch_size * 1024, 3),
                                     needs_grad=True)
    
        self.output_fields = ti.field(dtype=ti.f32,
                                      shape=(batch_size * 1024, self.out_dim),
                                      needs_grad=True)
        
        
        self._hash_kernel = hash_kernel

        class _module_function(torch.autograd.Function):

            @staticmethod
            def forward(ctx, input_pos, params):
                ctx.input_size = input_pos.shape
                
                output_embedding = torch.zeros(input_pos.shape[0], self.out_dim,
                                               dtype=input_pos.dtype,
                                               device=input_pos.device)
                
                torch2ti(self.input_fields, input_pos.contiguous())
                torch2ti(self.parameter_fields, params.contiguous())

                self._hash_kernel(
                    self.input_fields,
                    self.parameter_fields,
                    self.output_fields, # output
                    input_pos.shape[0], 
                )
                ti2torch(self.output_fields, output_embedding)

                return output_embedding

            @staticmethod
            def backward(ctx, doutput):
                input_size = ctx.input_size
                
                self.zero_grad()
                
                hash_grad = torch.zeros(self.total_hash_size,
                                        dtype=doutput.dtype,
                                        device=doutput.device)
                
                input_grad = torch.zeros(*input_size,
                                         dtype=doutput.dtype,
                                         device=doutput.device)

                torch2ti_grad(self.output_fields, doutput.contiguous())
                self._hash_kernel.grad(
                    self.input_fields,
                    self.parameter_fields,
                    self.output_fields,
                    doutput.shape[0],
                )
                ti2torch_grad(self.parameter_fields, hash_grad)
                ti2torch_grad(self.input_fields, input_grad)
                return input_grad, hash_grad

        self._module_function = _module_function

    def zero_grad(self):
        self.input_fields.grad.fill(0.)
        self.parameter_fields.grad.fill(0.)

    def forward(self, positions):
        return self._module_function.apply(positions, self.hash_table)

In [4]:
pos_encoder = ToyHashEncoder()

total_hash_size:  524288


In [5]:
xyz = torch.rand(8192, 3)
h = pos_encoder(xyz)

In [6]:
print("h shape: ", h.shape)

h shape:  torch.Size([8192, 2])


In [8]:
repeat = 10
with torch.autograd.profiler.profile(use_cuda=True) as prof:
    for _ in range(repeat):
        # check forward
        h = pos_encoder(xyz)
        loss = ((h * h) - torch.tanh(h)).sum()
        # Backward
        loss.backward()

print(
    'pytorch total\n',
    prof.key_averages(group_by_stack_n=5).table(sort_by='self_cuda_time_total', row_limit=5)
)


pytorch total
 -------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                               _module_functionBackward        32.40%       9.339ms        56.30%      16.228ms       1.623ms      10.140ms        34.93%      16.274ms       1.627ms            10  
                                       _module_function        33.99%       9.796ms        34.69%       9.998ms     999.800us       9.825ms        33.84%      10.045ms       1.004ms           

STAGE:2023-03-29 20:15:14 2437593:2437593 ActivityProfilerController.cpp:300] Completed Stage: Warm Up
STAGE:2023-03-29 20:15:14 2437593:2437593 ActivityProfilerController.cpp:306] Completed Stage: Collection
STAGE:2023-03-29 20:15:14 2437593:2437593 ActivityProfilerController.cpp:310] Completed Stage: Post Processing
