In [1]:
from timm import create_model
import numpy as np
import pandas as pd
import os
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from torchvision import transforms
import transformers
from transformers import DistilBertModel, DistilBertConfig, DistilBertTokenizer,\
        get_linear_schedule_with_warmup

import cv2

from PIL import Image
from tqdm.auto import tqdm

In [None]:
class Tokenizer:
    def __init__(self):
        tokenizer_load = "DeepPavlov/distilrubert-tiny-cased-conversational-v1"
        self.tokenizer = DistilBertTokenizer.from_pretrained(tokenizer_load)

    def tokenize(self, texts, max_len=77):
        tokenized = self.tokenizer.batch_encode_plus(texts,
                                                     truncation=True,
                                                     add_special_tokens=True,
                                                     max_length=max_len,
                                                     padding='max_length',
                                                     return_attention_mask=True,
                                                     return_tensors='pt')
        return torch.stack([tokenized["input_ids"], tokenized["attention_mask"]])


In [None]:
def get_transform():
    return transforms.Compose([
        transforms.Resize(224),
        transforms.CenterCrop(224),
        _convert_image_to_rgb,
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]), ])


def _convert_image_to_rgb(image):
    return image.convert("RGB")

In [None]:
class RuCLIPTinyDataset(Dataset):
    def __init__(self, dir, df_path, max_text_len=77):
        self.df = pd.read_csv(df_path)
        self.dir = dir
        self.max_text_len = max_text_len
        self.tokenizer = Tokenizer()
        self.transform = get_transform()

    def __getitem__(self, idx):
        # достаем имя изображения и ее лейбл
        image_name = self.df['image_name'].iloc[idx]
        text = self.df['text'].iloc[idx]
        tokens = self.tokenizer.tokenize([text], max_len=self.max_text_len)
        input_ids, attention_mask = tokens[0][0], tokens[1][0]
        image = cv2.imread(os.path.join(self.dir, image_name))
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = Image.fromarray(image)
        image = self.transform(image)
        return image, input_ids, attention_mask

    def __len__(self):
        return len(self.df)

In [None]:
class RuCLIPtiny(nn.Module):
    def __init__(self):
        super().__init__()
        self.visual = create_model('convnext_tiny',
                                   pretrained=False,
                                   num_classes=0,
                                   in_chans=3)  # out 768
        text_config = DistilBertConfig(**{"vocab_size": 30522,
                                          "max_position_embeddings": 512,
                                          "n_layers": 3,
                                          "n_heads": 12,
                                          "dim": 264,
                                          "hidden_dim": 792,
                                          "model_type": "distilbert"})
        self.transformer = DistilBertModel(text_config)
        self.final_ln = torch.nn.Linear(264, 768)
        self.logit_scale = torch.nn.Parameter(torch.ones([]) * np.log(1 / 0.07))

    @property
    def dtype(self):
        return self.visual.stem[0].weight.dtype

    def encode_image(self, image):
        return self.visual(image.type(self.dtype))

    def encode_text(self, input_ids, attention_mask):
        x = self.transformer(input_ids=input_ids, attention_mask=attention_mask)
        x = x.last_hidden_state[:, 0, :]
        x = self.final_ln(x)
        return x

    def forward(self, image, input_ids, attention_mask):
        image_features = self.encode_image(image)
        text_features = self.encode_text(input_ids, attention_mask)

        # normalized features
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        text_features = text_features / text_features.norm(dim=-1, keepdim=True)

        # cosine similarity as logits
        logit_scale = self.logit_scale.exp()
        logits_per_image = logit_scale * image_features @ text_features.t()
        logits_per_text = logits_per_image.t()

        return logits_per_image, logits_per_text

