In [None]:
import glob
import random
import gc
import os
import time
from datetime import datetime, timedelta

import numpy as np
import polars as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import IterableDataset, DataLoader
import torch.nn.utils.rnn as rnn_utils

In [2]:
unique_actions = {'favorite',
 'page_view',
 'remove',
 'review_view',
 'to_cart',
 'unfavorite',
 'view_description'}

unique_widgets = {None,
 'addressBookMap.addressBookBar',
 'addressBookMap.addressChangeProcessor',
 'cart.cartSplit',
 'cart.cartSplitShort',
 'cart.controls',
 'cart.sharingCart',
 'cart.split',
 'cart.total',
 'catalog.searchResultsV2',
 'catalog.warlockShelfScroll',
 'club.articleV2',
 'cms.bannerCarousel',
 'cms.cellList',
 'cms.navigationSlider',
 'cms.separator',
 'cms.storyYearSummaries',
 'cms.tileDataSourceWidget',
 'cms.uWidgetObject',
 'common.annotation',
 'common.curtainNavBar',
 'common.islandSeparator',
 'common.text',
 'csma.orderActions',
 'csma.orderDoneButtonBar',
 'csma.reorderCanceledShipment',
 'csma.sellerProducts',
 'csma.shipmentWidget',
 'csma.textBlock',
 'express.cartButtonPopup',
 'express.deliveryWidget',
 'express.deliveryWidgetBigOzon',
 'express.navigationSlider',
 'express.orderItems',
 'express.orderItemsPopup',
 'favorites.listSelector',
 'favorites.searchResultsV2',
 'favorites.sharedListSearchResults',
 'layout.ghost',
 'marketing.bigPromoPDP',
 'marketing.hammers',
 'marketing.sellerProducts',
 'messenger.messenger',
 'messenger.webMessenger',
 'myProfile.sectionMenu',
 'pdp-widget',
 'pdp.apparelNavBar',
 'pdp.aspectTile',
 'pdp.aspectsApparelColor',
 'pdp.aspectsApparelOther',
 'pdp.aspectsApparelSize',
 'pdp.aspectsNoSize',
 'pdp.badgeList',
 'pdp.brand',
 'pdp.characteristics',
 'pdp.collections',
 'pdp.descriptionAccordion',
 'pdp.galleryPreview',
 'pdp.helpfulHints',
 'pdp.inStock',
 'pdp.installmentPurchase',
 'pdp.modelParams',
 'pdp.navBar',
 'pdp.navTitle',
 'pdp.outOfStock',
 'pdp.price',
 'pdp.priceCell',
 'pdp.richContent',
 'pdp.shareLink',
 'pdp.shareWithAddFavorite',
 'pdp.textBlock',
 'pdp.textDescription',
 'pdp.tiles',
 'pdp.title',
 'pdp.webAddToFavorite',
 'pdp.webBrand',
 'pdp.webCharacteristics',
 'pdp.webCompare',
 'pdp.webMobCompare',
 'pdp.webMobRichContent',
 'pdp.webMobTextDescription',
 'pdp.webOutOfStock',
 'pdp.webProductMini',
 'pdp.webSellerList',
 'rpProduct.glueReviewList',
 'rpProduct.listReviews',
 'rpProduct.pinnedReview',
 'rpProduct.singleReview',
 'rpProduct.tilesReviewsList',
 'rpProduct.ugcCounters',
 'rpProduct.userReviews',
 'rpProduct.userReviewsList',
 'rpProduct.webListReviews',
 'rtb.advBanner',
 'rtb.advPageStay',
 'rtb.advVideoBanner',
 'rtb.advVideoBannerMobile',
 'shelf.accessoriesShelf',
 'shelf.analogLookSimilar',
 'shelf.analogShelf',
 'shelf.analogShelfFavorites',
 'shelf.analogShelfPersonal',
 'shelf.analogShelfPersonalPrimary',
 'shelf.analogShelfReviews',
 'shelf.analogShelfSecondary',
 'shelf.analogsShelfReturns',
 'shelf.apparelCart',
 'shelf.apparelOrders',
 'shelf.apparelPersonalFemale',
 'shelf.apparelPersonalKids',
 'shelf.apparelPersonalMale',
 'shelf.apparelPersonalSuggest',
 'shelf.bestsellers',
 'shelf.buyTogether',
 'shelf.cartCheckout',
 'shelf.cartShelf',
 'shelf.filterInfiniteScroll',
 'shelf.freshPersonal',
 'shelf.infiniteScroll',
 'shelf.infiniteScrollSuggests',
 'shelf.kindlyReminder',
 'shelf.oosModelVariants',
 'shelf.ordersShelf',
 'shelf.pdpAccessories',
 'shelf.personalCategoryShelf',
 'shelf.sellerAnalogs',
 'shelf.userCart',
 'shelf.userHistory',
 'shelf.userOrders',
 'shell.promoNavBar',
 'sis.mallBrandProducts',
 'sis.mallSellerProducts',
 'tile.relatedProducts',
 'tile.tileGridMobile',
 'tile.tileScrollMobile',
 'tile.tileShelf'}

