In [None]:
!pip install wandb
#!pip install bitsandbytes
#!pip install ruclip==0.0.2
!pip install transformers==4.27.4
#!pip install pycocotools
#!pip install git+https://github.com/openai/CLIP.git
#!pip install open_clip_torch

In [None]:
import torch
import torch.nn as nn
import os
import pickle
import sys
import argparse
import json

import random
import io

import wandb
import nltk
import numpy as np

from nltk.translate.bleu_score import corpus_bleu

from sklearn.model_selection import train_test_split

from datasets import load_dataset, load_metric

from torch.nn import functional as nnf
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast
from torch.utils.data import Subset


from transformers import GPT2Config, GPT2Model
from transformers import GPT2Tokenizer, GPT2LMHeadModel
from transformers import AdamW, get_linear_schedule_with_warmup
from transformers.optimization import Adafactor, AdafactorSchedule

from typing import Tuple, Optional, Union
from tqdm import tqdm, trange
from enum import Enum

In [None]:
manualSeed = 1337
#manualSeed = random.randint(1, 10000) # use if you want new results
random.seed(manualSeed)
torch.manual_seed(manualSeed)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
class MlpTransformer(nn.Module):
    def __init__(self, in_dim, h_dim, out_d: Optional[int] = None, act=nnf.relu, dropout=0.):
        super().__init__()
        out_d = out_d if out_d is not None else in_dim
        self.fc1 = nn.Linear(in_dim, h_dim)
        self.act = act
        self.fc2 = nn.Linear(h_dim, out_d)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.dropout(x)
        return x

