In [38]:
import os
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import torchinfo as torchinfo
import torchvision

from PIL import Image
from torch.optim import SGD, Adam
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

In [8]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

### Test on new dataset

In [17]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.RandomCrop(480),
    transforms.Resize(28),
    transforms.Normalize((0.1307,), (0.3081,)),    
])

In [18]:
class OutlierDataset(Dataset):
    def __init__(self, root_dir, transform):
        self.root_dir = root_dir
        self.transform = transform
        self.imgs = sorted([img for img in os.listdir(self.root_dir) if img.endswith('.jpeg')])
    
    def __len__(self):
        return len(self.imgs)
    
    def __getitem__(self, idx):
        img_path = os.path.join(self.root_dir, self.imgs[idx])
        img = Image.open(img_path).convert('L')
        
        img = self.transform(img)
        img = img.unsqueeze(0)
        return img

In [19]:
outlier_dataset = OutlierDataset("./data/images/", transform)

#### Network Definition

In [22]:
class Net(nn.Module):
    def __init__(self, inp_channels=1, num_classes=10, batch_norm=True):
        super(Net, self).__init__()
        self.batch_norm = batch_norm
        self.conv1 = nn.Conv2d(inp_channels, 32, 3)
        self.conv2 = nn.Conv2d(32, 64, 3)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, num_classes)
        self.bn1 = nn.BatchNorm2d(32)
        self.bn2 = nn.BatchNorm2d(64)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        
        self.pool1 = nn.MaxPool2d(2)
        self.activation = nn.ReLU()
    
    def forward(self, x):
        if self.batch_norm:
            x = self.activation(self.bn1(self.conv1(x)))
            x = self.pool1(self.activation(self.bn2(self.conv2(x))))
        else:
            x = self.activation(self.conv1(x))
            x = self.pool1(self.activation(self.conv2(x)))
        x = self.dropout1(x).flatten(1)
        x = self.activation(self.fc1(x))
        x = self.fc2(self.dropout2(x))
        return x

In [24]:
model = Net()
model.load_state_dict(torch.load("./models/default.pt"))
model

Net(
  (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
  (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
  (fc1): Linear(in_features=9216, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=10, bias=True)
  (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (dropout1): Dropout(p=0.25, inplace=False)
  (dropout2): Dropout(p=0.5, inplace=False)
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (activation): ReLU()
)

### Trial 1

In [43]:
pred_list = []
img_list = []
for i, img in enumerate(outlier_dataset):
    img = img.to(device)
    out = model(img)
    pred = torch.argmax(out, dim=1)
    pred_list.append(pred.item())
    img_list.append(img)

In [101]:
print("Labels: {}".format(pred_list))

Labels: [7, 6, 6, 2, 7, 7, 1, 2, 2, 7]


In [100]:
imgs = torch.cat(img_list)

In [95]:
def imshow(img):
    img = img * 0.3081 + 0.1307
    npimg = img.numpy().squeeze()
    
    plt.imshow(np.uint8(npimg * 255.0), cmap='gray', vmin=0, vmax=255)
    plt.show()