# Import Libraries

In [None]:
import os
import zipfile
import glob

import numpy as np
import matplotlib.pyplot as plt
import xml.etree.ElementTree as ET

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms

from PIL import Image
from tqdm.notebook import tqdm
from dataclasses import dataclass

## Extract Data

In [None]:
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

In [None]:
with zipfile.ZipFile("../input/generative-dog-images/all-dogs.zip","r") as zip_ref:
    zip_ref.extractall("/kaggle/temp/")
    
with zipfile.ZipFile("../input/generative-dog-images/Annotation.zip","r") as zip_ref:
    zip_ref.extractall("/kaggle/temp/")

## Lisiting Extracted Data

In [None]:
!ls /kaggle/temp/Annotation/

In [None]:
!ls /kaggle/temp/all-dogs

# Initializing Constants

In [None]:
ROOT = '/kaggle/temp/'

ANNOT_PATH = ROOT + 'Annotation/'
IMAGE_PATH = ROOT + 'all-dogs/'

annotations = os.listdir(ANNOT_PATH) 
images = os.listdir(IMAGE_PATH)

In [None]:
@dataclass
class TrainConfig:
    num_workers: int = 4
    epochs: int = 50
    batch_size: int = 64
    generate_size: int = 10_000
    save_epoch: int = 5
        
    mean: float = 0.5
    std: float = 0.5
        
    num_channels: int = 3
    image_size: int = 64
    feature_size: int = 64
    noise_size: int = 100
    embedding_dim: int = 256
    attention: bool = True

    device: 'typing.Any' = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
train_config = TrainConfig()

# Preprocessing

Took the help of https://www.kaggle.com/tikutiku/gan-dogs-starter-biggan

In [None]:
print(f"Number of breeds: {len(annotations)}")
print(f"Number of images: {len(images)}")

In [None]:
for breed in annotations:
    annotations += glob.glob(breed+'/*')
print(f"Number of available annotations: {len(annotations)}")

In [None]:
breed_map = {}
for annotation in annotations:
    index, *breed = annotation.split("-")
    breed_map[index] = index + "-" + "-".join(breed)
print(f"Number of breeds in breed_map: {len(breed_map)}")

In [None]:
breed_map

In [None]:
def bounding_box(annot_path, image):
    bounding_path = annot_path + str(breed_map[image.split("_")[0]]) + "/" + str(image.split(".")[0])
    tree = ET.parse(bounding_path)
    root = tree.getroot()
    objects = root.findall("object")
    bboxes = []
    for o in objects:
        bound_box = o.find("bndbox")
        x_min = int(bound_box.find("xmin").text)
        y_min = int(bound_box.find("ymin").text)
        x_max = int(bound_box.find("xmax").text)
        y_max = int(bound_box.find("ymax").text)
        bboxes.append((x_min, y_min, x_max, y_max))
    return bboxes

In [None]:
def bounding_box_ratio(annot_path, image):
    bounding_path = annot_path + str(breed_map[image.split("_")[0]]) + "/" + str(image.split(".")[0])
    tree = ET.parse(bounding_path)
    root = tree.getroot()
    objects = root.findall("object")
    bbox_ratios = []
    for o in objects:
        bound_box = o.find("bndbox")
        x_min = int(bound_box.find("xmin").text)
        y_min = int(bound_box.find("ymin").text)
        x_max = int(bound_box.find("xmax").text)
        y_max = int(bound_box.find("ymax").text)
        x_len = x_max - x_min
        y_len = y_max - y_min
        ratio = y_len / x_len
        bbox_ratios.append((x_len, y_len, ratio))
    return bbox_ratios

In [None]:
%%time

#threshold for aspect ratio, at the same time idx for each bbx
images_th = []

for image in tqdm(images):
    bbox_ratios = bounding_box_ratio(ANNOT_PATH, image)
    for i,(x_len, y_len, ratio) in enumerate(bbox_ratios):
        if ((ratio > 0.2) & (ratio < 4.0)):
            images_th.append(image[:-4] + '_' + str(i) + '.jpg')

print(f"Original Length: {len(images)}")
print(f"After Thresholding Length: {len(images_th)}")

