In [None]:
import torch

# Check if CUDA is available
if torch.cuda.is_available():
    # Get the GPU name
    gpu_name = torch.cuda.get_device_name(0)
    # Get GPU properties, including total memory (in bytes)
    gpu_properties = torch.cuda.get_device_properties(0)
    total_memory_gb = gpu_properties.total_memory / (1024 ** 3)
    
    print("GPU detected:", gpu_name)
    print(f"Total GPU Memory: {total_memory_gb:.2f} GB")
else:
    print("CUDA is not available. Running on CPU.")


GPU detected: Tesla P100-PCIE-16GB
Total GPU Memory: 15.89 GB


In [None]:
import torch
from transformers import AutoModel, AutoTokenizer, DynamicCache, Cache
from dataclasses import dataclass
from typing import List, Dict, Any, Callable
# import flash_attn

def _sample_top_p(logits, top_p=0.9):
    # First normalize the logits to prevent overflow/underflow
    logits = logits - torch.max(logits)
    
    # Convert to probabilities with softmax
    probs = torch.nn.functional.softmax(logits, dim=-1)
    
    # Apply a small epsilon to avoid numerical issues
    eps = 1e-10
    probs = torch.clamp(probs, min=eps)
    probs = probs / probs.sum()  # Re-normalize
    
    # Sort the probabilities
    sorted_probs, sorted_indices = torch.sort(probs, descending=True)
    
    # Compute cumulative probabilities
    cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
    
    # Remove tokens with cumulative probability above the threshold
    sorted_indices_to_remove = cumulative_probs > top_p
    # Keep the first token above threshold
    sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
    sorted_indices_to_remove[..., 0] = 0
    
    # Scatter back to original indices
    indices_to_remove = sorted_indices_to_remove.scatter(dim=-1, index=sorted_indices, src=sorted_indices_to_remove)
    probs = probs.masked_fill(indices_to_remove, 0.0)
    
    # Re-normalize after masking
    probs = probs / (probs.sum() + eps)
    
    # Check for invalid values before sampling
    if torch.isnan(probs).any() or torch.isinf(probs).any() or (probs < 0).any():
        # Fix invalid values
        probs = torch.nan_to_num(probs, nan=eps, posinf=1.0, neginf=eps)
        probs = torch.clamp(probs, min=eps)
        probs = probs / probs.sum()  # Re-normalize
    
    # Sample from the filtered distribution
    next_token = torch.multinomial(probs, num_samples=1)
    
    return next_token

class WIMInference:

    def __init__(
        self, model, tokenizer
    ) -> None:
        self.model = model
        self.tokenizer = tokenizer
        self.wim_kv_cache = DynamicCache()
        self.classifier_kv_cache = DynamicCache()

    def _prefill_tokens(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        cache_positions: torch.Tensor,
        kv_cache: Cache,
    ):
        with torch.no_grad():
            outputs = self.model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                cache_position=cache_positions,
                use_cache=True,
                past_key_values=kv_cache,
            )
        return outputs

    def shrink_kv_cache_from_end(self, new_size: int, kv_cache: Cache):

        def resize_tensor_list(token_list):
            for layer_idx in range(len(token_list)):
                token_list[layer_idx] = token_list[layer_idx][:, :, :new_size, :]

        resize_tensor_list(kv_cache.key_cache)
        resize_tensor_list(kv_cache.value_cache)
        kv_cache._seen_tokens = new_size

    def generate_text_with_kv_cache(
        self,
        max_new_tokens: int,
        previous_logits: torch.Tensor,
        do_sample: bool,
        top_p: float,
        temperature: float,
        early_stopping: bool,    
        kv_cache: Cache,
    ) -> str:
        generated_tokens = []

        # This is needed to create the cache_position tensor
        next_token_pos = kv_cache.get_seq_length()

        # Use the logits from the prefilling to generate the first token
        logits = previous_logits

        for _ in range(max_new_tokens):
            # Select the last token from the logits
            next_token_logits = logits[:, -1, :]
            if do_sample:
                # Divide the logits by the temperature
                next_token_logits = next_token_logits / temperature
                # Apply the softmax
                next_token_probs = torch.nn.functional.softmax(
                    next_token_logits, dim=-1
                )
                
                # Check for invalid values (moved here after next_token_probs is defined)
                if torch.isnan(next_token_probs).any() or torch.isinf(next_token_probs).any():
                    print("Invalid probabilities detected!")
                    print(f"Min prob: {next_token_probs.min()}, Max prob: {next_token_probs.max()}")
                    print(f"Contains NaN: {torch.isnan(next_token_probs).any()}")
                    print(f"Contains Inf: {torch.isinf(next_token_probs).any()}")
                    
                next_token = _sample_top_p(next_token_logits, top_p)  # Note: passing logits, not probs
            else:
                # Select the token with the highest probability
                next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
            assert next_token.size() == (1, 1)
            # Remove the batch dimension
            next_token = next_token.squeeze(0)
            generated_tokens.append(next_token)
            # Stop if we reached the EOS token
            if next_token.item() == self.tokenizer.eos_token_id and early_stopping:
                break
            # Use the generated token as input for the next step
            generation_input_ids = next_token.unsqueeze(-1)
            kv_cache_seq_len = kv_cache.get_seq_length()
            generation_attention_mask = torch.ones(
                (1, kv_cache_seq_len + 1), device=next_token.device, dtype=torch.long
            )
            generation_cache_position = torch.tensor(
                [next_token_pos], device=next_token.device
            )

            with torch.no_grad():
                # Get the model outputs
                outputs = self.model(
                    input_ids=generation_input_ids,
                    attention_mask=generation_attention_mask,
                    cache_position=generation_cache_position,
                    use_cache=True,
                    past_key_values=kv_cache,
                )
            logits = outputs.logits
            next_token_pos += 1

        generated_tokens = torch.cat(generated_tokens, dim=-1)
        # Decode the generated tokens
        decoded = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
        return decoded

    def prefill_text_with_kv_cache(self, text: str, kv_cache: Cache):
        # Tokenize the text
        inputs = self.tokenizer(text, return_tensors="pt")
        input_ids = inputs["input_ids"].to(self.model.device)
        seq_len = input_ids.size(1)
        attention_mask = inputs["attention_mask"].to(self.model.device)

        # If we have a KV-Cache, we need to extend the attention mask to account for tokens already in the KV-Cache
        if kv_cache.get_seq_length() > 0:
            kv_cache_seq_len = kv_cache.get_seq_length()
            attention_mask = torch.cat(
                [
                    torch.ones(
                        attention_mask.shape[0],
                        kv_cache_seq_len,
                        dtype=attention_mask.dtype,
                        device=attention_mask.device,
                    ),
                    attention_mask,
                ],
                dim=1,
            )

        # Generate the cache positions for the tokens to be prefilled
        cache_positions = torch.arange(
            kv_cache.get_seq_length(), kv_cache.get_seq_length() + seq_len
        ).to(self.model.device)
        outputs = self._prefill_tokens(input_ids, attention_mask, cache_positions, kv_cache)
        return kv_cache.get_seq_length(), seq_len, outputs
    


