In [1]:
# !pip install -q --upgrade transformers accelerate bitsandbytes

In [2]:
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

In [3]:
import torch

# Check if CUDA is available
if torch.cuda.is_available():
    # Get the GPU name
    gpu_name = torch.cuda.get_device_name(0)
    # Get GPU properties, including total memory (in bytes)
    gpu_properties = torch.cuda.get_device_properties(0)
    total_memory_gb = gpu_properties.total_memory / (1024 ** 3)
    
    print("GPU detected:", gpu_name)
    print(f"Total GPU Memory: {total_memory_gb:.2f} GB")
else:
    print("CUDA is not available. Running on CPU.")


GPU detected: Tesla P100-PCIE-16GB
Total GPU Memory: 15.89 GB


In [4]:
import torch

if torch.cuda.is_available():
    device = torch.device("cuda:0")
    
    # Get currently allocated memory in bytes and convert to GB
    allocated_memory = torch.cuda.memory_allocated(device) / (1024 ** 3)
    
    # Get the total memory reserved by the caching allocator in bytes and convert to GB
    reserved_memory = torch.cuda.memory_reserved(device) / (1024 ** 3)
    
    print(f"Allocated GPU Memory: {allocated_memory:.2f} GB")
    print(f"Reserved GPU Memory: {reserved_memory:.2f} GB")
else:
    print("CUDA is not available. Running on CPU.")


Allocated GPU Memory: 0.00 GB
Reserved GPU Memory: 0.00 GB


In [5]:
import torch
from transformers import AutoModel, AutoTokenizer, DynamicCache, Cache
from dataclasses import dataclass
from typing import List, Dict, Any, Callable
# import flash_attn

def _sample_top_p(logits, top_p=0.9):
    # First normalize the logits to prevent overflow/underflow
    logits = logits - torch.max(logits)
    
    # Convert to probabilities with softmax
    probs = torch.nn.functional.softmax(logits, dim=-1)
    
    # Apply a small epsilon to avoid numerical issues
    eps = 1e-10
    probs = torch.clamp(probs, min=eps)
    probs = probs / probs.sum()  # Re-normalize
    
    # Sort the probabilities
    sorted_probs, sorted_indices = torch.sort(probs, descending=True)
    
    # Compute cumulative probabilities
    cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
    
    # Remove tokens with cumulative probability above the threshold
    sorted_indices_to_remove = cumulative_probs > top_p
    # Keep the first token above threshold
    sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
    sorted_indices_to_remove[..., 0] = 0
    
    # Scatter back to original indices
    indices_to_remove = sorted_indices_to_remove.scatter(dim=-1, index=sorted_indices, src=sorted_indices_to_remove)
    probs = probs.masked_fill(indices_to_remove, 0.0)
    
    # Re-normalize after masking
    probs = probs / (probs.sum() + eps)
    
    # Check for invalid values before sampling
    if torch.isnan(probs).any() or torch.isinf(probs).any() or (probs < 0).any():
        # Fix invalid values
        probs = torch.nan_to_num(probs, nan=eps, posinf=1.0, neginf=eps)
        probs = torch.clamp(probs, min=eps)
        probs = probs / probs.sum()  # Re-normalize
    
    # Sample from the filtered distribution
    next_token = torch.multinomial(probs, num_samples=1)
    
    return next_token

class WIMInference:

    def __init__(
        self, model, tokenizer
    ) -> None:
        self.model = model
        self.tokenizer = tokenizer
        self.wim_kv_cache = DynamicCache()
        self.classifier_kv_cache = DynamicCache()

    def _prefill_tokens(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        cache_positions: torch.Tensor,
        kv_cache: Cache,
    ):
        with torch.no_grad():
            outputs = self.model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                cache_position=cache_positions,
                use_cache=True,
                past_key_values=kv_cache,
            )
        return outputs

    def shrink_kv_cache_from_end(self, new_size: int, kv_cache: Cache):

        def resize_tensor_list(token_list):
            for layer_idx in range(len(token_list)):
                token_list[layer_idx] = token_list[layer_idx][:, :, :new_size, :]

        resize_tensor_list(kv_cache.key_cache)
        resize_tensor_list(kv_cache.value_cache)
        kv_cache._seen_tokens = new_size

    def generate_text_with_kv_cache(
        self,
        max_new_tokens: int,
        previous_logits: torch.Tensor,
        do_sample: bool,
        top_p: float,
        temperature: float,
        early_stopping: bool,    
        kv_cache: Cache,
    ) -> str:
        generated_tokens = []

        # This is needed to create the cache_position tensor
        next_token_pos = kv_cache.get_seq_length()

        # Use the logits from the prefilling to generate the first token
        logits = previous_logits

        for _ in range(max_new_tokens):
            # Select the last token from the logits
            next_token_logits = logits[:, -1, :]
            if do_sample:
                # Divide the logits by the temperature
                next_token_logits = next_token_logits / temperature
                # Apply the softmax
                next_token_probs = torch.nn.functional.softmax(
                    next_token_logits, dim=-1
                )
                
                # Check for invalid values (moved here after next_token_probs is defined)
                if torch.isnan(next_token_probs).any() or torch.isinf(next_token_probs).any():
                    print("Invalid probabilities detected!")
                    print(f"Min prob: {next_token_probs.min()}, Max prob: {next_token_probs.max()}")
                    print(f"Contains NaN: {torch.isnan(next_token_probs).any()}")
                    print(f"Contains Inf: {torch.isinf(next_token_probs).any()}")
                    
                next_token = _sample_top_p(next_token_logits, top_p)  # Note: passing logits, not probs
            else:
                # Select the token with the highest probability
                next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
            assert next_token.size() == (1, 1)
            # Remove the batch dimension
            next_token = next_token.squeeze(0)
            generated_tokens.append(next_token)
            # Stop if we reached the EOS token
            if next_token.item() == self.tokenizer.eos_token_id and early_stopping:
                break
            # Use the generated token as input for the next step
            generation_input_ids = next_token.unsqueeze(-1)
            kv_cache_seq_len = kv_cache.get_seq_length()
            generation_attention_mask = torch.ones(
                (1, kv_cache_seq_len + 1), device=next_token.device, dtype=torch.long
            )
            generation_cache_position = torch.tensor(
                [next_token_pos], device=next_token.device
            )

            with torch.no_grad():
                # Get the model outputs
                outputs = self.model(
                    input_ids=generation_input_ids,
                    attention_mask=generation_attention_mask,
                    cache_position=generation_cache_position,
                    use_cache=True,
                    past_key_values=kv_cache,
                )
            logits = outputs.logits
            next_token_pos += 1

        generated_tokens = torch.cat(generated_tokens, dim=-1)
        # Decode the generated tokens
        decoded = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
        return decoded

    def prefill_text_with_kv_cache(self, text: str, kv_cache: Cache):
        # Tokenize the text
        inputs = self.tokenizer(text, return_tensors="pt")
        input_ids = inputs["input_ids"].to(self.model.device)
        seq_len = input_ids.size(1)
        attention_mask = inputs["attention_mask"].to(self.model.device)

        # If we have a KV-Cache, we need to extend the attention mask to account for tokens already in the KV-Cache
        if kv_cache.get_seq_length() > 0:
            kv_cache_seq_len = kv_cache.get_seq_length()
            attention_mask = torch.cat(
                [
                    torch.ones(
                        attention_mask.shape[0],
                        kv_cache_seq_len,
                        dtype=attention_mask.dtype,
                        device=attention_mask.device,
                    ),
                    attention_mask,
                ],
                dim=1,
            )

        # Generate the cache positions for the tokens to be prefilled
        cache_positions = torch.arange(
            kv_cache.get_seq_length(), kv_cache.get_seq_length() + seq_len
        ).to(self.model.device)
        outputs = self._prefill_tokens(input_ids, attention_mask, cache_positions, kv_cache)
        return kv_cache.get_seq_length(), seq_len, outputs
    


