Skip to content

Commit

Permalink
Drop the support for PyTorch<2.0 (#3272)
Browse files Browse the repository at this point in the history
  • Loading branch information
ordabayevy committed Oct 4, 2023
1 parent 01ccf36 commit fa73d9c
Show file tree
Hide file tree
Showing 31 changed files with 57 additions and 97 deletions.
2 changes: 1 addition & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ To run a single test from the command line
```sh
pytest -vs {path_to_test}::{test_name}
# or in cuda mode
CUDA_TEST=1 PYRO_TENSOR_TYPE=torch.cuda.DoubleTensor pytest -vs {path_to_test}::{test_name}
CUDA_TEST=1 PYRO_DTYPE=float64 PYRO_DEVICE=cuda pytest -vs {path_to_test}::{test_name}
```

To ensure documentation builds correctly, run
Expand Down
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,11 @@ test-all: lint FORCE
| xargs pytest -vx --nbval-lax

test-cuda: lint FORCE
CUDA_TEST=1 PYRO_TENSOR_TYPE=torch.cuda.DoubleTensor pytest -vx --stage unit
CUDA_TEST=1 PYRO_DTYPE=float64 PYRO_DEVICE=cuda pytest -vx --stage unit
CUDA_TEST=1 pytest -vx tests/test_examples.py::test_cuda

test-cuda-lax: lint FORCE
CUDA_TEST=1 PYRO_TENSOR_TYPE=torch.cuda.DoubleTensor pytest -vx --stage unit --lax
CUDA_TEST=1 PYRO_DTYPE=float64 PYRO_DEVICE=cuda pytest -vx --stage unit --lax
CUDA_TEST=1 pytest -vx tests/test_examples.py::test_cuda

test-jit: FORCE
Expand Down
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,6 @@ def setup(app):
if "READTHEDOCS" in os.environ:
os.system("pip install numpy")
os.system(
"pip install torch==1.11.0+cpu torchvision==0.12.0+cpu "
"pip install torch==2.0+cpu torchvision==0.15.0+cpu "
"-f https://download.pytorch.org/whl/torch_stable.html"
)
2 changes: 1 addition & 1 deletion examples/baseball.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,6 @@ def main(args):
torch.multiprocessing.set_sharing_strategy("file_system")

if args.cuda:
torch.set_default_tensor_type(torch.cuda.FloatTensor)
torch.set_default_device("cuda")

main(args)
2 changes: 1 addition & 1 deletion examples/contrib/cevae/synthetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def generate_data(args):

def main(args):
if args.cuda:
torch.set_default_tensor_type("torch.cuda.FloatTensor")
torch.set_default_device("cuda")

# Generate synthetic data.
pyro.set_rng_seed(args.seed)
Expand Down
9 changes: 3 additions & 6 deletions examples/contrib/epidemiology/regional.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,12 +205,9 @@ def main(args):
if args.warmup_steps is None:
args.warmup_steps = args.num_samples
if args.double:
if args.cuda:
torch.set_default_tensor_type(torch.cuda.DoubleTensor)
else:
torch.set_default_dtype(torch.float64)
elif args.cuda:
torch.set_default_tensor_type(torch.cuda.FloatTensor)
torch.set_default_dtype(torch.float64)
if args.cuda:
torch.set_default_device("cuda")

main(args)

Expand Down
9 changes: 3 additions & 6 deletions examples/contrib/epidemiology/sir.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,12 +391,9 @@ def main(args):
if args.warmup_steps is None:
args.warmup_steps = args.num_samples
if args.double:
if args.cuda:
torch.set_default_tensor_type(torch.cuda.DoubleTensor)
else:
torch.set_default_dtype(torch.float64)
elif args.cuda:
torch.set_default_tensor_type(torch.cuda.FloatTensor)
torch.set_default_dtype(torch.float64)
if args.cuda:
torch.set_default_device("cuda")

main(args)

Expand Down
2 changes: 1 addition & 1 deletion examples/contrib/funsor/hmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,7 +670,7 @@ def model_7(sequences, lengths, args, batch_size=None, include_prior=True):

def main(args):
if args.cuda:
torch.set_default_tensor_type("torch.cuda.FloatTensor")
torch.set_default_device("cuda")

logging.info("Loading data")
data = poly.load_data(poly.JSB_CHORALES)
Expand Down
5 changes: 2 additions & 3 deletions examples/contrib/mue/FactorMuE.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,9 +427,8 @@ def main(args):
)
args = parser.parse_args()

torch.set_default_dtype(torch.float64)
if args.cuda:
torch.set_default_tensor_type(torch.cuda.DoubleTensor)
else:
torch.set_default_dtype(torch.float64)
torch.set_default_device("cuda")

main(args)
5 changes: 2 additions & 3 deletions examples/contrib/mue/ProfileHMM.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,9 +316,8 @@ def main(args):
)
args = parser.parse_args()

torch.set_default_dtype(torch.float64)
if args.cuda:
torch.set_default_tensor_type(torch.cuda.DoubleTensor)
else:
torch.set_default_dtype(torch.float64)
torch.set_default_device("cuda")

