In [7]:
from embedding_reader import EmbeddingReader
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import os
import math

# Load LAION CLIP Data

### Fetch if not already stored

In [2]:
# emebeddings are .npy
# metadata is .paraquet
embedding_reader = EmbeddingReader(
    embeddings_folder="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/",
    metadata_folder="https://mystic.the-eye.eu/public/AI/cah/laion5b/metadata/laion2B-en/",
    meta_columns=['SAMPLE_ID', 'TEXT'],
    file_format="parquet_npy"
)
print("embedding count", embedding_reader.count)
print("dimension", embedding_reader.dimension)
print("total size", embedding_reader.total_size)
print("byte per item", embedding_reader.byte_per_item)

100%|██████████████████████████████████████████████████████████████| 4611/4611 [03:19<00:00, 23.07it/s]

embedding count 116341562
dimension 768
total size 178700639232
byte per item 1536





In [3]:
for emb, meta in embedding_reader(batch_size=10 ** 4, start=0, end=10 ** 4, show_progress=True):
    print(emb.shape)
    print(meta.size)

100%|███████████████████████████████████████████████████████████████████| 1/1 [02:06<00:00, 126.01s/it]

(10000, 768)
30000





In [6]:
meta['emb'] = emb.tolist()

In [15]:
meta.to_parquet('data/laion2B-10000.parquet.gzip', compression='gzip')

### Create Dataset and Splits

In [8]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device

device(type='cuda')

In [9]:
torch.cuda.set_device(1)

In [10]:
meta = pd.read_parquet('data/laion2B-10000.parquet.gzip')

In [11]:
emb = torch.Tensor(meta['emb'])

  emb = torch.Tensor(meta['emb'])


In [12]:
train_idx = int(len(meta) * 0.8)
x_train, y_train = meta['TEXT'][:train_idx], emb[:train_idx]
x_test, y_test = meta['TEXT'][train_idx:], emb[train_idx:]

In [13]:
x_train.shape, y_train.size(), x_test.shape, y_test.size()

((8000,), torch.Size([8000, 768]), (2000,), torch.Size([2000, 768]))

# Fine-tune LM -> Predict CLIP image embeddings

In [14]:
from transformers import AutoModel, AutoTokenizer

In [174]:
epochs = 4
batch_size = 16
train_batches = math.ceil(x_train.shape[0]/batch_size)
test_batches = math.ceil(x_test.shape[0]/batch_size)

## BERT

In [176]:
class CLIPEmbBERT(nn.Module):
    def __init__(self):
        super(CLIPEmbBERT, self).__init__()
        self.model = AutoModel.from_pretrained("bert-base-uncased")
        for param in self.model.parameters():
            param.requires_grad = False
        ### New layers:
        self.linear1 = nn.Linear(768, 1024)
        self.linear2 = nn.Linear(1024, 768)

    def forward(self, tokens, mask):
        cls_embedding = self.model(tokens, attention_mask=mask).pooler_output
        out = self.linear1(cls_embedding)
        out = self.linear2(out)

        return out 

In [177]:
bert_model = CLIPEmbBERT()
bert_model.to(device)

bert_tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

criterion = nn.MSELoss()
optimizer = torch.optim.Adam(bert_model.parameters(), lr=5e-5, weight_decay=3e-6)

Downloading:   0%|          | 0.00/420M [00:00<?, ?B/s]

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [178]:
x_train_bert = bert_tokenizer.batch_encode_plus(list(x_train), return_tensors='pt', padding=True, add_special_tokens=True).to(device)
y_train = y_train.to(device)

In [179]:
%%time
bert_model.train()
for epoch in range(epochs):
    
    rand_ids = torch.randperm(x_train_bert['input_ids'].size()[0])
    X = x_train_bert['input_ids'][rand_ids]
    masks = x_train_bert['attention_mask'][rand_ids]
    
    total_loss = 0.0
    
    for i in tqdm(range(0, X.size()[0], batch_size)):
        optimizer.zero_grad()
        
        outputs = bert_model(X[i:i+batch_size], mask=masks[i:i+batch_size])
        
        loss = torch.sqrt(criterion(torch.log(outputs + 1), torch.log(y_train[i:i+batch_size] + 1)))
        total_loss += loss.item()

        loss.backward()
        optimizer.step()
    
    print(f"Loss Epoch {epoch+1} / {epochs}: {total_loss/train_batches}")

