# Comparing `SurrealML` execution times with `ONNX` and `PyTorch` 

## Table of contents

1. [General dependencies and helpers](##general-dependencies-and-helpers)
2. [Some words about SurrealML](##some-words-about-surrealml)

## General dependencies and helpers

We will start by exporting some tools we will use for timing, and operating with SurrealDB/SurrrealML...

In [8]:
# we will use subprocess to upload our model via CLI programmatically
import subprocess

# time and wraps are used to define a chronometer decorator, in order to bring under a simple API the measurement of execution times
import time
from functools import wraps

# We import the necessary classes from SurrealML, to be discussed immediately
from surrealml import SurMlFile, Engine

## Some words about SurrealML

According to the [official docs](https://surrealdb.com/docs/surrealml):
```
SurrealML is an engine that seeks to do one thing, and one thing well: store and execute trained ML models. SurrealML does not intrude on the training frameworks that are already out there, instead works with them to ease the storage, loading, and execution of models. Someone using SurrealML will be able to train their model in a chosen framework in Python, save their model, and load and execute the model in either Python or Rust.
```

We aim to time the three cases that may be encountered in practice, namely:

1. **SurrealML**: predicting with the model in ONNX format _inside_ the SurrealDB, and then fetching the prediction from SurrealDB.
2. **PyTorch**: fetching the data from SurrealDB and _externally_ predicting with the PyTorch model.
3. **ONNX**: fetching the data from SurrealDB and _externally_ predicting with the ONNX model.

To handle timing in a consistent way, we define the decorator `chronometer`.

In [2]:
def chronometer(foo):
    @wraps(foo)
    def wrapper(*args, **kwargs):
        start = time.time()
        _ = foo(*args, *kwargs)
        end = time.time()
        return end - start

    return wrapper

In [3]:
import torch
import torch.nn as nn


class Fnn(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(10, 5)
        self.fc2 = nn.Linear(5, 1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

    def __str__(self) -> str:
        return self.__class__.__name__

ModuleNotFoundError: No module named 'torch'

In [None]:
model = Fnn()
model.load_state_dict(torch.load("parameters.pth")["model_state_dict"])
model.eval()

In [None]:


example_input = torch.rand(1, 10)
surml_file = SurMlFile(
    model=model, name=str(model), inputs=example_input, engine=Engine.PYTORCH
)

path_surml = "./model.surml"
surml_file.add_version("0.0.1")
surml_file.save(path_surml)

In [None]:
"""
The URL of the SurrealDB instance, and the authentication details, as well as the namespace and database in scope
"""

URL = "http://localhost:8000"
NS = "latency_test"
DB = "surreal_ml_vs_pytorch"
USR = "user"
PASS = "user"
CRD = (USR, PASS)
PATH = "file://surreal"

In [None]:
"""
The CLI command to import the model in the SurrealDB database of choice, this time we use it programmatically
"""

command = [
    "surreal",
    "ml",
    "import",
    "--endpoint",
    URL,
    "--user",
    USR,
    "--pass",
    PASS,
    "--ns",
    NS,
    "--db",
    DB,
    path_surml,
]


subprocess.Popen(
    command,
    stdin=subprocess.PIPE,
    stdout=subprocess.PIPE,
    stderr=subprocess.PIPE,
)
output = subprocess.check_output(command, stderr=subprocess.STDOUT)
output_str = output.decode("utf-8")
print(output_str)

In [None]:
%%capture
"""
###################################################################################################
--> The other way to upload a model to SurrealDB. Could not make it work by the time of writing :( 
###################################################################################################

surml_file.upload(
    path="./model.surml",
    url=URL,
    chunk_size=36864,
    namespace=NS,
    database=DB,
    username=USR,
    password=PASS)
"""

In [None]:
"""
The code that was used to generate the test data, and load it in SurrealDB
"""

from surrealist import Surreal

surreal = Surreal(
    url=URL,
    namespace=NS,
    database=DB,
    credentials=CRD,
    log_level="ERROR",
    timeout=10**4,
)

max_test_size = 10**4
# chunk_size should divide max_test_size, we do it as we create the test inputs in memory, and thus we should avoid a memory crash.
chunk_size = 10**2
number_chunks = int(max_test_size / chunk_size)

# likewise, test_step should divide max_test_size
test_step = 10**3
number_steps = int(max_test_size / test_step)

with surreal.connect() as connect:
    for _ in range(number_chunks):
        test_inputs = torch.rand(chunk_size, 10).tolist()
        [
            connect.query(f"CREATE inputs:ulid() SET value = {input};")
            for input in test_inputs
        ]

In [None]:
surreal_times = []
try:
    with surreal.connect() as connect:
        for increment in range(number_steps):
            test_size = (increment + 1) * test_step

            # in a prior run, the query result was tested with assert query_result["status"] == "OK"
            # also do print(query_result) to check further
            @chronometer
            def evaluate_with_surrealdb():
                _ = connect.query(
                    f"SELECT VALUE ml::Fnn<0.0.1>(value) FROM inputs LIMIT {test_size};"
                ).to_dict()["result"]

            elapsed_time = evaluate_with_surrealdb()
            print(f"For {test_size} datapoints, it took {elapsed_time} seconds")
            surreal_times.append(elapsed_time)
except Exception as e:
    print(e)

In [None]:
pytorch_times = []
try:
    with surreal.connect() as connect:
        for increment in range(1, 11):
            test_size = increment * test_step

            @chronometer
            def evaluate_pytorch():
                inputs = connect.query(
                    f"SELECT VALUE value FROM inputs LIMIT {test_size}"
                ).to_dict()["result"]

                with torch.no_grad():
                    _ = model.forward(torch.tensor(inputs))

            elapsed_time = evaluate_pytorch()
            print(f"For {test_size} datapoints, it took {elapsed_time} seconds")
            pytorch_times.append(elapsed_time)
except Exception as e:
    print(e)