In [6]:

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from tqdm import tqdm
import numpy as np
from dataclasses import dataclass
from typing import List, Dict, Any, Optional

@dataclass
class RLConfig:
    """Configuration for the RL-based margin generation."""
    learning_rate: float = 5e-5
    kl_coef: float = 0.05
    discount_factor: float = 0.99
    ppo_epochs: int = 4
    ppo_mini_batch_size: int = 4
    max_grad_norm: float = 1
    clip_param: float = 0.1
    value_loss_coef: float = 0.5
    entropy_coef: float = 0.03

class MarginRewardModel:
    """Model to compute rewards for generated margins."""
    
    def __init__(self, model_id: str, device: str = "cuda"):
        """Initialize the reward model.
        
        Args:
            model_id: The ID of the model to use for reward computation.
            device: The device to use for computation.
        """
        self.tokenizer = AutoTokenizer.from_pretrained(model_id)
        self.model = AutoModelForCausalLM.from_pretrained(
            model_id,
            device_map=device,
            torch_dtype=torch.bfloat16,
        ).eval()
        self.device = device
        
    # def compute_reward(self, margins: List[str], query: str, classification_results: List[bool]) -> torch.Tensor:
    #     """Compute rewards for a batch of margins.
        
    #     The reward combines:
    #     1. Relevance to the query
    #     2. Conciseness (penalize overly verbose margins)
    #     3. Information density
    #     4. Agreement with classifier (higher reward if classifier agrees)
        
    #     Args:
    #         margins: List of generated margins
    #         query: The original query
    #         classification_results: Whether each margin was classified as relevant
            
    #     Returns:
    #         Tensor of rewards for each margin
    #     """
    #     rewards = []
        
    #     for margin, is_relevant in zip(margins, classification_results):
    #         with torch.no_grad():
    #             # Construct prompt to evaluate margin quality
    #             prompt = f"Query: {query}\nMargin: {margin}\n\nRate the quality of this margin from 0 to 10 based on relevance and information density. A good margin should contain key information relevant to the query."
                
    #             inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
    #             outputs = self.model.generate(
    #                 **inputs, 
    #                 max_new_tokens=5,
    #                 return_dict_in_generate=True,
    #                 output_scores=True,
    #             )
                
    #             # Extract score from the generated text
    #             generated_text = self.tokenizer.decode(outputs.sequences[0][-5:], skip_special_tokens=True)
    #             try:
    #                 # Try to extract a numeric score from the response
    #                 score = float(''.join(c for c in generated_text if c.isdigit() or c == '.'))
    #                 # Normalize score to 0-1 range
    #                 score = min(max(score / 10.0, 0.0), 1.0)
    #             except:
    #                 # Default score if parsing fails
    #                 score = 0.5
                
    #         # Additional reward components
    #         length_penalty = min(1.0, 100 / max(10, len(margin.split())))  # Prefer concise margins
    #         classifier_agreement = 1.0 if is_relevant else 0.2  # Reward if classifier agrees it's relevant
            
    #         # Combine reward components
    #         reward = (0.6 * score) + (0.2 * length_penalty) + (0.2 * classifier_agreement)
    #         rewards.append(reward)
            
    #     return torch.tensor(rewards, device=self.device)
    def compute_reward(self, margins: List[str], query: str,  supporting_facts: List[str]) -> torch.Tensor:
        rewards = []
        
        for margin in margins:
            # More granular coherence scoring
            coherence_score = min(0.3, 0.3 * sum(1 for word in margin.split() if len(word) > 2) / max(1, len(margin.split())))
            relevance_score = 0.0
            # Enhanced relevance scoring
            query_terms = set(query.lower().split())
            margin_words = margin.lower().split()
            
            # Count occurrences of query terms, not just presence
            term_counts = {}
            for term in query_terms:
                term_counts[term] = sum(1 for word in margin_words if term in word)
            
            if sum(term_counts.values()) > 0:
                relevance_score = min(0.4, sum(term_counts.values()) / (len(query_terms) * 2) * 0.4)
            
            
            # Length penalty to discourage extremely short responses
            length_penalty = max(0.0, min(0.2, (len(margin) / 200) * 0.2))
            
            # Base reward can be smaller to enhance differentiation
            base_reward = 0.1

            contains_supporting_facts = any(fact.lower() in margin.lower() for fact in supporting_facts)
            penalty = -0.5 if not contains_supporting_facts else 1
            
            reward = base_reward + coherence_score + relevance_score + length_penalty + penalty
            rewards.append(reward)
            
            # Debug individual components
            # print(f"Margin: {margin[:50]}...")
            # print(f"Coherence: {coherence_score:.2f}, Relevance: {relevance_score:.2f}, " 
            #     f"Classifier: {penalty:.2f}, Length: {length_penalty:.2f}")
            
        return torch.tensor(rewards, device=self.device)


