# `torch2jax` - demonstrating input/output shapes specification

In [1]:
import torch
import jax

from torch2jax import tree_t2j, torch2jax_with_vjp
count_format = lambda i: {1: "1st", 2: "2nd", 3: "3rd"}[i] if i < 4 else f"{i}th"

### Single output function

In [2]:
# single output example ############################################################################
call_count = 0


def fn(a, b):
    global call_count
    call_count += 1
    print(f"I get evaluated {count_format(call_count)} time")
    return a + b + torch.ones_like(a)


a_torch, b_torch = torch.randn(10), torch.randn(10)
a_jax, b_jax = tree_t2j((a_torch, b_torch))

####################################################################################################

# Case 1. we do not specify output_shapes
# torch2jax will call the PyTorch function with example arguments and infer the output_shapes from 
# the output
#
# torch2jax DOES call the PyTorch function with example arguments
print("Case 1.")
jax_fn = torch2jax_with_vjp(fn, a_torch, b_torch)

# this would not work because torch2jax will try to call the PyTorch function with jax arguments
# jax_fn = torch2jax_with_vjp(fn, a_jax, b_jax) 

print(jax_fn(a_jax, b_jax))
call_count = 0

####################################################################################################

# Case 2. we specify output_shapes, but not dtypes
# torch2jax will guess that the dtype of all outputs is the same as floating point dtype of any of 
# the input
#
# torch2jax DOES NOT call the PyTorch function with example arguments
output_shapes = torch.Size(a_jax.shape)
jax_fn = torch2jax_with_vjp(fn, a_torch, b_torch, output_shapes=output_shapes)
print("\n\nCase 2.")
print(jax_fn(a_jax, b_jax))
call_count = 0

####################################################################################################
# Case 3. we specify output_shapes and dtypes
# torch2jax does not have to guess anything, example arguments need only carry
# information about shapes and dtypes
#
# torch2jax DOES NOT call the PyTorch function with example arguments

# Case 3. (a) passing torch example arguments - works!
print("\n\nCase 3. (a)")
output_shapes = jax.ShapeDtypeStruct(a_jax.shape, a_jax.dtype)
jax_fn = torch2jax_with_vjp(fn, a_torch, b_torch, output_shapes=output_shapes)
print(jax_fn(a_jax, b_jax))
call_count = 0

# Case 3. (b) passing jax example arguments - works! (because jax arguments have
# both a dtype and a shape)
print("\n\nCase 3. (b)")
jax_fn = torch2jax_with_vjp(fn, a_jax, b_jax, output_shapes=output_shapes)
print(jax_fn(a_jax, b_jax))
call_count = 0

# Case 3. (c) passing shape and dtype structs - works! (because shape and dtype
# structs have both a dtype and a shape)
print("\n\nCase 3. (c)")
a_shape = jax.ShapeDtypeStruct(a_jax.shape, a_jax.dtype)
b_shape = jax.ShapeDtypeStruct(b_jax.shape, b_jax.dtype)
jax_fn = torch2jax_with_vjp(fn, a_shape, b_shape, output_shapes=output_shapes)
print(jax_fn(a_jax, b_jax))
call_count = 0

Case 1.
I get evaluated 1st time
I get evaluated 2nd time
[ 1.4628059  2.9096096 -0.5187819  4.7829256 -0.576247   2.7845862
  1.3225893 -1.0953598 -1.8378775  1.493313 ]


Case 2.
I get evaluated 1st time
[ 1.4628059  2.9096096 -0.5187819  4.7829256 -0.576247   2.7845862
  1.3225893 -1.0953598 -1.8378775  1.493313 ]


Case 3. (a)
I get evaluated 1st time
[ 1.4628059  2.9096096 -0.5187819  4.7829256 -0.576247   2.7845862
  1.3225893 -1.0953598 -1.8378775  1.493313 ]


Case 3. (b)
I get evaluated 1st time
[ 1.4628059  2.9096096 -0.5187819  4.7829256 -0.576247   2.7845862
  1.3225893 -1.0953598 -1.8378775  1.493313 ]


Case 3. (c)
I get evaluated 1st time
[ 1.4628059  2.9096096 -0.5187819  4.7829256 -0.576247   2.7845862
  1.3225893 -1.0953598 -1.8378775  1.493313 ]


### Multiple output function

In [3]:
# multiple outputs example #########################################################################
call_count = 0


