diff --git a/torch_xla/csrc/dtype.cpp b/torch_xla/csrc/dtype.cpp index 918c0ba6515d..103630bcec8f 100644 --- a/torch_xla/csrc/dtype.cpp +++ b/torch_xla/csrc/dtype.cpp @@ -80,7 +80,8 @@ bool IsTpuDevice(XlaDeviceType hw_type) { (hw_type == XlaDeviceType::SPMD) && // HACK: find a better way to decide if SPMD is actually a TPU without // accessing the runtime. - runtime::sys_util::GetEnvString("PJRT_DEVICE", "") == "TPU"; + runtime::sys_util::GetEnvString("PJRT_DEVICE", "").find("TPU") != + std::string::npos; return (hw_type == XlaDeviceType::TPU) || spmd_device_is_tpu; }