# Generate Embeddings (Multi-GPU & Sharding Support)

This notebook generates embeddings for the corpus using a trained Bi-Encoder model. 
It supports:
- **Multi-GPU**: Automatically uses DataParallel if multiple GPUs are detected.
- **Sharding**: Allows splitting the workload across multiple machines/notebooks.

In [None]:
import json
import os
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from transformers import AutoModel, AutoTokenizer
from tqdm.notebook import tqdm

## 1. Configuration

In [None]:
class Config:
    # Paths
    DATA_DIR = "C:/Users/tam/Desktop/Data/Data Warehouse"
    CORPUS_FILE = os.path.join(DATA_DIR, "ReCDS_benchmark/PAR/corpus.jsonl")
    
    # Model
    MODEL_NAME = "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext"
    MAX_LENGTH = 512
    EMBEDDING_DIM = 768
    POOLING = "mean"  # 'cls', 'mean', or 'max'

    # Output
    OUTPUT_DIR = "embeddings_output"
    
    # Execution Parameters
    MODEL_PATH = "best_model.pt"
    BATCH_SIZE = 256
    
    # Distributed Processing (Sharding)
    # Set these values to split work across multiple machines
    TOTAL_SHARDS = 1  # Total number of machines/notebooks
    SHARD_ID = 0      # ID of this machine (0 to TOTAL_SHARDS-1)

## 2. Model Definition (BiEncoder)

In [None]:
class BiEncoder(nn.Module):
    def __init__(self, model_name='microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext',
                 embedding_dim=768, pooling='cls'):
        super(BiEncoder, self).__init__()

        self.query_encoder = AutoModel.from_pretrained(model_name)
        self.doc_encoder = AutoModel.from_pretrained(model_name)

        self.pooling = pooling
        self.embedding_dim = embedding_dim

    def pool_embeddings(self, last_hidden_state, attention_mask):
        if self.pooling == 'cls':
            # Use [CLS] token embedding
            return last_hidden_state[:, 0, :]

        elif self.pooling == 'mean':
            # Mean pooling with attention mask
            token_embeddings = last_hidden_state
            input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
            sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
            sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
            return sum_embeddings / sum_mask

        elif self.pooling == 'max':
            # Max pooling
            token_embeddings = last_hidden_state
            input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
            token_embeddings[input_mask_expanded == 0] = -1e9
            return torch.max(token_embeddings, 1)[0]

        else:
            raise ValueError(f"Unknown pooling method: {self.pooling}")

    def encode_query(self, input_ids, attention_mask):
        # Encode query (patient summary)
        outputs = self.query_encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            return_dict=True
        )
        pooled = self.pool_embeddings(outputs.last_hidden_state, attention_mask)
        return F.normalize(pooled, p=2, dim=1)

    def encode_doc(self, input_ids, attention_mask):
        # Encode document (article title + abstract)
        outputs = self.doc_encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            return_dict=True
        )
        pooled = self.pool_embeddings(outputs.last_hidden_state, attention_mask)
        return F.normalize(pooled, p=2, dim=1)

    def forward(self, query_input_ids=None, query_attention_mask=None,
                doc_input_ids=None, doc_attention_mask=None, mode='dual'):
        if mode == 'dual':
            query_embeddings = self.encode_query(query_input_ids, query_attention_mask)
            doc_embeddings = self.encode_doc(doc_input_ids, doc_attention_mask)
            return query_embeddings, doc_embeddings
        elif mode == 'doc':
            return self.encode_doc(doc_input_ids, doc_attention_mask)
        elif mode == 'query':
            return self.encode_query(query_input_ids, query_attention_mask)
        else:
            raise ValueError(f"Unknown mode: {mode}")

## 3. Dataset & Dataloader

In [None]:
class CorpusDataset(Dataset):
    def __init__(self, corpus_file, tokenizer, max_length=384, start_idx=0, end_idx=None):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.lines = []
        
        print(f"Loading corpus from line {start_idx} to {end_idx if end_idx else 'end'}...")
        with open(corpus_file, 'r', encoding='utf-8') as f:
            # Skip to start_idx
            for _ in range(start_idx):
                next(f, None)
            
            # Read until end_idx
            count = 0
            limit = end_idx - start_idx if end_idx else float('inf')
            
            for line in f:
                if count >= limit:
                    break
                self.lines.append(line)
                count += 1
        print(f"Loaded {len(self.lines)} documents into memory.")

    def __len__(self):
        return len(self.lines)

    def __getitem__(self, idx):
        line = self.lines[idx]
        try:
            doc = json.loads(line)
            doc_id = str(doc['_id'])
            text = f"{doc.get('title', '')} {doc.get('text', '')}".strip()
            
            enc = self.tokenizer(
                text,
                max_length=self.max_length,
                padding='max_length',
                truncation=True,
                return_tensors='pt'
            )
            
            return {
                'doc_id': doc_id,
                'input_ids': enc['input_ids'].squeeze(0),
                'attention_mask': enc['attention_mask'].squeeze(0)
            }
        except Exception as e:
            print(f"Error processing line {idx}: {e}")
            return None

