In [2]:
import glob
from PIL import Image
import sys
import torch
from matplotlib import pyplot as plt
import os
from torchvision.transforms import v2
from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils.data import Dataset, DataLoader
import numpy as np
from torchvision.transforms.functional import to_pil_image

In [3]:
PATH_TO_IMAGES = 'data/img_align_celeba'
PATH_TO_LABELS = 'data/list_attr_celeba.csv'

In [4]:
class ImageLoader(Dataset):
    def __init__(self, data_path, label_path, img_size=(234, 234), augment=True):

        self.data_path = data_path
        self.label_path = label_path
        self.augment = augment
        self.attr_names = self.get_attribute_names_from_csv()
        self.distribution = np.zeros(len(self.attr_names))
        self.images = self.get_images_from_directory(augment)
        self.labels = self.get_labels_from_csv()

        self.transform = v2.Compose([
            v2.Resize(size=img_size),
            #v2.CenterCrop(224),
            v2.ToImage(),
            v2.ToDtype(torch.float32, scale=True),
            # Normalization for pretrained mobilenet: mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225]
            #v2.Normalize(mean=[0.5, 0.5, 0.5],std=[0.5, 0.5, 0.5])
        ])

    def __getitem__(self, idx):
        image = Image.open(self.images[idx])
        image_tensor = self.transform(image)
        label_tensor = torch.Tensor(self.labels[int(idx / 4)]) if self.augment else torch.Tensor(self.labels[idx])
        return image_tensor, label_tensor

    def get_distribution(self):
        return self.distribution

    def __len__(self):
        return len(self.images)

    def get_images_from_directory(self, augment):
        if augment:
            return self.augment_images()
        else:
            return sorted(glob.glob(f'{self.data_path}/*.jpg'))

    def get_labels_from_csv(self):
        label_list = open(self.label_path).readlines()[1:]
        data_label = []
        for i in range(len(label_list)):
            data_label.append(label_list[i].strip().split(',')[1:])
        for i in range(len(data_label)):
            data_label[i] = [j.replace('-1', '0') for j in data_label[i]]
            data_label[i] = [int(j) for j in data_label[i]]
            self.distribution += np.array(data_label[i])
        return data_label

    def get_attribute_names_from_csv(self):
        return open(self.label_path).readlines()[0].split(',')[1:]


In [5]:
dataset = ImageLoader(PATH_TO_IMAGES, PATH_TO_LABELS, img_size=(144, 144),augment=False)

In [5]:
def show_images(images, num_of_samples=10, cols=5):
    images = iter(dataset)
    plt.figure(figsize=(17, 17))
    for i in range(num_of_samples):
        img, _ = next(images)
        plt.subplot(int(num_of_samples / cols) + 1, cols, i + 1)
        plt.imshow(to_pil_image(img))


In [None]:
show_images(dataset)

In [24]:
from torch import nn
from einops.layers.torch import Rearrange
from torch import Tensor


class PatchEmbedding(nn.Module):
    def __init__(self, in_channels = 3, patch_size = 8, emb_size = 128):
        self.patch_size = patch_size
        super().__init__()
        self.projection = nn.Sequential(
            # break-down the image in s1 x s2 patches and flat them
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=patch_size, p2=patch_size),
            nn.Linear(patch_size * patch_size * in_channels, emb_size)
        )

    def forward(self, x: Tensor) -> Tensor:
        x = self.projection(x)
        return x

# Run a quick test
sample_datapoint = torch.unsqueeze(dataset[0][0], 0)
print("Initial shape: ", sample_datapoint.shape)
embedding = PatchEmbedding()(sample_datapoint)
print("Patches shape: ", embedding.shape)

Initial shape:  torch.Size([1, 3, 144, 144])
Patches shape:  torch.Size([1, 324, 128])


In [21]:
from einops import rearrange

class Attention(nn.Module):
    def __init__(self, dim, n_heads, dropout):
        super().__init__()
        self.n_heads = n_heads
        self.att = torch.nn.MultiheadAttention(embed_dim=dim,
                                               num_heads=n_heads,
                                               dropout=dropout)
        self.q = torch.nn.Linear(dim, dim)
        self.k = torch.nn.Linear(dim, dim)
        self.v = torch.nn.Linear(dim, dim)

    def forward(self, x):
        q = self.q(x)
        k = self.k(x)
        v = self.v(x)
        attn_output, attn_output_weights = self.att(q, k, v)
        return attn_output

In [26]:
Attention(dim=128, n_heads=4, dropout=0.)(torch.ones((1, 5, 128))).shape

torch.Size([1, 5, 128])