In [13]:
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"


from timm import create_model
import numpy as np
import pandas as pd
import os
import torch
from torch import nn
from torch import optim, Tensor
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms
from torchinfo import summary
import transformers
from transformers import DistilBertModel, DistilBertConfig, DistilBertTokenizer,\
        get_linear_schedule_with_warmup
from transformers import AutoModel, AutoTokenizer

import cv2

from PIL import Image
from tqdm.auto import tqdm

import json
from itertools import product

import datasets
from datasets import Dataset, concatenate_datasets
import argparse
import requests

from io import BytesIO
from sklearn.model_selection import StratifiedKFold, train_test_split
from sklearn.metrics import balanced_accuracy_score, roc_auc_score, f1_score
import matplotlib.pyplot as plt
from IPython import display
import more_itertools

In [37]:
TABLE_DATASET_NAME = 'new_labeled.csv'
IMG_DATASET_NAME = 'images_7k'
BATCH_SIZE = 8

DATA_PATH = 'data/'
DEVICE='cuda' if torch.cuda.is_available() else 'cpu'

In [16]:
# MODEL_CKPT = 'siamese_fitted_10epochs_bert_turbo.pt'
# NAME_MODEL_NAME = 'DeepPavlov/distilrubert-tiny-cased-conversational-v1'
# DESCRIPTION_MODEL_NAME = 'sergeyzh/rubert-tiny-turbo'

MODEL_CKPT = 'siamese_fitted_10epochs_bert_tiny.pt'
NAME_MODEL_NAME = 'DeepPavlov/distilrubert-tiny-cased-conversational-v1'
DESCRIPTION_MODEL_NAME = 'cointegrated/rubert-tiny'

### RuCLIPtiny

In [28]:
class RuCLIPtiny(nn.Module):
    def __init__(self):
        super().__init__()
        self.visual = create_model('convnext_tiny',
                                   pretrained=False, # TODO: берём претрейн
                                   num_classes=0,
                                   in_chans=3)  # out 768

        self.transformer = AutoModel.from_pretrained(NAME_MODEL_NAME)
        name_model_output_shape = self.transformer.config.hidden_size  # dynamically get hidden size
        self.final_ln = torch.nn.Linear(name_model_output_shape, 768)  # now uses the transformer hidden size
        self.logit_scale = torch.nn.Parameter(torch.ones([]) * np.log(1 / 0.07))

    @property
    def dtype(self):
        return self.visual.stem[0].weight.dtype

    def encode_image(self, image):
        return self.visual(image.type(self.dtype))

    def encode_text(self, input_ids, attention_mask):
        x = self.transformer(input_ids=input_ids, attention_mask=attention_mask)
        x = x.last_hidden_state[:, 0, :]
        x = self.final_ln(x)
        return x

    def forward(self, image, input_ids, attention_mask):
        image_features = self.encode_image(image)
        text_features = self.encode_text(input_ids, attention_mask)

        # normalized features
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        text_features = text_features / text_features.norm(dim=-1, keepdim=True)

        # cosine similarity as logits
        logit_scale = self.logit_scale.exp()
        logits_per_image = logit_scale * image_features @ text_features.t()
        logits_per_text = logits_per_image.t()

        return logits_per_image, logits_per_text

In [29]:
def get_transform():
    return transforms.Compose([
        transforms.Resize(224),
        transforms.CenterCrop(224),
        _convert_image_to_rgb,
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]), ])

def _convert_image_to_rgb(image):
    return image.convert("RGB")

class Tokenizers:
    def __init__(self):
        self.name_tokenizer = AutoTokenizer.from_pretrained(NAME_MODEL_NAME)
        self.desc_tokenizer = AutoTokenizer.from_pretrained(DESCRIPTION_MODEL_NAME)

    def tokenize_name(self, texts, max_len=77):
        tokenized = self.name_tokenizer.batch_encode_plus(texts,
                                                     truncation=True,
                                                     add_special_tokens=True,
                                                     max_length=max_len,
                                                     padding='max_length',
                                                     return_attention_mask=True,
                                                     return_tensors='pt')
        return torch.stack([tokenized["input_ids"], tokenized["attention_mask"]])
    
    def tokenize_description(self, texts, max_len=77):
        tokenized = self.desc_tokenizer(texts,
                                        truncation=True,
                                        add_special_tokens=True,
                                        max_length=max_len,
                                        padding='max_length',
                                        return_attention_mask=True,
                                        return_tensors='pt')
        return torch.stack([tokenized["input_ids"], tokenized["attention_mask"]])