class RLMarginGenerator:
    """Generate margins using reinforcement learning."""
    
    def __init__(
        self, 
        model_id: str,
        reward_model_id: str = None,
        rl_config: RLConfig = None,
        device: str = "cuda"
    ):
        """Initialize the RL-based margin generator.
        
        Args:
            model_id: The ID of the model to use for generation.
            reward_model_id: The ID of the model to use for reward computation.
                If None, uses the same model as the generator.
            rl_config: Configuration for the RL algorithm.
            device: The device to use for computation.
        """
        self.device = device
        self.tokenizer = AutoTokenizer.from_pretrained(model_id)
        
        # Initialize the policy model (for generating margins)
        print('policy model', model_id)
        self.policy_model = AutoModelForCausalLM.from_pretrained(
            model_id,
            device_map=device,
            torch_dtype=torch.bfloat16,
        )
        
        print('ref model', model_id)
        # Initialize the reference model (for KL divergence computation)
        self.ref_model = AutoModelForCausalLM.from_pretrained(
            model_id,
            device_map=device,
            torch_dtype=torch.bfloat16,
        )
        for param in self.ref_model.parameters():
            param.requires_grad = False

        for param in self.policy_model.parameters():
            param.requires_grad = True
            
        # Initialize the reward model
        if reward_model_id is None:
            reward_model_id = model_id
        self.reward_model = MarginRewardModel(reward_model_id, device)
        
        # Initialize the RL config
        self.rl_config = rl_config if rl_config is not None else RLConfig()
        
        # Initialize the optimizer
        self.optimizer = torch.optim.Adam(
            self.policy_model.parameters(),
            lr=self.rl_config.learning_rate
        )

        if torch.isnan(torch.nn.utils.clip_grad_norm_(
            self.policy_model.parameters(),
            self.rl_config.max_grad_norm
        )):
            print("Warning: Nan gradiesnts detected")

        self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer, mode='max', factor=0.5, patience=2, verbose=2
        )
        
        # Initialize the WIM inference
        self.wim = WIMInference(self.policy_model, self.tokenizer)

        self.best_avg_reward = 0
        self.no_improvement_count = 0
        self.patience = 3
        self.total_steps = 0
        
    def _compute_logprobs(self, model, input_ids, attention_mask, labels):
        """Compute log probabilities for a batch of sequences."""
        with torch.set_grad_enabled(model.training):
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels,
            )
            
            logits = outputs.logits
            log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
            
            # Get log probs for each token in the sequence
            token_log_probs = torch.gather(
                log_probs[:, :-1, :], 2, labels[:, 1:, None]
            ).squeeze(-1)
            
            # Mask out padding tokens
            mask = (labels[:, 1:] != self.tokenizer.pad_token_id).float()
            token_log_probs = token_log_probs * mask
            
            # Sum log probs over sequence
            seq_log_probs = token_log_probs.sum(dim=1)
            
            return seq_log_probs
            
    def _compute_kl_divergence(self, input_ids, attention_mask, labels):
        """Compute KL divergence between policy and reference model with better numerical stability."""
        # Small constant to avoid division by zero or log(0)
        epsilon = 1e-8
        
        # Process in smaller chunks along the sequence dimension
        chunk_size = 128
        seq_length = input_ids.size(1)
        batch_size = input_ids.size(0)
        vocab_size = self.policy_model.config.vocab_size
        kl_divs = []
        
        for chunk_start in range(0, seq_length - 1, chunk_size):  # -1 for labels offset
            chunk_end = min(chunk_start + chunk_size, seq_length - 1)
            
            # Extract chunks for this iteration
            chunk_input_ids = input_ids[:, chunk_start:chunk_end + 1]  # +1 for next token prediction
            chunk_attention_mask = attention_mask[:, chunk_start:chunk_end + 1]
            
            # Get logits from policy model for this chunk
            with torch.set_grad_enabled(self.policy_model.training):
                policy_outputs = self.policy_model(
                    input_ids=chunk_input_ids,
                    attention_mask=chunk_attention_mask,
                )
                policy_logits = policy_outputs.logits[:, :-1]  # Remove last token prediction
            
            # Get logits from reference model for this chunk
            with torch.no_grad():
                ref_outputs = self.ref_model(
                    input_ids=chunk_input_ids,
                    attention_mask=chunk_attention_mask,
                )
                ref_logits = ref_outputs.logits[:, :-1]  # Remove last token prediction
            
            # Reshape for calculation
            policy_logits = policy_logits.reshape(-1, vocab_size)
            ref_logits = ref_logits.reshape(-1, vocab_size)
            
            # Convert to probabilities
            policy_probs = torch.nn.functional.softmax(policy_logits, dim=-1)
            ref_probs = torch.nn.functional.softmax(ref_logits, dim=-1)
            
            # Apply clipping to avoid numerical issues
            policy_probs = torch.clamp(policy_probs, min=epsilon, max=1.0-epsilon)
            ref_probs = torch.clamp(ref_probs, min=epsilon, max=1.0-epsilon)
            
            # Calculate KL divergence: p * log(p/q)
            kl = policy_probs * (torch.log(policy_probs) - torch.log(ref_probs))
            kl = kl.sum(dim=-1)
            
            # Reshape back to batch_size x chunk_length
            chunk_length = chunk_end - chunk_start
            kl = kl.reshape(batch_size, chunk_length)
            
            # Get labels for this chunk
            chunk_labels = labels[:, chunk_start+1:chunk_end+1]  # +1 for next token prediction
            
            # Create mask for relevant token positions
            valid_mask = (chunk_labels != -100).float()
            
            # Apply mask
            masked_kl = kl * valid_mask
            
            # Compute mean KL for this chunk (handling empty masks)
            chunk_kl_sum = masked_kl.sum(dim=1)
            chunk_mask_sum = valid_mask.sum(dim=1) + epsilon
            chunk_kl = chunk_kl_sum / chunk_mask_sum
            
            kl_divs.append(chunk_kl)
            
            # Free memory
            del policy_outputs, ref_outputs, policy_logits, ref_logits
            del policy_probs, ref_probs, kl, masked_kl
            torch.cuda.empty_cache()
        
        # Combine KL from all chunks (average)
        if kl_divs:
            kl_div = torch.stack(kl_divs, dim=1).mean(dim=1)
        else:
            kl_div = torch.zeros(batch_size, device=self.device)
        
        return kl_div
    
    def _validate_and_clean_margin(self, margin: str) -> str:
        """Validate and clean the generated margin."""
        # Remove excessive strange characters
        strange_chars = "�'`@#$%^*"
        for char in strange_chars:
            margin = margin.replace(char, "")
        
        # Trim to first sentence if too long
        if len(margin) > 200 and "." in margin:
            first_part = margin.split(".")[0] + "."
            if len(first_part) > 50:  # Ensure we have a meaningful sentence
                margin = first_part
        
        # Ensure minimum length
        if len(margin.strip()) < 10:
            margin = "The text does not contain sufficient information related to the query."
            
        return margin
    
    def generate_rl_margin(
        self,
        segment: str,
        query: str,
        extractive_summary_prompt: str,
        classification_prompt: str,
        max_new_tokens: int = 50,
        do_sample: bool = True,
        top_p: float = 0.9,
        temperature: float = 0.7,
        early_stopping: bool = True,
    ):
        """Generate a margin using the current policy model."""

        
        # Prefill the segment
        prefilled_tokens_before_extractive_summary, _, _ = self.wim.prefill_text_with_kv_cache(
            segment, self.wim.wim_kv_cache
        )
        
        # Prefill the extractive summary prompt
        _, _, extractive_summary_outputs = self.wim.prefill_text_with_kv_cache(
            extractive_summary_prompt.format(query=query), self.wim.wim_kv_cache
        )
        
        # Generate the margin
        margin = self.wim.generate_text_with_kv_cache(
            max_new_tokens=max_new_tokens,
            previous_logits=extractive_summary_outputs["logits"],
            do_sample=do_sample,
            top_p=top_p,
            temperature=temperature,
            early_stopping=early_stopping,
            kv_cache=self.wim.wim_kv_cache,
        )

        margin = self._validate_and_clean_margin(margin)
        
        # Shrink the KV cache back to before the extractive summary prompt
        self.wim.shrink_kv_cache_from_end(
            new_size=prefilled_tokens_before_extractive_summary,
            kv_cache=self.wim.wim_kv_cache,
        )
        
        # Classify the margin
        classification_input = classification_prompt.format(query=query, answer=margin)
        _, _, classification_prompt_logits = self.wim.prefill_text_with_kv_cache(
            classification_input, self.wim.classifier_kv_cache
        )
        
        classification_output = self.wim.generate_text_with_kv_cache(
            max_new_tokens=10,
            previous_logits=classification_prompt_logits["logits"],
            do_sample=False,
            top_p=0.9,
            temperature=1,
            early_stopping=early_stopping,
            kv_cache=self.wim.classifier_kv_cache,
        )
        
        # Parse the classification output
        is_relevant = self._parse_classifier_output(classification_output)
        
        # Clear the classifier KV cache
        self.wim.shrink_kv_cache_from_end(
            new_size=0, kv_cache=self.wim.classifier_kv_cache
        )
        
        return margin, is_relevant
        
    def _parse_classifier_output(self, output: str) -> bool:
        """Parse the classification output to determine if the margin is relevant."""
        output = output.replace("```", "").strip()
        output = output.split("#")[0]
        if output.endswith("YES"):
            return True
        else:
            return False
        
    def _chunk_text_to_segments(self, text, min_tokens_segment=256, tokenizer=None):
        """Chunk text into segments of approximately min_tokens_segment tokens."""
        
        if tokenizer is None:
            raise ValueError("Tokenizer must be provided to _chunk_text_to_segments")

        tokenizer = tokenizer
        segments = []
        current_segment = ""
        sentences = sent_tokenize(text)
        curr_tokens = 0
        
        for line in sentences:
            tokens = len(tokenizer.encode(line))
            if curr_tokens + tokens > min_tokens_segment:
                segments.append(current_segment)
                current_segment = ""
                curr_tokens = 0
            
            current_segment += line + " "
            curr_tokens += tokens
        
        if current_segment:
            segments.append(current_segment)
        for s in segments:
            print("segment", s)
        return segments
            
    def train_rl_margin_generator(
        self,
        segments: List[str],
        query: str,
        extractive_summary_prompt: str,
        classification_prompt: str,
        supporting_facts: List[str],
        num_episodes: int = 10,
        max_new_tokens: int = 50,
    ):
        """Train the margin generator using PPO."""
        self.best_avg_reward = 0
        self.no_improvement_count = 0
        self.total_steps = 0

        print(type(segments), "segmenmts datatype")
        # print("during training", segments)
        # print(type(segments), "type segment")
        segments = self._chunk_text_to_segments(segments, min_tokens_segment=1024, tokenizer=self.tokenizer)
        print('-------------------------------------------')
        print("Segments after chunking: ", segments)
        print('-------------------------------------------')

        # for s in segments:
        #     print("segment", s)
        # Set model to training mode explicitly
        self.policy_model.train()
        for episode in range(num_episodes):
            # Add after memory-intensive operations
        
            
            # Sample a batch of segments
            # batch_size = min(self.rl_config.ppo_mini_batch_size, len(segments))
            # segment_indices = np.random.choice(len(segments), batch_size, replace=False)
            # batch_segments = [segments[i] for i in segment_indices]
            # print(f"Batch segments: {batch_segments[3]}")
            # Generate margins using the current policy
            margins = []
            is_relevant_list = []
            # print("Processing", segments)
            for segment in tqdm(segments, desc="Generating margins"):
                # print("segment: ", segment)
                # Clear KV caches
                self.wim.shrink_kv_cache_from_end(0, self.wim.wim_kv_cache)
                
                margin, is_relevant = self.generate_rl_margin(
                    segment=segment,
                    query=query,
                    extractive_summary_prompt=extractive_summary_prompt,
                    classification_prompt=classification_prompt,
                    max_new_tokens=max_new_tokens,
                )
                # print("margin: ", margin)
                # print('-------------------------------------------')
                margins.append(margin)
                is_relevant_list.append(is_relevant)
                # Clear memory
                torch.cuda.empty_cache()
            
            # Compute rewards for the generated margins
            with torch.no_grad():
                rewards = self.reward_model.compute_reward(margins, query, supporting_facts)
            
            # Prepare inputs for PPO update
            inputs = []
            for segment, margin, reward in zip(segments, margins, rewards):
                print(f"Segment: {segment}\nMargin: {margin}\nReward: {reward}\n\n")
                # print("--------------")
                # Tokenize the segment + extractive summary prompt + margin
                context = segment + extractive_summary_prompt.format(query=query) + margin
                # Padding changed here
                input_tokens = self.tokenizer(context, return_tensors="pt", padding=False).to(self.device)
                
                # Create labels for computing log probs
                labels = input_tokens.input_ids.clone()
                # Mask out tokens we don't want to compute loss for
                context_without_margin = segment + extractive_summary_prompt.format(query=query)
                context_tokens = len(self.tokenizer(context_without_margin, return_tensors="pt").input_ids[0])
                labels[:, :context_tokens] = -100  # Mask out non-margin tokens
                
                inputs.append({
                    "input_ids": input_tokens.input_ids,
                    "attention_mask": input_tokens.attention_mask,
                    "labels": labels,
                })
            
            # PPO update
            for _ in range(self.rl_config.ppo_epochs):
                for i in range(len(inputs)):
                    self.optimizer.zero_grad()
                    
                    # Get model outputs with gradients enabled
                    outputs = self.policy_model(
                        input_ids=inputs[i]["input_ids"],
                        attention_mask=inputs[i]["attention_mask"],
                        labels=inputs[i]["labels"]
                    )
                    
                    # Model loss already has gradients
                    model_loss = outputs.loss
                    
                    # Compute KL divergence
                    kl_div = self._compute_kl_divergence(
                        inputs[i]["input_ids"],
                        inputs[i]["attention_mask"],
                        inputs[i]["labels"],
                    )
                    
                    # Compute policy loss - make sure reward is detached
                    reward_term = rewards[i].detach() * -1.0  # Negative since we want to maximize reward
                    kl_term = self.rl_config.kl_coef * kl_div.detach()
                    
                    # Combine with model loss to ensure gradients flow
                    policy_loss = model_loss + reward_term + kl_term
                    
                    # Backward pass
                    policy_loss.backward()
                    
                    # Clip gradients
                    torch.nn.utils.clip_grad_norm_(
                        self.policy_model.parameters(),
                        self.rl_config.max_grad_norm,
                    )
                    
                    # Update policy model
                    self.optimizer.step()
                    # Free memory
                    del outputs, model_loss, policy_loss
                    torch.cuda.empty_cache()
            
            # Log metrics
            avg_reward = rewards.mean().item()
            relevance_rate = sum(is_relevant_list) / len(is_relevant_list)
            print(f"Average reward: {avg_reward:.4f}")
            print(f"Average KL divergence: {kl_div.mean().item():.4f}")
            print(f"Relevance rate: {relevance_rate:.4f}")
            self.scheduler.step(avg_reward)
            # Early stopping check
            if avg_reward > self.best_avg_reward:
                self.best_avg_reward = avg_reward
                self.no_improvement_count = 0
                # Save best model
                self.save_model(f"best_model_episode_{episode+1}")
            else:
                self.no_improvement_count += 1
                
            if self.no_improvement_count >= self.patience:
                print(f"No improvement for {self.patience} episodes. Stopping early.")
                break
                
            # Save checkpoint periodically
            if (episode + 1) % 3 == 0:
                self.save_model(f"checkpoint_episode_{episode+1}")
            
            # Free memory between episodes
            torch.cuda.empty_cache()
    
    def save_model(self, output_dir: str):
        """Save the trained model."""
        self.policy_model.save_pretrained(output_dir)
        self.tokenizer.save_pretrained(output_dir)

