In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import datasets, transforms

import time
import numpy as np
import i8ie


class Net(nn.Module):    

    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 96, kernel_size=11, stride=4, padding=2)
        self.conv2 = nn.Conv2d(96, 256, kernel_size=5, padding=2)
        self.conv3 = nn.Conv2d(256, 384, kernel_size=3, padding=1)
        self.conv4 = nn.Conv2d(384, 384, kernel_size=3, padding=1)
        self.conv5 = nn.Conv2d(384, 256, kernel_size=3, padding=1)
        self.fc1= nn.Linear(256 * 6 * 6, 4096)
        self.fc2 = nn.Linear(4096, 4096)
        self.fc3 = nn.Linear(4096, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 3, 2)
        x = F.max_pool2d(F.relu(self.conv2(x)), 3, 2)
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        x = F.max_pool2d(F.relu(self.conv5(x)), 3, 2)
        x = x.reshape(-1, 6*6*256)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

    
class MyNet(i8ie.Module):    

    def __init__(self):
        super().__init__()
        self.conv1 = i8ie.Conv2d(3, 96, kernel_size=11, stride=4, padding=2)
        self.conv2 = i8ie.Conv2d(96, 256, kernel_size=5, padding=2)
        self.conv3 = i8ie.Conv2d(256, 384, kernel_size=3, padding=1)
        self.conv4 = i8ie.Conv2d(384, 384, kernel_size=3, padding=1)
        self.conv5 = i8ie.Conv2d(384, 256, kernel_size=3, padding=1)
        self.fc1= i8ie.Linear(256 * 6 * 6, 4096)
        self.fc2 = i8ie.Linear(4096, 4096)
        self.fc3 = i8ie.Linear(4096, 10)

    def forward(self, x):
        x = i8ie.relu(self.conv1(x))
        x = i8ie.max_pool2d(x, 3, 2)
        x = i8ie.max_pool2d(i8ie.relu(self.conv2(x)), 3, 2)
        x = i8ie.relu(self.conv3(x))
        x = i8ie.relu(self.conv4(x))
        x = i8ie.max_pool2d(i8ie.relu(self.conv5(x)), 3, 2)
        x = x.reshape(-1, 6*6*256)
        x = i8ie.relu(self.fc1(x))
        x = i8ie.relu(self.fc2(x))
        x = self.fc3(x)
        return x

state_dict = torch.load('alex_cifar10_224.pt')
torch_model = Net()
torch_model.load_state_dict(state_dict)
my_model = MyNet()
my_model.load(state_dict)

In [2]:
batch_size = 100
transform = torchvision.transforms.Compose(
    [
        transforms.Resize(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    ]
)

test_dataset = torchvision.datasets.CIFAR10('./data/cifar10', train=False, download=True,transform=transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
cal_loader = torch.utils.data.DataLoader(test_dataset, batch_size=100, shuffle=True)

xs = []
ts = []
xss = []
tss = []
for batch_idx, (x, target) in enumerate(test_loader):
    xs.append(i8ie.tensor(x))
    ts.append(i8ie.tensor(target))
    xss.append(x)
    tss.append(target)
    if batch_idx >= 100:
        break

Files already downloaded and verified


In [3]:
%%time
correct_cnt = 0
for x,target in zip(xss,tss):
    out = torch_model(x)
    _, pred_label = torch.max(out.data, 1)
    correct_cnt += (pred_label == target.data).sum()
    
print(correct_cnt)

tensor(7775)
CPU times: user 4min 51s, sys: 4.75 s, total: 4min 56s
Wall time: 37.1 s


In [4]:
%%time
correct_cnt = 0
for x,target in zip(xs,ts):
    x = my_model(x)
    p = i8ie.argmax(x, axis = 1)
    correct_cnt += (p == target).sum()

print(correct_cnt)

7775.0
CPU times: user 6min 22s, sys: 3.9 s, total: 6min 26s
Wall time: 48.3 s


In [5]:
%%time
my_model.prepare()
for batch_idx, (x, target) in enumerate(cal_loader):
    x = my_model(i8ie.tensor(x))
    break
my_model.convert()

CPU times: user 10.4 s, sys: 196 ms, total: 10.6 s
Wall time: 1.46 s


In [6]:
%%time
correct_cnt = 0
for x,target in zip(xs,ts):
    x = my_model(x)
    p = i8ie.argmax(x, axis = 1)
    correct_cnt += (p == target).sum()
print(correct_cnt)

7642.0
CPU times: user 4min 51s, sys: 1.47 s, total: 4min 52s
Wall time: 36.6 s
