In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer, AutoModel
from accelerate import Accelerator
from sklearn.metrics.pairwise import cosine_similarity
import pandas as pd
import numpy as np
from torch.nn.utils.rnn import pad_sequence
from torch import optim
from tqdm import tqdm
import pickle
import json

In [2]:
textual_history = pd.read_parquet('textual_history.parquet')
id_history = pd.read_parquet('id_history.parquet')
user_descriptions = pd.read_parquet('user_descriptions.parquet')

In [3]:
model_dir = "/storage/kromanova/models/multilingual-e5-large"

In [4]:
tokenizer = AutoTokenizer.from_pretrained(model_dir)

# Устройство
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
class BuildTrainDataset(Dataset):
    """Dataset for training with unpackable outputs and user-description fallback."""
    def __init__(self, textual_history, user_descriptions, id_history, tokenizer, max_length=128):
        """
        Args:
            textual_history: DataFrame with goods descriptions.
            user_descriptions: DataFrame with user profiles (text descriptions).
            id_history: DataFrame with (user_id, goods_ids).
            tokenizer: Tokenizer for text inputs.
            max_length: Maximum length for tokenized sequences.
        """
        self.textual_history = textual_history
        self.user_descriptions = user_descriptions
        self.id_history = id_history
        self.tokenizer = tokenizer
        self.max_length = max_length

        # Создаем маппинги для преобразования индексов
        self.user_id_map = {uid: idx for idx, uid in enumerate(id_history['viewer_uid'].unique())}
        self.item_id_map = {iid: idx for idx, iid in enumerate(id_history['clean_video_id'].explode().unique())}

        # Создаем обратные маппинги
        self.reverse_user_id_map = {idx: uid for uid, idx in self.user_id_map.items()}
        self.reverse_item_id_map = {idx: iid for iid, idx in self.item_id_map.items()}

    def __len__(self):
        return len(self.id_history)

    def __getitem__(self, idx):
        # Получаем viewer_uid и связанные данные
        viewer_uid = self.id_history.iloc[idx]['viewer_uid']
        item_text = self.textual_history.iloc[idx]['detailed_view']
        item_ids = self.id_history.iloc[idx]['clean_video_id']
    
        # Преобразование item_ids в числовой формат
        if isinstance(item_ids, (list, np.ndarray)):
            item_ids = [int(x) for x in item_ids]
        else:
            item_ids = [int(item_ids)]


         # Преобразуем абсолютные индексы в последовательные
        if isinstance(item_ids, (list, np.ndarray)):
            item_ids = [self.item_id_map[str(x)] for x in item_ids]
        else:
            item_ids = [self.item_id_map[str(item_ids)]]
            
        # Преобразуем user_id
        mapped_user_id = self.user_id_map[viewer_uid]

        # Сопоставляем viewer_uid с user_descriptions
        user_row = self.user_descriptions[self.user_descriptions['viewer_uid'] == viewer_uid]
        if not user_row.empty:
            user_text = user_row.iloc[0]['user_description']
        else:
            user_text = ""  # Пустое описание, если пользователь отсутствует
    
        # Комбинируем тексты товаров
        item_text = ' '.join(item_text)
    
        # Токенизация текстов
        item_encoding = self.tokenizer(
            item_text, padding='max_length', truncation=True, max_length=self.max_length, return_tensors="pt"
        )
        user_encoding = self.tokenizer(
            user_text, padding='max_length', truncation=True, max_length=self.max_length, return_tensors="pt"
        )
    
        item_text_inputs = {key: val.squeeze(0) for key, val in item_encoding.items()}
        user_text_inputs = {key: val.squeeze(0) for key, val in user_encoding.items()}

        item_ids = torch.tensor(item_ids, dtype=torch.int64)
        user_id = torch.tensor(mapped_user_id, dtype=torch.int64)
    
        return item_text_inputs, user_text_inputs, item_ids, user_id

In [6]:
from torch.nn.utils.rnn import pad_sequence

