diff --git a/fbgemm_gpu/codegen/training/python/split_embedding_optimizer_codegen.template b/fbgemm_gpu/codegen/training/python/split_embedding_optimizer_codegen.template index 1450d99a82..dc59c8b966 100644 --- a/fbgemm_gpu/codegen/training/python/split_embedding_optimizer_codegen.template +++ b/fbgemm_gpu/codegen/training/python/split_embedding_optimizer_codegen.template @@ -21,23 +21,38 @@ from torch.optim.optimizer import Optimizer import logging {%- if is_fbcode %} + +def _load_library(unified_path: str, cuda_path: str, hip_path: str) -> None: + try: + torch.ops.load_library(unified_path) + except Exception: + # Load the old paths for backwards compatibility + if torch.version.hip: + torch.ops.load_library(hip_path) + else: + torch.ops.load_library(cuda_path) + torch.ops.load_library( "//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops_cpu_training" ) +_load_library( + "//deeplearning/fbgemm/fbgemm_gpu/codegen:optimizer_ops", + "//deeplearning/fbgemm/fbgemm_gpu/codegen:optimizer_ops_cuda", + "//deeplearning/fbgemm/fbgemm_gpu/codegen:optimizer_ops_hip", +) + if torch.version.hip: torch.ops.load_library( "//deeplearning/fbgemm/fbgemm_gpu:split_table_batched_embeddings_hip" ) - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu/codegen:optimizer_ops_hip") torch.ops.load_library( "//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops_hip_training" ) else: torch.ops.load_library( "//deeplearning/fbgemm/fbgemm_gpu:split_table_batched_embeddings" - ) - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu/codegen:optimizer_ops") + ) torch.ops.load_library( "//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops_cuda_training" )