In [1]:
!pip install pytorch-lightning --upgrade -q

[0m

In [2]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
from torchvision.transforms import InterpolationMode
import matplotlib.pyplot as plt
from PIL import Image
from transformers import AutoModel, AutoTokenizer
from timm import create_model, list_models
from dataclasses import dataclass, asdict
from sklearn.model_selection import train_test_split
import pytorch_lightning as pl
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.loggers import WandbLogger
from pathlib import Path
import wandb

In [3]:
%env TOKENIZERS_PARALLELISM = false

env: TOKENIZERS_PARALLELISM=false


In [4]:
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
key = user_secrets.get_secret("wandb-key")
wandb.login(key=key)

[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [5]:
@dataclass
class Config:
    text_encoder = 'google/bert_uncased_L-4_H-256_A-4'
    image_encoder = 'resnet50d'
    
    # CLIP CONFIG
    proj_dim = 256
    dropout = 0.1
    max_length = 128
    
    save_path = Path('liteCLIP')

In [6]:
class ImageEncoder(nn.Module):
    def __init__(self, model_name):
        super().__init__()
        self.model_name = model_name
        self.backbone = create_model(self.model_name, 
                                     pretrained=True, 
                                     num_classes=1,
                                    )
        self.embed_dim = self.backbone.fc.in_features
        self.backbone.fc = nn.Identity()
        
    def forward(self,x):
        return self.backbone(x)

In [7]:
class TextEncoder(nn.Module):
    def __init__(self, model_name):
        super().__init__()
        self.model_name = model_name
        self.backbone = AutoModel.from_pretrained(self.model_name)
        self.embed_dim = self.backbone.config.hidden_size
        
    def mean_pooler(self, last_hidden_state, attention_mask):
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
        sum_embeddings = torch.sum(last_hidden_state * input_mask_expanded, 1)
        sum_mask = input_mask_expanded.sum(1)
        sum_mask = torch.clamp(sum_mask, min=1e-9)
        mean_embeddings = sum_embeddings / sum_mask
        return mean_embeddings
    
    def forward(self, inputs):
        outputs = self.backbone(**inputs)
        pooled_output = self.mean_pooler(outputs['last_hidden_state'],inputs['attention_mask'])
        return pooled_output

In [8]:
class ProjectionHead(nn.Module):
    def __init__(self, embed_dim, Config):
        super().__init__()
        self.embed_dim = embed_dim
        self.proj_dim = Config.proj_dim
        self.dropout = Config.dropout
        
        self.proj = nn.Linear(self.embed_dim, self.proj_dim)
        self.act = nn.GELU()
        self.drop = nn.Dropout(self.dropout)
        self.ln = nn.LayerNorm(self.proj_dim)
        
    def forward(self, x):
        x = self.proj(x)
        out = self.act(x)
        out = self.drop(out)
        x = x + out
        x = self.ln(x)
        return x

In [9]:
class CLIP(nn.Module):
    def __init__(self, Config):
        super().__init__()
        
        self.image_encoder = ImageEncoder(Config.image_encoder)
        self.text_encoder = TextEncoder(Config.text_encoder)
        
        self.im_embed_dim = self.image_encoder.embed_dim
        self.txt_embed_dim = self.text_encoder.embed_dim 
        
        self.img_projection = ProjectionHead(self.im_embed_dim,Config)
        self.txt_projection = ProjectionHead(self.txt_embed_dim,Config)
        
        
    def forward(self,inputs):
        image, text = inputs
        
        image_embeddings = self.image_encoder(image)
        image_embeddings = self.img_projection(image_embeddings)
        
        text_embeddings = self.text_encoder(text)
        text_embeddings = self.txt_projection(text_embeddings)
        
        # logits will be in the shape batch_size X batch_size
        logits = (text_embeddings @ image_embeddings.T)
        
        return logits

In [10]:
tokenizer = AutoTokenizer.from_pretrained(Config.text_encoder)

Downloading (…)lve/main/config.json:   0%|          | 0.00/383 [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

In [11]:
with open('/kaggle/input/flickr8k/captions.txt','r') as f:
    captions = f.readlines()

In [12]:
df = pd.read_csv('/kaggle/input/flickr8k/captions.txt')
base_path = '/kaggle/input/flickr8k/Images/'

In [13]:
df['image'] = df['image'].map(lambda x:base_path+x)

In [14]:
df

Unnamed: 0,image,caption
0,/kaggle/input/flickr8k/Images/1000268201_693b0...,A child in a pink dress is climbing up a set o...
1,/kaggle/input/flickr8k/Images/1000268201_693b0...,A girl going into a wooden building .
2,/kaggle/input/flickr8k/Images/1000268201_693b0...,A little girl climbing into a wooden playhouse .
3,/kaggle/input/flickr8k/Images/1000268201_693b0...,A little girl climbing the stairs to her playh...
4,/kaggle/input/flickr8k/Images/1000268201_693b0...,A little girl in a pink dress going into a woo...
...,...,...
40450,/kaggle/input/flickr8k/Images/997722733_0cb543...,A man in a pink shirt climbs a rock face
40451,/kaggle/input/flickr8k/Images/997722733_0cb543...,A man is rock climbing high in the air .
40452,/kaggle/input/flickr8k/Images/997722733_0cb543...,A person in a red shirt climbing up a rock fac...
40453,/kaggle/input/flickr8k/Images/997722733_0cb543...,A rock climber in a red shirt .


In [15]:
def showimage(idx):
    print(df['caption'][idx])
    plt.imshow(Image.open(df['image'][idx]))

In [16]:
train_df, val_df = train_test_split(df,test_size=0.2,shuffle=True)
train_df.reset_index(drop=True,inplace=True)
val_df.reset_index(drop=True,inplace=True)

In [17]:
len(train_df), len(val_df)

(32364, 8091)

In [18]:
class FlickrDataset:
    def __init__(self, df,tokenizer, Config):
        self.df = df
        self.tfms = T.Compose([
            T.Resize(224,interpolation=InterpolationMode.BICUBIC),
            T.CenterCrop(224),
            T.ToTensor(),
            T.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), 
                        std=(0.26862954, 0.26130258, 0.27577711))
        ])
        self.tokenizer = tokenizer
        self.max_len = Config.max_length
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self,idx):
        img = self.df['image'][idx]
        img = Image.open(img).convert('RGB')
        img = self.tfms(img)
        caption = self.df['caption'][idx]
        caption = self.tokenizer.encode_plus(caption,
                                             return_tensors='pt',
                                             padding='max_length',
                                             truncation=True,
                                             max_length = self.max_len
                                            )
        caption = {k:torch.squeeze(v,0) for k,v in caption.items()}
        return img, caption

In [19]:
train_ds = FlickrDataset(train_df,tokenizer, Config)
val_ds = FlickrDataset(val_df,tokenizer, Config)

In [20]:
class CLIPLoss(nn.Module):
    def __init__(self,):
        super().__init__()
        
    def forward(self, inputs):
        labels = torch.arange(inputs.size(0)).to(inputs.device)
        txt_loss = F.cross_entropy(inputs,labels,reduction='none')
        img_loss = F.cross_entropy(inputs.T,labels.T,reduction='none')
        loss = (txt_loss + img_loss).mean()
        return loss

In [21]:
class CLIPLightning(pl.LightningModule):
    def __init__(self,model,config):
        super(CLIPLightning,self).__init__()
        
        self.config = config
        self.model = model(self.config)
        
        self.loss_fn = CLIPLoss()
        
    def forward(self,x):
        return self.model(x)
    
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=1e-4)
        return [optimizer]
    
    def training_step(self, batch, batch_idx):
        
        logits = self(batch)
        loss = self.loss_fn(logits)
        
        self.log('train_loss_step',loss.item(),sync_dist=True)
        self.log('train_loss', loss.item(), on_epoch=True, sync_dist=True)
        
        return loss
    
    
    def validation_step(self, batch, batch_idx):
        
        logits = self(batch)
        loss = self.loss_fn(logits)
        
        self.log('val_loss_step',loss.item(),sync_dist=True)
        self.log('val_loss', loss.item(), on_epoch=True, sync_dist=True)

In [22]:
model = CLIPLightning(CLIP,Config)

Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet50d_ra2-464e36ba.pth" to /root/.cache/torch/hub/checkpoints/resnet50d_ra2-464e36ba.pth


Downloading (…)"pytorch_model.bin";:   0%|          | 0.00/45.1M [00:00<?, ?B/s]

Some weights of the model checkpoint at google/bert_uncased_L-4_H-256_A-4 were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [23]:
train_dl = torch.utils.data.DataLoader(train_ds,
                                       batch_size=64,
                                       shuffle=True,
                                       num_workers=2,
                                       pin_memory=True,
                                       persistent_workers=True
                                      )
valid_dl = torch.utils.data.DataLoader(val_ds,
                                       batch_size=64,
                                       num_workers=2,
                                       pin_memory=True,
                                       persistent_workers=True
                                      )

In [24]:
logger = WandbLogger(name='lite_clip_bert_medium_resnet50d',project='liteCLIP')
logger.experiment.config.update({
    'projection_embedding_size': Config.proj_dim,
    'projection_dropout': Config.dropout,
    'max_tokens': Config.max_length
})

[34m[1mwandb[0m: Currently logged in as: [33mshreydan[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [25]:
early_stop = EarlyStopping(monitor='val_loss',patience=2,mode='min')
trainer = pl.Trainer(accelerator='gpu',
                     max_epochs=5,
                     devices=2, 
                     logger=logger,
                     callbacks=[early_stop]
                    )

In [26]:
logger.watch(model)
trainer.fit(model, train_dl, valid_dl)
metrics = trainer.logged_metrics
print(metrics)

[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`
  
  


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

{}


In [27]:
Config.save_path.mkdir(exist_ok=True)

In [28]:
torch.save(trainer.model.model.state_dict(),Config.save_path/'clip_model.pt')
tokenizer.save_pretrained(Config.save_path)

('liteCLIP/tokenizer_config.json',
 'liteCLIP/special_tokens_map.json',
 'liteCLIP/vocab.txt',
 'liteCLIP/added_tokens.json',
 'liteCLIP/tokenizer.json')

In [29]:
inference_model = CLIP(Config)

Some weights of the model checkpoint at google/bert_uncased_L-4_H-256_A-4 were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [30]:
inference_model.load_state_dict(torch.load(Config.save_path/'clip_model.pt'),)

<All keys matched successfully>

In [31]:
inference_model.eval()

CLIP(
  (image_encoder): ImageEncoder(
    (backbone): ResNet(
      (conv1): Sequential(
        (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (4): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (5): ReLU(inplace=True)
        (6): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act1): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d

In [32]:
test = torch.utils.data.DataLoader(val_ds,batch_size=2)
x = next(iter(test))

In [33]:
img,text = x
img.shape

torch.Size([2, 3, 224, 224])

In [34]:
with torch.no_grad():
    print(inference_model(x))

tensor([[51.4102, 25.6722],
        [21.7080, 39.4997]])


In [35]:
with torch.no_grad():
    img_encodings = inference_model.image_encoder(img)
    img_encodings = inference_model.img_projection(img_encodings)

In [36]:
from transformers import AutoConfig

In [37]:
te = AutoConfig.from_pretrained(Config.text_encoder)

In [38]:
te.save_pretrained(Config.save_path)

In [39]:
!ls liteCLIP

1wow53um       config.json		tokenizer.json	       vocab.txt
clip_model.pt  special_tokens_map.json	tokenizer_config.json


In [40]:
config = AutoConfig.from_pretrained(Config.save_path)
new_text_encoder = AutoModel.from_config(config)