Skip to content

Commit

Permalink
Merge pull request #16 from PhanTask/patch-2
Browse files Browse the repository at this point in the history
Fix torchvision version check
  • Loading branch information
timmeinhardt committed Sep 12, 2023
2 parents e1dbc25 + 19ae9b7 commit e468bf1
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/trackformer/util/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from torch import Tensor
from visdom import Visdom

if float(torchvision.__version__[:3]) < 0.7:
if int(torchvision.__version__.split('.')[0]) <= 0 and int(torchvision.__version__.split('.')[1]) < 7:
from torchvision.ops import _new_empty_tensor
from torchvision.ops.misc import _output_size

Expand Down Expand Up @@ -470,7 +470,7 @@ def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corne
This will eventually be supported natively by PyTorch, and this
class can go away.
"""
if float(torchvision.__version__[:3]) < 0.7:
if int(torchvision.__version__.split('.')[0]) <= 0 and int(torchvision.__version__.split('.')[1]) < 7:
if input.numel() > 0:
return torch.nn.functional.interpolate(
input, size, scale_factor, mode, align_corners
Expand Down

0 comments on commit e468bf1

Please sign in to comment.