In [None]:
from typing import *
from datasets import load_dataset
from pathlib import Path

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
from typing import Literal
from torch.utils.data import DataLoader, Dataset
import torch
import open_clip as clip

class ClipWrap:

    """
    Wrapper class for OpenClip methods used to improve multi-modal rappresentation
    """
    @staticmethod
    def info():
        print(clip.list_pretrained())
        
    def __init__(self,photo_encoder:str, tokenizer:str, precision:str="float32", weights:str|None=None, device:Literal['cpu', 'cuda']="cpu"):
        self.precision = precision
        self.device = device
        
        self.encoder, train_arg, val_arg = clip.create_model_and_transforms(
            model_name=photo_encoder,
            pretrained=weights,
            load_weights=True,
            device=self.device,
            precision=self.precision,
            jit=True
        )

        self.train_arg = train_arg
        self.val_arg = val_arg

        self.tokenizer = clip.get_tokenizer(tokenizer)
        #print(self.encoder)
    
    def get_transformations(self) -> tuple[Callable, Callable]:
        return (self.train_arg, self.val_arg) # type: ignore
    
    def compute(self, images, prompts):
        self.encoder.eval()
        
        with torch.no_grad():
            
            # Preprocessing prompts
            tokens = self.tokenizer(prompts).to(self.device)
            prompts_emb = self.encoder.encode_text(tokens) # type: ignore
            prompts = prompts_emb.to(getattr(torch, self.precision)).to(self.device)
            
            # Preprocessing images
            #images = self.val_arg(images)
            images  = images.to(getattr(torch, self.precision)).to(self.device)
            images_emb = self.encoder.encode_image(images) # type: ignore
            
            
            img_std, img_mean = torch.std_mean(images_emb, dim=-1, keepdim=True)
            p_std, p_mean = torch.std_mean(prompts_emb, dim=-1, keepdim=True)

            #print(f"img_emb:{images.shape}")
            #print(f"prompt_emb:{prompts.shape}")

            images = (images_emb - img_mean)/img_std
            prompts = (prompts_emb - p_mean)/p_std

            coss = (100.0 * images @ prompts.T).softmax(dim=-1)

        return [images_emb, prompts_emb, coss]


In [None]:
class PixSet(Dataset):

    def __init__(self, src_dataset:Path|str,size:int, split:Literal['train','val','test'],offset:int=0, transformation:Callable|None = None):
        super().__init__()
        src_dataset = Path(src_dataset)
        self.offset = offset
        self.transformation = transformation
        if not src_dataset.exists():
            raise FileNotFoundError
        
        self.ds = load_dataset(path=str(src_dataset), split=f'{split}[{offset}:{size}]')
        
        if offset + size > self.ds.dataset_size:
            raise IndexError
        
    def __len__(self):
        return len(self.ds)
    
    def get_hf_dataset(self):
        return self.ds
    

    def __getitem__(self, index):
        record = self.ds[self.offset + index]
        
     
        original = record['original_image']
        edited   = record['edited_image']

        if self.transformation:
            original = self.transformation(original)
            edited   = self.transformation(edited)
        
        sample = {
            "original_image" : original,
            "original_prompt" : record['original_prompt'],
            "edit" : record['edit_prompt'],
            "edited_prompt": record['edited_prompt'],
            "edited_image" : edited
        }

        return sample

In [5]:
PixSet.info()

[('RN50', 'openai'), ('RN50', 'yfcc15m'), ('RN50', 'cc12m'), ('RN101', 'openai'), ('RN101', 'yfcc15m'), ('RN50x4', 'openai'), ('RN50x16', 'openai'), ('RN50x64', 'openai'), ('ViT-B-32', 'openai'), ('ViT-B-32', 'laion400m_e31'), ('ViT-B-32', 'laion400m_e32'), ('ViT-B-32', 'laion2b_e16'), ('ViT-B-32', 'laion2b_s34b_b79k'), ('ViT-B-32', 'datacomp_xl_s13b_b90k'), ('ViT-B-32', 'datacomp_m_s128m_b4k'), ('ViT-B-32', 'commonpool_m_clip_s128m_b4k'), ('ViT-B-32', 'commonpool_m_laion_s128m_b4k'), ('ViT-B-32', 'commonpool_m_image_s128m_b4k'), ('ViT-B-32', 'commonpool_m_text_s128m_b4k'), ('ViT-B-32', 'commonpool_m_basic_s128m_b4k'), ('ViT-B-32', 'commonpool_m_s128m_b4k'), ('ViT-B-32', 'datacomp_s_s13m_b4k'), ('ViT-B-32', 'commonpool_s_clip_s13m_b4k'), ('ViT-B-32', 'commonpool_s_laion_s13m_b4k'), ('ViT-B-32', 'commonpool_s_image_s13m_b4k'), ('ViT-B-32', 'commonpool_s_text_s13m_b4k'), ('ViT-B-32', 'commonpool_s_basic_s13m_b4k'), ('ViT-B-32', 'commonpool_s_s13m_b4k'), ('ViT-B-32', 'metaclip_400m'), ('V

In [6]:
embedder = ClipWrap("ViT-H-14-378-quickgelu", "ViT-H-14-378-quickgelu", weights="dfn5b", device="cuda")
loader = DataLoader( PixSet("../../dataset/pixset/", 16, "train", offset=0, transformation=embedder.get_transformations()[1]), batch_size=8, shuffle=True)


In [7]:
embeddings = []
for batch, data in enumerate(loader):
    print(f"===:(current batch {batch}):===")
    embeddings.append(embedder.compute(data['original_image'], data['original_prompt']))
   

===:(current batch 0):===


===:(current batch 1):===


In [10]:
torch.cosine_similarity(embeddings[0][0], embeddings[0][1]).norm()

tensor(0.9862, device='cuda:0')