In [1]:
from embedding_reader import EmbeddingReader
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import os
import math

# Load LAION CLIP Data

### Fetch if not already stored

In [2]:
# emebeddings are .npy
# metadata is .paraquet
embedding_reader = EmbeddingReader(
    embeddings_folder="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/",
    metadata_folder="https://mystic.the-eye.eu/public/AI/cah/laion5b/metadata/laion2B-en/",
    meta_columns=['SAMPLE_ID', 'TEXT'],
    file_format="parquet_npy"
)
print("embedding count", embedding_reader.count)
print("dimension", embedding_reader.dimension)
print("total size", embedding_reader.total_size)
print("byte per item", embedding_reader.byte_per_item)

100%|██████████████████████████████████████████████████████████████| 4611/4611 [03:19<00:00, 23.07it/s]

embedding count 116341562
dimension 768
total size 178700639232
byte per item 1536





In [3]:
for emb, meta in embedding_reader(batch_size=10 ** 4, start=0, end=10 ** 4, show_progress=True):
    print(emb.shape)
    print(meta.size)

100%|███████████████████████████████████████████████████████████████████| 1/1 [02:06<00:00, 126.01s/it]

(10000, 768)
30000





In [6]:
meta['emb'] = emb.tolist()

In [15]:
meta.to_parquet('data/laion2B-10000.parquet.gzip', compression='gzip')

### Create Dataset and Splits

In [2]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device

device(type='cuda')

In [3]:
torch.cuda.set_device(1)

In [4]:
meta = pd.read_parquet('data/laion2B-10000.parquet.gzip')

In [5]:
emb = torch.Tensor(meta['emb'])

  emb = torch.Tensor(meta['emb'])


In [6]:
train_idx = int(len(meta) * 0.8)
x_train, y_train = meta['TEXT'][:train_idx], emb[:train_idx]
x_test, y_test = meta['TEXT'][train_idx:], emb[train_idx:]

In [7]:
x_train.shape, y_train.size(), x_test.shape, y_test.size()

((8000,), torch.Size([8000, 768]), (2000,), torch.Size([2000, 768]))

# Fine-tune LM -> Predict CLIP image embeddings

In [8]:
from transformers import AutoModel, AutoTokenizer

In [9]:
epochs = 4
batch_size = 16
train_batches = math.ceil(x_train.shape[0]/batch_size)
test_batches = math.ceil(x_test.shape[0]/batch_size)

## BERT

In [90]:
class CLIPEmbBERT(nn.Module):
    def __init__(self):
        super(CLIPEmbBERT, self).__init__()
        self.model = AutoModel.from_pretrained("bert-base-uncased")
        for param in self.model.parameters():
            param.requires_grad = False
        ### New layers:
        self.linear1 = nn.Linear(768, 1024)
        self.linear2 = nn.Linear(1024, 768)
        

    def forward(self, tokens, mask):
        cls_embedding = self.model(tokens, attention_mask=mask).pooler_output
        out = self.linear1(cls_embedding)
        out = self.linear2(out)

        return out 

In [97]:
bert_model = CLIPEmbBERT()
bert_model.to(device)

bert_tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

criterion = nn.MSELoss()
optimizer = torch.optim.Adam(bert_model.parameters(), lr=5e-5, weight_decay=3e-6)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [98]:
x_train_bert = bert_tokenizer.batch_encode_plus(list(x_train), return_tensors='pt', padding=True, add_special_tokens=True).to(device)
y_train = y_train.to(device)

In [99]:
%%time
bert_model.train()
for epoch in range(epochs):
    
    rand_ids = torch.randperm(x_train_bert['input_ids'].size()[0])
    X = x_train_bert['input_ids'][rand_ids]
    masks = x_train_bert['attention_mask'][rand_ids]
    
    total_loss = 0.0
    
    for i in tqdm(range(0, X.size()[0], batch_size)):
        optimizer.zero_grad()
        
        outputs = bert_model(X[i:i+batch_size], mask=masks[i:i+batch_size])
        
        loss = torch.sqrt(criterion(outputs, y_train[i:i+batch_size]))
        total_loss += loss.item()

        loss.backward()
        optimizer.step()
    
    print(f"Loss Epoch {epoch+1} / {epochs}: {total_loss/train_batches}")

100%|██████████████████████████████████████████████| 500/500 [00:19<00:00, 25.69it/s]


Loss Epoch 1 / 4: 0.04026263348013163


100%|██████████████████████████████████████████████| 500/500 [00:19<00:00, 25.85it/s]


Loss Epoch 2 / 4: 0.03314060392603278


100%|██████████████████████████████████████████████| 500/500 [00:19<00:00, 25.84it/s]


