# Installs & tokens

In [1]:
%%capture
try:
    import dotenv
except ImportError:
    !pip install python-dotenv

In [2]:
# Log into huggingface via Kaggle Secrets or .env

import os
from dotenv import load_dotenv
import huggingface_hub

try:
    from kaggle_secrets import UserSecretsClient

    user_secrets = UserSecretsClient()
    HF_TOKEN = user_secrets.get_secret("HF_TOKEN")
except ModuleNotFoundError:
    print("Not Kaggle environment. Skipping Kaggle secrets.")
    print("Trying to load HF_TOKEN from .env.")
    load_dotenv()
    HF_TOKEN = os.getenv("HF_TOKEN")
    print("Success!")

huggingface_hub.login(token=HF_TOKEN)

Not Kaggle environment. Skipping Kaggle secrets.
Trying to load HF_TOKEN from .env.
Success!


Note: Environment variable`HF_TOKEN` is set and is the current active token independently from the token you've just configured.


# Choose notebook parameters

In [53]:
import torch

## CHOOSE MODEL PARAMETERS #################################################

MODEL_NAME_POSTFIX='splitting-by-query'
NAME_MODEL_NAME = 'cointegrated/rubert-tiny' # 'DeepPavlov/distilrubert-tiny-cased-conversational-v1'
DESCRIPTION_MODEL_NAME = 'cointegrated/rubert-tiny'

DATA_PATH = 'data/'
RESULTS_DIR = 'train_results/'

# BATCH_SIZE=60 # uses 14.5GiB of 1 GPU
# NUM_WORKERS=2 # TODO: use multiple GPU, tune number of workers
# SMOKE_TEST_BATCHES=None
# EPOCHS=10 # epochs > 8 => overfit; NOTE: can train for longer since we take best validation checkpoint anyway

BATCH_SIZE=5
NUM_WORKERS=0
SMOKE_TEST_BATCHES=10
EPOCHS=2

PRELOAD_MODEL_NAME = 'cc12m_rubert_tiny_ep_1.pt' # preload ruclip
# PRELOAD_MODEL_NAME = None

POS_WEIGHT = 4.0 # TODO: infer from data

# USE_ALL_TRAIN_PAIRS = False
# MAX_SAMPLES_PER_EPOCH = None

USE_ALL_TRAIN_PAIRS = True
MAX_SAMPLES_PER_EPOCH = 2_500
# MAX_SAMPLES_PER_EPOCH = 2_500 * 12

DROPOUT = 0.5
# DROPOUT = None

# BEST_CKPT_METRIC = 'f1'
BEST_CKPT_METRIC = 'pos_acc'

VALIDATION_SPLIT=.05
TEST_SPLIT=.1
RANDOM_SEED=42
LR=9e-5
MOMENTUM=0.9
WEIGHT_DECAY=1e-2
CONTRASTIVE_MARGIN=1.5
CONTRASTIVE_THRESHOLD=0.3
SHEDULER_PATIENCE=3 # in epochs

DEVICE='cuda' if torch.cuda.is_available() else 'cpu'

In [4]:
## CHOOSE DATA #########################################################

DATA_PATH=  'data/'
SOURCE_TABLE_NAME = 'tables_OZ_geo_5500/processed/OZ_geo_5500.csv'

# --- Load source_df and pairwise_mapping_df from Parquet ---
SOURCE_TABLE_NAME = 'tables_OZ_geo_5500/processed/OZ_geo_5500.csv'
PAIRWISE_TABLE_NAME = 'tables_OZ_geo_5500/processed/regex-pairwise-groups/regex-pairwise-groups_num-queries=20_patterns-dict-hash=6dbf9b3ef9568e60cd959f87be7e3b26.parquet'
IMG_DATASET_NAME = 'images_OZ_geo_5500'

In [5]:
## LOGGING PARAMS ######################################################################

# MLFLOW_URI = "http://176.56.185.96:5000"
# MLFLOW_URI = "http://localhost:5000"
MLFLOW_URI = None

MLFLOW_EXPERIMENT = "siamese/1fold"

TELEGRAM_TOKEN = None
# TELEGRAM_TOKEN = '' # set token to get notifications

# Definitions

In [6]:
# Imports
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
from transformers import AutoModel, AutoTokenizer

import mlflow
from mlflow.models import infer_signature

from timm import create_model
import numpy as np
import pandas as pd
import os
import torch
from torch import nn
from torch import optim, Tensor
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchvision import transforms
# from torchinfo import summary
# import transformers
# from transformers import DistilBertModel, DistilBertConfig, DistilBertTokenizer,\
#         get_linear_schedule_with_warmup
from transformers import AutoModel, AutoTokenizer

import cv2

from PIL import Image
from tqdm.auto import tqdm

# import json
# from itertools import product

# import datasets
# from datasets import Dataset, concatenate_datasets
# import argparse
import requests

# from io import BytesIO
# from sklearn.model_selection import StratifiedKFold, train_test_split
from sklearn.metrics import balanced_accuracy_score, roc_auc_score, f1_score
import matplotlib.pyplot as plt
from IPython.display import clear_output
# import more_itertools

from sklearn.model_selection import train_test_split
import torch
import torch.nn.functional as F
from pathlib import Path
from tqdm import tqdm
from sklearn.metrics import f1_score
import mlflow
from torch.optim.lr_scheduler import ReduceLROnPlateau
import matplotlib.pyplot as plt
import tempfile

In [7]:
def make_tg_report(text, token=None) -> None:
    method = 'sendMessage'
    chat_id = 324956476
    _ = requests.post(
            url='https://api.telegram.org/bot{0}/{1}'.format(token, method),
            data={'chat_id': chat_id, 'text': text} 
        ).json()

In [8]:
class RuCLIPtiny(nn.Module):
    def __init__(self, name_model_name):
        super().__init__()
        self.visual = create_model('convnext_tiny',
                                   pretrained=False, # TODO: берём претрейн
                                   num_classes=0,
                                   in_chans=3)  # out 768

        self.transformer = AutoModel.from_pretrained(name_model_name)
        name_model_output_shape = self.transformer.config.hidden_size  # dynamically get hidden size
        self.final_ln = torch.nn.Linear(name_model_output_shape, 768)  # now uses the transformer hidden size
        self.logit_scale = torch.nn.Parameter(torch.ones([]) * np.log(1 / 0.07))

    @property
    def dtype(self):
        return self.visual.stem[0].weight.dtype

    def encode_image(self, image):
        return self.visual(image.type(self.dtype))

    def encode_text(self, input_ids, attention_mask):
        x = self.transformer(input_ids=input_ids, attention_mask=attention_mask)
        x = x.last_hidden_state[:, 0, :]
        x = self.final_ln(x)
        return x

    def forward(self, image, input_ids, attention_mask):
        image_features = self.encode_image(image)
        text_features = self.encode_text(input_ids, attention_mask)

        # normalized features
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        text_features = text_features / text_features.norm(dim=-1, keepdim=True)

        # cosine similarity as logits
        logit_scale = self.logit_scale.exp()
        logits_per_image = logit_scale * image_features @ text_features.t()
        logits_per_text = logits_per_image.t()

        return logits_per_image, logits_per_text

In [9]:
def get_transform():
    return transforms.Compose([
        transforms.Resize(224),
        transforms.CenterCrop(224),
        _convert_image_to_rgb,
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]), ])

def _convert_image_to_rgb(image):
    return image.convert("RGB")

class Tokenizers:
    def __init__(self):
        self.name_tokenizer = AutoTokenizer.from_pretrained(NAME_MODEL_NAME)
        self.desc_tokenizer = AutoTokenizer.from_pretrained(DESCRIPTION_MODEL_NAME)

    def tokenize_name(self, texts, max_len=77):
        tokenized = self.name_tokenizer.batch_encode_plus(texts,
                                                     truncation=True,
                                                     add_special_tokens=True,
                                                     max_length=max_len,
                                                     padding='max_length',
                                                     return_attention_mask=True,
                                                     return_tensors='pt')
        return torch.stack([tokenized["input_ids"], tokenized["attention_mask"]])
    
    def tokenize_description(self, texts, max_len=77):
        tokenized = self.desc_tokenizer(texts,
                                        truncation=True,
                                        add_special_tokens=True,
                                        max_length=max_len,
                                        padding='max_length',
                                        return_attention_mask=True,
                                        return_tensors='pt')
        return torch.stack([tokenized["input_ids"], tokenized["attention_mask"]])

class SiameseRuCLIPDataset(torch.utils.data.Dataset):
    def __init__(self, df=None, labels=None, df_path=None, images_dir=DATA_PATH+'images/'):
        # loads data either from path using `df_path` or directly from `df` argument
        self.df = pd.read_csv(df_path) if df_path is not None else df
        self.labels = labels
        self.images_dir = images_dir
        self.tokenizers = Tokenizers()
        self.transform = get_transform()
        # 
        self.max_len = 77
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        name_tokens = self.tokenizers.tokenize_name([str(row.name_first), 
                                               str(row.name_second)], max_len=self.max_len)
        name_first = name_tokens[:, 0, :] # [input_ids, attention_mask]
        name_second = name_tokens[:, 1, :]
        desc_tokens = self.tokenizers.tokenize_description([str(row.description_first), 
                                               str(row.description_second)])
        desc_first = desc_tokens[:, 0, :] # [input_ids, attention_mask]
        desc_second = desc_tokens[:, 1, :]
        im_first = cv2.imread(os.path.join(self.images_dir, row.image_name_first))
        im_first = cv2.cvtColor(im_first, cv2.COLOR_BGR2RGB)
        im_first = Image.fromarray(im_first)
        im_first = self.transform(im_first)
        im_second = cv2.imread(os.path.join(self.images_dir, row.image_name_second))
        im_second = cv2.cvtColor(im_second, cv2.COLOR_BGR2RGB)
        im_second = Image.fromarray(im_second)
        im_second = self.transform(im_second)
        label = self.labels[idx]
        return im_first, name_first, desc_first, im_second, name_second, desc_second, label

    def __len__(self,):
        return len(self.df)

