## Load Mojo Kernels

In [1]:
import os, torch
from pathlib import Path
from max.torch import CustomOpLibrary

assert torch.cuda.is_available()
op_dir = os.path.abspath('operations')

## Simple `add_one` Operation

In [24]:
op_lib = CustomOpLibrary(Path(op_dir))
add_one = op_lib.my_add_constant[{"value": 1}]

In [25]:
import time

def torch_add_one(inputs):
    return inputs + 1

def mojo_add_one(inputs):
    outputs = torch.zeros_like(inputs)
    add_one(outputs, inputs)
    return outputs

for device in ["cpu", "cuda"]:
    for op in [torch_add_one, mojo_add_one]:
        x = torch.zeros(1024, device=device)
        x = op(x) # warm-up
        start = time.perf_counter()
        for _ in range(1000):
            x = op(x)
        end = time.perf_counter()
        print(op.__name__, device, x, end - start)

torch_add_one cpu tensor([1001., 1001., 1001.,  ..., 1001., 1001., 1001.]) 0.00474576698616147
mojo_add_one cpu tensor([1001., 1001., 1001.,  ..., 1001., 1001., 1001.]) 0.2719612669898197
torch_add_one cuda tensor([1001., 1001., 1001.,  ..., 1001., 1001., 1001.], device='cuda:0') 0.008035784005187452
mojo_add_one cuda tensor([1001., 1001., 1001.,  ..., 1001., 1001., 1001.], device='cuda:0') 0.27242381998803467


## Different MatMul Operations

In [2]:
from max.driver import CPU, Accelerator, accelerator_count, Tensor
import torch
M = 4096
K = 6144
N = 2048
device = CPU() if accelerator_count() == 0 else Accelerator()
torch_A = torch.randn(M, K)
torch_B = torch.randn(K, N)
A = Tensor.from_numpy(torch_A.numpy()).to(device)
B = Tensor.from_numpy(torch_B.numpy()).to(device)

In [4]:
from max.graph import Graph, TensorType, DeviceRef, ops
def build_graph(session, algorithm):
    with Graph("matmul_graph", 
               input_types=[
                   TensorType(dtype=A.dtype, shape=A.shape, device=DeviceRef.from_device(device)),
                   TensorType(dtype=B.dtype, shape=B.shape, device=DeviceRef.from_device(device))
               ], 
               custom_extensions=[Path(op_dir)]) as graph:
        A_value, B_value = graph.inputs
        output = ops.custom(
            name="my_matmul",
            device=DeviceRef.from_device(device),
            values=[A_value, B_value],
            out_types=[
                TensorType(dtype=A.dtype, shape=[
                    A_value.tensor.shape[0], B_value.tensor.shape[1]
                ], device=DeviceRef.from_device(device))
            ],
            parameters={"algorithm": algorithm},
        )
        graph.output(output[0].tensor)
    return session.load(graph) # compile the graph

from max.engine import InferenceSession
session = InferenceSession(devices=[device])
graph =  build_graph(session, "naive")
graph.execute(A, B)[0].to_numpy() # test run

algo: naive


array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]], shape=(4096, 2048), dtype=float32)

In [63]:
torch_A = torch.randn(M, K)
torch_B = torch.randn(K, N)
A = Tensor.from_numpy(torch_A.numpy()).to(device)
B = Tensor.from_numpy(torch_B.numpy()).to(device)
print(graph.execute(A, B)[0].to_numpy())
print("reference:\n", (torch_A @ torch_B).numpy())

algo: naive
[[-0.8861555   1.8308187  -0.23671497 ...  0.817165   -0.8145387
   2.9653795 ]
 [ 0.20740214  0.07566676 -0.06206705 ...  0.5274629   0.252149
   2.0435982 ]
 [-0.4652769   0.5647216  -0.37028655 ... -0.70551145  0.11549457
   1.4300979 ]
 ...
 [-0.11095516  0.3688923   0.42027935 ...  0.32099453  0.87393135
  -0.634261  ]
 [-0.9586582  -0.10623673 -0.3575286  ... -0.14929608  0.45016223
  -0.716326  ]
 [ 1.4587893  -0.38653764 -1.7863083  ...  0.1596423   0.8203581
   0.57582057]]
reference:
 [[  16.080612   -167.2997      -16.580082   ...   92.61364
   -20.823242     43.591152  ]
 [  -0.99066925  -59.388092     -9.075029   ...   41.45944
   -85.95067     -75.04101   ]
 [  53.804382    -49.96623      -5.9724903  ...  -23.966118
   156.89017      31.602358  ]
 ...
 [  -2.109706    246.81204       0.7551545  ...  -12.55712
   -99.15286       1.9999676 ]
 [ -58.867798    -50.43374      77.19668    ...  -19.317852
    54.39126     -76.479935  ]
 [ -17.241022      2.6617918   

## Reference
[1] https://github.com/modular/modular/blob/main/examples/custom_ops/kernels/matrix_multiplication.mojo