Loss Epoch 3 / 4: 0.031828350599855185


100%|██████████████████████████████████████████████| 500/500 [00:19<00:00, 25.84it/s]

Loss Epoch 4 / 4: 0.031118932079523803
CPU times: user 1min 16s, sys: 1.03 s, total: 1min 17s
Wall time: 1min 17s





In [100]:
x_test_bert = bert_tokenizer.batch_encode_plus(list(x_test), return_tensors='pt', padding=True, add_special_tokens=True).to(device)
y_test = y_test.to(device)

In [101]:
predicted_embs = np.array([[0]*768])
bert_model.eval()
with torch.no_grad():
    X = x_test_bert['input_ids']
    masks = x_test_bert['attention_mask']
    total_loss = 0.0
    for i in tqdm(range(0, X.size()[0], batch_size)):
        outputs = bert_model(X[i:i+batch_size], mask=masks[i:i+batch_size])
        predicted_embs = np.concatenate([predicted_embs, outputs.cpu()])

        loss = torch.sqrt(criterion(outputs, y_test[i:i+batch_size]))
        total_loss += loss.item()

total_loss/test_batches

100%|██████████████████████████████████████████████| 125/125 [00:04<00:00, 28.50it/s]


0.030406540259718895

In [102]:
predicted_embs[1:]

array([[ 0.0165555 ,  0.03351349,  0.02410737, ...,  0.00684628,
        -0.00677852, -0.0115644 ],
       [ 0.03341193,  0.00343005, -0.00269516, ..., -0.01118771,
         0.0062623 , -0.01241021],
       [ 0.03579801,  0.02571776,  0.01572936, ...,  0.00420775,
         0.0018131 , -0.00699689],
       ...,
       [ 0.01789567,  0.01747628,  0.02883603, ...,  0.00054394,
         0.00277947, -0.00023974],
       [ 0.01404809,  0.04418744,  0.00210286, ...,  0.02715633,
         0.02589402, -0.02859953],
       [ 0.03262623,  0.025831  ,  0.01618475, ...,  0.010462  ,
         0.01251491, -0.01629806]])

### Save LAION-2B embeddings

In [103]:
np.save("embeds/BERT_test_preds.npy", predicted_embs[1:])
np.save("embeds/LAION_test_gt.npy", np.array(y_test.cpu()))

### Winoground embeddings

In [104]:
win_df = pd.read_json("data/examples.jsonl", lines=True)
win_bert = bert_tokenizer.batch_encode_plus(win_df['caption_0'].tolist() + win_df['caption_1'].tolist(), return_tensors='pt', padding=True, add_special_tokens=True).to(device)
win_bert['input_ids'].shape

torch.Size([800, 32])

In [105]:
win_embs = np.array([[0]*768])
with torch.no_grad():
    X = win_bert['input_ids']
    masks = win_bert['attention_mask']
    for i in tqdm(range(0, X.size()[0], batch_size)):
        outputs = bert_model(X[i:i+batch_size], mask=masks[i:i+batch_size])
        win_embs = np.concatenate([win_embs, outputs.cpu()])

100%|████████████████████████████████████████████████| 50/50 [00:00<00:00, 95.00it/s]


In [106]:
np.save("embeds/BERT_win_preds.npy", win_embs[1:])

## DistilBERT

In [10]:
from transformers import DistilBertForSequenceClassification
class CLIPEmbDistilBERT(nn.Module):
    def __init__(self):
        super(CLIPEmbDistilBERT, self).__init__()
        self.model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=768)
        for param in self.model.parameters():
            param.requires_grad = False
        ### New layers:
        self.linear1 = nn.Linear(768, 1024)
        self.linear2 = nn.Linear(1024, 768)
        

    def forward(self, tokens, mask):
        cls_embedding = self.model(tokens, attention_mask=mask).logits
        out = self.linear1(cls_embedding)
        out = self.linear2(out)

        return out 

In [30]:
distilbert_model = CLIPEmbDistilBERT()
distilbert_model.to(device)

distilbert_tokenizer = AutoTokenizer.from_pretrained("distilbert-base-cased")

criterion = nn.MSELoss()
optimizer = torch.optim.Adam(distilbert_model.parameters(), lr=3e-5, weight_decay=3e-6)

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_projector.weight', 'vocab_projector.bias', 'vocab_layer_norm.bias', 'vocab_layer_norm.weight', 'vocab_transform.bias', 'vocab_transform.weight']
- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'pre_classifier.weight', 'classifier

In [31]:
x_train_distilbert = distilbert_tokenizer.batch_encode_plus(list(x_train), return_tensors='pt', padding=True, add_special_tokens=True).to(device)
y_train = y_train.to(device)