100%|██████████████████████████████████████████████████████████████████| 500/500 [00:37<00:00, 13.35it/s]


Loss Epoch 1 / 4: 0.040074531018733976


100%|██████████████████████████████████████████████████████████████████| 500/500 [00:37<00:00, 13.27it/s]


Loss Epoch 2 / 4: 0.03325400817021727


100%|██████████████████████████████████████████████████████████████████| 500/500 [00:38<00:00, 13.15it/s]


Loss Epoch 3 / 4: 0.032004800744354724


100%|██████████████████████████████████████████████████████████████████| 500/500 [00:37<00:00, 13.25it/s]

Loss Epoch 4 / 4: 0.0312991358526051
CPU times: user 2min 29s, sys: 1.51 s, total: 2min 31s
Wall time: 2min 30s





In [180]:
x_test_bert = bert_tokenizer.batch_encode_plus(list(x_test), return_tensors='pt', padding=True, add_special_tokens=True).to(device)
y_test = y_test.to(device)

In [181]:
predicted_embs = np.array([[0]*768])
bert_model.eval()
with torch.no_grad():
    X = x_test_bert['input_ids']
    masks = x_test_bert['attention_mask']
    total_loss = 0.0
    for i in tqdm(range(0, X.size()[0], batch_size)):
        outputs = bert_model(X[i:i+batch_size], mask=masks[i:i+batch_size])
        predicted_embs = np.concatenate([predicted_embs, outputs.cpu()])

        loss = torch.sqrt(criterion(torch.log(outputs + 1), torch.log(y_test[i:i+batch_size] + 1)))
        total_loss += loss.item()

total_loss/test_batches

100%|██████████████████████████████████████████████████████████████████| 125/125 [00:08<00:00, 14.65it/s]


0.030576227888464928

In [182]:
predicted_embs[1:]

array([[ 0.021662  ,  0.04147527,  0.01045391, ...,  0.00151252,
        -0.01788167,  0.0114178 ],
       [ 0.03370392,  0.04907173, -0.00467334, ...,  0.02452838,
        -0.01084931, -0.0040316 ],
       [ 0.02601591,  0.02909793,  0.00328831, ...,  0.00946003,
        -0.00404191, -0.00129716],
       ...,
       [ 0.0244996 ,  0.02940119,  0.00433417, ...,  0.01099559,
         0.00248771,  0.01305326],
       [ 0.04512278,  0.03372372,  0.01290418, ..., -0.0030854 ,
        -0.00783951, -0.00473635],
       [ 0.03254802,  0.04593226,  0.01455975, ..., -0.00737137,
        -0.01401397, -0.00299664]])

### Save LAION-2B embeddings

In [111]:
np.save("embeds/BERT_test_preds.npy", predicted_embs[1:])
np.save("embeds/LAION_test_gt.npy", np.array(y_test.cpu()))

### Winoground embeddings

In [112]:
win_df = pd.read_json("data/examples.jsonl", lines=True)
win_bert = bert_tokenizer.batch_encode_plus(win_df['caption_0'].tolist() + win_df['caption_1'].tolist(), return_tensors='pt', padding=True, add_special_tokens=True).to(device)
win_bert['input_ids'].shape

torch.Size([800, 32])

In [24]:
win_embs = np.array([[0]*768])
with torch.no_grad():
    X = win_bert['input_ids']
    masks = win_bert['attention_mask']
    for i in tqdm(range(0, X.size()[0], batch_size)):
        outputs = bert_model(X[i:i+batch_size], mask=masks[i:i+batch_size])
        win_embs = np.concatenate([win_embs, outputs.cpu()])

100%|███████████████████████████████████████████████| 50/50 [00:00<00:00, 79.26it/s]


In [26]:
np.save("embeds/BERT_win_preds.npy", win_embs[1:])

## DistilBERT

In [10]:
from transformers import DistilBertForSequenceClassification
class CLIPEmbDistilBERT(nn.Module):
    def __init__(self):
        super(CLIPEmbDistilBERT, self).__init__()
        self.model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=768)
          ### New layers: None

    def forward(self, tokens, mask):
        cls_emebdding = self.model(tokens, attention_mask=mask).logits

        return cls_emebdding