def custom_collate_fn(batch):
    """
    Custom collate function to handle variable-sized tensors in a batch.
    Args:
        batch: List of tuples from __getitem__.
    Returns:
        Batched tensors with padding.
    """
    # Разделяем входные данные
    item_text_inputs, user_text_inputs, item_ids, user_ids = zip(*batch)
    
    # Паддинг для текстовых данных
    item_text_inputs = {key: pad_sequence([x[key] for x in item_text_inputs], batch_first=True) for key in item_text_inputs[0]}
    user_text_inputs = {key: pad_sequence([x[key] for x in user_text_inputs], batch_first=True) for key in user_text_inputs[0]}
    
    # Паддинг для item_ids
    item_ids = pad_sequence([torch.tensor(x, dtype=torch.int64) for x in item_ids], batch_first=True, padding_value=0)
    
    # Стекинг user_ids
    user_ids = torch.stack(user_ids)
    
    return item_text_inputs, user_text_inputs, item_ids, user_ids

In [7]:
dataset = BuildTrainDataset(textual_history, user_descriptions, id_history, tokenizer, max_length=128)
dataloader = DataLoader(
    dataset,
    batch_size=32,
    shuffle=True,
    collate_fn=custom_collate_fn
)

In [8]:
# Сохранение маппингов
def save_mappings(dataset, path='mappings/'):
    import os
    os.makedirs(path, exist_ok=True)
    
    # Сохраняем через pickle (поддерживает любые типы данных Python)
    mappings = {
        'user_id_map': dataset.user_id_map,
        'item_id_map': dataset.item_id_map,
        'reverse_user_id_map': dataset.reverse_user_id_map,
        'reverse_item_id_map': dataset.reverse_item_id_map
    }
    
    with open(f'{path}id_mappings.pkl', 'wb') as f:
        pickle.dump(mappings, f)
    
    # Опционально: сохраняем через json (только если все ключи - строки)
    # Преобразуем ключи в строки
    json_mappings = {
        'user_id_map': {str(k): v for k, v in dataset.user_id_map.items()},
        'item_id_map': {str(k): v for k, v in dataset.item_id_map.items()},
        'reverse_user_id_map': {str(k): str(v) for k, v in dataset.reverse_user_id_map.items()},
        'reverse_item_id_map': {str(k): str(v) for k, v in dataset.reverse_item_id_map.items()}
    }
    
    with open(f'{path}id_mappings.json', 'w') as f:
        json.dump(json_mappings, f)

In [9]:
save_mappings(dataset)

In [10]:
class MultimodalRecommendationModel(nn.Module):
    """A multimodal model for text and recommendation tasks."""
    def __init__(self, text_model_name, user_vocab_size, items_vocab_size, id_embed_dim=32, text_embed_dim=768):
        super(MultimodalRecommendationModel, self).__init__()
        # Text-based model
        self.text_model = AutoModel.from_pretrained(text_model_name)

        # Embeddings for user IDs and items IDs
        self.user_id_embeddings = nn.Embedding(user_vocab_size, id_embed_dim)
        self.items_id_embeddings = nn.Embedding(items_vocab_size, id_embed_dim)

        # Fusion layers
        self.user_fusion = nn.Linear(text_embed_dim + id_embed_dim, text_embed_dim)
        self.items_fusion = nn.Linear(text_embed_dim + id_embed_dim, text_embed_dim)

    def forward(self, items_text_inputs, user_text_inputs, item_ids, user_id):
        # Text embeddings
        items_text_embeddings = self.text_model(**items_text_inputs).last_hidden_state.mean(dim=1)
        user_text_embeddings = self.text_model(**user_text_inputs).last_hidden_state.mean(dim=1)
    
        # ID embeddings with range check
        if torch.max(items_ids) >= self.items_id_embeddings.num_embeddings:
            raise ValueError(f"items_ids contains invalid index: {torch.max(items_ids).item()}")
        if torch.max(user_id) >= self.user_id_embeddings.num_embeddings:
            raise ValueError(f"user_id contains invalid index: {torch.max(user_id).item()}")
    
        items_id_embeddings = self.items_id_embeddings(items_ids).mean(dim=1)
        user_id_embedding = self.user_id_embeddings(user_id)
    
        # Fusion
        items_embeddings = self.items_fusion(torch.cat([items_text_embeddings, items_id_embeddings], dim=-1))
        user_embeddings = self.user_fusion(torch.cat([user_text_embeddings, user_id_embedding], dim=-1))
    
        return items_embeddings, user_embeddings