def collate_fn(batch):
    batch = [b for b in batch if b is not None]
    if not batch:
        return None
    
    return {
        'doc_id': [b['doc_id'] for b in batch],
        'input_ids': torch.stack([b['input_ids'] for b in batch]),
        'attention_mask': torch.stack([b['attention_mask'] for b in batch])
    }

## 4. Main Execution

In [None]:
def generate_embeddings():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    num_gpus = torch.cuda.device_count()
    print(f"Available GPUs: {num_gpus}")

    # 1. Calculate Shard Range
    print("Counting total documents in corpus...")
    total_docs = 0
    with open(Config.CORPUS_FILE, 'r', encoding='utf-8') as f:
        for _ in f:
            total_docs += 1
    print(f"Total documents: {total_docs}")
    
    docs_per_shard = math.ceil(total_docs / Config.TOTAL_SHARDS)
    start_idx = Config.SHARD_ID * docs_per_shard
    end_idx = min((Config.SHARD_ID + 1) * docs_per_shard, total_docs)
    
    print(f"Shard {Config.SHARD_ID}/{Config.TOTAL_SHARDS}: Processing documents {start_idx} to {end_idx}")
    
    if start_idx >= total_docs:
        print("Shard ID out of range. Nothing to do.")
        return

    # 2. Load Model
    print(f"Loading model from {Config.MODEL_PATH}...")
    tokenizer = AutoTokenizer.from_pretrained(Config.MODEL_NAME)
    model = BiEncoder(Config.MODEL_NAME, Config.EMBEDDING_DIM, Config.POOLING)
    
    if os.path.exists(Config.MODEL_PATH):
        checkpoint = torch.load(Config.MODEL_PATH, map_location='cpu')
        state_dict = checkpoint['model_state_dict']
        if list(state_dict.keys())[0].startswith('module.'):
            state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
        model.load_state_dict(state_dict)
        print("Model loaded successfully.")
    else:
        print("⚠️ Checkpoint not found! Using base model.")
    
    model = model.to(device)
    
    # Enable DataParallel if multiple GPUs are available
    if num_gpus > 1:
        print(f"Enabling DataParallel on {num_gpus} GPUs")
        model = nn.DataParallel(model)
        
    model.eval()

    # 3. Prepare Data
    dataset = CorpusDataset(Config.CORPUS_FILE, tokenizer, start_idx=start_idx, end_idx=end_idx)
    dataloader = DataLoader(
        dataset, 
        batch_size=Config.BATCH_SIZE, 
        shuffle=False, 
        num_workers=4, 
        collate_fn=collate_fn,
        pin_memory=True
    )

    os.makedirs(Config.OUTPUT_DIR, exist_ok=True)
    output_file = os.path.join(Config.OUTPUT_DIR, f"corpus_embeddings_shard_{Config.SHARD_ID}.jsonl")

    print(f"Generating embeddings to {output_file}...")
    
    with open(output_file, 'w', encoding='utf-8') as f:
        with torch.no_grad():
            for batch in tqdm(dataloader):
                if batch is None: continue
                
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                doc_ids = batch['doc_id']

                # Encode
                if num_gpus > 1:
                    # DataParallel wraps the model, so we call forward directly or access module
                    # BiEncoder.forward handles 'mode' argument
                    embeddings = model(doc_input_ids=input_ids, doc_attention_mask=attention_mask, mode='doc')
                else:
                    embeddings = model.encode_doc(input_ids, attention_mask)
                
                embeddings = embeddings.cpu().numpy()

                # Write to file
                for doc_id, emb in zip(doc_ids, embeddings):
                    record = {
                        "pmid": doc_id,
                        "embedding": emb.tolist()
                    }
                    f.write(json.dumps(record) + "\n")

    print(f"✅ Done! Shard {Config.SHARD_ID} saved to {output_file}")

# Run
generate_embeddings()