In [None]:

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from tqdm import tqdm
import numpy as np
from dataclasses import dataclass
from typing import List, Dict, Any, Optional

@dataclass
class RLConfig:
    """Configuration for the RL-based margin generation."""
    learning_rate: float = 5e-6
    kl_coef: float = 0.1
    discount_factor: float = 0.99
    ppo_epochs: int = 4
    ppo_mini_batch_size: int = 4
    max_grad_norm: float = 0.5
    clip_param: float = 0.2
    value_loss_coef: float = 0.5
    entropy_coef: float = 0.01

class MarginRewardModel:
    """Model to compute rewards for generated margins."""
    
    def __init__(self, model_id: str, device: str = "cuda"):
        """Initialize the reward model.
        
        Args:
            model_id: The ID of the model to use for reward computation.
            device: The device to use for computation.
        """
        self.tokenizer = AutoTokenizer.from_pretrained(model_id)
        self.model = AutoModelForCausalLM.from_pretrained(
            model_id,
            device_map=device,
            torch_dtype=torch.bfloat16,
        ).eval()
        self.device = device
        
    def compute_reward(self, margins: List[str], query: str, classification_results: List[bool]) -> torch.Tensor:
        """Compute rewards for a batch of margins.
        
        The reward combines:
        1. Relevance to the query
        2. Conciseness (penalize overly verbose margins)
        3. Information density
        4. Agreement with classifier (higher reward if classifier agrees)
        
        Args:
            margins: List of generated margins
            query: The original query
            classification_results: Whether each margin was classified as relevant
            
        Returns:
            Tensor of rewards for each margin
        """
        rewards = []
        
        for margin, is_relevant in zip(margins, classification_results):
            with torch.no_grad():
                # Construct prompt to evaluate margin quality
                prompt = f"Query: {query}\nMargin: {margin}\n\nRate the quality of this margin from 0 to 10 based on relevance and information density. A good margin should contain key information relevant to the query."
                
                inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
                outputs = self.model.generate(
                    **inputs, 
                    max_new_tokens=5,
                    return_dict_in_generate=True,
                    output_scores=True,
                )
                
                # Extract score from the generated text
                generated_text = self.tokenizer.decode(outputs.sequences[0][-5:], skip_special_tokens=True)
                try:
                    # Try to extract a numeric score from the response
                    score = float(''.join(c for c in generated_text if c.isdigit() or c == '.'))
                    # Normalize score to 0-1 range
                    score = min(max(score / 10.0, 0.0), 1.0)
                except:
                    # Default score if parsing fails
                    score = 0.5
                
            # Additional reward components
            length_penalty = min(1.0, 100 / max(10, len(margin.split())))  # Prefer concise margins
            classifier_agreement = 1.0 if is_relevant else 0.2  # Reward if classifier agrees it's relevant
            
            # Combine reward components
            reward = (0.6 * score) + (0.2 * length_penalty) + (0.2 * classifier_agreement)
            rewards.append(reward)
            
        return torch.tensor(rewards, device=self.device)

