diff --git a/test/test_operations.py b/test/test_operations.py index f95c387f4357..6adeee09aacf 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -1374,6 +1374,19 @@ def test_fn(t, c): ), dtype=torch.int64) self.runAtenTest([token_type_ids, cat_ids], test_fn) + def test_one_hot_no_fallback(self): + + def test_fn(t): + met.clear_all() + res = F.one_hot(t, num_classes=5) + # make sure there is no graph break + assert 'aten::' not in met.short_metrics_report() + return res + + t1 = torch.arange(0, 5) % 3 + + self.runAtenTest([t1], test_fn) + @skipIfFunctionalizationEnabled("views do not exist") def test_save_view_alias_check(self):