# Parameters

In [1]:
DATA_PATH = 'data'

IMG_DATASET_NAME = 'images_OZ_geo_5500'
EMB_DATASET_DIR = 'embeddings_OZ_geo_5500'
TABLE_DATASET_DIR = 'tables_OZ_geo_5500'
TABLE_DATASET_FILE = 'tables_OZ_geo_5500/OZ_geo_5500.csv'
TABLE_DATASET_FILES= [
    'Ozon_Crawler_Latest_info2025-04-07-12-57-51.xlsx',
    'Карты мира_озон.xlsx'
]

USE_GDRIVE = False # HACK

In [2]:
QUERY_SELLER = 'ИНТЕРТРЕЙД'

# SUBSET_QUERY_SKU = 2
# SUBSET_NONQUERY_SKU = 6

SUBSET_QUERY_SKU = None
SUBSET_NONQUERY_SKU = None

TOP_K = 50

In [3]:
import torch

# CLIP_MODEL = 'ruclip-vit-base-patch32-384'
CLIP_MODEL = None

# CLIP_SIAMESE_CONTRASTIVE_CKPT = 'siamese_contrastive_7k.pt'
CLIP_SIAMESE_CONTRASTIVE_CKPT = 'siamese_contrastive.pt'

# CLIP_SIAMESE_CKPT = 'siamese_fitted_10epochs_bert_tiny.pt'
# CLIP_SIAMESE_CKPT = 'siamese_fitted_10epochs_bert_turbo.pt'

COMPUTE_FINAL_EMBEDDINGS = True

SBERT_MODEL = 'all-distilroberta-v1'

# SBERT_BATCH_SIZE = 768 if torch.cuda.is_available() else 8
SBERT_BATCH_SIZE = 512 if torch.cuda.is_available() else 8 # lesser for larger TOP_K

RUCLIP_BATCH_SIZE = 512 if torch.cuda.is_available() else 8

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

# Imports

In [4]:
import pandas as pd

import requests
import os

import joblib
import xgboost as xgb
from datetime import date, timedelta
import numpy as np

import torch
from sentence_transformers import SentenceTransformer, util
from typing import List, Tuple
from PIL import Image
from io import BytesIO
import math

from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import balanced_accuracy_score

# import optuna

from pathlib import Path

In [5]:
if CLIP_MODEL is not None:
    try:
        import ruclip
    except ModuleNotFoundError:
        !pip install git+https://github.com/tony-pitchblack/ru-clip.git#egg=ruclip
        import ruclip

# Download data

In [6]:
try:
    import dotenv
except ImportError:
    !pip install python-dotenv

In [7]:
# Use tokens from .env

import os
from dotenv import load_dotenv

import huggingface_hub

load_dotenv()

HF_TOKEN = os.getenv("HF_TOKEN")
huggingface_hub.login(token=HF_TOKEN)


Note: Environment variable`HF_TOKEN` is set and is the current active token independently from the token you've just configured.


In [8]:
# tables_url = ''
images_url = 'https://drive.google.com/file/d/17gqYhKkkbo7zroI6Q4xebXqIYAdv8ydc/view?usp=sharing'
ckpt_url = 'https://drive.google.com/file/d/1iRQAUSMktsCdLdiUvf8YLt8XypN3-x9r/view?usp=sharing'

if USE_GDRIVE:
    !gdown --fuzzy {tables_url} -O data/tmp.zip
    !gdown --fuzzy {ckpt_url} -O data/tmp_2.zip
    !gdown --fuzzy {images_url} -O data/{IMG_DATASET_NAME}.zip

    !unzip -o -q data/tmp.zip -d data/
    !unzip -o -q data/tmp_2.zip -d data/
    !unzip -o -q data/{IMG_DATASET_NAME}.zip -d data/

In [9]:
# Download models' weights & text/image datasets from HF

if not USE_GDRIVE:
    from huggingface_hub import snapshot_download
    from pathlib import Path

    REPO_ID = "INDEEPA/clip-siamese"
    LOCAL_DIR = Path("data/train_results")
    LOCAL_DIR.mkdir(parents=True, exist_ok=True)

    snapshot_download(
        repo_id=REPO_ID,
        repo_type='dataset',
        local_dir='data',
        allow_patterns=[
            'train_results/**',
            f"*{EMB_DATASET_DIR}/**",
            f"*{TABLE_DATASET_DIR}/**",
            f"{IMG_DATASET_NAME}.zip"
        ],
    )

    !unzip -o -q data/{IMG_DATASET_NAME}.zip -d data/

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Fetching 66 files:   0%|          | 0/66 [00:00<?, ?it/s]

# Prepare data

In [10]:
DATA_PATH = 'data'

file_path = (
    Path(DATA_PATH) /
    Path('tables_OZ_geo_5500') /
    'Ozon_Crawler_Latest_info2025-04-07-12-57-51.xlsx'
)

descr_source_df = pd.read_excel(file_path)
descr_source_df.columns.tolist()

['Sku (Sku)',
 'Фото (CoverImage)',
 'Название товара (ProductName)',
 'Продавец (SellerName)',
 'Бренд (BrandName)',
 'Название категории (CategoryName)',
 'Цена соинвест (DiscountPrice)',
 'Цена по карте (OzonCardPrice)',
 'Сток FBO (StockFbm)',
 'Сток FBS (StockFbs)',
 'Ошибка загрузки (CrawlerError)',
 'Валюта (Currency)',
 'Цена до скидки (BasePrice)',
 'Рейтинг товара (AvgRating)',
 'Количество отзывов (Reviews)',
 'Описание (Description)',
 'Rich-контент (RichContent)',
 'Ссылка на карточку',
 'SellerProductId (SellerProductId)']

In [11]:
import re

# Function to extract Latin name and convert to snake_case
def extract_and_convert(col_name):
    match = re.search(r'\(([^)]+)\)', col_name)
    if match:
        camel = match.group(1)
    else:
        camel = col_name
    # Convert CamelCase to snake_case
    snake = re.sub(r'(?<!^)(?=[A-Z])', '_', camel).lower()
    return snake

# Apply renaming
descr_source_df.rename(columns={col: extract_and_convert(col) for col in descr_source_df.columns}, inplace=True)

# Check the result
print("Renamed columns:")
descr_source_df.columns.tolist()

Renamed columns:


['sku',
 'cover_image',
 'product_name',
 'seller_name',
 'brand_name',
 'category_name',
 'discount_price',
 'ozon_card_price',
 'stock_fbm',
 'stock_fbs',
 'crawler_error',
 'currency',
 'base_price',
 'avg_rating',
 'reviews',
 'description',
 'rich_content',
 'ссылка на карточку',
 'seller_product_id']

In [12]:
DATA_PATH = 'data'

file_path = (
    Path(DATA_PATH) /
    Path('tables_OZ_geo_5500') /
    'Карты мира_озон.xlsx'
)

source_df = pd.read_excel(file_path)
source_df.columns.tolist()

['SKU',
 'Name',
 'Category',
 'Схема',
 'Brand',
 'Niche',
 'Seller',
 'Balance',
 'Balance FBS',
 'Warehouses count',
 'Comments',
 'Final price',
 'Max price',
 'Min price',
 'Average price',
 'Median price',
 'Цена с Ozon картой',
 'Sales',
 'Revenue',
 'Revenue potential',
 'Revenue average',
 'Lost profit',
 'Lost profit percent',
 'URL',
 'Thumb',
 'Pics Count',
 'Has Video',
 'First Date',
 'Days in website',
 'Days in stock',
 'Days with sales',
 'Average if in stock',
 'Rating',
 'FBS',
 'Base price',
 'Category Position',
 'Categories Last Count',
 'Sales Per Day Average',
 'Turnover',
 'Frozen stocks',
 'Frozen stocks cost',
 'Frozen stocks percent']

In [13]:
all_required_cols = [
    'balance_first',
    'sales_first',
    'rating_first',
    'final_price_first',
    'comments_first',
    'description_first',
    'name_first',
    'options_first',
    'sku_first',
    'has_video_first',
    'photo_count_first',

    'balance_second', # Balance
    'sales_second',
    'rating_second', # AvgRating
    'final_price_second', # DiscountPrice,
    'comments_second', # Reviews
    'description_second',
    'name_second', # ProductName
    'options_second',
    'sku_second',
    'has_video_second',
    'photo_count_second',

    # 'image_url_first',
    # 'image_url_second',

    'iseq_vendor', # 0
    'iseq_color', # 0
    'iseq_brand', # BrandName
    'iseq_supp', # 0
    'are_related', # 0

    'desc_sim',
    'opt_sim',
    'name_sim',
    'img_sim',

    'label'
]

In [14]:
new_source_df_all = source_df.rename(
    columns={
        col: col.lower().replace(" ", "_")
        for col in source_df.columns
    }
)

new_source_df_all.columns.tolist()

['sku',
 'name',
 'category',
 'схема',
 'brand',
 'niche',
 'seller',
 'balance',
 'balance_fbs',
 'warehouses_count',
 'comments',
 'final_price',
 'max_price',
 'min_price',
 'average_price',
 'median_price',
 'цена_с_ozon_картой',
 'sales',
 'revenue',
 'revenue_potential',
 'revenue_average',
 'lost_profit',
 'lost_profit_percent',
 'url',
 'thumb',
 'pics_count',
 'has_video',
 'first_date',
 'days_in_website',
 'days_in_stock',
 'days_with_sales',
 'average_if_in_stock',
 'rating',
 'fbs',
 'base_price',
 'category_position',
 'categories_last_count',
 'sales_per_day_average',
 'turnover',
 'frozen_stocks',
 'frozen_stocks_cost',
 'frozen_stocks_percent']

In [15]:
# Combine balance columns

new_source_df_all = new_source_df_all.rename(
    columns={
        'balance': 'balance_fbm'
    }
)

new_source_df_all['balance'] = (
    new_source_df_all['balance_fbm'] +
    new_source_df_all['balance_fbs']
)

new_source_df_all[['balance_fbm', 'balance_fbs', 'balance']].describe()

Unnamed: 0,balance_fbm,balance_fbs,balance
count,5703.0,5703.0,5703.0
mean,2.063826,267.155006,269.218832
std,21.677691,534.741939,534.336187
min,0.0,0.0,0.0
25%,0.0,6.0,8.0
50%,0.0,73.0,79.0
75%,0.0,104.0,124.0
max,853.0,11107.0,11108.0


In [16]:
required_cols = [
    'balance',
    'sales',
    'final_price',
    'rating',
    'comments',
    # 'description',
    'name',
    # 'options'
    'sku',
    'has_video',
    'pics_count',
    'seller',
    'brand',
    'url'
]

new_source_df_all = (
    new_source_df_all[required_cols]
    .rename(columns={'pics_count': 'photo_count'})
)

new_source_df_all.head(1)

Unnamed: 0,balance,sales,final_price,rating,comments,name,sku,has_video,photo_count,seller,brand,url
0,346,156,1811,4.8,5227,Карта мира географическая политическая интерак...,936454663,0,4,GooDaY,,https://www.ozon.ru/context/detail/id/936454663/


In [17]:
# Extract image id from URL

descr_source_df['image_id'] = descr_source_df['cover_image'].dropna().apply(
    lambda s: re.search(r'/(\d+)\.jpg$', str(s)).group(1)
)

descr_source_df.dropna(subset='image_id', inplace=True)
descr_source_df[['image_id', 'sku']]

Unnamed: 0,image_id,sku
0,7323783851,1871769771
1,7394308097,1679550303
2,7299023048,1200553001
3,7388534766,922231521
4,7295079927,922230517
...,...,...
5560,6008538837,166584090
5561,6008438667,166451882
5562,7439544697,154409524
5563,7098349497,147896031


In [18]:
descr_source_df.columns.tolist()

['sku',
 'cover_image',
 'product_name',
 'seller_name',
 'brand_name',
 'category_name',
 'discount_price',
 'ozon_card_price',
 'stock_fbm',
 'stock_fbs',
 'crawler_error',
 'currency',
 'base_price',
 'avg_rating',
 'reviews',
 'description',
 'rich_content',
 'ссылка на карточку',
 'seller_product_id',
 'image_id']

In [19]:
new_source_df_all = new_source_df_all.merge(
    descr_source_df[['sku', 'description', 'image_id']],
    on='sku'
)

new_source_df_all['options'] = new_source_df_all['name']
new_source_df_all.columns.tolist()

['balance',
 'sales',
 'final_price',
 'rating',
 'comments',
 'name',
 'sku',
 'has_video',
 'photo_count',
 'seller',
 'brand',
 'url',
 'description',
 'image_id',
 'options']

In [20]:
new_source_df_all['description'] = (
    new_source_df_all['description']
    .fillna(new_source_df_all['name'])
)

In [21]:
# Take a subset: all query sku and some non-query sku

query_df = new_source_df_all[new_source_df_all.seller == QUERY_SELLER]
if SUBSET_QUERY_SKU is not None:
    query_df = query_df.sample(n=SUBSET_QUERY_SKU)

nonquery_df = new_source_df_all[~(new_source_df_all.seller == QUERY_SELLER)]
if SUBSET_NONQUERY_SKU is not None:
    nonquery_df = nonquery_df.sample(n=SUBSET_NONQUERY_SKU)

new_source_df = pd.concat([
    query_df,
    nonquery_df
]).reset_index(drop=True)

len(new_source_df), len(new_source_df_all)

(5562, 5562)

# Find top-k embeddings

## Init CLIP model

In [22]:
if CLIP_MODEL is not None:
    import ruclip
    clip, processor = ruclip.load(CLIP_MODEL, device=DEVICE)

### RuCLIPtiny

In [23]:
import os
# os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
# os.environ["CUDA_VISIBLE_DEVICES"] = "1"


from timm import create_model
import numpy as np
import pandas as pd
import os
import torch
from torch import nn
from torch import optim, Tensor
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms
# from torchinfo import summary
import transformers
from transformers import DistilBertModel, DistilBertConfig, DistilBertTokenizer,\
        get_linear_schedule_with_warmup
from transformers import AutoModel, AutoTokenizer

import cv2

from PIL import Image
from tqdm.auto import tqdm

import json
from itertools import product

# import datasets
# from datasets import Dataset, concatenate_datasets
import argparse
import requests

from io import BytesIO
from sklearn.model_selection import StratifiedKFold, train_test_split
from sklearn.metrics import balanced_accuracy_score, roc_auc_score, f1_score
import matplotlib.pyplot as plt
import more_itertools

In [24]:
class RuCLIPtiny(nn.Module):
    def __init__(self, name_model_name: str):
        """
        Initializes the RuCLIPtiny module using the provided name model.
        """
        super().__init__()
        self.visual = create_model('convnext_tiny',
                                   pretrained=False,  # set True if you want pretrained weights
                                   num_classes=0,
                                   in_chans=3)       # output: e.g. 768-dim features

        self.transformer = AutoModel.from_pretrained(name_model_name)
        name_model_output_size = self.transformer.config.hidden_size  # inferred dynamically
        self.final_ln = nn.Linear(name_model_output_size, 768)         # project to 768 dims
        self.logit_scale = nn.Parameter(torch.ones([]) * torch.log(torch.tensor(1/0.07)))

    @property
    def dtype(self):
        return self.visual.stem[0].weight.dtype

    def encode_image(self, image: torch.Tensor) -> torch.Tensor:
        return self.visual(image.type(self.dtype))

    def encode_text(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
        x = self.transformer(input_ids=input_ids, attention_mask=attention_mask)
        # use the CLS token (first token)
        x = x.last_hidden_state[:, 0, :]
        x = self.final_ln(x)
        return x

    def forward(self, image: torch.Tensor, input_ids: torch.Tensor, attention_mask: torch.Tensor):
        image_features = self.encode_image(image)
        text_features = self.encode_text(input_ids, attention_mask)
        # Normalize features
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        text_features = text_features / text_features.norm(dim=-1, keepdim=True)
        logit_scale = self.logit_scale.exp()
        logits_per_image = logit_scale * image_features @ text_features.t()
        logits_per_text = logits_per_image.t()
        return logits_per_image, logits_per_text


In [25]:
def get_transform():
    return transforms.Compose([
        transforms.Resize(224),
        transforms.CenterCrop(224),
        lambda image: image.convert("RGB"),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
    ])

class Tokenizers:
    def __init__(self, name_model_name: str, description_model_name: str):
        self.name_tokenizer = AutoTokenizer.from_pretrained(name_model_name)
        self.desc_tokenizer = AutoTokenizer.from_pretrained(description_model_name)

    def tokenize_name(self, texts, max_len=77):
        tokenized = self.name_tokenizer.batch_encode_plus(
            texts,
            truncation=True,
            add_special_tokens=True,
            max_length=max_len,
            padding='max_length',
            return_attention_mask=True,
            return_tensors='pt'
        )
        return torch.stack([tokenized["input_ids"], tokenized["attention_mask"]])

    def tokenize_description(self, texts, max_len=77):
        tokenized = self.desc_tokenizer(
            texts,
            truncation=True,
            add_special_tokens=True,
            max_length=max_len,
            padding='max_length',
            return_attention_mask=True,
            return_tensors='pt'
        )
        return torch.stack([tokenized["input_ids"], tokenized["attention_mask"]])



In [26]:
from transformers import AutoTokenizer
import torch

class NameTokenizer:
    def __init__(self, model_name: str):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)

    def tokenize(self, texts, max_len=77):
        tokenized = self.tokenizer.batch_encode_plus(
            texts,
            truncation=True,
            add_special_tokens=True,
            max_length=max_len,
            padding='max_length',
            return_attention_mask=True,
            return_tensors='pt'
        )
        return torch.stack([tokenized["input_ids"], tokenized["attention_mask"]])