class RLMarginGenerator:
    """Generate margins using reinforcement learning."""
    
    def __init__(
        self, 
        model_id: str,
        reward_model_id: str = None,
        rl_config: RLConfig = None,
        device: str = "cuda"
    ):
        """Initialize the RL-based margin generator.
        
        Args:
            model_id: The ID of the model to use for generation.
            reward_model_id: The ID of the model to use for reward computation.
                If None, uses the same model as the generator.
            rl_config: Configuration for the RL algorithm.
            device: The device to use for computation.
        """
        self.device = device
        self.tokenizer = AutoTokenizer.from_pretrained(model_id)
        
        # Initialize the policy model (for generating margins)
        print('policy model', model_id)
        self.policy_model = AutoModelForCausalLM.from_pretrained(
            model_id,
            device_map=device,
            torch_dtype=torch.bfloat16,
        )
        
        print('ref model', model_id)
        # Initialize the reference model (for KL divergence computation)
        self.ref_model = AutoModelForCausalLM.from_pretrained(
            model_id,
            device_map=device,
            torch_dtype=torch.bfloat16,
        )
        for param in self.ref_model.parameters():
            param.requires_grad = False

        for param in self.policy_model.parameters():
            param.requires_grad = True
            
        # Initialize the reward model
        if reward_model_id is None:
            reward_model_id = model_id
        self.reward_model = MarginRewardModel(reward_model_id, device)
        
        # Initialize the RL config
        self.rl_config = rl_config if rl_config is not None else RLConfig()
        
        # Initialize the optimizer
        self.optimizer = torch.optim.Adam(
            self.policy_model.parameters(),
            lr=self.rl_config.learning_rate
        )

        if torch.isnan(torch.nn.utils.clip_grad_norm_(
            self.policy_model.parameters(),
            self.rl_config.max_grad_norm
        )):
            print("Warning: Nan gradiesnts detected")
        
        # Initialize the WIM inference
        self.wim = WIMInference(self.policy_model, self.tokenizer)
        
    def _compute_logprobs(self, model, input_ids, attention_mask, labels):
        """Compute log probabilities for a batch of sequences."""
        with torch.set_grad_enabled(model.training):
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels,
            )
            
            logits = outputs.logits
            log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
            
            # Get log probs for each token in the sequence
            token_log_probs = torch.gather(
                log_probs[:, :-1, :], 2, labels[:, 1:, None]
            ).squeeze(-1)
            
            # Mask out padding tokens
            mask = (labels[:, 1:] != self.tokenizer.pad_token_id).float()
            token_log_probs = token_log_probs * mask
            
            # Sum log probs over sequence
            seq_log_probs = token_log_probs.sum(dim=1)
            
            return seq_log_probs
            
    def _compute_kl_divergence(self, input_ids, attention_mask, labels):
        """Compute KL divergence between policy and reference model with better numerical stability."""
        # Small constant to avoid division by zero or log(0)
        epsilon = 1e-8
        
        # Get logits from both models
        with torch.set_grad_enabled(self.policy_model.training):
            policy_outputs = self.policy_model(
                input_ids=input_ids,
                attention_mask=attention_mask,
            )
            policy_logits = policy_outputs.logits
        
        with torch.no_grad():
            ref_outputs = self.ref_model(
                input_ids=input_ids,
                attention_mask=attention_mask,
            )
            ref_logits = ref_outputs.logits
        
        # Focus only on the predicted token positions
        seq_length = policy_logits.size(1)
        vocab_size = policy_logits.size(2)
        
        # Make sure labels are properly shaped
        if labels.dim() == 2:
            labels = labels.unsqueeze(-1)
        
        # Get the masked positions (not -100)
        valid_positions = (labels[:, 1:, 0] != -100)
        
        # Extract logits for valid positions
        policy_logits = policy_logits[:, :-1].reshape(-1, vocab_size)
        ref_logits = ref_logits[:, :-1].reshape(-1, vocab_size)
        
        # Convert to probabilities
        policy_probs = torch.nn.functional.softmax(policy_logits, dim=-1)
        ref_probs = torch.nn.functional.softmax(ref_logits, dim=-1)
        
        # Apply clipping to avoid numerical issues
        policy_probs = torch.clamp(policy_probs, min=epsilon, max=1.0-epsilon)
        ref_probs = torch.clamp(ref_probs, min=epsilon, max=1.0-epsilon)
        
        # Calculate KL divergence: p * log(p/q)
        kl_div = policy_probs * (torch.log(policy_probs) - torch.log(ref_probs))
        
        # Sum over vocabulary dimension
        kl_div = kl_div.sum(dim=-1)
        
        # Reshape back to batch_size x seq_length
        kl_div = kl_div.reshape(-1, seq_length - 1)
        
        # Apply masking for valid positions
        if valid_positions.any():
            kl_div = kl_div * valid_positions.float()
            # Average over valid positions
            kl_div = kl_div.sum(dim=1) / (valid_positions.sum(dim=1).float() + epsilon)
        else:
            kl_div = kl_div.mean(dim=1)
        
        return kl_div
    
    def generate_rl_margin(
        self,
        segment: str,
        query: str,
        extractive_summary_prompt: str,
        classification_prompt: str,
        max_new_tokens: int = 100,
        do_sample: bool = True,
        top_p: float = 0.9,
        temperature: float = 1.0,
        early_stopping: bool = True,
    ):
        """Generate a margin using the current policy model."""
        # Prefill the segment
        prefilled_tokens_before_extractive_summary, _, _ = self.wim.prefill_text_with_kv_cache(
            segment, self.wim.wim_kv_cache
        )
        
        # Prefill the extractive summary prompt
        _, _, extractive_summary_outputs = self.wim.prefill_text_with_kv_cache(
            extractive_summary_prompt.format(query=query), self.wim.wim_kv_cache
        )
        
        # Generate the margin
        margin = self.wim.generate_text_with_kv_cache(
            max_new_tokens=max_new_tokens,
            previous_logits=extractive_summary_outputs["logits"],
            do_sample=do_sample,
            top_p=top_p,
            temperature=temperature,
            early_stopping=early_stopping,
            kv_cache=self.wim.wim_kv_cache,
        )
        
        # Shrink the KV cache back to before the extractive summary prompt
        self.wim.shrink_kv_cache_from_end(
            new_size=prefilled_tokens_before_extractive_summary,
            kv_cache=self.wim.wim_kv_cache,
        )
        
        # Classify the margin
        classification_input = classification_prompt.format(query=query, answer=margin)
        _, _, classification_prompt_logits = self.wim.prefill_text_with_kv_cache(
            classification_input, self.wim.classifier_kv_cache
        )
        
        classification_output = self.wim.generate_text_with_kv_cache(
            max_new_tokens=10,
            previous_logits=classification_prompt_logits["logits"],
            do_sample=False,
            top_p=0.9,
            temperature=1.0,
            early_stopping=early_stopping,
            kv_cache=self.wim.classifier_kv_cache,
        )
        
        # Parse the classification output
        is_relevant = self._parse_classifier_output(classification_output)
        
        # Clear the classifier KV cache
        self.wim.shrink_kv_cache_from_end(
            new_size=0, kv_cache=self.wim.classifier_kv_cache
        )
        
        return margin, is_relevant
        
    def _parse_classifier_output(self, output: str) -> bool:
        """Parse the classification output to determine if the margin is relevant."""
        output = output.replace("```", "").strip()
        output = output.split("#")[0]
        if output.endswith("YES"):
            return True
        else:
            return False
            
    def train_rl_margin_generator(
        self,
        segments: List[str],
        query: str,
        extractive_summary_prompt: str,
        classification_prompt: str,
        num_episodes: int = 10,
        max_new_tokens: int = 100,
    ):
        """Train the margin generator using PPO."""
        
        # Set model to training mode explicitly
        self.policy_model.train()

        for episode in range(num_episodes):
            print(f"Episode {episode + 1}/{num_episodes}")
            
            # Sample a batch of segments
            batch_size = min(self.rl_config.ppo_mini_batch_size, len(segments))
            segment_indices = np.random.choice(len(segments), batch_size, replace=False)
            batch_segments = [segments[i] for i in segment_indices]
            
            # Generate margins using the current policy
            margins = []
            is_relevant_list = []
            
            for segment in tqdm(batch_segments, desc="Generating margins"):
                # Clear KV caches
                self.wim.shrink_kv_cache_from_end(0, self.wim.wim_kv_cache)
                self.wim.shrink_kv_cache_from_end(0, self.wim.classifier_kv_cache)
                
                margin, is_relevant = self.generate_rl_margin(
                    segment=segment,
                    query=query,
                    extractive_summary_prompt=extractive_summary_prompt,
                    classification_prompt=classification_prompt,
                    max_new_tokens=max_new_tokens,
                )
                
                margins.append(margin)
                is_relevant_list.append(is_relevant)
            
            # Compute rewards for the generated margins
            with torch.no_grad():
                rewards = self.reward_model.compute_reward(margins, query, is_relevant_list)
            
            # Prepare inputs for PPO update
            inputs = []
            for segment, margin in zip(batch_segments, margins):
                # Tokenize the segment + extractive summary prompt + margin
                context = segment + extractive_summary_prompt.format(query=query) + margin
                input_tokens = self.tokenizer(context, return_tensors="pt", padding=True).to(self.device)
                
                # Create labels for computing log probs
                labels = input_tokens.input_ids.clone()
                # Mask out tokens we don't want to compute loss for
                context_without_margin = segment + extractive_summary_prompt.format(query=query)
                context_tokens = len(self.tokenizer(context_without_margin, return_tensors="pt").input_ids[0])
                labels[:, :context_tokens] = -100  # Mask out non-margin tokens
                
                inputs.append({
                    "input_ids": input_tokens.input_ids,
                    "attention_mask": input_tokens.attention_mask,
                    "labels": labels,
                })
            
            # PPO update
            for _ in range(self.rl_config.ppo_epochs):
                for i in range(len(inputs)):
                    self.optimizer.zero_grad()
                    
                    # Get model outputs with gradients enabled
                    outputs = self.policy_model(
                        input_ids=inputs[i]["input_ids"],
                        attention_mask=inputs[i]["attention_mask"],
                        labels=inputs[i]["labels"]
                    )
                    
                    # Model loss already has gradients
                    model_loss = outputs.loss
                    
                    # Compute KL divergence
                    kl_div = self._compute_kl_divergence(
                        inputs[i]["input_ids"],
                        inputs[i]["attention_mask"],
                        inputs[i]["labels"],
                    )
                    
                    # Compute policy loss - make sure reward is detached
                    reward_term = rewards[i].detach() * -1.0  # Negative since we want to maximize reward
                    kl_term = self.rl_config.kl_coef * kl_div.detach()
                    
                    # Combine with model loss to ensure gradients flow
                    policy_loss = model_loss + reward_term + kl_term
                    
                    # Backward pass
                    policy_loss.backward()
                    
                    # Clip gradients
                    torch.nn.utils.clip_grad_norm_(
                        self.policy_model.parameters(),
                        self.rl_config.max_grad_norm,
                    )
                    
                    # Update policy model
                    self.optimizer.step()
            
            # Log metrics
            avg_reward = rewards.mean().item()
            relevance_rate = sum(is_relevant_list) / len(is_relevant_list)
            
            print(f"Average reward: {avg_reward:.4f}")
            print(f"Average KL divergence: {kl_div.mean().item():.4f}")
            print(f"Relevance rate: {relevance_rate:.4f}")
    
    def save_model(self, output_dir: str):
        """Save the trained model."""
        self.policy_model.save_pretrained(output_dir)
        self.tokenizer.save_pretrained(output_dir)