In [7]:
# from wim import WIMInference
# from RL_margin_generation import RLMarginGenerator

import tiktoken
from nltk import sent_tokenize
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import os

class WIMRLInference(WIMInference):
    """Extended WIM inference class that uses RL-trained model for margin generation."""
    
    def __init__(self, model, tokenizer, rl_margin_generator=None):
        """Initialize the WIM inference with an optional RL margin generator.
        
        Args:
            model: The base model for generation.
            tokenizer: The tokenizer.
            rl_margin_generator: Optional RL-based margin generator.
        """
        super().__init__(model, tokenizer)
        self.rl_margin_generator = rl_margin_generator
        
    def process_with_rl_margins(
        self,
        context: str,
        query: str,
        system_message: str,
        extractive_summary_prompt: str,
        classification_prompt: str,
        final_answer_prompt: str,
        min_tokens_segment: int = 1024,
        max_new_tokens_extractive_summary: int = 50,
        max_new_tokens_final_answer: int = 50,
        max_new_tokens_classification: int = 10,
        do_sample: bool = True,
        top_p: float = 0.9,
        temperature: float = 1.0,
        early_stopping: bool = True,
        use_rl_generator: bool = True,
        print_step_summary: bool = False,
        tokenizer=None,
    ):
        """Process the context using the WIM approach with RL-enhanced margin generation.
        
        Args:
            context: The full document context.
            query: The user query.
            system_message: The system message prompt.
            extractive_summary_prompt: Template for extractive summary prompt.
            classification_prompt: Template for classification prompt.
            final_answer_prompt: Template for final answer prompt.
            min_tokens_segment: Minimum number of tokens per segment.
            max_new_tokens_extractive_summary: Maximum number of tokens to generate for margin.
            max_new_tokens_final_answer: Maximum number of tokens to generate for final answer.
            max_new_tokens_classification: Maximum number of tokens to generate for classification.
            do_sample: Whether to use sampling for generation.
            top_p: Top-p sampling parameter.
            temperature: Temperature for generation.
            early_stopping: Whether to use early stopping.
            use_rl_generator: Whether to use the RL-based margin generator.
            print_step_summary: Whether to print a summary for each step.
            
        Returns:
            final_answer: The generated answer.
            positive_margins: List of relevant margins used.
        """
        # Clear KV caches
        self.shrink_kv_cache_from_end(0, self.wim_kv_cache)
        self.shrink_kv_cache_from_end(0, self.classifier_kv_cache)
        # print("printing segments")
        # Segment the context
        segments = self._chunk_text_to_segments(context, min_tokens_segment, tokenizer)
        # print('-------------------------------------------------------------------------------------')
        # count=0
        # for s in segments:
        #     count += 1
        #     print(count, len(s))
        #     print('\n')
        # print('-------------------------------------------------------------------------------------')
        # Prefill the system message
        _, _, _ = self.prefill_text_with_kv_cache(system_message, self.wim_kv_cache)
        
        positive_margins = []
        
        with torch.no_grad():
            for segment_index in range(len(segments)):
                segment = segments[segment_index]
                print(f"Processing segment {segment_index + 1}/{len(segments)}")
                if use_rl_generator and self.rl_margin_generator is not None:
                    print("# Use RL-based margin generator")
                    margin, is_relevant = self.rl_margin_generator.generate_rl_margin(
                        segment=segment,
                        query=query,
                        extractive_summary_prompt=extractive_summary_prompt,
                        classification_prompt=classification_prompt,
                        max_new_tokens=max_new_tokens_extractive_summary,
                        do_sample=do_sample,
                        top_p=top_p,
                        temperature=temperature,
                        early_stopping=early_stopping,
                    )
                    print("margin:", margin)
                else:
                    print('# Use standard WIM approach')
                    prefilled_tokens_before_extractive_summary, _, _ = self.prefill_text_with_kv_cache(
                        segment, self.wim_kv_cache
                    )
                    
                    formatted_extractive_summary = extractive_summary_prompt.format(query=query)
                    _, _, extractive_summary_outputs = self.prefill_text_with_kv_cache(
                        formatted_extractive_summary, self.wim_kv_cache
                    )
                    
                    margin = self.generate_text_with_kv_cache(
                        max_new_tokens=max_new_tokens_extractive_summary,
                        previous_logits=extractive_summary_outputs["logits"],
                        do_sample=do_sample,
                        top_p=top_p,
                        temperature=temperature,
                        early_stopping=early_stopping,
                        kv_cache=self.wim_kv_cache,
                    )
                    
                    # Shrink KV cache back to before extractive summary
                    self.shrink_kv_cache_from_end(
                        new_size=prefilled_tokens_before_extractive_summary,
                        kv_cache=self.wim_kv_cache,
                    )
                    
                    # Classify the margin
                    classification_input = classification_prompt.format(query=query, answer=margin)
                    _, _, classification_prompt_logits = self.prefill_text_with_kv_cache(
                        classification_input, self.classifier_kv_cache
                    )
                    
                    classification_output = self.generate_text_with_kv_cache(
                        max_new_tokens=max_new_tokens_classification,
                        previous_logits=classification_prompt_logits["logits"],
                        do_sample=False,
                        top_p=top_p,
                        temperature=temperature,
                        early_stopping=early_stopping,
                        kv_cache=self.classifier_kv_cache,
                    )
                    
                    is_relevant = self._parse_classifier_output(classification_output)
                    
                    # Clear the classifier KV cache
                    self.shrink_kv_cache_from_end(
                        new_size=0, kv_cache=self.classifier_kv_cache
                    )
                
                if is_relevant:
                    positive_margins.append(margin)
                
                if print_step_summary:
                    print({
                        "step": segment_index,
                        "prefilled_tokens_so_far": self.wim_kv_cache.get_seq_length(),
                        "margin": margin.strip(),
                        "classification_result": is_relevant,
                    })
            
            # Prefill the concatenated margins and the prompt to ask the final answer
            concatenated_margins = "".join(positive_margins)
            formatted_final_answer = final_answer_prompt.format(
                margins=concatenated_margins, query=query
            )
            
            _, _, final_answer_prefill_outputs = self.prefill_text_with_kv_cache(
                formatted_final_answer, self.wim_kv_cache
            )
            
            # Generate the final answer
            final_answer = self.generate_text_with_kv_cache(
                max_new_tokens=max_new_tokens_final_answer,
                previous_logits=final_answer_prefill_outputs["logits"],
                do_sample=do_sample,
                top_p=top_p,
                temperature=temperature,
                early_stopping=early_stopping,
                kv_cache=self.wim_kv_cache,
            )
            
            return final_answer, positive_margins
    
    def _chunk_text_to_segments(self, text, min_tokens_segment=1024, tokenizer=None):
        """Chunk text into segments of approximately min_tokens_segment tokens."""
        
        if tokenizer is None:
            raise ValueError("Tokenizer must be provided to _chunk_text_to_segments")

        tokenizer = tokenizer
        segments = []
        current_segment = ""
        sentences = sent_tokenize(text)
        curr_tokens = 0
        
        for line in sentences:
            tokens = len(tokenizer.encode(line))
            if curr_tokens + tokens > min_tokens_segment:
                segments.append(current_segment)
                current_segment = ""
                curr_tokens = 0
            
            current_segment += line + " "
            curr_tokens += tokens
        
        if current_segment:
            segments.append(current_segment)
        
        return segments
    
    def _parse_classifier_output(self, output: str) -> bool:
        """Parse the classification output to determine if the margin is relevant."""
        output = output.replace("```", "").strip()
        output = output.split("#")[0]
        if output.endswith("YES"):
            return True
        else:
            return False


