In [1]:
import torch
import torch.nn as nn
from functools import partial
import clip
from einops import rearrange, repeat

from glob import glob
from PIL import Image
from torchvision import transforms as T
from tqdm import tqdm
import pickle
from pathlib import Path
import numpy as np

In [20]:
batch_size = 2
style_size = 8
avg_size = 8
local_emb_size = 257
local_emb_dim = 1024
emb_dim = 1024

# [2, 8, 257, 1024] to [2, 77, 1024]
# [2, 8, 257, 1024] -> [2, 8, 8, 1024]
# [257, 1024] to [8, 1024]
# 
x = torch.randn((batch_size, style_size, local_emb_size, local_emb_dim))
x.shape

torch.Size([2, 8, 257, 1024])

In [21]:

b, n, c, d = x.shape
print(batch.shape)
# reaarange [2, 8, 257, 1024] to [2, 8, 1024, 257]
batch = rearrange(x, 'b n c d -> b n d c')
print(batch.shape)
# linear (257, 8) to create [2, 8, 1024, 8]
batch = nn.Linear(c, avg_size)(batch)
print(batch.shape)
# reaarange [2, 8, 1024, 8] to [2, 8x8, 1024]
batch = rearrange(batch, 'b n c avg_size -> b (n avg_size) c')
print(batch.shape)

torch.Size([2, 8192, 257])
torch.Size([2, 8, 1024, 257])
torch.Size([2, 8, 1024, 8])
torch.Size([2, 64, 1024])


In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads: int, dim_in: int, dim_q: int, dim_k: int):
        super().__init__()
        self.heads = nn.ModuleList(
            [AttentionHead(dim_in, dim_q, dim_k) for _ in range(num_heads)]
        )
        self.linear = nn.Linear(num_heads * dim_k, dim_in)

    def forward(self, query: Tensor, key: Tensor, value: Tensor) -> Tensor:
        return self.linear(
            torch.cat([h(query, key, value) for h in self.heads], dim=-1)
        )

In [10]:
out = nn.Linear(1024, 1024)(x)
out.shape

torch.Size([2, 8, 257, 1024])

In [4]:
x.shape

torch.Size([2, 8, 257])

In [17]:
device = 'cuda:0'

clip_norm = T.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), 
                        std=(0.26862954, 0.26130258, 0.27577711))
clip_transform = T.Compose([T.ToTensor(),
                            clip_norm])


In [1]:

class ClipImageEncoder(nn.Module):
    """
        Uses the CLIP image encoder.
        """
    def __init__(
            self,
            model='ViT-L/14',
            context_dim=None,
            jit=False,
            device='cuda',
        ):
        super().__init__()
        self.model, _ = clip.load(name=model, device=device, jit=jit)

        self.model = self.model.eval()
        for param in self.parameters():
            param.requires_grad = False
            
    @torch.no_grad()
    def extract_features(self, x):
        b, n, c, h, w = x.shape

        return ret
            
    def forward(self, x):
        b, n, c, h, w = x.shape
        batch = rearrange(x, 'b n c h w -> (b n) c h w ')
        ret = self.model.encode_image(batch)        
        return rearrange(ret, '(b n) w -> b n w ', b=b, n=n)


In [23]:
encoder = ClipImageEncoder()
encoder = encoder.to(device)

In [2]:
style_files = glob("/home/soon/datasets/deepfashion_inshop/styles/**/*.jpg", recursive=True)

In [39]:
for style_file in tqdm(style_files[:]):
    style_image = Image.open(style_file)
    x = clip_transform(style_image).unsqueeze(0).unsqueeze(0).to(device)    
    emb = encoder(x).detach().cpu().squeeze(0).numpy()
    emb_file = style_file.replace('.jpg','.p')
    with open(emb_file, 'wb') as file:
        pickle.dump(emb, file)    

100%|██████████████████████████████████████| 1/1 [00:00<00:00, 90.59it/s]


In [3]:
style_files[0]

'/home/soon/datasets/deepfashion_inshop/styles/MEN/Suiting/id_00005928/01/1_front/hair.jpg'

In [5]:
folder = '/home/soon/datasets/deepfashion_inshop/styles/MEN/Suiting/id_00005928/01/1_front/'
style_names = ['face', 'hair', 'headwear', 'background', 'top', 'outer', 'bottom', 'shoes', 'accesories']


In [37]:
style_embeddings = []
for style_name in style_names:
    f_path = Path(folder)/f'{style_name}.p'
    if f_path.exists():
        with open(f_path, 'rb') as file:
            style_emb = pickle.load(file)
    else:
        style_emb = np.zeros((1,768), dtype=np.float16)
    style_embeddings.append(style_emb)
styles = torch.tensor(style_embeddings).squeeze(-2)

In [39]:
layer = nn.Linear(768, 4)

In [41]:
x = np.zeros((3,9,768))

In [45]:
layer(torch.tensor(x.astype(np.float32))).shape

torch.Size([3, 9, 4])

In [None]:
            style_images = []
            for style_name in self.style_names:
                f_path = full_styles_path/f'{style_name}.jpg'
                if f_path.exists() and not drop_style:
                    style_image = self.clip_transform((Image.open(f_path)))
                else:
                    style_image = self.clip_norm(torch.zeros(3, 224, 224))
                style_images.append(style_image)
            style_images = torch.stack(style_images) 