In [3]:
success_predecessors = {
    ('to_cart', 'cart.controls'),
    ('to_cart', 'cart.total'),
    ('to_cart', 'express.cartButtonPopup'),
    ('to_cart', 'pdp.webAddToFavorite'),
    ('favorite', 'pdp.webAddToFavorite'),
    ('view_description', 'pdp.textDescription'),
    ('page_view', 'pdp.price')
}

failure_predecessors = {
    ('remove', 'cart.controls'),
    ('remove', 'cart.total'),
    ('unfavorite', 'favorites.listSelector'),
    ('unfavorite', 'pdp.webAddToFavorite'),
    ('remove', 'shelf.userCart')
}

state_map = {}
k = 0

for action in unique_actions:
    for widget in unique_widgets:
        state_map[(action, widget)] = k
        k += 1

In [4]:
succ_pred = set([state_map[pair] for pair in success_predecessors])
fail_pred = set([state_map[pair] for pair in failure_predecessors])

In [5]:
data = np.load("/kaggle/input/mochaaaa/item_embeddings.npz")
item_embeds_dict = dict(zip(data['keys'], data['embeddings']))

In [6]:
del data
gc.collect()

52

In [None]:
_HOURS_SIN = np.sin(2 * np.pi * np.arange(24) / 24, dtype=np.float32)
_HOURS_COS = np.cos(2 * np.pi * np.arange(24) / 24, dtype=np.float32)

