In [1]:
import os

DATA_PATH = "/kaggle/input/www2025-mmctr-data"
print("Files in DATA_PATH:", os.listdir(DATA_PATH))


Files in DATA_PATH: ['MicroLens_1M_MMCTR']


In [None]:
!pip install protobuf==3.20.3


In [None]:
# ======================================================
# CTR_ItemEmb_Extraction_Kaggle.py
# Extraction des embeddings texte+image, fusion + projection
# ======================================================

import os
import torch
import torch.nn as nn
import polars as pl
import numpy as np
from transformers import AutoProcessor, AutoModel
from PIL import Image
from tqdm.auto import tqdm

# ---------------------------
# CONFIG
# ---------------------------
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_NAME = "google/siglip-base-patch16-384"
SIGLIP_DIM = 768
FINAL_DIM = 128

# Répertoire racine du dataset
BASE_DIR = "/kaggle/input/www2025-mmctr-data/MicroLens_1M_MMCTR"

# item_info est dans le sous-dossier MicroLens_1M_x1
ITEM_INFO_PATH = os.path.join(BASE_DIR, "MicroLens_1M_x1", "item_info.parquet")

# Le fichier avec les titres
ITEM_FEATURE_PATH = os.path.join(BASE_DIR,"item_feature.parquet")

# Les images sont dans item_images/item_images
IMAGE_DIR = os.path.join(BASE_DIR, "item_images", "item_images")

# Emplacement de sauvegarde
SAVE_PATH = "/kaggle/working/new_item_info.parquet"

# ---------------------------
# FONCTION PRINCIPALE
# ---------------------------
def main():
    print(f"Device: {DEVICE}")
    print("Chargement du modèle SigLIP...")
    processor = AutoProcessor.from_pretrained(MODEL_NAME)
    model = AutoModel.from_pretrained(MODEL_NAME).to(DEVICE)
    model.eval()
    proj_layer = nn.Linear(SIGLIP_DIM*2, FINAL_DIM).to(DEVICE)
    proj_layer.eval()

    # Chargement item_info et item_feature
    item_info = pl.read_parquet(ITEM_INFO_PATH)
    item_feature = pl.read_parquet(ITEM_FEATURE_PATH)

    # Merge pour avoir item_title
    df = item_info.join(item_feature[['item_id', 'item_title']], on='item_id', how='left')

    item_ids = df["item_id"].to_list()
    item_titles = df["item_title"].to_list()
    item_tags = df["item_tags"].to_list()
    print(f"Chargé {len(item_ids)} items (padding inclus)")

    final_embs = []
    missing = 0

    # Boucle sur les items
    for i, iid in tqdm(enumerate(item_ids), total=len(item_ids), desc="Extraction embeddings"):
        if iid == 0:  # padding
            final_embs.append([0.0]*FINAL_DIM)
            continue

        # Récupère le titre réel
        title = str(item_titles[i]) if item_titles[i] is not None else f"item {iid}"

        img_path = os.path.join(IMAGE_DIR, f"{iid}.jpg")
        if not os.path.exists(img_path):
            missing += 1
            final_embs.append([0.0]*FINAL_DIM)
            continue

        try:
            # Embedding texte
            text_inputs = processor(text=title, return_tensors="pt", padding=True, truncation=True).to(DEVICE)
            with torch.no_grad():
                text_emb = model.get_text_features(**text_inputs)
                text_emb = text_emb / text_emb.norm(dim=-1, keepdim=True)

            # Embedding image
            image = Image.open(img_path).convert("RGB")
            img_inputs = processor(images=image, return_tensors="pt").to(DEVICE)
            with torch.no_grad():
                img_emb = model.get_image_features(**img_inputs)
                img_emb = img_emb / img_emb.norm(dim=-1, keepdim=True)

            # Fusion + projection
            concat = torch.cat([text_emb, img_emb], dim=-1)
            with torch.no_grad():
                fused = proj_layer(concat).squeeze(0).cpu().numpy()
                fused = fused / np.linalg.norm(fused)  # L2-normalization

            final_embs.append(fused)

        except Exception as e:
            print(f"Erreur pour item {iid}: {e}")
            final_embs.append([0.0]*FINAL_DIM)

    print(f"\nExtraction terminée. Images manquantes : {missing}")

    # Sauvegarde
    out_df = pl.DataFrame({
      "item_id": item_ids,
      "item_tags": item_tags,
      "item_emb_d128": [np.array(v, dtype=np.float64).tolist() for v in final_embs]
    }, strict=False)

    out_df.write_parquet(SAVE_PATH)
    print(f"Fichier sauvegardé : {SAVE_PATH}")
    print(f"Colonnes : item_id | item_tags | item_emb_d128")

# ---------------------------
# EXECUTION
# ---------------------------
if __name__ == "__main__":
    main()
