Skip to content

RuntimeError: unwrapped_count > 0 INTERNAL ASSERT FAILED at "../aten/src/ATen/functorch/TensorWrapper.cpp":181, please report a bug to PyTorch. Should have at least one dead wrapper #98021

@richardrl

Description

@richardrl

functorch fails when the function requires manipulating a scalar variable

def kkt(p1_tilde, p2_tilde, lambda_1_tilde, lambda_2_tilde, q_pos1, q_theta1):
        rot_mat = torch.Tensor([[torch.cos(q_theta), -torch.sin(q_theta)],
                                [torch.sin(q_theta), torch.cos(q_theta)]])
        rot_mat_inv = torch.linalg.inv(rot_mat)
        box1_A = A_1_canonical @ rot_mat_inv
        box1_b = b_1_canonical + A_1_canonical @ rot_mat_inv @ q_pos1
        D = 2
        K = torch.zeros(2*D + num_1_points + num_2_points)
        # lagrangian gradient wrt p1
        K[:D] = 2* (p1_tilde - p2_tilde) + box1_A.T @ lambda_1_tilde
        K[D:2*D] = -2* (p1_tilde - p2_tilde) + A_2_canonical.T @ lambda_2_tilde
        K[2*D:2*D+num_1_points] = lambda_1_tilde * (box1_A @ p1_tilde - box1_b)
        K[2*D+num_1_points:2*D+num_1_points+num_2_points] = lambda_2_tilde * (A_2_canonical @ p2_tilde - b_2_canonical)
        return K

The above error in the title comes from having the scalar variable theta

J_q_k_out = jacrev(lambda q_pos1, q_theta1: kkt(p1_tilde, p2_tilde, lambda_1_tilde, lambda_2_tilde, q_pos1, q_theta1), argnums=(0, 1))(q_pos, q_theta)

cc @ezyang @gchanan @zou3519 @Chillee @samdow @soumith @kshitij12345 @janeyx99

Metadata

Metadata

Assignees

Labels

high prioritymodule: assert failureThe issue involves an assert failuremodule: functorchPertaining to torch.func or pytorch/functorchneeds reproductionEnsure you have actionable steps to reproduce the issue. Someone else needs to confirm the repro.triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions