In [80]:
from embedding_reader import EmbeddingReader
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import os
import math

# Load LAION CLIP Data

### Fetch if not already stored

In [2]:
# emebeddings are .npy
# metadata is .paraquet
embedding_reader = EmbeddingReader(
    embeddings_folder="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/",
    metadata_folder="https://mystic.the-eye.eu/public/AI/cah/laion5b/metadata/laion2B-en/",
    meta_columns=['SAMPLE_ID', 'TEXT'],
    file_format="parquet_npy"
)
print("embedding count", embedding_reader.count)
print("dimension", embedding_reader.dimension)
print("total size", embedding_reader.total_size)
print("byte per item", embedding_reader.byte_per_item)

100%|██████████████████████████████████████████████████████████████| 4611/4611 [03:19<00:00, 23.07it/s]

embedding count 116341562
dimension 768
total size 178700639232
byte per item 1536





In [3]:
for emb, meta in embedding_reader(batch_size=10 ** 4, start=0, end=10 ** 4, show_progress=True):
    print(emb.shape)
    print(meta.size)

100%|███████████████████████████████████████████████████████████████████| 1/1 [02:06<00:00, 126.01s/it]

(10000, 768)
30000





In [6]:
meta['emb'] = emb.tolist()

In [15]:
meta.to_parquet('laion2B-10000.parquet.gzip', compression='gzip')

### Create Dataset and Splits

In [None]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device

In [None]:
torch.cuda.set_device(1)

In [2]:
meta = pd.read_parquet('laion2B-10000.parquet.gzip')

In [3]:
emb = torch.Tensor(meta['emb'])

  emb = torch.Tensor(meta['emb'])


In [4]:
train_idx = int(len(meta) * 0.8)
x_train, y_train = meta['TEXT'][:train_idx], emb[:train_idx]
x_test, y_test = meta['TEXT'][train_idx:], emb[train_idx:]

In [5]:
x_train.shape, y_train.size(), x_test.shape, y_test.size()

((8000,), torch.Size([8000, 768]), (2000,), torch.Size([2000, 768]))

# Fine-tune LM -> Predict CLIP image embeddings

In [6]:
from transformers import AutoModel, AutoTokenizer

In [117]:
epochs = 4
batch_size = 64
train_batches = math.ceil(x_train.shape[0]/batch_size)
test_batches = math.ceil(x_test.shape[0]/batch_size)

## BERT

In [37]:
class CLIPEmbBERT(nn.Module):
    def __init__(self):
        super(CLIPEmbBERT, self).__init__()
        self.model = AutoModel.from_pretrained("bert-base-cased")
          ### New layers: None

    def forward(self, tokens, mask):
        cls_emebdding = self.model(tokens, attention_mask=mask).pooler_output

        return cls_emebdding 

In [38]:
bert_model = CLIPEmbBERT()
bert_model.to(device)

bert_tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")

criterion = nn.MSELoss()
optimizer = torch.optim.Adam(bert_model.parameters(), lr=5e-5, weight_decay=1e-6)

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [39]:
x_train_bert = distilbert_tokenizer.batch_encode_plus(list(x_train), return_tensors='pt', padding=True, add_special_tokens=True).to(device)
y_train = y_train.to(device)

In [None]:
bert_model.train()
for epoch in range(epochs):
    
    rand_ids = torch.randperm(x_train_bert['input_ids'].size()[0])
    X = x_train_bert['input_ids'][rand_ids]
    masks = x_train_bert['attention_mask'][rand_ids]
    
    for i in tqdm(range(0, X.size()[0], batch_size)):
        optimizer.zero_grad()
        
        outputs = bert_model(X[i:i+batch_size], mask=masks[i:i+batch_size])
        
        loss = torch.sqrt(criterion(outputs, y_train[i:i+batch_size]))

        loss.backward()
        optimizer.step()
    
    print(f"Loss Epoch {epoch}: {loss}")

100%|█████████████████████████████████████████| 125/125 [1:18:42<00:00, 37.78s/it]


Loss Epoch 0: 0.002549578435719013


 18%|███████▋                                  | 23/125 [14:20<1:00:06, 35.35s/it]

In [None]:
x_test_bert = bert_tokenizer.batch_encode_plus(list(x_test), return_tensors='pt', padding=True, add_special_tokens=True).to(device)
y_test = y_test.to(device)

In [None]:
bert_model.eval()
with torch.no_grad():
    X = x_test_bert['input_ids']
    masks = x_test_bert['attention_mask']
    total_loss = 0.0
    for i in tqdm(range(0, X.size()[0], batch_size)):
        outputs = bert_model(X[i:i+batch_size], mask=masks[i:i+batch_size])

        loss = torch.sqrt(criterion(outputs, y_test[i:i+batch_size]))
        total_loss += loss.item()

total_loss/test_batches

## DistilBERT

In [110]:
from transformers import DistilBertForSequenceClassification
class CLIPEmbDistilBERT(nn.Module):
    def __init__(self):
        super(CLIPEmbDistilBERT, self).__init__()
        self.model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=768)
          ### New layers: None

    def forward(self, tokens, mask):
        cls_emebdding = self.model(tokens, attention_mask=mask).logits

        return cls_emebdding

In [121]:
distilbert_model = CLIPEmbDistilBERT()
distilbert_model.to(device)

distilbert_tokenizer = AutoTokenizer.from_pretrained("distilbert-base-cased")

criterion = nn.MSELoss()

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_transform.weight', 'vocab_layer_norm.bias', 'vocab_projector.weight', 'vocab_projector.bias', 'vocab_layer_norm.weight', 'vocab_transform.bias']
- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier.weight', 'classifier.bias', 'classifier

In [122]:
optimizer = torch.optim.Adam(distilbert_model.parameters(), lr=3e-5, weight_decay=0.0)

In [123]:
x_train_distilbert = distilbert_tokenizer.batch_encode_plus(list(x_train), return_tensors='pt', padding=True, add_special_tokens=True).to(device)
y_train = y_train.to(device)

In [124]:
distilbert_model.train()
for epoch in range(epochs):
    
    rand_ids = torch.randperm(x_train_distilbert['input_ids'].size()[0])
    X = x_train_distilbert['input_ids'][rand_ids]
    masks = x_train_distilbert['attention_mask'][rand_ids]
    
    total_loss = 0.0
    
    for i in tqdm(range(0, X.size()[0], batch_size)):
        optimizer.zero_grad()
        
        outputs = distilbert_model(X[i:i+batch_size], mask=masks[i:i+batch_size])
        
        loss = torch.sqrt(criterion(outputs, y_train[i:i+batch_size]))
        total_loss += loss.item()

        loss.backward()
        optimizer.step()
    
    print(f"Loss Epoch {epoch+1} / {epochs}: {total_loss/train_batches}")

100%|████████████████████████████████████████████| 500/500 [00:29<00:00, 17.08it/s]


Loss Epoch 1 / 4: 0.03193383591994643


100%|████████████████████████████████████████████| 500/500 [00:29<00:00, 17.10it/s]


Loss Epoch 2 / 4: 0.029477865733206272


100%|████████████████████████████████████████████| 500/500 [00:29<00:00, 17.16it/s]


Loss Epoch 3 / 4: 0.029053160302340984


100%|████████████████████████████████████████████| 500/500 [00:29<00:00, 17.17it/s]

Loss Epoch 4 / 4: 0.02882398197427392





In [125]:
x_test_distilbert = distilbert_tokenizer.batch_encode_plus(list(x_test), return_tensors='pt', padding=True, add_special_tokens=True).to(device)
y_test = y_test.to(device)

In [126]:
distilbert_model.eval()
with torch.no_grad():
    X = x_test_distilbert['input_ids']
    masks = x_test_distilbert['attention_mask']
    total_loss = 0.0
    for i in tqdm(range(0, X.size()[0], batch_size)):
        outputs = distilbert_model(X[i:i+batch_size], mask=masks[i:i+batch_size])

        loss = torch.sqrt(criterion(outputs, y_test[i:i+batch_size]))
        total_loss += loss.item()

total_loss/test_batches

100%|████████████████████████████████████████████| 125/125 [00:02<00:00, 54.64it/s]


0.0281514692902565