def run_wim_rl(
    model_id: str,
    model_id_rl: str,
    input_document: str,
    query: str,
    use_rl_generator: bool = True,
    train_rl_generator: bool = False,
    num_episodes: int = 10,
    output_model_dir: str = None,
    # attn_implementation: str = "flash_attention_2",
    attn_implementation:str = 'sdpa',
    dtype: str = "bfloat16",
    min_tokens_segment: int = 1024,
    max_new_tokens_extractive_summary: int = 100,
    max_new_tokens_final_answer: int = 50,
    max_new_tokens_classification: int = 10,
    do_sample: bool = True,
    top_p: float = 0.9,
    temperature: float = 1.0,
    early_stopping: bool = True,
    print_step_summary: bool = False,
    user_header: str = "",
    generation_header: str = "",
):
    """Run WIM with RL-enhanced margin generation.
    
    Args:
        model_id: The ID of the model to use.
        input_document: The input document content.
        query: The user query.
        use_rl_generator: Whether to use the RL-based margin generator.
        train_rl_generator: Whether to train the RL generator.
        num_episodes: Number of episodes for RL training.
        output_model_dir: Directory to save the trained model.
        attn_implementation: Attention implementation to use.
        dtype: Data type for model weights.
        min_tokens_segment: Minimum number of tokens per segment.
        max_new_tokens_extractive_summary: Maximum number of tokens to generate for margin.
        max_new_tokens_final_answer: Maximum number of tokens to generate for final answer.
        max_new_tokens_classification: Maximum number of tokens to generate for classification.
        do_sample: Whether to use sampling for generation.
        top_p: Top-p sampling parameter.
        temperature: Temperature for generation.
        early_stopping: Whether to use early stopping.
        print_step_summary: Whether to print a summary for each step.
        user_header: User header for the model.
        generation_header: Generation header for the model.
        
    Returns:
        final_answer: The generated answer.
    """
    
    
    # Define model dtype
    model_dtype = torch.float32
    if dtype == "float16":
        model_dtype = torch.float16
    elif dtype == "float32":
        model_dtype = torch.float32
    elif dtype == "bfloat16":
        model_dtype = torch.bfloat16
    
    # Load tokenizer and model
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    tokenizer_rl = AutoTokenizer.from_pretrained(model_id_rl)
    tokenizer.add_special_tokens({'pad_token': '[PAD]'})
    tokenizer_rl.pad_token = tokenizer_rl.eos_token
    # quant_config = BitsAndBytesConfig(
    #     load_in_4bit=True,
    #     bnb_4bit_compute_dtype=torch.float16,  # could also try bfloat16
    #     bnb_4bit_use_double_quant=True,
    #     bnb_4bit_quant_type="nf4"  # best performance for LLaMA-like models
    # )
    
    # quantization_config = BitsAndBytesConfig(
    #     load_in_8bit=True,
    # )

    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        device_map="auto",
        attn_implementation=attn_implementation,
        torch_dtype=model_dtype,
        # quantization_config=quantization_config
    ).eval()

    model_rl = AutoModelForCausalLM.from_pretrained(
        model_id_rl,
        device_map="auto",
        attn_implementation=attn_implementation,
        torch_dtype=model_dtype,
        # quantization_config=quantization_config
    ).eval()

    # if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    
    # Load templates
    template_extractive_summary = """
    ```
    Given the above context, extract all information relevant to the query: "{query}". If the context is not relevant to the query, answer "I don't know."
    {generation_header}
    """.strip()
    
    template_classification = """
    {user_header}
    I asked an LLM assistant whether a piece of document is related to the query: "{query}". This is its answer: 
    ```text
    {answer}
    ```
    Should I save it for later? 
    Here are rules:
    - Answer YES if the answer contains information about the query. 
    - Answer NO if the answer says the piece isn't related to the query.

    Provide the answer in the format: <YES/NO>#<Explanation>. 
    Here is are example answers:

    YES#Yes, the information contains an excerpt from a book that is related to the question.
    NO#No, the LLM assistant concluded the information isn't relevant.

    Don't add any other comments, all your remarks should be included in the "Explanation" section.
    {generation_header}
    """.strip()
    
    template_system_message = """
    {user_header}
    ```
    """.strip()
    
    template_final_answer = """
    ```
    {margins}
    {query}
    {generation_header}
    """.strip()
    
    # Replace special tokens
    special_tokens = {
        "{user_header}": "<|start_header_id|>user<|end_header_id|>\n\n",
        "{generation_header}": "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n",
    }
    
    for token, replacement in special_tokens.items():
        template_extractive_summary = template_extractive_summary.replace(token, replacement)
        template_classification = template_classification.replace(token, replacement)
        template_system_message = template_system_message.replace(token, replacement)
        template_final_answer = template_final_answer.replace(token, replacement)
    
    # Create WIM inference
    # wim_inference = WIMInference(model, tokenizer)


    # Create WIM RL inference for training RL model
    wim_inference = WIMRLInference(model, tokenizer)
    
    # Create RL margin generator if needed
    rl_margin_generator = None
    if use_rl_generator or train_rl_generator:
        print("Use RL-based margin generator")
        rl_margin_generator = RLMarginGenerator(
            model_id=model_id_rl,
            reward_model_id=model_id_rl,  # Use same model for rewards
            device="cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available else "cpu",
        )
    else:
        print("Use standard WIM approach")
    # Train RL generator if requested
    if train_rl_generator and rl_margin_generator is not None:
        print("Training RL margin generator...")
        segments = wim_inference._chunk_text_to_segments(input_document, min_tokens_segment, tokenizer)
        rl_margin_generator.train_rl_margin_generator(
            segments=segments,
            query=query,
            extractive_summary_prompt=template_extractive_summary,
            classification_prompt=template_classification,
            num_episodes=num_episodes,
            max_new_tokens=max_new_tokens_extractive_summary,
            supporting_facts="Sandra"
        )
        
        # Save the trained model if requested
        if output_model_dir is not None:
            print(f"Saving trained model to {output_model_dir}...")
            rl_margin_generator.save_model(output_model_dir)
    
    # Create WIM RL inference
    wim_rl_inference = WIMRLInference(model, tokenizer, rl_margin_generator)
    
    # Process with WIM RL
    final_answer, positive_margins = wim_rl_inference.process_with_rl_margins(
        tokenizer=tokenizer,
        context=input_document,
        query=query,
        system_message=template_system_message,
        extractive_summary_prompt=template_extractive_summary,
        classification_prompt=template_classification,
        final_answer_prompt=template_final_answer,
        min_tokens_segment=min_tokens_segment,
        max_new_tokens_extractive_summary=max_new_tokens_extractive_summary,
        max_new_tokens_final_answer = max_new_tokens_final_answer,
        max_new_tokens_classification = max_new_tokens_classification,
        do_sample = do_sample,
        top_p = top_p,
        temperature = temperature,
        early_stopping = early_stopping,
        use_rl_generator = use_rl_generator,
        print_step_summary = True
    )

    return final_answer, positive_margins

