In [1]:
import os

WORK_DIR = '/home/lishi/workspace/MAAF/'
os.chdir(WORK_DIR)

import torch
from torch import nn
import torch.nn.functional as F
from torchvision import transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader, random_split
import pytorch_lightning as pl

from demo_utils import get_preserved_opt
from datasets.datasets import load_dataset


In [2]:
opt = get_preserved_opt()
dataset_dict = load_dataset(opt)

Reading dataset  fashioniq
0  files not found in  train
0  files not found in  val
0  files not found in  test
train size 45429
val size 15415
test size 15417


In [4]:
len(dataset_dict['train'].queries)

18000

In [3]:
from transformers import CLIPProcessor, CLIPModel
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

ftfy or spacy is not installed using BERT BasicTokenizer instead of ftfy.


In [4]:
train_dataset = dataset_dict['train']
train_dataloader = train_dataset.get_loader(
    batch_size=opt.batch_size,
    shuffle=False,
    drop_last=True,
    num_workers=opt.loader_num_workers
)

In [5]:
for data in train_dataloader:
    batch = data
    break

In [6]:
ref_imgs = [sample['source_img_data'] for sample in batch]
tgt_imgs = [sample['target_img_data'] for sample in batch]
mod_strs = [str(sample['mod']['str']) for sample in batch]
ref_imgs = processor(images=ref_imgs, return_tensors='pt')
tgt_imgs = processor(images=tgt_imgs, return_tensors='pt')
mod_strs = processor(text=mod_strs, return_tensors='pt', padding=True)

In [15]:
def contrastive_loss(logits: torch.Tensor) -> torch.Tensor:
    return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device))


def clip_loss(similarity: torch.Tensor) -> torch.Tensor:
    caption_loss = contrastive_loss(similarity)
    image_loss = contrastive_loss(similarity.T)
    return (caption_loss + image_loss) / 2.0

In [16]:
text_features = model.get_text_features(input_ids=mod_strs['input_ids'], attention_mask=mod_strs['attention_mask'])
ref_imgs_features = model.get_image_features(pixel_values=ref_imgs['pixel_values'])
tgt_imgs_features = model.get_image_features(pixel_values=tgt_imgs['pixel_values'])

# normalize
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
ref_imgs_features = ref_imgs_features / ref_imgs_features.norm(dim=-1, keepdim=True)
tgt_imgs_features = tgt_imgs_features / tgt_imgs_features.norm(dim=-1, keepdim=True)

# compose reference image feature and modifciation text feature
mod_imgs_features = (text_features + ref_imgs_features) / 2

# cosine similarity as logits
logit_scale = model.logit_scale.exp()
logits_per_ref = torch.matmul(mod_imgs_features, tgt_imgs_features.t()) * logit_scale

loss = clip_loss(logits_per_ref)

In [None]:
class InteractiveCLIP(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
        self.clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
        self.compose_choice = 'mean'

    def forward(self, x):
        # in lightning, forward defines the prediction/inference actions
        # embedding = self.encoder(x)
        # return embeddinga
        pass

    def training_step(self, batch, batch_idx):
        
        def _contrastive_loss(logits: torch.Tensor) -> torch.Tensor:
            return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device))

        def _clip_loss(similarity: torch.Tensor) -> torch.Tensor:
            ref_loss = contrastive_loss(similarity)
            tgt_loss = contrastive_loss(similarity.T)
            return (ref_loss + tgt_loss) / 2.0
        
        ref_imgs = [sample['source_img_data'] for sample in batch]
        tgt_imgs = [sample['target_img_data'] for sample in batch]
        mod_strs = [str(sample['mod']['str']) for sample in batch]
        ref_imgs = processor(images=ref_imgs, return_tensors='pt')
        tgt_imgs = processor(images=tgt_imgs, return_tensors='pt')
        mod_strs = processor(text=mod_strs, return_tensors='pt', padding=True)
        
        text_features = model.get_text_features(input_ids=mod_strs['input_ids'], attention_mask=mod_strs['attention_mask'])
        ref_imgs_features = model.get_image_features(pixel_values=ref_imgs['pixel_values'])
        tgt_imgs_features = model.get_image_features(pixel_values=tgt_imgs['pixel_values'])
        
        # normalize
        text_features = text_features / text_features.norm(dim=-1, keepdim=True)
        ref_imgs_features = ref_imgs_features / ref_imgs_features.norm(dim=-1, keepdim=True)
        tgt_imgs_features = tgt_imgs_features / tgt_imgs_features.norm(dim=-1, keepdim=True)
        
        # compose reference image feature and modifciation text feature
        if (self.compose_choice=="mean"):
            mod_imgs_features = (text_features + ref_imgs_features) / 2

        # cosine similarity as logits
        logit_scale = model.logit_scale.exp()
        logits_per_ref = torch.matmul(mod_imgs_features, tgt_imgs_features.t()) * logit_scale

        loss = _clip_loss(logits_per_ref)
        
        return loss
        
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-4)
        return optimizer