# Installs & tokens

In [2]:
%%capture
try:
    from hdbscan import HDBSCAN
except ImportError:
    !pip install hdbscan

In [3]:
%%capture
try:
    import mlflow
except ImportError:
    !pip install mlflow

In [4]:
%%capture
try:
    import dotenv
except ImportError:
    !pip install python-dotenv

In [1]:
# Log into huggingface via Kaggle Secrets or .env

import os
from dotenv import load_dotenv
import huggingface_hub

try:
    from kaggle_secrets import UserSecretsClient

    user_secrets = UserSecretsClient()
    HF_TOKEN = user_secrets.get_secret("HF_TOKEN")
except ModuleNotFoundError:
    print("Not Kaggle environment. Skipping Kaggle secrets.")
    print("Trying to load HF_TOKEN from .env.")
    load_dotenv()
    HF_TOKEN = os.getenv("HF_TOKEN")
    print("Success!")

huggingface_hub.login(token=HF_TOKEN)

Not Kaggle environment. Skipping Kaggle secrets.
Trying to load HF_TOKEN from .env.
Success!


Note: Environment variable`HF_TOKEN` is set and is the current active token independently from the token you've just configured.


# Choose notebook parameters

In [2]:
# Logging / file management parameters

MODEL_NAME_POSTFIX = 'splitting-by-query'
DATA_PATH = 'data/'
RESULTS_DIR = 'train_results/'

In [4]:
import torch

## CHOOSE MODEL PARAMETERS #################################################

# NAME_MODEL_NAME = 'cointegrated/rubert-tiny' # 'DeepPavlov/distilrubert-tiny-cased-conversational-v1'
# DESCRIPTION_MODEL_NAME = 'cointegrated/rubert-tiny'
# PRELOAD_MODEL_NAME = None

NAME_MODEL_NAME = None
DESCRIPTION_MODEL_NAME = None
PRELOAD_MODEL_NAME = 'cc12m_rubert_tiny_ep_1.pt' # preload ruclip

DROPOUT = 0.5
# DROPOUT = None

BEST_CKPT_METRIC = 'f1'
# BEST_CKPT_METRIC = 'pos_acc'

MOMENTUM=0.9
WEIGHT_DECAY=1e-2
CONTRASTIVE_THRESHOLD=0.3

TEST_RATIO = 0.1
VAL_RATIO = 0.1

DEVICE='cuda' if torch.cuda.is_available() else 'cpu'
RANDOM_SEED=42

In [5]:
# Choose parameters for training

# # GPU config (large)
# TODO: introduce gradient accumulation
# BATCH_SIZE_PER_DEVICE=60
# EPOCHS=10
# POS_NEG_RATIO=1.0
# HARD_SOFT_RATIO=0.7 # TODO: 0.85
# LIMIT_TRAIN_POS_PAIRS_PER_QUERY=5000
# LIMIT_VAL_POS_PAIRS_PER_QUERY=None
# LIMIT_TEST_POS_PAIRS_PER_QUERY=None
# LIMIT_QUERIES = None
# SHEDULER_PATIENCE=1 # in epochs
# LR=3e-5
# CONTRASTIVE_MARGIN=1.0

# # GPU config (small-medium)
# BATCH_SIZE_PER_DEVICE=60
# EPOCHS=20
# POS_NEG_RATIO=1.0
# HARD_SOFT_RATIO=0.7
# LIMIT_TRAIN_POS_PAIRS_PER_QUERY=50 # 50 for small; 500 for medium
# LIMIT_VAL_POS_PAIRS_PER_QUERY=None
# LIMIT_TEST_POS_PAIRS_PER_QUERY=None
# LIMIT_QUERIES = None
# SHEDULER_PATIENCE=3 # in epochs
# LR=9e-5
# CONTRASTIVE_MARGIN=1.5

# CPU smoke test config
BATCH_SIZE_PER_DEVICE=1
EPOCHS=1
POS_NEG_RATIO=1.0
HARD_SOFT_RATIO=0.5
LIMIT_TRAIN_POS_PAIRS_PER_QUERY=2
LIMIT_VAL_POS_PAIRS_PER_QUERY=2
LIMIT_TEST_POS_PAIRS_PER_QUERY=2
LIMIT_QUERIES = 2
SHEDULER_PATIENCE=3 # in epochs
LR=9e-5
CONTRASTIVE_MARGIN=1.5

In [6]:
## CHOOSE DATA #########################################################

DATA_PATH=  'data/'
SOURCE_TABLE_NAME = 'tables_OZ_geo_5500/processed/OZ_geo_5500.csv'

# --- Load source_df and pairwise_mapping_df from Parquet ---
SOURCE_TABLE_NAME = 'tables_OZ_geo_5500/processed/OZ_geo_5500.csv' # TODO: OZ_geo_5500_manual-edit.csv
PAIRWISE_TABLE_NAME = 'tables_OZ_geo_5500/processed/regex-pairwise-groups/regex-pairwise-groups_num-queries=20_patterns-dict-hash=a6223255f273e52a893ba7235e3c19b3/mapping.parquet'
IMG_DATASET_NAME = 'images_OZ_geo_5500'

In [7]:
## LOGGING PARAMS ######################################################################

# MLFLOW_URI = "http://176.56.185.96:5000"
# MLFLOW_URI = "http://localhost:5000"
MLFLOW_URI = None

MLFLOW_EXPERIMENT = "siamese/1fold"

TELEGRAM_TOKEN = None
# TELEGRAM_TOKEN = '' # set token to get notifications

# Definitions

In [19]:
# Imports
import os

# os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
# os.environ["CUDA_VISIBLE_DEVICES"] = "1"

os.environ["TOKENIZERS_PARALLELISM"] = "false"

import mlflow
from mlflow.models import infer_signature

from timm import create_model
import numpy as np
import pandas as pd
import os
import torch
from torch import nn
from torch import optim, Tensor
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchvision import transforms
from torchinfo import summary
# import transformers
# from transformers import DistilBertModel, DistilBertConfig, DistilBertTokenizer,\
#         get_linear_schedule_with_warmup
from transformers import AutoModel, AutoTokenizer

import cv2

from PIL import Image
from tqdm.auto import tqdm

# import json
# from itertools import product

# import datasets
# from datasets import Dataset, concatenate_datasets
# import argparse
import requests

# from io import BytesIO
# from sklearn.model_selection import StratifiedKFold, train_test_split
from sklearn.metrics import balanced_accuracy_score, roc_auc_score, f1_score
import matplotlib.pyplot as plt
from IPython.display import display, clear_output
# import more_itertools

from sklearn.model_selection import train_test_split
import torch
import torch.nn.functional as F
from pathlib import Path
from tqdm import tqdm
from sklearn.metrics import f1_score
import mlflow
from torch.optim.lr_scheduler import ReduceLROnPlateau
import matplotlib.pyplot as plt
import tempfile

In [20]:
def make_tg_report(text, token=None) -> None:
    method = 'sendMessage'
    chat_id = 324956476
    _ = requests.post(
            url='https://api.telegram.org/bot{0}/{1}'.format(token, method),
            data={'chat_id': chat_id, 'text': text} 
        ).json()

In [21]:
class RuCLIPtiny(nn.Module):
    def __init__(self, name_model_name):
        super().__init__()
        self.visual = create_model('convnext_tiny',
                                   pretrained=False, # TODO: берём претрейн
                                   num_classes=0,
                                   in_chans=3)  # out 768

        self.transformer = AutoModel.from_pretrained(name_model_name)
        name_model_output_shape = self.transformer.config.hidden_size  # dynamically get hidden size
        self.final_ln = torch.nn.Linear(name_model_output_shape, 768)  # now uses the transformer hidden size
        self.logit_scale = torch.nn.Parameter(torch.ones([]) * np.log(1 / 0.07))

    @property
    def dtype(self):
        return self.visual.stem[0].weight.dtype

    def encode_image(self, image):
        return self.visual(image.type(self.dtype))

    def encode_text(self, input_ids, attention_mask):
        x = self.transformer(input_ids=input_ids, attention_mask=attention_mask)
        x = x.last_hidden_state[:, 0, :]
        x = self.final_ln(x)
        return x

    def forward(self, image, input_ids, attention_mask):
        image_features = self.encode_image(image)
        text_features = self.encode_text(input_ids, attention_mask)

        # normalized features
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        text_features = text_features / text_features.norm(dim=-1, keepdim=True)

        # cosine similarity as logits
        logit_scale = self.logit_scale.exp()
        logits_per_image = logit_scale * image_features @ text_features.t()
        logits_per_text = logits_per_image.t()

        return logits_per_image, logits_per_text

In [22]:
def get_transform():
    return transforms.Compose([
        transforms.Resize(224),
        transforms.CenterCrop(224),
        _convert_image_to_rgb,
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]), ])

def _convert_image_to_rgb(image):
    return image.convert("RGB")

class Tokenizers:
    def __init__(self):
        self.name_tokenizer = AutoTokenizer.from_pretrained(NAME_MODEL_NAME)
        self.desc_tokenizer = AutoTokenizer.from_pretrained(DESCRIPTION_MODEL_NAME)

    def tokenize_name(self, texts, max_len=77):
        tokenized = self.name_tokenizer.batch_encode_plus(texts,
                                                     truncation=True,
                                                     add_special_tokens=True,
                                                     max_length=max_len,
                                                     padding='max_length',
                                                     return_attention_mask=True,
                                                     return_tensors='pt')
        return torch.stack([tokenized["input_ids"], tokenized["attention_mask"]], dim=1)
    
    def tokenize_description(self, texts, max_len=77):
        tokenized = self.desc_tokenizer(texts,
                                        truncation=True,
                                        add_special_tokens=True,
                                        max_length=max_len,
                                        padding='max_length',
                                        return_attention_mask=True,
                                        return_tensors='pt')
        return torch.stack([tokenized["input_ids"], tokenized["attention_mask"]], dim=1)

class SiameseRuCLIPDataset(torch.utils.data.Dataset):
    def __init__(self, df=None, labels=None, df_path=None, images_dir=DATA_PATH+'images/'):
        # loads data either from path using `df_path` or directly from `df` argument
        self.df = pd.read_csv(df_path) if df_path is not None else df
        self.labels = labels
        self.images_dir = images_dir
        self.tokenizers = Tokenizers()
        self.transform = get_transform()
        # 
        self.max_len = 77
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        name_tokens = self.tokenizers.tokenize_name([str(row.name_first), 
                                               str(row.name_second)], max_len=self.max_len)
        name_first = name_tokens[:, 0, :] # [input_ids, attention_mask]
        name_second = name_tokens[:, 1, :]
        desc_tokens = self.tokenizers.tokenize_description([str(row.description_first), 
                                               str(row.description_second)])
        desc_first = desc_tokens[:, 0, :] # [input_ids, attention_mask]
        desc_second = desc_tokens[:, 1, :]
        im_first = cv2.imread(os.path.join(self.images_dir, row.image_name_first))
        im_first = cv2.cvtColor(im_first, cv2.COLOR_BGR2RGB)
        im_first = Image.fromarray(im_first)
        im_first = self.transform(im_first)
        im_second = cv2.imread(os.path.join(self.images_dir, row.image_name_second))
        im_second = cv2.cvtColor(im_second, cv2.COLOR_BGR2RGB)
        im_second = Image.fromarray(im_second)
        im_second = self.transform(im_second)
        label = self.labels[idx]
        return im_first, name_first, desc_first, im_second, name_second, desc_second, label

    def __len__(self,):
        return len(self.df)

In [23]:
def average_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
    last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
    return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]