class UserHistoryDataset(IterableDataset):
    def __init__(
        self, 
        data_dirs,
        target_dirs,
        succ_set,
        fail_set,
        state_mapping,
        window_size=500,
        stride=250,
        history_len_thresh=10,
    ):
        self.data_dirs = data_dirs if isinstance(data_dirs, list) else [data_dirs]
        self.target_dirs_path = target_dirs
        self.window_size = window_size
        self.stride = stride
        self.history_len_thresh = history_len_thresh
        self.succ_set = frozenset(succ_set)
        self.fail_set = frozenset(fail_set)
        self.state_mapping = state_mapping
        self._precomputed_features = None

    def _precompute_user_data(self, int_file):
        try:
            int_lf = pl.scan_parquet(int_file).select([
                'user_id', 'action_widget', 'action_type', 'item_id', 'timestamp'
            ])
            
            user_ids = int_lf.select(pl.col('user_id').unique()).collect()['user_id']
            if len(user_ids) == 0:
                return

            orders = pl.scan_parquet(self.target_dirs_path).filter(
                pl.col('user_id').is_in(user_ids)
            ).select([
                'user_id', 'created_timestamp', 'last_status_timestamp', 
                'last_status', 'item_id'
            ]).collect()

            user_data = int_lf.collect()
            
            return user_data, orders, user_ids
            
        except Exception as e:
            print(f"Error preprocessing file {int_file}: {e}")
            return None, None, None

    def _process_user(self, user_id, user_data, orders):
        user_history = user_data.filter(pl.col('user_id') == user_id)
        if user_history.is_empty():
            return None

        user_orders = orders.filter(pl.col('user_id') == user_id)
        if user_orders.is_empty():
            return None

        history_len = len(user_history)
        if history_len < self.history_len_thresh:
            return None

        results = []
        if history_len < self.window_size:
            result = self._extract_target(user_history, user_orders)
            if result[0] is not None:
                results.append(result)
        else:
            max_start_idx = history_len - self.window_size
            for i in range(0, max_start_idx + 1, self.stride):
                window = user_history.slice(i, self.window_size)
                result = self._extract_target(window, user_orders)
                if result[0] is not None:
                    results.append(result)
        
        return results

    def _extract_target(self, history_window, user_orders):
        try:
            start_time = history_window['timestamp'].max()
            end_time = start_time + timedelta(days=14)
            
            target_window = user_orders.filter(
                (pl.col('created_timestamp') >= start_time) &
                (pl.col('last_status_timestamp') <= end_time)
            )
            
            if target_window.is_empty():
                return None, None, None, [], []

            delivered_mask = target_window['last_status'] == 'delivered_orders'
            canceled_mask = target_window['last_status'] == 'canceled_orders'
            
            positives = target_window.filter(delivered_mask)['item_id'].to_list()
            negatives = target_window.filter(canceled_mask)['item_id'].to_list()

            if not positives:
                return None, None, None, [], []

            features_result = self._build_user_features(history_window)
            if features_result is None:
                return None, None, None, [], []
                
            padded_features, attention_mask, failure_items = features_result
            all_negative_items = list(set(negatives + failure_items))

            history_item_ids = history_window['item_id'].to_list()
            
            return (padded_features, attention_mask, history_item_ids, positives, all_negative_items)
            
        except Exception as e:
            print(f"Error in _extract_target: {e}")
            return None, None, None, [], []

    def _build_user_features(self, user_df):
        try:
            user_history = user_df.sort('timestamp').select([
                pl.col('action_widget').fill_null("unknown"),
                pl.col('action_type').fill_null("unknown"),
                pl.col('item_id').fill_null(0),
                pl.col('timestamp')
            ])

            widgets = user_history['action_widget'].to_numpy()
            actions = user_history['action_type'].to_numpy()
            item_ids = user_history['item_id'].to_numpy()
            timestamps = user_history['timestamp'].to_numpy().astype(np.int64)
            
            n = len(user_history)
            if n == 0:
                return None

            states = np.array([self.state_mapping.get((a, w), -1) for a, w in zip(actions, widgets)], dtype=np.int32)
            is_success = np.isin(states, self.succ_set)
            is_failure = np.isin(states, self.fail_set)
            negative_items = np.unique(item_ids[is_failure & (item_ids > 0)]).tolist()
            
            ts_min = timestamps[0]
            time_from_start_min = np.clip((timestamps - ts_min) / 6e10, 0, 1e6).astype(np.float32)
            
            time_delta_min = np.zeros(n, dtype=np.float32)
            if n > 1:
                time_delta_min[1:] = np.clip(np.diff(timestamps) / 6e10, -1e6, 1e6)
            
            same_item_as_prev = np.zeros(n, dtype=np.float32)
            same_item_as_prev[1:] = (item_ids[1:] == item_ids[:-1]).astype(np.float32)
            
            same_state_as_prev = np.zeros(n, dtype=np.float32)
            same_state_as_prev[1:] = (states[1:] == states[:-1]).astype(np.float32)
            
            velocity = np.clip(1.0 / (np.abs(time_delta_min) + 1e-6), 0, 1e6).astype(np.float32)
            acceleration = np.zeros(n, dtype=np.float32)
            if n > 1:
                acceleration[1:] = np.clip(np.diff(velocity), -1e6, 1e6)
            
            seconds = timestamps // 1_000_000_000
            hours = (seconds // 3600) % 24
            hour_sin = _HOURS_SIN[hours]
            hour_cos = _HOURS_COS[hours]
            
            cum_success_count = np.cumsum(is_success).astype(np.float32)
            cum_failure_count = np.cumsum(is_failure).astype(np.float32)
            
            window_unique_widgets = np.ones(n, dtype=np.float32)
            for i in range(1, min(20, n)):
                window_unique_widgets[i] = len(np.unique(widgets[max(0, i-19):i+1]))
            
            interaction_counts = np.arange(1, n + 1, dtype=np.float32)
            success_rate = np.divide(cum_success_count, interaction_counts, 
                                   out=np.zeros_like(cum_success_count), 
                                   where=interaction_counts!=0)
            failure_rate = np.divide(cum_failure_count, interaction_counts,
                                   out=np.zeros_like(cum_failure_count),
                                   where=interaction_counts!=0)
            
            numerical_features = np.column_stack([
                time_from_start_min, time_delta_min, velocity, acceleration,
                same_item_as_prev, same_state_as_prev,
                hour_sin, hour_cos,
                cum_success_count, cum_failure_count,
                window_unique_widgets,
                success_rate, failure_rate
            ]).astype(np.float32)
            
            current_len = numerical_features.shape[0]
            target_len = self.window_size
            
            if current_len > target_len:
                numerical_features = numerical_features[:target_len]
                attention_mask = np.ones(target_len, dtype=np.float32)
            else:
                padding_len = target_len - current_len
                padding = np.zeros((padding_len, numerical_features.shape[1]), dtype=np.float32)
                numerical_features = np.concatenate([padding, numerical_features], axis=0)
                attention_mask = np.concatenate([
                    np.zeros(padding_len, dtype=np.float32),
                    np.ones(current_len, dtype=np.float32)
                ])
            
            return numerical_features, attention_mask, negative_items
            
        except Exception as e:
            print(f"Feature building error: {e}")
            return None

    def __iter__(self):
        worker_info = torch.utils.data.get_worker_info()
        
        if worker_info is None:
            files_to_process = self.data_dirs
        else:
            num_workers = worker_info.num_workers
            worker_id = worker_info.id
            num_files = len(self.data_dirs)
            
            files_per_worker = (num_files + num_workers - 1) // num_workers
            start_idx = worker_id * files_per_worker
            end_idx = min(start_idx + files_per_worker, num_files)
            files_to_process = self.data_dirs[start_idx:end_idx]

        for int_file in files_to_process:
            try:
                user_data, orders, user_ids = self._precompute_user_data(int_file)
                if user_data is None:
                    continue
                
                random.shuffle(user_ids)
                
                for user_id in user_ids:
                    results = self._process_user(user_id, user_data, orders)
                    if results:
                        for result in results:
                            yield result
                    
                    if random.random() < 0.1:  # 10% chance
                        gc.collect()
                
                del user_data, orders, user_ids
                gc.collect()
                
            except Exception as e:
                print(f"Error processing file {int_file}: {e}")
                continue

In [None]:
def final_collate_fn(batch):
    global item_embeds_dict
    try:
        padded_features, attention_masks, history_item_ids, pos_item_ids, neg_item_ids = zip(*batch)
        dummy_emb = np.zeros_like(next(iter(item_embeds_dict.values())), dtype=np.float32)

        def convert_ids_to_embeddings(ids_list):
            return [item_embeds_dict.get(id_, dummy_emb) for id_ in ids_list]
        
        history_embs = [convert_ids_to_embeddings(ids) for ids in history_item_ids]  # embeddings истории
        pos_embs = [convert_ids_to_embeddings(ids) for ids in pos_item_ids]          # embeddings позитивов
        
        # Обработка негативов - добавляем случайные если их мало
        processed_neg_embs = []
        all_item_ids = list(item_embeds_dict.keys())
        
        for i, neg_ids in enumerate(neg_item_ids):
            neg_embeddings = convert_ids_to_embeddings(neg_ids)
            
            # Если негативов меньше 5, добавляем случайные
            if len(neg_embeddings) < 5:
                needed_more = 5 - len(neg_embeddings)
                # Берем случайные айтемы, которых нет среди позитивов
                pos_ids_set = set(pos_item_ids[i])
                available_ids = [item_id for item_id in all_item_ids if item_id not in pos_ids_set]
                
                if len(available_ids) > 0:
                    import random
                    random_ids = random.sample(available_ids, min(needed_more, len(available_ids)))
                    random_embeddings = convert_ids_to_embeddings(random_ids)
                    neg_embeddings.extend(random_embeddings)
            
            processed_neg_embs.append(neg_embeddings)
        
        neg_embs = processed_neg_embs
    
        features_tensor = torch.as_tensor(np.stack(padded_features), dtype=torch.float32)
        masks_tensor = torch.as_tensor(np.stack(attention_masks), dtype=torch.float32)
        
        return features_tensor, masks_tensor, history_embs, pos_embs, neg_embs
        
    except Exception as e:
        print(f"Error in collate_fn: {e}")
        raise

In [None]:
class ItemTower(nn.Module):
    def __init__(self, input_dim, output_dim=128, dropout_rate=0.2):
        super(ItemTower, self).__init__()
        self.mlp = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.GELU(),
            nn.Dropout(dropout_rate),
            nn.Linear(128, output_dim),
            nn.GELU(),
            nn.Dropout(dropout_rate)
        )

    def forward(self, items):
        return self.mlp(items)