In [30]:
class SiameseRuCLIPDataset(torch.utils.data.Dataset):
    def __init__(
        self,
        df=None, labels=None, df_path=None,
        images_dir=DATA_PATH+'images/',
    ):
        # loads data either from path using `df_path` or directly from `df` argument
        self.df = pd.read_csv(df_path) if df_path is not None else df
        self.labels = labels
        self.images_dir = images_dir
        self.tokenizers = Tokenizers()
        self.transform = get_transform()
        # 
        self.max_len = 77
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]

        # Process names and descriptions
        name_tokens = self.tokenizers.tokenize_name([str(row.name_first), 
                                               str(row.name_second)], max_len=self.max_len)
        name_first = name_tokens[:, 0, :] # [input_ids, attention_mask]
        name_second = name_tokens[:, 1, :]
        desc_tokens = self.tokenizers.tokenize_description([str(row.description_first), 
                                               str(row.description_second)])
        desc_first = desc_tokens[:, 0, :] # [input_ids, attention_mask]
        desc_second = desc_tokens[:, 1, :]

        # Process images
        im_first_path = os.path.join(self.images_dir, row.image_name_first)
        im_first = cv2.imread(im_first_path)
        if im_first is None:
            raise FileNotFoundError(f"Image not found at {im_first_path}")
        im_first = cv2.cvtColor(im_first, cv2.COLOR_BGR2RGB)
        im_first = Image.fromarray(im_first)
        im_first = self.transform(im_first)

        im_second_path = os.path.join(self.images_dir, row.image_name_second)
        im_second = cv2.imread(im_second_path)
        if im_second is None:
            raise FileNotFoundError(f"Image not found at {im_second_path}")
        im_second = cv2.cvtColor(im_second, cv2.COLOR_BGR2RGB)
        im_second = Image.fromarray(im_second)
        im_second = self.transform(im_second)
        label = self.labels[idx]
        return im_first, name_first, desc_first, im_second, name_second, desc_second, label

    def __len__(self,):
        return len(self.df)

### SiameseRuCLIP

In [None]:
def average_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
        last_hidden = last_hidden_states.masked_fill(
            ~attention_mask[..., None].bool(), 0.0
        )
        return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]

