# Inference with Intel-extension for PyTorch 


In [None]:
import os
import sys
import numpy as np
import torch
import torch.utils.data as data
import torchvision.transforms as transforms

print(*torch.__config__.show().split("\n"), sep="\n")
device = torch.device("cpu")

In [None]:
import intel_pytorch_extension as ipex
print(ipex.__version__)
device_ipex = torch.device("xpu")

#### Import datasets

In [None]:
import medmnist
from medmnist.models import ResNet18
from medmnist.dataset import PathMNIST, ChestMNIST, DermaMNIST, OCTMNIST, PneumoniaMNIST, RetinaMNIST, BreastMNIST, OrganMNISTAxial, OrganMNISTCoronal, OrganMNISTSagittal
from medmnist.info import INFO

#### Environment settings 

In [None]:
data_flag = 'retinamnist'
download = True
input_root = 'tmp_data/'


flag_to_class = {
    "pathmnist": PathMNIST,
    "chestmnist": ChestMNIST,
    "dermamnist": DermaMNIST,
    "octmnist": OCTMNIST,
    "pneumoniamnist": PneumoniaMNIST,
    "retinamnist": RetinaMNIST,
    "breastmnist": BreastMNIST,
    "organmnist_axial": OrganMNISTAxial,
    "organmnist_coronal": OrganMNISTCoronal,
    "organmnist_sagittal": OrganMNISTSagittal,
}

DataClass = flag_to_class[data_flag]

info = INFO[data_flag]
task = info['task']
n_channels = info['n_channels']
n_classes = len(info['label'])
n_samples = info['n_samples']['train']

# preprocessing
data_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[.5], std=[.5])
])

# load the data
train_dataset = DataClass(root=input_root, split='train', transform=data_transform, download=download)

# encapsulate data into dataloader form
train_loader = data.DataLoader(dataset=train_dataset, batch_size=n_samples, shuffle=True)

## Model, loss and optimization definition

In [None]:
model = ResNet18(in_channels=n_channels, num_classes=n_classes).to(device) 
model_ipex = ResNet18(in_channels=n_channels, num_classes=n_classes).to(device_ipex)

In [None]:
def inference(model, data_loader, device):
    model.eval()
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(data_loader):
            outputs = model(inputs.to(device))  

In [8]:
import time

print('==> Inference ...')

restore_model_path = './output/export_inference/model/'+data_flag+'.pth'

model.load_state_dict(torch.load(restore_model_path)['net'])

tsb = time.time()
inference(model,train_loader, device)
tsf= time.time() - tsb
print("time for Stock PyTorch", tsf)

model_ipex.load_state_dict(torch.load(restore_model_path)['net'])
tipexb = time.time()
inference(model_ipex,train_loader, device_ipex)
tipexf = time.time() - tipexb
print("time for Intel Extension for PyTorch", tipexf)


==> Inference ...
time for Stock PyTorch 2.312333345413208
time for Intel Extension for PyTorch 0.9268012046813965