def fn(a, b):
    global call_count
    call_count += 1
    print(f"I get evaluated {count_format(call_count)} time")
    return a + b + torch.ones_like(a), a - b


a_torch, b_torch = torch.randn(10), torch.randn(10)
a_jax, b_jax = tree_t2j((a_torch, b_torch))

####################################################################################################

# Case 1. we do not specify output_shapes
# torch2jax will call the PyTorch function with example arguments and infer the output_shapes from
# the output
#
# torch2jax DOES call the PyTorch function with example arguments
print("Case 1.")
jax_fn = torch2jax_with_vjp(fn, a_torch, b_torch)

# this would not work because torch2jax will try to call the PyTorch function with jax arguments
# jax_fn = torch2jax_with_vjp(fn, a_jax, b_jax) 

print(jax_fn(a_jax, b_jax))
call_count = 0

####################################################################################################

# Case 2. we specify output_shapes, but not dtypes
# torch2jax will guess that the dtype of all outputs is the same as floating point dtype of any of
# the input
#
# torch2jax DOES NOT call the PyTorch function with example arguments
output_shapes = (torch.Size(a_jax.shape), torch.Size(a_jax.shape))
jax_fn = torch2jax_with_vjp(fn, a_torch, b_torch, output_shapes=output_shapes)
print("\n\nCase 2.")
print(jax_fn(a_jax, b_jax))
call_count = 0

####################################################################################################
# Case 3. we specify output_shapes and dtypes
# torch2jax does not have to guess anything, example arguments need only carry
# information about shapes and dtypes
#
# torch2jax DOES NOT call the PyTorch function with example arguments

# Case 3. (a) passing torch example arguments - works!
print("\n\nCase 3. (a)")
output_shapes = (
    jax.ShapeDtypeStruct(a_jax.shape, a_jax.dtype),
    jax.ShapeDtypeStruct(a_jax.shape, a_jax.dtype),
)
jax_fn = torch2jax_with_vjp(fn, a_torch, b_torch, output_shapes=output_shapes)
print(jax_fn(a_jax, b_jax))
call_count = 0

# Case 3. (b) passing jax example arguments - works! (because jax arguments have
# both a dtype and a shape)
print("\n\nCase 3. (b)")
jax_fn = torch2jax_with_vjp(fn, a_jax, b_jax, output_shapes=output_shapes)
print(jax_fn(a_jax, b_jax))
call_count = 0

# Case 3. (c) passing shape and dtype structs - works! (because shape and dtype
# structs have both a dtype and a shape)
print("\n\nCase 3. (c)")
a_shape = jax.ShapeDtypeStruct(a_jax.shape, a_jax.dtype)
b_shape = jax.ShapeDtypeStruct(b_jax.shape, b_jax.dtype)
jax_fn = torch2jax_with_vjp(fn, a_shape, b_shape, output_shapes=output_shapes)
print(jax_fn(a_jax, b_jax))
call_count = 0

Case 1.
I get evaluated 1st time
I get evaluated 2nd time
(Array([ 2.8916252 ,  0.8241199 ,  0.72123194,  0.98281103,  3.118568  ,
        0.38162696,  1.3556712 ,  0.570555  , -0.6736982 ,  2.4688735 ],      dtype=float32), Array([-0.84350806, -0.50517464, -0.5478717 , -0.21732308,  1.1992476 ,
        0.22803703, -0.4428231 ,  0.58957076, -0.9303674 , -1.063955  ],      dtype=float32))


Case 2.
I get evaluated 1st time
(Array([ 2.8916252 ,  0.8241199 ,  0.72123194,  0.98281103,  3.118568  ,
        0.38162696,  1.3556712 ,  0.570555  , -0.6736982 ,  2.4688735 ],      dtype=float32), Array([-0.84350806, -0.50517464, -0.5478717 , -0.21732308,  1.1992476 ,
        0.22803703, -0.4428231 ,  0.58957076, -0.9303674 , -1.063955  ],      dtype=float32))


Case 3. (a)
I get evaluated 1st time
(Array([ 2.8916252 ,  0.8241199 ,  0.72123194,  0.98281103,  3.118568  ,
        0.38162696,  1.3556712 ,  0.570555  , -0.6736982 ,  2.4688735 ],      dtype=float32), Array([-0.84350806, -0.50517464, -0