In [None]:
# from wim import WIMInference
# from RL_margin_generation import RLMarginGenerator

import tiktoken
from nltk import sent_tokenize
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import os

class WIMRLInference(WIMInference):
    """Extended WIM inference class that uses RL-trained model for margin generation."""
    
    def __init__(self, model, tokenizer, rl_margin_generator=None):
        """Initialize the WIM inference with an optional RL margin generator.
        
        Args:
            model: The base model for generation.
            tokenizer: The tokenizer.
            rl_margin_generator: Optional RL-based margin generator.
        """
        super().__init__(model, tokenizer)
        self.rl_margin_generator = rl_margin_generator
        
    def process_with_rl_margins(
        self,
        context: str,
        query: str,
        system_message: str,
        extractive_summary_prompt: str,
        classification_prompt: str,
        final_answer_prompt: str,
        min_tokens_segment: int = 4096,
        max_new_tokens_extractive_summary: int = 100,
        max_new_tokens_final_answer: int = 100,
        max_new_tokens_classification: int = 10,
        do_sample: bool = True,
        top_p: float = 0.9,
        temperature: float = 1.0,
        early_stopping: bool = True,
        use_rl_generator: bool = True,
        print_step_summary: bool = False,
    ):
        """Process the context using the WIM approach with RL-enhanced margin generation.
        
        Args:
            context: The full document context.
            query: The user query.
            system_message: The system message prompt.
            extractive_summary_prompt: Template for extractive summary prompt.
            classification_prompt: Template for classification prompt.
            final_answer_prompt: Template for final answer prompt.
            min_tokens_segment: Minimum number of tokens per segment.
            max_new_tokens_extractive_summary: Maximum number of tokens to generate for margin.
            max_new_tokens_final_answer: Maximum number of tokens to generate for final answer.
            max_new_tokens_classification: Maximum number of tokens to generate for classification.
            do_sample: Whether to use sampling for generation.
            top_p: Top-p sampling parameter.
            temperature: Temperature for generation.
            early_stopping: Whether to use early stopping.
            use_rl_generator: Whether to use the RL-based margin generator.
            print_step_summary: Whether to print a summary for each step.
            
        Returns:
            final_answer: The generated answer.
            positive_margins: List of relevant margins used.
        """
        # Clear KV caches
        self.shrink_kv_cache_from_end(0, self.wim_kv_cache)
        self.shrink_kv_cache_from_end(0, self.classifier_kv_cache)
        
        # Segment the context
        segments = self._chunk_text_to_segments(context, min_tokens_segment)
        
        # Prefill the system message
        _, _, _ = self.prefill_text_with_kv_cache(system_message, self.wim_kv_cache)
        
        positive_margins = []
        
        with torch.no_grad():
            for segment_index in range(len(segments)):
                segment = segments[segment_index]
                
                if use_rl_generator and self.rl_margin_generator is not None:
                    print("# Use RL-based margin generator")
                    margin, is_relevant = self.rl_margin_generator.generate_rl_margin(
                        segment=segment,
                        query=query,
                        extractive_summary_prompt=extractive_summary_prompt,
                        classification_prompt=classification_prompt,
                        max_new_tokens=max_new_tokens_extractive_summary,
                        do_sample=do_sample,
                        top_p=top_p,
                        temperature=temperature,
                        early_stopping=early_stopping,
                    )
                else:
                    print('# Use standard WIM approach')
                    prefilled_tokens_before_extractive_summary, _, _ = self.prefill_text_with_kv_cache(
                        segment, self.wim_kv_cache
                    )
                    
                    formatted_extractive_summary = extractive_summary_prompt.format(query=query)
                    _, _, extractive_summary_outputs = self.prefill_text_with_kv_cache(
                        formatted_extractive_summary, self.wim_kv_cache
                    )
                    
                    margin = self.generate_text_with_kv_cache(
                        max_new_tokens=max_new_tokens_extractive_summary,
                        previous_logits=extractive_summary_outputs["logits"],
                        do_sample=do_sample,
                        top_p=top_p,
                        temperature=temperature,
                        early_stopping=early_stopping,
                        kv_cache=self.wim_kv_cache,
                    )
                    
                    # Shrink KV cache back to before extractive summary
                    self.shrink_kv_cache_from_end(
                        new_size=prefilled_tokens_before_extractive_summary,
                        kv_cache=self.wim_kv_cache,
                    )
                    
                    # Classify the margin
                    classification_input = classification_prompt.format(query=query, answer=margin)
                    _, _, classification_prompt_logits = self.prefill_text_with_kv_cache(
                        classification_input, self.classifier_kv_cache
                    )
                    
                    classification_output = self.generate_text_with_kv_cache(
                        max_new_tokens=max_new_tokens_classification,
                        previous_logits=classification_prompt_logits["logits"],
                        do_sample=False,
                        top_p=top_p,
                        temperature=temperature,
                        early_stopping=early_stopping,
                        kv_cache=self.classifier_kv_cache,
                    )
                    
                    is_relevant = self._parse_classifier_output(classification_output)
                    
                    # Clear the classifier KV cache
                    self.shrink_kv_cache_from_end(
                        new_size=0, kv_cache=self.classifier_kv_cache
                    )
                
                if is_relevant:
                    positive_margins.append(margin)
                
                if print_step_summary:
                    print({
                        "step": segment_index,
                        "prefilled_tokens_so_far": self.wim_kv_cache.get_seq_length(),
                        "margin": margin.strip(),
                        "classification_result": is_relevant,
                    })
            
            # Prefill the concatenated margins and the prompt to ask the final answer
            concatenated_margins = "".join(positive_margins)
            formatted_final_answer = final_answer_prompt.format(
                margins=concatenated_margins, query=query
            )
            
            _, _, final_answer_prefill_outputs = self.prefill_text_with_kv_cache(
                formatted_final_answer, self.wim_kv_cache
            )
            
            # Generate the final answer
            final_answer = self.generate_text_with_kv_cache(
                max_new_tokens=max_new_tokens_final_answer,
                previous_logits=final_answer_prefill_outputs["logits"],
                do_sample=do_sample,
                top_p=top_p,
                temperature=temperature,
                early_stopping=early_stopping,
                kv_cache=self.wim_kv_cache,
            )
            
            return final_answer, positive_margins
    
    def _chunk_text_to_segments(self, text, min_tokens_segment=4096):
        """Chunk text into segments of approximately min_tokens_segment tokens."""
        
        
        tokenizer = tiktoken.encoding_for_model("gpt-4-turbo")
        segments = []
        current_segment = ""
        sentences = sent_tokenize(text)
        curr_tokens = 0
        
        for line in sentences:
            tokens = len(tokenizer.encode(line))
            if curr_tokens + tokens > min_tokens_segment:
                segments.append(current_segment)
                current_segment = ""
                curr_tokens = 0
            
            current_segment += line + " "
            curr_tokens += tokens
        
        if current_segment:
            segments.append(current_segment)
        
        return segments
    
    def _parse_classifier_output(self, output: str) -> bool:
        """Parse the classification output to determine if the margin is relevant."""
        output = output.replace("```", "").strip()
        output = output.split("#")[0]
        if output.endswith("YES"):
            return True
        else:
            return False