main(args)
4 changes: 1 addition & 3 deletions examples/einsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,7 @@ def time_fn(fn, equation, *operands, **kwargs):

def main(args):
if args.cuda:
torch.set_default_tensor_type("torch.cuda.FloatTensor")
else:
torch.set_default_tensor_type("torch.FloatTensor")
torch.set_default_device("cuda")

if args.method == "all":
for method in ["prob", "logprob", "gradient", "marginal", "map", "sample"]:
Expand Down
2 changes: 1 addition & 1 deletion examples/hmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,7 +620,7 @@ def model_7(sequences, lengths, args, batch_size=None, include_prior=True):

def main(args):
if args.cuda:
torch.set_default_tensor_type("torch.cuda.FloatTensor")
torch.set_default_device("cuda")

logging.info("Loading data")
data = poly.load_data(poly.JSB_CHORALES)
Expand Down
9 changes: 3 additions & 6 deletions examples/sir_hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,12 +663,9 @@ def main(args):
args = parser.parse_args()

if args.double:
if args.cuda:
torch.set_default_tensor_type(torch.cuda.DoubleTensor)
else:
torch.set_default_tensor_type(torch.DoubleTensor)
elif args.cuda:
torch.set_default_tensor_type(torch.cuda.FloatTensor)
torch.set_default_dtype(torch.float64)
if args.cuda:
torch.set_default_device("cuda")

main(args)

Expand Down
2 changes: 1 addition & 1 deletion examples/sparse_gamma_def.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from pyro.infer import SVI, TraceMeanField_ELBO
from pyro.infer.autoguide import AutoDiagonalNormal, init_to_feasible

torch.set_default_tensor_type("torch.FloatTensor")
torch.set_default_dtype(torch.float32)
pyro.util.set_rng_seed(0)