In [None]:
class Trainer:
    def __init__(self, train_dataframe, train_dir,
                 val_dataframe=None, val_dir=None, learning_rate=1e-4,
                 freeze_image_encoder=True, freeze_text_encoder=False, max_text_len=77,
                 train_batch_size=64, val_batch_size=64, num_workers=2,
                 weight_decay=1e-4, grad_accum=8):
        self.train_dataframe = train_dataframe
        self.train_dir = train_dir
        self.val_dataframe = val_dataframe
        self.val_dir = val_dir
        self.learning_rate = learning_rate
        self.freeze_image_encoder = freeze_image_encoder
        self.freeze_text_encoder = freeze_text_encoder
        self.max_text_len = max_text_len
        self.train_batch_size = train_batch_size
        self.val_batch_size = val_batch_size
        self.num_workers = num_workers
        self.weight_decay = weight_decay
        self.grad_accum = grad_accum
        print(f"train batch size = {self.train_batch_size * self.grad_accum}")

    def train_model(self, model, epochs_num=1, device='cuda', verbose=10):

        is_val = self.val_dataframe is not None and self.val_dir is not None

        model.to(device)

        train_dataset = RuCLIPTinyDataset(self.train_dir, self.train_dataframe, self.max_text_len)

        train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                                    batch_size=self.train_batch_size,
                                                    shuffle=True,
                                                    pin_memory=True,
                                                    num_workers=self.num_workers)

        if is_val:
            val_dataset = RuCLIPTinyDataset(self.val_dir, self.val_dataframe, self.max_text_len)
            val_loader = torch.utils.data.DataLoader(dataset=val_dataset,
                                                     batch_size=self.val_batch_size,
                                                     shuffle=False,
                                                     pin_memory=True,
                                                     num_workers=self.num_workers)

        for i, child in enumerate(model.children()):
            if (i == 0 and self.freeze_image_encoder) or (i == 1 and self.freeze_text_encoder):
                for param in child.parameters():
                    param.requires_grad = False

        loss_img = torch.nn.CrossEntropyLoss()
        loss_txt = torch.nn.CrossEntropyLoss()

        optimizer = torch.optim.AdamW(model.parameters(), lr=self.learning_rate, betas=(0.9, 0.98), eps=1e-8,
                                          weight_decay=self.weight_decay)
        total_steps = len(train_loader) * epochs_num
        scheduler = get_linear_schedule_with_warmup(optimizer,
                                                        num_warmup_steps=0,
                                                        num_training_steps=total_steps)
        for epoch in range(epochs_num):
            model.train()
            print(f'start training epoch {epoch}')
            curr_batch = 0
            X = []
            Y = []
            curr_batch = 0
            for i, batch in enumerate(tqdm(train_loader)):
                images = batch[0].cuda()
                input_ids = batch[1].cuda()
                attention_mask = batch[2].cuda()

                image_features = model.encode_image(images)
                text_features = model.encode_text(input_ids, attention_mask)

                image_features = image_features / image_features.norm(dim=-1, keepdim=True)
                text_features = text_features / text_features.norm(dim=-1, keepdim=True)

                X.append(image_features)
                Y.append(text_features)

                if ((i + 1) % self.grad_accum == 0) or (i + 1 == len(train_loader)):
                    optimizer.zero_grad()
                    X = torch.cat(X, axis=0).cuda()
                    Y = torch.cat(Y, axis=0).cuda()
                    logit_scale = model.logit_scale.exp()
                    logits_per_image = logit_scale * X @ Y.t()
                    logits_per_text = logits_per_image.t()
                    ground_truth = torch.arange(X.shape[0], dtype=torch.long).cuda()
                    img_l = loss_img(logits_per_image, ground_truth)
                    text_l = loss_txt(logits_per_text, ground_truth)
                    total_loss = (img_l + text_l) / 2
                    if curr_batch % verbose == 0:
                        print(f'{i}/{len(train_loader)} total_loss {total_loss}')
                    total_loss.backward()   
                    torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
                    optimizer.step()
                    scheduler.step()
                    
                    X = []
                    Y = []
                    curr_batch += 1
            if is_val:
                print(f'start val epoch {epoch}')
                total_loss = 0
                model.eval()
                with torch.no_grad():
                    for i, batch in enumerate(tqdm(val_loader)):
                        images = batch[0].to(device)
                        input_ids = batch[1].to(device)
                        attention_mask = batch[2].to(device)

                        logits_per_image, logits_per_text = model(images, input_ids, attention_mask)
                        ground_truth = torch.arange(batch[1].shape[0], dtype=torch.long).to(device)
                        img_l = loss_img(logits_per_image, ground_truth).item()
                        text_l = loss_txt(logits_per_text, ground_truth).item()
                        total_loss += (img_l + text_l) / 2
                    print(f'val loss = {total_loss / len(val_loader)}')
        return model


In [None]:
class Predictor:
    def __init__(self):
        self.tokenizer = Tokenizer()
        self.transform = get_transform()

    def prepare_images_features(self, model, images_path, device='cpu'):
        images_features = []
        for image_path in images_path:
            image = Image.open(image_path)
            image = self.transform(image)
            with torch.no_grad():
                image_features = model.encode_image(image.unsqueeze(0).to(device)).float().cpu()[0]
            images_features.append(image_features)
        images_features = torch.stack(images_features, axis=0)
        images_features /= images_features.norm(dim=-1, keepdim=True)
        return images_features.cpu()

    def prepare_text_features(self, model, texts, max_len=77, device='cpu'):
        texts_features = []
        for text in texts:
            tokens = self.tokenizer.tokenize([text], max_len)
            with torch.no_grad():
                text_features = model.encode_text(tokens[0].to(device), tokens[1].to(device)).float().cpu()[0]
            texts_features.append(text_features)
        texts_features = torch.stack(texts_features, axis=0)
        texts_features /= texts_features.norm(dim=-1, keepdim=True)
        return texts_features

    def __call__(self, model, images_path, classes, get_probs=False, max_len=77, device='cpu'):
        model.eval().to(device)
        image_features = self.prepare_images_features(model, images_path, device)
        texts_features = self.prepare_text_features(model, classes, max_len, device)
        text_probs = (1 * image_features @ texts_features.T).softmax(dim=-1)
        if get_probs:
            return text_probs
        else:
            return text_probs.argmax(-1)