In [10]:
def average_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
    last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
    return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]

class SiameseRuCLIP(nn.Module):
    def __init__(self,
                 device: str,
                 name_model_name: str,
                 description_model_name: str,
                 preload_model_name: str = None,
                 models_dir: str = None,
                 dropout: float = None):
        """
        Initializes the SiameseRuCLIP model.
        Required parameters:
          - models_dir: directory containing saved checkpoints.
          - name_model_name: model name for text (name) branch.
          - description_model_name: model name for description branch.
        """
        super().__init__()
        device = torch.device(device)

        # Initialize RuCLIPtiny
        self.ruclip = RuCLIPtiny(name_model_name)
        if preload_model_name is not None:
            std = torch.load(
                os.path.join(models_dir, preload_model_name),
                weights_only=True,
                map_location=device
            )
            self.ruclip.load_state_dict(std)
            self.ruclip.eval()
        self.ruclip = self.ruclip.to(device)

        # Initialize the description transformer
        self.description_transformer = AutoModel.from_pretrained(description_model_name)
        self.description_transformer = self.description_transformer.to(device)

        # Determine dimensionality
        vision_dim = self.ruclip.visual.num_features
        name_dim = self.ruclip.final_ln.out_features
        desc_dim = self.description_transformer.config.hidden_size
        self.hidden_dim = vision_dim + name_dim + desc_dim
        self.dropout = dropout

        # Define MLP head with optional dropout
        layers = [
            nn.Linear(self.hidden_dim, self.hidden_dim // 2),
            nn.ReLU(),
            *( [nn.Dropout(self.dropout)] if self.dropout is not None else [] ),
            nn.Linear(self.hidden_dim // 2, self.hidden_dim // 4),
        ]
        self.head = nn.Sequential(*layers).to(device)


    def encode_image(self, image):
        return self.ruclip.encode_image(image)

    def encode_name(self, name):
        return self.ruclip.encode_text(name[:, 0, :], name[:, 1, :])

    def encode_description(self, desc):
        last_hidden_states = self.description_transformer(desc[:, 0, :], desc[:, 1, :]).last_hidden_state
        attention_mask = desc[:, 1, :]
        return average_pool(last_hidden_states, attention_mask)

    def get_final_embedding(self, im, name, desc):
        image_emb = self.encode_image(im)
        name_emb = self.encode_name(name)
        desc_emb = self.encode_description(desc)

        # Concatenate the embeddings and forward through the head
        combined_emb = torch.cat([image_emb, name_emb, desc_emb], dim=1)
        final_embedding = self.head(combined_emb)
        return final_embedding

    def forward(self, im1, name1, desc1, im2, name2, desc2):
        out1 = self.get_final_embedding(im1, name1, desc1)
        out2 = self.get_final_embedding(im2, name2, desc2)
        return out1, out2

In [11]:
from torch.utils.data import WeightedRandomSampler
# Remove pos_weight from ContrastiveLoss
class ContrastiveLoss(torch.nn.Module):
    def __init__(self, margin: float = 1.5):
        super().__init__()
        self.margin = margin

    def forward(self, output1, output2, label):
        d = F.pairwise_distance(output1, output2)
        pos = (1 - label) * d.pow(2)
        neg = label * F.relu(self.margin - d).pow(2)
        return (pos + neg).mean()  # No weighting needed with balanced data

# Use WeightedRandomSampler for training
def create_balanced_train_loader(train_dataset, batch_size=32):
    """Create balanced training loader using weighted sampling."""
    labels = [pair['label'] for pair in train_dataset.pairs]
    class_counts = np.bincount(labels)
    class_weights = 1.0 / class_counts
    sample_weights = [class_weights[label] for label in labels]
    
    sampler = WeightedRandomSampler(
        weights=sample_weights,
        num_samples=len(sample_weights),
        replacement=True
    )
    
    return DataLoader(
        train_dataset,
        batch_size=batch_size,
        sampler=sampler,
        num_workers=0,
        pin_memory=False
    )

In [12]:
def evaluate_pair(output1, output2, target, threshold):
    euclidean_distance = F.pairwise_distance(output1, output2)
    # меньше границы, там где будет True — конкуренты
    cond = euclidean_distance < threshold
    pos_sum = 0
    neg_sum = 0
    pos_acc = 0
    neg_acc = 0

    for i in range(len(cond)):
        # 1 значит не конкуренты
        if target[i]:
            neg_sum+=1
            # 0 в cond значит дальше друг от друга чем threshold
            if not cond[i]:
                neg_acc+=1
        elif not target[i]:
            pos_sum+=1
            if cond[i]:
                pos_acc+=1

    return pos_acc, pos_sum, neg_acc, neg_sum

def predict(out1, out2, threshold=CONTRASTIVE_THRESHOLD):
    # вернёт 1 если похожи
    return F.pairwise_distance(out1, out2) < threshold

# Prepare data

## Download data from HF

In [19]:
# Download models' weights & text/image datasets

from huggingface_hub import snapshot_download
from pathlib import Path

REPO_ID = "INDEEPA/clip-siamese"
LOCAL_DIR = Path("data/train_results")
LOCAL_DIR.mkdir(parents=True, exist_ok=True)

snapshot_download(
    repo_id=REPO_ID,
    repo_type='dataset',
    local_dir='data',
    allow_patterns=[
        "train_results/cc12m*.pt",
        SOURCE_TABLE_NAME, PAIRWISE_TABLE_NAME,
        f"{IMG_DATASET_NAME}.zip"
    ],
)

# The following shell command was removed for script compatibility:
# !unzip -n -q data/{IMG_DATASET_NAME}.zip -d data/

# If you need to unzip in Python, use:
# import zipfile
# with zipfile.ZipFile(f"data/{IMG_DATASET_NAME}.zip", 'r') as zip_ref:
#     zip_ref.extractall("data/")


Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

'/home/anton/marketplace/clip-siamese/data'

In [20]:
source_df = pd.read_csv(DATA_PATH + SOURCE_TABLE_NAME)
pairwise_mapping_df = pd.read_parquet(DATA_PATH + PAIRWISE_TABLE_NAME)
pairwise_mapping_df.sku_query.nunique()

20

# Cluster soft negatives

In [21]:
# Suggest the correct path to the embedding file based on the context and previous file saving logic
# CHOSEN_EMBEDDING_FILE = 'embeddings/OZ_geo_5500/OZ_geo_5500_name-and-description_embeddings_num-rows=2.parquet'
CHOSEN_EMBEDDING_FILE = 'embeddings/OZ_geo_5500/OZ_geo_5500_name-and-description_embeddings_num-rows=5562.parquet'

In [22]:
from huggingface_hub import hf_hub_download
import pandas as pd

# Download the chosen embedding file from HuggingFace Hub to DATA_PATH
from pathlib import Path

downloaded_emb_file = hf_hub_download(
    repo_id="INDEEPA/clip-siamese",
    repo_type="dataset",
    filename=CHOSEN_EMBEDDING_FILE,
    local_dir=DATA_PATH,
)

print(f"Downloaded embedding file to:\n{downloaded_emb_file}")
emb_table = pd.read_parquet(downloaded_emb_file)
emb_table.head()

Downloaded embedding file to:
data/embeddings/OZ_geo_5500/OZ_geo_5500_name-and-description_embeddings_num-rows=5562.parquet


Unnamed: 0,sku,name_desc_emb
0,1871769771,"[-0.020089346915483475, -0.05487045273184776, ..."
1,1679550303,"[-0.00418242160230875, -0.04088427498936653, 0..."
2,1200553001,"[-0.023978281766176224, -0.05447990819811821, ..."
3,922231521,"[-0.024106157943606377, -0.053567297756671906,..."
4,922230517,"[-0.02229023538529873, -0.05309479311108589, -..."


In [23]:
from sklearn.cluster import HDBSCAN
import numpy as np

# Prepare the embeddings as a numpy array
embeddings = np.stack(emb_table['name_desc_emb'].values)

# Run HDBSCAN clustering using sklearn's implementation
# Use coarser clustering: increase min_cluster_size and set min_samples for more robust, larger clusters
clusterer = HDBSCAN(
    min_samples=2,
    # min_cluster_size=20,
    # min_samples=10,
    metric='cosine'
)
cluster_labels = clusterer.fit_predict(embeddings)

# Add cluster labels to the emb_table and assign to cluster_emb_table
cluster_emb_table = emb_table.copy()
cluster_emb_table['cluster_id'] = cluster_labels

# Print cluster label counts
print("Cluster label counts:")
display(cluster_emb_table['cluster_id'].value_counts().to_frame().T)

Cluster label counts:


cluster,-1,396,1,98,216,389,383,159,495,269,...,429,418,12,303,138,424,373,490,428,459
count,1244,64,47,42,38,30,29,27,26,25,...,5,5,5,5,5,5,5,5,5,5


In [24]:
# Print cluster ids with size > N
N = 10  # You can change N to any desired threshold
cluster_counts = cluster_emb_table['cluster_id'].value_counts()
large_clusters = cluster_counts[cluster_counts > N].to_frame()
print(f"Cluster IDs with size > {N}:")
display(large_clusters.T)


Cluster IDs with size > 10:


cluster,-1,396,1,98,216,389,383,159,495,269,...,440,374,414,27,445,69,355,301,182,207
count,1244,64,47,42,38,30,29,27,26,25,...,11,11,11,11,11,11,11,11,11,11


In [25]:
# Print SKUs for a given CLUSTER_ID
CLUSTER_ID = 80  # Change this to the desired cluster id

skus_in_cluster = cluster_emb_table.loc[cluster_emb_table['cluster_id'] == CLUSTER_ID, 'sku']
print(f"SKUs in cluster {CLUSTER_ID}:")
display(skus_in_cluster.tolist()[:10])

SKUs in cluster 80:


[1707837031,
 1706808656,
 1706808406,
 1589325642,
 1589325623,
 1589325615,
 1589324360,
 1589323257,
 1589310326]

# Make pairwise dataset

In [26]:
# split_query_groups

def split_query_groups(
    mapping_df: pd.DataFrame,
    test_size: float = 0.2,
    val_size: float = 0.05,
    random_state: int = 42
):
    """
    For each query SKU group, splits:
      - all query SKUs into test split (always, not by %)
      - positives, hard negatives, soft negatives (excluding query SKU) into test/val/train
    Returns: dict with keys 'train', 'val', 'test', each a DataFrame with columns:
      ['sku_query', 'split', 'sku_pos', 'sku_hard_neg', 'sku_soft_neg']
    """
    rng = np.random.default_rng(random_state)
    split_rows = []

    for _, row in mapping_df.iterrows():
        q = row['sku_query']
        pos = set(row['sku_pos']) - {q}
        hard_neg = set(row['sku_hard_neg']) - {q}
        soft_neg = set(row['sku_soft_neg']) - {q}

        def split_list(lst, test_frac, val_frac):
            lst = np.array(list(lst))
            n = len(lst)
            n_test = int(np.ceil(test_frac * n))
            n_val = int(np.ceil(val_frac * n))
            idx = rng.permutation(n)
            test_idx = idx[:n_test]
            val_idx = idx[n_test:n_test+n_val]
            train_idx = idx[n_test+n_test+n_val:]
            return lst[train_idx].tolist(), lst[val_idx].tolist(), lst[test_idx].tolist()

        pos_train, pos_val, pos_test = split_list(pos, test_size, val_size)
        hard_train, hard_val, hard_test = split_list(hard_neg, test_size, val_size)
        soft_train, soft_val, soft_test = split_list(soft_neg, test_size, val_size)

        # Only add the actual splits, not the empty test row
        split_rows.append({
            'sku_query': q,
            'split': 'train',
            'sku_pos': pos_train,
            'sku_hard_neg': hard_train,
            'sku_soft_neg': soft_train
        })
        split_rows.append({
            'sku_query': q,
            'split': 'val',
            'sku_pos': pos_val,
            'sku_hard_neg': hard_val,
            'sku_soft_neg': soft_val
        })
        split_rows.append({
            'sku_query': q,
            'split': 'test',
            'sku_pos': pos_test,
            'sku_hard_neg': hard_test,
            'sku_soft_neg': soft_test
        })

    split_df = pd.DataFrame(split_rows)
    split_dict = {
        split: split_df[split_df['split'] == split].reset_index(drop=True)
        for split in ['train', 'val', 'test']
    }
    return split_dict

In [27]:
# # check positives intersections
# # TODO: merge sku having many intersections; disentangle common positives for skus having small common number of positives

# import numpy as np
# import pandas as pd

# # Build a list of queries and a mapping from query to set of positives
# queries = pairwise_mapping_df['sku_query'].tolist()
# sku_to_pos = {
#     row['sku_query']: set(row['sku_pos'])
#     for _, row in pairwise_mapping_df.iterrows()
# }

# num_queries = len(queries)
# intersection_matrix = np.zeros((num_queries, num_queries), dtype=int)

# for i, q1 in enumerate(queries):
#     pos1 = sku_to_pos[q1]
#     for j, q2 in enumerate(queries):
#         pos2 = sku_to_pos[q2]
#         intersection_matrix[i, j] = len(pos1 & pos2)

# # Mask diagonal and above with -1
# mask = np.triu(np.ones_like(intersection_matrix, dtype=bool))
# intersection_matrix[mask] = -1

# # Optionally, wrap as DataFrame for readability
# intersection_df = pd.DataFrame(
#     intersection_matrix, 
#     index=queries, 
#     columns=queries
# )
# intersection_df

In [28]:
splits_dataset = split_query_groups(
    pairwise_mapping_df,
    test_size=0.1,
    val_size=0.1,
    random_state=42
)

splits_dataset['test'].head(1)

Unnamed: 0,sku_query,split,sku_pos,sku_hard_neg,sku_soft_neg
0,1871769771,test,[467420540],"[1418084594, 1573142945, 1536520050, 1573135817]","[1899881468, 1290396077, 1597431764, 165269677..."


In [29]:
# Display stats for each split: number of unique SKUs in pos/hard/soft/total as a table
import pandas as pd
from IPython.display import display

split_stats = []
for split_name, split_df in splits_dataset.items():
    pos_skus = set()
    hard_neg_skus = set()
    soft_neg_skus = set()
    for _, row in split_df.iterrows():
        if isinstance(row['sku_pos'], list):
            pos_skus.update(row['sku_pos'])
        if isinstance(row['sku_hard_neg'], list):
            hard_neg_skus.update(row['sku_hard_neg'])
        if isinstance(row['sku_soft_neg'], list):
            soft_neg_skus.update(row['sku_soft_neg'])
    total_skus = pos_skus | hard_neg_skus | soft_neg_skus
    split_stats.append({
        "split": split_name,
        "unique_pos_skus": len(pos_skus),
        "unique_hard_neg_skus": len(hard_neg_skus),
        "unique_soft_neg_skus": len(soft_neg_skus),
        "unique_total_skus": len(total_skus)
    })

stats_df = pd.DataFrame(split_stats).set_index("split")
display(stats_df)


Unnamed: 0_level_0,unique_pos_skus,unique_hard_neg_skus,unique_soft_neg_skus,unique_total_skus
split,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
train,547,1550,5562,5562
val,121,449,4795,4898
test,124,444,4802,4917


In [30]:
# print_split_summary

def print_split_summary(split_dict):
    summary = []
    for split_name, split_df in split_dict.items():
        pos_counts = split_df['sku_pos'].apply(lambda x: len(x) if isinstance(x, list) else 0)
        hard_counts = split_df['sku_hard_neg'].apply(lambda x: len(x) if isinstance(x, list) else 0)
        soft_counts = split_df['sku_soft_neg'].apply(lambda x: len(x) if isinstance(x, list) else 0)
        total_per_query = pos_counts + hard_counts + soft_counts
        summary.append({
            'split': split_name,
            'avg_pos_per_query': pos_counts.mean(),
            'avg_hard_neg_per_query': hard_counts.mean(),
            'avg_soft_neg_per_query': soft_counts.mean(),
            'avg_total_per_query': total_per_query.mean(),
            'total_pos': pos_counts.sum(),
            'total_hard_neg': hard_counts.sum(),
            'total_soft_neg': soft_counts.sum(),
            'total': pos_counts.sum() + hard_counts.sum() + soft_counts.sum(),
            'num_queries': len(split_df)
        })
    import pandas as pd
    df = pd.DataFrame(summary).set_index('split')
    print("Average number of pos/hard/soft/total per query per split:")
    display(df[['avg_pos_per_query', 'avg_hard_neg_per_query', 'avg_soft_neg_per_query', 'avg_total_per_query']])
    print("\nTotal number of pos/hard/soft/total per split:")
    display(df[['total_pos', 'total_hard_neg', 'total_soft_neg', 'total']])

print_split_summary(splits_dataset)

Average number of pos/hard/soft/total per query per split:


Unnamed: 0_level_0,avg_pos_per_query,avg_hard_neg_per_query,avg_soft_neg_per_query,avg_total_per_query
split,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
train,40.15,174.6,3673.35,3888.1
val,6.5,25.7,525.45,557.65
test,6.5,25.75,525.45,557.7



Total number of pos/hard/soft/total per split:


Unnamed: 0_level_0,total_pos,total_hard_neg,total_soft_neg,total
split,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
train,803,3492,73467,77762
val,130,514,10509,11153
test,130,515,10509,11154


In [31]:
from torch.utils.data import Dataset

from torch.utils.data import Dataset

class ClusterBasedPairwiseDataset(Dataset):
    def __init__(self, split_df, source_df, cluster_emb_table=None, images_dir=None,
                 n_pos=None, n_hard_neg=5, n_soft_neg=5, use_all_pos=True,
                 cluster_sampling=True, n_soft_neg_per_cluster=1, lazy_loading=True):
        """
        Memory-optimized dataset with lazy loading option.
        
        Args:
            lazy_loading: If True, generate pairs on-the-fly instead of pre-generating all
            n_pos: Number of positives to sample (None = use all if use_all_pos=True)
            n_hard_neg: Number of hard negatives to sample (None = use all available)
            n_soft_neg: Number of soft negatives to sample (None = use all available)
        """
        self._current_idx = 0  # Track current index for evaluation
        self.split_df = split_df.reset_index(drop=True)
        self.source_df = source_df.set_index('sku')
        self.images_dir = images_dir
        self.cluster_emb_table = cluster_emb_table
        self.n_pos = n_pos
        self.n_hard_neg = n_hard_neg
        self.n_soft_neg = n_soft_neg
        self.use_all_pos = use_all_pos
        self.cluster_sampling = cluster_sampling
        self.n_soft_neg_per_cluster = n_soft_neg_per_cluster
        self.lazy_loading = lazy_loading
        
        # Pre-compute cluster mappings if cluster sampling is enabled
        if self.cluster_sampling and cluster_emb_table is not None:
            self._build_cluster_mappings()
        
        if self.lazy_loading:
            # Only generate pair indices/metadata
            self._generate_pair_indices()
        else:
            # Pre-generate all pairs (original behavior)
            self._generate_all_pairs()

    def _build_cluster_mappings(self):
        """Build mappings from cluster to SKUs for efficient sampling."""
        self.sku_to_cluster = dict(zip(self.cluster_emb_table['sku'], self.cluster_emb_table['cluster_id']))
        self.cluster_to_skus = {}
        for _, row in self.cluster_emb_table.iterrows():
            cluster = row['cluster_id']
            sku = row['sku']
            if cluster not in self.cluster_to_skus:
                self.cluster_to_skus[cluster] = []
            self.cluster_to_skus[cluster].append(sku)

    def _generate_pair_indices(self):
        """Generate only the metadata needed to create pairs on-the-fly."""
        self.pair_indices = []
        
        for split_idx, row in self.split_df.iterrows():
            q = row['sku_query']
            pos = row['sku_pos'] if isinstance(row['sku_pos'], list) else []
            hard_neg = row['sku_hard_neg'] if isinstance(row['sku_hard_neg'], list) else []
            soft_neg = row['sku_soft_neg'] if isinstance(row['sku_soft_neg'], list) else []

            # Store metadata for on-the-fly generation
            self.pair_indices.append({
                'type': 'query_data',
                'split_idx': split_idx,
                'query_sku': q,
                'pos_skus': pos,
                'hard_neg_skus': hard_neg,
                'soft_neg_skus': soft_neg
            })

    def _sample_soft_negatives_by_cluster(self, soft_neg_skus):
        """Sample soft negatives from each cluster separately."""
        if not self.cluster_sampling or self.cluster_emb_table is None:
            # If n_soft_neg is None, use all available
            if self.n_soft_neg is None:
                return soft_neg_skus
            elif len(soft_neg_skus) <= self.n_soft_neg:
                return soft_neg_skus
            else:
                return np.random.choice(soft_neg_skus, size=self.n_soft_neg, replace=False).tolist()
        
        # Group soft negatives by cluster
        cluster_groups = {}
        for sku in soft_neg_skus:
            cluster = self.sku_to_cluster.get(sku, -1)
            if cluster not in cluster_groups:
                cluster_groups[cluster] = []
            cluster_groups[cluster].append(sku)
        
        # Sample from each cluster
        sampled_soft_negs = []
        for cluster, skus in cluster_groups.items():
            if len(skus) == 0:
                continue
            
            n_to_sample = min(self.n_soft_neg_per_cluster, len(skus))
            if n_to_sample > 0:
                if len(skus) <= n_to_sample:
                    sampled_soft_negs.extend(skus)
                else:
                    sampled_soft_negs.extend(np.random.choice(skus, size=n_to_sample, replace=False))
        
        # If n_soft_neg is None, return all sampled
        if self.n_soft_neg is None:
            return sampled_soft_negs
            
        # If we need more samples to reach n_soft_neg, sample randomly from all
        remaining_needed = self.n_soft_neg - len(sampled_soft_negs)
        if remaining_needed > 0 and soft_neg_skus:
            remaining_skus = [sku for sku in soft_neg_skus if sku not in sampled_soft_negs]
            if remaining_skus:
                additional_samples = min(remaining_needed, len(remaining_skus))
                if len(remaining_skus) <= additional_samples:
                    sampled_soft_negs.extend(remaining_skus)
                else:
                    sampled_soft_negs.extend(np.random.choice(remaining_skus, size=additional_samples, replace=False))
        
        return sampled_soft_negs[:self.n_soft_neg] if self.n_soft_neg is not None else sampled_soft_negs

    def _generate_all_pairs(self):
        """Pre-generate all pairs for the dataset (memory-intensive but faster access)."""
        self.pairs = []
        
        for _, row in self.split_df.iterrows():
            q = row['sku_query']
            pos = row['sku_pos'] if isinstance(row['sku_pos'], list) else []
            hard_neg = row['sku_hard_neg'] if isinstance(row['sku_hard_neg'], list) else []
            soft_neg = row['sku_soft_neg'] if isinstance(row['sku_soft_neg'], list) else []

            # Sample positives
            if self.use_all_pos and self.n_pos is None:
                pos_sample = pos
            else:
                n_pos_actual = self.n_pos if self.n_pos is not None else len(pos)
                if len(pos) <= n_pos_actual:
                    pos_sample = pos
                else:
                    pos_sample = np.random.choice(pos, size=n_pos_actual, replace=False).tolist()

            # Sample hard negatives - handle None case
            if self.n_hard_neg is None:
                # Use all available hard negatives
                hard_sample = hard_neg
            elif len(hard_neg) <= self.n_hard_neg:
                hard_sample = hard_neg
            else:
                hard_sample = np.random.choice(hard_neg, size=self.n_hard_neg, replace=False).tolist()

            # Sample soft negatives with cluster-based sampling - handle None case
            soft_sample = self._sample_soft_negatives_by_cluster(soft_neg)

            # Create pairs
            for pos_sku in pos_sample:
                self.pairs.append({
                    'sku_first': q,
                    'sku_second': pos_sku,
                    'label': 0
                })

            for hard_sku in hard_sample:
                self.pairs.append({
                    'sku_first': q,
                    'sku_second': hard_sku,
                    'label': 1
                })

            for soft_sku in soft_sample:
                self.pairs.append({
                    'sku_first': q,
                    'sku_second': soft_sku,
                    'label': 1
                })

        # Convert to DataFrame and join with source data
        if self.pairs:
            pairs_df = pd.DataFrame(self.pairs)
            
            # Join with source_df to get all required columns
            first_data = self.source_df.loc[pairs_df['sku_first']].reset_index()
            first_data.columns = [f"{col}_first" if col != 'sku' else 'sku_first' for col in first_data.columns]
            
            second_data = self.source_df.loc[pairs_df['sku_second']].reset_index()
            second_data.columns = [f"{col}_second" if col != 'sku' else 'sku_second' for col in second_data.columns]
            
            self.pairs_df = pd.concat([
                first_data.reset_index(drop=True),
                second_data.reset_index(drop=True),
                pairs_df[['label']].reset_index(drop=True)
            ], axis=1)
            
            self.siamese_dataset = SiameseRuCLIPDataset(
                self.pairs_df.drop(columns=['label']),
                self.pairs_df['label'].values,
                images_dir=self.images_dir
            )
        else:
            self.pairs_df = pd.DataFrame()
            self.siamese_dataset = None

    def _generate_pair_on_demand(self, idx):
        """Generate a single pair on-demand (for lazy loading)."""
        # This is a simplified version - you'd need to implement the logic
        # to map idx to a specific (query, positive/negative) combination
        # and generate the pair data on-the-fly
        # This is more complex to implement but saves memory
        raise NotImplementedError("Lazy loading not fully implemented yet")

    def __len__(self):
        if self.lazy_loading:
            # Calculate total pairs without generating them
            total_pairs = 0
            for _, row in self.split_df.iterrows():
                pos_count = len(row['sku_pos']) if isinstance(row['sku_pos'], list) else 0
                hard_count = len(row['sku_hard_neg']) if isinstance(row['sku_hard_neg'], list) else 0
                soft_count = len(row['sku_soft_neg']) if isinstance(row['sku_soft_neg'], list) else 0
                
                # Apply limits only if they're not None
                if self.n_hard_neg is not None:
                    hard_count = min(hard_count, self.n_hard_neg)
                if self.n_soft_neg is not None:
                    soft_count = min(soft_count, self.n_soft_neg)
                    
                total_pairs += pos_count + hard_count + soft_count
            return total_pairs
        else:
            return len(self.pairs) if self.pairs else 0

    def __getitem__(self, idx):
        self._current_idx = idx  # Store current index
        if self.lazy_loading:
            return self._generate_pair_on_demand(idx)
        else:
            if self.siamese_dataset is None:
                raise IndexError("No pairs available")
            return self.siamese_dataset[idx]
    def get_pair_info(self, idx):
        """Get query and target SKU information for a given index."""
        pair = self.pairs[idx]
        return {
            'query_sku': pair['sku_first'],
            'target_sku': pair['sku_second'],
            'label': pair['label']
        }

# Run training

In [49]:
def fast_comprehensive_evaluation_limited(model, limited_batches, device, k_values=[1, 3, 5, 10]):
    """
    Fast ranking evaluation on limited data for smoke testing.
    
    Args:
        model: The trained model
        limited_batches: List of limited batches for evaluation
        device: Device to run evaluation on
        k_values: List of k values for precision@k and recall@k
    
    Returns:
        Dictionary containing evaluation metrics
    """
    import pandas as pd
    import numpy as np
    from sklearn.metrics import precision_recall_curve, average_precision_score
    import torch.nn.functional as F
    
    model.eval()
    all_data = []
    
    with torch.no_grad():
        for batch_idx, batch in enumerate(limited_batches):
            # Handle batch unpacking (with or without SKUs)
            if len(batch) == 9:  # Enhanced dataset with SKUs
                im1, n1, d1, im2, n2, d2, labels, query_skus, target_skus = batch
                im1, n1, d1, im2, n2, d2, labels = [t.to(device) for t in [im1, n1, d1, im2, n2, d2, labels]]
                query_skus = [str(sku) for sku in query_skus]
                target_skus = [str(sku) for sku in target_skus]
            else:  # Original dataset without SKUs
                im1, n1, d1, im2, n2, d2, labels = [t.to(device) for t in batch]
                batch_size = labels.size(0)
                query_skus = [f"query_{batch_idx}_{i}" for i in range(batch_size)]
                target_skus = [f"target_{batch_idx}_{i}" for i in range(batch_size)]

            out1, out2 = model(im1, n1, d1, im2, n2, d2)
            distances = F.pairwise_distance(out1, out2)
            similarities = 1 / (1 + distances)
            
            # Convert to CPU for processing
            distances_cpu = distances.cpu().numpy()
            similarities_cpu = similarities.cpu().numpy()
            labels_cpu = labels.cpu().numpy()
            
            # Store batch results
            batch_size = len(labels_cpu)
            for i in range(batch_size):
                all_data.append({
                    'query_sku': query_skus[i],
                    'target_sku': target_skus[i],
                    'distance': distances_cpu[i],
                    'similarity': similarities_cpu[i],
                    'label': labels_cpu[i]
                })
    
    if not all_data:
        # Return empty metrics if no data
        return {
            'mean_average_precision': 0.0,
            'precision_at_k': {k: 0.0 for k in k_values},
            'recall_at_k': {k: 0.0 for k in k_values},
            'f1_score': 0.0,
            'balanced_accuracy': 0.0,
            'specificity': 0.0,
            'optimal_threshold': 0.5
        }
    
    # Convert to DataFrame for easier processing
    df = pd.DataFrame(all_data)
    
    # Compute global metrics
    all_distances = df['distance'].values
    all_labels = df['label'].values
    all_similarities = df['similarity'].values
    
    # Convert to binary classification format (0 = positive/similar, 1 = negative/dissimilar)
    y_true = (all_labels == 0).astype(int)
    y_scores = all_similarities
    
    # Compute global metrics
    if len(np.unique(y_true)) > 1:  # Check if we have both classes
        try:
            precision, recall, pr_thresholds = precision_recall_curve(y_true, y_scores)
            ap_score = average_precision_score(y_true, y_scores)
            
            # Find optimal threshold based on F1 score
            f1_scores = 2 * (precision * recall) / (precision + recall + 1e-8)
            best_f1_idx = np.argmax(f1_scores)
            
            if len(pr_thresholds) > best_f1_idx:
                best_threshold_sim = pr_thresholds[best_f1_idx]
                best_threshold_dist = 1 / (1 + best_threshold_sim) - 1 if best_threshold_sim > 0 else 0.5
            else:
                best_threshold_dist = 0.5
                
            best_f1 = f1_scores[best_f1_idx]
        except Exception as e:
            print(f"Warning: Error computing global metrics: {e}")
            ap_score = 0.0
            best_f1 = 0.0
            best_threshold_dist = 0.5
    else:
        ap_score = 0.0
        best_f1 = 0.0
        best_threshold_dist = 0.5
    
    # Compute confusion matrix at optimal threshold
    y_pred = (all_distances < best_threshold_dist).astype(int)
    
    # Calculate confusion matrix components
    tp = np.sum((y_true == 1) & (y_pred == 1))
    fp = np.sum((y_true == 0) & (y_pred == 1))
    tn = np.sum((y_true == 0) & (y_pred == 0))
    fn = np.sum((y_true == 1) & (y_pred == 0))
    
    # Calculate metrics
    precision_global = tp / (tp + fp) if (tp + fp) > 0 else 0
    recall_global = tp / (tp + fn) if (tp + fn) > 0 else 0
    specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
    balanced_accuracy = (recall_global + specificity) / 2
    
    # Compute ranking metrics using fast computation
    precision_at_k, recall_at_k, mean_avg_precision = fast_ranking_metrics_limited(df, k_values)
    
    # Combine all metrics
    metrics = {
        # Global metrics
        'average_precision': ap_score,
        'precision': precision_global,
        'recall': recall_global,
        'f1_score': best_f1,
        'specificity': specificity,
        'balanced_accuracy': balanced_accuracy,
        'optimal_threshold': best_threshold_dist,
        'confusion_matrix': {'TP': int(tp), 'FP': int(fp), 'TN': int(tn), 'FN': int(fn)},
        
        # Ranking metrics
        'mean_average_precision': mean_avg_precision,
        'precision_at_k': precision_at_k,
        'recall_at_k': recall_at_k,
        'k_values': k_values
    }
    
    return metrics


def fast_ranking_metrics_limited(df, k_values):
    """
    Compute ranking metrics efficiently from a DataFrame with limited data.
    
    Args:
        df: DataFrame with columns ['query_sku', 'similarity', 'label']
        k_values: List of k values for precision@k and recall@k
    
    Returns:
        Tuple of (precision_at_k_dict, recall_at_k_dict, mean_average_precision)
    """
    import numpy as np
    
    # Initialize containers
    precision_at_k = {k: [] for k in k_values}
    recall_at_k = {k: [] for k in k_values}
    average_precisions = []
    
    # Group by query
    for query_sku in df['query_sku'].unique():
        query_data = df[df['query_sku'] == query_sku].sort_values('similarity', ascending=False)
        
        # Get labels in ranked order (0 = positive, 1 = negative)
        is_positive = (query_data['label'] == 0).values
        total_positives = is_positive.sum()
        
        if total_positives == 0:
            continue
        
        # Compute metrics for all k values at once
        for k in k_values:
            if len(is_positive) >= k:
                top_k_positives = is_positive[:k].sum()
                precision_at_k[k].append(top_k_positives / k)
                recall_at_k[k].append(top_k_positives / total_positives)
        
        # Compute Average Precision efficiently
        precision_values = []
        num_positives = 0
        for i, is_pos in enumerate(is_positive):
            if is_pos:
                num_positives += 1
                precision_values.append(num_positives / (i + 1))
        
        if precision_values:
            average_precisions.append(np.mean(precision_values))
    
    # Average across queries
    avg_precision_at_k = {k: np.mean(v) if v else 0.0 for k, v in precision_at_k.items()}
    avg_recall_at_k = {k: np.mean(v) if v else 0.0 for k, v in recall_at_k.items()}
    mean_avg_precision = np.mean(average_precisions) if average_precisions else 0.0
    
    return avg_precision_at_k, avg_recall_at_k, mean_avg_precision


def validation_with_embedding_collection(model, criterion, data_loader, epoch, 
                                        device='cpu', split_name='validation',
                                        metric='f1', limit_batches=None, 
                                        collect_for_ranking=False, k_values=[1, 3, 5, 10]):
    """
    Validation function that collects embeddings during loss computation for ranking metrics.
    This REUSES embeddings computed for loss, avoiding duplicate forward passes.
    
    Args:
        model: The model to evaluate
        criterion: Loss function
        data_loader: DataLoader for validation data
        epoch: Current epoch number
        device: Device to run on
        split_name: Name of the split being evaluated
        metric: Optimization metric ('f1', 'pos_acc', 'map')
        limit_batches: Limit number of batches (for smoke testing)
        collect_for_ranking: Whether to collect embeddings for ranking metrics
        k_values: List of k values for ranking metrics
    
    Returns:
        Tuple of (pos_acc, neg_acc, avg_acc, f1, avg_loss, best_thr, eval_metrics)
    """
    import gc
    import numpy as np
    import torch.nn.functional as F
    from sklearn.metrics import f1_score
    from tqdm.auto import tqdm
    
    assert metric in ('f1', 'pos_acc', 'map'), "metric must be 'f1', 'pos_acc', or 'map'"
    
    model.eval()
    total_loss = 0.0
    all_distances, all_labels = [], []
    
    # For ranking metrics collection (when collect_for_ranking=True)
    ranking_data = [] if collect_for_ranking else None
    
    # Calculate total batches
    total_batches = limit_batches if limit_batches is not None else len(data_loader)
    
    with torch.no_grad():
        # Create data_iter BEFORE the loop
        data_iter = enumerate(tqdm(data_loader, 
                                 desc=f"{split_name.capitalize()}", 
                                 unit="batch", 
                                 total=total_batches))
        
        for batch_idx, batch in data_iter:  # Now data_iter is defined
            if limit_batches is not None and batch_idx >= limit_batches:
                break
                
            # Handle batch unpacking
            if len(batch) == 9:  # Enhanced dataset with SKUs
                im1, n1, d1, im2, n2, d2, lbl, query_skus, target_skus = batch
                im1, n1, d1, im2, n2, d2, lbl = [t.to(device) for t in [im1, n1, d1, im2, n2, d2, lbl]]
            else:  # Original dataset without SKUs
                im1, n1, d1, im2, n2, d2, lbl = [t.to(device) for t in batch]
                if collect_for_ranking:
                    # 🔥 FIX: Use consistent query grouping instead of per-sample queries
                    batch_size = lbl.size(0)
                    # Group all samples in this batch under a single query for smoke testing
                    query_skus = [f"smoke_test_query"] * batch_size  # ← All samples share same query
                    target_skus = [f"target_{batch_idx}_{i}" for i in range(batch_size)]
            
            # SINGLE FORWARD PASS - compute embeddings once
            out1, out2 = model(im1, n1, d1, im2, n2, d2)
            
            # Compute loss using the same embeddings
            loss = criterion(out1, out2, lbl)
            total_loss += loss.item()
            
            # REUSE the same embeddings for distance computation
            distances = F.pairwise_distance(out1, out2)
            
            # Store for threshold optimization
            all_distances.append(distances.detach().cpu())
            all_labels.append(lbl.detach().cpu())
            
            # Collect data for ranking metrics if requested
            if collect_for_ranking and ranking_data is not None:
                similarities = 1 / (1 + distances)
                distances_cpu = distances.cpu().numpy()
                similarities_cpu = similarities.cpu().numpy()
                labels_cpu = lbl.cpu().numpy()
                
                batch_size = len(labels_cpu)
                for i in range(batch_size):
                    ranking_data.append({
                        'query_sku': query_skus[i] if len(batch) == 9 else f"query_{batch_idx}_{i}",
                        'target_sku': target_skus[i] if len(batch) == 9 else f"target_{batch_idx}_{i}",
                        'distance': distances_cpu[i],
                        'similarity': similarities_cpu[i],
                        'label': labels_cpu[i]
                    })
            
            # Explicit cleanup
            del im1, n1, d1, im2, n2, d2, lbl, out1, out2, loss, distances
            if len(batch) == 9:
                del query_skus, target_skus
    
    # Calculate metrics using actual number of processed batches
    actual_batches = min(batch_idx + 1, limit_batches or len(data_loader))
    avg_loss = total_loss / actual_batches
    
    # Concatenate all distances and labels
    distances = torch.cat(all_distances)
    labels = torch.cat(all_labels)
    
    # Clear the lists immediately to free memory
    del all_distances, all_labels
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    
    # Threshold sweep for optimization
    margin = 1.5
    steps = 200
    grid = np.linspace(0.0, margin, steps)
    best_val, best_thr = -1.0, 0.0
    y_true = (labels.numpy() == 0).astype(int)
    
    for t in grid:
        y_pred = (distances.numpy() < t).astype(int)
        if metric == 'f1':
            val = f1_score(y_true, y_pred, zero_division=0)
        else:  # metric == 'pos_acc'
            pos_mask = (y_true == 1)
            val = (y_pred[pos_mask] == 1).mean() if pos_mask.sum() > 0 else 0.0
        if val > best_val:
            best_val, best_thr = val, t
    
    threshold = best_thr
    
    # Final metrics at chosen threshold
    preds = (distances < threshold).long()
    pos_mask = (labels == 0)
    neg_mask = (labels == 1)
    
    pos_acc = (preds[pos_mask] == 1).float().mean().item() if pos_mask.any() else 0.0
    neg_acc = (preds[neg_mask] == 0).float().mean().item() if neg_mask.any() else 0.0
    avg_acc = (pos_acc + neg_acc) / 2.0
    f1 = f1_score((labels.numpy() == 0).astype(int), preds.numpy(), zero_division=0)
    
    # Compute ranking metrics if requested
    if collect_for_ranking and ranking_data:
        import pandas as pd
        ranking_df = pd.DataFrame(ranking_data)
        precision_at_k, recall_at_k, mean_avg_precision = fast_ranking_metrics_limited(ranking_df, k_values)
        
        eval_metrics = {
            'mean_average_precision': mean_avg_precision,
            'precision_at_k': precision_at_k,
            'recall_at_k': recall_at_k,
            'f1_score': f1,
            'balanced_accuracy': avg_acc,
            'specificity': neg_acc
        }
    else:
        eval_metrics = {
            'mean_average_precision': 0.0,
            'precision_at_k': {k: 0.0 for k in k_values},
            'recall_at_k': {k: 0.0 for k in k_values},
            'f1_score': f1,
            'balanced_accuracy': avg_acc,
            'specificity': neg_acc
        }
    
    # Final cleanup
    del distances, labels, preds, pos_mask, neg_mask
    if ranking_data:
        del ranking_data
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    
    # Log to console
    report = (f"[{split_name}] Epoch {epoch} – "
              f"loss: {avg_loss:.4f}, "
              f"P Acc: {pos_acc:.3f}, "
              f"N Acc: {neg_acc:.3f}, "
              f"Avg Acc: {avg_acc:.3f}, "
              f"F1: {f1:.3f}, "
              f"thr*: {threshold:.3f} "
              f"(optimised: {metric})")
    print(report)
    
    return pos_acc, neg_acc, avg_acc, f1, avg_loss, threshold, eval_metrics

In [38]:
def train_with_smart_ranking(model, optimizer, criterion, num_epochs, 
                           train_loader, valid_loader=None, device='cpu',
                           print_epoch=False, models_dir=None, metric='f1',
                           limit_batches=None, 
                           compute_ranking_every_n_epochs=5,  # Only compute ranking periodically
                           k_values=[1, 3, 5, 10]):
    """
    Memory-optimized training with smart ranking metrics computation that REUSES embeddings.
    
    Args:
        model: The model to train
        optimizer: Optimizer for training
        criterion: Loss function
        num_epochs: Number of epochs to train
        train_loader: Training data loader
        valid_loader: Validation data loader
        device: Device to use ('cpu' or 'cuda')
        print_epoch: Whether to print epoch information
        models_dir: Directory to save model checkpoints
        metric: Metric to optimize ('f1', 'pos_acc', 'map')
        limit_batches: Limit batches for smoke testing
        compute_ranking_every_n_epochs: How often to compute ranking metrics (0 = never, 1 = every epoch)
        k_values: List of k values for precision@k and recall@k
        
    Returns:
        Tuple of (train_losses, val_losses, best_valid_metric, best_weights, thr_history, ranking_history)
    """
    import gc
    import psutil
    from tqdm.auto import tqdm
    from torch.optim.lr_scheduler import ReduceLROnPlateau
    from pathlib import Path
    from copy import deepcopy
    
    # Ensure epochs_num is an integer and properly limited
    num_epochs = int(num_epochs)
    assert num_epochs > 0, f"epochs_num must be positive, got {num_epochs}"
    
    print(f"🔄 Starting training for {num_epochs} epochs")
    
    assert metric in ('f1', 'pos_acc', 'map'), "metric must be 'f1', 'pos_acc', or 'map'"

    def log_memory_usage(stage):
        process = psutil.Process()
        memory_mb = process.memory_info().rss / 1024 / 1024
        print(f"Memory usage at {stage}: {memory_mb:.1f} MB")

    model.to(device)
    train_losses, val_losses, thr_history = [], [], []
    ranking_history = []
    best_valid_metric, best_threshold = float('-inf'), None
    best_weights = None

    scheduler = ReduceLROnPlateau(
        optimizer,
        mode="max",
        factor=0.1,
        patience=3,
        threshold=1e-4,
        threshold_mode='rel'
    )

    if models_dir:
        Path(models_dir).mkdir(parents=True, exist_ok=True)

    log_memory_usage("training start")
    
    # Create epoch progress bar with explicit range
    epoch_range = list(range(1, num_epochs + 1))  # Convert to list to be explicit
    epoch_pbar = tqdm(epoch_range, desc="Epochs", unit="epoch")
    
    epoch_counter = 0  # Add explicit counter for debugging
    
    for epoch in epoch_pbar:
        epoch_counter += 1
        print(f"🔸 Processing epoch {epoch}/{num_epochs} (counter: {epoch_counter})")
        
        # Add safety check
        if epoch_counter > num_epochs:
            print(f"⚠️ Safety break: epoch_counter ({epoch_counter}) > epochs_num ({num_epochs})")
            break
        log_memory_usage(f"epoch {epoch} start")
        
        # ==================== TRAINING PHASE ====================
        model.train()
        total_train_loss = 0.0
        
        # Create batch progress bar for training
        train_total = limit_batches if limit_batches is not None else len(train_loader)
        train_pbar = tqdm(train_loader, desc=f"Epoch {epoch} - Training", 
                         leave=False, unit="batch", total=train_total)
        
        batch_count = 0
        for batch_idx, batch in enumerate(train_pbar):
            if limit_batches is not None and batch_idx >= limit_batches:
                break
                
            # Handle batch unpacking (with or without SKUs)
            if len(batch) == 9:  # Enhanced dataset with SKUs
                im1, n1, d1, im2, n2, d2, lbl, query_skus, target_skus = batch
                im1, n1, d1, im2, n2, d2, lbl = [t.to(device) for t in [im1, n1, d1, im2, n2, d2, lbl]]
            else:  # Original dataset without SKUs
                im1, n1, d1, im2, n2, d2, lbl = [t.to(device) for t in batch]
            
            optimizer.zero_grad()
            out1, out2 = model(im1, n1, d1, im2, n2, d2)
            loss = criterion(out1, out2, lbl)
            loss.backward()
            optimizer.step()
            total_train_loss += loss.item()
            batch_count += 1
            
            # Explicit cleanup of batch variables
            del im1, n1, d1, im2, n2, d2, lbl, out1, out2, loss
            if len(batch) == 9:
                del query_skus, target_skus
            
            train_pbar.set_postfix({'loss': f'{total_train_loss/batch_count:.4f}'})

        train_pbar.close()
        
        # Calculate average loss using actual number of processed batches
        avg_train_loss = total_train_loss / batch_count if batch_count > 0 else 0.0
        train_losses.append(avg_train_loss)

        # Memory cleanup after training phase
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        
        log_memory_usage(f"epoch {epoch} after training")

        # ==================== VALIDATION PHASE ====================
        if valid_loader is not None:
            # Decide whether to compute ranking metrics
            compute_ranking = (
                compute_ranking_every_n_epochs > 0 and  # Ranking computation enabled
                (epoch % compute_ranking_every_n_epochs == 0 or epoch == num_epochs)  # Periodic or final
            )
            
            print(f"Computing validation metrics (ranking: {compute_ranking})...")
            
            # 🔥 KEY FIX: Use validation_with_embedding_collection instead of validation_with_optional_ranking
            # This REUSES embeddings computed for loss computation!
            pos_acc, neg_acc, avg_acc, f1, avg_val_loss, best_thr, eval_metrics = validation_with_embedding_collection(
                model=model,
                criterion=criterion,
                data_loader=valid_loader,
                epoch=epoch,
                device=device,
                split_name='validation',
                metric=metric,
                limit_batches=limit_batches,
                collect_for_ranking=compute_ranking,  # ← Key parameter that enables embedding reuse!
                k_values=k_values
            )
            
            ranking_history.append(eval_metrics)
            val_losses.append(avg_val_loss)
            thr_history.append(best_thr)

            # Choose the metric value for optimization
            if metric == 'f1':
                current_metric = f1
            elif metric == 'pos_acc':
                current_metric = pos_acc
            elif metric == 'map':
                current_metric = eval_metrics.get('mean_average_precision', 0.0)
            else:
                current_metric = f1

            scheduler.step(current_metric)

            # Enhanced logging with ranking metrics when available
            if compute_ranking and eval_metrics.get('mean_average_precision', 0) > 0:
                print(f"Epoch {epoch}/{num_epochs} - Train Loss: {avg_train_loss:.4f}, "
                      f"Val Loss: {avg_val_loss:.4f}, Val {metric}: {current_metric:.4f}")
                print(f"  Ranking Metrics - MAP: {eval_metrics.get('mean_average_precision', 0):.4f}")
                
                # Print P@k and R@k for selected k values
                for k in [1, 3, 5]:  # Show only key metrics to avoid clutter
                    if k in k_values and k in eval_metrics.get('precision_at_k', {}):
                        p_k = eval_metrics['precision_at_k'].get(k, 0)
                        r_k = eval_metrics['recall_at_k'].get(k, 0)
                        print(f"  P@{k}: {p_k:.3f}, R@{k}: {r_k:.3f}")
            else:
                print(f"Epoch {epoch}/{num_epochs} - Train Loss: {avg_train_loss:.4f}, "
                      f"Val Loss: {avg_val_loss:.4f}, Val {metric}: {current_metric:.4f}, "
                      f"Pos Acc: {pos_acc:.3f}, Neg Acc: {neg_acc:.3f}, F1: {f1:.3f}, "
                      f"Threshold: {best_thr:.3f}")

            # Save checkpoint every epoch if requested
            if models_dir:
                checkpoint_path = Path(models_dir) / f"checkpoint_epoch_{epoch}.pt"
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'train_loss': avg_train_loss,
                    'val_loss': avg_val_loss,
                    'metric_value': current_metric,
                    'threshold': best_thr
                }, checkpoint_path)

            # Update best model if improved
            if current_metric > best_valid_metric:
                best_valid_metric = current_metric
                best_threshold = best_thr
                best_weights = deepcopy(model.state_dict())
                if models_dir:
                    best_model_path = Path(models_dir) / "best_model.pt"
                    torch.save({
                        'epoch': epoch,
                        'model_state_dict': best_weights,
                        'metric_value': best_valid_metric,
                        'threshold': best_threshold,
                        'metric_name': metric
                    }, best_model_path)

            # Memory cleanup after validation
            gc.collect()
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            
            log_memory_usage(f"epoch {epoch} end")

        # Update epoch progress bar
        if valid_loader is not None:
            epoch_pbar.set_postfix({
                'train_loss': f'{avg_train_loss:.4f}',
                'val_loss': f'{avg_val_loss:.4f}',
                f'val_{metric}': f'{current_metric:.4f}'
            })

    epoch_pbar.close()
    
    print(f"Training completed!")
    print(f"Best validation {metric}: {best_valid_metric:.4f} (threshold={best_threshold:.4f})")
    
    # Final memory cleanup
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    
    log_memory_usage("training complete")
    
    print(f"✅ Training completed! Processed {epoch_counter} epochs out of {num_epochs} requested")
    
    return train_losses, val_losses, best_valid_metric, best_weights, thr_history, ranking_history

In [32]:
# suppress hf warnings
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [33]:
from sklearn.metrics import f1_score                 # ← new
import numpy as np
import torch.nn.functional as F

def best_threshold(distances: torch.Tensor,
                   labels:    torch.Tensor,
                   steps:     int = 200,
                   margin:    float = 1.5):
    """
    Sweep `steps` evenly-spaced thresholds between 0 and `margin`
    and return the one that maximises duplicate-class F1.
    Labels: 0 = duplicate (positive), 1 = different (negative).
    """
    d   = distances.detach().cpu().numpy()
    y   = labels.detach().cpu().numpy()
    thr = np.linspace(0.0, margin, steps)

    best_f1, best_thr = -1.0, 0.0
    for t in thr:
        y_pred = (d < t).astype(int)          # 1 = duplicate prediction
        f1     = f1_score(1 - y, y_pred)      # make 1 = positive for sklearn
        if f1 > best_f1:
            best_f1, best_thr = f1, t
    return best_thr, best_f1

In [34]:
def print_dataset_distribution(train_dataset, val_dataset, test_dataset):
    """Print the distribution of positives, hard negatives, and soft negatives across all splits."""
    import pandas as pd
    from IPython.display import display

    def get_split_stats(dataset, split_name):
        total_pairs = len(dataset)
        
        # Count actual pairs by label
        pos_count = sum(1 for pair in dataset.pairs if pair['label'] == 0)
        neg_count = sum(1 for pair in dataset.pairs if pair['label'] == 1)
        
        # Split negatives into hard and soft by counting in pairs
        hard_neg_count = 0
        soft_neg_count = 0
        for pair in dataset.pairs:
            if pair['label'] == 1:  # it's a negative pair
                sku_second = pair['sku_second']
                # Check if this negative is in hard negatives of the query
                query_row = dataset.split_df[dataset.split_df['sku_query'] == pair['sku_first']].iloc[0]
                if sku_second in query_row['sku_hard_neg']:
                    hard_neg_count += 1
                else:
                    soft_neg_count += 1
        
        return {
            'Split': split_name,
            'Total Pairs': f"{total_pairs:,}",
            'Positives': f"{pos_count:,} ({(pos_count/total_pairs*100):.1f}%)",
            'Hard Negatives': f"{hard_neg_count:,} ({(hard_neg_count/total_pairs*100):.1f}%)",
            'Soft Negatives': f"{soft_neg_count:,} ({(soft_neg_count/total_pairs*100):.1f}%)"
        }

    # Create distribution table
    stats = []
    stats.append(get_split_stats(train_dataset, "Training"))
    stats.append(get_split_stats(val_dataset, "Validation"))
    stats.append(get_split_stats(test_dataset, "Testing"))
    
    # Create and display DataFrame
    df = pd.DataFrame(stats)
    df = df.set_index('Split')
    
    print("\nDataset Distribution Statistics:")
    display(df)

