# Welcome to Modal notebooks!

Write Python code and collaborate in real time. Your code runs in Modal's
**serverless cloud**, and anyone in the same workspace can join.

This notebook comes with some common Python libraries installed. Run
cells with `Shift+Enter`.

In [1]:
!pip install numpy
!pip install torch torchvision --index-url https://download.pytorch.org/whl/cu126
!pip install tqdm
!pip install transformers
!pip install tensorboard

Defaulting to user installation because normal site-packages is not writeable

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.2[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython3 -m pip install --upgrade pip[0m
Defaulting to user installation because normal site-packages is not writeable
Looking in indexes: https://download.pytorch.org/whl/cu126

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.2[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython3 -m pip install --upgrade pip[0m
Defaulting to user installation because normal site-packages is not writeable

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.2[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;

In [2]:
!which python
!which python3
!which pip
!which pip3
import sys
print(sys.executable)
print(sys.path)


/usr/bin/python
/usr/bin/python3
/usr/local/bin/pip
/usr/local/bin/pip3
/usr/bin/python3
['/usr/lib/python310.zip', '/usr/lib/python3.10', '/usr/lib/python3.10/lib-dynload', '', '/home/admin/.local/lib/python3.10/site-packages', '/usr/local/lib/python3.10/dist-packages', '/usr/lib/python3/dist-packages']


In [3]:
import sys
sys.path.append("/home/admin/.local/lib/python3.10/site-packages")
import torch
print(torch.__version__)


2.9.1+cu126


In [4]:
import os
import json
import shutil
import warnings
from collections import defaultdict

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
import numpy as np
from torch.cuda.amp import autocast, GradScaler

from tqdm.autonotebook import tqdm
from transformers import AutoModel, AutoTokenizer, get_linear_schedule_with_warmup

warnings.filterwarnings("ignore")

  from tqdm.autonotebook import tqdm


In [5]:
class Config:
    # Paths
    DATA_DIR = "/workspace/Data Warehouse"
    CORPUS_FILE = os.path.join(DATA_DIR, "ReCDS_benchmark/PAR/corpus.jsonl")
    TRAIN_QUERIES = os.path.join(DATA_DIR, "ReCDS_benchmark/queries/train_queries.jsonl")
    TRAIN_QRELS = os.path.join(DATA_DIR, "ReCDS_benchmark/PAR/qrels_train.tsv")
    DEV_QUERIES = os.path.join(DATA_DIR, "ReCDS_benchmark/queries/dev_queries.jsonl")
    DEV_QRELS = os.path.join(DATA_DIR, "ReCDS_benchmark/PAR/qrels_dev.tsv")
    BM25_HARD_NEGS_FILE = os.path.join(DATA_DIR, "ReCDS_benchmark/PAR/gold/bm25_hard_negs.json")
    PAIRS_TRAIN_FILE = os.path.join(DATA_DIR, "ReCDS_benchmark/PAR/gold/pairs_train.jsonl")

    # Model
    MODEL_NAME = "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext"
    MAX_LENGTH = 384
    EMBEDDING_DIM = 768
    POOLING = "mean"

    # Training
    BATCH_SIZE = 324
    NUM_EPOCHS = 10
    LEARNING_RATE = 2e-5
    WEIGHT_DECAY = 0.01
    WARMUP_RATIO = 0.1
    MAX_GRAD_NORM = 1.0
    TEMPERATURE = 0.05
    NUM_HARD_NEGATIVES = 2
    USE_MIXED_PRECISION = True
    # System
    NUM_WORKERS = 120
    CHECKPOINT_DIR = os.path.join(DATA_DIR, "ReCDS_benchmark/PAR/checkpoints")
    LOG_DIR = os.path.join(DATA_DIR, "ReCDS_benchmark/PAR/logs")

In [6]:
class PARDatasetOptimized(Dataset):
    def __init__(self, queries_file, qrels_file, corpus_file, tokenizer, max_length=512,
                 bm25_hard_negs_file=None, pairs_train_file=None, num_hard_negatives=2):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.num_hard_negatives = num_hard_negatives

        print("Loading queries...")
        self.queries = {}
        with open(queries_file, 'r', encoding='utf-8') as f:
            for line in f:
                query = json.loads(line.strip())
                self.queries[query['_id']] = query['text']
        print(f"Loaded {len(self.queries)} queries")

        # Load BM25 hard negatives
        # self.bm25_hard_negs = {}
        # if bm25_hard_negs_file:
        #     print(f"Loading BM25 hard negatives...")
        #     with open(bm25_hard_negs_file, 'r', encoding='utf-8') as f:
        #         self.bm25_hard_negs = json.load(f)
        #     print(f"Loaded BM25 hard negatives for {len(self.bm25_hard_negs)} queries")

        # Load training pairs
        # if pairs_train_file:
        print(f"Loading training pairs...")
        self.pairs = []
        with open(pairs_train_file, 'r', encoding='utf-8') as f:
            for line in f:
                pair = json.loads(line.strip())
                self.pairs.append({
                    'query_id': pair['query_id'],
                    'pos_id': pair['pos_id'],
                    'neg_ids': pair.get('neg_ids', [])
                })
        print(f"Loaded {len(self.pairs)} training pairs")
        # else:
        #     print("Loading qrels...")
        #     qrels = {}
        #     with open(qrels_file, 'r', encoding='utf-8') as f:
        #         for line in f:
        #             parts = line.strip().split('\t')
        #             if len(parts) == 3:
        #                 query_id, doc_id, relevance = parts
        #                 try:
        #                     if int(relevance) > 0:
        #                         if query_id not in qrels:
        #                             qrels[query_id] = []
        #                         qrels[query_id].append(doc_id)
        #                 except ValueError:
        #                     continue

            # print("Creating training pairs...")
            # self.pairs = []
            # for query_id, doc_ids in qrels.items():
            #     if query_id in self.queries:
            #         for doc_id in doc_ids:
            #             neg_ids = self.bm25_hard_negs.get(query_id, [])[:self.num_hard_negatives]
            #             self.pairs.append({
            #                 'query_id': query_id,
            #                 'pos_id': doc_id,
            #                 'neg_ids': neg_ids
            #             })
            # print(f"Created {len(self.pairs)} pairs")

        # Load required documents
        needed_doc_ids = set()
        for pair in self.pairs:
            needed_doc_ids.add(pair['pos_id'])
            needed_doc_ids.update(pair['neg_ids'])
        print(f"Need to load {len(needed_doc_ids)} documents")

        print(f"Loading documents...")
        self.corpus = {}
        loaded_count = 0
        with open(corpus_file, 'r', encoding='utf-8') as f:
            for line in f:
                doc = json.loads(line.strip())
                doc_id = str(doc.get('_id', ''))
                if doc_id in needed_doc_ids:
                    title = doc.get('title', '').strip()
                    abstract = doc.get('text', '').strip()
                    if title or abstract:
                        self.corpus[doc_id] = {'title': title, 'abstract': abstract}
                        loaded_count += 1
                # if loaded_count >= len(needed_doc_ids):
                #     break

        print(f"Loaded {len(self.corpus)} documents")

        # Filter pairs
        self.pairs = [
            {
                'query_id': p['query_id'],
                'pos_id': p['pos_id'],
                'neg_ids': [n for n in p['neg_ids'] if n in self.corpus]
            }
            for p in self.pairs
            if p['query_id'] in self.queries and p['pos_id'] in self.corpus
        ]
        print(f"Filtered to {len(self.pairs)} valid pairs")

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        pair = self.pairs[idx]
        query_id = pair['query_id']
        pos_id = pair['pos_id']
        neg_ids = pair['neg_ids']
    
        # âœ… Kiá»ƒm tra dá»¯ liá»‡u bá»‹ thiáº¿u
        if query_id not in self.queries or pos_id not in self.corpus:
            return None
    
        query_text = self.queries[query_id]
        pos_doc = self.corpus[pos_id]
        pos_text = f"{pos_doc['title']} {pos_doc['abstract']}".strip()
    
        query_enc = self.tokenizer(
            query_text, max_length=self.max_length, padding='max_length',
            truncation=True, return_tensors='pt'
        )
        pos_enc = self.tokenizer(
            pos_text, max_length=self.max_length, padding='max_length',
            truncation=True, return_tensors='pt'
        )
    
        result = {
            'query_input_ids': query_enc['input_ids'].squeeze(0),
            'query_attention_mask': query_enc['attention_mask'].squeeze(0),
            'pos_doc_input_ids': pos_enc['input_ids'].squeeze(0),
            'pos_doc_attention_mask': pos_enc['attention_mask'].squeeze(0),
        }
    
        # Encode hard negatives náº¿u cÃ³
        if neg_ids:
            neg_ids_list = []
            neg_masks_list = []
            for neg_id in neg_ids[:self.num_hard_negatives]:
                if neg_id in self.corpus:
                    neg_doc = self.corpus[neg_id]
                    neg_text = f"{neg_doc['title']} {neg_doc['abstract']}".strip()
                    neg_enc = self.tokenizer(
                        neg_text, max_length=self.max_length,
                        padding='max_length', truncation=True, return_tensors='pt'
                    )
                    neg_ids_list.append(neg_enc['input_ids'].squeeze(0))
                    neg_masks_list.append(neg_enc['attention_mask'].squeeze(0))
    
            if neg_ids_list:
                result['neg_doc_input_ids'] = torch.stack(neg_ids_list)
                result['neg_doc_attention_mask'] = torch.stack(neg_masks_list)
            else:
                result['neg_doc_input_ids'] = torch.empty(0, self.max_length, dtype=torch.long)
                result['neg_doc_attention_mask'] = torch.empty(0, self.max_length, dtype=torch.long)
        else:
            result['neg_doc_input_ids'] = torch.empty(0, self.max_length, dtype=torch.long)
            result['neg_doc_attention_mask'] = torch.empty(0, self.max_length, dtype=torch.long)
    
        return result


In [7]:
class BiEncoder(nn.Module):
    def __init__(self, model_name, embedding_dim, pooling='mean'):
        super().__init__()
        self.query_encoder = AutoModel.from_pretrained(model_name)
        self.doc_encoder = AutoModel.from_pretrained(model_name)
        self.pooling = pooling

    def pool_embeddings(self, last_hidden_state, attention_mask):
        if self.pooling == 'cls':
            return last_hidden_state[:, 0, :]
        elif self.pooling == 'mean':
            token_embeddings = last_hidden_state
            input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
            sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
            sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
            return sum_embeddings / sum_mask
        elif self.pooling == 'max':
            token_embeddings = last_hidden_state
            input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
            token_embeddings[input_mask_expanded == 0] = -1e9
            return torch.max(token_embeddings, 1)[0]

    def encode_query(self, input_ids, attention_mask):
        outputs = self.query_encoder(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
        pooled = self.pool_embeddings(outputs.last_hidden_state, attention_mask)
        return F.normalize(pooled, p=2, dim=1)

    def encode_doc(self, input_ids, attention_mask):
        outputs = self.doc_encoder(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
        pooled = self.pool_embeddings(outputs.last_hidden_state, attention_mask)
        return F.normalize(pooled, p=2, dim=1)

    def forward(self, query_input_ids=None, query_attention_mask=None,
                doc_input_ids=None, doc_attention_mask=None, mode='dual'):
        if mode == 'dual':
            return self.encode_query(query_input_ids, query_attention_mask), \
                   self.encode_doc(doc_input_ids, doc_attention_mask)
        elif mode == 'doc':
            return self.encode_doc(doc_input_ids, doc_attention_mask)
        elif mode == 'query':
            return self.encode_query(query_input_ids, query_attention_mask)

In [8]:
class InfoNCELoss(nn.Module):
    def __init__(self, temperature=0.05):
        super().__init__()
        self.temperature = temperature
        self.criterion = nn.CrossEntropyLoss()

    def forward(self, query_embeddings, doc_embeddings, hard_neg_embeddings=None):
        batch_size = query_embeddings.size(0)
        similarity_matrix = torch.matmul(query_embeddings, doc_embeddings.T) / self.temperature

        if hard_neg_embeddings is not None and hard_neg_embeddings.size(1) > 0:
            hard_neg_sim = torch.bmm(
                query_embeddings.unsqueeze(1),
                hard_neg_embeddings.transpose(1, 2)
            ).squeeze(1) / self.temperature
            similarity_matrix = torch.cat([similarity_matrix, hard_neg_sim], dim=1)

        labels = torch.arange(batch_size, device=query_embeddings.device)
        loss_q2d = self.criterion(similarity_matrix, labels)
        
        doc2query_sim = torch.matmul(doc_embeddings, query_embeddings.T) / self.temperature
        loss_d2q = self.criterion(doc2query_sim, labels)
        
        return (loss_q2d + loss_d2q) / 2.0

In [9]:
def compute_metrics(query_embeddings, doc_embeddings, query_ids, doc_ids, qrels_dict, k_values=[10, 100, 1000]):
    similarity_matrix = torch.matmul(query_embeddings, doc_embeddings.T).cpu().numpy()
    metrics = defaultdict(list)

    for i, query_id in enumerate(query_ids):
        relevant_docs = qrels_dict.get(query_id, set())
        if not relevant_docs:
            continue

        scores = similarity_matrix[i]
        sorted_indices = np.argsort(-scores)
        sorted_doc_ids = [doc_ids[idx] for idx in sorted_indices]

        for k in k_values:
            top_k_docs = sorted_doc_ids[:k]
            num_relevant = len(set(top_k_docs) & relevant_docs)
            
            metrics[f'Recall@{k}'].append(num_relevant / len(relevant_docs))
            metrics[f'P@{k}'].append(num_relevant / k)
            
            dcg = sum(1.0 / np.log2(rank + 2) for rank, doc_id in enumerate(top_k_docs) if doc_id in relevant_docs)
            idcg = sum(1.0 / np.log2(rank + 2) for rank in range(min(len(relevant_docs), k)))
            metrics[f'nDCG@{k}'].append(dcg / idcg if idcg > 0 else 0.0)

        for rank, doc_id in enumerate(sorted_doc_ids, start=1):
            if doc_id in relevant_docs:
                metrics['MRR'].append(1.0 / rank)
                break
        else:
            metrics['MRR'].append(0.0)

    return {name: np.mean(values) for name, values in metrics.items()}


def evaluate_full_corpus(model, dataset, device, batch_size=32, use_amp=False):
    """
    ThÃªm tham sá»‘ use_amp Ä‘á»ƒ há»— trá»£ mixed precision trong quÃ¡ trÃ¬nh evaluation
    """
    from torch.cuda.amp import autocast
    
    model.eval()
    tokenizer = dataset.tokenizer
    max_length = dataset.max_length

    # Build qrels
    qrels_dict = defaultdict(set)
    for pair in dataset.pairs:
        qrels_dict[pair['query_id']].add(pair['pos_id'])

    # Encode queries
    query_ids = list(dataset.queries.keys())
    query_embeddings_list = []
    print("Encoding queries...")
    for i in tqdm(range(0, len(query_ids), batch_size)):
        batch_ids = query_ids[i:i + batch_size]
        batch_texts = [dataset.queries[qid] for qid in batch_ids]
        enc = tokenizer(batch_texts, max_length=max_length, padding='max_length', 
                       truncation=True, return_tensors='pt')
        with torch.no_grad():
            if use_amp:
                with autocast():
                    embs = model(query_input_ids=enc['input_ids'].to(device),
                                query_attention_mask=enc['attention_mask'].to(device), mode='query')
            else:
                embs = model(query_input_ids=enc['input_ids'].to(device),
                            query_attention_mask=enc['attention_mask'].to(device), mode='query')
            query_embeddings_list.append(embs.cpu())
    query_embeddings = torch.cat(query_embeddings_list, dim=0)

    # Encode documents
    doc_ids = list(dataset.corpus.keys())
    doc_embeddings_list = []
    print("Encoding documents...")
    for i in tqdm(range(0, len(doc_ids), batch_size)):
        batch_ids = doc_ids[i:i + batch_size]
        batch_texts = [f"{dataset.corpus[did]['title']} {dataset.corpus[did]['abstract']}".strip() 
                      for did in batch_ids]
        enc = tokenizer(batch_texts, max_length=max_length, padding='max_length',
                       truncation=True, return_tensors='pt')
        with torch.no_grad():
            if use_amp:
                with autocast():
                    embs = model(doc_input_ids=enc['input_ids'].to(device),
                                doc_attention_mask=enc['attention_mask'].to(device), mode='doc')
            else:
                embs = model(doc_input_ids=enc['input_ids'].to(device),
                            doc_attention_mask=enc['attention_mask'].to(device), mode='doc')
            doc_embeddings_list.append(embs.cpu())
    doc_embeddings = torch.cat(doc_embeddings_list, dim=0)

    return compute_metrics(query_embeddings, doc_embeddings, query_ids, doc_ids, qrels_dict)



def save_checkpoint(filepath, epoch, step, model, optimizer, loss, ndcg_at_10=None, scaler=None):
    model_state = model.module.state_dict() if isinstance(model, nn.DataParallel) else model.state_dict()
    checkpoint = {
        'epoch': epoch,
        'step': step,
        'loss': loss,
        'ndcg_at_10': ndcg_at_10 if ndcg_at_10 is not None else 0.0,
        'model_state_dict': model_state,
        'optimizer_state_dict': optimizer.state_dict(),
    }
    
    # Save scaler state if using mixed precision
    if scaler is not None:
        checkpoint['scaler_state_dict'] = scaler.state_dict()

    last_path = os.path.join(filepath, "last_model.pt")
    torch.save(checkpoint, last_path)

    best_path = os.path.join(filepath, "best_model.pt")
    if not os.path.exists(best_path):
        torch.save(checkpoint, best_path)
    else:
        best_checkpoint = torch.load(best_path, map_location='cpu')
        best_ndcg = best_checkpoint.get('ndcg_at_10', 0.0)
        if ndcg_at_10 is not None and ndcg_at_10 > best_ndcg:
            torch.save(checkpoint, best_path)
            print(f"Updated best model: nDCG@10 {best_ndcg:.4f} â†’ {ndcg_at_10:.4f}")


def load_checkpoint(filepath, model, optimizer, device, scaler=None):
    last_path = "last_model.pt"
    best_path = "best_model.pt"
    import numpy as np
    torch.serialization.add_safe_globals([np._core.multiarray.scalar])
    model_to_load = model.module if isinstance(model, nn.DataParallel) else model
    
    if os.path.isfile(last_path):
        checkpoint = torch.load(last_path, map_location=device, weights_only=False)
        model_to_load.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        if scaler is not None and 'scaler_state_dict' in checkpoint:
            scaler.load_state_dict(checkpoint['scaler_state_dict'])
        print(f"Loaded from last_model.pt epoch {checkpoint['epoch']}")
        return checkpoint['epoch'] + 1, checkpoint.get('step', 0), \
               checkpoint.get('loss', float('inf')), checkpoint.get('ndcg_at_10', 0.0)
    
    if os.path.isfile(best_path):
        checkpoint = torch.load(best_path, map_location=device, weights_only=False)
        model_to_load.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        if scaler is not None and 'scaler_state_dict' in checkpoint:
            scaler.load_state_dict(checkpoint['scaler_state_dict'])
        print(f"Loaded from best_model.pt epoch {checkpoint['epoch']}")
        return checkpoint['epoch'] + 1, checkpoint.get('step', 0), \
               checkpoint.get('loss', float('inf')), checkpoint.get('ndcg_at_10', 0.0)
    
    print("No checkpoint found, starting from scratch")
    return 0, 0, float('inf'), 0.0

def collate_skip_none(batch):
    batch = [b for b in batch if b is not None]
    if len(batch) == 0:
        return None
    return torch.utils.data.default_collate(batch)
    
def train():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    n_gpus = torch.cuda.device_count()
    print(f"Available GPUs: {n_gpus}")
    
    use_amp = Config.USE_MIXED_PRECISION and device.type == 'cuda'
    if use_amp:
        print("âœ“ Mixed Precision Training ENABLED (FP16) - Expect ~2x speedup")
    else:
        print("âœ— Mixed Precision Training DISABLED")

    tokenizer = AutoTokenizer.from_pretrained(Config.MODEL_NAME, use_fast=True)
    model = BiEncoder(Config.MODEL_NAME, Config.EMBEDDING_DIM, Config.POOLING).to(device)
    
    if n_gpus > 1:
        model = nn.DataParallel(model)
        print(f"Using DataParallel on {n_gpus} GPUs")

    train_dataset = PARDatasetOptimized(
        queries_file=Config.TRAIN_QUERIES,
        qrels_file=Config.TRAIN_QRELS,
        corpus_file=Config.CORPUS_FILE,
        tokenizer=tokenizer,
        max_length=Config.MAX_LENGTH,
        bm25_hard_negs_file=Config.BM25_HARD_NEGS_FILE,
        pairs_train_file=Config.PAIRS_TRAIN_FILE,
        num_hard_negatives=Config.NUM_HARD_NEGATIVES
    )
    train_dataloader = DataLoader(
        dataset=train_dataset,
        batch_size=Config.BATCH_SIZE,
        num_workers=Config.NUM_WORKERS,
        shuffle=True,
        drop_last=False,
        collate_fn=collate_skip_none,
        pin_memory=True if device.type == 'cuda' else False
    )

    dev_dataset = PARDatasetOptimized(
        queries_file=Config.DEV_QUERIES,
        qrels_file=Config.DEV_QRELS,
        corpus_file=Config.CORPUS_FILE,
        tokenizer=tokenizer,
        max_length=Config.MAX_LENGTH,
        bm25_hard_negs_file=None,  # Dev khÃ´ng cáº§n hard negatives
        pairs_train_file=Config.PAIRS_TRAIN_FILE,
        num_hard_negatives=0
    )
    dev_dataloader = DataLoader(
        dataset=dev_dataset,
        batch_size=Config.BATCH_SIZE,
        num_workers=Config.NUM_WORKERS,
        shuffle=False,
        drop_last=False,
        pin_memory=True if device.type == 'cuda' else False
    )

    criterion = InfoNCELoss(Config.TEMPERATURE)
    optimizer = torch.optim.AdamW(model.parameters(), lr=Config.LEARNING_RATE, 
                                 weight_decay=Config.WEIGHT_DECAY)

    # Initialize GradScaler for mixed precision
    scaler = GradScaler() if use_amp else None

    total_steps = len(train_dataloader) * Config.NUM_EPOCHS
    warmup_steps = int(total_steps * Config.WARMUP_RATIO)
    scheduler = get_linear_schedule_with_warmup(optimizer, warmup_steps, total_steps)

    if os.path.isdir(Config.LOG_DIR):
        shutil.rmtree(Config.LOG_DIR)
    os.makedirs(Config.LOG_DIR)
    os.makedirs(Config.CHECKPOINT_DIR, exist_ok=True)

    writer = SummaryWriter(Config.LOG_DIR)
    start_epoch, global_step, best_loss, best_ndcg = load_checkpoint(
        Config.CHECKPOINT_DIR, model, optimizer, device, scaler
    )
    print("load.....")
    print(start_epoch, Config.NUM_EPOCHS)
    for epoch in range(start_epoch, 10):
        model.train()
        all_losses = []

        progress_bar = tqdm(train_dataloader, colour="BLUE")
        for i, batch in enumerate(progress_bar):
            optimizer.zero_grad()
            
            # Mixed precision training
            if use_amp:
                with autocast():
                    query_embeddings, doc_embeddings = model(
                        batch['query_input_ids'].to(device),
                        batch['query_attention_mask'].to(device),
                        batch['pos_doc_input_ids'].to(device),
                        batch['pos_doc_attention_mask'].to(device)
                    )

                    hard_neg_embeddings = None
                    if 'neg_doc_input_ids' in batch and batch['neg_doc_input_ids'].size(1) > 0:
                        neg_ids = batch['neg_doc_input_ids'].to(device)
                        neg_mask = batch['neg_doc_attention_mask'].to(device)
                        batch_size, num_negs, max_len = neg_ids.size()
                        
                        neg_embs = model(doc_input_ids=neg_ids.view(-1, max_len),
                                        doc_attention_mask=neg_mask.view(-1, max_len), mode='doc')
                        hard_neg_embeddings = neg_embs.view(batch_size, num_negs, -1)

                    loss = criterion(query_embeddings, doc_embeddings, hard_neg_embeddings)
                
                # Backward with gradient scaling
                scaler.scale(loss).backward()
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), Config.MAX_GRAD_NORM)
                scaler.step(optimizer)
                scaler.update()
            else:
                # Standard FP32 training
                query_embeddings, doc_embeddings = model(
                    batch['query_input_ids'].to(device),
                    batch['query_attention_mask'].to(device),
                    batch['pos_doc_input_ids'].to(device),
                    batch['pos_doc_attention_mask'].to(device)
                )

                hard_neg_embeddings = None
                if 'neg_doc_input_ids' in batch and batch['neg_doc_input_ids'].size(1) > 0:
                    neg_ids = batch['neg_doc_input_ids'].to(device)
                    neg_mask = batch['neg_doc_attention_mask'].to(device)
                    batch_size, num_negs, max_len = neg_ids.size()
                    
                    neg_embs = model(doc_input_ids=neg_ids.view(-1, max_len),
                                    doc_attention_mask=neg_mask.view(-1, max_len), mode='doc')
                    hard_neg_embeddings = neg_embs.view(batch_size, num_negs, -1)

                loss = criterion(query_embeddings, doc_embeddings, hard_neg_embeddings)
                
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), Config.MAX_GRAD_NORM)
                optimizer.step()
            
            scheduler.step()
            
            progress_bar.set_description(
                f"Epoch {epoch+1}/{Config.NUM_EPOCHS} Loss {loss.item():.4f} LR {scheduler.get_last_lr()[0]:.2e}"
            )

            all_losses.append(loss.item())
            writer.add_scalar("Train/loss", loss.item(), global_step)
            writer.add_scalar("Train/learning_rate", scheduler.get_last_lr()[0], global_step)
            global_step += 1
            if global_step % 1200 == 0:
                ckpt_path = os.path.join(Config.CHECKPOINT_DIR, f"step_{global_step}.pt")
                save_checkpoint(Config.CHECKPOINT_DIR, epoch, global_step, model, optimizer, 
                                loss.item(), None, scaler)
                print(f"ðŸ’¾ Saved checkpoint at step {global_step} â†’ {ckpt_path}")
        train_loss = np.mean(all_losses)
        print(f"\nEpoch {epoch+1} - Train Loss: {train_loss:.4f}")

        # Validation
        model.eval()
        dev_losses = []
        with torch.no_grad():
            for batch in tqdm(dev_dataloader, desc="Dev loss"):
                if use_amp:
                    with autocast():
                        query_embs, doc_embs = model(
                            batch['query_input_ids'].to(device),
                            batch['query_attention_mask'].to(device),
                            batch['pos_doc_input_ids'].to(device),
                            batch['pos_doc_attention_mask'].to(device)
                        )
                        loss = criterion(query_embs, doc_embs, None)
                else:
                    query_embs, doc_embs = model(
                        batch['query_input_ids'].to(device),
                        batch['query_attention_mask'].to(device),
                        batch['pos_doc_input_ids'].to(device),
                        batch['pos_doc_attention_mask'].to(device)
                    )
                    loss = criterion(query_embs, doc_embs, None)
                dev_losses.append(loss.item())

        dev_loss = np.mean(dev_losses)
        print(f"Epoch {epoch+1} - Dev Loss: {dev_loss:.4f}")

        # dev_metrics = evaluate_full_corpus(model, dev_dataset, device, use_amp=use_amp)
        # print(f"\nDev Metrics:")
        # print(f"  MRR:       {dev_metrics.get('MRR', 0):.4f}")
        # print(f"  nDCG@10:   {dev_metrics.get('nDCG@10', 0):.4f}")
        # print(f"  Recall@1k: {dev_metrics.get('Recall@1000', 0):.4f}")

        # writer.add_scalar("Dev/loss", dev_loss, epoch)
        # writer.add_scalar("Dev/MRR", dev_metrics.get('MRR', 0), epoch)
        # writer.add_scalar("Dev/nDCG@10", dev_metrics.get('nDCG@10', 0), epoch)
        # writer.add_scalar("Dev/Recall@1k", dev_metrics.get('Recall@1000', 0), epoch)

        save_checkpoint(Config.CHECKPOINT_DIR, epoch, global_step, model, optimizer, 
                       dev_loss, dev_metrics.get('nDCG@10', 0), scaler)

    writer.close()


if __name__ == '__main__':
    print("run")
    train()
    print("xong")

run
Using device: cuda
Available GPUs: 3
âœ“ Mixed Precision Training ENABLED (FP16) - Expect ~2x speedup
Using DataParallel on 3 GPUs
Loading queries...
Loaded 155151 queries
Loading training pairs...
Loaded 1978118 training pairs
Need to load 993196 documents
Loading documents...
Loaded 993196 documents
Filtered to 1978118 valid pairs
Loading queries...
Loaded 5924 queries
Loading training pairs...
Loaded 1978118 training pairs
Need to load 993196 documents
Loading documents...
Loaded 993196 documents
Filtered to 0 valid pairs
Loaded from last_model.pt epoch 0
load.....
1 10


  0%|          | 0/6106 [00:00<?, ?it/s]

ðŸ’¾ Saved checkpoint at step 7200 â†’ /workspace/Data Warehouse/ReCDS_benchmark/PAR/checkpoints/step_7200.pt
ðŸ’¾ Saved checkpoint at step 8400 â†’ /workspace/Data Warehouse/ReCDS_benchmark/PAR/checkpoints/step_8400.pt
ðŸ’¾ Saved checkpoint at step 9600 â†’ /workspace/Data Warehouse/ReCDS_benchmark/PAR/checkpoints/step_9600.pt


KeyboardInterrupt: 

In [None]:
print(1)

In [None]:
import torch
torch.cuda.empty_cache()
