In [70]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [71]:
import torch
from torch import nn
from torch.utils.data import DataLoader

import torchvision
from torchvision.transforms import ToTensor, Compose, Normalize

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from dataset import ImageDataset, DSubset, Label, get_channel_means_stdevs

import pickle

In [72]:
labels_map = {1: 'Airplane', 2: 'Automobile', 3: 'Bird', 4: 'Cat', 5: 'Deer', 6: 'Dog', 7: 'Frog', 8: 'Horse', 9: 'Ship', 10: 'Truck'}

with open('../results/channel_training_statistics.pkl', 'rb') as f:
    training_channel_means, training_channel_stdevs = pickle.load(f)
    
tf = Compose([
    Normalize(training_channel_means, training_channel_stdevs)
])

label_type = Label.REAL_OR_SYNTHETIC

train_dataset = ImageDataset(DSubset.TRAIN, label_type, transform = tf)
test_dataset = ImageDataset(DSubset.TEST, label_type, transform = tf)

train_dataloader = DataLoader(train_dataset, batch_size = 128)
test_dataloader = DataLoader(test_dataset, batch_size = 128)

In [73]:
num_channels = 3
train_channel_means, train_channel_stdevs = get_channel_means_stdevs(train_dataloader, num_channels = num_channels, verbose = False)

# Verify successful standardization functionality: mean 0 and standard deviation 1 on training set
assert np.allclose(np.array(train_channel_means), np.zeros(num_channels), atol = 1e-5)
assert np.allclose(np.array(train_channel_stdevs), np.ones(num_channels), atol = 1e-5)

In [94]:
class SyntheticCNN(nn.Module):

    def __init__(self):
        super().__init__()

        self.conv_layers = nn.Sequential(
            nn.Conv2d(3, 5, 5), # [BATCH_SIZE, 3, 32, 32] -> [BATCH_SIZE, 5, 28, 28]
            nn.BatchNorm2d(5),
            nn.LeakyReLU(),
            nn.MaxPool2d(2, 2), # [BATCH_SIZE, 5, 28, 28] -> [BATCH_SIZE, 5, 14, 14]
            nn.Conv2d(5, 15, 5), # [BATCH_SIZE, 5, 14, 14] -> [BATCH_SIZE, 15, 10, 10]
            nn.BatchNorm2d(15),
            nn.LeakyReLU(),
            nn.MaxPool2d(2, 2), # [BATCH_SIZE, 15, 10, 10] -> [BATCH_SIZE, 15, 5, 5]
        )

        self.linear_layers = nn.Sequential(
            nn.Linear(15 * 5 * 5, 64),
            nn.LeakyReLU(),
            nn.Linear(64, 32),
            nn.LeakyReLU(),
            nn.Linear(32, 16),
            nn.LeakyReLU(),
            nn.Linear(16, 1)
        )

        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        
        x = self.conv_layers(x)
        x = torch.flatten(x, start_dim = 1) # Flatten all dimensions except batch (dim 0)
        x = self.linear_layers(x)
        x = self.sigmoid(x)

        return x # probability of class 1 (synthetic)