Expand Down
2 changes: 1 addition & 1 deletion examples/sparse_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
"""


torch.set_default_tensor_type("torch.FloatTensor")
torch.set_default_dtype(torch.float32)


def dot(X, Z):
Expand Down
2 changes: 1 addition & 1 deletion examples/svi_horovod.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def main(args):
if args.cuda:
torch.cuda.set_device(hvd.local_rank())
if args.cuda:
torch.set_default_tensor_type("torch.cuda.FloatTensor")
torch.set_default_device("cuda")
device = torch.tensor(0).device

if args.horovod:
Expand Down
2 changes: 1 addition & 1 deletion examples/svi_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import argparse

import pytorch_lightning as pl
import lightning.pytorch as pl
import torch

import pyro
Expand Down
2 changes: 1 addition & 1 deletion profiler/gaussianhmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def random_mvn(batch_shape, dim, requires_grad=False):

def main(args):
if args.cuda:
torch.set_default_tensor_type("torch.cuda.FloatTensor")
torch.set_default_device("cuda")

hidden_dim = args.hidden_dim
obs_dim = args.obs_dim
Expand Down
2 changes: 1 addition & 1 deletion pyro/contrib/gp/parameterized.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ class Parameterized(PyroModule):
>>> assert "b_scale_unconstrained" in dict(linear.named_parameters())
Note that by default, data of a parameter is a float :class:`torch.Tensor`
(unless we use :func:`torch.set_default_tensor_type` to change default
(unless we use :func:`torch.set_default_dtype` to change default
tensor type). To cast these parameters to a correct data type or GPU device,
we can call methods such as :meth:`~torch.nn.Module.double` or
:meth:`~torch.nn.Module.cuda`. See :class:`torch.nn.Module` for more
Expand Down
6 changes: 4 additions & 2 deletions pyro/infer/mcmc/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,13 +107,15 @@ def __init__(
self.rng_seed = (torch.initial_seed() + chain_id) % MAX_SEED
self.log_queue = log_queue
self.result_queue = result_queue
self.default_tensor_type = torch.Tensor().type()
self.default_dtype = torch.Tensor().dtype
self.default_device = torch.Tensor().device
self.hook = hook
self.event = event

def run(self, *args, **kwargs):
pyro.set_rng_seed(self.rng_seed)
torch.set_default_tensor_type(self.default_tensor_type)
torch.set_default_dtype(self.default_dtype)
torch.set_default_device(self.default_device)
kwargs = kwargs
logger = logging.getLogger("pyro.infer.mcmc")
logger_id = "CHAIN:{}".format(self.chain_id)
Expand Down
10 changes: 0 additions & 10 deletions pyro/ops/provenance.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,6 @@ def _track_provenance_set(x, provenance: frozenset):
@track_provenance.register(tuple)
@track_provenance.register(dict)
def _track_provenance_pytree(x, provenance: frozenset):
# avoid max-recursion depth error for torch<=2.0
flat_args, _ = tree_flatten(x)
if not flat_args or flat_args[0] is x:
return x

return tree_map(partial(track_provenance, provenance=provenance), x)


Expand Down Expand Up @@ -143,11 +138,6 @@ def _extract_provenance_set(x):
@extract_provenance.register(tuple)
@extract_provenance.register(dict)
def _extract_provenance_pytree(x):
# avoid max-recursion depth error for torch<=2.0
flat_args, _ = tree_flatten(x)
if not flat_args or flat_args[0] is x:
return x, frozenset()

flat_args, spec = tree_flatten(x)
xs = []
provenance = frozenset()
Expand Down
6 changes: 1 addition & 5 deletions pyro/optim/pytorch_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,7 @@
del _PyroOptim

# Load all schedulers from PyTorch
# breaking change in torch >= 1.14: LRScheduler is new base class
if hasattr(torch.optim.lr_scheduler, "LRScheduler"):
_torch_scheduler_base = torch.optim.lr_scheduler.LRScheduler # type: ignore
else: # for torch < 1.13, _LRScheduler is base class
_torch_scheduler_base = torch.optim.lr_scheduler._LRScheduler # type: ignore
_torch_scheduler_base = torch.optim.lr_scheduler.LRScheduler # type: ignore

for _name, _Optim in torch.optim.lr_scheduler.__dict__.items():
if not isinstance(_Optim, type):
Expand Down
8 changes: 4 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,10 @@
"jupyter>=1.0.0",
"graphviz>=0.8",
"matplotlib>=1.3",
"torchvision>=0.12.0",
"torchvision>=0.15.0",
"visdom>=0.1.4,<0.2.2", # FIXME visdom.utils is unavailable >=0.2.2
"pandas",
"pillow==8.2.0", # https://github.com/pytorch/pytorch/issues/61125
"pillow>=8.3.1", # https://github.com/pytorch/pytorch/issues/61125
"scikit-learn",
"seaborn>=0.11.0",
"wget",
Expand Down Expand Up @@ -102,7 +102,7 @@
"numpy>=1.7",
"opt_einsum>=2.3.2",
"pyro-api>=0.1.1",
"torch>=1.11.0",
"torch>=2.0",
"tqdm>=4.36",
],
extras_require={
Expand Down Expand Up @@ -135,7 +135,7 @@
"yapf",
],
"horovod": ["horovod[pytorch]>=0.19"],
"lightning": ["pytorch_lightning"],
"lightning": ["lightning"],
"funsor": [
# This must be a released version when Pyro is released.
# "funsor[torch] @ git+git://github.com/pyro-ppl/funsor.git@7bb52d0eae3046d08a20d1b288544e1a21b4f461",
Expand Down
23 changes: 3 additions & 20 deletions tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,11 @@ def wrapper(*args, **kwargs):
)

try:
import pytorch_lightning
import lightning
except ImportError:
pytorch_lightning = None
lightning = None
requires_lightning = pytest.mark.skipif(
pytorch_lightning is None, reason="pytorch lightning is not available"
lightning is None, reason="pytorch lightning is not available"
)

try:
Expand All @@ -93,23 +93,6 @@ def get_gpu_type(t):
return getattr(torch.cuda, t.__name__)


@contextlib.contextmanager
def tensors_default_to(host):
"""
Context manager to temporarily use Cpu or Cuda tensors in PyTorch.
:param str host: Either "cuda" or "cpu".
"""
assert host in ("cpu", "cuda"), host
old_module, name = torch.Tensor().type().rsplit(".", 1)
new_module = "torch.cuda" if host == "cuda" else "torch"
torch.set_default_tensor_type("{}.{}".format(new_module, name))
try:
yield
finally:
torch.set_default_tensor_type("{}.{}".format(old_module, name))


@contextlib.contextmanager
def default_dtype(dtype):
"""
Expand Down
4 changes: 3 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@

import pyro

torch.set_default_tensor_type(os.environ.get("PYRO_TENSOR_TYPE", "torch.DoubleTensor"))
DTYPE = getattr(torch, os.environ.get("PYRO_DTYPE", "float64"))
torch.set_default_dtype(DTYPE)
torch.set_default_device(os.environ.get("PYRO_DEVICE", "cpu"))


def pytest_configure(config):
Expand Down
2 changes: 1 addition & 1 deletion tests/contrib/timeseries/test_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
)
@pytest.mark.parametrize("T", [11, 37])
def test_timeseries_models(model, nu_statedim, obs_dim, T):
torch.set_default_tensor_type("torch.DoubleTensor")
torch.set_default_dtype(torch.float64)
dt = 0.1 + torch.rand(1).item()

if model == "lcmgp":
Expand Down
2 changes: 1 addition & 1 deletion tests/contrib/timeseries/test_lgssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
@pytest.mark.parametrize("obs_dim", [2, 4])
@pytest.mark.parametrize("T", [11, 17])
def test_generic_lgssm_forecast(model_class, state_dim, obs_dim, T):
torch.set_default_tensor_type("torch.DoubleTensor")
torch.set_default_dtype(torch.float64)

if model_class == "lgssm":
model = GenericLGSSM(
Expand Down

0 comments on commit fa73d9c

Please sign in to comment.