<a href="https://colab.research.google.com/github/yasharha/AD-with-GANs/blob/add-anogan-img-to-latent-space-mapping/src/ml/latent_space_mapping.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)
!cp -r /content/drive/MyDrive/Colab/data /content/data

%cd /content/drive/MyDrive/Colab
%ls

Mounted at /content/drive
/content/drive/MyDrive/Colab
[0m[01;34mAdGAN[0m/  [01;34mdata[0m/  [01;34msaved_models[0m/  [01;34msrc[0m/


In [2]:
import torch
import torch.nn as nn
from PIL import Image
from torchvision import transforms, datasets

In [3]:
class Generator(nn.Module):
    def __init__(self, size_z, num_feature_maps, num_color_channels):
        super(Generator, self).__init__()
        self.size_z = size_z
        self.network = nn.Sequential(
            nn.ConvTranspose2d(self.size_z, num_feature_maps * 4, 4, 1, 0, bias=False),
            nn.BatchNorm2d(num_feature_maps * 4),
            nn.ReLU(True),

            nn.ConvTranspose2d(num_feature_maps * 4, num_feature_maps * 2, 3, 2, 1, bias=False),
            nn.BatchNorm2d(num_feature_maps * 2),
            nn.ReLU(True),

            nn.ConvTranspose2d(num_feature_maps * 2, num_feature_maps, 4, 2, 1, bias=False),
            nn.BatchNorm2d(num_feature_maps),
            nn.ReLU(True),

            nn.ConvTranspose2d(num_feature_maps, num_color_channels, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, x):
        output = self.network(x)
        return output

    def gen_shifted(self, x, shift):
        shift = torch.unsqueeze(shift, -1)
        shift = torch.unsqueeze(shift, -1)
        return self.forward(x + shift)

In [4]:
class Discriminator(nn.Module):
    def __init__(self, num_feature_maps, num_color_channels):
        super(Discriminator, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(num_color_channels, num_feature_maps, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
        )
        self.layer2 = nn.Sequential(
            nn.Conv2d(num_feature_maps, num_feature_maps * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(num_feature_maps * 2),
            nn.LeakyReLU(0.2, inplace=True),
        )
        self.layer3 = nn.Sequential(
            nn.Conv2d(num_feature_maps * 2, num_feature_maps * 4, 3, 2, 1, bias=False),
            nn.BatchNorm2d(num_feature_maps * 4),
            nn.LeakyReLU(0.2, inplace=True),
        )
        self.fc = nn.Sequential(
            nn.Conv2d(num_feature_maps * 4, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = self.layer3(out)
        feature = out
        out = self.fc(out)
        return out.view(-1, 1).squeeze(1), feature

In [5]:
class LatentSpaceMapper:

    def __init__(self, generator: Generator, discriminator: Discriminator, device):
        self.generator: Generator = generator
        self.discriminator: Discriminator = discriminator
        self.device = device

    def map_image_to_point_in_latent_space(self, image: torch.Tensor, size_z=100, opt_iterations=10000):
        z = torch.randn(1, size_z, 1, 1, device=self.device, requires_grad=True)
        z_optimizer = torch.optim.Adam([z], lr=1e-4)
        losses = []

        for i in range(opt_iterations):
            loss = self.__get_anomaly_score(z, image.unsqueeze(0).to(self.device))
            loss.backward()
            z_optimizer.step()
            if i % 1000 == 0:
                # print(f"Iteration: {i} -- Loss: {loss.data.item()}")
                losses.append(loss.data.item())

        return z

    def __get_anomaly_score(self, z, x_query):
        lamda = 0.1
        g_z = self.generator(z.to(self.device))
        _, x_prop = self.discriminator(x_query)
        _, g_z_prop = self.discriminator(g_z)

        loss_r = torch.sum(torch.abs(x_query - g_z))
        loss_d = torch.sum(torch.abs(x_prop - g_z_prop))

        return (1 - lamda) * loss_r + lamda * loss_d

In [6]:
from torch.utils.data import Dataset
import pandas as pd
import os


class AnoMNIST(Dataset):
    def __init__(self, root_dir, transform=None):
        root_dir = os.path.join(root_dir, "AnoMNIST")
        assert os.path.exists(os.path.join(root_dir, "anomnist_dataset.csv")), "Invalid root directory"
        self.root_dir = root_dir
        self.transform = transform
        self.label = pd.read_csv(os.path.join(root_dir, "anomnist_dataset.csv"))

    def __len__(self):
        return len(self.label)

    def __getitem__(self, idx):
        img_name = os.path.join(self.root_dir, self.label.iloc[idx, 0])
        image_label = self.label.iloc[idx, 1]
        image = Image.open(img_name)

        if self.transform:
            image = self.transform(image)

        return image, image_label

In [7]:
def get_ano_mnist_dataset(transform, root_dir, labels=[], train_size=0.9):
    ano_mnist_dataset = AnoMNIST(
        root_dir=root_dir,
        transform=transform
    )

    mnist_dataset = datasets.MNIST(
        root=root_dir,
        train=True,
        transform=transform,
        download=True,
    )

    dat = torch.utils.data.ConcatDataset([ano_mnist_dataset, mnist_dataset])

    if len(labels) > 0:
        dat = [d for d in dat if (d[1] in labels)]

    absolute_train_size = int(len(dat) * train_size)
    absolute_test_size = len(dat) - absolute_train_size
    return torch.utils.data.random_split(dat, [absolute_train_size, absolute_test_size])

In [9]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
num_color_channels = 1
num_feature_maps_g = 64
num_feature_maps_d = 64
size_z = 100

device

device(type='cuda')

In [11]:
generator = Generator(size_z=size_z,
                      num_feature_maps=num_feature_maps_g,
                      num_color_channels=num_color_channels).to(device)
discriminator = Discriminator(num_feature_maps=num_feature_maps_d,
                              num_color_channels=num_color_channels).to(device)

generator.load_state_dict(torch.load("./saved_models/generator.pkl", map_location=torch.device(device)))
discriminator.load_state_dict(torch.load('./saved_models/discriminator.pkl', map_location=torch.device(device)))

<All keys matched successfully>

In [12]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(.5,), std=(.5,))
])
ano_mnist_dataset, ano_mnist_dataset_test = get_ano_mnist_dataset(transform=transform, root_dir="./data", labels=[9])

In [None]:
# tpi = transforms.ToPILImage()
# test_img = ano_mnist_dataset[2][0]
# img = tpi(torch.squeeze(test_img))
# img.show()

In [None]:
lsm: LatentSpaceMapper = LatentSpaceMapper(generator=generator, discriminator=discriminator, device=device)
mapped_images = []
counter = len(ano_mnist_dataset)
for img in ano_mnist_dataset:
    print(f"{counter} images left")
    mapped_z = lsm.map_image_to_point_in_latent_space(img[0], opt_iterations=7000)
    mapped_images.append(mapped_z)
    torch.save(mapped_z, f'./latent_space_mappings/mapped_z_{counter}.pt')
    counter-=1

5758 images left
5757 images left
5756 images left
5755 images left
5754 images left
5753 images left
5752 images left
5751 images left
5750 images left
5749 images left
5748 images left
5747 images left
5746 images left
5745 images left
5744 images left
5743 images left
5742 images left
5741 images left
5740 images left
5739 images left
5738 images left
5737 images left
5736 images left
5735 images left
5734 images left
5733 images left
5732 images left
5731 images left
5730 images left
5729 images left
5728 images left
5727 images left
5726 images left
5725 images left
5724 images left
5723 images left
5722 images left
5721 images left
5720 images left
5719 images left
5718 images left
5717 images left
5716 images left
5715 images left
5714 images left
5713 images left
5712 images left
5711 images left
5710 images left
5709 images left
5708 images left
5707 images left
5706 images left
5705 images left
5704 images left
5703 images left
5702 images left
5701 images left
5700 images le

ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.

ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.

ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.



Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py", line 3553, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-16-30d1d71889fc>", line 8, in <cell line: 4>
    torch.save(mapped_z, f'./latent_space_mappings/mapped_z_{counter}.pt')
  File "/usr/local/lib/python3.10/dist-packages/torch/serialization.py", line 440, in save
    with _open_zipfile_writer(f) as opened_zipfile:
  File "/usr/local/lib/python3.10/dist-packages/torch/serialization.py", line 315, in _open_zipfile_writer
    return container(name_or_buffer)
  File "/usr/local/lib/python3.10/dist-packages/torch/serialization.py", line 288, in __init__
    super().__init__(torch._C.PyTorchFileWriter(str(name)))
RuntimeError: Parent directory ./latent_space_mappings does not exist.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packa

In [None]:
# img = generator(mapped_z)
# img = tpi(torch.squeeze(img))
# img.show()