class UserTower(nn.Module):
    def __init__(
        self,
        item_emb_dim,
        user_feature_dim,
        gru_hidden_dim=64,
        user_output_dim=128,
        mlp_hidden_dim=256,
        dropout_rate=0.2
    ):
        super(UserTower, self).__init__()
        
        self.gru_input_dim = item_emb_dim + user_feature_dim
        self.gru = nn.GRU(
            input_size=self.gru_input_dim,
            hidden_size=gru_hidden_dim,
            batch_first=True,
            dropout=dropout_rate if 1 > 1 else 0.0
        )
        
        self.final_mlp = nn.Sequential(
            nn.Linear(gru_hidden_dim, mlp_hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout_rate),
            nn.Linear(mlp_hidden_dim, user_output_dim),
            nn.GELU(),
            nn.Dropout(dropout_rate)
        )
        self.output_dim = user_output_dim

    def forward(self, history_item_embs, user_other_features, attention_mask):
        batch_size, max_seq_len, user_feat_dim = user_other_features.shape
        item_emb_dim = history_item_embs[0].shape[-1]
        
        padded_history_item_embs = []
        for i, item_emb_seq in enumerate(history_item_embs):
            seq_len = item_emb_seq.shape[0]
            if seq_len < max_seq_len:
                padding = torch.zeros((max_seq_len - seq_len, item_emb_dim), device=item_emb_seq.device)
                padded_seq = torch.cat([item_emb_seq, padding], dim=0)
            else:
                padded_seq = item_emb_seq[:max_seq_len]
            padded_history_item_embs.append(padded_seq)
        
        item_embs_tensor = torch.stack(padded_history_item_embs, dim=0)
        
        combined_features = torch.cat([item_embs_tensor, user_other_features], dim=-1)

        lengths = attention_mask.sum(dim=1).long()
        lengths = torch.clamp(lengths, min=1)

        lengths_sorted, sorted_idx = lengths.sort(descending=True)
        combined_features_sorted = combined_features[sorted_idx]
        
        packed = rnn_utils.pack_padded_sequence(
            combined_features_sorted,
            lengths_sorted.cpu(),
            batch_first=True,
            enforce_sorted=True
        )

        packed_gru_out, hidden = self.gru(packed)
        hidden_sorted = hidden.squeeze(0)

        _, original_idx = sorted_idx.sort()
        hidden_original = hidden_sorted[original_idx]

        user_embedding = self.final_mlp(hidden_original)
        
        return user_embedding