In [8]:
from huggingface_hub import login

my_token = ""
login(token=my_token)

In [9]:
# import json

# file_path = "/kaggle/input/examples/babilong_8k.json"
# with open(file_path, "r") as file:
#     data = json.load(file)
# context = data.get("context", "")

In [10]:
import json

def extract_data_from_json(file_path):
    with open(file_path, 'r', encoding='utf-8') as file:
        data = json.load(file)

    extracted_samples = []

    for sample in data:
        question = sample.get('question')
        supporting_facts = sample.get('supporting_facts', [])
        context = sample.get('context', [])

        # Convert each context article into a single line string.
        # We assume that each item in context is a tuple or list where the second element contains the text as a list of strings.
        single_line_context_list = [' '.join(article[1]) for article in context]
        # Join all context strings into one full segment.
        segment = ' '.join(single_line_context_list)
        
        # Convert each supporting fact (typically a list) into a string, and then join them into one sentence.
        single_line_supporting_facts = ' '.join([' '.join(map(str, fact)) for fact in supporting_facts])

        # Build the output with keys 'query', 'segment', and 'supporting_facts'
        extracted_samples.append({
            'query': question,
            'segment': segment,
            'supporting_facts': single_line_supporting_facts
        })

    return extracted_samples

training_data = extract_data_from_json('/kaggle/input/wim-data/hotpot_dev_distractor_v1.json')


In [11]:
# i=0
# for data in training_data[:3]:
#     i+=1
#     print("datapoint", i)
#     print("query:", data['query'])
#     print("Supporting Facts:", data['supporting_facts'])
#     print("Segment:", data['segment'])

In [12]:
# Load templates
template_extractive_summary = """
```
Given the above context, extract all information relevant to the query: "{query}". If the context is not relevant to the query, answer "I don't know."
{generation_header}
""".strip()