In [11]:
# Пример данных для размерностей
user_vocab_size = len(id_history['viewer_uid'].unique())  # Количество уникальных пользователей
items_vocab_size = len(id_history['clean_video_id'].explode().unique())  # Количество уникальных товаров

# Размеры текстовых эмбеддингов и ID-эмбеддингов
text_embed_dim = 1024  # Обычно совпадает с размерностью текстовой модели
id_embed_dim = 32

In [12]:
# Инициализация модели
model = MultimodalRecommendationModel(
    text_model_name=model_dir,
    user_vocab_size=user_vocab_size,
    items_vocab_size=items_vocab_size,
    id_embed_dim=id_embed_dim,
    text_embed_dim=text_embed_dim
)

In [13]:
# Loss functions
contrastive_loss_fn = nn.CosineEmbeddingLoss()
recommendation_loss_fn = nn.CrossEntropyLoss()

In [14]:
optimizer = optim.AdamW(model.parameters(), lr=5e-5)
model.to(device)

MultimodalRecommendationModel(
  (text_model): XLMRobertaModel(
    (embeddings): XLMRobertaEmbeddings(
      (word_embeddings): Embedding(250002, 1024, padding_idx=1)
      (position_embeddings): Embedding(514, 1024, padding_idx=1)
      (token_type_embeddings): Embedding(1, 1024)
      (LayerNorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): XLMRobertaEncoder(
      (layer): ModuleList(
        (0-23): 24 x XLMRobertaLayer(
          (attention): XLMRobertaAttention(
            (self): XLMRobertaSelfAttention(
              (query): Linear(in_features=1024, out_features=1024, bias=True)
              (key): Linear(in_features=1024, out_features=1024, bias=True)
              (value): Linear(in_features=1024, out_features=1024, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): XLMRobertaSelfOutput(
              (dense): Linear(in_features=1024, out_featu

In [15]:
model_device = next(model.parameters()).device
print(f"Model is on device: {model_device}")

Model is on device: cuda:0


In [16]:
epochs = 1
lambda_rec = 0.2

In [17]:
for epoch in range(epochs):
    model.train()
    total_loss = 0
    total_contrastive_loss = 0
    total_recommendation_loss = 0

    # Iterate through batches
    for batch in tqdm(dataloader, desc=f"Epoch {epoch + 1}/{epochs}"):
        items_text_inputs, user_text_inputs, item_ids, user_ids = batch
        # Move to device
        items_text_inputs = {key: val.to(device) for key, val in items_text_inputs.items()}
        user_text_inputs = {key: val.to(device) for key, val in user_text_inputs.items()}
        items_ids = item_ids.to(device)
        user_ids = user_ids.to(device)

        # Forward pass
        items_embeddings, user_embeddings = model(items_text_inputs, user_text_inputs, item_ids, user_ids)

        # Recommendation loss (user-to-items matching)
        logits = torch.matmul(user_embeddings, items_embeddings.T)  # Cosine similarity
        labels = torch.arange(len(user_embeddings)).to(device)  # Positive examples are diagonal
        recommendation_loss = recommendation_loss_fn(logits, labels)

        # Contrastive loss (items-to-items and user-to-user)
        # Example: Assuming items_pairs and user_pairs are available
        # You need to provide these pairs from your dataset or dataloader.
        positive_labels = torch.ones(items_embeddings.size(0), device=device)
        negative_labels = -torch.ones(items_embeddings.size(0), device=device)
        
        # Contrastive losses (dummy placeholders, replace with your real logic)
        contrastive_goods_loss = contrastive_loss_fn(items_embeddings, items_embeddings, positive_labels)
        contrastive_users_loss = contrastive_loss_fn(user_embeddings, user_embeddings, negative_labels)
        contrastive_loss = contrastive_goods_loss + contrastive_users_loss

        # Total loss
        loss = contrastive_loss + lambda_rec * recommendation_loss

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Accumulate loss
        total_loss += loss.item()
        total_contrastive_loss += contrastive_loss.item()
        total_recommendation_loss += recommendation_loss.item()

  item_ids = pad_sequence([torch.tensor(x, dtype=torch.int64) for x in item_ids], batch_first=True, padding_value=0)
Epoch 1/1:   0%|          | 13/28375 [00:32<19:43:31,  2.50s/it]


KeyboardInterrupt: 