77import torch
88import torch .nn as nn
99import torch .nn .functional as F
10- from torch .testing ._internal .common_cuda import TEST_CUDA
1110from torch .testing ._internal .common_device_type import (
1211 dtypes ,
1312 dtypesIfCUDA ,
13+ dtypesIfXPU ,
1414 instantiate_device_type_tests ,
1515 largeTensorTest ,
16- onlyCUDA ,
1716 onlyNativeDeviceTypes ,
17+ onlyOn ,
1818 skipCUDAIf ,
1919 skipMeta ,
20+ skipXPUIf ,
2021 TEST_WITH_ROCM ,
2122)
2223from torch .testing ._internal .common_nn import NNTestCase
2930 run_tests ,
3031 set_default_dtype ,
3132 skipIfTorchDynamo ,
33+ TEST_CUDA ,
34+ TEST_XPU ,
35+ )
36+
37+
38+ device_type = (
39+ acc .type if (acc := torch .accelerator .current_accelerator (True )) else "cpu"
3240)
3341
3442
3543class TestEmbeddingNN (NNTestCase ):
3644 _do_cuda_memory_leak_check = True
3745 _do_cuda_non_default_stream = True
3846
39- @unittest .skipIf (not TEST_CUDA , "CUDA unavailable" )
47+ @unittest .skipIf (not TEST_CUDA and not TEST_XPU , "CUDA/XPU unavailable" )
4048 def test_embedding_max_norm_unsorted_repeating_indices (self ):
4149 def create_embedding (device ):
4250 # Seed RNG so we get the same Embedding each time
@@ -48,8 +56,8 @@ def create_embedding(device):
4856 ix = torch .arange (2 , device = "cpu" , dtype = torch .long ).repeat (2000 )
4957 out_cpu = create_embedding ("cpu" )(ix )
5058
51- ix = ix .to ("cuda" )
52- out = create_embedding ("cuda" )(ix )
59+ ix = ix .to (device_type )
60+ out = create_embedding (device_type )(ix )
5361 self .assertEqual (out .cpu (), out_cpu )
5462
5563 def test_embedding_sparse_basic (self ):
@@ -81,9 +89,9 @@ def test_move_sparse_half_embedding(self):
8189 self .assertEqual (embedding .embedding_dim , 3 )
8290 self .assertEqual (embedding .num_embeddings , 10 )
8391
84- if torch .cuda .is_available ():
85- embedding .to ("cuda" )
86- self .assertEqual (embedding .weight .device .type , "cuda" )
92+ if not torch .accelerator .is_available ():
93+ embedding .to (device_type )
94+ self .assertEqual (embedding .weight .device .type , device_type )
8795 embedding .to ("cpu" )
8896 self .assertEqual (embedding .weight .device .type , "cpu" )
8997
@@ -182,11 +190,11 @@ def test_embedding_functional(self):
182190 self .assertEqual (res_old , res_F )
183191
184192 # https://github.com/pytorch/pytorch/issues/130806
185- @unittest .skipIf (not TEST_CUDA , "CUDA not available" )
186- @largeTensorTest ("40GB" , device = "cuda" )
193+ @unittest .skipIf (not TEST_CUDA and not TEST_XPU , "CUDA/XPU not available" )
194+ @largeTensorTest ("40GB" , device = device_type )
187195 def test_large_tensors (self ):
188- input = torch .randint (low = 0 , high = 16032 , size = [131072 ], device = "cuda" )
189- w = torch .randn ([16032 , 16384 ], device = "cuda" )
196+ input = torch .randint (low = 0 , high = 16032 , size = [131072 ], device = device_type )
197+ w = torch .randn ([16032 , 16384 ], device = device_type )
190198 out = torch .nn .functional .embedding (input , w )
191199 self .assertEqual (out .dim (), 2 )
192200 self .assertEqual (out .numel (), 2147483648 )
@@ -308,6 +316,7 @@ def test_embedding_scalar_weight_error(self, device):
308316 torch .nn .functional .embedding (indices , weight )
309317
310318 @dtypesIfCUDA (torch .float16 , torch .float64 )
319+ @dtypesIfXPU (torch .float16 , torch .float64 )
311320 @dtypes (torch .float64 )
312321 def test_embedding_backward (self , device , dtype ):
313322 embedding = nn .Embedding (10 , 3 , sparse = True )
@@ -348,6 +357,7 @@ def test_embedding_backward(self, device, dtype):
348357 else (torch .float , torch .double , torch .half )
349358 )
350359 )
360+ @dtypesIfXPU (torch .float32 , torch .double , torch .half )
351361 @dtypes (torch .float32 )
352362 def test_embedding_max_norm_backward (self , device , dtype ):
353363 # can't use gradcheck since in place renorm makes analytical gradients different from produced ones
@@ -372,6 +382,7 @@ def test_embedding_max_norm_backward(self, device, dtype):
372382 else (torch .float , torch .double , torch .half )
373383 )
374384 )
385+ @dtypesIfXPU (torch .float32 , torch .double , torch .half )
375386 @dtypes (torch .float32 )
376387 def test_embedding_max_norm_fwd_AD (self , device , dtype ):
377388 if torch .device (device ).type == "xla" :
@@ -396,6 +407,7 @@ def test_embedding_max_norm_fwd_AD(self, device, dtype):
396407 else (torch .float , torch .double , torch .half )
397408 )
398409 )
410+ @dtypesIfXPU (torch .float32 , torch .double , torch .half )
399411 @dtypes (torch .float32 )
400412 def test_embedding_padding_idx (self , device , dtype ):
401413 embedding = nn .Embedding (10 , 20 , padding_idx = 0 ).to (device , dtype )
@@ -488,6 +500,7 @@ def test_embedding_padding_idx(self, device, dtype):
488500 @onlyNativeDeviceTypes
489501 @dtypes (torch .float32 , torch .float64 )
490502 @dtypesIfCUDA (torch .half , torch .bfloat16 )
503+ @dtypesIfXPU (torch .half , torch .bfloat16 )
491504 def test_embedding_bag_1D_padding_idx (self , device , dtype ):
492505 num_features = 3
493506 max_indices_per_bag = 10
@@ -632,11 +645,12 @@ def gen_2D_indices_from_1D(
632645 weights .grad , weights_check .grad , msg = msg , atol = atol , rtol = rtol
633646 )
634647
635- @onlyCUDA
648+ @onlyOn ([ "cuda" , "xpu" ])
636649 @dtypes (
637650 torch .bfloat16 ,
638651 )
639652 @largeTensorTest ("80GB" , device = "cuda" )
653+ @largeTensorTest ("80GB" , device = "xpu" )
640654 def test_embedding_backward_large_batch_overflow (self , device , dtype ):
641655 """
642656 Test that embedding_dense_backward handles large batches that exceed INT32_MAX thread IDs.
@@ -708,6 +722,7 @@ def test_embedding_backward_large_batch_overflow(self, device, dtype):
708722 @onlyNativeDeviceTypes
709723 @dtypes (torch .float32 , torch .float64 )
710724 @dtypesIfCUDA (torch .half , torch .bfloat16 )
725+ @dtypesIfXPU (torch .half , torch .bfloat16 )
711726 def test_embedding_bag_2D_padding_idx (self , device , dtype ):
712727 # Use a Python implementation of embedding_bag with padding_idx support
713728 # to check torch.nn.functional.embedding_bag correctness
@@ -818,7 +833,7 @@ def embedding_bag_check(indices, weights, mode, sparse, padding_idx):
818833 rtol = None
819834 self .assertEqual (grad , grad_check , msg = msg , atol = atol , rtol = rtol )
820835
821- @onlyCUDA
836+ @onlyOn ([ "cuda" , "xpu" ])
822837 @dtypes (
823838 * (
824839 (torch .float , torch .double , torch .bfloat16 , torch .half )
@@ -854,6 +869,7 @@ def test_embedding_bag_empty_input(self, device, dtypes):
854869 self .assertEqual (output , torch .zeros_like (output ))
855870
856871 @skipCUDAIf (True , "no out-of-bounds check on CUDA for perf." )
872+ @skipXPUIf (True , "no out-of-bounds check on XPU for perf." )
857873 @dtypes (* itertools .product ((torch .float , torch .double ), (torch .int , torch .long )))
858874 @parametrize_test ("padding_idx" , [None , 0 ])
859875 @parametrize_test ("mode" , ["sum" , "mean" , "max" ])
@@ -1066,6 +1082,13 @@ def _embedding_bag_reference_impl(
10661082 (torch .float , torch .double , torch .half ),
10671083 )
10681084 )
1085+ @dtypesIfXPU (
1086+ * itertools .product (
1087+ (torch .int , torch .long ),
1088+ (torch .int , torch .long ),
1089+ (torch .float32 , torch .double , torch .half ),
1090+ )
1091+ )
10691092 def test_EmbeddingBag_empty_per_sample_weights_and_offsets (self , device , dtypes ):
10701093 # Test empty input and per sample weight, and backward pass. There was a CUDA
10711094 # invalid configuration bug (more context in #46572)
@@ -1132,6 +1155,13 @@ def test_per_sample_weights(mode, trainable_scale):
11321155 (torch .float , torch .double , torch .half ),
11331156 )
11341157 )
1158+ @dtypesIfXPU (
1159+ * itertools .product (
1160+ (torch .int , torch .long ),
1161+ (torch .int , torch .long ),
1162+ (torch .float32 , torch .double , torch .half ),
1163+ )
1164+ )
11351165 def test_EmbeddingBag_per_sample_weights_and_offsets (self , device , dtypes ):
11361166 def test_per_sample_weights (mode , trainable_scale ):
11371167 es = nn .EmbeddingBag (5 , 2 , mode = mode ).to (dtype = dtypes [2 ], device = device )
@@ -1193,6 +1223,13 @@ def test_per_sample_weights(mode, trainable_scale):
11931223 (torch .float , torch .double , torch .half ),
11941224 )
11951225 )
1226+ @dtypesIfXPU (
1227+ * itertools .product (
1228+ (torch .int , torch .long ),
1229+ (torch .int , torch .long ),
1230+ (torch .float32 , torch .double , torch .half ),
1231+ )
1232+ )
11961233 def test_EmbeddingBag_per_sample_weights_and_new_offsets (self , device , dtypes ):
11971234 def test_per_sample_weights_new_offsets (
11981235 mode , trainable_scale , include_last_offset , has_weight = True
@@ -1357,6 +1394,11 @@ def _test_EmbeddingBag_vs_Embedding(
13571394 (torch .int , torch .long ), (torch .half , torch .float , torch .double )
13581395 )
13591396 )
1397+ @dtypesIfXPU (
1398+ * itertools .product (
1399+ (torch .int , torch .long ), (torch .half , torch .float32 , torch .double )
1400+ )
1401+ )
13601402 @dtypes (* itertools .product ((torch .int , torch .long ), (torch .float , torch .double )))
13611403 def test_EmbeddingBag_per_sample_weights_and_no_offsets (self , device , dtypes ):
13621404 def run_tests (mode , sparse , trainable_per_sample_weights ):
@@ -1390,8 +1432,8 @@ def run_tests(mode, sparse, trainable_per_sample_weights):
13901432 ):
13911433 run_tests (mode , sparse , trainable_per_sample_weights )
13921434
1393- # Test CUDA Dense on half precision
1394- if device == "cuda " :
1435+ # Test CUDA/XPU Dense on half precision
1436+ if device != "cpu " :
13951437 modes = ("sum" ,)
13961438 sparsity = (False ,)
13971439 trainable_scale = (True , False )
@@ -1552,9 +1594,18 @@ def _test_EmbeddingBag(
15521594 (torch .float , torch .double , torch .half ),
15531595 )
15541596 )
1597+ @dtypesIfXPU (
1598+ * itertools .product (
1599+ (torch .int , torch .long ),
1600+ (torch .int , torch .long ),
1601+ (torch .float32 , torch .double , torch .half ),
1602+ )
1603+ )
15551604 def test_embedding_bag_device (self , device , dtypes ):
15561605 if IS_JETSON and torch .bfloat16 in dtypes and device == "cpu" :
15571606 self .skipTest ("bfloat16 not supported with Jetson cpu" )
1607+ if dtypes [2 ] == torch .float64 and "xpu" in device :
1608+ self .skipTest ("https://github.com/intel/torch-xpu-ops/issues/2295" )
15581609 with set_default_dtype (torch .double ):
15591610 self ._test_EmbeddingBag (
15601611 device ,
@@ -1582,10 +1633,10 @@ def test_embedding_bag_device(self, device, dtypes):
15821633 )
15831634
15841635 test_backward = False
1585- if self .device_type == "cuda " :
1636+ if self .device_type != "cpu " :
15861637 # see 'todo' in test_embedding_bag.
15871638 test_backward = dtypes [2 ] is not torch .float16
1588- elif self . device_type == "cpu" :
1639+ else :
15891640 # TODO: figure out why precision on sparse embeddings isn't the
15901641 # same as for dense.
15911642 test_backward = (
@@ -1626,6 +1677,13 @@ def test_embedding_bag_device(self, device, dtypes):
16261677 (torch .float , torch .double , torch .half ),
16271678 )
16281679 )
1680+ @dtypesIfXPU (
1681+ * itertools .product (
1682+ (torch .int , torch .long ),
1683+ (torch .int , torch .long ),
1684+ (torch .float32 , torch .double , torch .half ),
1685+ )
1686+ )
16291687 def test_embedding_bag_non_contiguous_weight (self , device , dtypes ):
16301688 weight_tensor = torch .randn (3 , 4 , dtype = dtypes [2 ], device = device )
16311689
@@ -1703,7 +1761,7 @@ def test_embedding_bag_per_sample_weights_grad(
17031761 bag (x , per_sample_weights = F .softmax (w , dim = - 1 ))
17041762
17051763
1706- instantiate_device_type_tests (TestEmbeddingNNDeviceType , globals ())
1764+ instantiate_device_type_tests (TestEmbeddingNNDeviceType , globals (), allow_xpu = True )
17071765instantiate_parametrized_tests (TestEmbeddingNN )
17081766
17091767if __name__ == "__main__" :
0 commit comments