template_classification = """
{user_header}
I asked an LLM assistant whether a piece of document is related to the query: "{query}". This is its answer: 
```text
{answer}
```
Should I save it for later? 
Here are rules:
- Answer YES if the answer contains information about the query. 
- Answer NO if the answer says the piece isn't related to the query.

Provide the answer in the format: <YES/NO>#<Explanation>. 
Here is are example answers:

YES#Yes, the information contains an excerpt from a book that is related to the question.
NO#No, the LLM assistant concluded the information isn't relevant.

Don't add any other comments, all your remarks should be included in the "Explanation" section.
{generation_header}
""".strip()

template_system_message = """
{user_header}
```
""".strip()

template_final_answer = """
```
{margins}
{query}
{generation_header}
""".strip()

# Replace special tokens
special_tokens = {
    "{user_header}": "<|start_header_id|>user<|end_header_id|>\n\n",
    "{generation_header}": "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n",
}

for token, replacement in special_tokens.items():
    template_extractive_summary = template_extractive_summary.replace(token, replacement)
    template_classification = template_classification.replace(token, replacement)
    template_system_message = template_system_message.replace(token, replacement)
    template_final_answer = template_final_answer.replace(token, replacement)

In [13]:


model_id = "h2oai/h2o-danube3-500m-base"  # or any compatible model identifier
model_id_rl = 'h2oai/h2o-danube3-500m-base'

rl_margin_generator = RLMarginGenerator(
            model_id=model_id_rl,
            reward_model_id=model_id_rl,  # Use same model for rewards
            device="cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available else "cpu",
        )
# Load tokenizer and model
model = AutoModelForCausalLM.from_pretrained(
        model_id,
        device_map="auto",
        # quantization_config=quantization_config
    )
tokenizer = AutoTokenizer.from_pretrained(model_id)
wim_inference = WIMRLInference(model, tokenizer, rl_margin_generator)
for i in range(10):
    query = training_data[i]['query']
    context = training_data[i]['segment']
    supporting_facts = training_data[i]['supporting_facts']
    print("query", query)
    print("input segment", context)
    print(type(context), "contex type")
    rl_margin_generator.train_rl_margin_generator(
        segments = context,
        query = query,
        extractive_summary_prompt = template_extractive_summary,
        classification_prompt = template_classification,
        num_episodes = 1,
        max_new_tokens = 50,
        supporting_facts = supporting_facts,
    )


