Skip to content

Commit

Permalink
fix torch version check for resolve_conj
Browse files Browse the repository at this point in the history
  • Loading branch information
albertfgu committed Mar 18, 2022
1 parent 07a0c61 commit 83a9f13
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion src/models/sequence/ss/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
_c2r = torch.view_as_real
_r2c = torch.view_as_complex

if torch.__version__.startswith('1.10'):
if tuple(map(int, torch.__version__.split('.')[:2])) >= (1, 10):
_resolve_conj = lambda x: x.conj().resolve_conj()
else:
_resolve_conj = lambda x: x.conj()
Expand Down
2 changes: 1 addition & 1 deletion src/models/sequence/ss/standalone/s4.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def _broadcast_dims(*tensors):
_c2r = torch.view_as_real
_r2c = torch.view_as_complex
_conj = lambda x: torch.cat([x, x.conj()], dim=-1)
if torch.__version__.startswith('1.10'):
if tuple(map(int, torch.__version__.split('.')[:2])) >= (1, 10):
_resolve_conj = lambda x: x.conj().resolve_conj()
else:
_resolve_conj = lambda x: x.conj()
Expand Down

1 comment on commit 83a9f13

@DorotheaKolossa
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thank you so much for the quick support! This fixed the issue :)

Please sign in to comment.