-
Notifications
You must be signed in to change notification settings - Fork 25.6k
Closed
Labels
high prioritymodule: assert failureThe issue involves an assert failureThe issue involves an assert failuremodule: functorchPertaining to torch.func or pytorch/functorchPertaining to torch.func or pytorch/functorchneeds reproductionEnsure you have actionable steps to reproduce the issue. Someone else needs to confirm the repro.Ensure you have actionable steps to reproduce the issue. Someone else needs to confirm the repro.triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
functorch fails when the function requires manipulating a scalar variable
def kkt(p1_tilde, p2_tilde, lambda_1_tilde, lambda_2_tilde, q_pos1, q_theta1):
rot_mat = torch.Tensor([[torch.cos(q_theta), -torch.sin(q_theta)],
[torch.sin(q_theta), torch.cos(q_theta)]])
rot_mat_inv = torch.linalg.inv(rot_mat)
box1_A = A_1_canonical @ rot_mat_inv
box1_b = b_1_canonical + A_1_canonical @ rot_mat_inv @ q_pos1
D = 2
K = torch.zeros(2*D + num_1_points + num_2_points)
# lagrangian gradient wrt p1
K[:D] = 2* (p1_tilde - p2_tilde) + box1_A.T @ lambda_1_tilde
K[D:2*D] = -2* (p1_tilde - p2_tilde) + A_2_canonical.T @ lambda_2_tilde
K[2*D:2*D+num_1_points] = lambda_1_tilde * (box1_A @ p1_tilde - box1_b)
K[2*D+num_1_points:2*D+num_1_points+num_2_points] = lambda_2_tilde * (A_2_canonical @ p2_tilde - b_2_canonical)
return K
The above error in the title comes from having the scalar variable theta
J_q_k_out = jacrev(lambda q_pos1, q_theta1: kkt(p1_tilde, p2_tilde, lambda_1_tilde, lambda_2_tilde, q_pos1, q_theta1), argnums=(0, 1))(q_pos, q_theta)
cc @ezyang @gchanan @zou3519 @Chillee @samdow @soumith @kshitij12345 @janeyx99
Metadata
Metadata
Assignees
Labels
high prioritymodule: assert failureThe issue involves an assert failureThe issue involves an assert failuremodule: functorchPertaining to torch.func or pytorch/functorchPertaining to torch.func or pytorch/functorchneeds reproductionEnsure you have actionable steps to reproduce the issue. Someone else needs to confirm the repro.Ensure you have actionable steps to reproduce the issue. Someone else needs to confirm the repro.triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module