tokenizer_config.json:   0%|          | 0.00/968 [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/493k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.79M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/414 [00:00<?, ?B/s]

policy model h2oai/h2o-danube3-500m-base


config.json:   0%|          | 0.00/685 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.03G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/132 [00:00<?, ?B/s]

ref model h2oai/h2o-danube3-500m-base
query Were Scott Derrickson and Ed Wood of the same nationality?
input segment Ed Wood is a 1994 American biographical period comedy-drama film directed and produced by Tim Burton, and starring Johnny Depp as cult filmmaker Ed Wood.  The film concerns the period in Wood's life when he made his best-known films as well as his relationship with actor Bela Lugosi, played by Martin Landau.  Sarah Jessica Parker, Patricia Arquette, Jeffrey Jones, Lisa Marie, and Bill Murray are among the supporting cast. Scott Derrickson (born July 16, 1966) is an American director, screenwriter and producer.  He lives in Los Angeles, California.  He is best known for directing horror films such as "Sinister", "The Exorcism of Emily Rose", and "Deliver Us From Evil", as well as the 2016 Marvel Cinematic Universe installment, "Doctor Strange." Woodson is a census-designated place (CDP) in Pulaski County, Arkansas, in the United States.  Its population was 403 at the 2010

Generating margins: 100%|██████████| 2/2 [00:03<00:00,  1.65s/it]


Segment: Ed Wood is a 1994 American biographical period comedy-drama film directed and produced by Tim Burton, and starring Johnny Depp as cult filmmaker Ed Wood. The film concerns the period in Wood's life when he made his best-known films as well as his relationship with actor Bela Lugosi, played by Martin Landau. Sarah Jessica Parker, Patricia Arquette, Jeffrey Jones, Lisa Marie, and Bill Murray are among the supporting cast. Scott Derrickson (born July 16, 1966) is an American director, screenwriter and producer. He lives in Los Angeles, California. He is best known for directing horror films such as "Sinister", "The Exorcism of Emily Rose", and "Deliver Us From Evil", as well as the 2016 Marvel Cinematic Universe installment, "Doctor Strange." Woodson is a census-designated place (CDP) in Pulaski County, Arkansas, in the United States. Its population was 403 at the 2010 census. It is part of the Little Rock–North Little Rock–Conway Metropolitan Statistical Area. Woodson and its ac

Generating margins: 100%|██████████| 2/2 [00:03<00:00,  1.57s/it]


Segment: Meet Corliss Archer, a program from radio's Golden Age, ran from January 7, 1943 to September 30, 1956. Although it was CBS's answer to NBC's popular "A Date with Judy", it was also broadcast by NBC in 1948 as a summer replacement for "The Bob Hope Show". From October 3, 1952 to June 26, 1953, it aired on ABC, finally returning to CBS. Despite the program's long run, fewer than 24 episodes are known to exist. Shirley Temple Black (April 23, 1928 – February 10, 2014) was an American actress, singer, dancer, businesswoman, and diplomat who was Hollywood's number one box-office draw as a child actress from 1935 to 1938. As an adult, she was named United States ambassador to Ghana and to Czechoslovakia and also served as Chief of Protocol of the United States. Janet Marie Waldo (February 4, 1920 – June 12, 2016) was an American radio and voice actress. She is best known in animation for voicing Judy Jetson, Nancy in "Shazzan", Penelope Pitstop, and Josie in "Josie and the Pussycat

Generating margins: 100%|██████████| 2/2 [00:03<00:00,  1.68s/it]


Segment: The Andre Norton Award for Young Adult Science Fiction and Fantasy is an annual award presented by the Science Fiction and Fantasy Writers of America (SFWA) to the author of the best young adult or middle grade science fiction or fantasy book published in the United States in the preceding year. It is named to honor prolific science fiction and fantasy author Andre Norton (1912–2005), and it was established by then SFWA president Catherine Asaro and the SFWA Young Adult Fiction committee and announced on February 20, 2005. Any published young adult or middle grade science fiction or fantasy novel is eligible for the prize, including graphic novels. There is no limit on word count. The award is presented along with the Nebula Awards and follows the same rules for nominations and voting; as the awards are separate, works may be simultaneously nominated for both the Andre Norton award and a Nebula Award. Victoria Hanley is an American young adult fantasy novelist. Her first three

Generating margins: 100%|██████████| 2/2 [00:03<00:00,  1.59s/it]


Segment: Esma Sultan (21 March 1873 – 7 May 1899) was an Ottoman princess, the daughter of Sultan Abdülaziz and his wife Gevheri Kadın, herself the daughter of Salih Bey Svatnba. She was the half-sister of Abdülmecid II, the last Caliph of the Muslim world. The Great Mosque of Algiers (Arabic: الجامع الكبير‎ ‎ , "Jemaa Kebir") or “Djama’a al-Kebir” (meaning Great Mosque) is a mosque in Algiers, Algeria, located very close to Algiers Harbor. An inscription on the minbar (منبر) or the pulpit testifies to fact that the mosque was built in 1097. It is also known by several other names such as Grand Mosque d'Alger, Djamaa al-Kebir, El Kebir Mosque and Jami Masjid. It is one of the few remaining examples of Almoravid architecture. It is the oldest mosque in Algiers and is said to be the oldest mosque in Algeria after Sidi Okba Mosque. It was built under sultan Ali ibn Yusuf. Its minaret dates from 1332 (1324 in some sources) and was built by the Ziyyanid Sultan of Tlemcen. The gallery at the

Generating margins: 100%|██████████| 2/2 [00:03<00:00,  1.55s/it]


Segment: Just Another Romantic Wrestling Comedy is a 2006 film starring April Hunter and Joanie Laurer. This Romantic comedy film was premiered at New Jersey and New York City on December 1, 2006 and was released on DVD in the United States and the United Kingdom on April 17, 2007. After the film's DVD release "Just Another Romantic Wrestling Comedy" won an "Honorable Mention" award at the New Jersey International Festival awards. The release is being handled by "Victory Multimedia". Kingston Morning is Dave Eggar's 4th solo release recorded in Brooklyn, New York; Kingston, Jamaica; and Big Stone Gap, Virginia; and released by Domo Records. "Itsbynne Reel" was nominated at the 53rd Grammy Awards for "Best Instrumental Arrangement". Nola is a 2003 American romantic comedy film written and directed by Alan Hruska. It depicts the struggle of a young woman trying to survive in New York City while looking for her birth father. It premiered in New York City on July 23, 2004. Adriana Trigiani

Generating margins: 100%|██████████| 2/2 [00:03<00:00,  1.59s/it]


Segment: South Korean boy group Shinee have received several awards and nominations for their music work. The group was formed by S.M. Entertainment in 2008 and released their first full-length album, "The Shinee World", on August 28, 2008, which won the Newcomer Album of the Year at the 23rd Golden Disk Awards. The first single released from the album was "Sanso Gateun Neo (Love Like Oxygen)" and won first place on "M Countdown" on September 18, 2008 making it the group's first win on Korean music shows since debut. Their second album "Lucifer" (2010) produced two singles, "Lucifer" and "Hello". For their outstanding choreography the group was nominated for the Best Dance Performance Award at the Mnet Asian Music Awards in 2010. " Lucifer" also won the Disk Bonsang Award at the 25th Golden Disk Awards as well as the Popularity Award. On March 21, 2012 the group released their fourth EP "Sherlock" for which the group was awarded another Disk Bonsang Award at the 27th Golden Disc Awards

Generating margins: 100%|██████████| 2/2 [00:03<00:00,  1.70s/it]


Segment: James P. Comer (born James Pierpont Comer, September 25, 1934 in East Chicago, Indiana) is currently the Maurice Falk Professor of Child Psychiatry at the Yale Child Study Center and has been since 1976. He is also an associate dean at the Yale School of Medicine. As one of the world's leading child psychiatrists, he is best known for his efforts to improve the scholastic performance of children from lower-income and minority backgrounds which led to the founding of the Comer School Development Program in 1968. His program has been used in more than 600 schools in eighty-two school districts. He is the author of ten books, including the autobiographical "Maggie’s American Dream: The Life and Times of a Black Family", 1988; "Leave No Child Behind: Preparing Today's Youth for Tomorrow's World", 2004; and his most recent book, "What I Learned in School: Reflections on Race, Child Development, and School Reform", 2009. He has also written more than 150 articles for Parents (magazi

Generating margins: 100%|██████████| 4/4 [00:06<00:00,  1.61s/it]


Segment: The Billings Bulls were a junior ice hockey organization based in Billings, Montana. They most recently played home games at the 550-seat Centennial Ice Arena and due to the arena's small size, the Bulls frequently sold out games. They previously played their home games in the Metrapark which had a max capacity of 9,000 for hockey games. However, a negotiating dispute with arena officials and local county commissioners resulted in the team losing its lease. The Robins Center is a 7,201-seat multi-purpose arena in Richmond, Virginia. Opened in 1972, the arena is home to the University of Richmond Spiders basketball. It hosted the ECAC South (now known as the Colonial Athletic Association) men's basketball tournament in 1983. It is named for E. Claiborne Robins Sr, class of 1931, who, along with his family, have been leading benefactors for the school. The opening of the Robins Center returning Spider basketball to an on-campus facility for the first time since the mid-1940s whe

Generating margins: 100%|██████████| 2/2 [00:03<00:00,  1.63s/it]


Segment: Annie Morton (born October 8, 1970) is an American model born in Pennsylvania. She has appeared on the covers of "British Vogue", "ID", "Marie Claire", and other magazines. She has been photographed by Helmut Newton; Peter Lindbergh; Annie Leibovitz; Richard Avedon; Juergen Teller; Paul Jasmin, Mary Ellen Mark and Terry Richardson, and modeled for Donna Karan, Givenchy, Guerlain, Chanel, "Harper's Bazaar", "Sports Illustrated" and Victoria's Secret. A long time vegetarian, an advocate for organic lifestyle choices and natural healthcare. She co-founded Tsi-La Organics, a "Green Luxury" company that creates and sells vegan, organic perfume and skin care products. Madonna is a biography by English author Andrew Morton, chronicling the life of American recording artist Madonna. The book was released in November 2001 by St. Martin's Press in the United States and in April 2002 by Michael O'Mara Books in the United Kingdom. Morton decided to write a biography on Madonna in 2000. Th

Generating margins: 100%|██████████| 2/2 [00:03<00:00,  1.58s/it]


Segment: Mendocino County, California, was the first jurisdiction in the United States to ban the cultivation, production or distribution of genetically modified organisms (GMOs). The ordinance, entitled Measure H, was passed by referendum on March 2, 2004. Initiated by the group "GMO Free Mendocino", the campaign was a highly publicized grassroots effort by local farmers and environmental groups who contend that the potential risks of GMOs to human health and the ecosystem have not yet been fully understood. The measure was met with opposition by several interest groups representing the biotechnology industry, The California Plant Health Association (now the Western Plant Health Association) and CropLife America, a Washington-based consortium whose clients represent some of the largest food distributors in the nation, including Monsanto, DuPont and Dow Chemical. Since the enactment of the ordinance, Mendocino County has been added to an international list of "GMO free zones." Pre-empt

In [14]:

# # Define your model and input parameters
# # model_id = "HachiML/TinyLlama2-jp-122M-FlashAttention2"  # or any compatible model identifier
# model_id = "h2oai/h2o-danube3-500m-base"  # or any compatible model identifier
# model_id_rl = 'h2oai/h2o-danube3-500m-base'
# # input_document = "Your long document text goes here..."
# query = "Where is Sandra?"


# # Set parameters for margin generation and RL training
# final_answer, positive_margins = run_wim_rl(
#     model_id=model_id,
#     model_id_rl=model_id_rl,
#     input_document=context,
#     query=query,
#     use_rl_generator=True,      # Use the RL-enhanced margin generator
#     train_rl_generator=True,   # Set True if you want to train the RL generator
#     num_episodes=5             # Number of training episodes (only used if training)
# )

# print("Final Answer:")
# print(final_answer)
# print("\nPositive Margins:")
# for margin in positive_margins:
#     print(margin)

In [15]:

# model_id = "shubvhamgore18218/WiM_llama_full_dataset"  # or any compatible model identifier
# model_id_rl = 'shubvhamgore18218/WiM_llama_full_dataset'

# rl_margin_generator = RLMarginGenerator(
#             model_id=model_id_rl,
#             reward_model_id=model_id_rl,  # Use same model for rewards
#             device="cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available else "cpu",
#         )
# for i in range(len(traning_data)):
#     query = traning_data[i]['question']
#     context = traning_data[i]['context']
#     supporting_facts = traning_data[i]['supporting_facts']

#     rl_margin_generator.train_rl_margin_generator(
#         segments = context,
#         query = query,
#         extractive_summary_prompt = template_extractive_summary,
#         classification_prompt = template_classification,
#         num_episodes = 10,
#         max_new_tokens = 100,
#         supporting_facts = supporting_facts,
#     )

In [16]:
print("donme")

donme