In [None]:
class MultiHeadAttention(nn.Module):

    def __init__(self, dim_self, dim_ref, num_heads, bias=True, dropout=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim_self // num_heads
        self.scale = head_dim ** -0.5
        self.to_queries = nn.Linear(dim_self, dim_self, bias=bias)
        self.to_keys_values = nn.Linear(dim_ref, dim_self * 2, bias=bias)
        self.project = nn.Linear(dim_self, dim_self)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, y=None, mask=None):
        y = y if y is not None else x
        b, n, c = x.shape
        _, m, d = y.shape
        # b n h dh
        queries = self.to_queries(x).reshape(b, n, self.num_heads, c // self.num_heads)
        # b m 2 h dh
        keys_values = self.to_keys_values(y).reshape(b, m, 2, self.num_heads, c // self.num_heads)
        keys, values = keys_values[:, :, 0], keys_values[:, :, 1]
        attention = torch.einsum('bnhd,bmhd->bnmh', queries, keys) * self.scale
        if mask is not None:
            if mask.dim() == 2:
                mask = mask.unsqueeze(1)
            attention = attention.masked_fill(mask.unsqueeze(3), float("-inf"))
        attention = attention.softmax(dim=2)
        
        out = torch.einsum('bnmh,bmhd->bnhd', attention, values).reshape(b, n, c)
        out = self.project(out)
        return out, attention

In [None]:
class TransformerLayer(nn.Module):

    def forward_with_attention(self, x, y=None, mask=None):
        x_, attention = self.attn(self.norm1(x), y, mask)
        x = x + x_
        x = x + self.mlp(self.norm2(x))
        return x, attention

    def forward(self, x, y=None, mask=None):
        x = x + self.attn(self.norm1(x), y, mask)[0]
        x = x + self.mlp(self.norm2(x))
        return x

    def __init__(self, dim_self, dim_ref, num_heads, mlp_ratio=4., bias=False, dropout=0., act=nnf.relu,
                 norm_layer: nn.Module = nn.LayerNorm):
        super().__init__()
        self.norm1 = norm_layer(dim_self)
        self.attn = MultiHeadAttention(dim_self, dim_ref, num_heads, bias=bias, dropout=dropout)
        self.norm2 = norm_layer(dim_self)
        self.mlp = MlpTransformer(dim_self, int(dim_self * mlp_ratio), act=act, dropout=dropout)


class Transformer(nn.Module):

    def forward_with_attention(self, x, y=None, mask=None):
        attentions = []
        for layer in self.layers:
            x, att = layer.forward_with_attention(x, y, mask)
            attentions.append(att)
        return x, attentions

    def forward(self, x, y=None, mask=None):
        for i, layer in enumerate(self.layers):
            if i % 2 == 0 and self.enc_dec: # cross
                x = layer(x, y)
            elif self.enc_dec:  # self
                x = layer(x, x, mask)
            else:  # self or cross
                x = layer(x, y, mask)
        return x

    def __init__(self, dim_self: int, num_heads: int, num_layers: int, dim_ref: Optional[int] = None,
                 mlp_ratio: float = 2., act=nnf.relu, norm_layer: nn.Module = nn.LayerNorm, enc_dec: bool = False):
        super(Transformer, self).__init__()
        dim_ref = dim_ref if dim_ref is not None else dim_self
        self.enc_dec = enc_dec
        if enc_dec:
            num_layers = num_layers * 2
        layers = []
        for i in range(num_layers):
            if i % 2 == 0 and enc_dec:  # cross
                layers.append(TransformerLayer(dim_self, dim_ref, num_heads, mlp_ratio, act=act, norm_layer=norm_layer))
            elif enc_dec:  # self
                layers.append(TransformerLayer(dim_self, dim_self, num_heads, mlp_ratio, act=act, norm_layer=norm_layer))
            else:  # self or cross
                layers.append(TransformerLayer(dim_self, dim_ref, num_heads, mlp_ratio, act=act, norm_layer=norm_layer))
        self.layers = nn.ModuleList(layers)

In [None]:
class TransformerMapper(nn.Module):

    def forward(self, x):
        x = self.linear(x).view(x.shape[0], self.clip_length, -1)
        prefix = self.prefix_const.unsqueeze(0).expand(x.shape[0], *self.prefix_const.shape)
        prefix = torch.cat((x, prefix), dim=1)
        out = self.transformer(prefix)[:, self.clip_length:]
        return out

    def __init__(self, dim_clip: int, dim_embedding: int, prefix_length: int, clip_length: int, num_layers: int = 8):
        super(TransformerMapper, self).__init__()
        self.clip_length = clip_length
        self.transformer = Transformer(dim_embedding, 8, num_layers)
        self.linear = nn.Linear(dim_clip, clip_length * dim_embedding)
        self.prefix_const = nn.Parameter(torch.randn(prefix_length, dim_embedding), requires_grad=True)

In [None]:

def freeze(
    model,
    freeze_emb=False,
    freeze_ln=True,
    freeze_attn=True,
    freeze_ff=True,
    freeze_other=False,
):
    
    for name, p in model.named_parameters():
    # freeze all parameters except the layernorm and positional embeddings
        name = name.lower()
        if 'ln' in name or 'norm' in name:
            p.requires_grad = not freeze_ln
        elif 'embeddings' in name:
            p.requires_grad = not freeze_emb
        elif 'mlp' in name:
            p.requires_grad = not freeze_ff
        elif 'attn' in name:
            p.requires_grad = not freeze_attn
        else:
            p.requires_grad = not freeze_other
           
    return model

In [None]:
from enum import Enum
class MappingType(Enum):
    MLP = 'mlp'
    Transformer = 'transformer'

In [None]:
gpt_model_name = 'sberbank-ai/rugpt3medium_based_on_gpt2'
class ClipCaptionModel(nn.Module):
    def __init__(
        self,
        prefix_length: int,
        clip_length: Optional[int] = 10,
        prefix_size: int = 512,
        num_layers: int = 8,
        mapping_type: MappingType = MappingType.Transformer
    ):
        super(ClipCaptionModel, self).__init__()
        self.prefix_length = prefix_length

        self.gpt = GPT2LMHeadModel.from_pretrained(gpt_model_name)
        self.gpt_embedding_size = self.gpt.transformer.wte.weight.shape[1]

        if mapping_type == MappingType.MLP:
            self.clip_project = MLP((
                prefix_size,
                self.gpt_embedding_size * prefix_length // 2,
                self.gpt_embedding_size * prefix_length
            ))
        else:
            self.clip_project = TransformerMapper(
                prefix_size,
                self.gpt_embedding_size,
                prefix_length,
                clip_length, 
                num_layers
            )

        
    def get_dummy_token(self, batch_size: int, device: torch.device) -> torch.Tensor:
        return torch.zeros(batch_size, self.prefix_length, dtype=torch.int64, device=device)
    
    @autocast() 
    def forward(
        self,        
        tokens: torch.Tensor,
        prefix: torch.Tensor,
        mask: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None
    ):
        embedding_text = self.gpt.transformer.wte(tokens)
        prefix_projections = self.clip_project(
            prefix.float()
        ).view(-1, self.prefix_length, self.gpt_embedding_size)

        embedding_cat = torch.cat((prefix_projections, embedding_text), dim=1)

        if labels is not None:
            dummy_token = self.get_dummy_token(tokens.shape[0], tokens.device)
            labels = torch.cat((dummy_token, tokens), dim=1)
        out = self.gpt(inputs_embeds=embedding_cat, labels=labels, attention_mask=mask, output_hidden_states = True)
        
        return out

class ClipCaptionPrefix(ClipCaptionModel):
    def parameters(self, recurse: bool = True):
        return self.clip_project.parameters()

    def train(self, mode: bool = True):
        super(ClipCaptionPrefix, self).train(mode)
        self.gpt.eval()
        return self

In [None]:
class CPU_Unpickler(pickle.Unpickler):
    def find_class(self, module, name):
        if module == 'torch.storage' and name == '_load_from_bytes':
            return lambda b: torch.load(io.BytesIO(b), map_location='cpu')
        else: return super().find_class(module, name)

In [None]:
class ClipCocoDataset(Dataset):
    def __init__(
        self,
        data_path: str,
        prefix_length=30,
        model_type = gpt_model_name,
        normalize_prefix=False,
        train=True,
    ):

        self.tokenizer = GPT2Tokenizer.from_pretrained(model_type)
        self.prefix_length = prefix_length
        self.normalize_prefix = normalize_prefix
        if train:
            with open(data_path, 'rb') as f:
                all_data = CPU_Unpickler(f).load() #pickle.load(f)
            print("Data size is %0d" % len(all_data["clip_embedding"]))
        else:
            with open(data_path, 'rb') as f:
                all_data = CPU_Unpickler(f).load() #pickle.load(f)
            print("Data size is %0d" % len(all_data["clip_embedding"]))

        sys.stdout.flush()
        self.prefixes = all_data["clip_embedding"]
        captions_raw = all_data["captions"]
        
        self.captions = captions_raw

        self.image_id = all_data["path_images"]

        self.captions_tokens = []
        self.caption2embedding = []
        max_seq_len = 0
        i = 0
        for caption in tqdm(captions_raw):
            self.captions_tokens.append(
                torch.tensor(self.tokenizer.encode(caption), dtype=torch.int64)
            )
            self.caption2embedding.append(self.prefixes[i])
            i += 1
            max_seq_len = max(max_seq_len, self.captions_tokens[-1].shape[0])

        all_len = torch.tensor([len(self.captions_tokens[i]) for i in range(len(self))]).float()
        self.max_seq_len = min(int(all_len.mean() + all_len.std() * 10), int(all_len.max()))

    def get_image(self, item):
        if self.train:
            path_img = f"/kaggle/input/train2014/train2014/{self.image_id[item]}"
        else:
            path_img = f"/kaggle/input/val2014/val2014/{self.image_id[item]}"
            
        image = cv2.imread(path_img)
        image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
        image.thumbnail((196, 196), Image.Resampling.LANCZOS)
        return image
    
    def pad_tokens(self, item: int):
        tokens = self.captions_tokens[item]
        padding = self.max_seq_len - tokens.shape[0]
        if padding > 0:
            tokens = torch.cat((tokens, torch.zeros(padding, dtype=torch.int64) - 1))
            #self.captions_tokens[item] = tokens
        elif padding < 0:
            tokens = tokens[:self.max_seq_len]
            #self.captions_tokens[item] = tokens
        mask = tokens.ge(0)  # mask is zero where we out of sequence
        tokens[~mask] = 0
        mask = mask.float()
        mask = torch.cat((torch.ones(self.prefix_length), mask), dim=0)  # adding prefix mask
        return tokens, mask
    
    def __len__(self) -> int:
        return len(self.captions_tokens)

    def __getitem__(self, item):
        tokens, mask = self.pad_tokens(item)
        prefix = self.prefixes[item]
        if self.normalize_prefix:
            prefix = prefix.float()
            prefix = prefix / prefix.norm(2, -1)
        return tokens, mask, prefix

In [None]:
def calc_bleu(y_pred, y_true):
    references = [[reference.split()] for reference in y_true]
    hypotheses = [hypothesis.split() for hypothesis in y_pred]
    # Рассчитываем BLEU-4
    bleu_score = corpus_bleu(references, hypotheses, weights=(0.25, 0.25, 0.25, 0.25))
    return bleu_score*100

In [None]:
def validation(
    validation_dataset: ClipCocoDataset,
    validation_dataloader,
    model: ClipCaptionModel,
    args,
    output_dir: str = ".",
    output_prefix: str = "",   
):

    batch_size = args.bs
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    model.eval()
    
    all_real_text = []
    all_generate_text = []

    bleu_validation = []
    print(f">>> Validation epoch")
    sys.stdout.flush()
    progress = tqdm(total=len(validation_dataloader), desc=output_prefix)
    step=0
    for idx, (tokens, mask, prefix) in enumerate(validation_dataloader):
        step += 1
        tokens, mask, prefix = tokens.to(device), mask.to(device), prefix.to(device, dtype=torch.bfloat16)
        with torch.no_grad():
            outputs = model(tokens, prefix, mask)
            logits = outputs.logits[:, validation_dataset.prefix_length - 1: -1]

            generated_texts = []
            real_text = []
            generated_text_batch = validation_dataset.tokenizer.decode(logits[0].argmax(dim=-1).tolist())
            first_dot_index = generated_text_batch.find('.')
            if first_dot_index != -1:
                generated_texts.append(generated_text_batch[35:first_dot_index + 1])
                all_generate_text.append(generated_text_batch[35:first_dot_index + 1])
            else:
                generated_texts.append(generated_text_batch[35:])
                all_generate_text.append(generated_text_batch[35:])

            real_text_batch = validation_dataset.tokenizer.decode(tokens[0].tolist())
            first_pad_index = real_text_batch.find('<pad>')
            if first_pad_index != -1:
                real_text.append(real_text_batch[35:first_pad_index])
                all_real_text.append(real_text_batch[35:first_pad_index])
            else:
                real_text.append(real_text_batch[35:])
                all_real_text.append(real_text_batch[35:])

    
            bleu = calc_bleu(generated_texts, real_text)
            progress.set_postfix({"bleu": bleu})
            wandb.log({"bleu-4_validation":  bleu})
    
            bleu_validation.append(bleu)
            wandb.log({"mean_bleu-4_valid": np.mean(bleu_validation)})
                
        progress.update()
    progress.close()

    return all_real_text, all_generate_text

In [None]:
wandb.login()

In [None]:
wandb.init(project="ClipCap_NAS", name="valid-ruclip-transformer")

In [None]:
class Args():
    def __init__(self):
        self.backbone = gpt_model_name
        self.train_data = "/kaggle/input/coco2014-ru-clip-embeddings/embeddings_ru_clip_train.pkl"
        self.valid_data = "/kaggle/input/coco2014-ru-clip-embeddings/embeddings_ru_clip_valid.pkl"
        self.out_dir = 'checkpoints'
        self.prefix = 'valid-transformer_gpt'
        self.epochs = 3
        self.save_every = 1
        self.prefix_length = 30
        self.bs = 1
        self.only_prefix = False
        self.lr = 2e-5
        self.warmup_steps = 5000
args = Args()

In [None]:
valid_dataset = ClipCocoDataset(args.valid_data, args.prefix_length, train=False)

In [None]:
wandb.config = {
  "batch_size": args.bs
}

model = ClipCaptionModel(args.prefix_length)
model_path = "/kaggle/input/5epoch-clipcap-transfformers/checkpoints/transformer_gpt_latest_gpt2_medium.pt"
model.load_state_dict(torch.load(model_path, map_location='cpu'))

model = model.to(device)

In [None]:
valid_dataloader = DataLoader(
    valid_dataset,
    batch_size=args.bs,
    shuffle=False,
    drop_last=False,
)

In [None]:
model = validation(
    valid_dataset,
    valid_dataloader,
    model,
    args,
    output_dir=args.out_dir,
    output_prefix=args.prefix
)