class DescriptionTokenizer:
    def __init__(self, model_name: str):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)

    def tokenize(self, texts, max_len=77):
        tokenized = self.tokenizer(
            texts,
            truncation=True,
            add_special_tokens=True,
            max_length=max_len,
            padding='max_length',
            return_attention_mask=True,
            return_tensors='pt'
        )
        return torch.stack([tokenized["input_ids"], tokenized["attention_mask"]])


In [27]:
class SiameseRuCLIPDataset(torch.utils.data.Dataset):
    def __init__(self, images_dir: str, name_model_name: str, description_model_name: str, df=None, labels=None, df_path=None):
        """
        Dataset requires the concrete models' names for tokenization.
        """
        assert os.path.isdir(images_dir), f"Image dir does not exist: '{self.images_dir}'"

        self.df = pd.read_csv(df_path) if df_path is not None else df
        self.labels = labels
        self.images_dir = images_dir
        self.tokenizers = Tokenizers(name_model_name, description_model_name)
        self.transform = get_transform()
        self.max_len = 77

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        # Tokenize names
        name_tokens = self.tokenizers.tokenize_name([str(row.name_first), str(row.name_second)], max_len=self.max_len)
        name_first = name_tokens[:, 0, :]  # [input_ids, attention_mask]
        name_second = name_tokens[:, 1, :]
        # Tokenize descriptions
        desc_tokens = self.tokenizers.tokenize_description([str(row.description_first), str(row.description_second)])
        desc_first = desc_tokens[:, 0, :]
        desc_second = desc_tokens[:, 1, :]
        # Process images
        im_first_path = os.path.join(self.images_dir, row.image_name_first)
        im_first = cv2.imread(im_first_path)
        im_first = cv2.cvtColor(im_first, cv2.COLOR_BGR2RGB)
        im_first = Image.fromarray(im_first)
        im_first = self.transform(im_first)
        im_second_path = os.path.join(self.images_dir, row.image_name_first)
        im_second = cv2.imread(os.path.join(im_second_path))
        im_second = cv2.cvtColor(im_second, cv2.COLOR_BGR2RGB)
        im_second = Image.fromarray(im_second)
        im_second = self.transform(im_second)
        label = self.labels[idx]
        return im_first, name_first, desc_first, im_second, name_second, desc_second, label

    def __len__(self):
        return len(self.df)

In [28]:
import os
import pandas as pd
import torch
from torch.utils.data import Dataset
from PIL import Image
import cv2

class RuCLIPDataset(torch.utils.data.Dataset):
    def __init__(
        self,
        images_dir: str,
        name_model_name: str, description_model_name: str,
        df=None, labels=None, df_path=None,
        names_as_descriptions=False,
    ):
        """
        Dataset requires the concrete models' names for tokenization.
        """
        assert os.path.isdir(images_dir), f"Image dir does not exist: '{images_dir}'"

        self.df = pd.read_csv(df_path) if df_path is not None else df
        self.labels = labels
        self.images_dir = images_dir
        self.tokenizers = Tokenizers(name_model_name, description_model_name)
        self.transform = get_transform()
        self.max_len = 77
        self.names_as_descriptions = names_as_descriptions

    def __getitem__(self, idx):
        row = self.df.iloc[idx]

        # Tokenize name
        name_tokens = self.tokenizers.tokenize_name([str(row['name'])], max_len=self.max_len)
        name = name_tokens[:, 0, :]  # [input_ids, attention_mask]
        # print('name', row['name'])

        # Tokenize description
        if self.names_as_descriptions:
            desc = name
        else:
            # print('description', row.description)
            desc_tokens = self.tokenizers.tokenize_description([str(row.description)])
            desc = desc_tokens[:, 0, :]

        # Process image
        im_path = os.path.join(self.images_dir, row.image_name)
        im = cv2.imread(im_path)
        im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
        im = Image.fromarray(im)
        im = self.transform(im)
        return im, name, desc

    def __len__(self):
        return len(self.df)

### SiameseRuCLIP

In [29]:
def average_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
    last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
    return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]

