In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import datasets, transforms

import time
import numpy as np
import i8ie

class Net(nn.Module):    

    def __init__(self):
        super().__init__()
        self.quantized = False
        self.conv1 = nn.Conv2d(3, 20, kernel_size=5)
        self.conv2 = nn.Conv2d(20, 50, kernel_size=5)
        self.conv3 = nn.Conv2d(50, 120, kernel_size=5)
        self.fc = nn.Linear(960*8, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2,2)
        x = self.conv3(x)
        x = F.relu(x)
        x = x.reshape(-1, 960*8)
        x = self.fc(x)
        return x

class MyNet(i8ie.Module):
    def __init__(self):
        super().__init__()
        self.quantized = False
        self.conv1 = i8ie.Conv2d(3, 20, kernel_size=5)
        self.conv2 = i8ie.Conv2d(20, 50, kernel_size=5)
        self.conv3 = i8ie.Conv2d(50, 120, kernel_size=5)
        self.fc = i8ie.Linear(960*8, 10)

    def forward(self, x):
        x = i8ie.relu(self.conv1(x))
        x = i8ie.relu(self.conv2(x))
        x = i8ie.max_pool2d(x, 2,2)
        x = i8ie.relu(self.conv3(x))
        x = x.reshape(-1, 960*8)
        x = self.fc(x)
        return x
    
state_dict = torch.load('conv_cifar10_32.pt')
torch_model = Net()
torch_model.load_state_dict(state_dict)
my_model = MyNet()
my_model.load(state_dict)

In [2]:
batch_size = 100
transform = torchvision.transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    ]
)

test_dataset = torchvision.datasets.CIFAR10('./data/cifar10', train=False, download=True,transform=transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
cal_loader = torch.utils.data.DataLoader(test_dataset, batch_size=100, shuffle=True)

xs = []
ts = []
xss = []
tss = []
for batch_idx, (x, target) in enumerate(test_loader):
    xs.append(i8ie.tensor(x))
    ts.append(i8ie.tensor(target))
    xss.append(x)
    tss.append(target)
    if batch_idx >= 100:
        break

Files already downloaded and verified


In [3]:
%%time
correct_cnt = 0
for x,target in zip(xss,tss):
    out = torch_model(x)
    _, pred_label = torch.max(out.data, 1)
    correct_cnt += (pred_label == target.data).sum()
    
print(correct_cnt)

tensor(7033)
CPU times: user 10.2 s, sys: 43.5 ms, total: 10.3 s
Wall time: 1.29 s


In [4]:
%%time
correct_cnt = 0
for x,target in zip(xs,ts):
    x = my_model(x)
    p = i8ie.argmax(x, axis = 1)
    correct_cnt += (p == target).sum()

print(correct_cnt)

7033.0
CPU times: user 11.4 s, sys: 20 ms, total: 11.4 s
Wall time: 1.43 s


In [5]:
%%time
my_model.prepare()
for batch_idx, (x, target) in enumerate(cal_loader):
    x = my_model(i8ie.tensor(x))
    break
my_model.convert()

CPU times: user 678 ms, sys: 19.9 ms, total: 698 ms
Wall time: 87.1 ms


In [6]:
%%time
correct_cnt = 0
for x,target in zip(xs,ts):
    x = my_model(x)
    p = i8ie.argmax(x, axis = 1)
    correct_cnt += (p == target).sum()
print(correct_cnt)

7020.0
CPU times: user 11.1 s, sys: 20.1 ms, total: 11.1 s
Wall time: 1.39 s