def run_wim_rl(
    model_id: str,
    model_id_rl: str,
    input_document: str,
    query: str,
    use_rl_generator: bool = True,
    train_rl_generator: bool = False,
    num_episodes: int = 10,
    output_model_dir: str = None,
    # attn_implementation: str = "flash_attention_2",
    attn_implementation:str = 'sdpa',
    dtype: str = "bfloat16",
    min_tokens_segment: int = 4096,
    max_new_tokens_extractive_summary: int = 100,
    max_new_tokens_final_answer: int = 50,
    max_new_tokens_classification: int = 10,
    do_sample: bool = True,
    top_p: float = 0.9,
    temperature: float = 1.0,
    early_stopping: bool = True,
    print_step_summary: bool = False,
    user_header: str = "",
    generation_header: str = "",
):
    """Run WIM with RL-enhanced margin generation.
    
    Args:
        model_id: The ID of the model to use.
        input_document: The input document content.
        query: The user query.
        use_rl_generator: Whether to use the RL-based margin generator.
        train_rl_generator: Whether to train the RL generator.
        num_episodes: Number of episodes for RL training.
        output_model_dir: Directory to save the trained model.
        attn_implementation: Attention implementation to use.
        dtype: Data type for model weights.
        min_tokens_segment: Minimum number of tokens per segment.
        max_new_tokens_extractive_summary: Maximum number of tokens to generate for margin.
        max_new_tokens_final_answer: Maximum number of tokens to generate for final answer.
        max_new_tokens_classification: Maximum number of tokens to generate for classification.
        do_sample: Whether to use sampling for generation.
        top_p: Top-p sampling parameter.
        temperature: Temperature for generation.
        early_stopping: Whether to use early stopping.
        print_step_summary: Whether to print a summary for each step.
        user_header: User header for the model.
        generation_header: Generation header for the model.
        
    Returns:
        final_answer: The generated answer.
    """
    
    
    # Define model dtype
    model_dtype = torch.float32
    if dtype == "float16":
        model_dtype = torch.float16
    elif dtype == "float32":
        model_dtype = torch.float32
    elif dtype == "bfloat16":
        model_dtype = torch.bfloat16
    
    # Load tokenizer and model
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    tokenizer_rl = AutoTokenizer.from_pretrained(model_id_rl)
    tokenizer_rl.pad_token = tokenizer_rl.eos_token
    # quant_config = BitsAndBytesConfig(
    #     load_in_4bit=True,
    #     bnb_4bit_compute_dtype=torch.float16,  # could also try bfloat16
    #     bnb_4bit_use_double_quant=True,
    #     bnb_4bit_quant_type="nf4"  # best performance for LLaMA-like models
    # )
    
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        device_map="auto",
        attn_implementation=attn_implementation,
        torch_dtype=model_dtype,
    ).eval()

    model_rl = AutoModelForCausalLM.from_pretrained(
        model_id_rl,
        device_map="auto",
        attn_implementation=attn_implementation,
        torch_dtype=model_dtype,
    ).eval()

    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    # Load templates
    template_extractive_summary = """
    ```
    Given the above context, extract all information relevant to the query: "{query}". If the context is not relevant to the query, answer "I don't know."
    {generation_header}
    """.strip()
    
    template_classification = """
    {user_header}
    I asked an LLM assistant whether a piece of document is related to the query: "{query}". This is its answer: 
    ```text
    {answer}
    ```
    Should I save it for later? 
    Here are rules:
    - Answer YES if the answer contains information about the query. 
    - Answer NO if the answer says the piece isn't related to the query.

    Provide the answer in the format: <YES/NO>#<Explanation>. 
    Here is are example answers:

    YES#Yes, the information contains an excerpt from a book that is related to the question.
    NO#No, the LLM assistant concluded the information isn't relevant.

    Don't add any other comments, all your remarks should be included in the "Explanation" section.
    {generation_header}
    """.strip()
    
    template_system_message = """
    {user_header}
    ```
    """.strip()
    
    template_final_answer = """
    ```
    {margins}
    {query}
    {generation_header}
    """.strip()
    
    # Replace special tokens
    special_tokens = {
        "{user_header}": "<|start_header_id|>user<|end_header_id|>\n\n",
        "{generation_header}": "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n",
    }
    
    for token, replacement in special_tokens.items():
        template_extractive_summary = template_extractive_summary.replace(token, replacement)
        template_classification = template_classification.replace(token, replacement)
        template_system_message = template_system_message.replace(token, replacement)
        template_final_answer = template_final_answer.replace(token, replacement)
    
    # Create WIM inference
    # wim_inference = WIMInference(model, tokenizer)


    # Create WIM RL inference for training RL model
    wim_inference = WIMRLInference(model, tokenizer)
    
    # Create RL margin generator if needed
    rl_margin_generator = None
    if use_rl_generator or train_rl_generator:
        print("Use RL-based margin generator")
        rl_margin_generator = RLMarginGenerator(
            model_id=model_id_rl,
            reward_model_id=model_id_rl,  # Use same model for rewards
            device="cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available else "cpu",
        )
    else:
        print("Use standard WIM approach")
    # Train RL generator if requested
    if train_rl_generator and rl_margin_generator is not None:
        print("Training RL margin generator...")
        segments = wim_inference._chunk_text_to_segments(input_document, min_tokens_segment)
        rl_margin_generator.train_rl_margin_generator(
            segments=segments,
            query=query,
            extractive_summary_prompt=template_extractive_summary,
            classification_prompt=template_classification,
            num_episodes=num_episodes,
            max_new_tokens=max_new_tokens_extractive_summary,
        )
        
        # Save the trained model if requested
        if output_model_dir is not None:
            print(f"Saving trained model to {output_model_dir}...")
            rl_margin_generator.save_model(output_model_dir)
    
    # Create WIM RL inference
    wim_rl_inference = WIMRLInference(model, tokenizer, rl_margin_generator)
    
    # Process with WIM RL
    final_answer, positive_margins = wim_rl_inference.process_with_rl_margins(
        context=input_document,
        query=query,
        system_message=template_system_message,
        extractive_summary_prompt=template_extractive_summary,
        classification_prompt=template_classification,
        final_answer_prompt=template_final_answer,
        min_tokens_segment=min_tokens_segment,
        max_new_tokens_extractive_summary=max_new_tokens_extractive_summary,
        max_new_tokens_final_answer = max_new_tokens_final_answer,
        max_new_tokens_classification = max_new_tokens_classification,
        do_sample = do_sample,
        top_p = top_p,
        temperature = temperature,
        early_stopping = early_stopping,
        use_rl_generator = use_rl_generator,
        print_step_summary = True
    )

    return final_answer, positive_margins