In [11]:
distilbert_model = CLIPEmbDistilBERT()
distilbert_model.to(device)

distilbert_tokenizer = AutoTokenizer.from_pretrained("distilbert-base-cased")

criterion = nn.MSELoss()
optimizer = torch.optim.Adam(distilbert_model.parameters(), lr=3e-3, weight_decay=1e-5)

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_layer_norm.bias', 'vocab_projector.bias', 'vocab_layer_norm.weight', 'vocab_transform.weight', 'vocab_transform.bias', 'vocab_projector.weight']
- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier.bias', 'classifier.bias', 'classifier.w

In [12]:
x_train_distilbert = distilbert_tokenizer.batch_encode_plus(list(x_train), return_tensors='pt', padding=True, add_special_tokens=True).to(device)
y_train = y_train.to(device)

In [13]:
distilbert_model.train()
for epoch in range(epochs):
    
    rand_ids = torch.randperm(x_train_distilbert['input_ids'].size()[0])
    X = x_train_distilbert['input_ids'][rand_ids]
    masks = x_train_distilbert['attention_mask'][rand_ids]
    
    total_loss = 0.0
    
    for i in tqdm(range(0, X.size()[0], batch_size)):
        optimizer.zero_grad()
        
        outputs = distilbert_model(X[i:i+batch_size], mask=masks[i:i+batch_size])
        
        loss = torch.sqrt(criterion(outputs, y_train[i:i+batch_size]))
        total_loss += loss.item()

        loss.backward()
        optimizer.step()
    
    print(f"Loss Epoch {epoch+1} / {epochs}: {total_loss/train_batches}")

100%|█████████████████████████████████████████████| 500/500 [00:59<00:00,  8.46it/s]


Loss Epoch 1 / 4: 0.03901317881047726


100%|█████████████████████████████████████████████| 500/500 [00:59<00:00,  8.40it/s]


Loss Epoch 2 / 4: 0.028321496821939945


100%|█████████████████████████████████████████████| 500/500 [00:59<00:00,  8.39it/s]


Loss Epoch 3 / 4: 0.028325878888368607


100%|█████████████████████████████████████████████| 500/500 [00:59<00:00,  8.42it/s]

Loss Epoch 4 / 4: 0.028327112585306166





In [14]:
x_test_distilbert = distilbert_tokenizer.batch_encode_plus(list(x_test), return_tensors='pt', padding=True, add_special_tokens=True).to(device)
y_test = y_test.to(device)

In [15]:
predicted_embs = np.array([[0]*768])
distilbert_model.eval()
with torch.no_grad():
    X = x_test_distilbert['input_ids']
    masks = x_test_distilbert['attention_mask']
    total_loss = 0.0
    for i in tqdm(range(0, X.size()[0], batch_size)):
        outputs = distilbert_model(X[i:i+batch_size], mask=masks[i:i+batch_size])
        predicted_embs = np.concatenate([predicted_embs, outputs.cpu()])

        loss = torch.sqrt(criterion(outputs, y_test[i:i+batch_size]))
        total_loss += loss.item()

total_loss/test_batches

100%|█████████████████████████████████████████████| 125/125 [00:05<00:00, 24.08it/s]


0.028233408972620964

### Save embeddings

In [16]:
np.save("embeds/DistilBERT_test_preds.npy", predicted_embs[1:])

### Winoground embeddings

In [16]:
win_df = pd.read_json("data/examples.jsonl", lines=True)
win_distilbert = distilbert_tokenizer.batch_encode_plus(win_df['caption_0'].tolist() + win_df['caption_1'].tolist(), return_tensors='pt', padding=True, add_special_tokens=True).to(device)
win_distilbert['input_ids'].shape

torch.Size([800, 32])

In [17]:
win_embs = np.array([[0]*768])
with torch.no_grad():
    X = win_distilbert['input_ids']
    masks = win_distilbert['attention_mask']
    for i in tqdm(range(0, X.size()[0], batch_size)):
        outputs = distilbert_model(X[i:i+batch_size], mask=masks[i:i+batch_size])
        win_embs = np.concatenate([win_embs, outputs.cpu()])

100%|██████████████████████████████████████████████| 50/50 [00:00<00:00, 119.67it/s]


In [18]:
np.save("embeds/DistilBERT_win_preds.npy", win_embs[1:])