In [1]:
import torch
import torchvision
import coremltools as ct

In [2]:
torch.__version__

'1.13.1+cu117'

In [3]:
animal_list = ['penguin',
 'whale',
 'bird',
 'monkey',
 'mosquito',
 'fish',
 'duck',
 'spider',
 'owl',
 'squirrel',
 'snail',
 'panda',
 'lobster',
 'elephant',
 'bear',
 'horse',
 'raccoon',
 'dolphin',
 'snake',
 'cat',
 'frog',
 'rhinoceros',
 'teddy-bear',
 'crocodile',
 'sheep',
 'flamingo',
 'mouse',
 'scorpion',
 'cow',
 'octopus',
 'tiger',
 'dog',
 'bat',
 'parrot',
 'mermaid',
 'butterfly',
 'dragon',
 'kangaroo',
 'rabbit',
 'giraffe',
 'zebra',
 'swan',
 'ant',
 'bee',
 'hedgehog',
 'lion',
 'camel',
 'shark',
 'crab',
 'sea_turtle',
 'pig'
]

In [4]:
# Load pre-trained model
torch_model = torchvision.models.efficientnet_b3(weights='DEFAULT')
torch_model.features[0][0].in_channels = 1
torch_model.features[0][0].weight.data = torch_model.features[0][0].weight.data.sum(dim=1)[:,None]
num_classes = len(animal_list)
fc_in_features = torch_model.classifier[1].in_features
torch_model.classifier[1] = torch.nn.Linear(in_features=fc_in_features, out_features=num_classes)
# Load weights from trained model
checkpoint_pth = '../models/checkpoint_v11_efficientnetb3.pth'
torch_model.load_state_dict(torch.load(checkpoint_pth))
torch_model = torch.nn.Sequential(
    torch_model,
    torch.nn.Softmax(dim=1)
)
# Shift model to GPU
device = 'cuda'
torch_model.to(device)