In [None]:
from huggingface_hub import login

# Paste your token here
login(token="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")

In [None]:
if __name__ == '__main__':
    # Define your model and input parameters
    # model_id = "HachiML/TinyLlama2-jp-122M-FlashAttention2"  # or any compatible model identifier
    model_id = "meta-llama/Llama-3.2-1B-Instruct"  # or any compatible model identifier
    model_id_rl = 'shubvhamgore18218/WiM_llama'
    # input_document = "Your long document text goes here..."
    query = "Who is silky Epeira?"
    
    # Set parameters for margin generation and RL training
    final_answer, positive_margins = run_wim_rl(
        model_id=model_id,
        model_id_rl=model_id_rl,
        input_document='/kaggle/input/examples/babilong_8k.json',
        query=query,
        use_rl_generator=True,      # Use the RL-enhanced margin generator
        train_rl_generator=True,   # Set True if you want to train the RL generator
        num_episodes=10             # Number of training episodes (only used if training)
    )
    
    print("Final Answer:")
    print(final_answer)
    print("\nPositive Margins:")
    for margin in positive_margins:
        print(margin)

tokenizer_config.json:   0%|          | 0.00/54.5k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/296 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.02k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/914k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/4.38M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/547 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/877 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.47G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/189 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/673 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/488M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/111 [00:00<?, ?B/s]

