In [None]:
from google.colab import drive
# drive.mount('/content/drive', force_remount=True)

In [None]:
from torchvision.datasets import MNIST
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.optim as optim
from torchvision import datasets, transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
from PIL import Image
import time
import os
import shutil
import csv
import zipfile
import PIL
import math
from IPython.display import clear_output
import pandas as pd
from torch.utils.data import Dataset

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

batch_size = 512
num_classes = 1
learning_rate = 0.0001
num_epochs = 100
num_color_channels = 1
num_feature_maps_g = 64
num_feature_maps_d = 64
size_z = 100
adam_beta1 = 0.1
test_size = 1

map_anomalies = True
map_normals = True

In [None]:
class Generator(nn.Module):
    def __init__(self, size_z, num_feature_maps, num_color_channels):
        super(Generator, self).__init__()
        self.size_z = size_z
        self.network = nn.Sequential(
            nn.ConvTranspose2d(self.size_z, num_feature_maps * 4, 4, 1, 0, bias=False),
            nn.BatchNorm2d(num_feature_maps * 4),
            nn.ReLU(True),

            nn.ConvTranspose2d(num_feature_maps * 4, num_feature_maps * 2, 3, 2, 1, bias=False),
            nn.BatchNorm2d(num_feature_maps * 2),
            nn.ReLU(True),

            nn.ConvTranspose2d(num_feature_maps * 2, num_feature_maps, 4, 2, 1, bias=False),
            nn.BatchNorm2d(num_feature_maps),
            nn.ReLU(True),

            nn.ConvTranspose2d(num_feature_maps, num_color_channels, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, x):
        output = self.network(x)
        return output

    def gen_shifted(self, x, shift):
        shift = torch.unsqueeze(shift, -1)
        shift = torch.unsqueeze(shift, -1)
        return self.forward(x + shift)

In [None]:
class Discriminator(nn.Module):
    def __init__(self, num_feature_maps, num_color_channels):
        super(Discriminator, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(num_color_channels, num_feature_maps, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
        )
        self.layer2 = nn.Sequential(
            nn.Conv2d(num_feature_maps, num_feature_maps * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(num_feature_maps * 2),
            nn.LeakyReLU(0.2, inplace=True),
        )
        self.layer3 = nn.Sequential(
            nn.Conv2d(num_feature_maps * 2, num_feature_maps * 4, 3, 2, 1, bias=False),
            nn.BatchNorm2d(num_feature_maps * 4),
            nn.LeakyReLU(0.2, inplace=True),
        )
        self.fc = nn.Sequential(
            nn.Conv2d(num_feature_maps * 4, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = self.layer3(out)
        feature = out
        out = self.fc(out)
        return out.view(-1, 1).squeeze(1), feature

Generate and/or load Datasets

In [None]:
class AnoClassMNIST(Dataset):
    def __init__(self, root_dir, transform=None):
        root_dir = os.path.join(root_dir, "AnoClassMNIST")
        assert os.path.exists(os.path.join(root_dir, "ano_class_mnist_dataset.csv")), "Invalid root directory"
        self.root_dir = root_dir
        self.transform = transform
        self.label = pd.read_csv(os.path.join(root_dir, "ano_class_mnist_dataset.csv"))

    def __len__(self):
        return len(self.label)

    def __getitem__(self, idx):
        img_name = os.path.join(self.root_dir, self.label.iloc[idx, 0])
        image_label = {"label": self.label.iloc[idx, 1], "anomaly": self.label.iloc[idx, 2]}
        image = Image.open(img_name)

        if self.transform:
            image = self.transform(image)

        return image, image_label


class AnomalyExtendedMNIST(datasets.MNIST):
    def __getitem__(self, idx):
        return super(AnomalyExtendedMNIST, self).__getitem__(idx)[0], {"label": super(AnomalyExtendedMNIST, self).__getitem__(idx)[1], "anomaly": False}

In [None]:
def generate_ano_class_mnist_dataset(root_dir, norm_class=9, ano_class=6, ano_fraction=0.2, copy_zip_to='/content/drive/MyDrive/Colab/data'):
    if os.path.exists(root_dir):
        shutil.rmtree(root_dir)

    ano_mnist_drop_folder = os.path.join(root_dir, "AnoClassMNIST")
    csv_path = os.path.join(ano_mnist_drop_folder, "ano_class_mnist_dataset.csv")

    os.makedirs(root_dir, exist_ok=True)
    os.makedirs(ano_mnist_drop_folder, exist_ok=True)

    with open(csv_path, 'a', newline='') as file:
        writer = csv.writer(file)
        fields = ["filename", "label", "anomaly"]
        writer.writerow(fields)

    mnist_dataset = MNIST(
        root=root_dir,
        train=True,
        download=True,
    )

    norms = [d for d in mnist_dataset if (d[1] == norm_class)]
    for i, img in enumerate(norms):
        img[0].save(os.path.join(ano_mnist_drop_folder, f"img_{norm_class}_{i}.png"))
        with open(csv_path, 'a', newline='') as file:
            writer = csv.writer(file)
            fields = [f'img_{norm_class}_{i}.png', f"{norm_class}", "False"]
            writer.writerow(fields)


    anos = [d for d in mnist_dataset if (d[1] == ano_class)]
    anos = anos[:round(len(anos)*ano_fraction)]
    for i, img in enumerate(anos):
        img[0].save(os.path.join(ano_mnist_drop_folder, f"img_{ano_class}_{i}.png"))
        with open(csv_path, 'a', newline='') as file:
            writer = csv.writer(file)
            fields = [f'img_{ano_class}_{i}.png', f"{ano_class}", "True"]
            writer.writerow(fields)

    if copy_zip_to:
      shutil.make_archive(os.path.join(copy_zip_to, "AnoClassMNIST"), 'zip', ano_mnist_drop_folder)


def get_ano_class_mnist_dataset(root_dir, train_size=0.9, batch_size=256):
    ano_mnist_dataset = AnoClassMNIST(
        root_dir=root_dir,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=(.5,), std=(.5,))
        ])
    )

    return torch.utils.data.DataLoader(ano_mnist_dataset, batch_size=batch_size, shuffle=True)

In [None]:
generate_ano_class_mnist_dataset('/content/data', norm_class=9, ano_class=6, ano_fraction=0.2, copy_zip_to='/content/drive/MyDrive/Colab/data')

In [None]:
def load_ano_mnist_from_drive(drop_folder):
  with zipfile.ZipFile('/content/drive/MyDrive/Colab/data/AnoMNIST.zip', 'r') as zip_ref:
    zip_ref.extractall(drop_folder)

def load_ano_class_mnist_from_drive(drop_folder):
    if os.path.exists(drop_folder):
        shutil.rmtree(drop_folder)

    os.makedirs(drop_folder, exist_ok=True)
    with zipfile.ZipFile('/content/drive/MyDrive/Colab/data/AnoClassMNIST.zip', 'r') as zip_ref:
        zip_ref.extractall(drop_folder)

GAN training

In [None]:
generator = Generator(size_z=size_z,
                      num_feature_maps=num_feature_maps_g,
                      num_color_channels=num_color_channels).to(device)
discriminator = Discriminator(num_feature_maps=num_feature_maps_d,
                              num_color_channels=num_color_channels).to(device)

In [None]:
load_ano_class_mnist_from_drive(drop_folder='/content/data/AnoClassMNIST')
ano_class_mnist_dataset = get_ano_class_mnist_dataset(root_dir='/content/data', batch_size=batch_size)

In [None]:
def save_checkpoint(epoch):
    print("Saving Checkpoint...")
    drive.mount('/content/drive', force_remount=True)
    timestamp = time.time()
    torch.save(generator.state_dict(),f'/content/drive/My Drive/Colab/saved_models/generator_epoch_{epoch}_{timestamp}.pkl')
    torch.save(discriminator.state_dict(),f'/content/drive/My Drive/Colab/saved_models/discriminator_epoch_{epoch}_{timestamp}.pkl')

def save_models():
    print("Saving Models...")
    drive.mount('/content/drive', force_remount=True)
    torch.save(generator.state_dict(),f'/content/drive/My Drive/Colab/saved_models/generator_latest.pkl')
    torch.save(discriminator.state_dict(),f'/content/drive/My Drive/Colab/saved_models/discriminator_latest.pkl')

def load_models():
    print("Loading Models...")
    drive.mount('/content/drive', force_remount=True)
    generator.load_state_dict(torch.load("/content/drive/My Drive/Colab/saved_models/generator.pkl", map_location=torch.device(device)))
    discriminator.load_state_dict(torch.load('/content/drive/My Drive/Colab/saved_models/discriminator.pkl', map_location=torch.device(device)))

In [None]:
criterion = nn.BCELoss()
fixed_noise = torch.randn(64, size_z, 1, 1, device=device)

real_label = 1.
fake_label = 0.

optimizerG = optim.Adam(generator.parameters(), lr=learning_rate, betas=(adam_beta1, 0.999))
optimizerD = optim.Adam(discriminator.parameters(), lr=learning_rate, betas=(adam_beta1, 0.999))

def train_gan(dataset):
    img_list = []
    G_losses = []
    D_losses = []
    iters = 0
    #dataloader = torch.utils.data.DataLoader(dataset=dataset,
    #                                     batch_size=batch_size,
    #                                     shuffle=True)

    dataloader = dataset

    print("Starting Training Loop...")
    for epoch in range(num_epochs):
        for i, (real_images, _) in enumerate(dataloader):
            # get batch-size from actual image batch
            bs = real_images.shape[0]

            # -- train discriminator --

            # reset/clear discriminators gradient
            discriminator.zero_grad()

            # move images to either CPU or GPU
            real_images = real_images.to(device)

            # creates a label tensor filled with 1s
            label = torch.full((bs,), real_label, dtype=torch.float, device=device)

            # get probs for discriminators guess on the real images
            output, _ = discriminator(real_images)

            # get loss for real images. that means it calculates the difference
            # between the output of the model with the current parameter and the
            # target (goal) of what the model is supposed to do
            # output --> current outcome of the model
            # label  --> target of the model
            lossD_real = criterion(output, label)

            # calculates the gradient (using chain-rule)
            # see https://pytorch.org/docs/stable/generated/torch.Tensor.backward.html
            lossD_real.backward()

            # Gets the mean value of all results from the discriminator to get an average
            # probability of all sample evaluations (for real data ) --> D(x)
            D_x = output.mean().item()

            # create noise as an input for the G in order to create fake images
            noise = torch.randn(bs, size_z, 1, 1, device=device)

            # use generator to map input noise to an output that is supposed do become fake images during training
            fake_images = generator(noise)

            # creates a label tensor filled with 0s
            label.fill_(fake_label)

            # get discriminators guess on fake images
            output, _ = discriminator(fake_images.detach())

            # get loss for fake images
            lossD_fake = criterion(output, label)

            # adjust parameter to identify fakes
            lossD_fake.backward()

            # gets the mean value of all results from the discriminator to get an average
            # probability of all sample evaluations. this time for the fake images that were
            # generated by the generator --> D(G(z))
            D_G_z1 = output.mean().item()

            # calculate loss
            lossD = lossD_real + lossD_fake

            # adjust models (discriminator) parameter
            optimizerD.step()

            # -- train generator --

            # reset/clear generators gradient
            generator.zero_grad()

            # creates a label tensor filled with 1s
            label.fill_(real_label)

            # get discriminators guess on fake images
            output, _ = discriminator(fake_images)
            output = output.view(-1)

            # get loss for fake images
            lossG = criterion(output, label)

            # adjust parameter to generate fakes
            lossG.backward()

            # gets the mean value of all results from the discriminator to get an average
            # probability of all sample evaluations. this time for the fake images that were
            # generated by the generator --> D(G(z))
            D_G_z2 = output.mean().item()

            # adjust models (generator) parameter
            optimizerG.step()
            # Save Losses for plotting later
            G_losses.append(lossG.item())
            D_losses.append(lossD.item())

            # Check how the generator is doing by saving G's output on fixed_noise
            if (iters % 500 == 0) or ((epoch == num_epochs - 1) and (i == len(dataloader) - 1)):
                with torch.no_grad():
                    fake = generator(fixed_noise).detach().cpu()
                img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
            iters += 1

        print('[%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
              % (epoch+1, num_epochs, lossD.item(), lossG.item(), D_x, D_G_z1, D_G_z2))

        if epoch>0 and epoch%50 == 0:
          save_checkpoint(epoch)

    plt.figure(figsize=(10, 5))
    plt.title("Generator and Discriminator Loss During Training")
    plt.plot(G_losses, label="G")
    plt.plot(D_losses, label="D")
    plt.xlabel("iterations")
    plt.ylabel("Loss")
    plt.legend()
    plt.show()

    plt.rcParams['animation.embed_limit'] = 100
    fig = plt.figure(figsize=(8, 8))
    plt.axis("off")
    ims = [[plt.imshow(np.transpose(i, (1, 2, 0)), animated=True)] for i in img_list]
    ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)
    HTML(ani.to_jshtml())

In [None]:
train_gan(ano_class_mnist_dataset)
save_models()

Latent Space Exploration and generation of Direction-Matrix A

In [None]:
class LatentSpaceMapper:
    def __init__(self, generator: Generator, discriminator: Discriminator, device):
        self.generator: Generator = generator
        self.generator.to(device)
        self.discriminator: Discriminator = discriminator
        self.discriminator.to(device)
        self.device = device

    def map_image_to_point_in_latent_space(self, image: torch.Tensor, batch_size=1, size_z=100, max_opt_iterations=30000, opt_threshold=140.0, plateu_threshold=3.0, check_every_n_iter=4000, learning_rate=0.4, print_every_n_iters=10000, ignore_rules_below_threshold=50, retry_after_n_iters=10000, immediate_retry_threshold=200):
        image.to(self.device)
        z = torch.randn(batch_size, size_z, 1, 1, device=self.device, requires_grad=True)
        z_optimizer = torch.optim.Adam([z], lr=learning_rate)
        losses = []
        final_loss = 0
        latest_checkpoint_loss = 0

        # scheduler = lr_scheduler.LinearLR(z_optimizer, start_factor=0.4, end_factor=0.001, total_iters=max_opt_iterations-(math.floor(max_opt_iterations*0.2)))
        # scheduler = lr_scheduler.StepLR(z_optimizer, step_size=max_opt_iterations, gamma=0.9)
        # scheduler = torch.optim.lr_scheduler.CyclicLR(z_optimizer, base_lr=0.01, max_lr=0.4, cycle_momentum=False)
        for i in range(max_opt_iterations):
            retry = False
            loss = self.__get_anomaly_score(z, image.unsqueeze(0).to(self.device))
            loss.backward()
            z_optimizer.step()
            final_loss = loss.data.item()

            if i == 1:
                latest_checkpoint_loss = loss.data.item()

            if loss.data.item() < opt_threshold*batch_size:
                print(f"Iteration: {i} -- Reached Defined Optimum -- Final Loss: {loss.data.item()}")
                break

            if (i % print_every_n_iters == 0 and i != 0) or (i == max_opt_iterations-1):
                print(f"Iteration: {i} -- Current Loss: {loss.data.item()} -- Current Learning-Rate: {z_optimizer.param_groups[0]['lr']}")
                losses.append(loss.data.item())

            if i % check_every_n_iter == 0 and i != 0:
                if abs(loss.data.item()-latest_checkpoint_loss) < plateu_threshold:
                    print(f"Reached Plateu at Iteration {i} -- Loss: {loss.data.item()}")
                    retry = True
                    break
                if loss.data.item() > immediate_retry_threshold:
                    print(f"Loss at Iteration {i} too high -- Loss: {loss.data.item()}")
                    retry = True
                    break
                latest_checkpoint_loss = loss.data.item()

            if i == retry_after_n_iters and loss.data.item() > ignore_rules_below_threshold:
                retry = True
                break

            #scheduler.step()

        return z, final_loss, retry

    def __get_anomaly_score(self, z, x_query):
        lamda = 0.1
        g_z = self.generator(z.to(self.device))
        loss_r = torch.sum(torch.abs(x_query - g_z))

        return loss_r

        #_, x_prop = self.discriminator(x_query)
        #_, g_z_prop = self.discriminator(g_z)
        #loss_d = torch.sum(torch.abs(x_prop - g_z_prop))
        #return (1 - lamda) * loss_r + lamda * loss_d

In [None]:
generator.load_state_dict(torch.load("/content/drive/MyDrive/Colab/saved_models/generator.pkl", map_location=torch.device(device)))
discriminator.load_state_dict(torch.load('/content/drive/MyDrive/Colab/saved_models/discriminator.pkl', map_location=torch.device(device)))

In [None]:
def create_cp(iteration_number):
  print("CREATING CHECKPOINT...")
  drive.mount('/content/drive', force_remount=True)
  shutil.make_archive(f"/content/drive/MyDrive/Colab/data/latent_space_mappings_cp/latent_space_mappings_cp{iteration_number}", 'zip', "/content/data/latent_space_mappings")

def save_to_drive(mapped_z, iteration_number, csv_path):
  torch.save(mapped_z, f'/content/drive/MyDrive/Colab/data/latent_space_mappings/mapped_z_{iteration_number}.pt')
  shutil.copy(csv_path, "/content/drive/MyDrive/Colab/data/latent_space_mappings/latent_space_mappings.csv")

base_folder = "/content/data/latent_space_mappings"
csv_path = os.path.join(base_folder, "latent_space_mappings.csv")


prepare_target_folder = True
if prepare_target_folder:
    if not os.path.exists(base_folder):
        os.makedirs(base_folder, exist_ok=True)

    if os.path.exists(base_folder):
        shutil.rmtree(base_folder)
        os.makedirs(base_folder, exist_ok=True)

    with open(csv_path, 'a', newline='') as file:
        writer = csv.writer(file)
        fields = ["filename", "label", "reconstruction_loss"]
        writer.writerow(fields)

In [None]:
t = transforms.ToPILImage()
lsm: LatentSpaceMapper = LatentSpaceMapper(generator=generator, discriminator=discriminator, device=device)
mapped_images = []
cp_counter = 0
counter = len(ano_class_mnist_dataset)

i = 0
retry_counter = 0
iterator = iter(ano_class_mnist_dataset)
d = next(iterator)

In [None]:
while counter > 0:
    if i % 50 == 0 and i != 0:
        clear_output()
        print("Cleared Output...")

    print(f"{counter} images left")
    print(f"Label: {d[1]['label'].item()}")

    max_retries = 5
    opt_threshold=60
    ignore_rules_below_threshold=75
    immediate_retry_threshold=110
    max_opt_iterations=20000

    if (d[1]['anomaly'].item() == True and map_anomalies) or (d[1]['anomaly'].item() == False and map_normals):
        mapped_z, reconstruction_loss, retry = lsm.map_image_to_point_in_latent_space(d[0][0],
                                                                            batch_size=1,
                                                                            max_opt_iterations=max_opt_iterations,
                                                                            plateu_threshold=0.0005,
                                                                            check_every_n_iter=5000,
                                                                            learning_rate=0.01,
                                                                            print_every_n_iters=5000,
                                                                            retry_after_n_iters=30000,
                                                                            ignore_rules_below_threshold=ignore_rules_below_threshold,
                                                                            opt_threshold=opt_threshold,
                                                                            immediate_retry_threshold=immediate_retry_threshold)

        if retry:
            if retry_counter == max_retries:
                retry_counter = 0
                i+=1
                counter-=1
                print("Retry Limit reached. Moving on to next sample")
                print('Original Image That Could Not Be Mapped')
                display(t(original_img[0]).resize((128, 128), PIL.Image.NEAREST))
                d = next(iterator)
                print('-----------------------')
                continue
            else:
                retry_counter += 1
                print(f"Could not find optimal region within the defined iteration count. Retry ({retry_counter}) with another random z...")
                continue

        retry_counter = 0
        mapped_images.append(mapped_z)
        with open(csv_path, 'a', newline='') as file:
            writer = csv.writer(file)
            fields = [f'mapped_z_{counter}.pt', d[1]['label'].item(), math.floor(reconstruction_loss)]
            writer.writerow(fields)

        torch.save(mapped_z, os.path.join(base_folder, f'mapped_z_{counter}.pt'))
        save_to_drive(mapped_z, counter, csv_path)
        cp_counter += 1
        if cp_counter % 50 == 0:
            create_cp(counter)
            clear_output

        print('Original Image')
        display(t(d[0][0]).resize((128, 128), PIL.Image.NEAREST))
        print('Mapped and Reconstructed Image')
        original_img = generator(mapped_z).cpu()
        display(t(original_img[0]).resize((128, 128), PIL.Image.NEAREST))
        print('-----------------------')

    i+=1
    counter-=1
    d = next(iterator)

create_cp(0)