In [44]:
import random

import random
import gc
import psutil

def _run(train_df, val_df, test_df, source_df, cluster_emb_table, images_dir,
         # Model params 
         model_name='DeepPavlov/distilrubert-tiny-cased-conversational-v1',
         description_model_name='cointegrated/rubert-tiny',
         batch_size=1, num_epochs=10, learning_rate=1e-4,
         # Dataset sampling params
         n_pos=None,
         n_hard_neg=5,
         n_soft_neg=5,
         use_all_pos=True,
         cluster_sampling=True,
         n_soft_neg_per_cluster=1,
         # Other params
         mlflow_tracking_uri=None, 
         experiment_name=None, 
         SMOKE_TEST_BATCHES=None,
         optimize_for_ranking=False, k_values=[1, 3, 5, 10]
         ):
    """
    Memory-optimized training pipeline with comprehensive cleanup.
    """
    
    def log_memory_usage(stage):
        process = psutil.Process()
        memory_mb = process.memory_info().rss / 1024 / 1024
        print(f"Memory usage at {stage}: {memory_mb:.1f} MB")
    
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # Set random seeds for reproducibility
    np.random.seed(42)
    random.seed(42)
    torch.manual_seed(42)

    if SMOKE_TEST_BATCHES is not None:
        print(f"🔥 SMOKE TEST MODE: Limited to {SMOKE_TEST_BATCHES} batches per phase")
    
    log_memory_usage("start")
    
    # ==================== DATASET CREATION ====================
    print("Creating datasets...")

    train_dataset = ClusterBasedPairwiseDataset(
        split_df=train_df,
        source_df=source_df,
        cluster_emb_table=cluster_emb_table,
        images_dir=images_dir,
        n_pos=n_pos,
        n_hard_neg=n_hard_neg,
        n_soft_neg=n_soft_neg,
        use_all_pos=use_all_pos,
        cluster_sampling=cluster_sampling,
        n_soft_neg_per_cluster=n_soft_neg_per_cluster,
        lazy_loading=False
    )

    val_dataset = ClusterBasedPairwiseDataset(
        split_df=val_df,
        source_df=source_df,
        cluster_emb_table=cluster_emb_table,
        images_dir=images_dir,
        n_pos=n_pos,
        n_hard_neg=n_hard_neg,
        n_soft_neg=n_soft_neg,
        use_all_pos=use_all_pos,
        cluster_sampling=cluster_sampling,
        n_soft_neg_per_cluster=n_soft_neg_per_cluster,
        lazy_loading=False
    )

    test_dataset = ClusterBasedPairwiseDataset(
        split_df=test_df,
        source_df=source_df,
        cluster_emb_table=cluster_emb_table,
        images_dir=images_dir,
        n_pos=n_pos,
        n_hard_neg=n_hard_neg,
        n_soft_neg=n_soft_neg,
        use_all_pos=use_all_pos,
        cluster_sampling=cluster_sampling,
        n_soft_neg_per_cluster=n_soft_neg_per_cluster,
        lazy_loading=False
    )

    # Print distribution statistics
    print_dataset_distribution(train_dataset, val_dataset, test_dataset)
    
    # ==================== DATALOADER SETUP ====================
    print("Creating dataloaders...")
    # Create balanced training loader
    train_loader = create_balanced_train_loader(
        train_dataset, 
        batch_size=batch_size
    )
    
    val_loader = DataLoader(
        val_dataset, 
        batch_size=batch_size, 
        shuffle=False, 
        num_workers=0,
        pin_memory=False
    )
    
    test_loader = DataLoader(
        test_dataset, 
        batch_size=batch_size, 
        shuffle=False, 
        num_workers=0,
        pin_memory=False
    )
    
    print(f"DataLoader info:")
    print(f"  Training batches: {len(train_loader)} (batch_size={batch_size})")
    print(f"  Validation batches: {len(val_loader)} (batch_size={batch_size})")
    print(f"  Testing batches: {len(test_loader)} (batch_size={batch_size})")
    
    # ==================== MODEL & OPTIMIZER ====================
    print("Initializing model and optimizer...")
    
    model = SiameseRuCLIP(
        device=device,
        name_model_name=model_name,
        description_model_name=description_model_name
    ).to(device)

    criterion = ContrastiveLoss(margin=1.5)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    if mlflow_tracking_uri:
        mlflow.set_tracking_uri(mlflow_tracking_uri)
        mlflow.set_experiment(experiment_name)
        mlflow.log_params({
            'batch_size': batch_size,
            'num_epochs': num_epochs,
            'learning_rate': learning_rate,
            'model_name': model_name,
            'description_model_name': description_model_name,
            'smoke_test_batches': SMOKE_TEST_BATCHES
        })

    log_memory_usage("after model creation")
    
    # ==================== TRAINING ====================
    print("Starting training...")
    # Print batch limitation message once after training
    if SMOKE_TEST_BATCHES is not None:
        print(f'🚨 Training was limited to {SMOKE_TEST_BATCHES} batches per epoch')

    # Choose optimization metric
    if optimize_for_ranking:
        optimization_metric = 'map'  # Optimize for Mean Average Precision
    else:
        optimization_metric = 'f1'
    
    # Train with ranking metrics every 5 epochs
    train_losses, val_losses, best_metric, best_weights, thr_history, ranking_history = train_with_smart_ranking(
        model=model,
        optimizer=optimizer,
        criterion=criterion,
        num_epochs=num_epochs,
        train_loader=train_loader,
        valid_loader=val_loader,
        device=device,
        print_epoch=True,
        metric='f1',  # or 'map' for ranking optimization
        limit_batches=SMOKE_TEST_BATCHES,
        compute_ranking_every_n_epochs=1,  # Compute ranking every 2 epochs
        k_values=[1, 3, 5, 10, 20]
    )
    
    # ==================== SMART TEST EVALUATION ====================
    print("\nRunning final test evaluation with fast ranking metrics...")
    if best_weights is not None:
        model.load_state_dict(best_weights)

    # 🔥 FIXED: Use the same approach as training/validation
    if SMOKE_TEST_BATCHES is not None:
        print(f"🔥 Test evaluation limited to {SMOKE_TEST_BATCHES} batches")
        # Use validation_with_embedding_collection for consistency
        pos_acc, neg_acc, avg_acc, test_f1, test_loss, threshold, test_metrics = validation_with_embedding_collection(
            model=model,
            criterion=criterion,
            data_loader=test_loader,
            epoch='test',
            device=device,
            split_name='test',
            metric='f1',
            limit_batches=SMOKE_TEST_BATCHES,  # ← This properly limits batches
            collect_for_ranking=True,  # Get ranking metrics
            k_values=k_values
        )
    else:
        # For full evaluation, use the comprehensive function
        test_metrics = fast_comprehensive_evaluation(model, test_loader, device, k_values=k_values)
        # Extract individual metrics from test_metrics for consistency
        pos_acc = test_metrics.get('precision', 0)
        neg_acc = test_metrics.get('specificity', 0) 
        avg_acc = test_metrics.get('balanced_accuracy', 0)
        test_f1 = test_metrics.get('f1_score', 0)
        test_loss = 0.0  # Not computed in comprehensive evaluation
        threshold = test_metrics.get('optimal_threshold', 0.5)

    # Print test results
    print("\nFinal Test Results:")
    print(f"  Test F1: {test_f1:.4f}")
    print(f"  Test Loss: {test_loss:.4f}")
    print(f"  Test Positive Accuracy: {pos_acc:.4f}")
    print(f"  Test Negative Accuracy: {neg_acc:.4f}")
    print(f"  Test Average Accuracy: {avg_acc:.4f}")
    print(f"  Best Threshold: {threshold:.4f}")
    
    # Print ranking metrics if available
    if test_metrics.get('mean_average_precision', 0) > 0:
        print(f"  Mean Average Precision: {test_metrics['mean_average_precision']:.4f}")
        for k in [1, 3, 5]:
            if k in test_metrics.get('precision_at_k', {}):
                p_k = test_metrics['precision_at_k'][k]
                r_k = test_metrics['recall_at_k'][k]
                print(f"  P@{k}: {p_k:.3f}, R@{k}: {r_k:.3f}")

    # Final cleanup
    del test_loader, model, optimizer, criterion
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    
    log_memory_usage("final")
    
    # Log final metrics to MLflow if enabled
    if mlflow_tracking_uri:
        mlflow.log_metrics({
            'test_loss': test_loss,
            'test_f1': test_f1,
            'test_pos_acc': pos_acc,
            'test_neg_acc': neg_acc,
            'test_avg_acc': avg_acc,
            'test_threshold': threshold,
            'test_map': test_metrics.get('mean_average_precision', 0.0)
        })
    
    return {
        'train_losses': train_losses,
        'val_losses': val_losses,
        'best_valid_metric': best_metric,
        'test_f1': test_f1,
        'test_loss': test_loss,
        'test_pos_acc': pos_acc,
        'test_neg_acc': neg_acc,
        'test_avg_acc': avg_acc,
        'test_threshold': threshold,
        'test_metrics': test_metrics,
        'ranking_history': ranking_history,
        'k_values': k_values
    }