class TwoTowerModel(nn.Module):
    def __init__(
        self,
        item_input_dim=128,
        item_output_dim=128,
        user_feature_dim=13,
        gru_hidden_dim=64,
        user_output_dim=128,
        mlp_hidden_dim=256,
        dropout_rate=0.2
    ):
        super(TwoTowerModel, self).__init__()
        
        self.item_emb_dim = item_input_dim
        
        self.user_tower = UserTower(
            item_emb_dim=self.item_emb_dim,
            user_feature_dim=user_feature_dim,
            gru_hidden_dim=gru_hidden_dim,
            user_output_dim=user_output_dim,
            mlp_hidden_dim=mlp_hidden_dim,
            dropout_rate=dropout_rate
        )
        
        self.item_tower = ItemTower(
            input_dim=item_input_dim,
            output_dim=item_output_dim,
            dropout_rate=dropout_rate
        )

    def forward(self, padded_features_tensor, attention_masks_tensor, history_item_embs_tensors, pos_items_tensors, neg_items_tensors):
        batch_size = padded_features_tensor.shape[0]
        
        user_emb = self.user_tower(
            history_item_embs_tensors,
            padded_features_tensor,
            attention_masks_tensor
        )
        
        pos_scores = []
        neg_scores = []

        for i in range(batch_size):
            # Позитивные
            if pos_items_tensors[i].shape[0] > 0:
                pos_item_emb = self.item_tower(pos_items_tensors[i])
                scores_pos = torch.matmul(user_emb[i], pos_item_emb.T)
            else:
                scores_pos = torch.tensor([], dtype=torch.float32, device=user_emb.device)
            pos_scores.append(scores_pos)

            # Негативные
            if neg_items_tensors[i].shape[0] > 0:
                neg_item_emb = self.item_tower(neg_items_tensors[i])
                scores_neg = torch.matmul(user_emb[i], neg_item_emb.T)
            else:
                scores_neg = torch.tensor([], dtype=torch.float32, device=user_emb.device)
            neg_scores.append(scores_neg)
        
        return pos_scores, neg_scores

