In [1]:
import time
import torch
from torch.utils.data import DataLoader
from torch import nn
from torch import optim
import torchvision
from torchvision import transforms

In [2]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

In [3]:
test_set = torchvision.datasets.CIFAR10(root='./data', train=False,
                                        download=True, transform=transform)

Files already downloaded and verified


# GPU

In [4]:
if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

DEVICE = torch.device(device)

In [5]:
DEVICE

device(type='cpu')

In [6]:
class ConvNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.act1 = nn.Tanh()
        self.pool1 = nn.MaxPool2d(2)
        self.conv2 = nn.Conv2d(32, 16, kernel_size=3, padding=1)
        self.act2 = nn.Tanh()
        self.pool2 = nn.MaxPool2d(2)
        self.fc1 = nn.Linear(16 * 8 * 8, 128)
        self.act3 = nn.Tanh()
        self.fc2 = nn.Linear(128, 10)

    
    def forward(self, x):
        out = self.pool1(self.act1(self.conv1(x)))
        out = self.pool2(self.act2(self.conv2(out)))
        out = out.reshape(-1, 16 * 8 * 8)
        out = self.act3(self.fc1(out))
        out = self.fc2(out)
        return out

In [7]:
loaded_model = ConvNet()
loaded_model.load_state_dict(torch.load("cifar10_model.pt", map_location=DEVICE))
loaded_model.eval()

ConvNet(
  (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (act1): Tanh()
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (act2): Tanh()
  (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=1024, out_features=128, bias=True)
  (act3): Tanh()
  (fc2): Linear(in_features=128, out_features=10, bias=True)
)

In [8]:
img, label = test_set[0]

In [9]:
out = loaded_model(img.unsqueeze(0))

In [10]:
_, pred = torch.max(out, dim=1)

In [11]:
pred

tensor([3])

In [12]:
label

3