In [None]:
# Extract the splits
train_df = splits_dataset['train']
val_df = splits_dataset['val']
test_df = splits_dataset['test']

images_dir = DATA_PATH + IMG_DATASET_NAME
results = _run(
    # Required data params
    train_df=train_df,
    val_df=val_df,
    test_df=test_df,
    source_df=source_df,
    cluster_emb_table=cluster_emb_table,
    images_dir=images_dir,
    
    # Model params
    model_name=NAME_MODEL_NAME,
    description_model_name=DESCRIPTION_MODEL_NAME,
    batch_size=BATCH_SIZE,
    num_epochs=EPOCHS,
    
    # # Dataset sampling params - example balanced configuration
    # n_pos=100,                    # Sample 10 positives
    # n_hard_neg=100,              # 10 hard negatives
    # n_soft_neg=100,              # 10 soft negatives
    # use_all_pos=False,          # Use fixed number instead of all
    # cluster_sampling=True,       # Sample from clusters
    # n_soft_neg_per_cluster=1,   # Take 1 from each cluster

    # Let training use all available data, limit val/test
    n_pos=None,                 # Use all positives available
    n_hard_neg=None,           # Use all hard negatives available  
    n_soft_neg=None,           # Use all soft negatives available
    use_all_pos=True,          # Use all positives
    cluster_sampling=True,
    n_soft_neg_per_cluster=1,
    
    # Other params
    SMOKE_TEST_BATCHES=SMOKE_TEST_BATCHES
)