class PairwiseLogisticLoss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, pos_scores, neg_scores):
        total_loss = 0.0
        total_pairs = 0
        for p_scores, n_scores in zip(pos_scores, neg_scores):
            if p_scores.numel() == 0 or n_scores.numel() == 0:
                continue
            diff = p_scores.unsqueeze(1) - n_scores.unsqueeze(0)
            loss_per_pair = torch.log(1 + torch.exp(-diff.clamp(max=10))).clamp(max=10)
            total_loss += loss_per_pair.sum()
            total_pairs += diff.numel()
        return total_loss / (total_pairs + 1e-9) if total_pairs > 0 else torch.tensor(0.0, requires_grad=True)

In [10]:
BATCH_SIZE = 128
NUM_WORKERS = 4
LEARNING_RATE = 1e-3
NUM_EPOCHS = 50
SAVE_PATH = "./checkpoints"
BEST_MODEL_FILENAME = "best_model.pth"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [11]:
print(f"Device: {DEVICE}")
print(f"Batch size: {BATCH_SIZE}")

Device: cuda
Batch size: 128


In [12]:
#paths
data_dirs_path1 = "/kaggle/input/user-chunks-pt2/ultra_mega_user_chunks_pt1"
data_dirs_path2 = "/kaggle/input/user-chunks-pt1"
target_dirs_path = "/kaggle/input/ozon-data/ml_ozon_recsys_train_final_apparel_orders_data/ml_ozon_recsys_train_final_apparel_orders_data"
item_embeds_path = "/kaggle/input/mochaaaa/item_embeddings.npz"
checkpoint_path = "/kaggle/input/modelka20/pytorch/default/1/checkpoints/checkpoint_batch_2460.pth"
checkpoint = torch.load(checkpoint_path, map_location=DEVICE)
data_dirs = glob.glob(os.path.join(data_dirs_path1, "*.parquet")) +\
    glob.glob(os.path.join(data_dirs_path2, "*.parquet"))

target_dirs = glob.glob(os.path.join(target_dirs_path, "*.parquet"))

In [None]:
print(f"Starting fine-tuning at {datetime.now()}")
print(f"Using device: {DEVICE}")
print("-" * 40)

dataset = UserHistoryDataset(
    data_dirs, 
    target_dirs, 
    succ_pred,
    fail_pred,
    state_map,
    window_size=500,
    stride=250
)

dataloader = DataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    collate_fn=final_collate_fn,
    persistent_workers=True if NUM_WORKERS > 0 else False
)
print("DataLoader is ready.")
print("-" * 40)

