In [2]:
!pip install wandb
!pip install bitsandbytes
!pip install ruclip==0.0.2
!pip install transformers==4.27.4
!pip install pycocotools
!pip install git+https://github.com/openai/CLIP.git
!pip install open_clip_torch

Collecting bitsandbytes
  Obtaining dependency information for bitsandbytes from https://files.pythonhosted.org/packages/d9/8d/b62d4fb02587e293e5b91b68bbcaa2d88c6a0360b622e9521d4bd07a20cd/bitsandbytes-0.41.3.post2-py3-none-any.whl.metadata
  Downloading bitsandbytes-0.41.3.post2-py3-none-any.whl.metadata (9.8 kB)
Downloading bitsandbytes-0.41.3.post2-py3-none-any.whl (92.6 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m92.6/92.6 MB[0m [31m13.4 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[?25hInstalling collected packages: bitsandbytes
Successfully installed bitsandbytes-0.41.3.post2
Collecting ruclip==0.0.2
  Downloading ruclip-0.0.2-py3-none-any.whl (14 kB)
Collecting huggingface-hub==0.2.1 (from ruclip==0.0.2)
  Downloading huggingface_hub-0.2.1-py3-none-any.whl (61 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.9/61.9 kB[0m [31m2.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting youtokentome~=1.0.6 (from ruclip==0.0.2)
  Downloadin

In [3]:
import torch
import torch.nn as nn
import os
import pickle
import sys
import argparse
import json
import ruclip
import clip, open_clip
import random
import io
import bitsandbytes as bnb
import wandb
import nltk
import numpy as np

from nltk.translate.bleu_score import corpus_bleu

from sklearn.model_selection import train_test_split

from datasets import load_dataset, load_metric

from torch.nn import functional as nnf
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast
from torch.utils.data import Subset


from transformers import GPT2Config, GPT2Model
from transformers import GPT2Tokenizer, GPT2LMHeadModel
from transformers import AdamW, get_linear_schedule_with_warmup
from transformers.optimization import Adafactor, AdafactorSchedule

from typing import Tuple, Optional, Union
from tqdm import tqdm, trange
from enum import Enum



In [4]:
manualSeed = 1337
#manualSeed = random.randint(1, 10000) # use if you want new results
random.seed(manualSeed)
torch.manual_seed(manualSeed)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [5]:
class MappingType(Enum):
    MLP = 'mlp'
    Transformer = 'transformer'

class MLP(nn.Module):
    def __init__(self, sizes: Tuple[int, ...], bias=True, act=nn.Tanh):
        super(MLP, self).__init__()
        layers = []
        for i in range(len(sizes) - 1):
            layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=bias))
            if i < len(sizes) - 2:
                layers.append(act())
        self.model = nn.Sequential(*layers)
    
    @autocast()  
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.model(x)


In [6]:
def freeze(
    model,
    freeze_emb=False,
    freeze_ln=False,
    freeze_attn=True,
    freeze_ff=True,
    freeze_other=False,
):
    
    for name, p in model.named_parameters():
    # freeze all parameters except the layernorm and positional embeddings
        name = name.lower()
        if 'ln' in name or 'norm' in name:
            p.requires_grad = not freeze_ln
        elif 'embeddings' in name:
            p.requires_grad = not freeze_emb
        elif 'mlp' in name:
            p.requires_grad = not freeze_ff
        elif 'attn' in name:
            p.requires_grad = not freeze_attn
        else:
            p.requires_grad = not freeze_other
           
    return model

In [7]:
gpt_model_name = 'sberbank-ai/rugpt3medium_based_on_gpt2'
class ClipCaptionModel(nn.Module):
    def __init__(
        self,
        prefix_length: int,
        clip_length: Optional[int] = None,
        prefix_size: int = 512,
        num_layers: int = 8,
        mapping_type: MappingType = MappingType.MLP
    ):
        super(ClipCaptionModel, self).__init__()
        self.prefix_length = prefix_length

        self.gpt = GPT2LMHeadModel.from_pretrained(gpt_model_name)
        self.gpt_embedding_size = self.gpt.transformer.wte.weight.shape[1]

        if mapping_type == MappingType.MLP:
            self.clip_project = MLP((
                prefix_size,
                self.gpt_embedding_size * prefix_length // 2,
                self.gpt_embedding_size * prefix_length
            ))
        else:
            self.clip_project = TransformerMapper(
                prefix_size,
                self.gpt_embedding_size,
                prefix_length,
                clip_length, 
                num_layers
            )

        
    def get_dummy_token(self, batch_size: int, device: torch.device) -> torch.Tensor:
        return torch.zeros(batch_size, self.prefix_length, dtype=torch.int64, device=device)
    
    @autocast() 
    def forward(
        self,        
        tokens: torch.Tensor,
        prefix: torch.Tensor,
        mask: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None
    ):
        embedding_text = self.gpt.transformer.wte(tokens)
        prefix_projections = self.clip_project(
            prefix.float()
        ).view(-1, self.prefix_length, self.gpt_embedding_size)

        embedding_cat = torch.cat((prefix_projections, embedding_text), dim=1)

        if labels is not None:
            dummy_token = self.get_dummy_token(tokens.shape[0], tokens.device)
            labels = torch.cat((dummy_token, tokens), dim=1)
        out = self.gpt(inputs_embeds=embedding_cat, labels=labels, attention_mask=mask, output_hidden_states = True)
        
        return out

class ClipCaptionPrefix(ClipCaptionModel):
    def parameters(self, recurse: bool = True):
        return self.clip_project.parameters()

    def train(self, mode: bool = True):
        super(ClipCaptionPrefix, self).train(mode)
        self.gpt.eval()
        return self

In [8]:
class CPU_Unpickler(pickle.Unpickler):
    def find_class(self, module, name):
        if module == 'torch.storage' and name == '_load_from_bytes':
            return lambda b: torch.load(io.BytesIO(b), map_location='cpu')
        else: return super().find_class(module, name)

In [9]:
class ClipCocoDataset(Dataset):
    def __init__(
        self,
        data_path: str,
        prefix_length=30,
        model_type = gpt_model_name,
        normalize_prefix=False,
        train=True,
    ):

        self.tokenizer = GPT2Tokenizer.from_pretrained(model_type)
        self.prefix_length = prefix_length
        self.normalize_prefix = normalize_prefix
        if train:
            with open(data_path, 'rb') as f:
                all_data = CPU_Unpickler(f).load() #pickle.load(f)
            print("Data size is %0d" % len(all_data["clip_embedding"]))
        else:
            with open(data_path, 'rb') as f:
                all_data = CPU_Unpickler(f).load() #pickle.load(f)
            print("Data size is %0d" % len(all_data["clip_embedding"]))

        sys.stdout.flush()
        self.prefixes = all_data["clip_embedding"]
        captions_raw = all_data["captions"]
        
        self.captions = captions_raw

        self.image_id = all_data["path_images"]

        self.captions_tokens = []
        self.caption2embedding = []
        max_seq_len = 0
        i = 0
        for caption in tqdm(captions_raw):
            self.captions_tokens.append(
                torch.tensor(self.tokenizer.encode(caption), dtype=torch.int64)
            )
            self.caption2embedding.append(self.prefixes[i])
            i += 1
            max_seq_len = max(max_seq_len, self.captions_tokens[-1].shape[0])

        all_len = torch.tensor([len(self.captions_tokens[i]) for i in range(len(self))]).float()
        self.max_seq_len = min(int(all_len.mean() + all_len.std() * 10), int(all_len.max()))

    def get_image(self, item):
        if self.train:
            path_img = f"/kaggle/input/train2014/train2014/{self.image_id[item]}"
        else:
            path_img = f"/kaggle/input/val2014/val2014/{self.image_id[item]}"
            
        image = cv2.imread(path_img)
        image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
        image.thumbnail((196, 196), Image.Resampling.LANCZOS)
        return image
    
    def pad_tokens(self, item: int):
        tokens = self.captions_tokens[item]
        padding = self.max_seq_len - tokens.shape[0]
        if padding > 0:
            tokens = torch.cat((tokens, torch.zeros(padding, dtype=torch.int64) - 1))
            #self.captions_tokens[item] = tokens
        elif padding < 0:
            tokens = tokens[:self.max_seq_len]
            #self.captions_tokens[item] = tokens
        mask = tokens.ge(0)  # mask is zero where we out of sequence
        tokens[~mask] = 0
        mask = mask.float()
        mask = torch.cat((torch.ones(self.prefix_length), mask), dim=0)  # adding prefix mask
        return tokens, mask
    
    def __len__(self) -> int:
        return len(self.captions_tokens)

    def __getitem__(self, item):
        tokens, mask = self.pad_tokens(item)
        prefix = self.prefixes[item]
        if self.normalize_prefix:
            prefix = prefix.float()
            prefix = prefix / prefix.norm(2, -1)
        return tokens, mask, prefix

In [10]:
def calc_bleu(y_pred, y_true):
    references = [[reference.split()] for reference in y_true]
    hypotheses = [hypothesis.split() for hypothesis in y_pred]
    # Рассчитываем BLEU-4
    bleu_score = corpus_bleu(references, hypotheses, weights=(0.25, 0.25, 0.25, 0.25))
    return bleu_score*100

## **TRAIN LOOP**

In [11]:
def train(
    train_dataset: ClipCocoDataset,
    train_dataloader,
    model: ClipCaptionModel,
    optimizer,
    scheduler,
    args,
    warmup_steps: int = 5000,
    output_dir: str = ".",
    output_prefix: str = "",   
):

    batch_size = args.bs
    epochs = args.epochs
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)



    model.train()

    mean_epoch_train_loss = []
    mean_bleu_train_epoch = []
    
    
    for epoch in range(epochs):
        loss_train_epoch = []
        bleu_train_epoch = []
        print(f">>> Training epoch {epoch+1}")
        sys.stdout.flush()
        progress = tqdm(total=len(train_dataloader), desc=output_prefix)
        step=0
        for idx, (tokens, mask, prefix) in enumerate(train_dataloader):
            model.zero_grad()
            step += 1
            tokens, mask, prefix = tokens.to(device), mask.to(device), prefix.to(device, dtype=torch.bfloat16)
            
            outputs = model(tokens, prefix, mask)
            logits = outputs.logits[:, train_dataset.prefix_length - 1: -1]

            loss = nnf.cross_entropy(
                logits.reshape(-1, logits.shape[-1]),
                tokens.flatten().to(torch.int64),
                ignore_index=0
            )

            loss.backward()
            optimizer.step()
            scheduler.step()

            loss_train_epoch.append(loss.item())
            optimizer.zero_grad()

            progress.set_postfix({"supervised_loss_train": loss.item()})
    
            if step % 500 == 0:
                wandb.log({"supervised_loss_train": loss.item()})
                wandb.log({"mean_supervised_loss_train":  np.mean(loss_train_epoch)})
            
            if step % 1000 == 0:
                with torch.no_grad():
                    # BLEU-4
                    logits_cpu = logits.cpu()
                    tokens_cpu = tokens.cpu()
                    generated_texts = []
                    real_text = []
                    for b in range(batch_size):
                        generated_text_batch = train_dataset.tokenizer.decode(logits_cpu[b].argmax(dim=-1).tolist())
                        first_dot_index = generated_text_batch.find('.')
                        if first_dot_index != -1:
                            generated_texts.append(generated_text_batch[35:first_dot_index + 1])
                        else:
                            generated_texts.append(generated_text_batch[35:])
                        
                        real_text_batch = train_dataset.tokenizer.decode(tokens_cpu[b].tolist())
                        first_pad_index = real_text_batch.find('<pad>')
                        if first_pad_index != -1:
                            real_text.append(real_text_batch[35:first_pad_index])
                        else:
                            real_text.append(real_text_batch[35:])
                    
                    bleu = calc_bleu(generated_texts, real_text)
                    wandb.log({"supervised_bleu-4_train":  bleu})
                    bleu_train_epoch.append(bleu)
                    wandb.log({"mean_supervised_bleu-4_train": np.mean(bleu_train_epoch)})

            progress.update()
            if (idx + 1) % 7000 == 0:
                torch.save(
                    model.state_dict(),
                    os.path.join(output_dir, f"{output_prefix}_latest_gpt2_medium.pt"),
                )
        progress.close()
        if epoch % args.save_every == 0:
            torch.save(
                model.state_dict(),
                os.path.join(output_dir, f"{output_prefix}-{(epoch+1):03d}_gpt2_medium.pt"),
            )
        mean_epoch_train_loss.append(np.mean(loss_train_epoch))
        mean_bleu_train_epoch.append(np.mean(bleu_train_epoch))
        
        wandb.log({"mean_epoch_sup_train_loss": mean_epoch_train_loss[-1]})
        wandb.log({"mean_bleu_sup_train_epoch": mean_bleu_train_epoch[-1]})

    return model

In [12]:
wandb.login()

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

  ········································


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [13]:
wandb.init(project="ClipCap_NAS", name="ruclip-prefixmlp-train")

[34m[1mwandb[0m: Currently logged in as: [33mrbeketov[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [14]:
class Args():
    def __init__(self):
        self.backbone = gpt_model_name
        self.train_data = "/kaggle/input/coco2014-ru-clip-embeddings/embeddings_ru_clip_train.pkl"
        self.valid_data = "/kaggle/input/coco2014-ru-clip-embeddings/embeddings_ru_clip_valid.pkl"
        self.out_dir = 'checkpoints'
        self.prefix = 'only_prefix'
        self.epochs = 1
        self.save_every = 1
        self.prefix_length = 30
        self.bs = 3
        self.only_prefix = False
        self.lr = 2e-5
        self.warmup_steps = 5000
args = Args()

In [15]:
train_dataset = ClipCocoDataset(args.train_data, args.prefix_length, train=True)

vocab.json:   0%|          | 0.00/1.61M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/1.27M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/574 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.25k [00:00<?, ?B/s]

Data size is 414113


100%|██████████| 414113/414113 [02:12<00:00, 3118.45it/s]


In [16]:
wandb.config = {
  "learning_rate": args.lr,
  "epochs": args.epochs,
  "batch_size": args.bs
}



model = ClipCaptionPrefix(args.prefix_length)
model = model.to(device)

config.json:   0%|          | 0.00/761 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/1.73G [00:00<?, ?B/s]

In [17]:
train_dataloader = DataLoader(
    train_dataset,
    batch_size=args.bs,
    shuffle=True,
    drop_last=True,
)

In [18]:
optimizer = AdamW(
        model.parameters(),
        lr=args.lr,
    )

scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=args.warmup_steps,
    num_training_steps=args.epochs * len(train_dataloader)
)



In [None]:
model = train(
    train_dataset,
    train_dataloader,
    model,
    optimizer,
    scheduler,
    args,
    warmup_steps=args.warmup_steps,
    output_dir=args.out_dir,
    output_prefix=args.prefix
)

>>> Training epoch 1


Corpus/Sentence contains 0 counts of 2-gram overlaps.
BLEU scores might be undesirable; use SmoothingFunction().
only_prefix:   2%|▏         | 2529/138037 [10:50<9:38:24,  3.90it/s, supervised_loss_train=1.67] 