Skip to content

Commit

Permalink
Updating depricated torch.trtrs (#54)
Browse files Browse the repository at this point in the history
  • Loading branch information
mattcleigh committed Dec 4, 2023
1 parent 8272cdc commit 27d4bf8
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions normflows/flows/mixing.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,9 +505,9 @@ def weight_inverse(self):
"""
lower, upper = self._create_lower_upper()
identity = torch.eye(self.features, self.features)
lower_inverse, _ = torch.trtrs(identity, lower, upper=False, unitriangular=True)
weight_inverse, _ = torch.trtrs(
lower_inverse, upper, upper=True, unitriangular=False
lower_inverse = torch.linalg.solve_triangular(lower, identity, upper=False, unitriangular=True)
weight_inverse = torch.linalg.solve_triangular(
upper, lower_inverse, upper=True, unitriangular=False
)
return weight_inverse

Expand Down

0 comments on commit 27d4bf8

Please sign in to comment.