## Load Mojo Kernels

In [23]:
import os, torch
from pathlib import Path
from max.torch import CustomOpLibrary

assert torch.cuda.is_available()
op_dir = os.path.abspath('operations')

## Simple `add_one` Operation

In [24]:
op_lib = CustomOpLibrary(Path(op_dir))
add_one = op_lib.my_add_constant[{"value": 1}]

In [25]:
import time

def torch_add_one(inputs):
    return inputs + 1

def mojo_add_one(inputs):
    outputs = torch.zeros_like(inputs)
    add_one(outputs, inputs)
    return outputs

for device in ["cpu", "cuda"]:
    for op in [torch_add_one, mojo_add_one]:
        x = torch.zeros(1024, device=device)
        x = op(x) # warm-up
        start = time.perf_counter()
        for _ in range(1000):
            x = op(x)
        end = time.perf_counter()
        print(op.__name__, device, x, end - start)

torch_add_one cpu tensor([1001., 1001., 1001.,  ..., 1001., 1001., 1001.]) 0.00474576698616147
mojo_add_one cpu tensor([1001., 1001., 1001.,  ..., 1001., 1001., 1001.]) 0.2719612669898197
torch_add_one cuda tensor([1001., 1001., 1001.,  ..., 1001., 1001., 1001.], device='cuda:0') 0.008035784005187452
mojo_add_one cuda tensor([1001., 1001., 1001.,  ..., 1001., 1001., 1001.], device='cuda:0') 0.27242381998803467


## Different MatMul Operations

In [26]:
from max.driver import CPU, Accelerator, accelerator_count, Tensor
import torch
M = 4096
K = 6144
N = 2048
device = CPU() if accelerator_count() == 0 else Accelerator()
torch_A = torch.randn(M, K)
torch_B = torch.randn(K, N)
A = Tensor.from_numpy(torch_A.numpy()).to(device)
B = Tensor.from_numpy(torch_B.numpy()).to(device)

In [27]:
from max.graph import Graph, TensorType, DeviceRef, ops
def build_graph(session, algorithm):
    with Graph("matmul_graph", 
               input_types=[
                   TensorType(dtype=A.dtype, shape=A.shape, device=DeviceRef.from_device(device)),
                   TensorType(dtype=B.dtype, shape=B.shape, device=DeviceRef.from_device(device))
               ], 
               custom_extensions=[Path(op_dir)]) as graph:
        A_value, B_value = graph.inputs
        output = ops.custom(
            name="my_matmul",
            device=DeviceRef.from_device(device),
            values=[A_value, B_value],
            out_types=[
                TensorType(dtype=A.dtype, shape=[
                    A_value.tensor.shape[0], B_value.tensor.shape[1]
                ], device=DeviceRef.from_device(device))
            ],
            parameters={"algorithm": algorithm},
        )
        graph.output(output[0].tensor)
    return session.load(graph) # compile the graph

from max.engine import InferenceSession
session = InferenceSession(devices=[device])
graph =  build_graph(session, "naive")
graph.execute(A, B)[0].to_numpy()

algo: naive


array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]], shape=(4096, 2048), dtype=float32)

In [28]:
torch_A = torch.randn(M, K)
torch_B = torch.randn(K, N)
A = Tensor.from_numpy(torch_A.numpy()).to(device)
B = Tensor.from_numpy(torch_B.numpy()).to(device)
print(graph.execute(A, B)[0].to_numpy())
print("reference:\n", (torch_A @ torch_B).numpy())

algo: naive
[[ 2.088285    2.1629362   0.01127979 ...  1.1015418  -1.3027214
   0.0937182 ]
 [-1.6135992  -0.21211405  0.14463045 ... -1.1236326  -0.9327972
   0.3458073 ]
 [-2.062809    0.02069045 -0.35954815 ...  0.4331466   1.4154134
   1.3926151 ]
 ...
 [-2.0157542   0.65817773  0.27584735 ... -1.0323421   0.18329106
  -0.07031224]
 [-0.89539456 -0.43150127 -1.7474142  ... -0.3353319   0.60521877
  -1.2522432 ]
 [-0.49523154 -1.2680957   0.8045344  ...  0.17800637 -1.5463879
   0.23116419]]
reference:
 [[ -48.72889     71.77749    -36.678432  ...  -42.68213     51.045113
    32.300026 ]
 [ 119.70815    -78.42561    105.39278   ...    1.873215     4.0641813
    98.63118  ]
 [  27.451885   142.09695      2.8380036 ... -118.677734   -18.013294
    -5.920418 ]
 ...
 [  17.899277     3.7735796  -24.88771   ... -129.08417    -98.912445
   116.08875  ]
 [ -68.0239      29.47119   -102.6938    ...  -92.22157    -35.40308
   -40.6087   ]
 [ -25.16601    -33.739555    56.680183  ...  -38.804