Using device: cpu
🔥 SMOKE TEST MODE: Limited to 10 batches per phase
Memory usage at start: 5940.1 MB
Creating datasets...

Dataset Distribution Statistics:


Unnamed: 0_level_0,Total Pairs,Positives,Hard Negatives,Soft Negatives
Split,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
Training,13903,803 (5.8%),"3,492 (25.1%)","9,608 (69.1%)"
Validation,6017,130 (2.2%),514 (8.5%),"5,373 (89.3%)"
Testing,5966,130 (2.2%),515 (8.6%),"5,321 (89.2%)"


Creating dataloaders...
DataLoader info:
  Training batches: 2781 (batch_size=5)
  Validation batches: 1204 (batch_size=5)
  Testing batches: 1194 (batch_size=5)
Initializing model and optimizer...
Memory usage at after model creation: 5941.6 MB
Starting training...
🚨 Training was limited to 10 batches per epoch
🔄 Starting training for 2 epochs
Memory usage at training start: 5941.6 MB


Epochs:   0%|          | 0/2 [00:00<?, ?epoch/s]

🔸 Processing epoch 1/2 (counter: 1)
Memory usage at epoch 1 start: 5941.6 MB


Epoch 1 - Training:   0%|          | 0/10 [00:00<?, ?batch/s]

Memory usage at epoch 1 after training: 6030.3 MB
Computing validation metrics (ranking: True)...


Validation:   0%|          | 0/10 [00:00<?, ?batch/s]

[validation] Epoch 1 – loss: nan, P Acc: 0.000, N Acc: 1.000, Avg Acc: 0.500, F1: 0.000, thr*: 0.000 (optimised: f1)
Epoch 1/2 - Train Loss: nan, Val Loss: nan, Val f1: 0.0000
  Ranking Metrics - MAP: 1.0000
  P@1: 1.000, R@1: 1.000
  P@3: 0.000, R@3: 0.000
  P@5: 0.000, R@5: 0.000
Memory usage at epoch 1 end: 6030.3 MB
🔸 Processing epoch 2/2 (counter: 2)
Memory usage at epoch 2 start: 6030.3 MB


Epoch 2 - Training:   0%|          | 0/10 [00:00<?, ?batch/s]