import numpy as np
from onnx import save_model
from onnx.checker import check_model
from onnx.numpy_helper import from_array
from onnx.shape_inference import infer_shapes
from onnxscript import script
from onnxscript.onnx_types import FLOAT
from onnxscript.onnx_opset import opset12 as op

rng = np.random.default_rng()


def gen_with_stride(stride: int):
    @script()
    def model_script(
        x: FLOAT[1, 3, 512, 512],
    ) -> FLOAT[1, 64, None, None]:
        W = op.Constant(
            value=from_array(rng.random((64, 3, 7, 7), dtype=np.float32)),
        )

        B = op.Constant(value=from_array(rng.random(64, dtype=np.float32)))

        return op.Conv(
            x,
            W,
            B,
            dilations=[1, 1],
            strides=[stride, stride],
            auto_pad="SAME_UPPER",
        )

    model = model_script.to_model_proto()
    check_model(model, full_check=True)
    save_model(model, f"repro_{stride}.onnx")


for stride in [1, 2]:
    gen_with_stride(stride)
