From ecbb6d0bd88551b8a16890c96ee1cd928cba7d62 Mon Sep 17 00:00:00 2001
From: Justin Chu <justinchuby@users.noreply.github.com>
Date: Tue, 14 Jan 2025 16:05:51 -0800
Subject: [PATCH 1/7] [torchlib] Implement type promotion

---
 .../torch_lib/_type_promotion.py              | 65 +++++++++++++++++++
 1 file changed, 65 insertions(+)
 create mode 100644 onnxscript/function_libs/torch_lib/_type_promotion.py

diff --git a/onnxscript/function_libs/torch_lib/_type_promotion.py b/onnxscript/function_libs/torch_lib/_type_promotion.py
new file mode 100644
index 0000000000..239617c13f
--- /dev/null
+++ b/onnxscript/function_libs/torch_lib/_type_promotion.py
@@ -0,0 +1,65 @@
+"""Type promotion functions for op implementations."""
+
+from typing import Sequence
+from onnxscript import ir
+
+def _get_higher_dtype(a: ir.DataType, b: ir.DataType) -> ir.DataType:
+    """Get the higher dtype of two dtypes."""
+    # Reference: https://github.com/pytorch/pytorch/blob/bdd942efd76e74baa5dd0a262f7c843ddfe2e11b/torch/_prims_common/__init__.py#L1160
+    if a == b:
+        return a
+
+    if a is None:
+        return b
+
+    if b is None:
+        return a
+
+    ordered_datatypes = (
+        (ir.DataType.BOOL,),
+        (ir.DataType.UINT8, ir.DataType.INT8),
+        (ir.DataType.INT16,),
+        (ir.DataType.INT32,),
+        (ir.DataType.INT64,),
+        (ir.DataType.FLOAT16, ir.DataType.BFLOAT16),
+        (ir.DataType.FLOAT,),
+        (ir.DataType.DOUBLE,),
+        (ir.DataType.COMPLEX64,),
+        (ir.DataType.COMPLEX128,),
+    )
+
+    for idx, dtypes in enumerate(ordered_datatypes):
+        if a in dtypes and b in dtypes:
+            return ordered_datatypes[idx + 1][0]
+        if a in dtypes:
+            return b
+        if b in dtypes:
+            return a
+
+    raise ValueError(f"Unexpected data types: {a}, {b}")
+
+
+def promote_types(op, values: Sequence[ir.Value]) -> Sequence[ir.Value]:
+    """Promote the types of the given values."""
+    if not values:
+        return ()
+
+    for value in values:
+        if value.dtype is None:
+            raise ValueError(f"Value {value} does not have dtype information and cannot be promoted.")
+
+    promoted = values[0].dtype
+    assert promoted is not None
+    for value in values[1:]:
+        dtype = value.dtype
+        assert dtype is not None
+        promoted = _get_higher_dtype(promoted, dtype)
+
+    results = []
+    for value in values:
+        if value.dtype != promoted:
+            results.append(op.Cast(value, to=promoted))
+        else:
+            results.append(value)
+
+    return results

From b507bf52a0f8eb127b9eb986e7b1fc8696c52698 Mon Sep 17 00:00:00 2001
From: Justin Chu <justinchuby@users.noreply.github.com>
Date: Tue, 14 Jan 2025 16:26:27 -0800
Subject: [PATCH 2/7] wip

---
 .../function_libs/torch_lib/_type_promotion.py   |  5 ++++-
 onnxscript/function_libs/torch_lib/ops/core.py   | 16 +++++++++++-----
 2 files changed, 15 insertions(+), 6 deletions(-)

diff --git a/onnxscript/function_libs/torch_lib/_type_promotion.py b/onnxscript/function_libs/torch_lib/_type_promotion.py
index 239617c13f..28ab83e64e 100644
--- a/onnxscript/function_libs/torch_lib/_type_promotion.py
+++ b/onnxscript/function_libs/torch_lib/_type_promotion.py
@@ -58,7 +58,10 @@ def promote_types(op, values: Sequence[ir.Value]) -> Sequence[ir.Value]:
     results = []
     for value in values:
         if value.dtype != promoted:
-            results.append(op.Cast(value, to=promoted))
+            new_val = op.Cast(value, to=promoted)
+            new_val.dtype = promoted
+            new_val.shape = value.shape
+            results.append(new_val)
         else:
             results.append(value)
 
diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py
index a1793858e9..97cfd315b8 100644
--- a/onnxscript/function_libs/torch_lib/ops/core.py
+++ b/onnxscript/function_libs/torch_lib/ops/core.py
@@ -48,6 +48,7 @@
     TTensor2,
     TTensorOrString,
 )
+from onnxscript.function_libs.torch_lib import _type_promotion
 from onnxscript.onnx_opset import opset18 as op
 from onnxscript.onnx_types import TensorType
 
@@ -160,9 +161,9 @@ def aten_acosh(self: TFloat) -> TFloat:
 
 
 @torch_op(("aten::add.Tensor", "aten::add.Scalar", "_operator::add"), trace_only=True)
-def aten_add(self: TReal, other: TReal, alpha: float = 1.0) -> TReal:
+def aten_add(self: TTensor, other: TTensor2, alpha: float = 1.0) -> TensorType:
     """add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor"""
-    # TODO(microsoft/onnxruntime#15977): Improve fp16 precision
+    self, other = _type_promotion.promote_types(op, [self, other])
     if alpha != 1.0:
         alpha = op.CastLike(alpha, other)
         other = op.Mul(other, alpha)
@@ -175,6 +176,7 @@ def aten_add(self: TReal, other: TReal, alpha: float = 1.0) -> TReal:
 def aten_add_complex(self: TReal, other: TReal, alpha: float = 1.0) -> TReal:
     """add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor"""
 
+    # TODO(justinchuby): Type promotion for complex numbers
     return aten_add(self, other, alpha=alpha)
 
 
