# Set XRT TPU Config

In [1]:
import os
os.environ['XRT_TPU_CONFIG'] = "localservice;0;localhost:51011"

# Recap the model class to load saved model

In [2]:
import numpy as np
import os
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.utils.utils as xu
import torchvision
from torchvision import datasets, transforms

class BasicBlock(nn.Module):
  expansion = 1

  def __init__(self, in_planes, planes, stride=1):
    super(BasicBlock, self).__init__()
    self.conv1 = nn.Conv2d(
        in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
    self.bn1 = nn.BatchNorm2d(planes)
    self.conv2 = nn.Conv2d(
        planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
    self.bn2 = nn.BatchNorm2d(planes)

    self.shortcut = nn.Sequential()
    if stride != 1 or in_planes != self.expansion * planes:
      self.shortcut = nn.Sequential(
          nn.Conv2d(
              in_planes,
              self.expansion * planes,
              kernel_size=1,
              stride=stride,
              bias=False), nn.BatchNorm2d(self.expansion * planes))

  def forward(self, x):
    out = F.relu(self.bn1(self.conv1(x)))
    out = self.bn2(self.conv2(out))
    out += self.shortcut(x)
    out = F.relu(out)
    return out


class ResNet(nn.Module):

  def __init__(self, block, num_blocks, num_classes=10):
    super(ResNet, self).__init__()
    self.in_planes = 64

    self.conv1 = nn.Conv2d(
        3, 64, kernel_size=3, stride=1, padding=1, bias=False)
    self.bn1 = nn.BatchNorm2d(64)
    self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
    self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
    self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
    self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
    self.linear = nn.Linear(512 * block.expansion, num_classes)

  def _make_layer(self, block, planes, num_blocks, stride):
    strides = [stride] + [1] * (num_blocks - 1)
    layers = []
    for stride in strides:
      layers.append(block(self.in_planes, planes, stride))
      self.in_planes = planes * block.expansion
    return nn.Sequential(*layers)

  def forward(self, x):
    out = F.relu(self.bn1(self.conv1(x)))
    out = self.layer1(out)
    out = self.layer2(out)
    out = self.layer3(out)
    out = self.layer4(out)
    out = F.avg_pool2d(out, 4)
    out = torch.flatten(out, 1)
    out = self.linear(out)
    return F.log_softmax(out, dim=1)


def ResNet18():
  return ResNet(BasicBlock, [2, 2, 2, 2])

# Create Model Handler

The following model handler is derrived from the torchserve's vision handler class.
We add the init inference and post process method according to our Resnet-18 model trainined on cifar-10 dataset.

In [107]:
import os
import json

import torch
from ts.torch_handler.vision_handler import VisionHandler

import copy
from pathlib import Path
import warnings

import numpy as np
import pandas as pd
import pytorch_lightning as pl
import torch


import logging
logger = logging.getLogger(__name__)
import torch_xla.core.xla_model as xm

class xla_image_classifier(VisionHandler):
    norm = transforms.Normalize(
        mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010))

    image_processing = transforms.Compose([
        transforms.ToTensor(),
        norm,
    ])

    def __init__(self):
        super(xla_image_classifier, self).__init__()
        self.initialized = False
        
    def initialize(self, ctx):
        """ Loads the model.pt file and initialized the model object.
        Instantiates Tokenizer for preprocessor to use
        Loads labels to name mapping file for post-processing inference response
        """
        self.manifest = ctx.manifest

        properties = ctx.system_properties
        model_dir = properties.get("model_dir")
        self.device = xm.xla_device()

        # Read model serialize/pt file
        serialized_file = self.manifest["model"]["serializedFile"]
        model_pt_path = os.path.join(model_dir, serialized_file)
        if not os.path.isfile(model_pt_path):
            raise RuntimeError("Missing the model.pt or pytorch_model.bin file")
        
        # Load model
        self.model = torch.load(model_pt_path)
        self.model.to(self.device)
        self.model.eval()
        logger.debug('model from path {0} loaded successfully'.format(model_dir))
        

        self.initialized = True
        
     

    def inference(self, inputs):
        results = []
        return  self.model(inputs).argmax(dim=-1)
    
    def postprocess(self, inputs):
        CIFAR10_LABELS = ['airplane', 'automobile', 'bird', 'cat', 'deer',
                 'dog', 'frog', 'horse', 'ship', 'truck']
        return list(map(lambda x:CIFAR10_LABELS[x], inputs.tolist() ))

# Test the Handler
Here we will use the MockContext object to invoke the the handler as it would be invoked within torchserve

In [112]:
# Test Handler
from ts.torch_handler.unit_tests.test_utils.mock_context import MockContext
def model_setup():
    context = MockContext(model_name="img_cls", model_dir='/home/sivaibhav/img_cls')
    with open('/home/sivaibhav/kitten_small.jpg', 'rb') as fin:
        image_bytes = fin.read()
    return (context, image_bytes)

def test_initialize(model_setup):
    model_context, _ = model_setup
    handler = xla_image_classifier()
    handler.initialize(model_context)
    return handler

def test_handle(model_setup):
    context, data = model_setup
    handler = test_initialize(model_setup)
    test_data = [{'data': data}] * 2
    results = handler.handle(test_data, context)
    return results

In [113]:
_model_setup = model_setup()
test_initialize(_model_setup)

<__main__.xla_image_classifier at 0x7f64ffa756a0>

In [111]:
test_handle(_model_setup)

['cat', 'cat']