class SiameseRuCLIP(nn.Module):
    def __init__(self,
                 preload_ruclip=False,
                 preload_model_name='cc12m_rubert_tiny_ep_1.pt', # 'cc12m_ddp_4mill_ep_4.pt'
                 device='cpu',
                 models_dir=DATA_PATH + 'train_results/'):
        super().__init__()
        self.ruclip = RuCLIPtiny()

        if preload_ruclip:
            std = torch.load(
                models_dir + preload_model_name,
                weights_only=True,
                map_location=device
            )
            self.ruclip.load_state_dict(std)
            self.ruclip = self.ruclip.to(device)
            self.ruclip.eval()

        self.description_transformer = AutoModel.from_pretrained(DESCRIPTION_MODEL_NAME)

        # Automatically infer dimensions:
        # For the vision encoder, we assume the timm model provides 'num_features'
        vision_dim = self.ruclip.visual.num_features  # e.g. 768 for convnext_tiny
        
        # For the name branch, use the output dimension of the final linear layer.
        name_dim = self.ruclip.final_ln.out_features   # e.g. 768
        
        # For the description transformer, take the hidden size from its configuration.
        desc_dim = self.description_transformer.config.hidden_size  # e.g. 312 for cointegrated/rubert-tiny
        
        # Compute the per-product embedding as the concatenation of the three modalities.
        per_product_dim = vision_dim + name_dim + desc_dim  # e.g. 768 + 768 + 312 = 1848
        head_input_dim = 2 * per_product_dim  # for a pair of products
        
        self.hidden_dim = per_product_dim  # storing the per-product dimension
        
        # Build the MLP head that takes concatenated features from two products.
        self.head = nn.Sequential(
            nn.Linear(head_input_dim, head_input_dim // 2),
            nn.ReLU(),
            nn.Linear(head_input_dim // 2, head_input_dim // 4),
            nn.ReLU(),
            nn.Linear(head_input_dim // 4, 2)
        )
        
    def encode_description(self, desc):
        # desc is [input_ids, attention_mask]
        last_hidden_states = self.description_transformer(desc[:, 0, :], desc[:, 1, :]).last_hidden_state
        attention_mask = desc[:, 1, :]
        # TODO: нужно ли делать пулинг, посмотреть на результаты
        return average_pool(last_hidden_states, attention_mask)
    
    def forward(self, im1, name1, desc1, im2, name2, desc2):
        image_emb1 = self.ruclip.encode_image(im1)
        image_emb2 = self.ruclip.encode_image(im2)
        name_emb1 = self.ruclip.encode_text(name1[:, 0, :], name1[:, 1, :])
        name_emb2 = self.ruclip.encode_text(name2[:, 0, :], name2[:, 1, :])
        desc_emb1 = self.encode_description(desc1) 
        desc_emb2 = self.encode_description(desc2)
        first_emb = torch.cat([image_emb1, name_emb1, desc_emb1], dim=1)
        second_emb = torch.cat([image_emb2, name_emb2, desc_emb2], dim=1)
        x = torch.cat([first_emb, second_emb], dim=1)
        out = self.head(x)
        return out


# Evaluation loop

## Run evaluation

In [32]:
# Load model

from pathlib import Path

model_ckpt_path = Path(DATA_PATH) / 'train_results' / MODEL_CKPT
std = torch.load(model_ckpt_path, map_location=DEVICE)
model = SiameseRuCLIP()
model.load_state_dict(std)

<All keys matched successfully>

In [None]:
# Load data

labeled = pd.read_csv(DATA_PATH + TABLE_DATASET_NAME)
images_dir = DATA_PATH + IMG_DATASET_NAME

y = labeled.label.values
X = labeled.drop(columns='label').copy()

test_ds = SiameseRuCLIPDataset(X, y, images_dir=images_dir)
test_dl = DataLoader(test_ds, batch_size=BATCH_SIZE)

In [None]:
def validation(model, valid_loader, score, device='cpu') -> float:
    correct_val = 0
    with torch.no_grad(): 
        model.eval()
        for data in tqdm(valid_loader):
            im1, name1, desc1, im2, name2, desc2, label = data 
            im1, name1, desc1, im2, name2, desc2 = im1.to(device), name1.to(device), desc1.to(device), im2.to(device), name2.to(device), desc2.to(device)
            out = model(im1, name1, desc1, im2, name2, desc2) 
            _, predicted = torch.max(out.data, -1)
            predicted = predicted.cpu().numpy()
            correct_val += score(label, predicted)
    return correct_val / len(valid_loader)

In [35]:
test_score = validation(model, test_dl, f1_score)

  0%|          | 0/621 [00:00<?, ?it/s]

In [43]:
from dotenv import load_dotenv
import wandb

load_dotenv()

wandb.init(
    project="product-matching",
    entity="overfit1010",
    name=f"test-{MODEL_CKPT}",
    config={
        "table_dataset_name": TABLE_DATASET_NAME,
        "img_dataset_name": IMG_DATASET_NAME,
        "model_ckpt": MODEL_CKPT,
        "name_model_name": NAME_MODEL_NAME,
        "description_model_name": DESCRIPTION_MODEL_NAME,
        "model_type": "siamese",
        "run_type": "test"
    }
)

# Optional: log summary metric
wandb.summary["test.f1_score"] = test_score
wandb.finish()

0,1
test.f1_score,0.00161
