Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Autodiff runtime Error #85

Open
Percyx0313 opened this issue Oct 7, 2023 · 0 comments
Open

Autodiff runtime Error #85

Percyx0313 opened this issue Oct 7, 2023 · 0 comments

Comments

@Percyx0313
Copy link

Percyx0313 commented Oct 7, 2023

I want to change the kernel function at step ray_marching since I want to get the gradient for xyzs.
I reimplement it like volumn rendering but I get the runtime error.

image
`
class RayMarchingRenderer(torch.nn.Module):

def __init__(self):
    super(RayMarchingRenderer, self).__init__()

    self._raymarching_rendering_kernel = raymarching_train_kernel
    class _module_function(torch.autograd.Function):

        @staticmethod
        def forward(
                ctx, 
                rays_o, 
                rays_d, 
                hits_t, 
                density_bitfield, 
                cascades,
                scale, 
                exp_step_factor, 
                grid_size, 
                max_samples
            ):
            noise = torch.rand_like(rays_o[:, 0])
            counter = torch.zeros(
                2,
                device=rays_o.device,
                dtype=torch.int32
            )
            rays_a = torch.empty(
                rays_o.shape[0], 3,
                device=rays_o.device,
                dtype=torch.int32,
            )
            xyzs = torch.empty(
                rays_o.shape[0] * max_samples, 3,
                device=rays_o.device,
                dtype=torch_type,
                requires_grad=True
            )
            dirs = torch.empty(
                rays_o.shape[0] * max_samples, 3,
                device=rays_o.device,
                dtype=torch_type,
                requires_grad=True
            )
            deltas = torch.empty(
                rays_o.shape[0] * max_samples,
                device=rays_o.device,
                dtype=torch_type,
            )
            ts = torch.empty(
                rays_o.shape[0] * max_samples,
                device=rays_o.device,
                dtype=torch_type,
            )
            
            raymarching_train_kernel(
                rays_o, 
                rays_d,
                hits_t,
                density_bitfield, 
                noise, 
                counter,
                rays_a,
                xyzs,
                dirs,
                deltas,
                ts,
                cascades, grid_size, scale,
                exp_step_factor, max_samples
            )

            # total samples for all rays
            total_samples = counter[0]  
            # remove redundant output
            xyzs = xyzs[:total_samples]
            dirs = dirs[:total_samples]
            deltas = deltas[:total_samples]
            ts = ts[:total_samples]
            
            ctx.save_for_backward(
                rays_o, 
                rays_d,
                hits_t,
                density_bitfield, 
                noise, 
                counter,
                rays_a,
                xyzs,
                dirs,
                deltas,
                ts,
            )
            ctx.cascades=cascades
            ctx.grid_size=grid_size
            ctx.scale=scale
            ctx.exp_step_factor=exp_step_factor
            ctx.max_samples=max_samples
            return rays_a, xyzs, dirs, deltas, ts, total_samples

        @staticmethod
        def backward(
                ctx, 
                dL_drays_a, 
                dL_dxyzs, 
                dL_ddirs,
                dL_ddeltas,
                dL_dts,
                dL_dtotal_samples
            ):
            
            cascades=ctx.cascades
            grid_size=ctx.grid_size
            scale=ctx.scale
            exp_step_factor=ctx.exp_step_factor
            max_samples=ctx.max_samples
            (
                rays_o, 
                rays_d,
                hits_t,
                density_bitfield, 
                noise, 
                counter,
                rays_a,
                xyzs,
                dirs,
                deltas,
                ts,
            ) = ctx.saved_tensors
            # put the gradients into the tensors before calling the grad kernel
            rays_a.grad = dL_drays_a
            xyzs.grad = dL_dxyzs
            dirs.grad = dL_ddirs
            deltas.grad=dL_ddeltas
            ts.grad =dL_dts
            # total_samples.grad=dL_dtotal_samples

            self._raymarching_rendering_kernel.grad(
                rays_o, 
                rays_d,
                hits_t,
                density_bitfield, 
                noise, 
                counter,
                rays_a,
                xyzs,
                dirs,
                deltas,
                ts,
                cascades, grid_size, scale,
                exp_step_factor,max_samples
            )

            return rays_o.grad, rays_d.grad, None, None, None, None, None, xyzs.grad, dirs.grad, deltas.grad, ts.grad, None, None, None, None, None

    self._module_function = _module_function.apply

def forward(
        self, 
        rays_o, 
        rays_d, 
        hits_t, 
        density_bitfield, 
        cascades,
        scale, 
        exp_step_factor, 
        grid_size, 
        max_samples
    ):
    return self._module_function(
        rays_o.contiguous(), 
        rays_d.contiguous(), 
        hits_t.contiguous(), 
        density_bitfield, 
        cascades,
        scale, 
        exp_step_factor, 
        grid_size, 
        max_samples
    )

`

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant