In [1]:
!pip install transformers torch

Collecting transformers
  Downloading transformers-4.12.5-py3-none-any.whl (3.1 MB)
[K     |████████████████████████████████| 3.1 MB 20.4 MB/s eta 0:00:01
[?25hCollecting torch
  Downloading torch-1.10.0-cp36-cp36m-manylinux1_x86_64.whl (881.9 MB)
[K     |██████████████████████████████  | 827.9 MB 87.0 MB/s eta 0:00:01

IOPub data rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_data_rate_limit`.

Current values:
NotebookApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
NotebookApp.rate_limit_window=3.0 (secs)



[K     |████████████████████████████████| 881.9 MB 621 bytes/s 
Collecting huggingface-hub<1.0,>=0.1.0
  Downloading huggingface_hub-0.1.2-py3-none-any.whl (59 kB)
[K     |████████████████████████████████| 59 kB 12.9 MB/s eta 0:00:01
Collecting sacremoses
  Downloading sacremoses-0.0.46-py3-none-any.whl (895 kB)
[K     |████████████████████████████████| 895 kB 58.2 MB/s eta 0:00:01
Collecting tokenizers<0.11,>=0.10.1
  Downloading tokenizers-0.10.3-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (3.3 MB)
[K     |████████████████████████████████| 3.3 MB 64.3 MB/s eta 0:00:01
Installing collected packages: tokenizers, sacremoses, huggingface-hub, transformers, torch
Successfully installed huggingface-hub-0.1.2 sacremoses-0.0.46 tokenizers-0.10.3 torch-1.10.0 transformers-4.12.5
You should consider upgrading via the '/home/ec2-user/anaconda3/envs/python3/bin/python -m pip install --upgrade pip' command.[0m


In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW, get_linear_schedule_with_warmup
from tqdm.notebook import tqdm
import os
import joblib
import json
import logging

logger = logging.getLogger(__name__)

In [108]:
class CLIPDataset(Dataset):
    
    def __init__(self,data_path, prefix_length, tokenizer):
        self.tokenizer = tokenizer
        self.prefix_length = prefix_length
        with open(data_path, 'rb') as file:
            data = joblib.load(file)
        logger.info(f"Length of Data is {len(data['clip_embedding'])}")
        self.prefixes = data['clip_embedding']
        self.captions = data['captions']
        self.caption_tokens = []
        self.caption_to_embedding = []
        self.max_seq_len = 128
        #self.captions = data['captions']
        for caption in captions:
            self.caption_tokens.append(torch.tensor(self.tokenizer.encode(caption), dtype=torch.int64))
            self.caption_to_embedding.append(self.prefixes)
            max_seq_len = 512
            
        with open(f"{data_path[:-6]}_tokens.joblib", 'wb') as f:
            joblib.dump([self.caption_tokens, self.caption_to_embedding], f)
            
        #all_len = torch.tensor([len(self.captions_tokens[i]) for i in range(len(self))]).float()
        #self.max_seq_len = min(int(all_len.mean() + all_len.std() * 10), int(all_len.max()))
        
    def __len__(self):
        return len(self.caption_tokens)
    
    def pad_tokens(self, item):
        tokens = self.caption_tokens[item]
        padding = self.max_seq_len - tokens.shape[0]
        if padding > 0:
            tokens = torch.cat((tokens, torch.zeros(padding, dtype=torch.int64) - 1))
            self.caption_tokens[item] = tokens
        elif padding < 0:
            tokens = tokens[:self.max_seq_len]
            self.captions_tokens[item] = tokens
        mask = tokens.ge(0)  # mask is zero where we out of sequence
        tokens[~mask] = 0
        mask = mask.float()
        mask = torch.cat((torch.ones(self.prefix_length), mask), dim=0)  # adding prefix mask
        return tokens, mask
    
    def __getitem__(self, item) :
        tokens, mask = self.pad_tokens(item)
        prefixes_ = self.caption_to_embedding[item]
        #return tokens, mask, self.prefixes[self.caption_to_embedding[item]], self.caption[item]
        return tokens, mask,  prefixes_
                #return tokens, mask, self.captions[item]
    

In [109]:
tokenizer = GPT2Tokenizer.from_pretrained("surajp/gpt2-hindi")
dataset = CLIPDataset("clip_embeddings.joblib",512,tokenizer)

In [110]:
dataset

<__main__.CLIPDataset at 0x7f8cee00d3c8>

In [120]:
encoding = dataset[0]
print(len(encoding))


3


In [122]:

class MLP(nn.Module):

    def __init__(self, sizes,bias=True, act=nn.Tanh):
        super(MLP, self).__init__()
        layers = []
        for i in range(len(sizes) - 1):
            layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=bias))
            if i < len(sizes) - 2:
                layers.append(act())
        self.model = nn.Sequential(*layers)
        
    def forward(self, x):
        return self.model(x)


In [164]:
class ClipCaptionModel(nn.Module):
    
    def __init__(self, prefix_length, prefix_size=512):
        super().__init__()
        self.prefix_length = prefix_length
        self.gpt = GPT2LMHeadModel.from_pretrained('surajp/gpt2-hindi')
        self.gpt_embedding_size = self.gpt.transformer.wte.weight.shape[1]
        print(self.gpt_embedding_size)
        self.clip_project = MLP((prefix_size, (self.gpt_embedding_size * prefix_length) // 2,
                                 self.gpt_embedding_size * prefix_length))
        

    def get_dummy_token(self, batch_size, device):
        return torch.zeros(batch_size, self.prefix_length, dtype=torch.int64, device=device)

    def forward(self, tokens, prefix, mask):
        embedding_text = self.gpt.transformer.wte(tokens)
        prefix_projections = self.clip_project(prefix).view(-1, self.prefix_length, self.gpt_embedding_size)
        print(prefix_projections.shape)
        print(embedding_text.shape)
        embedding_cat = torch.cat((prefix_projections, embedding_text), dim=1)
        if labels is not None:
            dummy_token = self.get_dummy_token(tokens.shape[0], tokens.device)
            labels = torch.cat((dummy_token, tokens), dim=1)
        out = self.gpt(inputs_embeds=embedding_cat, labels=labels, attention_mask=mask)
        return out

In [168]:

def train(dataset, model,lr = 2e-5, warmup_steps=5000, output_dir="checkpoints", output_prefix="test"):

    device = torch.device('cuda:0')
    batch_size = 1
    epochs = 1
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    model = model.to(device)
    model.train()
    optimizer = AdamW(model.parameters(), lr=lr)
    train_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)
    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=warmup_steps, num_training_steps=epochs * len(train_dataloader)
    )
    for epoch in range(epochs):
        print(f">>> Training epoch {epoch}")
        progress = tqdm(total=len(train_dataloader), desc=output_prefix)
        for idx, (tokens, mask, prefix) in enumerate(train_dataloader):
            model.zero_grad()
            tokens, mask, prefix = tokens.to(device), mask.to(device), prefix.to(device, dtype=torch.float32)
            outputs = model(tokens, prefix, mask)
            logits = outputs.logits[:, dataset.prefix_length - 1: -1]
            loss = nnf.cross_entropy(logits.reshape(-1, logits.shape[-1]), tokens.flatten(), ignore_index=0)
            loss.backward()
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
            progress.set_postfix({"loss": loss.item()})
            progress.update()
            if (idx + 1) % 10000 == 0:
                torch.save(
                    model.state_dict(),
                    os.path.join(output_dir, f"{output_prefix}_latest.bin"),
                )
        if epoch % 100 == 0 or epoch == epochs - 1:
            torch.save(
                model.state_dict(),
                os.path.join(output_dir, f"{output_prefix}-{epoch:03d}.bin"),
            )
    return model

In [169]:
import torch
torch.cuda.empty_cache()

In [170]:
model = ClipCaptionModel(10)
train(dataset,model)

test:   0%|          | 0/20225 [01:44<?, ?it/s]
test:   0%|          | 0/20225 [00:50<?, ?it/s]


768
>>> Training epoch 0




test:   0%|          | 0/40450 [00:00<?, ?it/s][A[A

torch.Size([40450, 10, 768])
torch.Size([1, 128, 768])


RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 40450 but got size 1 for tensor number 1 in the list.