Use RL-based margin generator
policy model shubvhamgore18218/WiM_llama
ref model shubvhamgore18218/WiM_llama
Training RL margin generator...
Episode 1/10


Generating margins: 100%|██████████| 1/1 [00:03<00:00,  3.12s/it]
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
From v4.47 onwards, when a model cache is to be returned, `generate` will return a `Cache` instance instead by default (as opposed to the legacy tuple of tuples format). If you want to keep returning the legacy format, please set `return_legacy_cache=True`.


Average reward: 0.5400
Average KL divergence: 0.0007
Relevance rate: 0.0000
Episode 2/10


Generating margins: 100%|██████████| 1/1 [00:01<00:00,  1.89s/it]
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


Average reward: 0.5400
Average KL divergence: 0.0007
Relevance rate: 0.0000
Episode 3/10


Generating margins: 100%|██████████| 1/1 [00:01<00:00,  1.96s/it]
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


Average reward: 0.5400
Average KL divergence: 0.0003
Relevance rate: 0.0000
Episode 4/10


Generating margins: 100%|██████████| 1/1 [00:01<00:00,  1.96s/it]
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


Average reward: 0.5400
Average KL divergence: 0.0008
Relevance rate: 0.0000
Episode 5/10


Generating margins: 100%|██████████| 1/1 [00:01<00:00,  1.84s/it]
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


