# Import libraries

In [22]:
from transformers import AutoFeatureExtractor, ResNetForImageClassification
import torch

In [30]:
from PIL import Image
import matplotlib.pyplot as plt
import os

# Import model

In [36]:
model = ResNetForImageClassification.from_pretrained('./resnet18')
model.eval()
feature_extractor = AutoFeatureExtractor.from_pretrained("./resnet18")

# Import test images and evaluate model

In [32]:
# import all image files in current directory
images = []
for file in os.listdir('./images'):
    if(".jpg" in file):
        images.append(file)

In [45]:
predictions = []
for i in images:
    image = Image.open(os.path.join('.','images',i))
    inputs = feature_extractor(image, return_tensors="pt")
    with torch.no_grad():
        logits = model(**inputs).logits
    predicted_label = logits.argmax(-1).item()
    predictions.append(model.config.id2label[predicted_label])

In [48]:
images

['aeroplane.jpg', 'cat.jpg', 'flower.jpg', 'pigeon.jpg', 'football.jpg']

In [49]:
predictions

['airliner',
 'tabby, tabby cat',
 'rapeseed',
 'European gallinule, Porphyrio porphyrio',
 'soccer ball']

# Export model to ONNX format

In [50]:
# dummy input, required to establish approximate computational graph
dummy_input = torch.randn(1, 3, 224, 224)

In [51]:
torch.onnx.export(model, dummy_input, "resnet18.onnx")

  if num_channels != self.num_channels:


verbose: False, log level: Level.ERROR



In [53]:
inputs

{'pixel_values': tensor([[[[-0.6452, -0.6965, -0.7650,  ...,  0.2453,  0.2282,  0.2111],
          [-0.6452, -0.7137, -0.7822,  ...,  0.2111,  0.1939,  0.1768],
          [-0.6281, -0.7308, -0.7822,  ...,  0.1254,  0.1426,  0.1254],
          ...,
          [-1.0904, -1.3815, -1.5014,  ..., -0.7993, -0.8678, -0.9020],
          [-1.1075, -1.2959, -1.2274,  ..., -0.8335, -1.0733, -1.0390],
          [-1.3815, -1.1589, -0.8678,  ..., -1.0219, -1.0733, -1.0219]],

         [[ 0.3627,  0.3452,  0.3102,  ...,  0.7829,  0.7654,  0.7479],
          [ 0.3102,  0.2927,  0.2402,  ...,  0.7304,  0.7129,  0.7129],
          [ 0.2927,  0.2402,  0.2052,  ...,  0.6779,  0.6779,  0.6779],
          ...,
          [-0.6352, -0.8452, -0.9503,  ..., -0.0224, -0.1625, -0.3025],
          [-0.5826, -0.6527, -0.4951,  ..., -0.1800, -0.5126, -0.5301],
          [-0.7752, -0.3025,  0.0301,  ..., -0.5301, -0.6527, -0.5651]],

         [[-1.2816, -1.2641, -1.2816,  ..., -0.7238, -0.7064, -0.7413],
          [-1