diff --git a/intermediate_source/neural_tangent_kernels.py b/intermediate_source/neural_tangent_kernels.py index 62a49794af..d70d5c5dca 100644 --- a/intermediate_source/neural_tangent_kernels.py +++ b/intermediate_source/neural_tangent_kernels.py @@ -13,7 +13,7 @@ .. note:: - This tutorial requires PyTorch 2.0.0 or later. + This tutorial requires PyTorch 2.6.0 or later. Setup ----- @@ -24,7 +24,12 @@ import torch import torch.nn as nn from torch.func import functional_call, vmap, vjp, jvp, jacrev -device = 'cuda' if torch.cuda.device_count() > 0 else 'cpu' + +if torch.accelerator.is_available() and torch.accelerator.device_count() > 0: + device = torch.accelerator.current_accelerator() +else: + device = torch.device("cpu") + class CNN(nn.Module): def __init__(self):