print("Creating model, loss, and optimizer...")
ITEM_INPUT_DIM = 128       
USER_FEATURE_DIM = 13      
ITEM_OUTPUT_DIM = 128
USER_OUTPUT_DIM = 128

model = TwoTowerModel(
    item_input_dim=ITEM_INPUT_DIM,
    item_output_dim=ITEM_OUTPUT_DIM,
    user_feature_dim=USER_FEATURE_DIM,
    gru_hidden_dim=64,
    user_output_dim=USER_OUTPUT_DIM,
    mlp_hidden_dim=128,
    dropout_rate=0.2
).to(DEVICE)

FINE_TUNE_LR = 5e-5
criterion = PairwiseLogisticLoss()
optimizer = optim.Adam(model.parameters(), lr=FINE_TUNE_LR, weight_decay=1e-6)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=800, gamma=0.9)
 
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

print("Model, loss, and optimizer are ready.")
print("-" * 40)

print("Starting fine-tuning loop...")
best_loss = float('inf')
total_batches_processed = 0

for epoch in range(NUM_EPOCHS):
    model.train()
    total_loss = 0.0
    num_batches = 0
    epoch_start_time = time.time()
    
    print(f"\n Epoch {epoch+1}/{NUM_EPOCHS}")
    print("-" * 20)

    for batch_idx, batch in enumerate(dataloader):
        try:
            padded_features_tuple, masks_tuple, history_item_embs_tuple, pos_items_tuple, neg_items_tuple = batch
            
            padded_features_tensor = torch.tensor(np.stack(padded_features_tuple), dtype=torch.float32)
            attention_masks_tensor = torch.tensor(np.stack(masks_tuple), dtype=torch.float32)

            padded_features_tensor = padded_features_tensor.to(DEVICE)
            attention_masks_tensor = attention_masks_tensor.to(DEVICE)
            
            history_item_embs_tensors = []
            for emb_list in history_item_embs_tuple:
                if len(emb_list) > 0:
                    if isinstance(emb_list[0], np.ndarray):
                        emb_array = np.array(emb_list)
                    else:
                        emb_array = np.array([emb_list]) if np.isscalar(emb_list[0]) else np.array(emb_list)
                    emb_tensor = torch.tensor(emb_array, dtype=torch.float32).to(DEVICE)
                else:
                    emb_tensor = torch.zeros((1, ITEM_INPUT_DIM), dtype=torch.float32, device=DEVICE)
                history_item_embs_tensors.append(emb_tensor)
            
            pos_items_tensors = []
            for emb_list in pos_items_tuple:
                if len(emb_list) > 0:
                    if isinstance(emb_list[0], np.ndarray):
                        emb_array = np.array(emb_list)
                    else:
                        emb_array = np.array([emb_list]) if np.isscalar(emb_list[0]) else np.array(emb_list)
                    emb_tensor = torch.tensor(emb_array, dtype=torch.float32).to(DEVICE)
                else:
                    emb_tensor = torch.zeros((0, ITEM_INPUT_DIM), dtype=torch.float32, device=DEVICE)
                pos_items_tensors.append(emb_tensor)
            
            neg_items_tensors = []
            for emb_list in neg_items_tuple:
                if len(emb_list) > 0:
                    if isinstance(emb_list[0], np.ndarray):
                        emb_array = np.array(emb_list)
                    else:
                        emb_array = np.array([emb_list]) if np.isscalar(emb_list[0]) else np.array(emb_list)
                    emb_tensor = torch.tensor(emb_array, dtype=torch.float32).to(DEVICE)
                else:
                    emb_tensor = torch.zeros((0, ITEM_INPUT_DIM), dtype=torch.float32, device=DEVICE)
                neg_items_tensors.append(emb_tensor)
            
            optimizer.zero_grad()
            pos_scores, neg_scores = model(
                padded_features_tensor,
                attention_masks_tensor,
                history_item_embs_tensors,
                pos_items_tensors,
                neg_items_tensors
            )
            
            loss = criterion(pos_scores, neg_scores)
            
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)
            optimizer.step()
            
            total_loss += loss.item()
            num_batches += 1
            total_batches_processed += 1

            LOG_INTERVAL = 50
            SAVE_INTERVAL = 20
            
            if batch_idx % LOG_INTERVAL == 0:
                print(f"Batch {batch_idx:04d} | Loss: {loss.item():.4f}")

            if total_batches_processed % SAVE_INTERVAL == 0:
                current_avg_loss = total_loss / num_batches if num_batches > 0 else float('inf')
                print(f"[Batch {total_batches_processed:06d}] Current Avg Loss: {current_avg_loss:.4f}")
                print(f"Current LR: {optimizer.param_groups[0]['lr']:.2e}")
                
                os.makedirs(SAVE_PATH, exist_ok=True)
                checkpoint_path = os.path.join(SAVE_PATH, f"finetune_checkpoint_batch_{total_batches_processed}.pth")
                torch.save({
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler_state_dict': scheduler.state_dict(),
                    'loss': current_avg_loss,
                    'batch_num': total_batches_processed,
                    'epoch': epoch,
                    'lr': optimizer.param_groups[0]['lr']
                }, checkpoint_path)
                print(f"Fine-tune checkpoint saved to {checkpoint_path}")

                del pos_scores, neg_scores, loss
                if DEVICE.type == 'cuda':
                    torch.cuda.empty_cache()
                print(f"Memory cleared after checkpoint.")

            MEMORY_CLEAR_INTERVAL = 100
            if batch_idx % MEMORY_CLEAR_INTERVAL == 0 and batch_idx > 0:
                if DEVICE.type == 'cuda':
                    torch.cuda.empty_cache()
                print(f"Periodic memory clear (batch {batch_idx}).")

        except Exception as e:
            print(f"Error in batch {batch_idx}: {e}")
            if DEVICE.type == 'cuda':
                torch.cuda.empty_cache()
            print(f"Memory cleared after error in batch {batch_idx}.")
            continue

    avg_loss = total_loss / num_batches if num_batches > 0 else float('inf')
    epoch_time = time.time() - epoch_start_time
    
    print(f"Epoch {epoch+1} Summary:")
    print(f"  Avg Loss: {avg_loss:.4f}")
    print(f"  Time: {epoch_time/60:.2f} min")
    print(f"  Current LR: {optimizer.param_groups[0]['lr']:.2e}")

    scheduler.step()

    if avg_loss < best_loss:
        best_loss = avg_loss
        os.makedirs(SAVE_PATH, exist_ok=True)
        torch.save(model.state_dict(), os.path.join(SAVE_PATH, f"best_finetuned_model_epoch_{epoch+1}_loss_{avg_loss:.4f}.pth"))
        print(f"New best fine-tune loss: {best_loss:.4f}")

    if DEVICE.type == 'cuda':
        torch.cuda.empty_cache()
    print(f"Memory cleared at the end of epoch {epoch+1}.")

print("\nFine-tuning finished!")


Starting fine-tuning at 2025-08-30 23:41:40.751307
Using device: cuda
----------------------------------------
DataLoader is ready.
----------------------------------------
Creating model, loss, and optimizer...
Model, loss, and optimizer are ready.
----------------------------------------
Starting fine-tuning loop...

Epoch 1/50
--------------------
Batch 0000 | Loss: 0.7796
[Batch 000020] Current Avg Loss: 0.6322
Current LR: 1.00e-03
Fine-tune checkpoint saved to ./checkpoints/finetune_checkpoint_batch_20.pth
Memory cleared after checkpoint.
[Batch 000040] Current Avg Loss: 0.6009
Current LR: 1.00e-03
Fine-tune checkpoint saved to ./checkpoints/finetune_checkpoint_batch_40.pth
Memory cleared after checkpoint.
Batch 0050 | Loss: 0.5696
[Batch 000060] Current Avg Loss: 0.5855
Current LR: 1.00e-03
Fine-tune checkpoint saved to ./checkpoints/finetune_checkpoint_batch_60.pth
Memory cleared after checkpoint.
[Batch 000080] Current Avg Loss: 0.5693
Current LR: 1.00e-03
Fine-tune checkpoint