In [394]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [395]:
import torch
from torch import nn
from torch.utils.data import DataLoader

import torchvision
from torchvision.transforms import ToTensor, Compose, Normalize

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from dataset import ImageDataset, DSubset, Label, get_channel_means_stdevs

import pickle

In [396]:
labels_map = {1: 'Airplane', 2: 'Automobile', 3: 'Bird', 4: 'Cat', 5: 'Deer', 6: 'Dog', 7: 'Frog', 8: 'Horse', 9: 'Ship', 10: 'Truck'}

with open('../results/channel_training_statistics.pkl', 'rb') as f:
    training_channel_means, training_channel_stdevs = pickle.load(f)
    
tf = Compose([
    Normalize(training_channel_means, training_channel_stdevs)
])

label_type = Label.REAL_OR_SYNTHETIC

train_dataset = ImageDataset(DSubset.TRAIN, label_type, transform = tf)
test_dataset = ImageDataset(DSubset.TEST, label_type, transform = tf)

train_dataloader = DataLoader(train_dataset, batch_size = 64, shuffle = True)
test_dataloader = DataLoader(test_dataset, batch_size = 64, shuffle = True)

In [397]:
# num_channels = 3
# train_channel_means, train_channel_stdevs = get_channel_means_stdevs(train_dataloader, num_channels = num_channels, verbose = False)

# # Verify successful standardization functionality: mean 0 and standard deviation 1 on training set
# assert np.allclose(np.array(train_channel_means), np.zeros(num_channels), atol = 1e-5)
# assert np.allclose(np.array(train_channel_stdevs), np.ones(num_channels), atol = 1e-5)

In [398]:
class SyntheticCNN(nn.Module):
    """
    Convolutional Neural Network to classify images of being either real 
    or synthetically (AI) generated.
    """

    def __init__(self):
        """
        SyntheticCNN initializer.
        """
        super().__init__()

        self.conv_layers = nn.Sequential(
            nn.Conv2d(3, 5, 5), # [BATCH_SIZE, 3, 32, 32] -> [BATCH_SIZE, 5, 28, 28]
            nn.BatchNorm2d(5),
            nn.LeakyReLU(),
            nn.MaxPool2d(2, 2), # [BATCH_SIZE, 5, 28, 28] -> [BATCH_SIZE, 5, 14, 14]
            nn.Conv2d(5, 15, 5), # [BATCH_SIZE, 5, 14, 14] -> [BATCH_SIZE, 15, 10, 10]
            nn.BatchNorm2d(15),
            nn.LeakyReLU(),
            nn.MaxPool2d(2, 2), # [BATCH_SIZE, 15, 10, 10] -> [BATCH_SIZE, 15, 5, 5]
        )

        self.linear_layers = nn.Sequential(
            nn.Linear(15 * 5 * 5, 64),
            nn.LeakyReLU(),
            nn.Linear(64, 32),
            nn.LeakyReLU(),
            nn.Linear(32, 16),
            nn.LeakyReLU(),
            nn.Linear(16, 1)
        )

    def forward(self, x):
        """
        SyntheticCNN forward method. Runs convolutional layers first, then
        converts to linear layers.

        Args:
            x -- The input to be passed through the network.

        Returns:
            x -- The output of the model.
        """
        
        x = self.conv_layers(x)
        x = torch.flatten(x, start_dim = 1) # Flatten all dimensions except batch (dim 0)
        x = self.linear_layers(x)

        return x # logit of class 1 (synthetic) likelihood
    
synthetic_model = SyntheticCNN()

In [399]:
binary_cross_entropy = nn.BCEWithLogitsLoss()

sgd = torch.optim.SGD(synthetic_model.parameters(), lr = 1e-3, momentum = 0.9)