class SiameseContrastiveRuCLIP(nn.Module):
    def __init__(self,
                 device: str,
                 name_model_name: str,
                 description_model_name: str,
                 models_dir: str = None,
                 preload_ruclip: bool = False,
                 preload_model_name: str = None):
        """
        Initializes the SiameseContrastiveRuCLIP model.
        Required parameters:
          - models_dir: directory containing saved checkpoints.
          - name_model_name: model name for text (name) branch.
          - description_model_name: model name for description branch.
        """
        super().__init__()
        device = torch.device(device)

        # Initialize RuCLIPtiny
        self.ruclip = RuCLIPtiny(name_model_name)
        if preload_ruclip:
            std = torch.load(
                os.path.join(models_dir, preload_model_name),
                weights_only=True,
                map_location=device
            )
            self.ruclip.load_state_dict(std)
            self.ruclip.eval()
        self.ruclip = self.ruclip.to(device)

        # Initialize the description transformer
        self.description_transformer = AutoModel.from_pretrained(description_model_name)
        self.description_transformer = self.description_transformer.to(device)

        # Determine dimensionality
        vision_dim = self.ruclip.visual.num_features
        name_dim = self.ruclip.final_ln.out_features
        desc_dim = self.description_transformer.config.hidden_size
        self.hidden_dim = vision_dim + name_dim + desc_dim

        # Define MLP head
        self.head = nn.Sequential(
            nn.Linear(self.hidden_dim, self.hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(self.hidden_dim // 2, self.hidden_dim // 4),
        ).to(device)

    def encode_image(self, image):
        return self.ruclip.encode_image(image)

    def encode_name(self, name):
        return self.ruclip.encode_text(name[:, 0, :], name[:, 1, :])

    def encode_description(self, desc):
        last_hidden_states = self.description_transformer(desc[:, 0, :], desc[:, 1, :]).last_hidden_state
        attention_mask = desc[:, 1, :]
        return average_pool(last_hidden_states, attention_mask)

    def get_final_embedding(self, im, name, desc):
        image_emb = self.encode_image(im)
        name_emb = self.encode_name(name)
        desc_emb = self.encode_description(desc)

        # Concatenate the embeddings and forward through the head
        combined_emb = torch.cat([image_emb, name_emb, desc_emb], dim=1)
        final_embedding = self.head(combined_emb)
        return final_embedding

    def forward(self, im1, name1, desc1, im2, name2, desc2):
        out1 = self.get_final_embedding(im1, name1, desc1)
        out2 = self.get_final_embedding(im2, name2, desc2)
        return out1, out2

In [30]:
siamese_contrastive_model_configs = [
    dict(
        MODEL_CKPT = 'siamese_contrastive.pt',

        NAME_MODEL_NAME = 'cointegrated/rubert-tiny',
        # NAME_MODEL_NAME = 'DeepPavlov/distilrubert-tiny-cased-conversational-v1',

        DESCRIPTION_MODEL_NAME = 'cointegrated/rubert-tiny',
        # DESCRIPTION_MODEL_NAME = 'sergeyzh/rubert-tiny-turbo',

        CONTRASTIVE_THRESHOLD=0.3,
    ),

    dict(
        MODEL_CKPT = 'siamese_contrastive_7k.pt',

        NAME_MODEL_NAME = 'cointegrated/rubert-tiny',
        # NAME_MODEL_NAME = 'DeepPavlov/distilrubert-tiny-cased-conversational-v1',

        DESCRIPTION_MODEL_NAME = 'cointegrated/rubert-tiny',

        CONTRASTIVE_THRESHOLD=0.3,
    ),


    # dict(
    #     MODEL_CKPT = 'siamese_contrastive_1gpu.pt',

    #     NAME_MODEL_NAME = 'cointegrated/rubert-tiny',
    #     # NAME_MODEL_NAME = 'DeepPavlov/distilrubert-tiny-cased-conversational-v1',

    #     # DESCRIPTION_MODEL_NAME = 'sergeyzh/rubert-tiny-turbo',
    #     DESCRIPTION_MODEL_NAME = 'cointegrated/rubert-tiny',

    #     CONTRASTIVE_THRESHOLD=0.3,
    # ),
]

# Map the model checkpoints to their configs
siamese_contrastive_model_configs = {
    config['MODEL_CKPT']: config
    for config in siamese_contrastive_model_configs
}

In [31]:
# Load siamese model

def load_model(model_config, is_contrastive=False):
    if is_contrastive:
        model_class = SiameseContrastiveRuCLIP
    else:
        raise NotImplementedError

    ckpt_name = model_config['MODEL_CKPT']
    model_ckpt_path = Path(DATA_PATH) / 'train_results' / ckpt_name
    std = torch.load(model_ckpt_path, map_location=DEVICE)

    # Initialize the model using the configuration.
    model = model_class(
        name_model_name=model_config["NAME_MODEL_NAME"],
        description_model_name=model_config["DESCRIPTION_MODEL_NAME"],
        device=DEVICE,
    )

    model.load_state_dict(std)
    return model

if CLIP_SIAMESE_CONTRASTIVE_CKPT is not None:
    model_config = siamese_contrastive_model_configs[CLIP_SIAMESE_CONTRASTIVE_CKPT]
    model = load_model(model_config, is_contrastive=True)

## Compute embeddings

In [32]:
from typing import Tuple, List, Any
from PIL import Image  # Assumes Pillow is installed

def get_sku_image_offline(sku_or_image_id, img_dataset_dir):
        """Load an image from disk for a given SKU.
           It first tries .jpg then .webp.
        """
        for ext in ['.jpg', '.webp']:
            img_path = os.path.join(img_dataset_dir, f"{sku_or_image_id}{ext}")
            if os.path.exists(img_path):
                try:
                    with open(img_path, 'rb') as f:
                        img_data = f.read()
                    image = Image.open(BytesIO(img_data))
                    image.load()
                    return image
                except Exception as e:
                    print(f"Error loading {img_path}: {e}")
        return None

def get_image_and_name(
    row,
    image_id_col: str,
    name_col: str,
    offline: bool = True,
    img_dataset_dir: str = '../data/images_7k'
) -> Tuple[Any, Any]:
    """
    Retrieves a single image and its corresponding name from a DataFrame row.

    Args:
        row: The DataFrame row.
        image_id_col: Column name for the image identifier.
        name_col: Column name for the product name.
        offline: Whether to use the offline image retrieval function.
        img_dataset_dir: Directory to load images from when offline.

    Returns:
        A tuple (image, name). If the image is not loaded, image will be None.
    """
    if offline:
        image = get_sku_image_offline(int(row[image_id_col]), img_dataset_dir)
    else:
        image = get_sku_image(int(row[image_id_col]))
    name = row[name_col]
    return image, name

def get_images_names(
    df,
    image_id_col: str = 'sku',
    name_col: str = 'name',
    offline: bool = True,
    img_dataset_dir: str = '../data/images_7k'
) -> Tuple[List[Image.Image], List[Any], List[int]]:
    """
    Iterates over the DataFrame and retrieves the image and name for each single product.

    Args:
        df: DataFrame containing one product per row.
        image_id_col: Column name for the image identifier.
        name_col: Column name for the product name.
        offline: Whether to load images using the offline function.
        img_dataset_dir: Directory for offline images.

    Returns:
        A tuple (images, names, problems) where:
         - images: list of loaded images,
         - names: list of corresponding names,
         - problems: list of row indices where the image failed to load.
    """
    images, names, problems = [], [], []
    for idx, row in df.iterrows():
        img, prod_name = get_image_and_name(row, image_id_col, name_col, offline, img_dataset_dir)
        if img is not None:
            images.append(img)
            names.append(prod_name)
        else:
            problems.append(idx)
    return images, names, problems

In [33]:
# Paths for caching embeddings
from pathlib import Path

emb_prefix = Path(DATA_PATH) / 'embeddings_OZ_geo_5500'
emb_prefix.mkdir(parents=True, exist_ok=True)  # Ensure directory exists

n_query = query_df.sku.nunique()
n_nonquery = nonquery_df.sku.nunique()

if CLIP_MODEL is not None:
    model_name = CLIP_MODEL
elif CLIP_SIAMESE_CONTRASTIVE_CKPT is not None:
    model_name = CLIP_SIAMESE_CONTRASTIVE_CKPT

images_embs_file_name = emb_prefix / f'{model_name}_images_latents_query-{n_query}_nonquery-{n_nonquery}.npy'
names_embs_file_name = emb_prefix / f'{model_name}_names_latents_query-{n_query}_nquery-{n_nonquery}.npy'
final_embs_file_name = emb_prefix / f'{model_name}_final_latents_query-{n_query}_nonquery-{n_nonquery}.npy'

In [34]:
def compute_embeddings_ruclip(
    new_source_df,
    images_embs_file_name,
    names_embs_file_name,
    RUCLIP_BATCH_SIZE,
    DEVICE,
    processor,
    clip,
    get_images_names,
):
    """
    Computes embeddings for images and names if not already cached.

    Args:
        new_source_df (pd.DataFrame): DataFrame containing the source data.
        images_embs_file_name (str): Path to save/load image embeddings.
        names_embs_file_name (str): Path to save/load name embeddings.
        RUCLIP_BATCH_SIZE (int): Batch size for processing.
        DEVICE (str): Device to use for computation (e.g., 'cuda' or 'cpu').
        processor (ruclip.processor.RuCLIPProcessor): Processor for RuCLIP.
        clip (ruclip.CLIP): RuCLIP model.

    Returns:
        tuple: (images_latents, names_latents, problems_ids)
    """
    from tqdm import tqdm
    import os
    import torch
    import numpy as np

    if not os.path.isfile(images_embs_file_name) or not os.path.isfile(names_embs_file_name):
        templates = ['{}', 'это {}', 'на картинке {}', 'товар {}']
        predictor = ruclip.Predictor(
            clip, processor, DEVICE,
            bs=RUCLIP_BATCH_SIZE,
            templates=templates
        )

        images_latents = []
        names_latents = []
        problems_ids = []

        def get_batches(df, batch_size):
            for start in range(0, len(df), batch_size):
                yield df.iloc[start:start+batch_size]

        total_batches = len(new_source_df) // RUCLIP_BATCH_SIZE
        with torch.no_grad():
            for batch_idx, df_batch in tqdm(enumerate(get_batches(new_source_df, batch_size=RUCLIP_BATCH_SIZE))):
                print(f'\nBatch {batch_idx+1} / {total_batches}')
                images_batch, names_batch, problems_ids_batch = get_images_names(
                    df=df_batch,
                    image_id_col='image_id',
                    name_col='name',
                    img_dataset_dir='data/images_OZ_geo_5500',
                    offline=True
                )
                # Process your batch here:
                images_latents_batch = predictor.get_image_latents(images_batch).detach().cpu()
                name_latents_batch = predictor.get_text_latents(names_batch).detach().cpu()

                images_latents.append(images_latents_batch)
                names_latents.append(name_latents_batch)

                problems_ids.extend(problems_ids_batch)

        images_latents = torch.cat(images_latents).numpy()
        names_latents = torch.cat(names_latents).numpy()

        # Save image latents
        np.save(images_embs_file_name, images_latents)

        # Save name latents
        np.save(names_embs_file_name, names_latents)
    else:
        # Load cached embeddings
        images_latents = np.load(images_embs_file_name)
        names_latents = np.load(names_embs_file_name)
        problems_ids = []
        print("Loaded embeddings from cache.")

    return images_latents, names_latents, problems_ids

In [35]:
if CLIP_MODEL is not None:
    # Ensure the required variables and objects are defined
    images_latents, names_latents, problems_ids = compute_embeddings_ruclip(
        new_source_df=new_source_df,
        images_embs_file_name=images_embs_file_name,
        names_embs_file_name=names_embs_file_name,
        RUCLIP_BATCH_SIZE=RUCLIP_BATCH_SIZE,
        DEVICE=DEVICE,
        processor=processor,
        clip=clip,
        get_images_names=get_images_names
    )

    # Print results
    print(f"Image embeddings shape: {images_latents.shape}")
    print(f"Name embeddings shape: {names_latents.shape}")
    print(f"Problematic IDs: {problems_ids}")

In [36]:
def compute_embeddings_ruclip_siamese_contrastive(
    new_source_df,
    images_embs_file_name,
    names_embs_file_name,
    RUCLIP_BATCH_SIZE,
    DEVICE,
    model,
    images_dir,
    name_model_name,
    description_model_name
):
    """
    Computes embeddings for images and names using a RuCLIPDataset if not already cached.
    """
    import os
    import torch
    import numpy as np
    from torch.utils.data import DataLoader
    from tqdm import tqdm

    # Construct the `image_name` column from `image_id`
    new_source_df['image_name'] = new_source_df['image_id'].astype(str).apply(
        lambda x: f"{x}.jpg"  # Adjust extension as needed
    )

    IMG_DATASET_DIR = os.path.join(DATA_PATH, IMG_DATASET_NAME)

    if not os.path.isfile(images_embs_file_name) or not os.path.isfile(names_embs_file_name):
        dataset = RuCLIPDataset(
            images_dir=IMG_DATASET_DIR,
            name_model_name=name_model_name,
            description_model_name=description_model_name,
            df=new_source_df
        )
        dataloader = DataLoader(dataset, batch_size=RUCLIP_BATCH_SIZE, shuffle=False)
        total_batches = len(dataloader)
        print(f"Total batches to process: {total_batches}")

        images_latents = []
        names_latents = []

        with torch.no_grad():
            for batch_idx, (images, names, descriptions) in tqdm(enumerate(dataloader), total=total_batches, desc="Processing Batches"):
                # Move data to the device
                images = images.to(DEVICE)
                names = names.to(DEVICE)

                # Get embeddings from the model
                images_latents_batch = model.encode_image(images).detach().cpu()
                name_latents_batch = model.encode_name(names).detach().cpu()

                images_latents.append(images_latents_batch)
                names_latents.append(name_latents_batch)

        images_latents = torch.cat(images_latents).numpy()
        names_latents = torch.cat(names_latents).numpy()

        # Save embeddings
        np.save(images_embs_file_name, images_latents)
        np.save(names_embs_file_name, names_latents)

    else:
        images_latents = np.load(images_embs_file_name)
        names_latents = np.load(names_embs_file_name)
        print("Loaded embeddings from cache.")

    return images_latents, names_latents, []


In [37]:
if CLIP_SIAMESE_CONTRASTIVE_CKPT is not None and not COMPUTE_FINAL_EMBEDDINGS:
    # Ensure the required variables and objects are defined
    images_latents, names_latents, problems_ids = compute_embeddings_ruclip_siamese_contrastive(
        new_source_df=new_source_df,
        images_embs_file_name=images_embs_file_name,
        names_embs_file_name=names_embs_file_name,
        RUCLIP_BATCH_SIZE=RUCLIP_BATCH_SIZE,
        DEVICE=DEVICE,
        model=model,
        images_dir='data/images_OZ_geo_5500',
        name_model_name=model_config["NAME_MODEL_NAME"],
        description_model_name=model_config["DESCRIPTION_MODEL_NAME"]
    )

    # Print results
    print(f"Image embeddings shape: {images_latents.shape}")
    print(f"Name embeddings shape: {names_latents.shape}")
    print(f"Problematic IDs: {problems_ids}")

In [38]:
def compute_embeddings_final_siamese_contrastive(
    new_source_df,
    final_embs_file_name,
    RUCLIP_BATCH_SIZE,
    DEVICE,
    model,
    images_dir,
    name_model_name,
    description_model_name
):
    import os
    import torch
    import numpy as np
    from torch.utils.data import DataLoader
    from tqdm import tqdm

    # Construct the `image_name` column from `image_id`
    new_source_df['image_name'] = new_source_df['image_id'].astype(str).apply(
        lambda x: f"{x}.jpg"  # Adjust extension as needed
    )

    # IMG_DATASET_DIR and DATA_PATH/IMG_DATASET_NAME are assumed to be defined globally
    IMG_DATASET_DIR = os.path.join(DATA_PATH, IMG_DATASET_NAME)

    # Initialize list to record indices with problems (if any are detected)
    problems_ids = []

    if not os.path.isfile(final_embs_file_name):
        dataset = RuCLIPDataset(
            images_dir=IMG_DATASET_DIR,
            name_model_name=name_model_name,
            description_model_name=description_model_name,
            df=new_source_df
        )
        dataloader = DataLoader(dataset, batch_size=RUCLIP_BATCH_SIZE, shuffle=False)
        total_batches = len(dataloader)
        print(f"Total batches to process: {total_batches}")

        embeddings = []

        with torch.no_grad():
            for batch_idx, (images, names, descriptions) in tqdm(enumerate(dataloader), total=total_batches, desc="Processing Batches"):
                # Move data to the device
                images = images.to(DEVICE)
                names = names.to(DEVICE)
                descriptions = descriptions.to(DEVICE)

                # Compute the final embeddings
                final_embeddings_batch = model.get_final_embedding(images, names, descriptions).detach().cpu()

                # (Optional) Insert logic here to check for problems and record batch indices to problems_ids
                embeddings.append(final_embeddings_batch)

        embeddings = torch.cat(embeddings).numpy()

        # Save embeddings to disk
        np.save(final_embs_file_name, embeddings)

    else:
        embeddings = np.load(final_embs_file_name)
        print("Loaded embeddings from cache.")

    return embeddings, problems_ids

In [39]:
if CLIP_SIAMESE_CONTRASTIVE_CKPT is not None and COMPUTE_FINAL_EMBEDDINGS:
    # Ensure the required variables and objects are defined
    final_latents, problems_ids = compute_embeddings_final_siamese_contrastive(
        new_source_df=new_source_df,
        final_embs_file_name=final_embs_file_name,
        RUCLIP_BATCH_SIZE=RUCLIP_BATCH_SIZE,
        DEVICE=DEVICE,
        model=model,
        images_dir='data/images_OZ_geo_5500',
        name_model_name=model_config["NAME_MODEL_NAME"],
        description_model_name=model_config["DESCRIPTION_MODEL_NAME"]
    )

    # Print results
    print(f"Final embeddings shape: {final_latents.shape}")

Loaded embeddings from cache.
Final embeddings shape: (5562, 462)


## Similarity search

In [40]:
#@title find_top_k_similar

import numpy as np
from sklearn.metrics.pairwise import cosine_similarity

def find_top_k_similar(query_embeddings, embedding_matrix, k=5, metric='cosine', exclude_indices=None):
    """
    Find top-k similar items for each query embedding in a batch, with the option to exclude certain indices.
    If k is None, returns all indices sorted by similarity.

    Args:
        query_embeddings (np.ndarray): Array of query embeddings, shape (batch, D).
        embedding_matrix (np.ndarray): Array of all embeddings, shape (N, D).
        k (int or None): Number of top matches to return, or None to return all sorted candidates.
        metric (str): 'cosine' or 'euclidean'.
        exclude_indices (list, np.ndarray, or boolean mask, optional): Indices to exclude from search.
            If a list/array of indices is provided, it will be converted to a boolean mask.
            If a boolean mask is provided, it must have shape (N,).

    Returns:
        top_k (np.ndarray): Indices of similar embeddings for each query, shape (batch, M) where
                            M == k (or M == number of valid candidates if k is None).
        scores (np.ndarray): Corresponding similarity scores (or negated distances) for each query,
                             shape (batch, M).
    """
    # Convert exclude_indices to a boolean mask if needed
    if exclude_indices is not None:
        if isinstance(exclude_indices, (list, np.ndarray)):
            exclude_indices = np.array(exclude_indices)
            if exclude_indices.dtype != bool:
                mask = np.zeros(embedding_matrix.shape[0], dtype=bool)
                mask[exclude_indices] = True
            else:
                mask = exclude_indices
        else:
            raise ValueError("exclude_indices must be a list, np.ndarray, or boolean mask.")
    else:
        mask = None

    if metric == 'cosine':
        # Compute cosine similarities for the entire batch (shape: (batch, N))
        similarities = cosine_similarity(query_embeddings, embedding_matrix)
        # Set similarities for excluded indices to -infinity so they are not selected.
        if mask is not None:
            similarities[:, mask] = -np.inf
        # Sort indices in descending order of similarity.
        sorted_idx = np.argsort(-similarities, axis=1)
        if k is None:
            top_k = sorted_idx
            scores = np.take_along_axis(similarities, top_k, axis=1)
        else:
            top_k = sorted_idx[:, :k]
            scores = np.take_along_axis(similarities, top_k, axis=1)
    elif metric == 'euclidean':
        # Compute Euclidean distances: shape (batch, N)
        distances = np.linalg.norm(query_embeddings[:, None, :] - embedding_matrix[None, :, :], axis=2)
        # Set distances for excluded indices to +infinity so they are not selected.
        if mask is not None:
            distances[:, mask] = np.inf
        # Sort indices in ascending order of distance.
        sorted_idx = np.argsort(distances, axis=1)
        if k is None:
            top_k = sorted_idx
            # Negate distances so that higher scores correspond to closer matches.
            scores = -np.take_along_axis(distances, top_k, axis=1)
        else:
            top_k = sorted_idx[:, :k]
            scores = -np.take_along_axis(distances, top_k, axis=1)
    else:
        raise ValueError("Unsupported metric: choose 'cosine' or 'euclidean'")

    return top_k, scores


In [41]:
# Limit queries to computed batches

if COMPUTE_FINAL_EMBEDDINGS:
    max_emb_cnt = final_latents.shape[0]
else:
    max_emb_cnt = images_latents.shape[0]

truncated_df = new_source_df.iloc[:max_emb_cnt]
query_indices = truncated_df[
    truncated_df.sku.isin(query_df.sku)
].index.tolist()

In [42]:
# Load images/names queries

if not COMPUTE_FINAL_EMBEDDINGS:
    query_images_embs = images_latents[query_indices]  # e.g., shape (3, 768)
    query_names_embs = names_latents[query_indices]  # e.g., shape (3, 768)

    print(query_images_embs.shape, images_latents.shape)

In [43]:
# Load final/names queries

if COMPUTE_FINAL_EMBEDDINGS:
    query_final_embs = final_latents[query_indices]  # e.g., shape (3, 768)

    print(query_final_embs.shape)

(23, 462)


In [44]:
# Find top-k matches for images
if COMPUTE_FINAL_EMBEDDINGS:
    top_k_final, scores_final = find_top_k_similar(
        query_final_embs, final_latents,
        k=TOP_K,
        metric='cosine',
        exclude_indices=query_indices
    )

    print("Top-k image indices per query (shape):")
    print(top_k_final.shape)

    print("Corresponding similarity scores:")
    print(scores_final)

Top-k image indices per query (shape):
(23, 50)
Corresponding similarity scores:
[[0.9788321  0.976992   0.976992   ... 0.9547391  0.95449173 0.95449173]
 [0.9689189  0.9657544  0.95760274 ... 0.9223962  0.9214299  0.9213377 ]
 [0.98584026 0.9829995  0.98169225 ... 0.95928633 0.959041   0.95892525]
 ...
 [0.9161713  0.9158402  0.9153507  ... 0.87390023 0.87360716 0.8734062 ]
 [0.94727147 0.94716656 0.946903   ... 0.9063225  0.9059004  0.90561515]
 [0.97873294 0.97851026 0.9760494  ... 0.9621091  0.96182775 0.9618242 ]]


In [45]:
# Find top-k matches for images
if not COMPUTE_FINAL_EMBEDDINGS:
    top_k_images, scores_images = find_top_k_similar(
        query_images_embs, images_latents,
        k=TOP_K,
        metric='cosine',
        exclude_indices=query_indices
    )

    print("Top-k image indices per query (shape):")
    print(top_k_images.shape)

    print("Corresponding similarity scores:")
    print(scores_images)

In [46]:
# Find top-k matches for name
if not COMPUTE_FINAL_EMBEDDINGS:
    top_k_name, scores_name = find_top_k_similar(
        query_names_embs, names_latents,
        k=TOP_K,
        metric='cosine',
        exclude_indices=query_indices
    )

    print("Top-k image indices per query (shape):")
    print(top_k_name.shape)

    print("Corresponding similarity scores:")
    print(scores_name)

In [47]:
import numpy as np

def union_top_k_candidates(top_k_name, top_k_images, scores_name, scores_images, TOP_K_CANDIDATES=5):
    """
    Unites two top-k candidate lists (and their scores) while avoiding duplicates.
    For each query (row), candidates present in both name and image lists are merged,
    keeping the candidate's highest score. The merged list is then sorted in descending
    order by score and only the top k unique candidates are returned.

    Args:
        top_k_name (np.ndarray): Array of candidate IDs from the name modality, shape (batch_size, num_candidates)
        top_k_images (np.ndarray): Array of candidate IDs from the image modality, shape (batch_size, num_candidates)
        scores_name (np.ndarray): Array of scores for the name candidates, same shape as top_k_name.
        scores_images (np.ndarray): Array of scores for the image candidates, same shape as top_k_images.
        TOP_K_CANDIDATES (int): Number of top candidates to return after merging.

    Returns:
        unique_candidates (np.ndarray): Array of merged candidate IDs, shape (batch_size, TOP_K_CANDIDATES)
        unique_scores (np.ndarray): Array of merged candidate scores, shape (batch_size, TOP_K_CANDIDATES)
    """
    batch_size = top_k_name.shape[0]
    merged_candidates = []
    merged_scores = []

    for i in range(batch_size):
        # Concatenate candidates and scores from both modalities for the current query.
        candidates = np.concatenate([top_k_name[i], top_k_images[i]])
        scores = np.concatenate([scores_name[i], scores_images[i]])

        # Use a dict to store each candidate and its best (highest) score.
        cand_dict = {}
        for cand, score in zip(candidates, scores):
            if cand in cand_dict:
                if score > cand_dict[cand]:
                    cand_dict[cand] = score
            else:
                cand_dict[cand] = score

        # Sort the unique candidates by score in descending order.
        sorted_items = sorted(cand_dict.items(), key=lambda x: x[1], reverse=True)
        # Extract only the top-k candidates and scores.
        top_candidates = [item[0] for item in sorted_items][:TOP_K_CANDIDATES]
        top_scores = [item[1] for item in sorted_items][:TOP_K_CANDIDATES]

        # If there are less than TOP_K_CANDIDATES (for some reason), we can pad lists (or leave as is)
        # Here we assume every row has enough unique candidates.
        merged_candidates.append(top_candidates)
        merged_scores.append(top_scores)

    # Convert the list of lists to numpy arrays.
    unique_candidates = np.array(merged_candidates)
    unique_scores = np.array(merged_scores)

    return unique_candidates, unique_scores

In [48]:
if COMPUTE_FINAL_EMBEDDINGS:
    top_k_united_indices = top_k_final
    scores_united = scores_final
else:
    # Instead of doing this directly, it might be clearer to keep the two modalities separate.
    # Then call our function:
    top_k_united_indices, scores_united = union_top_k_candidates(
        top_k_name, top_k_images, scores_name, scores_images,
        TOP_K_CANDIDATES=TOP_K
    )

print("Merged top candidates shape:", top_k_united_indices.shape)
print("Merged scores shape:", scores_united.shape)
print(top_k_united_indices)
print(scores_united)

Merged top candidates shape: (23, 50)
Merged scores shape: (23, 50)
[[ 184 1851  677 ... 4867 1316 5461]
 [1935 2944  458 ...  174  838 4797]
 [1206  210 1090 ... 3293  726 1561]
 ...
 [5332 5142 4104 ... 4771 1747  838]
 [4104 2963 5142 ... 4360 2491 4627]
 [5004  294 3665 ...  112  953 1385]]
[[0.9788321  0.976992   0.976992   ... 0.9547391  0.95449173 0.95449173]
 [0.9689189  0.9657544  0.95760274 ... 0.9223962  0.9214299  0.9213377 ]
 [0.98584026 0.9829995  0.98169225 ... 0.95928633 0.959041   0.95892525]
 ...
 [0.9161713  0.9158402  0.9153507  ... 0.87390023 0.87360716 0.8734062 ]
 [0.94727147 0.94716656 0.946903   ... 0.9063225  0.9059004  0.90561515]
 [0.97873294 0.97851026 0.9760494  ... 0.9621091  0.96182775 0.9618242 ]]


In [49]:
CAND_IDX = 1

display(truncated_df.loc[query_indices[CAND_IDX]][['name', 'url']])
truncated_df.iloc[top_k_united_indices[CAND_IDX]][['name', 'url']].head()

Unnamed: 0,1
name,Географическая карта России настенная 102х160 ...
url,https://www.ozon.ru/context/detail/id/491270369/


Unnamed: 0,name,url
1935,Федеративное устройство России. Физическая кар...,https://www.ozon.ru/context/detail/id/963112517/
2944,"Карта мира ГеоДом ""Политическая"", Субъекты РФ,...",https://www.ozon.ru/context/detail/id/1798160349/
458,Политическая карта мира АСТ Масштаб 1:40 000 0...,https://www.ozon.ru/context/detail/id/1802254737/
4328,Карта России АСТ Складная. В новых границах. П...,https://www.ozon.ru/context/detail/id/1544611730/
161,Карта России настенная 122х79 см. Новые границ...,https://www.ozon.ru/context/detail/id/178733797/


# Compute paired dataset

## Make pairs for query sku

In [50]:
# get pairs

import pandas as pd

def get_pairs(df, sku, ignore_sku_list=[]):
    """
    Given a target SKU, return a paired DataFrame where:
      - *_first columns correspond to the target SKU row.
      - *_second columns correspond to all other SKU rows.
      - Equality columns (iseq_vendor, iseq_color, iseq_brand, iseq_supp, are_related) are added (all set to 0).

    Parameters:
        sku (int or str): SKU identifier for the target row.

    Returns:
        pd.DataFrame: DataFrame with paired rows.
    """
    # Ensure new_source_df is available in the global scope
    # Select the target row and the remaining rows
    target_df = df[df['sku'] == sku]
    if target_df.empty:
        raise ValueError(f"SKU {sku} not found in new_source_df")
    rest_df = df[~(df['sku'].isin([sku] + ignore_sku_list))]

    # Create a cross join (cartesian product) between the target row and all others
    paired_df_all = pd.merge(
        target_df.assign(key=1),
        rest_df.assign(key=1),
        on='key',
        suffixes=('_first', '_second')
    ).drop('key', axis=1)

    # Add equality columns and set them all to 0
    eq_cols = ['iseq_vendor', 'iseq_color', 'iseq_supp', 'are_related']
    for col in eq_cols:
        paired_df_all[col] = 0

    paired_df_all['iseq_brand'] = paired_df_all['brand_first'] == paired_df_all['brand_second']

    # Define desired final order of columns
    final_columns = [
        'balance_first', 'sales_first', 'rating_first', 'final_price_first',
        'comments_first', 'description_first', 'name_first', 'options_first',
        'sku_first', 'has_video_first', 'photo_count_first',

        'balance_second', 'sales_second', 'rating_second', 'final_price_second',
        'comments_second', 'description_second', 'name_second', 'options_second',
        'sku_second', 'has_video_second', 'photo_count_second',

        'iseq_vendor', 'iseq_color', 'iseq_brand', 'iseq_supp', 'are_related',

        'image_id_first', 'image_id_second',
        'url_first', 'url_second'
    ]

    paired_df_all = paired_df_all[final_columns]
    return paired_df_all

In [51]:
paired_df_all = pd.DataFrame()
for query_idx, top_k_idx in zip(query_indices, top_k_united_indices):
    paired_df = get_pairs(
        truncated_df.loc[top_k_idx.tolist() + [query_idx]], # TODO: fix this crime
        sku=truncated_df.loc[query_idx].sku,
        ignore_sku_list=query_df.sku.tolist(),
    )
    paired_df_all = pd.concat([paired_df_all, paired_df], ignore_index=True)
    # break

paired_df_all.shape

(1150, 31)

In [52]:
from pathlib import Path

n_query = query_df.sku.nunique()
n_nonquery = nonquery_df.sku.nunique()

tables_prefix = Path(DATA_PATH) / 'tables_OZ_geo_5500'
tables_prefix.mkdir(parents=True, exist_ok=True)

# Paired data CSV
file_path_pairs = (
    tables_prefix /
    f'tabular_OZ_geo_5500_top-{TOP_K}'
    f'_query-{n_query}_nonquery-{n_nonquery}_pairs'
    f'_sbert={SBERT_MODEL}_clip={model_name}'
    '.csv'
)

print(str(file_path_pairs))

paired_df_all.to_csv(file_path_pairs, index=None)

data/tables_OZ_geo_5500/tabular_OZ_geo_5500_top-50_query-23_nonquery-5539_pairs_sbert=all-distilroberta-v1_clip=siamese_contrastive.pt.csv


## Add embedding distances

In [53]:
sbert = SentenceTransformer(SBERT_MODEL, device=DEVICE)

In [54]:
from pathlib import Path

# Ensure cache directory exists
emb_prefix = Path(DATA_PATH) / 'embeddings_OZ_geo_5500'
emb_prefix.mkdir(parents=True, exist_ok=True)

n_pairs = len(paired_df_all)  # Key based on number of paired rows; change if needed.

if CLIP_MODEL is not None:
    model_name = CLIP_MODEL
elif CLIP_SIAMESE_CONTRASTIVE_CKPT is not None:
    model_name = CLIP_SIAMESE_CONTRASTIVE_CKPT

# Define cache filenames for description and options similarities
desc_sim_file = emb_prefix / f'{model_name}_desc_sim_pairs-{n_pairs}.npy'
opt_sim_file = emb_prefix / f'{model_name}_opt_sim_pairs-{n_pairs}.npy'

In [99]:
!rm -f {desc_sim_file} {opt_sim_file}

In [101]:
# Compute description and option similarities using filtered indices
# (similar to how image and name sims are computed)

if not os.path.isfile(desc_sim_file) or not os.path.isfile(opt_sim_file):
    desc_sim = []
    opt_sim = []
    # Caches to avoid recomputing SBERT encodings for the same rows
    desc_cache = {}
    opt_cache = {}

    # Loop over each query and its corresponding top-k candidate indices
    for query_idx, candidate_indices in zip(query_indices, top_k_united_indices):
        # Get (and cache) SBERT embedding for the query description
        if query_idx not in desc_cache:
            query_desc = truncated_df.loc[query_idx, 'description']
            query_desc_emb = sbert.encode(query_desc, convert_to_tensor=True, show_progress_bar=Truee)
            desc_cache[query_idx] = query_desc_emb
        else:
            query_desc_emb = desc_cache[query_idx]

        # Get (and cache) SBERT embedding for the query options
        if query_idx not in opt_cache:
            query_opt = truncated_df.loc[query_idx, 'options']
            query_opt_emb = sbert.encode(query_opt, convert_to_tensor=True, show_progress_bar=Truee)
            opt_cache[query_idx] = query_opt_emb
        else:
            query_opt_emb = opt_cache[query_idx]

        # Loop over each candidate index for the current query
        for candidate_idx in candidate_indices:
            # Get (and cache) SBERT embedding for candidate description
            if candidate_idx not in desc_cache:
                cand_desc = truncated_df.loc[candidate_idx, 'description']
                cand_desc_emb = sbert.encode(cand_desc, convert_to_tensor=True, show_progress_bar=Truee)
                desc_cache[candidate_idx] = cand_desc_emb
            else:
                cand_desc_emb = desc_cache[candidate_idx]

            # Get (and cache) SBERT embedding for candidate options
            if candidate_idx not in opt_cache:
                cand_opt = truncated_df.loc[candidate_idx, 'options']
                cand_opt_emb = sbert.encode(cand_opt, convert_to_tensor=True, show_progress_bar=Truee)
                opt_cache[candidate_idx] = cand_opt_emb
            else:
                cand_opt_emb = opt_cache[candidate_idx]

            # Compute cosine similarities for description and options
            sim_desc = util.cos_sim(query_desc_emb, cand_desc_emb).cpu().numpy().squeeze()
            sim_opt = util.cos_sim(query_opt_emb, cand_opt_emb).cpu().numpy().squeeze()

            desc_sim.append(sim_desc)
            opt_sim.append(sim_opt)

    # Convert the lists to numpy arrays
    desc_sim = np.array(desc_sim)
    opt_sim = np.array(opt_sim)

    # Cache the computed similarities
    np.save(desc_sim_file, desc_sim)
    np.save(opt_sim_file, opt_sim)
else:
    # Load cached similarities if available
    desc_sim = np.load(desc_sim_file)
    opt_sim = np.load(opt_sim_file)
    print("Loaded cached description and option similarities.")

print("Description similarities shape:", desc_sim.shape)
print("Option similarities shape:", opt_sim.shape)

Loaded cached description and option similarities.
Description similarities shape: (1150,)
Option similarities shape: (1150,)


In [102]:
def get_images_names_paired(
    df,
    image_id_col_first: str = 'sku_first',
    image_id_col_second: str = 'sku_second',
    name_col_first: str = 'name_first',
    name_col_second: str = 'name_second',
    offline: bool = True,
    img_dataset_dir: str = '../data/images_7k'
) -> Tuple[List[Image.Image], List[object], List[int]]:
    """
    Iterate over the DataFrame and extract image pairs and their corresponding names.

    For each row, two images and their associated names are extracted using get_image_and_name.
    If both images are successfully loaded, they are added to the lists. Otherwise, the row index is recorded as a problem.

    Args:
        df: A DataFrame containing data rows.
        image_id_col_first: Column name for the first image identifier.
        image_id_col_second: Column name for the second image identifier.
        name_col_first: Column name for the first name.
        name_col_second: Column name for the second name.
        offline: Whether to load images offline.
        img_dataset_dir: Directory to load images from when offline.

    Returns:
        A tuple (images, names, problems) where:
            - images is a list containing both images from each successful row,
            - names is a list of the corresponding names,
            - problems is a list of row indices where one or both images failed to load.
    """
    images, names, problems = [], [], []

    for idx, row in df.iterrows():
        img1, name1 = get_image_and_name(row, image_id_col_first, name_col_first, offline, img_dataset_dir)
        img2, name2 = get_image_and_name(row, image_id_col_second, name_col_second, offline, img_dataset_dir)

        if img1 is not None and img2 is not None:
            images.extend([img1, img2])
            names.extend([name1, name2])
        else:
            problems.append(idx)

    return images, names, problems

In [103]:
# # Example usage:
# images, names, problems_ids = get_images_names_paired(
#     paired_df,
#     image_id_col_first='image_id_first',
#     image_id_col_second='image_id_second',
#     name_col_first='name_first',  # Adjust these column names if needed
#     name_col_second='name_second',
#     img_dataset_dir='data/images_OZ_geo_5500'
# )

# print(f'Images loaded: {len(images)}')
# print(f'Images not loaded: {len(problems_ids)}')

In [104]:
desc_sim.shape

(1150,)

In [105]:
# Delete problematic ids
paired_df_all = paired_df_all[~paired_df_all.index.isin(problems_ids)]

desc_sim = np.delete(desc_sim, problems_ids, axis=0)
opt_sim = np.delete(opt_sim, problems_ids, axis=0)

In [108]:
if COMPUTE_FINAL_EMBEDDINGS:
    name_sim = []
    img_sim = []
    for query_idx, top_k_idx in zip(query_indices, top_k_united_indices):
        first = query_final_embs[query_idx]
        for candidate_idx in top_k_idx:
            second = final_latents[candidate_idx]
            name_sim.append(
                util.cos_sim(first, second).cpu().numpy().squeeze()
            )
            img_sim.append(
                util.cos_sim(first, second).cpu().numpy().squeeze()
            )

    print(len(name_sim))
    print(len(img_sim))

    scores = np.c_[desc_sim, opt_sim, name_sim, img_sim]

1150
1150


In [109]:
if not COMPUTE_FINAL_EMBEDDINGS:
    name_sim = []
    img_sim = []
    for query_idx, top_k_idx in zip(query_indices, top_k_united_indices):
        first_name = names_latents[query_idx]  # Use name embeddings for name_sim
        first_image = images_latents[query_idx]  # Use image embeddings for img_sim
        for candidate_idx in top_k_idx:
            second_name = names_latents[candidate_idx]
            second_image = images_latents[candidate_idx]
            name_sim.append(
                util.cos_sim(first_name, second_name).cpu().numpy().squeeze()
            )
            img_sim.append(
                util.cos_sim(first_image, second_image).cpu().numpy().squeeze()
            )

    print(len(name_sim))
    print(len(img_sim))

    scores = np.c_[desc_sim, opt_sim, name_sim, img_sim]

In [110]:
scores_df = pd.DataFrame(scores, columns=['desc_sim', 'opt_sim', 'name_sim', 'img_sim'])

paired_embedded_df_all = pd.concat(
    [
        paired_df_all.drop(columns=scores_df.columns, errors='ignore'),
        scores_df
    ],
    axis=1
)
paired_embedded_df_all.head(1)

Unnamed: 0,balance_first,sales_first,rating_first,final_price_first,comments_first,description_first,name_first,options_first,sku_first,has_video_first,...,iseq_supp,are_related,image_id_first,image_id_second,url_first,url_second,desc_sim,opt_sim,name_sim,img_sim
0,649,370,4.9,807,1719,Карта мира настенная — идеальный помощник для ...,"Карта МИРА настенная политическая,160х102 см, ...","Карта МИРА настенная политическая,160х102 см, ...",491279127,0,...,0,0,7295087132,6893671218,https://www.ozon.ru/context/detail/id/491279127/,https://www.ozon.ru/context/detail/id/1294181688/,0.825724,0.83766,0.978832,0.978832


In [111]:
paired_embedded_df_all.columns.tolist()

['balance_first',
 'sales_first',
 'rating_first',
 'final_price_first',
 'comments_first',
 'description_first',
 'name_first',
 'options_first',
 'sku_first',
 'has_video_first',
 'photo_count_first',
 'balance_second',
 'sales_second',
 'rating_second',
 'final_price_second',
 'comments_second',
 'description_second',
 'name_second',
 'options_second',
 'sku_second',
 'has_video_second',
 'photo_count_second',
 'iseq_vendor',
 'iseq_color',
 'iseq_brand',
 'iseq_supp',
 'are_related',
 'image_id_first',
 'image_id_second',
 'url_first',
 'url_second',
 'desc_sim',
 'opt_sim',
 'name_sim',
 'img_sim']

# Save all files to HF

In [112]:
from pathlib import Path

n_query = query_df.sku.nunique()
n_nonquery = nonquery_df.sku.nunique()

tables_prefix = Path(DATA_PATH) / 'tables_OZ_geo_5500'
tables_prefix.mkdir(parents=True, exist_ok=True)

# Embedded CSV
file_path_embedded = (
    tables_prefix /
    f'tabular_OZ_geo_5500_top-{TOP_K}'
    f'_query-{n_query}_nonquery-{n_nonquery}_embedded'
    f'_sbert={SBERT_MODEL}_clip={model_name}'
    f'{"_final-embs" if COMPUTE_FINAL_EMBEDDINGS else ""}'
    '.csv'
)

print(str(file_path_embedded))

paired_embedded_df_all.to_csv(file_path_embedded, index=None)

data/tables_OZ_geo_5500/tabular_OZ_geo_5500_top-50_query-23_nonquery-5539_embedded_sbert=all-distilroberta-v1_clip=siamese_contrastive.pt_final-embs.csv


In [None]:
from huggingface_hub import HfApi, login

api = HfApi()
api.upload_folder(
    folder_path=DATA_PATH,  # Path to the local directory
    repo_id="INDEEPA/clip-siamese",
    repo_type="dataset",
    ignore_patterns=['**/*.jpg', "**/*.webp"]
)

siamese_contrastive.pt_desc_sim_pairs-1150.npy:   0%|          | 0.00/4.73k [00:00<?, ?B/s]

siamese_contrastive.pt_final_latents_query-23_nonquery-5539.npy:   0%|          | 0.00/10.3M [00:00<?, ?B/s]

Upload 3 LFS files:   0%|          | 0/3 [00:00<?, ?it/s]

siamese_contrastive.pt_opt_sim_pairs-1150.npy:   0%|          | 0.00/4.73k [00:00<?, ?B/s]