In [32]:
distilbert_model.train()
for epoch in range(epochs):
    
    rand_ids = torch.randperm(x_train_distilbert['input_ids'].size()[0])
    X = x_train_distilbert['input_ids'][rand_ids]
    masks = x_train_distilbert['attention_mask'][rand_ids]
    
    total_loss = 0.0
    
    for i in tqdm(range(0, X.size()[0], batch_size)):
        optimizer.zero_grad()
        
        outputs = distilbert_model(X[i:i+batch_size], mask=masks[i:i+batch_size])
        
        loss = torch.sqrt(criterion(outputs, y_train[i:i+batch_size]))
        total_loss += loss.item()

        loss.backward()
        optimizer.step()
    
    print(f"Loss Epoch {epoch+1} / {epochs}: {total_loss/train_batches}")

100%|██████████████████████████████████████████████| 500/500 [00:10<00:00, 49.17it/s]


Loss Epoch 1 / 4: 0.03265494028106332


100%|██████████████████████████████████████████████| 500/500 [00:10<00:00, 49.36it/s]


Loss Epoch 2 / 4: 0.03028310514241457


100%|██████████████████████████████████████████████| 500/500 [00:09<00:00, 50.39it/s]


Loss Epoch 3 / 4: 0.029588653806596994


100%|██████████████████████████████████████████████| 500/500 [00:09<00:00, 50.53it/s]

Loss Epoch 4 / 4: 0.029224651232361794





In [33]:
x_test_distilbert = distilbert_tokenizer.batch_encode_plus(list(x_test), return_tensors='pt', padding=True, add_special_tokens=True).to(device)
y_test = y_test.to(device)

In [34]:
predicted_embs = np.array([[0]*768])
distilbert_model.eval()
with torch.no_grad():
    X = x_test_distilbert['input_ids']
    masks = x_test_distilbert['attention_mask']
    total_loss = 0.0
    for i in tqdm(range(0, X.size()[0], batch_size)):
        outputs = distilbert_model(X[i:i+batch_size], mask=masks[i:i+batch_size])
        predicted_embs = np.concatenate([predicted_embs, outputs.cpu()])

        loss = torch.sqrt(criterion(outputs, y_test[i:i+batch_size]))
        total_loss += loss.item()

total_loss/test_batches

100%|██████████████████████████████████████████████| 125/125 [00:02<00:00, 51.60it/s]


0.028273208543658255

In [36]:
predicted_embs[1:]

array([[ 0.02629502,  0.0218821 ,  0.00022918, ...,  0.01273302,
        -0.00452789, -0.00140381],
       [ 0.02734128,  0.0214195 ,  0.00020625, ...,  0.00506776,
        -0.00234368,  0.00178314],
       [ 0.021662  ,  0.01842918, -0.00124166, ...,  0.01000047,
        -0.00620675, -0.0024468 ],
       ...,
       [ 0.02207613,  0.02050962,  0.00670609, ...,  0.0037114 ,
        -0.00767948, -0.00495844],
       [ 0.02228279,  0.02043567,  0.00370518, ...,  0.0058035 ,
        -0.00285789,  0.00020745],
       [ 0.02511614,  0.02209706,  0.00242683, ...,  0.00734688,
        -0.00405447,  0.0002732 ]])

### Save embeddings

In [37]:
np.save("embeds/DistilBERT_test_preds.npy", predicted_embs[1:])

### Winoground embeddings

In [38]:
win_df = pd.read_json("data/examples.jsonl", lines=True)
win_distilbert = distilbert_tokenizer.batch_encode_plus(win_df['caption_0'].tolist() + win_df['caption_1'].tolist(), return_tensors='pt', padding=True, add_special_tokens=True).to(device)
win_distilbert['input_ids'].shape

torch.Size([800, 32])

In [39]:
win_embs = np.array([[0]*768])
with torch.no_grad():
    X = win_distilbert['input_ids']
    masks = win_distilbert['attention_mask']
    for i in tqdm(range(0, X.size()[0], batch_size)):
        outputs = distilbert_model(X[i:i+batch_size], mask=masks[i:i+batch_size])
        win_embs = np.concatenate([win_embs, outputs.cpu()])

100%|███████████████████████████████████████████████| 50/50 [00:00<00:00, 160.69it/s]


In [40]:
np.save("embeds/DistilBERT_win_preds.npy", win_embs[1:])

In [45]:
win_embs[1] @ win_embs[401].T 

0.41961190452122643

In [46]:
X[0], X[400]

(tensor([ 101, 1126, 1385, 1825, 8514,  170, 1685, 1825,  102,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0], device='cuda:1'),
 tensor([ 101,  170, 1685, 1825, 8514, 1126, 1385, 1825,  102,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0], device='cuda:1'))