Average reward: 0.5400
Average KL divergence: 0.0031
Relevance rate: 0.0000
Episode 6/10


Generating margins: 100%|██████████| 1/1 [00:01<00:00,  1.95s/it]
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


Average reward: 0.5400
Average KL divergence: 0.0022
Relevance rate: 0.0000
Episode 7/10


Generating margins: 100%|██████████| 1/1 [00:01<00:00,  1.86s/it]
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


Average reward: 0.5400
Average KL divergence: 0.0035
Relevance rate: 0.0000
Episode 8/10


Generating margins: 100%|██████████| 1/1 [00:01<00:00,  1.83s/it]
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


Average reward: 0.5400
Average KL divergence: 0.0027
Relevance rate: 0.0000
Episode 9/10


Generating margins: 100%|██████████| 1/1 [00:01<00:00,  1.82s/it]
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


Average reward: 0.5400
Average KL divergence: 0.0034
Relevance rate: 0.0000
Episode 10/10


Generating margins: 100%|██████████| 1/1 [00:01<00:00,  1.88s/it]
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


Average reward: 0.5400
Average KL divergence: 0.0036
Relevance rate: 0.0000
# Use RL-based margin generator
{'step': 0, 'prefilled_tokens_so_far': 7, 'margin': "Context: Symbol 1998 (), also called the tishop, is a political new version of three punnels by Mos.  It was based on a dozen single attack reference to this installment was originally played into a wedding back for the rebuilt battle.\nWarderty Fox {'100, Rion and The Fall of LLC was the fourth problem of a North American lawyer (Eastern) fought in the night of 1996,", 'classification_result': False}
Final Answer:
Silky Epeira is a popular beauty influencer and model based in the United States.

Positive Margins:
