In [None]:
# !pip install onnx2torch

In [2]:
import torch
import torchvision
from onnx2torch import convert
import numpy as np
import os

In [2]:
# Path to ONNX model
onnx_model_path = './medium.onnx'
# You can pass the path to the onnx model to convert it or...
torch_model = convert(onnx_model_path)

In [3]:
transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])
test_ds = torchvision.datasets.MNIST('./data', train=False, transform=transform, download=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw



In [7]:
test_dl = torch.utils.data.DataLoader(
    test_ds,     
    batch_size = 512,
    shuffle = False)

In [12]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = torch_model.to(device)
predicts = []
labels = []
for batch in test_dl:
    y = model(batch[0].to(device)).to('cpu').detach().numpy()
    t = batch[1].numpy()
    predicts.append(np.argmax(y, axis=1))
    labels.append(t)

In [18]:
predicts = np.concatenate(predicts, axis=0)
labels = np.concatenate(labels, axis=0)

In [19]:
np.mean(predicts==labels)

0.9903

In [3]:
model = torch.nn.Sequential(
        torch.nn.Conv2d(
            in_channels=1,
            out_channels=16,
            kernel_size=3,
            stride=1,
            padding=1,
            ),
        torch.nn.ReLU(),
        torch.nn.MaxPool2d(
            kernel_size=2,
            ),
        torch.nn.Conv2d(
            in_channels=16,
            out_channels=16,
            kernel_size=3,
            stride=1,
            padding=1,
            ),
        torch.nn.ReLU(),
        torch.nn.MaxPool2d(
            kernel_size=2,
            ),
        torch.nn.Conv2d(
            in_channels=16,
            out_channels=16,
            kernel_size=3,
            stride=1,
            padding=1,
            ),
        # 7*7*16=784
        torch.nn.Flatten(),
        torch.nn.Linear(7*7*16, 256),
        torch.nn.ReLU(),
        torch.nn.Linear(256, 10),
    )

In [6]:
p = os.path.join('datafree-model-extraction', 'dfme', 'checkpoint', 'student_debug', 'mnist-medium.pt')
model.load_state_dict(torch.load(p))

<All keys matched successfully>

In [7]:
dummy_input = torch.randn((1, 1, 28, 28))
torch.onnx.export(model, dummy_input, "medium_ex.onnx", verbose=True)

graph(%input.1 : Float(1, 1, 28, 28, strides=[784, 784, 28, 1], requires_grad=0, device=cpu),
      %0.weight : Float(16, 1, 3, 3, strides=[9, 9, 3, 1], requires_grad=1, device=cpu),
      %0.bias : Float(16, strides=[1], requires_grad=1, device=cpu),
      %3.weight : Float(16, 16, 3, 3, strides=[144, 9, 3, 1], requires_grad=1, device=cpu),
      %3.bias : Float(16, strides=[1], requires_grad=1, device=cpu),
      %6.weight : Float(16, 16, 3, 3, strides=[144, 9, 3, 1], requires_grad=1, device=cpu),
      %6.bias : Float(16, strides=[1], requires_grad=1, device=cpu),
      %8.weight : Float(256, 784, strides=[784, 1], requires_grad=1, device=cpu),
      %8.bias : Float(256, strides=[1], requires_grad=1, device=cpu),
      %10.weight : Float(10, 256, strides=[256, 1], requires_grad=1, device=cpu),
      %10.bias : Float(10, strides=[1], requires_grad=1, device=cpu)):
  %input : Float(1, 16, 28, 28, strides=[12544, 784, 28, 1], requires_grad=0, device=cpu) = onnx::Conv[dilations=[1, 1], 