In [24]:
import torch
import torch.nn as nn
import torchvision.transforms
from torch.utils.data import DataLoader
import torch.optim as optim
from torchvision import datasets, transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML

In [25]:
from models.generator import Generator
from models.discriminator import Discriminator
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [26]:
batch_size = 128
num_classes = 10
learning_rate = 0.002
num_epochs = 100
num_color_channels = 1
num_feature_maps_g = 32
num_feature_maps_d = 32
size_z = 100
adam_beta1 = 0.2
num_gpu = 0

In [27]:
generator = Generator(size_z=size_z,
                      num_feature_maps=num_feature_maps_g,
                      num_color_channels=num_color_channels).to(device)
discriminator = Discriminator(num_feature_maps=num_feature_maps_d,
                              num_color_channels=num_color_channels).to(device)

In [28]:
generator.load_state_dict(torch.load('./saved_models/generator_100_epochs_mnist_excluded_3s.pkl', map_location=torch.device(device)))
discriminator.load_state_dict(torch.load('./saved_models/discriminator_100_epochs_mnist_excluded_3s.pkl', map_location=torch.device(device)))

<All keys matched successfully>

In [29]:
data = datasets.MNIST(
    root = "data",
    transform = transforms.ToTensor()
)

dataloader = torch.utils.data.DataLoader(
    data,
    batch_size=64,
    shuffle=True
)

In [30]:
def get_anomaly_score(x_query, g_z):
    x_prop = discriminator(x_query)
    g_z_prop = discriminator(g_z)

    loss_residual = torch.sum(torch.abs(x_query - g_z))
    loss_discriminative = torch.sum(torch.abs(x_prop - g_z_prop))

    return (1-.1)*loss_residual + .1* loss_discriminative

In [31]:
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
for i in range(1, cols * rows + 1):
    sample_idx = torch.randint(len(dataloader), size=(1,)).item()
    img, label = dataloader[sample_idx]
    figure.add_subplot(rows, cols, i)
    plt.axis("off")
    plt.imshow(img.squeeze(), cmap="gray")
plt.show()


TypeError: 'DataLoader' object is not subscriptable

<Figure size 800x800 with 0 Axes>