From e622560106c1540ea5a5f3637f1dde4eeb87eefa Mon Sep 17 00:00:00 2001 From: jafraustro Date: Thu, 18 Sep 2025 15:15:43 -0700 Subject: [PATCH] Add Accelerator API Signed-off-by: jafraustro --- intermediate_source/neural_tangent_kernels.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/intermediate_source/neural_tangent_kernels.py b/intermediate_source/neural_tangent_kernels.py index 62a49794af5..d70d5c5dca3 100644 --- a/intermediate_source/neural_tangent_kernels.py +++ b/intermediate_source/neural_tangent_kernels.py @@ -13,7 +13,7 @@ .. note:: - This tutorial requires PyTorch 2.0.0 or later. + This tutorial requires PyTorch 2.6.0 or later. Setup ----- @@ -24,7 +24,12 @@ import torch import torch.nn as nn from torch.func import functional_call, vmap, vjp, jvp, jacrev -device = 'cuda' if torch.cuda.device_count() > 0 else 'cpu' + +if torch.accelerator.is_available() and torch.accelerator.device_count() > 0: + device = torch.accelerator.current_accelerator() +else: + device = torch.device("cpu") + class CNN(nn.Module): def __init__(self):