In [400]:
def run_train_loop(dataloader: DataLoader, model: nn.Module, loss_fn: nn.Module, optimizer: torch.optim.Optimizer):
    """
    Runs one full epoch of training on model.

    Args:
        dataloader -- The DataLoader through which to produce instances.
        model -- The model to be used for label prediction on instances.
        loss_fn -- The loss function, for backpropagation
        optimizer -- The optimizer, for reducing loss

    Returns:
        average_epoch_loss -- The model loss this epoch, averaged by the number of instances in the dataset
        epoch_accuracy -- The model accuracy this epoch, averaged by the number of instances in the dataset
    """
    
    model.train()

    num_correct_total = 0
    epoch_loss = 0.0

    for i, (X, y) in enumerate(dataloader):

        pred = model(X)
        batch_loss = loss_fn(pred.squeeze(), y.float())

        batch_loss.backward()
        
        epoch_loss += batch_loss.item()

        num_correct_in_batch = torch.sum((torch.sigmoid(pred.detach().squeeze()) > 0.5).float() == y.float()).item()
        num_correct_total += num_correct_in_batch

        # print(f'Batch {i+1} | Loss: {batch_loss.item():>7f} | Accuracy: {num_correct_in_batch / len(y):>7}')

        optimizer.step()
        optimizer.zero_grad()

    average_epoch_loss = epoch_loss / len(dataloader.dataset)
    epoch_accuracy = num_correct_total / len(dataloader.dataset)

    return average_epoch_loss, epoch_accuracy


def run_test_loop(dataloader: DataLoader, model: nn.Module, loss_fn: nn.Module):
    """
    Runs one full dataset-worth of testing on model.

    Args:
        dataloader -- The DataLoader through which to produce instances.
        model -- The model to be used for label prediction on instances.
        loss_fn -- The loss function, for improvement checking

    Returns:
        average_epoch_loss -- The model loss this epoch, averaged by the number of instances in the dataset
        epoch_accuracy -- The model accuracy this epoch, averaged by the number of instances in the dataset
    """
    
    model.eval()

    num_correct = 0
    epoch_loss = 0.0

    with torch.no_grad():

        for X, y in dataloader:

            pred = model(X)
            batch_loss = loss_fn(pred.squeeze(), y.float())

            epoch_loss += batch_loss.item()
            num_correct += torch.sum((torch.sigmoid(pred.detach().squeeze()) > 0.5).float() == y.float()).item()

    average_epoch_loss = epoch_loss / len(dataloader.dataset)
    epoch_accuracy = num_correct / len(dataloader.dataset)

    return average_epoch_loss, epoch_accuracy

In [401]:
EPOCHS = 10

for i in range(EPOCHS):
    train_loss, train_accuracy, = run_train_loop(train_dataloader, synthetic_model, binary_cross_entropy, sgd)
    test_loss, test_accuracy = run_test_loop(test_dataloader, synthetic_model, binary_cross_entropy)

    print(f'Epoch {i+1:>3} | Train Loss: {train_loss:>10f} | Train Accuracy: {train_accuracy:>10f}')
    print(f'Epoch {i+1:>3} | Test Loss: {test_loss:>10f} | Test Accuracy: {test_accuracy:>10f}')
    print('-' * len(f'Epoch {i+1} | Train Loss: {train_loss:>10f} | Train Accuracy: {train_accuracy:>10f}'))

Epoch 1 | Train Loss: 0.008483 | Train Accuracy: 0.702440
Epoch 1 | Test Loss: 0.006170 | Test Accuracy: 0.823850
---------------------------------------------------------
Epoch 2 | Train Loss: 0.004899 | Train Accuracy: 0.867770
Epoch 2 | Test Loss: 0.004879 | Test Accuracy: 0.866500
---------------------------------------------------------
Epoch 3 | Train Loss: 0.004121 | Train Accuracy: 0.891700
Epoch 3 | Test Loss: 0.003963 | Test Accuracy: 0.895900
---------------------------------------------------------
Epoch 4 | Train Loss: 0.003741 | Train Accuracy: 0.902800
Epoch 4 | Test Loss: 0.003759 | Test Accuracy: 0.900300
---------------------------------------------------------
Epoch 5 | Train Loss: 0.003520 | Train Accuracy: 0.908880
Epoch 5 | Test Loss: 0.003273 | Test Accuracy: 0.918300
---------------------------------------------------------
Epoch 6 | Train Loss: 0.003359 | Train Accuracy: 0.913900
Epoch 6 | Test Loss: 0.003196 | Test Accuracy: 0.917950
--------------------------