@@ -203,12 +205,14 @@ def aten_addbmm(
 
 
 @torch_op("aten::addcdiv")
-def aten_addcdiv(self: TFloat, tensor1: TFloat, tensor2: TFloat, value: float = 1.0) -> TFloat:
+def aten_addcdiv(self: TensorType, tensor1: TensorType, tensor2: TensorType, value: float = 1.0) -> TensorType:
     """addcdiv(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor
 
     Performs the element-wise division of tensor1 by tensor2, multiplies the result
     by the scalar value and adds it to self.
     """
+    # FIXME: Int to float
+    self, tensor1, tensor2 = _type_promotion.promote_types(op, [self, tensor1, tensor2])
 
     return op.Add(self, op.Mul(op.Div(tensor1, tensor2), value))
 
@@ -225,6 +229,7 @@ def aten_addcmul(
     Performs the element-wise multiplication of tensor1 by tensor2, multiplies the
     result by the scalar value and adds it to self.
     """
+    self, tensor1, tensor2 = _type_promotion.promote_types(op, [self, tensor1, tensor2])
 
     # Follow the order in https://github.com/pytorch/pytorch/blob/29e3fddb082b5a14262a7246bc62381a55199d45/aten/src/ATen/native/cpu/PointwiseOpsKernel.cpp#L47
     # TODO(#811): Understand fp16 accuracy issue
@@ -258,12 +263,13 @@ def aten_addmv(
 
 @torch_op("aten::addr", traceable=True)
 def aten_addr(
-    self: TReal, vec1: TReal, vec2: TReal, beta: float = 1.0, alpha: float = 1.0
-) -> TReal:
+    self: TensorType, vec1: TensorType, vec2: TensorType, beta: float = 1.0, alpha: float = 1.0
+) -> TensorType:
     """addr(Tensor self, Tensor vec1, Tensor vec2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
 
     Performs the outer-product of vectors vec1 and vec2 and adds it to the matrix input.
     """
+    self, vec1, vec2 = _type_promotion.promote_types(op, [self, vec1, vec2])
     vec1_shape = op.Constant(value_ints=[-1, 1])
     vec2_shape = op.Constant(value_ints=[1, -1])
     vec1_reshaped = op.Reshape(vec1, vec1_shape)

From 1033ac50c1ce9a22991bd76dd840da34cd0ee310 Mon Sep 17 00:00:00 2001
From: Justin Chu <justinchuby@users.noreply.github.com>
Date: Wed, 15 Jan 2025 12:48:14 -0800
Subject: [PATCH 3/7] , trace_only=True

---
 onnxscript/function_libs/torch_lib/ops/core.py | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py
index 97cfd315b8..afebde5ad2 100644
--- a/onnxscript/function_libs/torch_lib/ops/core.py
+++ b/onnxscript/function_libs/torch_lib/ops/core.py
@@ -204,7 +204,7 @@ def aten_addbmm(
     return op.Add(scaled_self, op.Mul(reduced_batches, alpha))
 
 
-@torch_op("aten::addcdiv")
+@torch_op("aten::addcdiv", trace_only=True)
 def aten_addcdiv(self: TensorType, tensor1: TensorType, tensor2: TensorType, value: float = 1.0) -> TensorType:
     """addcdiv(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor
 
@@ -217,7 +217,7 @@ def aten_addcdiv(self: TensorType, tensor1: TensorType, tensor2: TensorType, val
     return op.Add(self, op.Mul(op.Div(tensor1, tensor2), value))
 
 
-@torch_op("aten::addcmul")
+@torch_op("aten::addcmul", trace_only=True)
 def aten_addcmul(
     self: TReal,
     tensor1: TReal,
@@ -261,7 +261,7 @@ def aten_addmv(
     return op.Add(op.Mul(self, beta), op.Mul(op.MatMul(mat, vec), alpha))
 
 
-@torch_op("aten::addr", traceable=True)
+@torch_op("aten::addr", trace_only=True)
 def aten_addr(
     self: TensorType, vec1: TensorType, vec2: TensorType, beta: float = 1.0, alpha: float = 1.0
 ) -> TensorType:

From eeeae08698ff72e501253df737293e3d1ee0fe53 Mon Sep 17 00:00:00 2001
From: Justin Chu <justinchuby@users.noreply.github.com>
Date: Wed, 15 Jan 2025 12:48:37 -0800
Subject: [PATCH 4/7] format

---
 onnxscript/function_libs/torch_lib/ops/core.py | 11 ++++-------
 1 file changed, 4 insertions(+), 7 deletions(-)

diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py
index afebde5ad2..f3130e82cc 100644
--- a/onnxscript/function_libs/torch_lib/ops/core.py
+++ b/onnxscript/function_libs/torch_lib/ops/core.py
@@ -205,7 +205,9 @@ def aten_addbmm(
 
 
 @torch_op("aten::addcdiv", trace_only=True)
-def aten_addcdiv(self: TensorType, tensor1: TensorType, tensor2: TensorType, value: float = 1.0) -> TensorType:
+def aten_addcdiv(
+    self: TensorType, tensor1: TensorType, tensor2: TensorType, value: float = 1.0
+) -> TensorType:
     """addcdiv(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor
 
     Performs the element-wise division of tensor1 by tensor2, multiplies the result
@@ -218,12 +220,7 @@ def aten_addcdiv(self: TensorType, tensor1: TensorType, tensor2: TensorType, val
 
 
 @torch_op("aten::addcmul", trace_only=True)
-def aten_addcmul(
-    self: TReal,
-    tensor1: TReal,
-    tensor2: TReal,
-    value: float = 1.0,
-) -> TReal:
+def aten_addcmul(self: TReal, tensor1: TReal, tensor2: TReal, value: float = 1.0) -> TReal:
     """addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor
 
     Performs the element-wise multiplication of tensor1 by tensor2, multiplies the

From c6e300c00f943b71cae5f64e0878941bb9fecff9 Mon Sep 17 00:00:00 2001
From: Justin Chu <justinchuby@users.noreply.github.com>
Date: Wed, 15 Jan 2025 12:48:58 -0800
Subject: [PATCH 5/7] TensorType

---
 onnxscript/function_libs/torch_lib/ops/core.py | 4 +++-
 1 file changed, 3 insertions(+), 1 deletion(-)

diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py
index f3130e82cc..14aa2dbd89 100644
--- a/onnxscript/function_libs/torch_lib/ops/core.py
+++ b/onnxscript/function_libs/torch_lib/ops/core.py
@@ -220,7 +220,9 @@ def aten_addcdiv(
 
 
 @torch_op("aten::addcmul", trace_only=True)
-def aten_addcmul(self: TReal, tensor1: TReal, tensor2: TReal, value: float = 1.0) -> TReal:
+def aten_addcmul(
+    self: TensorType, tensor1: TensorType, tensor2: TensorType, value: float = 1.0
+) -> TensorType:
     """addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor
 
     Performs the element-wise multiplication of tensor1 by tensor2, multiplies the

From fd915c1495600fb68f6a0f840b03b8a08ace03be Mon Sep 17 00:00:00 2001
From: Justin Chu <justinchuby@users.noreply.github.com>
Date: Wed, 15 Jan 2025 12:53:23 -0800
Subject: [PATCH 6/7] aten_addcdiv

---
 onnxscript/function_libs/torch_lib/ops/core.py | 16 ++++++++++++----
 1 file changed, 12 insertions(+), 4 deletions(-)

diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py
index 14aa2dbd89..90f2d719de 100644
--- a/onnxscript/function_libs/torch_lib/ops/core.py
+++ b/onnxscript/function_libs/torch_lib/ops/core.py
@@ -213,10 +213,15 @@ def aten_addcdiv(
     Performs the element-wise division of tensor1 by tensor2, multiplies the result
     by the scalar value and adds it to self.
     """
-    # FIXME: Int to float
+    # FIXME(justinchuby): Int to float promotion
     self, tensor1, tensor2 = _type_promotion.promote_types(op, [self, tensor1, tensor2])
+    quotient = op.Div(tensor1, tensor2)
+    if value == 1.0:
+        quotient_scaled = quotient
+    else:
+        quotient_scaled = op.Mul(quotient, op.CastLike(value, tensor1))
 
-    return op.Add(self, op.Mul(op.Div(tensor1, tensor2), value))
+    return op.Add(self, quotient_scaled)
 
 
 @torch_op("aten::addcmul", trace_only=True)
@@ -231,8 +236,11 @@ def aten_addcmul(
     self, tensor1, tensor2 = _type_promotion.promote_types(op, [self, tensor1, tensor2])
 
     # Follow the order in https://github.com/pytorch/pytorch/blob/29e3fddb082b5a14262a7246bc62381a55199d45/aten/src/ATen/native/cpu/PointwiseOpsKernel.cpp#L47
-    # TODO(#811): Understand fp16 accuracy issue
-    return op.Add(self, op.Mul(op.Mul(value, tensor1), tensor2))
+    if value == 1.0:
+        tensor_1_scaled = tensor1
+    else:
+        tensor_1_scaled = op.Mul(op.CastLike(value, tensor1), tensor1)
+    return op.Add(self, op.Mul(tensor_1_scaled, tensor2))
 
 
 @torch_op("aten::addmm", trace_only=True)

From 5c45f70fa826fdc80006685907a332cceb1d283c Mon Sep 17 00:00:00 2001
From: Justin Chu <justinchuby@users.noreply.github.com>
Date: Tue, 28 Jan 2025 16:50:12 -0800
Subject: [PATCH 7/7] TensorType

---
 onnxscript/function_libs/torch_lib/ops/core.py | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py
index 5d67b67363..55731600bd 100644
--- a/onnxscript/function_libs/torch_lib/ops/core.py
+++ b/onnxscript/function_libs/torch_lib/ops/core.py
@@ -161,7 +161,7 @@ def aten_acosh(self: TFloat) -> TFloat:
 
 
 @torch_op(("aten::add.Tensor", "aten::add.Scalar", "_operator::add"), trace_only=True)
-def aten_add(self: TTensor, other: TTensor2, alpha: float = 1.0) -> TensorType:
+def aten_add(self: TensorType, other: TensorType, alpha: float = 1.0) -> TensorType:
     """add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor"""
     self, other = _type_promotion.promote_types(op, [self, other])
     if alpha != 1.0:
@@ -171,10 +171,10 @@ def aten_add(self: TTensor, other: TTensor2, alpha: float = 1.0) -> TensorType:
 
 
 @torch_op(("aten::add.Tensor", "aten::add.Scalar"), trace_only=True, complex=True)
-def aten_add_complex(self: TReal, other: TReal, alpha: float = 1.0) -> TReal:
+def aten_add_complex(self: TensorType, other: TensorType, alpha: float = 1.0) -> TensorType:
     """add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor"""
 
-    # TODO(justinchuby): Type promotion for complex numbers
+    self, other = _type_promotion.promote_types(op, [self, other])
     return aten_add(self, other, alpha=alpha)