In [None]:
import onnxruntime as ort
import pickle
import warnings
from pathlib import Path
import torch
from nnsmith.materialize import Model, Oracle
import numpy as np
from torch.onnx.verification import find_mismatch

In [None]:
model_path = Path("nnsmith_constrained/symbolic-cinit/torch/20/model_81_1574862576")

# Get the paths for pickles and weights
gir_path: Path = model_path / "gir.pkl"
oracle_path: Path = model_path / "oracle.pkl"
weights_path: Path = model_path / "model.pth"

# Load the model from pickle
with gir_path.open("rb") as f:
    gir = pickle.load(f)
model_type = Model.init("torch", "cpu")
model = model_type.from_gir(gir)

model.torch_model.load_state_dict(torch.load(weights_path), strict=False)

# Load oracle
oracle = Oracle.load(oracle_path)

model_args = tuple([torch.from_numpy(val) for key, val in oracle.input.items()])

In [None]:
find_mismatch(model.native_model, model_args, opset_version=16, keep_initializers_as_inputs=True).export_repro("outpt")

In [None]:
torch.onnx.export(model.native_model, model_args, "./model1.onnx", opset_version=16, keep_initializers_as_inputs=False)

In [None]:
sess = ort.InferenceSession("./model1.onnx", providers=['CPUExecutionProvider'])

In [None]:
inp = {inp.name: np.random.randn(*inp.shape).astype(np.float32) for inp in sess.get_inputs()}

In [None]:
inputs = {}
for inp in sess.get_inputs():
    inputs[inp.name] =  np.random.randn(*inp.shape).astype(np.float32)
    if inp.type == "tensor(bool)":
        inputs[inp.name] = np.atleast_1d(np.all(inputs[inp.name]))

In [None]:
inputs

In [None]:
sess.run(output_names=list(model.output_like),input_feed=inputs)