In [None]:
#from https://www.kaggle.com/korovai/dogs-images-intruders-extraction
intruders = [
    #n02088238-basset
    'n02088238_10870_0.jpg',
    
    #n02088466-bloodhound
    'n02088466_6901_1.jpg',
    'n02088466_6963_0.jpg',
    'n02088466_9167_0.jpg',
    'n02088466_9167_1.jpg',
    'n02088466_9167_2.jpg',
    
    #n02089867-Walker_hound
    'n02089867_2221_0.jpg',
    'n02089867_2227_1.jpg',
    
    #n02089973-English_foxhound # No details
    'n02089973_1132_3.jpg',
    'n02089973_1352_3.jpg',
    'n02089973_1458_1.jpg',
    'n02089973_1799_2.jpg',
    'n02089973_2791_3.jpg',
    'n02089973_4055_0.jpg',
    'n02089973_4185_1.jpg',
    'n02089973_4185_2.jpg',
    
    #n02090379-redbone
    'n02090379_4673_1.jpg',
    'n02090379_4875_1.jpg',
    
    #n02090622-borzoi # Confusing
    'n02090622_7705_1.jpg',
    'n02090622_9358_1.jpg',
    'n02090622_9883_1.jpg',
    
    #n02090721-Irish_wolfhound # very small
    'n02090721_209_1.jpg',
    'n02090721_1222_1.jpg',
    'n02090721_1534_1.jpg',
    'n02090721_1835_1.jpg',
    'n02090721_3999_1.jpg',
    'n02090721_4089_1.jpg',
    'n02090721_4276_2.jpg',
    
    #n02091032-Italian_greyhound
    'n02091032_722_1.jpg',
    'n02091032_745_1.jpg',
    'n02091032_1773_0.jpg',
    'n02091032_9592_0.jpg',
    
    #n02091134-whippet
    'n02091134_2349_1.jpg',
    'n02091134_14246_2.jpg',
    
    #n02091244-Ibizan_hound
    'n02091244_583_1.jpg',
    'n02091244_2407_0.jpg',
    'n02091244_3438_1.jpg',
    'n02091244_5639_1.jpg',
    'n02091244_5639_2.jpg',
    
    #n02091467-Norwegian_elkhound
    'n02091467_473_0.jpg',
    'n02091467_4386_1.jpg',
    'n02091467_4427_1.jpg',
    'n02091467_4558_1.jpg',
    'n02091467_4560_1.jpg',
    
    #n02091635-otterhound
    'n02091635_1192_1.jpg',
    'n02091635_4422_0.jpg',
    
    #n02091831-Saluki
    'n02091831_1594_1.jpg',
    'n02091831_2880_0.jpg',
    'n02091831_7237_1.jpg',
    
    #n02092002-Scottish_deerhound
    'n02092002_1551_1.jpg',
    'n02092002_1937_1.jpg',
    'n02092002_4218_0.jpg',
    'n02092002_4596_0.jpg',
    'n02092002_5246_1.jpg',
    'n02092002_6518_0.jpg',
    
    #02093256-Staffordshire_bullterrier
    'n02093256_1826_1.jpg',
    'n02093256_4997_0.jpg',
    'n02093256_14914_0.jpg',
    
    #n02093428-American_Staffordshire_terrier
    'n02093428_5662_0.jpg',
    'n02093428_6949_1.jpg'
            ]

print(f"Number of intruders: {len(intruders)}")

In [None]:
class DogDataset(torch.utils.data.Dataset):
    
    def __init__(self, annot_path, image_path, image_list, label_map, transform=None, intruders=[]):
        self.ANNOT_PATH = annot_path
        self.IMAGE_PATH = image_path
        
        self.image_list = image_list
        
        self.transform = transform
        
        self.images = []
        self.labels = []
        
        for image_path in self.image_list:
            if image_path in intruders:
                continue
                
            *image_name, bbox_idx = image_path.split("_")
            image_path = "_".join(image_name) + ".jpg"
            image = self._data_preprocessing(image_path, int(bbox_idx.split(".")[0]))
            
            if self.transform:
                image = self.transform(image)
                
            self.images.append(image)
            
            label = label_map[image_name[0]]
            self.labels.append(label)
            
    def _data_preprocessing(self, image_path, bbox_idx):
        bbox = bounding_box(self.ANNOT_PATH, image_path)[bbox_idx]
        image  = Image.open(os.path.join(self.IMAGE_PATH, image_path)) # PILImage format
        cropped_image  = image.crop(bbox)
        return cropped_image
        
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]
        
        return {"images": image, "labels": label}

In [None]:
%%time

n_classes = len(annotations)

label_map = {breed: i for i, breed in enumerate(breed_map.keys())}

transform = transforms.Compose([transforms.Resize((train_config.image_size, train_config.image_size)), 
                                 transforms.ToTensor(), 
                                 transforms.Normalize(mean=(train_config.mean), std=(train_config.std))
                                ])

train_set = DogDataset(annot_path=ANNOT_PATH, 
                       image_path=IMAGE_PATH, 
                       image_list=images_th, 
                       label_map=label_map, 
                       transform=transform, 
                       intruders=intruders
                      )