class SiameseRuCLIP(nn.Module):
    def __init__(self,
                 device: str,
                 name_model_name: str,
                 description_model_name: str,
                 preload_model_name: str = None,
                 models_dir: str = None,
                 dropout: float = None):
        """
        Initializes the SiameseRuCLIP model.
        Required parameters:
          - models_dir: directory containing saved checkpoints.
          - name_model_name: model name for text (name) branch.
          - description_model_name: model name for description branch.
        """
        super().__init__()
        device = torch.device(device)

        # Initialize RuCLIPtiny
        self.ruclip = RuCLIPtiny(name_model_name)
        if preload_model_name is not None:
            std = torch.load(
                os.path.join(models_dir, preload_model_name),
                weights_only=True,
                map_location=device
            )
            self.ruclip.load_state_dict(std)
            self.ruclip.eval()
        self.ruclip = self.ruclip.to(device)

        # Initialize the description transformer
        self.description_transformer = AutoModel.from_pretrained(description_model_name)
        self.description_transformer = self.description_transformer.to(device)

        # Determine dimensionality
        vision_dim = self.ruclip.visual.num_features
        name_dim = self.ruclip.final_ln.out_features
        desc_dim = self.description_transformer.config.hidden_size
        self.hidden_dim = vision_dim + name_dim + desc_dim
        self.dropout = dropout

        # Define MLP head with optional dropout
        layers = [
            nn.Linear(self.hidden_dim, self.hidden_dim // 2),
            nn.ReLU(),
            *( [nn.Dropout(self.dropout)] if self.dropout is not None else [] ),
            nn.Linear(self.hidden_dim // 2, self.hidden_dim // 4),
        ]
        self.head = nn.Sequential(*layers).to(device)


    def encode_image(self, image):
        return self.ruclip.encode_image(image)

    def encode_name(self, name):
        return self.ruclip.encode_text(name[:, 0, :], name[:, 1, :])

    def encode_description(self, desc):
        last_hidden_states = self.description_transformer(desc[:, 0, :], desc[:, 1, :]).last_hidden_state
        attention_mask = desc[:, 1, :]
        return average_pool(last_hidden_states, attention_mask)

    def get_final_embedding(self, im, name, desc):
        image_emb = self.encode_image(im)
        name_emb = self.encode_name(name)
        desc_emb = self.encode_description(desc)

        # Concatenate the embeddings and forward through the head
        combined_emb = torch.cat([image_emb, name_emb, desc_emb], dim=1)
        final_embedding = self.head(combined_emb)
        return final_embedding

    def forward(self, im1, name1, desc1, im2, name2, desc2):
        out1 = self.get_final_embedding(im1, name1, desc1)
        out2 = self.get_final_embedding(im2, name2, desc2)
        return out1, out2

In [24]:
class ContrastiveLoss(torch.nn.Module):
    def __init__(self, margin=2.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin
        
    def __name__(self,):
        return 'ContrastiveLoss'

    def forward(self, output1, output2, label):
        euclidean_distance = F.pairwise_distance(output1, output2)
        pos = (1-label) * torch.pow(euclidean_distance, 2)
        neg = label * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2)
        loss_contrastive = torch.mean( pos + neg )
        return loss_contrastive

In [25]:
# TODO: plot epoch after each train epoch in `train()`

from pathlib import Path

def plot_epoch(loss_history, filename="data/runs_artifacts/epoch_loss.png") -> None:
    Path(filename).parent.mkdir(parents=True, exist_ok=True)
    clear_output(wait=True)
    plt.figure(figsize=(6, 4))
    plt.title("Training loss")
    plt.xlabel("Iteration number")
    plt.ylabel("Loss")
    plt.plot(loss_history, 'b')
    plt.tight_layout()
    plt.savefig(filename)  # Save the plot to a file
    plt.show()

In [26]:
def evaluate_pair(output1, output2, target, threshold):
    euclidean_distance = F.pairwise_distance(output1, output2)
    # меньше границы, там где будет True — конкуренты
    cond = euclidean_distance < threshold
    pos_sum = 0
    neg_sum = 0
    pos_acc = 0
    neg_acc = 0

    for i in range(len(cond)):
        # 1 значит не конкуренты
        if target[i]:
            neg_sum+=1
            # 0 в cond значит дальше друг от друга чем threshold
            if not cond[i]:
                neg_acc+=1
        elif not target[i]:
            pos_sum+=1
            if cond[i]:
                pos_acc+=1

    return pos_acc, pos_sum, neg_acc, neg_sum

def predict(out1, out2, threshold=CONTRASTIVE_THRESHOLD):
    # вернёт 1 если похожи
    return F.pairwise_distance(out1, out2) < threshold

In [27]:
def sku_to_model_inputs(sku_list, source_df, images_dir, tokenizers, transform):
    """
    Convert list of SKUs to model inputs (images, names, descriptions).
    This is the key function that bridges SKU IDs to actual model data.
    """
    # Get data from source_df
    sku_data = source_df.loc[source_df['sku'].isin(sku_list)].set_index('sku')
    
    images = []
    names = []
    descriptions = []
    
    for sku in sku_list:
        if sku in sku_data.index:
            row = sku_data.loc[sku]
            
            # Load and transform image
            import cv2
            from PIL import Image
            img_path = os.path.join(images_dir, row['image_name'])
            img = cv2.imread(img_path)
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            img = Image.fromarray(img)
            img = transform(img)
            images.append(img)
            
            # Get text data
            names.append(str(row['name']))
            descriptions.append(str(row['description']))
        else:
            # Handle missing SKU - create dummy data
            images.append(torch.zeros(3, 224, 224))  # Dummy image
            names.append("Unknown Product")
            descriptions.append("No description available")
    
    # Stack images and tokenize text
    images = torch.stack(images)
    name_tokens = tokenizers.tokenize_name(names)
    desc_tokens = tokenizers.tokenize_description(descriptions)
    
    return images, name_tokens, desc_tokens


def train_with_new_dataset(model, optimizer, criterion, epochs_num, train_loader, 
                          valid_loader=None, device='cpu', print_epoch=False, 
                          models_dir=None, metric='f1', source_df=None, images_dir=None):
    """
    Updated training function that handles PairGenerationDataset batch format.
    """
    from copy import deepcopy
    from timm import create_model
    from transformers import AutoModel, AutoTokenizer
    
    assert metric in ('f1', 'pos_acc'), "metric must be 'f1' or 'pos_acc'"

    model.to(device)
    train_losses, val_losses, thr_history = [], [], []
    best_valid_metric, best_threshold = float('-inf'), None
    best_weights = None

    # Initialize tokenizers and transform for SKU-to-input conversion
    tokenizers = Tokenizers()  # Your existing tokenizer class
    transform = get_transform()   # Your existing transform function

    scheduler = ReduceLROnPlateau(
        optimizer,
        mode="max",
        factor=0.1,
        patience=SHEDULER_PATIENCE,
        threshold=1e-4,
        threshold_mode='rel'
    )

    if models_dir:
        Path(models_dir).mkdir(parents=True, exist_ok=True)

    for epoch in range(1, epochs_num + 1):
        # ---- training ----
        model.train()
        total_train_loss = 0.0
        
        for batch in tqdm(train_loader, desc=f"Training Epoch {epoch}"):
            # Handle PairGenerationDataset format (dict) vs old format (tuple)
            if isinstance(batch, dict):
                # New format: batch is a dict with keys like 'sku_first', 'sku_second', 'label'
                sku_first = batch['sku_first']
                sku_second = batch['sku_second'] 
                labels = batch['label'].to(device)
                
                # Convert SKUs to model inputs
                im1, n1, d1 = sku_to_model_inputs(sku_first.tolist(), source_df, images_dir, tokenizers, transform)
                im2, n2, d2 = sku_to_model_inputs(sku_second.tolist(), source_df, images_dir, tokenizers, transform)
                
                # Move to device
                im1, n1, d1 = im1.to(device), n1.to(device), d1.to(device)
                im2, n2, d2 = im2.to(device), n2.to(device), d2.to(device)
                
            else:
                # Old format: tuple of tensors
                im1, n1, d1, im2, n2, d2, labels = [t.to(device) for t in batch]
            
            optimizer.zero_grad()
            out1, out2 = model(im1, n1, d1, im2, n2, d2)
            loss = criterion(out1, out2, labels)
            loss.backward()
            optimizer.step()
            total_train_loss += loss.item()
            
        train_losses.append(total_train_loss / len(train_loader))

        # ---- evaluation & checkpointing ----
        if print_epoch and valid_loader is not None:
            pos_acc, neg_acc, avg_acc, f1_val, val_loss, val_thr = evaluation(
                model, criterion, valid_loader, epoch, device=device,
                split_name='val', threshold=None, margin=CONTRASTIVE_MARGIN,
                steps=200, metric=metric, source_df=source_df, images_dir=images_dir
            )
            val_losses.append(val_loss)
            thr_history.append(val_thr)

            # pick the metric value to step & compare
            cur_metric = pos_acc if metric == 'pos_acc' else f1_val
            scheduler.step(cur_metric)

            # save checkpoint every epoch if requested
            if models_dir:
                torch.save(model.state_dict(),
                           Path(models_dir) / f"checkpoint_epoch_{epoch}.pt")

            # update best if improved
            if cur_metric > best_valid_metric:
                best_valid_metric = cur_metric
                best_threshold     = val_thr
                best_weights       = deepcopy(model.state_dict())

        print(f'Epoch {epoch} done.')

    print(f"Best evaluation {metric}: {best_valid_metric:.3f}  (thr={best_threshold:.3f})")
    return train_losses, val_losses, best_valid_metric, best_weights, thr_history

In [28]:
import numpy as np
import torch
import torch.nn.functional as F
from sklearn.metrics import f1_score
import mlflow
from copy import deepcopy
from torch.optim.lr_scheduler import ReduceLROnPlateau
from pathlib import Path

def evaluation(model, criterion, data_loader, epoch, device='cpu',
              split_name='val', threshold=None, margin=1.5,
              steps=200, metric='f1', source_df=None, images_dir=None,
              precompute_pairs=False):
    """
    Evaluation function that handles both precomputed and on-the-fly data formats.
    
    Args:
        precompute_pairs: If True, expects PrecomputedPairDataset format with pre-loaded tensors.
                         If False, expects PairGenerationDataset format and converts SKUs to tensors.
    """
    assert metric in ('f1', 'pos_acc'), "metric must be 'f1' or 'pos_acc'"
    assert split_name in ('val', 'test'), "split_name must be 'val' or 'test'"

    model.eval()
    total_loss = 0.0
    all_d, all_lbl = [], []

    # Only initialize tokenizers if not using precomputed pairs
    if not precompute_pairs:
        tokenizers = Tokenizers()
        transform = get_transform()

    with torch.no_grad():
        for batch in tqdm(data_loader, desc=f"Evaluation on {split_name}"):
            if precompute_pairs:
                # PrecomputedPairDataset format - data is already tensors
                im1 = batch['image_first'].to(device)
                n1 = batch['name_first'].to(device)
                d1 = batch['desc_first'].to(device)
                im2 = batch['image_second'].to(device)
                n2 = batch['name_second'].to(device)
                d2 = batch['desc_second'].to(device)
                labels = batch['label'].to(device)
            else:
                # PairGenerationDataset format - convert SKUs to tensors
                sku_first = batch['sku_first']
                sku_second = batch['sku_second']
                labels = batch['label'].to(device)
                
                # Convert SKUs to model inputs
                im1, n1, d1 = sku_to_model_inputs(sku_first.tolist(), source_df, images_dir, tokenizers, transform)
                im2, n2, d2 = sku_to_model_inputs(sku_second.tolist(), source_df, images_dir, tokenizers, transform)
                
                # Move to device
                im1, n1, d1 = im1.to(device), n1.to(device), d1.to(device)
                im2, n2, d2 = im2.to(device), n2.to(device), d2.to(device)
            
            out1, out2 = model(im1, n1, d1, im2, n2, d2)
            total_loss += criterion(out1, out2, labels).item()
            all_d.append(F.pairwise_distance(out1, out2).cpu())
            all_lbl.append(labels.cpu())

    distances = torch.cat(all_d)
    labels = torch.cat(all_lbl)
    avg_loss = total_loss / len(data_loader)

    # === threshold sweep ===
    if threshold is None:
        grid = np.linspace(0.0, margin, steps)
        best_val, best_thr = -1.0, 0.0
        y_true = (labels.numpy() == 0).astype(int)
        for t in grid:
            y_pred = (distances.numpy() < t).astype(int)
            if metric == 'f1':
                val = f1_score(y_true, y_pred, zero_division=0)
            else:  # metric == 'pos_acc'
                pos_mask = (y_true == 1)
                val = (y_pred[pos_mask] == 1).mean() if pos_mask.sum() > 0 else 0.0
            if val > best_val:
                best_val, best_thr = val, t
        threshold = best_thr
    else:
        best_thr = threshold

    # === final metrics at chosen threshold ===
    preds = (distances < threshold).long()
    pos_mask = (labels == 0)
    neg_mask = (labels == 1)

    pos_acc = (preds[pos_mask] == 1).float().mean().item() if pos_mask.any() else 0.0
    neg_acc = (preds[neg_mask] == 0).float().mean().item() if neg_mask.any() else 0.0
    avg_acc = (pos_acc + neg_acc) / 2.0
    f1 = f1_score((labels.numpy() == 0).astype(int),
                  preds.numpy(), zero_division=0)

    # log to console
    report = (f"[{split_name}] Epoch {epoch} – "
              f"loss: {avg_loss:.4f}, "
              f"P Acc: {pos_acc:.3f}, "
              f"N Acc: {neg_acc:.3f}, "
              f"Avg Acc: {avg_acc:.3f}, "
              f"F1: {f1:.3f}, "
              f"thr*: {threshold:.3f} "
              f"(optimised: {metric})")
    print(report)
    make_tg_report(report, TELEGRAM_TOKEN)

    # log to MLflow
    if MLFLOW_URI and split_name == 'val':
        if metric == 'f1':
            mlflow.log_metric("valid_f1_score", f1, step=epoch)
        else:
            mlflow.log_metric("valid_pos_accuracy", pos_acc, step=epoch)

    return pos_acc, neg_acc, avg_acc, f1, avg_loss, threshold

In [29]:
from time import perf_counter
from datetime import timedelta

def train(model,
          optimizer,
          criterion,
          epochs_num,
          train_loader,
          valid_loader=None,
          device='cpu',
          print_epoch=False,
          models_dir=None,
          metric='f1'):
    """
    Trains for `epochs_num` epochs, using `evaluation(..., metric=metric)` each epoch.
    Uses the same `metric` to step the LR scheduler and to pick the best checkpoint.

    Returns:
      train_losses, val_losses, best_valid_metric, best_weights, thr_history
    """
    assert metric in ('f1', 'pos_acc'), "metric must be 'f1' or 'pos_acc'"

    model.to(device)
    train_losses, val_losses, thr_history = [], [], []
    best_valid_metric, best_threshold = float('-inf'), None
    best_weights = None

    scheduler = ReduceLROnPlateau(
        optimizer,
        mode="max",
        factor=0.1,
        patience=SHEDULER_PATIENCE,
        threshold=1e-4,
        threshold_mode='rel'
    )

    if models_dir:
        Path(models_dir).mkdir(parents=True, exist_ok=True)

    for epoch in range(1, epochs_num + 1):
        # ---- training ----
        model.train()
        total_train_loss = 0.0
        for batch in train_loader:
            im1, n1, d1, im2, n2, d2, lbl = [t.to(device) for t in batch]
            optimizer.zero_grad()
            out1, out2 = model(im1, n1, d1, im2, n2, d2)
            loss = criterion(out1, out2, lbl)
            loss.backward()
            optimizer.step()
            total_train_loss += loss.item()
        train_losses.append(total_train_loss / len(train_loader))

        # ---- evaluation & checkpointing ----
        if print_epoch and valid_loader is not None:
            pos_acc, neg_acc, avg_acc, f1_val, val_loss, val_thr = evaluation(
                model,
                criterion,
                valid_loader,
                epoch,
                device=device,
                split_name='val',
                threshold=None,
                margin=CONTRASTIVE_MARGIN,
                steps=200,
                metric=metric
            )
            val_losses.append(val_loss)
            thr_history.append(val_thr)

            # pick the metric value to step & compare
            cur_metric = pos_acc if metric == 'pos_acc' else f1_val
            scheduler.step(cur_metric)

            # save checkpoint every epoch if requested
            if models_dir:
                torch.save(model.state_dict(),
                           Path(models_dir) / f"checkpoint_epoch_{epoch}.pt")

            # update best if improved
            if cur_metric > best_valid_metric:
                best_valid_metric = cur_metric
                best_threshold     = val_thr
                best_weights       = deepcopy(model.state_dict())

        print(f'Epoch {epoch} done.')

    print(f"Best evaluation {metric}: {best_valid_metric:.3f}  (thr={best_threshold:.3f})")
    return train_losses, val_losses, best_valid_metric, best_weights, thr_history

# Prepare data

## Download data from HF

In [47]:
# Download models' weights & text/image datasets

from huggingface_hub import snapshot_download
from pathlib import Path

REPO_ID = "INDEEPA/clip-siamese"
LOCAL_DIR = Path("data/train_results")
LOCAL_DIR.mkdir(parents=True, exist_ok=True)

snapshot_download(
    repo_id=REPO_ID,
    repo_type='dataset',
    local_dir='data',
    allow_patterns=[
        f"train_results/{PRELOAD_MODEL_NAME}",
        SOURCE_TABLE_NAME,
        PAIRWISE_TABLE_NAME,
        f"{IMG_DATASET_NAME}.zip"
    ],
)

# Unzip the image dataset
import zipfile
with zipfile.ZipFile(f"data/{IMG_DATASET_NAME}.zip", 'r') as zip_ref:
    zip_ref.extractall("data/")

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

In [84]:
source_df = pd.read_csv(DATA_PATH + SOURCE_TABLE_NAME)
source_df.columns.tolist()

['sku',
 'description',
 'image_url',
 'name',
 'category',
 'схема',
 'brand',
 'niche',
 'seller',
 'balance_fbo',
 'balance_fbs',
 'warehouses_count',
 'comments',
 'final_price',
 'max_price',
 'min_price',
 'average_price',
 'median_price',
 'membership_card_price',
 'sales',
 'revenue',
 'revenue_potential',
 'revenue_average',
 'lost_profit',
 'lost_profit_percent',
 'url',
 'thumb',
 'pics_count',
 'has_video',
 'first_date',
 'days_in_website',
 'days_in_stock',
 'days_with_sales',
 'average_if_in_stock',
 'rating',
 'fbs',
 'base_price',
 'category_position',
 'categories_last_count',
 'sales_per_day_average',
 'sales.1',
 'frozen_stocks',
 'frozen_stocks_cost',
 'frozen_stocks_percent',
 'balance',
 'image_name']

In [85]:
pairwise_mapping_df = pd.read_parquet(DATA_PATH + PAIRWISE_TABLE_NAME)
pairwise_mapping_df.columns.tolist()

['sku_query', 'sku_pos', 'sku_hard_neg', 'sku_soft_neg']

# Cluster soft negatives

In [86]:
# List embeddings files in repo
FILTER_STRING = 'name-and-description_embeddings'

from huggingface_hub import list_repo_files

emb_files = [name for name in list_repo_files("INDEEPA/clip-siamese", repo_type="dataset") if FILTER_STRING in name and "OZ_geo_5500" in name]
for file in emb_files:
    display(file)

'embeddings/OZ_geo_5500/OZ_geo_5500_name-and-description_embeddings_num-rows=2.parquet'

'embeddings/OZ_geo_5500/OZ_geo_5500_name-and-description_embeddings_num-rows=5562.parquet'

In [87]:
# Suggest the correct path to the embedding file based on the context and previous file saving logic
CHOSEN_EMBEDDING_FILE = 'OZ_geo_5500_name-and-description_embeddings_num-rows=5562.parquet'

In [88]:
from huggingface_hub import hf_hub_download
import pandas as pd

# Download the chosen embedding file from HuggingFace Hub to DATA_PATH
from pathlib import Path

downloaded_emb_file = hf_hub_download(
    repo_id="INDEEPA/clip-siamese",
    repo_type="dataset",
    filename=f'embeddings/OZ_geo_5500/{CHOSEN_EMBEDDING_FILE}',
    local_dir=DATA_PATH,
)

print(f"Downloaded embedding file to:\n{downloaded_emb_file}")
emb_table = pd.read_parquet(downloaded_emb_file)
emb_table.head()

Downloaded embedding file to:
data/embeddings/OZ_geo_5500/OZ_geo_5500_name-and-description_embeddings_num-rows=5562.parquet


Unnamed: 0,sku,name_desc_emb
0,1871769771,"[-0.020089346915483475, -0.05487045273184776, ..."
1,1679550303,"[-0.00418242160230875, -0.04088427498936653, 0..."
2,1200553001,"[-0.023978281766176224, -0.05447990819811821, ..."
3,922231521,"[-0.024106157943606377, -0.053567297756671906,..."
4,922230517,"[-0.02229023538529873, -0.05309479311108589, -..."


In [89]:
from hdbscan import HDBSCAN
import numpy as np

# Prepare the embeddings as a numpy array
embeddings = np.stack(emb_table['name_desc_emb'].values)

# Run HDBSCAN clustering using sklearn's implementation
clusterer = HDBSCAN(
    min_samples=2,
    # metric='cosine',
)

cluster_labels = clusterer.fit_predict(embeddings)

# Add cluster labels to the emb_table and assign to cluster_emb_table
cluster_emb_table = emb_table.copy()
cluster_emb_table['cluster_id'] = cluster_labels

# Print cluster label counts
print("Cluster label counts:")
display(cluster_emb_table['cluster_id'].value_counts().to_frame().T)



Cluster label counts:


cluster_id,-1,357,1,301,92,235,355,345,163,193,...,35,287,282,295,33,272,336,381,192,435
count,1285,62,47,45,42,37,31,30,30,30,...,5,5,5,5,5,5,5,5,5,5


In [90]:
# Print cluster ids with size > N
N = 4  # You can change N to any desired threshold
cluster_counts = cluster_emb_table['cluster_id'].value_counts()
large_clusters = cluster_counts[cluster_counts > N].to_frame()
print(f"Cluster IDs with size > {N}:")
display(large_clusters.T)


Cluster IDs with size > 4:


cluster_id,-1,357,1,301,92,235,355,345,163,193,...,35,287,282,295,33,272,336,381,192,435
count,1285,62,47,45,42,37,31,30,30,30,...,5,5,5,5,5,5,5,5,5,5


In [91]:
# Print SKUs for a given CLUSTER_ID
CLUSTER_ID = 272  # Change this to the desired cluster id

skus_in_cluster = cluster_emb_table.loc[cluster_emb_table['cluster_id'] == CLUSTER_ID, 'sku']
print(f"SKUs in cluster {CLUSTER_ID}:")
display(skus_in_cluster.tolist()[:10])

SKUs in cluster 272:


[1913636945, 1893883403, 1857098848, 1853440559, 1783442338]

# Make train/val/test splits

In [92]:
def split_query_groups(
    mapping_df: pd.DataFrame,
    test_size: float = 0.2,
    val_size: float = 0.05,
    random_state: int = 42,
    min_positives_for_3way: int = 6  # minimum positives (including query) for 3-way split
):
    """
    Adaptive splitting based on number of positives (not including query):
    - 6+ positives: normal train/val/test split with query in test
    - 5 positives: pos1,pos2 to train; pos3,pos4 to val; q,pos5 to test
    - 4 positives: pos1,pos2 to train; pos3,pos4 to val; q,pos3,pos4 to test
    - 3 positives: pos1,pos2 to train; pos2,pos3 to val; q,pos3 to test
    - 2 positives: pos1,pos2 to train; q,pos1 to val; q,pos2 to test
    - 1 positive: copy q,pos to all splits
    - 0 positives: skip
    """
    rng = np.random.default_rng(random_state)
    split_rows = []

    for _, row in mapping_df.iterrows():
        q = row['sku_query']
        pos_without_query = list(set(row['sku_pos']) - {q})  # Convert to list for indexing
        hard_neg = set(row['sku_hard_neg']) - {q}
        soft_neg = set(row['sku_soft_neg']) - {q}
        
        total_positives = len(pos_without_query)  # Not including query

        def split_list(lst, test_frac, val_frac=None):
            lst = np.array(list(lst))
            n = len(lst)
            if val_frac is None:  # 2-way split
                n_test = int(np.ceil(test_frac * n))
                idx = rng.permutation(n)
                test_idx = idx[:n_test]
                train_idx = idx[n_test:]
                return lst[train_idx].tolist(), [], lst[test_idx].tolist()
            else:  # 3-way split
                n_test = int(np.ceil(test_frac * n))
                n_val = int(np.ceil(val_frac * n))
                idx = rng.permutation(n)
                test_idx = idx[:n_test]
                val_idx = idx[n_test:n_test+n_val]
                train_idx = idx[n_test+n_val:]
                return lst[train_idx].tolist(), lst[val_idx].tolist(), lst[test_idx].tolist()

        if total_positives >= 6:
            # 6+ positives: normal train/val/test split
            pos_train, pos_val, pos_test = split_list(pos_without_query, test_size, val_size)
            hard_train, hard_val, hard_test = split_list(hard_neg, test_size, val_size)
            soft_train, soft_val, soft_test = split_list(soft_neg, test_size, val_size)
            pos_test.append(q)  # query goes to test
            
            splits_to_create = ['train', 'val', 'test']
            pos_lists = [pos_train, pos_val, pos_test]
            hard_lists = [hard_train, hard_val, hard_test]
            soft_lists = [soft_train, soft_val, soft_test]
            
        elif total_positives == 5:
            # 5 pos: pos1,pos2 to train; pos3,pos4 to val; q,pos5 to test
            rng.shuffle(pos_without_query)  # randomize order
            pos1, pos2, pos3, pos4, pos5 = pos_without_query[:5]
            
            # Split negatives proportionally
            hard_train, hard_val, hard_test = split_list(hard_neg, test_size, val_size)
            soft_train, soft_val, soft_test = split_list(soft_neg, test_size, val_size)
            
            splits_to_create = ['train', 'val', 'test']
            pos_lists = [
                [pos1, pos2],      # train
                [pos3, pos4],      # val
                [q, pos5]          # test
            ]
            hard_lists = [hard_train, hard_val, hard_test]
            soft_lists = [soft_train, soft_val, soft_test]
            
        elif total_positives == 4:
            # 4 pos: pos1,pos2 to train; pos3,pos4 to val; q,pos3,pos4 to test
            rng.shuffle(pos_without_query)  # randomize order
            pos1, pos2, pos3, pos4 = pos_without_query[:4]
            
            # Split negatives proportionally
            hard_train, hard_val, hard_test = split_list(hard_neg, test_size, val_size)
            soft_train, soft_val, soft_test = split_list(soft_neg, test_size, val_size)
            
            splits_to_create = ['train', 'val', 'test']
            pos_lists = [
                [pos1, pos2],          # train
                [pos3, pos4],          # val
                [q, pos3, pos4]        # test (overlaps with val)
            ]
            hard_lists = [hard_train, hard_val, hard_test]
            soft_lists = [soft_train, soft_val, soft_test]
            
        elif total_positives == 3:
            # 3 pos: pos1,pos2 to train; pos2,pos3 to val; q,pos3 to test
            rng.shuffle(pos_without_query)  # randomize order
            pos1, pos2, pos3 = pos_without_query[:3]
            
            # Split negatives proportionally
            hard_train, hard_val, hard_test = split_list(hard_neg, test_size, val_size)
            soft_train, soft_val, soft_test = split_list(soft_neg, test_size, val_size)
            
            splits_to_create = ['train', 'val', 'test']
            pos_lists = [
                [pos1, pos2],      # train
                [pos2, pos3],      # val (overlaps with train)
                [q, pos3]          # test (overlaps with val)
            ]
            hard_lists = [hard_train, hard_val, hard_test]
            soft_lists = [soft_train, soft_val, soft_test]
            
        elif total_positives == 2:
            # 2 pos: pos1,pos2 to train; q,pos1 to val; q,pos2 to test
            rng.shuffle(pos_without_query)  # randomize order
            pos1, pos2 = pos_without_query[:2]
            
            # Split negatives proportionally
            hard_train, hard_val, hard_test = split_list(hard_neg, test_size, val_size)
            soft_train, soft_val, soft_test = split_list(soft_neg, test_size, val_size)
            
            splits_to_create = ['train', 'val', 'test']
            pos_lists = [
                [pos1, pos2],      # train: both positives
                [q, pos1],         # val: query + first positive
                [q, pos2]          # test: query + second positive
            ]
            hard_lists = [hard_train, hard_val, hard_test]
            soft_lists = [soft_train, soft_val, soft_test]
            
        elif total_positives == 1:
            # 1 pos: copy q,pos to all splits
            pos = pos_without_query[0]
            
            # Split negatives proportionally
            hard_train, hard_val, hard_test = split_list(hard_neg, test_size, val_size)
            soft_train, soft_val, soft_test = split_list(soft_neg, test_size, val_size)
            
            splits_to_create = ['train', 'val', 'test']
            pos_lists = [
                [q, pos],  # train: query + positive
                [q, pos],  # val: same pair
                [q, pos]   # test: same pair
            ]
            hard_lists = [hard_train, hard_val, hard_test]
            soft_lists = [soft_train, soft_val, soft_test]
            
        else:
            # Skip queries with no positives
            continue

        # Create the split rows
        for split_name, pos_list, hard_list, soft_list in zip(
            splits_to_create, pos_lists, hard_lists, soft_lists
        ):
            split_rows.append({
                'sku_query': q,
                'split': split_name,
                'sku_pos': pos_list,
                'sku_hard_neg': hard_list,
                'sku_soft_neg': soft_list
            })

    split_df = pd.DataFrame(split_rows)
    split_dict = {
        split: split_df[split_df['split'] == split].reset_index(drop=True)
        for split in ['train', 'val', 'test'] if split in split_df['split'].values
    }
    return split_dict

In [93]:
splits_dataset = split_query_groups(
    pairwise_mapping_df,
    test_size=0.1,
    val_size=0.1,
    random_state=42
)

pd.reset_option('display.max_colwidth')
splits_dataset['test'].head()

Unnamed: 0,sku_query,split,sku_pos,sku_hard_neg,sku_soft_neg
0,1871769771,test,"[467396304, 1871769771]","[1873027006, 1166886051, 601557370]","[1899881468, 1290396077, 1597431764, 165269677..."
1,1200553001,test,"[945075396, 1436509994, 1436449707, 1438364324...","[1499532091, 963112482, 1422204647, 1122827873...","[1878150702, 1901123430, 1595672507, 679265327..."
2,922231521,test,"[1436509994, 1158222448, 1081199697, 490461399...","[1001260979, 1802254834, 1252814277, 805782980...","[1032263980, 879403681, 1816716304, 1630407222..."
3,922230517,test,"[600803111, 1125093440, 1726148392, 974286048,...","[564434635, 1449544071, 1333611366, 1294181877...","[1807617650, 1113350792, 1245721824, 620961901..."
4,922230183,test,"[1819952117, 1679157969, 914654189, 922230183]","[959054273, 601557360, 1705669581, 950215375, ...","[1634447035, 1706808534, 1438798026, 181995203..."


In [94]:
# Compute avg, min, max count of soft negatives per query (per split) and output as table
import pandas as pd

soft_neg_stats = []
for split_name, df in splits_dataset.items():
    soft_neg_counts = df['sku_soft_neg'].apply(lambda x: len(x) if isinstance(x, list) else 0)
    avg_soft_neg = soft_neg_counts.mean() if not soft_neg_counts.empty else 0
    min_soft_neg = soft_neg_counts.min() if not soft_neg_counts.empty else 0
    max_soft_neg = soft_neg_counts.max() if not soft_neg_counts.empty else 0
    soft_neg_stats.append({
        'split': split_name,
        'avg_soft_neg': avg_soft_neg,
        'min_soft_neg': min_soft_neg,
        'max_soft_neg': max_soft_neg
    })

soft_neg_stats_df = pd.DataFrame(soft_neg_stats)
print(soft_neg_stats_df)

   split  avg_soft_neg  min_soft_neg  max_soft_neg
0  train       4198.80          3583          4446
1    val        525.45           449           556
2   test        525.45           449           556


In [95]:
# Prepare a summary table for each split
import pandas as pd

summary_rows = []
per_query_stats = {}

for split_name, df in splits_dataset.items():
    num_rows = len(df)
    num_pos = df['sku_pos'].apply(lambda x: len(x) if isinstance(x, list) else 0).sum()
    num_hard = df['sku_hard_neg'].apply(lambda x: len(x) if isinstance(x, list) else 0).sum()
    num_soft = df['sku_soft_neg'].apply(lambda x: len(x) if isinstance(x, list) else 0).sum()
    unique_skus = set(df['sku_query'])
    for col in ['sku_pos', 'sku_hard_neg', 'sku_soft_neg']:
        unique_skus.update([sku for sublist in df[col] for sku in (sublist if isinstance(sublist, list) else [])])
    summary_rows.append({
        'split': split_name,
        '#queries': num_rows,
        '#pos': num_pos,
        '#hard_neg': num_hard,
        '#soft_neg': num_soft,
        '#total_sku': len(unique_skus)
    })

    # Per-query stats for each type
    per_query = pd.DataFrame({
        'pos': df['sku_pos'].apply(lambda x: len(x) if isinstance(x, list) else 0),
        'hard_neg': df['sku_hard_neg'].apply(lambda x: len(x) if isinstance(x, list) else 0),
        'soft_neg': df['sku_soft_neg'].apply(lambda x: len(x) if isinstance(x, list) else 0),
    })
    agg = per_query.agg(['mean', 'std', 'min', 'max']).T
    agg.index.name = 'type'
    agg.columns.name = 'agg'
    per_query_stats[split_name] = agg

summary_df = pd.DataFrame(summary_rows)
display(summary_df)

Unnamed: 0,split,#queries,#pos,#hard_neg,#soft_neg,#total_sku
0,train,20,939,4004,83976,5562
1,val,20,134,513,10509,4899
2,test,20,150,514,10509,4918


In [96]:
# Display per-query stats for each split
multiindex_tuples = []
values = []
for split_name, agg in per_query_stats.items():
    for t in agg.index:
        for a in agg.columns:
            multiindex_tuples.append((split_name, t, a))
            values.append(agg.loc[t, a])
multiindex = pd.MultiIndex.from_tuples(multiindex_tuples, names=['split', 'type', 'agg'])
per_query_multi_df = pd.Series(values, index=multiindex).unstack(['type', 'agg']).swaplevel(axis=1)
# The above gives columns as (agg, type), swap to (type, agg)
per_query_multi_df.columns = per_query_multi_df.columns.swaplevel(0,1)
per_query_multi_df = per_query_multi_df.sort_index(axis=1, level=0)

# Cast all columns which are numerical to int (if possible)
for col in per_query_multi_df.columns:
    # Only cast if dtype is numeric and all values are close to integer (to avoid ValueError)
    if pd.api.types.is_numeric_dtype(per_query_multi_df[col]):
        if np.allclose(per_query_multi_df[col].dropna() % 1, 0):
            per_query_multi_df[col] = per_query_multi_df[col].astype(int)

# Reorder columns so that for each type (pos, hard_neg, soft_neg), columns are in order: mean, std, min, max
ordered_types = ['pos', 'hard_neg', 'soft_neg']
ordered_aggs = ['mean', 'std', 'min', 'max']
per_query_multi_df = per_query_multi_df.loc[:, [(t, a) for t in ordered_types for a in ordered_aggs]]
# Sort per_query_multi_df by split: train, val, test
split_order = ['train', 'val', 'test']
per_query_multi_df = per_query_multi_df.reindex(split_order)

display(per_query_multi_df.astype(int))

type,pos,pos,pos,pos,hard_neg,hard_neg,hard_neg,hard_neg,soft_neg,soft_neg,soft_neg,soft_neg
agg,mean,std,min,max,mean,std,min,max,mean,std,min,max
split,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2
train,46,64,2,216,200,249,0,755,4198,306,3583,4446
val,6,8,1,28,25,31,0,95,525,38,449,556
test,7,8,2,29,25,31,1,95,525,38,449,556


In [97]:
# Save splits

import os

split_info = f"test={TEST_RATIO}_val={VAL_RATIO}"
save_dir = os.path.join(
    "data/tables_OZ_geo_5500/processed/pairwise-mapping-splits", split_info
)
os.makedirs(save_dir, exist_ok=True)

# Save each split as a parquet file
for split_name, df in splits_dataset.items():
    out_path = os.path.join(save_dir, f"{split_name}.parquet")
    df.to_parquet(out_path, index=False)
    print(f"Saved {split_name} split to\n{out_path}")
    print()

Saved train split to
data/tables_OZ_geo_5500/processed/pairwise-mapping-splits/test=0.1_val=0.1/train.parquet

Saved val split to
data/tables_OZ_geo_5500/processed/pairwise-mapping-splits/test=0.1_val=0.1/val.parquet

Saved test split to
data/tables_OZ_geo_5500/processed/pairwise-mapping-splits/test=0.1_val=0.1/test.parquet



### Data leak sources

In [48]:
# # DATA LEAK #1 (major). Common positives
# # Some queries have common positives (see positive intersection matrix below)
# # Potential solution (2 steps): 
# #   1. merge sku w/ many intersections, 
# #   2. disentangle common positives w/ few intersections
# # NOTE: probably common hard negatives as well; certainly a lot of common soft negatives

# import numpy as np
# import pandas as pd

# # Build a list of queries and a mapping from query to set of positives
# queries = pairwise_mapping_df['sku_query'].tolist()
# sku_to_pos = {
#     row['sku_query']: set(row['sku_pos'])
#     for _, row in pairwise_mapping_df.iterrows()
# }

# num_queries = len(queries)
# intersection_matrix = np.zeros((num_queries, num_queries), dtype=int)

# for i, q1 in enumerate(queries):
#     pos1 = sku_to_pos[q1]
#     for j, q2 in enumerate(queries):
#         pos2 = sku_to_pos[q2]
#         intersection_matrix[i, j] = len(pos1 & pos2)

# # Mask diagonal and above with -1
# mask = np.triu(np.ones_like(intersection_matrix, dtype=bool))
# intersection_matrix[mask] = -1

# # Optionally, wrap as DataFrame for readability
# intersection_df = pd.DataFrame(
#     intersection_matrix, 
#     index=queries, 
#     columns=queries
# )
# intersection_df

In [99]:
# # DATA LEAK #2 (major). Lack of global disjointness w.r.t. sku over splits
# # Many skus are present in one split as positives and in another split as negatives

# def analyze_global_disjointness(splits_dataset):
#     """
#     Analyze global SKU disjointness across train/val/test splits.
#     Count positives in train that appear as negatives in val/test.
#     """
    
#     # Collect all SKUs by split and role
#     train_positives = set()
#     val_negatives = set()
#     test_negatives = set()
    
#     # Extract all positive SKUs from train split
#     for _, row in splits_dataset['train'].iterrows():
#         train_positives.update(row['sku_pos'])
    
#     # Extract all negative SKUs from val split
#     for _, row in splits_dataset['val'].iterrows():
#         val_negatives.update(row['sku_hard_neg'])
#         val_negatives.update(row['sku_soft_neg'])
    
#     # Extract all negative SKUs from test split  
#     for _, row in splits_dataset['test'].iterrows():
#         test_negatives.update(row['sku_hard_neg'])
#         test_negatives.update(row['sku_soft_neg'])
    
#     # Find overlaps (data leaks)
#     train_pos_in_val_neg = train_positives & val_negatives
#     train_pos_in_test_neg = train_positives & test_negatives
    
#     # Summary statistics
#     print("=== GLOBAL DISJOINTNESS ANALYSIS ===")
#     print(f"Train positives: {len(train_positives):,}")
#     print(f"Val negatives: {len(val_negatives):,}")
#     print(f"Test negatives: {len(test_negatives):,}")
#     print()
#     print("=== DATA LEAKS DETECTED ===")
#     print(f"Train positives appearing as Val negatives: {len(train_pos_in_val_neg):,}")
#     print(f"Train positives appearing as Test negatives: {len(train_pos_in_test_neg):,}")
#     print()
#     print("=== LEAK PERCENTAGES ===")
#     print(f"% of train positives leaked to val: {len(train_pos_in_val_neg)/len(train_positives)*100:.1f}%")
#     print(f"% of train positives leaked to test: {len(train_pos_in_test_neg)/len(train_positives)*100:.1f}%")
    
#     return {
#         'train_positives': train_positives,
#         'val_negatives': val_negatives, 
#         'test_negatives': test_negatives,
#         'train_pos_in_val_neg': train_pos_in_val_neg,
#         'train_pos_in_test_neg': train_pos_in_test_neg
#     }

# # Run the analysis
# leak_analysis = analyze_global_disjointness(splits_dataset)

In [100]:
# # DATA LEAK #3 (minor). Reused train sku in test for low-resource query sku
# # Some queries do not have enough positives for train/test or train/val/test splits
# # train/test split requires at least 3 positives per query (4 including query) => 2 train / 2 test
# # train/val/test split requires at least 5 positives per query (6 including query) => 2 train / 2 val / 2 test

# pairwise_mapping_df.set_index('sku_query')['sku_pos'].map(lambda s: len(s)).sort_values(ascending=True)

In [None]:
# DATA ISSUE (tiny): some query sku have bad descriptions (related to SEO-optimization or crawling issues)
# OZ_geo_5500_manual-edit.csv contains manual edits to fix the descriptions

## Make pairwise dataset

In [123]:
import pandas as pd
import numpy as np
from torch.utils.data import Dataset
from typing import Optional, List, Dict, Any
import random
import torch

class PairGenerationDataset(Dataset):
    """
    Dataset for constructing training pairs using a hybrid sampling strategy.
    """
    
    def __init__(self, 
                 split_df: pd.DataFrame,
                 cluster_emb_table: pd.DataFrame,
                 source_df: pd.DataFrame = None,
                 images_dir: str = None,
                 max_pos_pairs_per_query: int = 3,
                 pos_neg_ratio: float = 2.0,
                 hard_soft_ratio: float = 0.7,
                 random_seed: int = 42):
        
        self.split_df = split_df.reset_index(drop=True)
        self.cluster_emb_table = cluster_emb_table
        self.source_df = source_df.set_index('sku') if source_df is not None else None
        self.images_dir = images_dir
        self.max_pos_pairs_per_query = max_pos_pairs_per_query
        self.pos_neg_ratio = pos_neg_ratio
        self.hard_soft_ratio = hard_soft_ratio
        self.random_seed = random_seed
        
        # Set random seeds for reproducibility
        random.seed(random_seed)
        np.random.seed(random_seed)
        
        # Build cluster mappings for soft negatives
        self._build_cluster_mappings()
        
        # Generate all pairs
        self.pairs = self._generate_pairs()
        
    def _build_cluster_mappings(self):
        """Build mappings from cluster to SKUs for efficient sampling."""
        if self.cluster_emb_table is not None:
            self.sku_to_cluster = dict(zip(self.cluster_emb_table['sku'], self.cluster_emb_table['cluster_id']))
            self.cluster_to_skus = {}
            for _, row in self.cluster_emb_table.iterrows():
                cluster = row['cluster_id']
                sku = row['sku']
                if cluster not in self.cluster_to_skus:
                    self.cluster_to_skus[cluster] = []
                self.cluster_to_skus[cluster].append(sku)
        else:
            self.sku_to_cluster = {}
            self.cluster_to_skus = {}
    
    def _generate_pairs(self) -> List[Dict[str, Any]]:
        """Generate all training pairs using hybrid approach."""
        all_pairs = []
        
        for _, row in self.split_df.iterrows():
            query_sku = row['sku_query']
            pos_skus = row['sku_pos'] if isinstance(row['sku_pos'], list) else []
            hard_neg_skus = row['sku_hard_neg'] if isinstance(row['sku_hard_neg'], list) else []
            soft_neg_skus = row['sku_soft_neg'] if isinstance(row['sku_soft_neg'], list) else []
            
            # Generate positive pairs
            positive_pairs = self._generate_positive_pairs_fixed(query_sku, pos_skus)
            
            # Calculate negatives based on actual positive count and ratio
            actual_pos_count = len(positive_pairs)
            total_negatives_needed = int(actual_pos_count * self.pos_neg_ratio)
            
            # Split negatives into hard and soft based on hard_soft_ratio
            num_hard_neg = int(total_negatives_needed * self.hard_soft_ratio)
            num_soft_neg = total_negatives_needed - num_hard_neg
            
            # Generate negative pairs
            hard_negative_pairs = self._generate_hard_negative_pairs(
                query_sku, pos_skus, hard_neg_skus, num_hard_neg
            )
            soft_negative_pairs = self._generate_soft_negative_pairs(
                query_sku, pos_skus, soft_neg_skus, num_soft_neg
            )
            
            # Combine all pairs for this query
            all_pairs.extend(positive_pairs)
            all_pairs.extend(hard_negative_pairs)
            all_pairs.extend(soft_negative_pairs)
            
        return all_pairs
    
    def _generate_positive_pairs_fixed(self, query_sku: int, pos_skus: List[int]) -> List[Dict[str, Any]]:
        """Generate exactly max_pos_pairs_per_query positive pairs."""
        # Fix: Handle None case
        if not pos_skus or (self.max_pos_pairs_per_query is not None and self.max_pos_pairs_per_query <= 0):
            return []
        
        # Get all possible positive pairs (Cartesian product excluding self-pairs)
        all_possible_pairs = []
        for pos1 in pos_skus:
            for pos2 in pos_skus:
                if pos1 != pos2:
                    all_possible_pairs.append((pos1, pos2))
        
        # Handle case where no pairs can be formed
        if not all_possible_pairs:
            if len(set(pos_skus)) == 1:
                single_sku = pos_skus[0]
                all_possible_pairs = [(single_sku, single_sku)]
            else:
                return []
        
        # Handle None case for max_pos_pairs_per_query
        if self.max_pos_pairs_per_query is None:
            # Use all available pairs if no limit set
            selected_pairs = all_possible_pairs
        elif len(all_possible_pairs) >= self.max_pos_pairs_per_query:
            selected_pairs = random.sample(all_possible_pairs, self.max_pos_pairs_per_query)
        else:
            # If not enough unique pairs, sample with replacement
            selected_pairs = random.choices(all_possible_pairs, k=self.max_pos_pairs_per_query)
        
        pairs = []
        for pos1, pos2 in selected_pairs:
            pairs.append({
                'sku_first': pos1,
                'sku_second': pos2,
                'label': 0,  # 0 = positive (similar)
                'query_sku': query_sku,
                'pair_type': 'positive'
            })
        
        return pairs
    
    def _generate_hard_negative_pairs(self, query_sku: int, pos_skus: List[int], 
                                    hard_neg_skus: List[int], num_hard_neg: int) -> List[Dict[str, Any]]:
        """Generate hard negative pairs by pairing positives with hard negatives."""
        if num_hard_neg <= 0 or not pos_skus or not hard_neg_skus:
            return []
        
        pairs = []
        for _ in range(num_hard_neg):
            # Randomly select a positive and a hard negative
            pos_sku = random.choice(pos_skus)
            hard_sku = random.choice(hard_neg_skus)
            
            pairs.append({
                'sku_first': pos_sku,
                'sku_second': hard_sku,
                'label': 1,  # 1 = negative (different)
                'query_sku': query_sku,
                'pair_type': 'hard_negative'
            })
        
        return pairs
    
    def _generate_soft_negative_pairs(self, query_sku: int, pos_skus: List[int], 
                                    soft_neg_skus: List[int], num_soft_neg: int) -> List[Dict[str, Any]]:
        """Generate soft negative pairs by pairing positives with soft negatives."""
        if num_soft_neg <= 0 or not pos_skus or not soft_neg_skus:
            return []
        
        pairs = []
        for _ in range(num_soft_neg):
            # Randomly select a positive and a soft negative
            pos_sku = random.choice(pos_skus)
            soft_sku = random.choice(soft_neg_skus)
            
            pairs.append({
                'sku_first': pos_sku,
                'sku_second': soft_sku,
                'label': 1,  # 1 = negative (different)
                'query_sku': query_sku,
                'pair_type': 'soft_negative'
            })
        
        return pairs
    
    def get_batch_stats(self) -> Dict[str, Any]:
        """Get detailed statistics about the generated pairs."""
        stats = {
            'total_pairs': len(self.pairs),
            'positives': sum(1 for p in self.pairs if p['pair_type'] == 'positive'),
            'hard_negatives': sum(1 for p in self.pairs if p['pair_type'] == 'hard_negative'),
            'soft_negatives': sum(1 for p in self.pairs if p['pair_type'] == 'soft_negative'),
            'queries': len(self.split_df),
        }
        
        stats['total_negatives'] = stats['hard_negatives'] + stats['soft_negatives']
        
        # Calculate actual ratios achieved
        if stats['positives'] > 0:
            stats['actual_neg_pos_ratio'] = stats['total_negatives'] / stats['positives']
        else:
            stats['actual_neg_pos_ratio'] = 0.0
            
        if stats['total_negatives'] > 0:
            stats['actual_hard_soft_ratio'] = stats['hard_negatives'] / stats['total_negatives']
        else:
            stats['actual_hard_soft_ratio'] = 0.0
        
        # Calculate per-query averages
        stats['avg_pos_per_query'] = stats['positives'] / max(stats['queries'], 1)
        stats['avg_neg_per_query'] = stats['total_negatives'] / max(stats['queries'], 1)
        stats['avg_pairs_per_query'] = stats['total_pairs'] / max(stats['queries'], 1)
        
        return stats
    
    def __len__(self):
        return len(self.pairs)
    
    def __getitem__(self, idx):
        pair = self.pairs[idx]
        return {
            'sku_first': torch.tensor(pair['sku_first']),
            'sku_second': torch.tensor(pair['sku_second']),
            'label': torch.tensor(pair['label'], dtype=torch.float32),
            'query_sku': pair['query_sku'],
            'pair_type': pair['pair_type']
        }
    
    def get_batch_stats(self) -> Dict[str, Any]:
        """Get detailed statistics about the generated pairs."""
        stats = {
            'total_pairs': len(self.pairs),
            'positives': sum(1 for p in self.pairs if p['pair_type'] == 'positive'),
            'hard_negatives': sum(1 for p in self.pairs if p['pair_type'] == 'hard_negative'),
            'soft_negatives': sum(1 for p in self.pairs if p['pair_type'] == 'soft_negative'),
            'queries': len(self.split_df),
        }
        
        stats['total_negatives'] = stats['hard_negatives'] + stats['soft_negatives']
        
        # Calculate actual ratios achieved
        if stats['positives'] > 0:
            stats['actual_neg_pos_ratio'] = stats['total_negatives'] / stats['positives']
        else:
            stats['actual_neg_pos_ratio'] = 0.0
            
        if stats['total_negatives'] > 0:
            stats['actual_hard_soft_ratio'] = stats['hard_negatives'] / stats['total_negatives']
        else:
            stats['actual_hard_soft_ratio'] = 0.0
        
        # Calculate per-query averages
        stats['avg_pos_per_query'] = stats['positives'] / max(stats['queries'], 1)
        stats['avg_neg_per_query'] = stats['total_negatives'] / max(stats['queries'], 1)
        stats['avg_pairs_per_query'] = stats['total_pairs'] / max(stats['queries'], 1)
        
        return stats
    
    def print_detailed_stats(self):
        """Print comprehensive dataset statistics."""
        stats = self.get_batch_stats()
        
        print("=== Hybrid PairGenerationDataset Statistics ===")
        print(f"📊 Dataset Overview:")
        print(f"  Queries: {stats['queries']}")
        print(f"  Total pairs: {stats['total_pairs']}")
        print(f"  Avg pairs per query: {stats['avg_pairs_per_query']:.1f}")
        
        print(f"\n🎯 Pair Breakdown:")
        print(f"  Positives: {stats['positives']} ({stats['positives']/max(stats['total_pairs'],1)*100:.1f}%)")
        print(f"  Hard negatives: {stats['hard_negatives']} ({stats['hard_negatives']/max(stats['total_pairs'],1)*100:.1f}%)")
        print(f"  Soft negatives: {stats['soft_negatives']} ({stats['soft_negatives']/max(stats['total_pairs'],1)*100:.1f}%)")
        print(f"  Total negatives: {stats['total_negatives']}")
        
        print(f"\n⚖️ Ratios Achieved:")
        print(f"  Target neg:pos ratio: {self.pos_neg_ratio:.2f}")
        print(f"  Actual neg:pos ratio: {stats['actual_neg_pos_ratio']:.2f}")
        print(f"  Target hard:soft ratio: {self.hard_soft_ratio:.2f}")
        print(f"  Actual hard:soft ratio: {stats['actual_hard_soft_ratio']:.2f}")
        
        print(f"\n📈 Per-Query Averages:")
        print(f"  Avg positives per query: {stats['avg_pos_per_query']:.1f}")
        print(f"  Avg negatives per query: {stats['avg_neg_per_query']:.1f}")
        print(f"  Max pos pairs setting: {self.max_pos_pairs_per_query}")

    def to_pairwise_dataframe(self):
        """
        Generate a full pairwise dataframe containing all pairs from the dataset.

        Returns:
            pd.DataFrame: Pairwise dataframe with columns from source_df duplicated 
                        with '_first' and '_second' suffixes, plus 'pair_type', 'sku_query',
                        'sku_first', 'sku_second', and 'label'
        """
        if self.source_df is None:
            raise ValueError("source_df is required to generate pairwise dataframe")

        pairwise_rows = []

        # Map pair_type to desired format and label
        pair_type_mapping = {
            'positive': 'pos',
            'hard_negative': 'hard_neg', 
            'soft_negative': 'soft_neg'
        }
        
        # Map pair_type to binary labels for training
        pair_type_to_label = {
            'positive': 0,      # Similar pairs get label 0 (for contrastive loss)
            'hard_negative': 1, # Dissimilar pairs get label 1 (for contrastive loss)
            'soft_negative': 1  # Dissimilar pairs get label 1 (for contrastive loss)
        }

        # Add progress bar for pair processing
        for pair in tqdm(self.pairs, desc="Converting to pairwise dataframe"):
            sku_first = pair['sku_first']
            sku_second = pair['sku_second']
            query_sku = pair['query_sku']
            pair_type = pair['pair_type']

            mapped_pair_type = pair_type_mapping.get(pair_type, pair_type)
            label = pair_type_to_label.get(pair_type, 1)  # Default to 1 (dissimilar) if unknown

            # Check if both SKUs exist in source_df
            if sku_first in self.source_df.index and sku_second in self.source_df.index:
                row_first = self.source_df.loc[sku_first]
                row_second = self.source_df.loc[sku_second]

                # Create pairwise row
                pairwise_row = {}

                # Add columns from source_df with _first suffix
                for col in self.source_df.columns:
                    pairwise_row[f"{col}_first"] = row_first[col]

                # Add columns from source_df with _second suffix  
                for col in self.source_df.columns:
                    pairwise_row[f"{col}_second"] = row_second[col]

                # Add pair metadata
                pairwise_row['pair_type'] = mapped_pair_type
                pairwise_row['sku_query'] = query_sku
                pairwise_row['label'] = 1 - label  # Add the inverted label column (NOT for contrastive loss)

                # Add sku_first and sku_second columns explicitly
                pairwise_row['sku_first'] = sku_first
                pairwise_row['sku_second'] = sku_second

                pairwise_rows.append(pairwise_row)
            else:
                # Log warning for missing SKUs but continue
                missing_skus = []
                if sku_first not in self.source_df.index:
                    missing_skus.append(sku_first)
                if sku_second not in self.source_df.index:
                    missing_skus.append(sku_second)
                print(f"Warning: SKU(s) {missing_skus} not found in source_df, skipping pair")

        return pd.DataFrame(pairwise_rows)

In [145]:
LIMIT_VAL_POS_FOR_SAVING = None

val_dataset = PairGenerationDataset(
    split_df=splits_dataset['val'],
    cluster_emb_table=cluster_emb_table,
    source_df=source_df,
    images_dir = os.path.join(DATA_PATH, IMG_DATASET_NAME),
    max_pos_pairs_per_query=LIMIT_VAL_POS_FOR_SAVING,
    pos_neg_ratio=POS_NEG_RATIO,
    hard_soft_ratio=HARD_SOFT_RATIO,
    random_seed=RANDOM_SEED
)
val_dataset.print_detailed_stats()

=== Hybrid PairGenerationDataset Statistics ===
📊 Dataset Overview:
  Queries: 20
  Total pairs: 3977
  Avg pairs per query: 198.8

🎯 Pair Breakdown:
  Positives: 1989 (50.0%)
  Hard negatives: 992 (24.9%)
  Soft negatives: 996 (25.0%)
  Total negatives: 1988

⚖️ Ratios Achieved:
  Target neg:pos ratio: 1.00
  Actual neg:pos ratio: 1.00
  Target hard:soft ratio: 0.50
  Actual hard:soft ratio: 0.50

📈 Per-Query Averages:
  Avg positives per query: 99.5
  Avg negatives per query: 99.4
  Max pos pairs setting: None


In [146]:
val_pairwise_df = val_dataset.to_pairwise_dataframe()
val_pairwise_df.shape

Converting to pairwise dataframe: 100%|██████████| 3977/3977 [00:01<00:00, 2926.44it/s]


(3977, 95)

In [147]:
LIMIT_TEST_POS_FOR_SAVING = None

test_dataset = PairGenerationDataset(
    split_df=splits_dataset['test'],
    cluster_emb_table=cluster_emb_table,
    source_df=source_df,
    images_dir = os.path.join(DATA_PATH, IMG_DATASET_NAME),
    max_pos_pairs_per_query=LIMIT_TEST_POS_FOR_SAVING,
    pos_neg_ratio=POS_NEG_RATIO,
    hard_soft_ratio=HARD_SOFT_RATIO,
    random_seed=RANDOM_SEED
)
test_dataset.print_detailed_stats()

=== Hybrid PairGenerationDataset Statistics ===
📊 Dataset Overview:
  Queries: 20
  Total pairs: 4476
  Avg pairs per query: 223.8

🎯 Pair Breakdown:
  Positives: 2238 (50.0%)
  Hard negatives: 1119 (25.0%)
  Soft negatives: 1119 (25.0%)
  Total negatives: 2238

⚖️ Ratios Achieved:
  Target neg:pos ratio: 1.00
  Actual neg:pos ratio: 1.00
  Target hard:soft ratio: 0.50
  Actual hard:soft ratio: 0.50

📈 Per-Query Averages:
  Avg positives per query: 111.9
  Avg negatives per query: 111.9
  Max pos pairs setting: None


In [148]:
test_pairwise_df = test_dataset.to_pairwise_dataframe()
test_pairwise_df.shape

Converting to pairwise dataframe: 100%|██████████| 4476/4476 [00:01<00:00, 2746.31it/s]


(4476, 95)

In [154]:
# Save pairwise_df with a name including all sampling parameters

# Prepare parameter names with dashes instead of underscores
def param_name(param):
    return param.replace('_', '-')

for split, pairwise_df, pos_limit in [
    ('val',  val_pairwise_df,  LIMIT_VAL_POS_FOR_SAVING),
    ('test', test_pairwise_df, LIMIT_TEST_POS_FOR_SAVING), 
]:
    # Collect parameters and their values
    params = [
        ('num-rows', len(pairwise_df)),
        ('limit-pos', pos_limit),
        ('pos-neg', POS_NEG_RATIO),
        ('hard-soft', HARD_SOFT_RATIO),
        ('seed', RANDOM_SEED)
    ]

    # Build folder_name
    param_str = '_'.join([f"{param}={value}" for param, value in params])
    # Use DATA_PATH as the root directory to ensure the full path exists
    filepath = (
        Path(DATA_PATH) / 'tables_OZ_geo_5500' 
        / 'processed' / 'pairwise-rendered' / split / param_str / 'pairs.parquet'
    )
    filepath.parent.mkdir(parents=True, exist_ok=True)

    # Save DataFrame
    pairwise_df.to_parquet(filepath, index=False)
    # Print path without DATA_PATH at the start
    rel_path = filepath.relative_to(DATA_PATH)
    print(f"Saved {split} pairwise_df to\n{rel_path}")
    print()

Saved val pairwise_df to
tables_OZ_geo_5500/processed/pairwise-rendered/val/num-rows=3977_limit-pos=None_pos-neg=1.0_hard-soft=0.5_seed=42/pairs.parquet

Saved test pairwise_df to
tables_OZ_geo_5500/processed/pairwise-rendered/test/num-rows=4476_limit-pos=None_pos-neg=1.0_hard-soft=0.5_seed=42/pairs.parquet



In [33]:
from typing import Dict, Any, List
import random
import torch
from torch.utils.data import Dataset
from tqdm import tqdm
import cv2
from PIL import Image
import os

class PrecomputedPairDataset(Dataset):
    """
    Pre-compute all pairs and cache tokenized data for faster loading.
    """
    def __init__(self, split_df, source_df, images_dir, **kwargs):
        print("\n=== Initializing PrecomputedPairDataset ===")
        self.images_dir = images_dir
        
        # Store parameters for pair generation
        self.max_pos_pairs_per_query = kwargs.get('max_pos_pairs_per_query')
        self.pos_neg_ratio = kwargs.get('pos_neg_ratio', 2.0)
        self.hard_soft_ratio = kwargs.get('hard_soft_ratio', 0.5)
        
        # Pre-compute all pairs (without actual data loading)
        print("\n1. Generating pairs metadata...")
        self.pairs = self._generate_pairs_metadata(split_df, **kwargs)
        print(f"Generated {len(self.pairs)} pairs")
        
        # Pre-load and cache all unique SKUs data
        print("\n2. Starting data precomputation...")
        self.sku_cache = self._preload_sku_data(source_df)
        
        print(f"\n=== Dataset Ready ===")
        print(f"Total pairs: {len(self.pairs)}")
        print(f"Unique SKUs cached: {len(self.sku_cache)}")
    
    def get_batch_stats(self) -> Dict[str, Any]:
        """Get detailed statistics about the generated pairs."""
        stats = {
            'total_pairs': len(self.pairs),
            'positives': sum(1 for p in self.pairs if p['pair_type'] == 'positive'),
            'hard_negatives': sum(1 for p in self.pairs if p['pair_type'] == 'hard_negative'),
            'soft_negatives': sum(1 for p in self.pairs if p['pair_type'] == 'soft_negative'),
            'queries': len(set(p['query_sku'] for p in self.pairs)),  # Count unique queries
        }
        
        stats['total_negatives'] = stats['hard_negatives'] + stats['soft_negatives']
        
        # Calculate actual ratios achieved
        if stats['positives'] > 0:
            stats['actual_neg_pos_ratio'] = stats['total_negatives'] / stats['positives']
        else:
            stats['actual_neg_pos_ratio'] = 0.0
            
        if stats['total_negatives'] > 0:
            stats['actual_hard_soft_ratio'] = stats['hard_negatives'] / stats['total_negatives']
        else:
            stats['actual_hard_soft_ratio'] = 0.0
        
        # Calculate per-query averages
        stats['avg_pos_per_query'] = stats['positives'] / max(stats['queries'], 1)
        stats['avg_neg_per_query'] = stats['total_negatives'] / max(stats['queries'], 1)
        stats['avg_pairs_per_query'] = stats['total_pairs'] / max(stats['queries'], 1)
        
        return stats
    
    def _generate_positive_pairs_fixed(self, query_sku: int, pos_skus: List[int], max_pairs: int = None) -> List[Dict[str, Any]]:
        """Generate positive pairs with progress tracking."""
        if max_pairs is None:
            max_pairs = self.max_pos_pairs_per_query
            
        # Fix: Handle None case
        if not pos_skus or (max_pairs is not None and max_pairs <= 0):
            return []
        
        # Get all possible positive pairs (Cartesian product excluding self-pairs)
        all_possible_pairs = []
        for pos1 in pos_skus:
            for pos2 in pos_skus:
                if pos1 != pos2:
                    all_possible_pairs.append((pos1, pos2))
        
        # Handle case where no pairs can be formed
        if not all_possible_pairs:
            if len(set(pos_skus)) == 1:
                single_sku = pos_skus[0]
                all_possible_pairs = [(single_sku, single_sku)]
            else:
                return []
        
        # Handle None case for max_pos_pairs_per_query
        if max_pairs is None:
            # Use all available pairs if no limit set
            selected_pairs = all_possible_pairs
        elif len(all_possible_pairs) >= max_pairs:
            selected_pairs = random.sample(all_possible_pairs, max_pairs)
        else:
            # If not enough unique pairs, sample with replacement
            selected_pairs = random.choices(all_possible_pairs, k=max_pairs)
        
        pairs = []
        for pos1, pos2 in selected_pairs:
            pairs.append({
                'sku_first': pos1,
                'sku_second': pos2,
                'label': 0,  # 0 = positive (similar)
                'query_sku': query_sku,
                'pair_type': 'positive'
            })
        
        return pairs
    
    def _generate_hard_negative_pairs(self, query_sku: int, pos_skus: List[int], 
                                    hard_neg_skus: List[int], num_hard_neg: int) -> List[Dict[str, Any]]:
        """Generate hard negative pairs by pairing positives with hard negatives."""
        if num_hard_neg <= 0 or not pos_skus or not hard_neg_skus:
            return []
        
        pairs = []
        for _ in range(num_hard_neg):
            # Randomly select a positive and a hard negative
            pos_sku = random.choice(pos_skus)
            hard_sku = random.choice(hard_neg_skus)
            
            pairs.append({
                'sku_first': pos_sku,
                'sku_second': hard_sku,
                'label': 1,  # 1 = negative (different)
                'query_sku': query_sku,
                'pair_type': 'hard_negative'
            })
        
        return pairs
    
    def _generate_soft_negative_pairs(self, query_sku: int, pos_skus: List[int], 
                                    soft_neg_skus: List[int], num_soft_neg: int) -> List[Dict[str, Any]]:
        """Generate soft negative pairs by pairing positives with soft negatives."""
        if num_soft_neg <= 0 or not pos_skus or not soft_neg_skus:
            return []
        
        pairs = []
        for _ in range(num_soft_neg):
            # Randomly select a positive and a soft negative
            pos_sku = random.choice(pos_skus)
            soft_sku = random.choice(soft_neg_skus)
            
            pairs.append({
                'sku_first': pos_sku,
                'sku_second': soft_sku,
                'label': 1,  # 1 = negative (different)
                'query_sku': query_sku,
                'pair_type': 'soft_negative'
            })
        
        return pairs
    
    def _preload_sku_data(self, source_df):
        """Pre-load and cache all SKU data including tokenized text."""
        tokenizers = Tokenizers()
        transform = get_transform()
        
        sku_cache = {}
        unique_skus = set()
        
        # Collect all unique SKUs from pairs
        for pair in self.pairs:
            unique_skus.add(pair['sku_first'])
            unique_skus.add(pair['sku_second'])
        
        print(f"Found {len(unique_skus)} unique SKUs")
        source_indexed = source_df.set_index('sku')
        
        # Single progress bar for all operations
        pbar = tqdm(total=len(unique_skus), desc="Overall progress")
        
        for sku in unique_skus:
            if sku in source_indexed.index:
                row = source_indexed.loc[sku]
                
                # Pre-tokenize text
                name_tokens = tokenizers.tokenize_name([str(row['name'])])
                desc_tokens = tokenizers.tokenize_description([str(row['description'])])
                
                # Pre-load and transform image
                img_path = os.path.join(self.images_dir, row['image_name'])
                try:
                    img = cv2.imread(img_path)
                    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
                    img = Image.fromarray(img)
                    img = transform(img)
                except Exception as e:
                    print(f"\nWarning: Failed to load image for SKU {sku}: {e}")
                    img = torch.zeros(3, 224, 224)
                
                sku_cache[sku] = {
                    'image': img,
                    'name_tokens': name_tokens[0],  # Remove batch dimension
                    'desc_tokens': desc_tokens[0],  # Remove batch dimension
                }
            
            pbar.update(1)
        
        pbar.close()
        return sku_cache
    
    def _generate_pairs_metadata(self, split_df, **kwargs):
        """Generate pair metadata with progress tracking."""
        print("Generating pairs with parameters:")
        for k, v in kwargs.items():
            print(f"  {k}: {v}")
        
        # Set random seed for reproducibility
        random.seed(kwargs.get('random_seed', 42))
            
        pairs = []
        
        with tqdm(total=len(split_df), desc="Generating pairs") as pbar:
            for _, row in split_df.iterrows():
                query_sku = row['sku_query']
                pos_skus = row['sku_pos'] if isinstance(row['sku_pos'], list) else []
                hard_neg_skus = row['sku_hard_neg'] if isinstance(row['sku_hard_neg'], list) else []
                soft_neg_skus = row['sku_soft_neg'] if isinstance(row['sku_soft_neg'], list) else []
                
                # Generate positive pairs
                pos_pairs = self._generate_positive_pairs_fixed(
                    query_sku, pos_skus, 
                    max_pairs=kwargs.get('max_pos_pairs_per_query')
                )
                pairs.extend(pos_pairs)
                
                # Calculate negatives based on actual positive count
                actual_pos_count = len(pos_pairs)
                total_negatives_needed = int(actual_pos_count * kwargs.get('pos_neg_ratio', 1.0))
                num_hard_neg = int(total_negatives_needed * kwargs.get('hard_soft_ratio', 0.5))
                num_soft_neg = total_negatives_needed - num_hard_neg
                
                # Generate negative pairs
                pairs.extend(self._generate_hard_negative_pairs(
                    query_sku, pos_skus, hard_neg_skus, num_hard_neg
                ))
                pairs.extend(self._generate_soft_negative_pairs(
                    query_sku, pos_skus, soft_neg_skus, num_soft_neg
                ))
                
                pbar.update(1)
                pbar.set_postfix({
                    'pos': len(pos_pairs),
                    'hard': num_hard_neg,
                    'soft': num_soft_neg
                })
        
        return pairs
    
    def __getitem__(self, idx):
        pair = self.pairs[idx]
        sku_first = pair['sku_first']
        sku_second = pair['sku_second']
        
        # Get cached data (no I/O operations!)
        data_first = self.sku_cache[sku_first]
        data_second = self.sku_cache[sku_second]
        
        return {
            'image_first': data_first['image'],
            'name_first': data_first['name_tokens'],
            'desc_first': data_first['desc_tokens'],
            'image_second': data_second['image'],
            'name_second': data_second['name_tokens'],
            'desc_second': data_second['desc_tokens'],
            'label': torch.tensor(pair['label'], dtype=torch.float32),
            'pair_type': pair['pair_type']
        }
    
    def __len__(self):
        return len(self.pairs)
    
    def get_stats(self):
        """Get detailed dataset statistics (alternative interface)."""
        batch_stats = self.get_batch_stats()
        
        stats = {
            'total_pairs': batch_stats['total_pairs'],
            'unique_skus': len(self.sku_cache),
            'pair_types': {
                'positive': batch_stats['positives'],
                'hard_negative': batch_stats['hard_negatives'],
                'soft_negative': batch_stats['soft_negatives']
            },
            'memory_usage': {
                'images': sum(img['image'].nelement() * img['image'].element_size() 
                            for img in self.sku_cache.values()) / 1024**2,  # MB
                'tokens': sum((t['name_tokens'].nelement() + t['desc_tokens'].nelement()) * 
                            t['name_tokens'].element_size() 
                            for t in self.sku_cache.values()) / 1024**2  # MB
            }
        }
        
        return stats
    
    def print_detailed_stats(self):
        """Print comprehensive dataset statistics."""
        stats = self.get_batch_stats()
        
        print("=== PrecomputedPairDataset Statistics ===")
        print(f"📊 Dataset Overview:")
        print(f"  Queries: {stats['queries']}")
        print(f"  Total pairs: {stats['total_pairs']}")
        print(f"  Avg pairs per query: {stats['avg_pairs_per_query']:.1f}")
        
        print(f"\n🎯 Pair Breakdown:")
        print(f"  Positives: {stats['positives']} ({stats['positives']/max(stats['total_pairs'],1)*100:.1f}%)")
        print(f"  Hard negatives: {stats['hard_negatives']} ({stats['hard_negatives']/max(stats['total_pairs'],1)*100:.1f}%)")
        print(f"  Soft negatives: {stats['soft_negatives']} ({stats['soft_negatives']/max(stats['total_pairs'],1)*100:.1f}%)")
        print(f"  Total negatives: {stats['total_negatives']}")
        
        print(f"\n⚖️ Ratios Achieved:")
        print(f"  Target neg:pos ratio: {self.pos_neg_ratio:.2f}")
        print(f"  Actual neg:pos ratio: {stats['actual_neg_pos_ratio']:.2f}")
        print(f"  Target hard:soft ratio: {self.hard_soft_ratio:.2f}")
        print(f"  Actual hard:soft ratio: {stats['actual_hard_soft_ratio']:.2f}")
        
        print(f"\n📈 Per-Query Averages:")
        print(f"  Avg positives per query: {stats['avg_pos_per_query']:.1f}")
        print(f"  Avg negatives per query: {stats['avg_neg_per_query']:.1f}")
        print(f"  Max pos pairs setting: {self.max_pos_pairs_per_query}")

# Run training

In [38]:
# # Check that each image_name is present in the images_dir
# images_dir = os.path.join(DATA_PATH, IMG_DATASET_NAME)
# missing_images = []
# for img_name in source_df['image_name']:
#     img_path = os.path.join(images_dir, img_name)
#     if not os.path.isfile(img_path):
#         missing_images.append(img_name)

# if missing_images:
#     print(f"Missing {len(missing_images)} images:")
#     print(missing_images[:10])  # Show up to 10 missing images
# else:
#     print("All images are present in the images_dir.")

In [None]:
def train_with_threshold_tracking(model, optimizer, criterion, epochs_num, train_loader, 
                                valid_loader=None, device='cpu', print_epoch=False, 
                                models_dir=None, metric='f1', source_df=None, images_dir=None,
                                precompute_pairs=False):
    """
    Training function that handles both precomputed and on-the-fly data formats.
    
    Args:
        precompute_pairs: If True, expects PrecomputedPairDataset format with pre-loaded tensors.
                         If False, expects PairGenerationDataset format and converts SKUs to tensors.
    """
    from copy import deepcopy
    
    assert metric in ('f1', 'pos_acc'), "metric must be 'f1' or 'pos_acc'"

    model.to(device)
    train_losses, val_losses, thr_history = [], [], []
    best_valid_metric, best_threshold = float('-inf'), None
    best_weights = None

    # Only initialize tokenizers if not using precomputed pairs
    if not precompute_pairs:
        tokenizers = Tokenizers()
        transform = get_transform()

    scheduler = ReduceLROnPlateau(
        optimizer,
        mode="max",
        factor=0.1,
        patience=SHEDULER_PATIENCE,
        threshold=1e-4,
        threshold_mode='rel'
    )

    if models_dir:
        Path(models_dir).mkdir(parents=True, exist_ok=True)

    for epoch in range(1, epochs_num + 1):
        # ---- training ----
        model.train()
        total_train_loss = 0.0
        
        for batch in tqdm(train_loader, desc=f"Training Epoch {epoch}"):
            if precompute_pairs:
                # PrecomputedPairDataset format - data is already tensors
                im1 = batch['image_first'].to(device)
                n1 = batch['name_first'].to(device)
                d1 = batch['desc_first'].to(device)
                im2 = batch['image_second'].to(device)
                n2 = batch['name_second'].to(device)
                d2 = batch['desc_second'].to(device)
                labels = batch['label'].to(device)
            else:
                # PairGenerationDataset format - convert SKUs to tensors
                sku_first = batch['sku_first']
                sku_second = batch['sku_second'] 
                labels = batch['label'].to(device)
                
                # Convert SKUs to model inputs
                im1, n1, d1 = sku_to_model_inputs(sku_first.tolist(), source_df, images_dir, tokenizers, transform)
                im2, n2, d2 = sku_to_model_inputs(sku_second.tolist(), source_df, images_dir, tokenizers, transform)
                
                # Move to device
                im1, n1, d1 = im1.to(device), n1.to(device), d1.to(device)
                im2, n2, d2 = im2.to(device), n2.to(device), d2.to(device)
            
            optimizer.zero_grad()
            out1, out2 = model(im1, n1, d1, im2, n2, d2)
            loss = criterion(out1, out2, labels)
            loss.backward()
            optimizer.step()
            total_train_loss += loss.item()
            
        train_losses.append(total_train_loss / len(train_loader))

        # ---- evaluation & checkpointing ----
        if print_epoch and valid_loader is not None:
            pos_acc, neg_acc, avg_acc, f1_val, val_loss, val_thr = evaluation(
                model, criterion, valid_loader, epoch, device=device,
                split_name='val', threshold=None, margin=CONTRASTIVE_MARGIN,
                steps=200, metric=metric, source_df=source_df, images_dir=images_dir,
                precompute_pairs=precompute_pairs
            )
            val_losses.append(val_loss)
            thr_history.append(val_thr)

            # pick the metric value to step & compare
            cur_metric = pos_acc if metric == 'pos_acc' else f1_val
            scheduler.step(cur_metric)

            # save checkpoint every epoch if requested
            if models_dir:
                # Create detailed filename with validation metrics
                checkpoint_filename = (
                    f"siamese_contrastive_soft-neg_epoch={epoch}_"
                    f"val-f1={f1_val:.3f}_val-pos-acc={pos_acc:.3f}_val-neg-acc={neg_acc:.3f}_"
                    f"{'_' + MODEL_NAME_POSTFIX if MODEL_NAME_POSTFIX else ''}_"
                    f"{'_' + PRELOAD_MODEL_NAME if PRELOAD_MODEL_NAME else ''}_"
                    f"best-{metric:.3f}-threshold={val_thr:.3f}.pt"
                )
                
                if isinstance(model, torch.nn.DataParallel):
                    torch.save(model.module.state_dict(),
                            Path(models_dir) / checkpoint_filename)
                else:
                    torch.save(model.state_dict(),
                            Path(models_dir) / checkpoint_filename)

            # update best if improved
            if cur_metric > best_valid_metric:
                best_valid_metric = cur_metric
                best_threshold = val_thr
                if isinstance(model, torch.nn.DataParallel):
                    best_weights = deepcopy(model.module.state_dict())
                else:
                    best_weights = deepcopy(model.state_dict())

        print(f'Epoch {epoch} done.')

    print(f"Best evaluation {metric}: {best_valid_metric:.3f}  (thr={best_threshold:.3f})")
    return train_losses, val_losses, best_valid_metric, best_weights, best_threshold


def get_optimal_num_workers():
    """Calculate optimal number of workers based on system specs."""
    import psutil
    
    num_gpus = torch.cuda.device_count()
    num_cpus = psutil.cpu_count(logical=False)  # Physical cores
    
    if num_gpus > 1:
        # For multi-GPU: fewer workers per GPU to avoid context switching
        workers = min(4, num_cpus // num_gpus)
    else:
        # For single GPU: can use more workers
        workers = min(8, num_cpus)
    
    return max(2, workers)

In [42]:
def _run(
    limit_train_pos_pairs_per_query=None,
    limit_val_pos_pairs_per_query=None,
    limit_test_pos_pairs_per_query=None,
    batch_size_per_device=None,
    precompute_pairs=False,
    images_dir=None
):
    """
    Updated _run() function that tracks best thresholds per epoch to avoid duplicate evaluation.
    """
    assert images_dir is not None

    # Handle batch size scaling for multi-GPU
    if batch_size_per_device is None:
        batch_size_per_device = batch_size_per_device
    
    num_gpus = torch.cuda.device_count()
    if num_gpus > 1:
        print(f"Using {num_gpus} GPUs with DataParallel")
        original_batch_size = batch_size_per_device
        batch_size = batch_size_per_device * num_gpus
        print(f"Scaled batch size from {original_batch_size} to {batch_size}")
    else:
        batch_size = batch_size_per_device

    # ---------- 1) build DataLoaders with PairGenerationDataset ----------
    
    if LIMIT_QUERIES:
        sampled_queries = pairwise_mapping_df['sku_query'].drop_duplicates().sample(n=LIMIT_QUERIES, random_state=RANDOM_SEED)
        actual_pairwise_mapping_df = pairwise_mapping_df[pairwise_mapping_df['sku_query'].isin(sampled_queries)]
    else:
        actual_pairwise_mapping_df = pairwise_mapping_df

    splits_dataset = split_query_groups(
        actual_pairwise_mapping_df,
        test_size=TEST_RATIO,
        val_size=VAL_RATIO,
        random_state=RANDOM_SEED,
    )

    splits = {'train': splits_dataset['train'],
            'val': splits_dataset['val'], 
            'test': splits_dataset['test']}
    loaders = {}

    SPLIT_LIMITS = {
        'train': limit_train_pos_pairs_per_query,
        'val': limit_val_pos_pairs_per_query,
        'test': limit_test_pos_pairs_per_query
    }

    optimal_workers = get_optimal_num_workers()
    print(f"Using {optimal_workers} workers for data loading")

    for split_name, split_df in splits.items():
        if precompute_pairs:
            dataset = PrecomputedPairDataset(
                split_df=splits_dataset['train'],
                source_df=source_df,
                images_dir=images_dir,
                max_pos_pairs_per_query=LIMIT_TRAIN_POS_PAIRS_PER_QUERY,
                pos_neg_ratio=POS_NEG_RATIO,
                hard_soft_ratio=HARD_SOFT_RATIO
            )

            # Print dataset statistics
            stats = dataset.get_stats()
            print("\nDataset Statistics:")
            print(f"Total pairs: {stats['total_pairs']}")
            print(f"Unique SKUs: {stats['unique_skus']}")
            print("\nPair Types:")
            for pair_type, count in stats['pair_types'].items():
                print(f"  {pair_type}: {count}")
            print("\nMemory Usage:")
            print(f"  Images: {stats['memory_usage']['images']:.1f} MB")
            print(f"  Tokens: {stats['memory_usage']['tokens']:.1f} MB")
        else:
            dataset = PairGenerationDataset(
                split_df=split_df,
                cluster_emb_table=cluster_emb_table,
                source_df=source_df,
                images_dir=images_dir,
                max_pos_pairs_per_query=SPLIT_LIMITS[split_name],
                pos_neg_ratio=POS_NEG_RATIO,
                hard_soft_ratio=HARD_SOFT_RATIO,
                random_seed=RANDOM_SEED
            )
        
        stats = dataset.get_batch_stats()
        print(f"{split_name.upper()} Dataset - Total: {stats['total_pairs']}, "
            f"Pos: {stats['positives']}, Hard Neg: {stats['hard_negatives']}, "
            f"Soft Neg: {stats['soft_negatives']}")

        loaders[split_name] = DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=(split_name == 'train'),
            num_workers=optimal_workers,
            pin_memory=True,
            persistent_workers=True,
            prefetch_factor=4,  # Pre-fetch more batches
            drop_last=True,     # Avoid small last batches
        )

    train_loader = loaders['train']
    valid_loader = loaders['val']
    test_loader = loaders['test']

    # ---------- 2) model / optimizer / criterion ----------
    print('\n===== Starting training =====')
    print("Loading model and optimizer…")
    model = SiameseRuCLIP(
        DEVICE,
        NAME_MODEL_NAME,
        DESCRIPTION_MODEL_NAME,
        PRELOAD_MODEL_NAME,
        models_dir=DATA_PATH + RESULTS_DIR,
        dropout=DROPOUT
    )
    print("Loaded model and optimizer.")
    
    # ADD MULTI-GPU SUPPORT
    if num_gpus > 1:
        print(f"Using {num_gpus} GPUs with DataParallel")
        model = torch.nn.DataParallel(model)
    
    model = model.to(DEVICE)

    # ADD MISSING CRITERION AND OPTIMIZER
    criterion = ContrastiveLoss(margin=CONTRASTIVE_MARGIN).to(DEVICE)
    
    optimizer = torch.optim.AdamW(
        model.parameters(), 
        lr=LR, 
        weight_decay=WEIGHT_DECAY
    )

    # ---------- 3) training with threshold tracking ----------
    with tempfile.TemporaryDirectory() as tmp_ckpt_dir:
        (train_losses, val_losses,
        best_metric_val, best_weights, 
        best_threshold) = train_with_threshold_tracking(
            model, optimizer, criterion,
            EPOCHS, train_loader, valid_loader,
            print_epoch=True, device=DEVICE,
            models_dir=tmp_ckpt_dir,
            metric=BEST_CKPT_METRIC,
            source_df=source_df,
            images_dir=images_dir,
            precompute_pairs=precompute_pairs  # Set this based on your dataset type
        )

    print(f"→ Best evaluation {BEST_CKPT_METRIC}: {best_metric_val:.3f} (threshold: {best_threshold:.3f})")

    # ---------- 4) loss curves ----------
    if len(train_losses) >= 2:
        epochs_ax = list(range(2, len(train_losses) + 1))
        fig, ax = plt.subplots()
        ax.plot(epochs_ax, train_losses[1:], label='Train Loss')
        ax.plot(epochs_ax, val_losses[1:],   label='Val   Loss')
        ax.set_xlabel('Epoch')
        ax.set_ylabel('Loss')
        ax.set_title('Training & evaluation Loss by Epoch')
        ax.legend()
        if MLFLOW_URI:
            mlflow.log_figure(fig, 'loss_by_epoch.png')
        display.clear_output(wait=True)
        display.display(fig)
        plt.close(fig)

    # ---------- 5) final TEST (using best threshold) ----------
    model.load_state_dict(best_weights)
    print(f"Using best threshold: {best_threshold:.3f}")
    
    (test_pos_acc, test_neg_acc,
     test_acc, test_f1,
     test_loss, _) = evaluation(
        model, criterion, test_loader,
        epoch='test', device=DEVICE,
        split_name='test',
        threshold=best_threshold,
        metric=BEST_CKPT_METRIC,
        source_df=source_df,
        images_dir=images_dir,
        precompute_pairs=precompute_pairs
    )

    test_metric = test_pos_acc if BEST_CKPT_METRIC == 'pos_acc' else test_f1
    print(f"Test {BEST_CKPT_METRIC}: {test_metric:.3f}")

    # ---------- 6) save checkpoint ----------
    filename = (
        f"siamese_contrastive_soft-neg_epoch={EPOCHS}_"
        f"test-f1={test_f1:.3f}_test-pos-acc={test_pos_acc:.3f}_test-neg-acc={test_neg_acc:.3f}_"
        f"{'_' + MODEL_NAME_POSTFIX if MODEL_NAME_POSTFIX else ''}_"
        f"{'_' + PRELOAD_MODEL_NAME if PRELOAD_MODEL_NAME else ''}_"
        f"best-{BEST_CKPT_METRIC:.3f}-threshold={best_threshold:.3f}.pt"
    )
    final_path = Path(DATA_PATH + RESULTS_DIR) / filename
    final_path.parent.mkdir(parents=True, exist_ok=True)
    
    # Handle DataParallel state dict saving
    if isinstance(model, torch.nn.DataParallel):
        torch.save(best_weights, final_path)
    else:
        torch.save(best_weights, final_path)
    
    print(f"Saved best‐{BEST_CKPT_METRIC} checkpoint to:\n{final_path}")

    if MLFLOW_URI:
        mlflow.log_metric("test_pos_accuracy", test_pos_acc)
        mlflow.log_metric("test_neg_accuracy", test_neg_acc)
        mlflow.log_metric("test_accuracy",     test_acc)
        mlflow.log_metric("test_f1_score",     test_f1)
        mlflow.end_run()


In [40]:
_run(
    limit_train_pos_pairs_per_query=LIMIT_TRAIN_POS_PAIRS_PER_QUERY,
    limit_val_pos_pairs_per_query=LIMIT_VAL_POS_PAIRS_PER_QUERY,
    limit_test_pos_pairs_per_query=LIMIT_TEST_POS_PAIRS_PER_QUERY,
    batch_size_per_device=BATCH_SIZE_PER_DEVICE,
    precompute_pairs=True,
    images_dir = os.path.join(DATA_PATH, IMG_DATASET_NAME)

)

Using 4 workers for data loading

=== Initializing PrecomputedPairDataset ===

1. Generating pairs metadata...
Generating pairs with parameters:
  max_pos_pairs_per_query: 2
  pos_neg_ratio: 1.0
  hard_soft_ratio: 0.5


Generating pairs: 100%|██████████| 2/2 [00:00<00:00, 819.52it/s, pos=2, hard=1, soft=1] 

Generated 8 pairs

2. Starting data precomputation...





Found 15 unique SKUs


Overall progress: 100%|██████████| 15/15 [00:00<00:00, 26.43it/s]



=== Dataset Ready ===
Total pairs: 8
Unique SKUs cached: 15

Dataset Statistics:
Total pairs: 8
Unique SKUs: 15

Pair Types:
  positive: 4
  hard_negative: 2
  soft_negative: 2

Memory Usage:
  Images: 8.6 MB
  Tokens: 0.0 MB
TRAIN Dataset - Total: 8, Pos: 4, Hard Neg: 2, Soft Neg: 2

=== Initializing PrecomputedPairDataset ===

1. Generating pairs metadata...
Generating pairs with parameters:
  max_pos_pairs_per_query: 2
  pos_neg_ratio: 1.0
  hard_soft_ratio: 0.5


Generating pairs: 100%|██████████| 2/2 [00:00<00:00, 395.20it/s, pos=2, hard=1, soft=1] 


Generated 8 pairs

2. Starting data precomputation...
Found 15 unique SKUs


Overall progress: 100%|██████████| 15/15 [00:00<00:00, 31.72it/s]



=== Dataset Ready ===
Total pairs: 8
Unique SKUs cached: 15

Dataset Statistics:
Total pairs: 8
Unique SKUs: 15

Pair Types:
  positive: 4
  hard_negative: 2
  soft_negative: 2

Memory Usage:
  Images: 8.6 MB
  Tokens: 0.0 MB
VAL Dataset - Total: 8, Pos: 4, Hard Neg: 2, Soft Neg: 2

=== Initializing PrecomputedPairDataset ===

1. Generating pairs metadata...
Generating pairs with parameters:
  max_pos_pairs_per_query: 2
  pos_neg_ratio: 1.0
  hard_soft_ratio: 0.5


Generating pairs: 100%|██████████| 2/2 [00:00<00:00, 421.64it/s, pos=2, hard=1, soft=1] 


Generated 8 pairs

2. Starting data precomputation...
Found 15 unique SKUs


Overall progress: 100%|██████████| 15/15 [00:00<00:00, 30.86it/s]



=== Dataset Ready ===
Total pairs: 8
Unique SKUs cached: 15

Dataset Statistics:
Total pairs: 8
Unique SKUs: 15

Pair Types:
  positive: 4
  hard_negative: 2
  soft_negative: 2

Memory Usage:
  Images: 8.6 MB
  Tokens: 0.0 MB
TEST Dataset - Total: 8, Pos: 4, Hard Neg: 2, Soft Neg: 2

===== Starting training =====
Loading model and optimizer…
Loaded model and optimizer.


Training Epoch 1: 100%|██████████| 8/8 [00:11<00:00,  1.40s/it]
Evaluation on val: 100%|██████████| 8/8 [00:02<00:00,  3.30it/s]


[val] Epoch 1 – loss: 1.1622, P Acc: 0.250, N Acc: 1.000, Avg Acc: 0.625, F1: 0.400, thr*: 0.641 (optimised: f1)
Epoch 1 done.
Best evaluation f1: 0.400  (thr=0.641)
→ Best evaluation f1: 0.400 (threshold: 0.641)
Using best threshold: 0.641


Evaluation on test: 100%|██████████| 8/8 [00:02<00:00,  2.71it/s]


[test] Epoch test – loss: 1.1622, P Acc: 0.250, N Acc: 1.000, Avg Acc: 0.625, F1: 0.400, thr*: 0.641 (optimised: f1)
Test f1: 0.400
Saved best‐f1 checkpoint to:
data/train_results/siamese_contrastive_test-f1=0.400_test-pos-acc=0.250_test-neg-acc=1.000_splitting-by-query_cc12m_rubert_tiny_ep_1.pt_best-0.4-threshold=0.641.pt