Sequential(
  (0): EfficientNet(
    (features): Sequential(
      (0): Conv2dNormActivation(
        (0): Conv2d(1, 40, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (1): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): SiLU(inplace=True)
      )
      (1): Sequential(
        (0): MBConv(
          (block): Sequential(
            (0): Conv2dNormActivation(
              (0): Conv2d(40, 40, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=40, bias=False)
              (1): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): SiLU(inplace=True)
            )
            (1): SqueezeExcitation(
              (avgpool): AdaptiveAvgPool2d(output_size=1)
              (fc1): Conv2d(40, 10, kernel_size=(1, 1), stride=(1, 1))
              (fc2): Conv2d(10, 40, kernel_size=(1, 1), stride=(1, 1))
              (activation): SiLU(inplace=True)
              (scale_activa

In [5]:
# Switch model to eval mode
torch_model.eval()

Sequential(
  (0): EfficientNet(
    (features): Sequential(
      (0): Conv2dNormActivation(
        (0): Conv2d(1, 40, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (1): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): SiLU(inplace=True)
      )
      (1): Sequential(
        (0): MBConv(
          (block): Sequential(
            (0): Conv2dNormActivation(
              (0): Conv2d(40, 40, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=40, bias=False)
              (1): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): SiLU(inplace=True)
            )
            (1): SqueezeExcitation(
              (avgpool): AdaptiveAvgPool2d(output_size=1)
              (fc1): Conv2d(40, 10, kernel_size=(1, 1), stride=(1, 1))
              (fc2): Conv2d(10, 40, kernel_size=(1, 1), stride=(1, 1))
              (activation): SiLU(inplace=True)
              (scale_activa

In [6]:
# Trace the model with random data.
example_input = torch.rand(1, 1, 224, 224)
example_input = example_input.to(device)
traced_model = torch.jit.trace(torch_model, example_input)
out = traced_model(example_input)

In [7]:
out

tensor([[0.0214, 0.0190, 0.0156, 0.0172, 0.0179, 0.0184, 0.0213, 0.0120, 0.0195,
         0.0171, 0.0128, 0.0283, 0.0223, 0.0209, 0.0174, 0.0214, 0.0184, 0.0202,
         0.0122, 0.0181, 0.0174, 0.0274, 0.0234, 0.0402, 0.0211, 0.0159, 0.0161,
         0.0226, 0.0267, 0.0112, 0.0202, 0.0159, 0.0218, 0.0186, 0.0204, 0.0173,
         0.0148, 0.0216, 0.0228, 0.0211, 0.0208, 0.0240, 0.0165, 0.0193, 0.0195,
         0.0184, 0.0142, 0.0215, 0.0165, 0.0198, 0.0214]], device='cuda:0',
       grad_fn=<SoftmaxBackward0>)

In [8]:
# Set the image scale and bias for input image preprocessing.
scale = 1/255
image_input = ct.ImageType(name="image_input",
                           shape=example_input.shape,
                           scale=scale,
                           color_layout=ct.colorlayout.GRAYSCALE,
                          channel_first=None)

In [9]:
# Convert to Core ML using the Unified Conversion API.
model = ct.convert(
    traced_model,
    inputs=[image_input],
    classifier_config = ct.ClassifierConfig(animal_list),
    compute_units=ct.ComputeUnit.ALL
)

Converting PyTorch Frontend ==> MIL Ops: 100%|█████████▉| 955/956 [00:00<00:00, 3076.57 ops/s]
Running MIL frontend_pytorch pipeline: 100%|██████████| 5/5 [00:00<00:00, 120.38 passes/s]
Running MIL default pipeline: 100%|██████████| 56/56 [00:01<00:00, 42.09 passes/s] 
Running MIL backend_neuralnetwork pipeline: 100%|██████████| 8/8 [00:00<00:00, 405.99 passes/s]
Translating MIL ==> NeuralNetwork Ops: 100%|██████████| 1157/1157 [00:01<00:00, 1066.25 ops/s]


In [11]:
# Set the model metadata
model.author = "Shen Juin Lee"
model.short_description = "Predicts the animal depicted in a doodle."

In [12]:
# Save the converted model.
model.save("drawandtell_v1.mlmodel")
# Print a confirmation message.
print('Model converted and saved!')

Model converted and saved!


In [11]:
spec = ct.utils.load_spec("drawandtell_v1.mlmodel")
builder = ct.models.neural_network.NeuralNetworkBuilder(spec=spec)
builder.inspect_layers(last=6)

[Id: 385], Name: 1827 (Type: innerProduct)
          Updatable: False
          Input blobs: ['input.677']
          Output blobs: ['var_1827']
[Id: 384], Name: input.677 (Type: reshapeStatic)
          Updatable: False
          Input blobs: ['x']
          Output blobs: ['input.677']
[Id: 383], Name: x (Type: pooling)
          Updatable: False
          Input blobs: ['input.675']
          Output blobs: ['x']
[Id: 382], Name: input.675 (Type: multiply)
          Updatable: False
          Input blobs: ['input.673', 'input.675__silu_sigmoid__']
          Output blobs: ['input.675']
[Id: 381], Name: input.675__silu_sigmoid__ (Type: activation)
          Updatable: False
          Input blobs: ['input.673']
          Output blobs: ['input.675__silu_sigmoid__']
[Id: 380], Name: input.673 (Type: convolution)
          Updatable: False
          Input blobs: ['input.669']
          Output blobs: ['input.673']


In [12]:
builder.add_softmax(name='softmax_final', input_name='var_1827', output_name='class_probs')

name: "softmax_final"
input: "var_1827"
output: "class_probs"
softmax {
}

In [13]:
ct.models.utils.save_spec(builder.spec, 'drawandtell_v1.mlmodel')

In [16]:
spec = model.get_spec()
for out in spec.description.output:
    if out.type.WhichOneof('Type') == "dictionaryType":
        print(out)
        break

name: "var_1827"
type {
  dictionaryType {
    stringKeyType {
    }
  }
}