train_loader = torch.utils.data.DataLoader(train_set,
                          shuffle=True, batch_size=train_config.batch_size,
                          num_workers=train_config.num_workers, pin_memory=True)

In [None]:
def show_grid(image):
  npimage = image.numpy()
  plt.imshow(np.transpose(npimage, (1, 2, 0)))
  plt.show()

In [None]:
iter_loader = iter(train_loader)
data = iter_loader.next()
show_grid(torchvision.utils.make_grid(data["images"], normalize=True))
print(data["labels"])

# Creating Model

## Helper Functions

In [None]:
def initialize_weights(m):
    if isinstance(m, nn.Linear):
        nn.init.xavier_uniform_(m.weight.data)
        m.bias.data.fill_(0)

In [None]:
def l2_normalize(v, eps=1e-12):
    return v / (v.norm() + eps)

## Self Attention Layer

In [None]:
class SelfAttentionLayer(nn.Module):
    def __init__(self, in_dim):
        super().__init__()
        
        self.query_conv = nn.Conv2d(in_channels = in_dim, out_channels = in_dim // 2, kernel_size = 1)
        self.key_conv = nn.Conv2d(in_channels = in_dim, out_channels = in_dim // 2, kernel_size = 1)
        self.value_conv = nn.Conv2d(in_channels = in_dim, out_channels = in_dim, kernel_size = 1)
        
        self.gamma = nn.Parameter(torch.zeros(1))
        self.softmax = nn.Softmax(dim = -1)
        
    def forward(self, x):
        batch_size, C, width, height = x.size()
        query = self.query_conv(x).view(batch_size, -1, width * height).permute(0, 2, 1)
        key = self.key_conv(x).view(batch_size, -1, width * height)
        energy = torch.bmm(query, key)
        
        attn = self.softmax(energy)
        value = self.value_conv(x).view(batch_size, -1, width * height)
        out = torch.bmm(value, attn.permute(0, 2, 1))
        out = out.view(batch_size, C, width, height)
        
        out = self.gamma * out + x
        return out, attn

## Spectral Normalization Layer

In [None]:
class SpectralNormLayer(nn.Module):
    def __init__(self, module, name="weight", power_iterations=1):
        super().__init__()
        self.module = module
        self.name = name
        self.power_iterations = power_iterations
        
        if not self._made_params():
            self._make_params()
            
    def _update_u_v(self):
        u = getattr(self.module, self.name + "_u")
        v = getattr(self.module, self.name + "_v")
        w = getattr(self.module, self.name + "_bar")
        
        height = w.data.shape[0]
        for _ in range(self.power_iterations):
            v.data = l2_normalize(torch.mv(torch.t(w.view(height, -1).data), u.data))
            u.data = l2_normalize(torch.mv(w.view(height, -1).data, v.data))
            
        sigma = u.dot(w.view(height, -1).mv(v))
        setattr(self.module, self.name, w / sigma.expand_as(w))
        
    def _made_params(self):
        try:
            u = getattr(self.module, self.name + "_u")
            v = getattr(self.module, self.name + "_v")
            w = getattr(self.module, self.name + "_bar")
            return True
        except AttributeError:
            return False
    
    def _make_params(self):
        w = getattr(self.module, self.name)
        
        height = w.data.shape[0]
        width = w.view(height, -1).data.shape[1]
        
        u = nn.Parameter(w.data.new(height).normal_(0, 1), requires_grad = False)
        v = nn.Parameter(w.data.new(width).normal_(0, 1), requires_grad = False)
        u.data = l2_normalize(u.data)
        v.data = l2_normalize(v.data)
        w_bar = nn.Parameter(w.data)
        
        del self.module._parameters[self.name]
        
        self.module.register_parameter(self.name + "_u", u)
        self.module.register_parameter(self.name + "_v", v)
        self.module.register_parameter(self.name + "_bar", w_bar)
        
    def forward(self, *args):
        self._update_u_v()
        return self.module.forward(*args)

## Generator

In [None]:
class Generator(nn.Module):
    def __init__(self, noise_size, embedding_dim, num_classes, num_channels, image_size, attn = True, feature_size = 64):
        super().__init__()
        self.attn = attn
        self.noise_dim = noise_size
        
        assert image_size == 64, "Cannot produce images other than size 64x64!"
        self.image_size = image_size
        
        # Embedding Layer
        self.label_embedding = nn.Embedding(num_classes, embedding_dim)
        
        # Linear Layer
        self.linear_layer = nn.Sequential(
            nn.Linear(embedding_dim, noise_size),
            nn.BatchNorm1d(noise_size),
            nn.ReLU()
        )
        
        # Size 1 -> 4
        self.layer_1 = nn.Sequential(
            SpectralNormLayer(nn.ConvTranspose2d(noise_size, feature_size * 8, kernel_size = 4)),
            nn.BatchNorm2d(feature_size * 8),
            nn.ReLU()
        )
        
        # Size 4 -> 8
        self.layer_2 = nn.Sequential(
            SpectralNormLayer(nn.ConvTranspose2d(feature_size * 8, feature_size * 4, kernel_size = 4, stride = 2, padding = 1)),
            nn.BatchNorm2d(feature_size * 4),
            nn.ReLU()
        )
        
        # Size 8 -> 16
        self.layer_3 = nn.Sequential(
            SpectralNormLayer(nn.ConvTranspose2d(feature_size * 4, feature_size * 2, kernel_size = 4, stride = 2, padding = 1)),
            nn.BatchNorm2d(feature_size * 2),
            nn.ReLU()
        )
        
        # Attention Layer 1
        self.attention_1 = SelfAttentionLayer(feature_size * 2)
        
        # Size 16 -> 32
        self.layer_4 = nn.Sequential(
            SpectralNormLayer(nn.ConvTranspose2d(feature_size * 2, feature_size * 2, kernel_size = 4, stride = 2, padding = 1)),
            nn.BatchNorm2d(feature_size * 2),
            nn.ReLU()
        )
        
        # Attention Layer 2
        self.attention_2 = SelfAttentionLayer(feature_size * 2)
        
        # Output layer 32 -> 64
        self.output_layer = nn.Sequential(
            nn.ConvTranspose2d(feature_size * 2, 3, kernel_size = 4, stride = 2, padding = 1),
            nn.Tanh()
        )
        
        self.optimizer = optim.Adam(self.parameters(), lr = 0.0002, betas = (0.5, 0.999))
        
    def forward(self, noise, labels):
        
        label_embed = self.label_embedding(labels)
        linear_out = self.linear_layer(label_embed)
        
        x = torch.mul(linear_out, noise)
        x = x.view(x.shape[0], -1, 1, 1)
        x = self.layer_1(x)
        x = self.layer_2(x)
        x = self.layer_3(x)

        if self.attn:
            x, _ = self.attention_1(x)
        
        x = self.layer_4(x)
        
        if self.attn:
            x, _ = self.attention_2(x)
            
        outputs = self.output_layer(x)
        return outputs

In [None]:
class Discriminator(nn.Module):
    def __init__(self, embedding_dim, num_classes, num_channels, image_size, attn = True, feature_size = 64):
        super().__init__()
        self.embedding_dim = embedding_dim
        self.attn = attn
        
        assert image_size == 64, "Cannot create model for images other than size 64x64!"
        self.image_size = image_size
        
        # Size 64 -> 32
        self.layer_1 = nn.Sequential(
            SpectralNormLayer(nn.Conv2d(num_channels, feature_size, 4, 2, 1)),
            nn.LeakyReLU(0.1)
        )
        
        # Size 32 -> 16
        self.layer_2 = nn.Sequential(
            SpectralNormLayer(nn.Conv2d(feature_size, feature_size * 2, 4, 2, 1)),
            nn.LeakyReLU(0.1)
        )
        
        # Size 16 -> 8
        self.layer_3 = nn.Sequential(
            SpectralNormLayer(nn.Conv2d(feature_size * 2, feature_size * 2, 4, 2, 1)),
            nn.LeakyReLU(0.1)
        )
        
        # Attention Layer 1
        self.attention_1 = SelfAttentionLayer(feature_size * 2)
        
        # Size 8 -> 4
        self.layer_4 = nn.Sequential(
            SpectralNormLayer(nn.Conv2d(feature_size * 2, feature_size * 4, 4, 2, 1)),
            nn.LeakyReLU(0.1)
        )
        
        # Attention Layer 2
        self.attention_2 = SelfAttentionLayer(feature_size * 4)
        
        # Embedding Layer
        self.label_embedding = nn.Embedding(num_classes, embedding_dim)
        
        # Linear Layer
        self.image_label_layer = nn.Sequential(
            nn.Conv2d(embedding_dim + feature_size * 4, feature_size * 8, 1, 1, 0, bias = False),
            nn.BatchNorm2d(feature_size * 8),
            nn.LeakyReLU(0.2)
        )
        
        # Output Layer
        self.output_layer = nn.Sequential(
            nn.Conv2d(feature_size * 8, 1, 4, 1, 0, bias = False),
            nn.Sigmoid()
        )
        
        self.optimizer = optim.Adam(self.parameters(), lr = 0.0002, betas = (0.5, 0.999))
        
    def forward(self, inputs, labels):
        x = self.layer_1(inputs)
        x = self.layer_2(x)
        x = self.layer_3(x)
        
        if self.attn:
            x, _ = self.attention_1(x)
            
        x = self.layer_4(x)
        
        if self.attn:
            x, _ = self.attention_2(x)
            
        label_embed = self.label_embedding(labels)
        label_embed = label_embed.unsqueeze(2).unsqueeze(2).repeat(1, 1, 4, 4)
        
        x = torch.cat([x, label_embed], dim = 1)
        x = self.image_label_layer(x)
        outputs = self.output_layer(x)
        return outputs.view(-1, 1)

In [None]:
generator = Generator(train_config.noise_size, train_config.embedding_dim, n_classes, train_config.num_channels, train_config.image_size)
generator.apply(initialize_weights)
generator.to(train_config.device)

In [None]:
discriminator = Discriminator(train_config.noise_size, n_classes, train_config.num_channels, train_config.image_size)
discriminator.apply(initialize_weights)
discriminator.to(train_config.device)

# Train Model

In [None]:
adversarial_loss = nn.BCELoss().to(train_config.device)

In [None]:
fixed_noise = torch.randn(size=(n_classes, train_config.noise_size)).to(train_config.device)
fixed_labels = torch.arange(0, n_classes, dtype=torch.long).to(train_config.device)

def plot_output(epoch):
  plt.clf()
  with torch.no_grad():

    generator.eval()
    test_images = generator(fixed_noise, fixed_labels)
    generator.train()
    grid = torchvision.utils.make_grid(test_images.cpu(), normalize=True)
    show_grid(grid)

In [None]:
pbar = tqdm()

device = train_config.device

for epoch in range(train_config.epochs):
    print(f"Epoch: {epoch + 1} / {train_config.epochs}")
    pbar.reset(total=len(train_loader))
    
    # Setting up losses
    discriminator_losses = []
    generator_losses = []
    
    for i, data in enumerate(train_loader):
        
        # Bring to device
        real_images = data["images"].to(device)
        real_labels = data["labels"].to(device)
        
        # Get batch size
        current_batch_size = real_images.size()[0]
        
        # For real vs fake
        real_valid = torch.ones(current_batch_size, 1).to(device)
        fake_valid = torch.zeros(current_batch_size, 1).to(device)
        
        # Train Generator
        generator.zero_grad()
        input_noise = torch.randn(size=(current_batch_size, train_config.noise_size)).to(device)
        fake_images = generator(input_noise, real_labels)
        disc_fake_valid = discriminator(fake_images, real_labels)
        
        generator_loss = adversarial_loss(disc_fake_valid, real_valid)
        generator_loss.backward()
        generator.optimizer.step()
        generator_losses.append(generator_loss)
        
        # Train Discriminator
        discriminator.zero_grad()
        
        ## Calculate real loss
        disc_real_valid = discriminator(real_images, real_labels)
        disc_real_loss = adversarial_loss(disc_real_valid, real_valid)
        
        ## Calculate wrong loss
        wrong_labels = torch.randint(0, n_classes, (current_batch_size, )).to(device)
        disc_wrong_valid = discriminator(real_images, wrong_labels)
        disc_wrong_loss = adversarial_loss(disc_wrong_valid, fake_valid)
        
        ## Calculate fake loss
        disc_fake_valid = discriminator(fake_images.detach(), real_labels)
        disc_fake_loss = adversarial_loss(disc_fake_valid, fake_valid)
        
        ## Calculating total loss
        discriminator_loss = disc_real_loss + disc_wrong_loss + disc_fake_loss
        if discriminator_loss > 0.5:
            discriminator_loss.backward()
            discriminator.optimizer.step()
        else:
            discriminator_loss = discriminator_loss.detach()
        discriminator_losses.append(discriminator_loss)
        
        pbar.update()
        
    print(f"Discriminator Loss: {torch.mean(torch.FloatTensor(discriminator_losses)):.3f}")
    print(f"Generator Loss: {torch.mean(torch.FloatTensor(generator_losses)):.3f}")
    
    if (epoch + 1) % train_config.save_epoch == 0:
        plot_output(epoch + 1)
        
        
pbar.refresh()

In [None]:
with torch.no_grad():

    generator.eval()
    save_images = generator(fixed_noise, fixed_labels)
    generator.train()
    grid = torchvision.utils.make_grid(save_images.cpu(), normalize=True)
    torchvision.utils.save_image(grid, '/kaggle/working/output_images.png')