In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from tqdm import tqdm
import numpy as np
from dataclasses import dataclass
from typing import List, Dict, Any, Optional
import copy
from torch.distributions import Categorical, kl_divergence


In [2]:
import torch

# Check if CUDA is available
if torch.cuda.is_available():
    # Get the GPU name
    gpu_name = torch.cuda.get_device_name(0)
    # Get GPU properties, including total memory (in bytes)
    gpu_properties = torch.cuda.get_device_properties(0)
    total_memory_gb = gpu_properties.total_memory / (1024 ** 3)
    
    print("GPU detected:", gpu_name)
    print(f"Total GPU Memory: {total_memory_gb:.2f} GB")
else:
    print("CUDA is not available. Running on CPU.")

GPU detected: Tesla P100-PCIE-16GB
Total GPU Memory: 15.89 GB


In [3]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache, Cache
from dataclasses import dataclass
from typing import List, Dict, Any, Callable
# import flash_attn

def _sample_top_p(logits, top_p=0.9):
    # First normalize the logits to prevent overflow/underflow
    logits = logits - torch.max(logits)
    
    # Convert to probabilities with softmax
    probs = torch.nn.functional.softmax(logits, dim=-1)
    
    # Apply a small epsilon to avoid numerical issues
    eps = 1e-10
    probs = torch.clamp(probs, min=eps)
    probs = probs / probs.sum()  # Re-normalize
    
    # Sort the probabilities
    sorted_probs, sorted_indices = torch.sort(probs, descending=True)
    
    # Compute cumulative probabilities
    cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
    
    # Remove tokens with cumulative probability above the threshold
    sorted_indices_to_remove = cumulative_probs > top_p
    # Keep the first token above threshold
    sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
    sorted_indices_to_remove[..., 0] = 0
    
    # Scatter back to original indices
    indices_to_remove = sorted_indices_to_remove.scatter(dim=-1, index=sorted_indices, src=sorted_indices_to_remove)
    probs = probs.masked_fill(indices_to_remove, 0.0)
    
    # Re-normalize after masking
    probs = probs / (probs.sum() + eps)
    
    # Check for invalid values before sampling
    if torch.isnan(probs).any() or torch.isinf(probs).any() or (probs < 0).any():
        # Fix invalid values
        probs = torch.nan_to_num(probs, nan=eps, posinf=1.0, neginf=eps)
        probs = torch.clamp(probs, min=eps)
        probs = probs / probs.sum()  # Re-normalize
    
    # Sample from the filtered distribution
    next_token = torch.multinomial(probs, num_samples=1)
    
    return next_token

class WIMInference:

    def __init__(self, model, tokenizer) -> None:
        self.model = model
        self.tokenizer = tokenizer
        self.wim_kv_cache = DynamicCache()
        self.classifier_kv_cache = DynamicCache()

    def _prefill_tokens(self,input_ids: torch.Tensor,attention_mask: torch.Tensor,cache_positions: torch.Tensor,kv_cache: Cache,):
        with torch.no_grad():
            outputs = self.model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                cache_position=cache_positions,
                use_cache=True,
                past_key_values=kv_cache,
            )
        # print(outputs)
        return outputs

    def shrink_kv_cache_from_end(self, new_size: int, kv_cache: Cache):

        def resize_tensor_list(token_list):
            for layer_idx in range(len(token_list)):
                token_list[layer_idx] = token_list[layer_idx][:, :, :new_size, :]

        resize_tensor_list(kv_cache.key_cache)
        resize_tensor_list(kv_cache.value_cache)
        kv_cache._seen_tokens = new_size

    def generate_text_with_kv_cache(self,max_new_tokens: int,previous_logits: torch.Tensor,do_sample: bool,top_p: float,temperature: float,early_stopping: bool,kv_cache: Cache,) -> str:
        generated_tokens = []

        # This is needed to create the cache_position tensor
        next_token_pos = kv_cache.get_seq_length()

        # Use the logits from the prefilling to generate the first token
        logits = previous_logits

        for _ in range(max_new_tokens):
            # Select the last token from the logits
            next_token_logits = logits[:, -1, :]
            if do_sample:
                # Divide the logits by the temperature
                next_token_logits = next_token_logits / temperature
                # Apply the softmax
                next_token_probs = torch.nn.functional.softmax(
                    next_token_logits, dim=-1
                )
                
                # Check for invalid values (moved here after next_token_probs is defined)
                if torch.isnan(next_token_probs).any() or torch.isinf(next_token_probs).any():
                    print("Invalid probabilities detected!")
                    print(f"Min prob: {next_token_probs.min()}, Max prob: {next_token_probs.max()}")
                    print(f"Contains NaN: {torch.isnan(next_token_probs).any()}")
                    print(f"Contains Inf: {torch.isinf(next_token_probs).any()}")
                    
                next_token = _sample_top_p(next_token_logits, top_p)  # Note: passing logits, not probs
            else:
                # Select the token with the highest probability
                next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
            assert next_token.size() == (1, 1)
            # Remove the batch dimension
            next_token = next_token.squeeze(0)
            generated_tokens.append(next_token)
            # Stop if we reached the EOS token
            if next_token.item() == self.tokenizer.eos_token_id and early_stopping:
                break
            # Use the generated token as input for the next step
            generation_input_ids = next_token.unsqueeze(-1)
            kv_cache_seq_len = kv_cache.get_seq_length()
            generation_attention_mask = torch.ones(
                (1, kv_cache_seq_len + 1), device=next_token.device, dtype=torch.long
            )
            generation_cache_position = torch.tensor(
                [next_token_pos], device=next_token.device
            )

            with torch.no_grad():
                # Get the model outputs
                outputs = self.model(
                    input_ids=generation_input_ids,
                    attention_mask=generation_attention_mask,
                    cache_position=generation_cache_position,
                    use_cache=True,
                    past_key_values=kv_cache,
                )
            logits = outputs.logits
            next_token_pos += 1

        generated_tokens = torch.cat(generated_tokens, dim=-1)
        # Decode the generated tokens
        decoded = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
        return decoded

    def prefill_text_with_kv_cache(self, text: str, kv_cache: Cache):
        # Tokenize the text
        inputs = self.tokenizer(text, return_tensors="pt")
        input_ids = inputs["input_ids"].to(self.model.device)
        seq_len = input_ids.size(1)
        attention_mask = inputs["attention_mask"].to(self.model.device)

        # If we have a KV-Cache, we need to extend the attention mask to account for tokens already in the KV-Cache
        if kv_cache.get_seq_length() > 0:
            kv_cache_seq_len = kv_cache.get_seq_length()
            attention_mask = torch.cat(
                [
                    torch.ones(
                        attention_mask.shape[0],
                        kv_cache_seq_len,
                        dtype=attention_mask.dtype,
                        device=attention_mask.device,
                    ),
                    attention_mask,
                ],
                dim=1,
            )

        # Generate the cache positions for the tokens to be prefilled
        cache_positions = torch.arange(
            kv_cache.get_seq_length(), kv_cache.get_seq_length() + seq_len
        ).to(self.model.device)
        outputs = self._prefill_tokens(input_ids, attention_mask, cache_positions, kv_cache)
        return kv_cache.get_seq_length(), seq_len, outputs
    


In [4]:
# def compute_reward(margins, query, supporting_facts, device):
#     rewards = []
#     for margin in margins:
#         coherence_score = min(0.3, 0.3 * sum(1 for word in margin.split() if len(word)> 2) / max(1, len(margin.split())))
#         relevance_score = 0.0
#         query_terms = set(query.lower().split())
#         margin_words = set(margin.lower().split())
#         supporting_facts_terms = set(supporting_facts.lower().split())

#         term_counts = {}
#         for term in query_terms:
#             term_counts[term] = sum(1 for word in margin_words if term in word)
        
#         if sum(term_counts.values()) > 0:
#             relevance_score = min(0.4, sum(term_counts.values()) / (len(query_terms) * 2) * 0.4)

#         # Length penalty to discourage extremely short responses
#         length_penalty = max(0.0, min(0.2, (len(margin) / 200) * 0.2))
        
#         # Base reward can be smaller to enhance differentiation
#         base_reward = 0.1

#         contains_supporting_facts = any(fact.lower() in margin.lower() for fact in supporting_facts)
#         penalty = -0.5 if not contains_supporting_facts else 1
        
#         reward = base_reward + coherence_score + relevance_score + length_penalty + penalty
#         rewards.append(reward)
#         return torch.tensor(rewards, device=device)
    

import re
import torch

def compute_reward(margins, query, supporting_facts, device):
    """
    margins:      List[str]   — the generated margins to score
    query:        str         — the query string
    supporting_facts: List[str] — list of ground-truth supporting facts
    device:       torch.device

    Returns: FloatTensor of shape (len(margins),)
    """
    rewards = []
    # Pre-tokenize once
    query_terms = set(re.findall(r'\w+', query.lower()))
    # Flatten all supporting_facts into a token set
    sf_terms = set()
    for fact in supporting_facts:
        sf_terms |= set(re.findall(r'\w+', fact.lower()))
    num_query_terms = max(1, len(query_terms))
    num_sf_terms    = max(1, len(sf_terms))

    for margin in margins:
        # tokenize margin
        words = re.findall(r'\w+', margin.lower())
        num_words = len(words)

        # 1) Coherence: fraction of “real” words (length>2), capped at 0.3
        if num_words > 0:
            lam = sum(1 for w in words if len(w) > 2) / num_words
        else:
            lam = 0.0
        coherence_score = 0.3 * lam

        # 2) Query relevance: fraction of query terms present, capped at 0.3
        present_q = sum(1 for t in query_terms if t in words)
        relevance_score = 0.3 * min(1.0, present_q / num_query_terms)

        # 3) Supporting‐fact coverage: fraction of SF tokens present, capped at 0.3
        present_sf = sum(1 for t in sf_terms if t in words)
        sf_score   = 0.3 * min(1.0, present_sf / num_sf_terms)

        # 4) Length bonus: longer margins up to 200 tokens get up to 0.1
        length_bonus = min(0.1, 0.1 * num_words / 200)

        # 5) Base floor
        base_reward = 0.05

        # Final reward in [0.05, 1.0]
        reward = base_reward \
               + coherence_score \
               + relevance_score \
               + sf_score \
               + length_bonus

        rewards.append(reward)

    return torch.tensor(rewards, device=device, dtype=torch.float32)


In [5]:
import nltk
from nltk.tokenize import sent_tokenize

# Optionally download the 'punkt' tokenizer data if not done before
nltk.download('punkt')

def chunk_text_to_segments(text, min_tokens_segment, tokenizer=None):
        """Chunk text into segments of approximately min_tokens_segment tokens."""
        
        if tokenizer is None:
            raise ValueError("Tokenizer must be provided to _chunk_text_to_segments")

        tokenizer = tokenizer
        segments = []
        current_segment = ""
        sentences = sent_tokenize(text)
        curr_tokens = 0
        
        for sentence in sentences:
            sent_tokens = len(tokenizer.encode(sentence))
            
            # if this sentence would overflow the current segment:
            if curr_tokens + sent_tokens > min_tokens_segment:
                # only save a non-empty segment
                if current_segment.strip():
                    segments.append(current_segment.strip())
                # start a fresh segment with this sentence
                current_segment = sentence + " "
                curr_tokens = sent_tokens
            else:
                # otherwise, keep accumulating
                current_segment += sentence + " "
                curr_tokens += sent_tokens
        
        # catch the final segment
        if current_segment.strip():
            segments.append(current_segment.strip())
        
        return segments

[nltk_data] Downloading package punkt to /usr/share/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [6]:
@dataclass
class RLConfig:
    """Configuration for the RL-based margin generation."""
    learning_rate: float = 5e-5
    kl_coef: float = 0.05
    discount_factor: float = 0.99
    ppo_epochs: int = 5
    ppo_mini_batch_size: int = 4
    max_grad_norm: float = 1
    clip_param: float = 0.1
    value_loss_coef: float = 0.5
    entropy_coef: float = 0.03

In [16]:
class RLMargin_Generator:
    def __init__(self, model_id, rl_config: RLConfig, device='cuda', tokenizer=None):
        self.device = device
        self.policy_model = model_id
        self.ref_model = model_id
        self.tokenizer = tokenizer
        self.avg_reward_computation = 0
        self.count = 0
        self.counter = 0
        # self.tokenizer = AutoTokenizer.from_pretrained(model_id)
        # self.policy_model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto",attn_implementation='eager', torch_dtype='bfloat16')
        # self.ref_model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto",attn_implementation='eager', torch_dtype='bfloat16')
        self.rl_config = rl_config

        for param in self.ref_model.parameters():
            param.requires_grad = False

        for param in self.policy_model.parameters():
            param.requires_grad = True

        self.optimizer = torch.optim.Adam(
            self.policy_model.parameters(),
            lr = self.rl_config.learning_rate,
        )
        if torch.isnan(torch.nn.utils.clip_grad_norm_(
            self.policy_model.parameters(),
            self.rl_config.max_grad_norm
        )):
            print("Warning: Nan gradiesnts detected")

        self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer, mode='max', factor=0.5, patience=2, verbose=2
        )
        
        # Initialize the WIM inference
        self.wim = WIMInference(self.policy_model, self.tokenizer)

        self.best_avg_reward = 0
        self.no_improvement_count = 0
        self.patience = 3
        self.total_steps = 0

    def _compute_kl_divergence(self, input_ids, attention_mask, labels):
        """Compute KL divergence between policy and reference model with better numerical stability."""
        # Small constant to avoid division by zero or log(0)
        batch_size, seq_len = input_ids.size()
        vocab_size = self.policy_model.config.vocab_size
        chunk_size = 128

        kl_divs = []
        for start in range(0, seq_len - 1, chunk_size):
            end = min(start + chunk_size, seq_len - 1)

            # +1 on slice for next-token prediction
            chunk_ids   = input_ids[:, start : end+1]
            chunk_mask  = attention_mask[:, start : end+1]

            # New policy logits
            policy_logits = self.policy_model(
                input_ids=chunk_ids,
                attention_mask=chunk_mask
            ).logits[:, :-1]                   # drop last prediction
            # Old policy logits (frozen)
            with torch.no_grad():
                ref_logits = self.ref_model(
                    input_ids=chunk_ids,
                    attention_mask=chunk_mask
                ).logits[:, :-1]

            # Flatten to (batch*chunk_len, vocab)
            flat_pol = policy_logits.reshape(-1, vocab_size)
            flat_ref = ref_logits.reshape(-1, vocab_size)

            # Compute KL per token in a numerically stable way
            dist_new = Categorical(logits=flat_pol)
            dist_old = Categorical(logits=flat_ref)
            kl_flat  = kl_divergence(dist_new, dist_old)    # shape: (batch*chunk_len,)

            # Reshape back to (batch, chunk_len)
            chunk_len = end - start
            kl_tokens = kl_flat.reshape(batch_size, chunk_len)

            # Mask out non-generated tokens
            # labels are offset by +1 compared to logits
            chunk_labels = labels[:, start+1 : end+1]      # (batch, chunk_len)
            valid_mask   = (chunk_labels != -100).float()

            # Sum & average only over valid positions
            sum_kl   = (kl_tokens * valid_mask).sum(dim=1)
            count    = valid_mask.sum(dim=1)
            # avoid div-by-zero
            count    = torch.where(count > 0, count, torch.ones_like(count))
            chunk_kl = sum_kl / count

            kl_divs.append(chunk_kl)

        # Final per-example KL is the mean over all chunks
        if kl_divs:
            kl_div = torch.stack(kl_divs, dim=1).mean(dim=1)  # (batch,)
        else:
            kl_div = torch.zeros(batch_size, device=self.device)

        return kl_div

    def _validate_and_clean_margin(self, margin: str) -> str:
        """Validate and clean the generated margin."""
        # Remove excessive strange characters
        strange_chars = "�'`@#$%^*"
        for char in strange_chars:
            margin = margin.replace(char, "")
        
        # Trim to first sentence if too long
        if len(margin) > 200 and "." in margin:
            first_part = margin.split(".")[0] + "."
            if len(first_part) > 50:  # Ensure we have a meaningful sentence
                margin = first_part
        
        # Ensure minimum length
        if len(margin.strip()) < 10:
            margin = "The text does not contain sufficient information related to the query."
            
        return margin
    
    def save_model_output(self, output_dir: str):
        """Save the trained model."""
        self.policy_model.save_pretrained(output_dir)
        self.tokenizer.save_pretrained(output_dir)
    
    def generate_rl_margin(
        self,
        segment,
        query,
        extractive_summary,
        supporting_facts,
        max_new_tokens=50,
        do_sample=True,
        top_p=0.9,
        temperature=0.7,
        early_stopping=True,
        remove_segment=True,
        train=False
    ):
        if train:
            segments = chunk_text_to_segments(text = segment, min_tokens_segment=512, tokenizer=self.tokenizer)
            for segment in segments:

                initial_kv_cache_size = len(self.wim.wim_kv_cache)

                prefilled_tokens_before_extractive_summary, _, _ = self.wim.prefill_text_with_kv_cache(
                    segment, self.wim.wim_kv_cache
                )
                _, _, extractive_summary_outputs = self.wim.prefill_text_with_kv_cache(
                    extractive_summary.format(query=query), self.wim.wim_kv_cache
                )
                margin = self.wim.generate_text_with_kv_cache(
                    max_new_tokens=max_new_tokens,
                    previous_logits=extractive_summary_outputs["logits"],
                    do_sample=do_sample,
                    top_p=top_p,
                    temperature=temperature,
                    early_stopping=early_stopping,
                    kv_cache=self.wim.wim_kv_cache,
                )
                margin = self._validate_and_clean_margin(margin)
                # Shrink the KV cache back to before the extractive summary prompt
                
                if(remove_segment):
                    self.wim.shrink_kv_cache_from_end(new_size=initial_kv_cache_size,kv_cache=self.wim.wim_kv_cache)
                else:
                    self.wim.shrink_kv_cache_from_end(new_size=prefilled_tokens_before_extractive_summary,kv_cache=self.wim.wim_kv_cache,)


                return margin
        else:
            initial_kv_cache_size = len(self.wim.wim_kv_cache)

            prefilled_tokens_before_extractive_summary, _, _ = self.wim.prefill_text_with_kv_cache(
                segment, self.wim.wim_kv_cache
            )
            _, _, extractive_summary_outputs = self.wim.prefill_text_with_kv_cache(
                extractive_summary.format(query=query), self.wim.wim_kv_cache
            )
            margin = self.wim.generate_text_with_kv_cache(
                max_new_tokens=max_new_tokens,
                previous_logits=extractive_summary_outputs["logits"],
                do_sample=do_sample,
                top_p=top_p,
                temperature=temperature,
                early_stopping=early_stopping,
                kv_cache=self.wim.wim_kv_cache,
            )
            margin = self._validate_and_clean_margin(margin)
            # Shrink the KV cache back to before the extractive summary prompt
            
            if(remove_segment):
                self.wim.shrink_kv_cache_from_end(new_size=initial_kv_cache_size,kv_cache=self.wim.wim_kv_cache)
            else:
                self.wim.shrink_kv_cache_from_end(new_size=prefilled_tokens_before_extractive_summary,kv_cache=self.wim.wim_kv_cache,)

            # reward = compute_reward(margin, query, supporting_facts, self.device)
            return margin


    def train(
            self,
            segments, 
            query, 
            supporting_facts,
            extractive_summary_prompt, 
            num_episodes, 
            max_new_tokens, 
            min_tokens_segment
        ):
        self.best_avg_reward = -float('inf')
        self.no_improvement_count = 0
        self.total_steps = 0

        self.policy_model.train()
        for episode in range(num_episodes):
            # if needed to change after certain episodes
            # if(episodes % 10 == 0):
            self.ref_model = copy.deepcopy(self.policy_model)
            print(f"\n--- Episode {episode+1}/{num_episodes} ---")
            # Sample a batch of segments
            chunked_segments = chunk_text_to_segments(text = segments, min_tokens_segment=min_tokens_segment, tokenizer=self.tokenizer)
            # print(f"Processing {len(chunked_segments)} segments...")
            # Generate margins using the current policy
            margins = []
            # is_relevant_list = []
            
            for segment in tqdm(chunked_segments, desc="Generating margins"):

                self.wim.shrink_kv_cache_from_end(new_size=0, kv_cache=self.wim.wim_kv_cache)

                margin = self.generate_rl_margin(
                    segment=segment,
                    query=query,
                    extractive_summary=extractive_summary_prompt,
                    max_new_tokens=max_new_tokens,
                    train=True,
                    supporting_facts=None
                )
                # print('segment:  ',segment)
                # print('margin: ',margin)
                margins.append(margin) 
                # torch.cuda.empty_cache()

            with torch.no_grad():
                self.avg_reward_computation += self.avg_reward_computation
                self.count += 1
                rewards = compute_reward(margins, query, supporting_facts, self.device)
                # Ensure rewards is a tensor on the correct device
                if not isinstance(rewards, torch.Tensor):
                    rewards = torch.tensor(rewards, device=self.device, dtype=torch.float32)
                elif rewards.device != self.device:
                    rewards = rewards.to(self.device)

            inputs = []
            for segment, margin, reward in zip(chunked_segments, margins, rewards):
                print(f"Query: {query}")
                print(f"Segment: {segment}")
                print(f"Margin: {margin}")
                print(f"Reward: {reward.item()}")
                context = segment + extractive_summary_prompt.format(query=query) + margin
                input_tokens = self.tokenizer(context, return_tensors="pt", padding=False).to(self.device)
                labels = input_tokens.input_ids.clone()
                context_without_margin = segment + extractive_summary_prompt.format(query=query)
                context_tokens = len(self.tokenizer(context_without_margin, return_tensors="pt").input_ids[0])
                labels[:, :context_tokens] = -100  
                
                inputs.append({
                    "input_ids": input_tokens.input_ids,
                    "attention_mask": input_tokens.attention_mask,
                    "labels": labels,
                })

            # PPO update
            # PPO update
            for ppo_epoch in range(self.rl_config.ppo_epochs):
                self.counter += 1
                if(self.counter >1000):
                    break
                print(f"ppo epoch: {ppo_epoch}")
                total_ppo_loss = 0.0
                total_kl_div = 0.0
                num_samples_processed = 0
                for i in range(len(inputs)):
                    self.optimizer.zero_grad()
                    
                    # Get model outputs with gradients enabled
                    outputs = self.policy_model(
                        input_ids=inputs[i]["input_ids"],
                        attention_mask=inputs[i]["attention_mask"],
                        labels=inputs[i]["labels"]
                    )
                    print('outputs.loss',outputs.loss)
                    # Model loss already has gradients
                    model_loss = outputs.loss
                    
                    # Compute KL divergence
                    kl_div = self._compute_kl_divergence(
                        inputs[i]["input_ids"],
                        inputs[i]["attention_mask"],
                        inputs[i]["labels"],
                    )
                    
                    # Compute policy loss - make sure reward is detached
                    reward_term = rewards[i].detach() * -1.0  # Negative since we want to maximize reward
                    kl_term = self.rl_config.kl_coef * kl_div.detach()
                    
                    # Combine with model loss to ensure gradients flow
                    policy_loss = model_loss + reward_term + kl_term
                    
                    # Backward pass
                    policy_loss.backward()
                    
                    # Clip gradients
                    torch.nn.utils.clip_grad_norm_(
                        self.policy_model.parameters(),
                        self.rl_config.max_grad_norm,
                    )
                    
                    # Update policy model
                    self.optimizer.step()
                    total_ppo_loss += policy_loss.item()
                    total_kl_div += kl_div.item() # Use .item() for logging scalars
                    num_samples_processed += 1
                    
                if num_samples_processed > 0:
                    avg_ppo_loss = total_ppo_loss / num_samples_processed
                    avg_kl_div = total_kl_div / num_samples_processed
                    # print(f"  PPO Epoch {ppo_epoch+1}/{self.rl_config.ppo_epochs} - Avg Loss: {avg_ppo_loss:.4f}, Avg KL: {avg_kl_div:.4f}")
                # else:
                #     print(f"  PPO Epoch {ppo_epoch+1}/{self.rl_config.ppo_epochs} - No samples processed.")

            # === Logging & Control Flow ===
            # Ensure rewards tensor is on CPU for numpy conversion if needed, and calculate mean
            avg_reward = rewards.mean().item() if rewards.numel() > 0 else 0.0
            # relevance_rate = sum(is_relevant_list) / len(is_relevant_list) # If is_relevant_list was computed

            print(f"Episode {episode+1} Summary:")
            print(f"  Average reward: {avg_reward:.4f}")
            # print(f"  Average KL divergence (last epoch): {avg_kl_div:.4f}") # KL reported per PPO epoch now
            # print(f"Relevance rate: {relevance_rate:.4f}") # If computed

            # Update LR Scheduler
            if self.scheduler:
                 # Check scheduler type - some step based on loss, some on metric
                 if isinstance(self.scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
                     self.scheduler.step(avg_reward)
                 else:
                     # For other schedulers like StepLR, CosineAnnealingLR, etc.
                     # Check if they need a metric or just step()
                     try:
                         self.scheduler.step()
                     except TypeError:
                         print("Warning: Scheduler might need a metric but received none.")


            # Early stopping check
            if avg_reward > self.best_avg_reward:
                print(f"  New best average reward: {avg_reward:.4f} (Improvement)")
                self.best_avg_reward = avg_reward
                self.no_improvement_count = 0
                # Save best model
                self.save_model_output(f"best_model_episode_{episode+1}") # Assuming this method exists
            else:
                self.no_improvement_count += 1
                print(f"  No improvement in average reward for {self.no_improvement_count} episode(s).")

            if self.no_improvement_count >= self.patience:
                print(f"No improvement for {self.patience} episodes. Stopping early.")
                break

            # <<< Efficiency Recommendation >>>
            # Clearing cache once per episode is much more reasonable than in inner loops.
            # Do this only if memory accumulation across episodes is observed.
            torch.cuda.empty_cache()

        print("Training finished.")

    def compute_avg_reward(self):
        return self.avg_reward_computation / self.count if self.count > 0 else 0.0




In [8]:
template_extractive_summary = """
```
Given the above context, extract all information relevant to the query: "{query}". If the context is not relevant to the query, answer "I don't know."
{generation_header}
""".strip()

template_system_message = """
{user_header}
```
""".strip()

template_final_answer = """
```
{margins}
{query}
{generation_header}
""".strip()

# Replace special tokens
special_tokens = {
    "{user_header}": "<|start_header_id|>user<|end_header_id|>\n\n",
    "{generation_header}": "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n",
}

for token, replacement in special_tokens.items():
    template_extractive_summary = template_extractive_summary.replace(token, replacement)
    template_system_message = template_system_message.replace(token, replacement)
    template_final_answer = template_final_answer.replace(token, replacement)

In [9]:
from huggingface_hub import login

# Paste your token here
my_token = ""
login(token=my_token)

In [26]:
import json

def extract_data_from_json(file_path):
    with open(file_path, 'r', encoding='utf-8') as file:
        data = json.load(file)

    extracted_samples = []
    for sample in data:

        question = sample.get('question')
        answer = sample.get('answer')
        supporting_facts = sample.get('supporting_facts', [])
        context = sample.get('context', [])

        # Convert each context article into a single line string.
        # We assume that each item in context is a tuple or list where the second element contains the text as a list of strings.
        single_line_context_list = [' '.join(article[1]) for article in context]
        # Join all context strings into one full segment.
        segment = ' '.join(single_line_context_list)
        
        # Convert each supporting fact (typically a list) into a string, and then join them into one sentence.
        single_line_supporting_facts = ' '.join(fact[0] for fact in supporting_facts)
        # Build the output with keys 'query', 'segment', and 'supporting_facts'
        extracted_samples.append({
            'query': question,
            'segment': segment,
            'supporting_facts': single_line_supporting_facts,
            'answer' : answer
        })

    return extracted_samples

# training_data = extract_data_from_json('/kaggle/input/win-data/hotpot_dev_distractor_v1.json')


In [11]:
# trainaing_data = training_data[:5924]
# test_data = training_data[5925:]

In [12]:
# # training_data[0]
# {'query': 'Were Scott Derrickson and Ed Wood of the same nationality?',
#  'segment': 'Ed Wood is a 1994 American biographical period comedy-drama film directed and produced by Tim Burton, and starring Johnny Depp as cult filmmaker Ed Wood.  The film concerns the period in Wood\'s life when he made his best-known films as well as his relationship with actor Bela Lugosi, played by Martin Landau.  Sarah Jessica Parker, Patricia Arquette, Jeffrey Jones, Lisa Marie, and Bill Murray are among the supporting cast. Scott Derrickson (born July 16, 1966) is an American director, screenwriter and producer.  He lives in Los Angeles, California.  He is best known for directing horror films such as "Sinister", "The Exorcism of Emily Rose", and "Deliver Us From Evil", as well as the 2016 Marvel Cinematic Universe installment, "Doctor Strange." Woodson is a census-designated place (CDP) in Pulaski County, Arkansas, in the United States.  Its population was 403 at the 2010 census.  It is part of the Little Rock–North Little Rock–Conway Metropolitan Statistical Area.  Woodson and its accompanying Woodson Lake and Wood Hollow are the namesake for Ed Wood Sr., a prominent plantation owner, trader, and businessman at the turn of the 20th century.  Woodson is adjacent to the Wood Plantation, the largest of the plantations own by Ed Wood Sr. Tyler Bates (born June 5, 1965) is an American musician, music producer, and composer for films, television, and video games.  Much of his work is in the action and horror film genres, with films like "Dawn of the Dead, 300, Sucker Punch," and "John Wick."  He has collaborated with directors like Zack Snyder, Rob Zombie, Neil Marshall, William Friedkin, Scott Derrickson, and James Gunn.  With Gunn, he has scored every one of the director\'s films; including "Guardians of the Galaxy", which became one of the highest grossing domestic movies of 2014, and its 2017 sequel.  In addition, he is also the lead guitarist of the American rock band Marilyn Manson, and produced its albums "The Pale Emperor" and "Heaven Upside Down". Edward Davis Wood Jr. (October 10, 1924 – December 10, 1978) was an American filmmaker, actor, writer, producer, and director. Deliver Us from Evil is a 2014 American supernatural horror film directed by Scott Derrickson and produced by Jerry Bruckheimer.  The film is officially based on a 2001 non-fiction book entitled "Beware the Night" by Ralph Sarchie and Lisa Collier Cool, and its marketing campaign highlighted that it was "inspired by actual accounts".  The film stars Eric Bana, Édgar Ramírez, Sean Harris, Olivia Munn, and Joel McHale in the main roles and was released on July 2, 2014. Adam Collis is an American filmmaker and actor.  He attended the Duke University from 1986 to 1990 and the University of California, Los Angeles from 2007 to 2010.  He also studied cinema at the University of Southern California from 1991 to 1997.  Collis first work was the assistant director for the Scott Derrickson\'s short "Love in the Ruins" (1995).  In 1998, he played "Crankshaft" in Eric Koyanagi\'s "Hundred Percent". Sinister is a 2012 supernatural horror film directed by Scott Derrickson and written by Derrickson and C. Robert Cargill.  It stars Ethan Hawke as fictional true-crime writer Ellison Oswalt who discovers a box of home movies in his attic that puts his family in danger. Conrad Brooks (born Conrad Biedrzycki on January 3, 1931 in Baltimore, Maryland) is an American actor.  He moved to Hollywood, California in 1948 to pursue a career in acting.  He got his start in movies appearing in Ed Wood films such as "Plan 9 from Outer Space", "Glen or Glenda", and "Jail Bait."  He took a break from acting during the 1960s and 1970s but due to the ongoing interest in the films of Ed Wood, he reemerged in the 1980s and has become a prolific actor.  He also has since gone on to write, produce and direct several films. Doctor Strange is a 2016 American superhero film based on the Marvel Comics character of the same name, produced by Marvel Studios and distributed by Walt Disney Studios Motion Pictures.  It is the fourteenth film of the Marvel Cinematic Universe (MCU).  The film was directed by Scott Derrickson, who wrote it with Jon Spaihts and C. Robert Cargill, and stars Benedict Cumberbatch as Stephen Strange, along with Chiwetel Ejiofor, Rachel McAdams, Benedict Wong, Michael Stuhlbarg, Benjamin Bratt, Scott Adkins, Mads Mikkelsen, and Tilda Swinton.  In "Doctor Strange", surgeon Strange learns the mystic arts after a career-ending car accident.',
#  'supporting_facts': 'Scott Derrickson 0 Ed Wood 0'}

In [27]:
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
import torch

all_samples = extract_data_from_json('/kaggle/input/wim-data/hotpot_dev_distractor_v1.json')
train_samples, val_samples = train_test_split(all_samples, test_size=0.2, random_state=42)

class HotpotDataset(Dataset):
    def __init__(self, samples):
        self.samples = samples
    def __len__(self):
        return len(self.samples)
    def __getitem__(self, idx):
        return self.samples[idx]
    
def collate_batch(batch):
    batch_queries = [sample['query'] for sample in batch]
    batch_segments = [sample['segment'] for sample in batch]
    batch_supporting_facts = [sample['supporting_facts'] for sample in batch]
    batch_answer = [sample['answer'] for sample in batch]
    return {
        'query': batch_queries,
        'segment': batch_segments,
        'supporting_facts': batch_supporting_facts,
        'answer':batch_answer
    }

train_ds = HotpotDataset(train_samples)
val_ds = HotpotDataset(val_samples)

train_dataloader = DataLoader(train_ds, batch_size=8, collate_fn=collate_batch)
val_dataloader = DataLoader(val_ds, batch_size=8, collate_fn=collate_batch)

In [14]:
template_extractive_summary = """
```
Given the above context, extract all information relevant to the query: "{query}". If the context is not relevant to the query, answer "I don't know."
""".strip()

In [17]:
model_id = "h2oai/h2o-danube3-500m-base"  # or any compatible model identifier
model_id_rl = 'h2oai/h2o-danube3-500m-base'
dir = '/kaggle/working/'

model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto",attn_implementation='eager', torch_dtype='bfloat16')
tokenizer = AutoTokenizer.from_pretrained(model_id, device_map="auto",attn_implementation='eager', torch_dtype='bfloat16')

rl_margin_generator = RLMargin_Generator(
            model_id=model,
            tokenizer=tokenizer,
            rl_config=RLConfig(),
            device="cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available else "cpu",
        )
count=0
for batch in train_dataloader:
    for segment, query, supporting_facts in zip(batch['segment'], batch['query'], batch['supporting_facts']):
        count+=1
        if count > 2:
            break
        rl_margin_generator.train(
            segments = segment,
            query = query,
            extractive_summary_prompt = template_extractive_summary,
            num_episodes = 5,
            max_new_tokens = 50,
            supporting_facts = supporting_facts,
            min_tokens_segment=256,
        )
print("avg reward for threshold")
rl_margin_generator.compute_avg_reward()
# for batch in val_dataloader:
#     for segment, query, supporting_facts in zip(batch['segment'], batch['query'], batch['supporting_facts']):
#         rl_margin_generator.train(
#             segments = segment,
#             query = query,
#             extractive_summary_prompt = template_extractive_summary,
#             num_episodes = 1,
#             max_new_tokens = 50,
#             supporting_facts = supporting_facts,
#             min_tokens_segment=256,
#         )

# for i in range(10):
#     query = training_data[i]['query']
#     context = training_data[i]['segment']
#     supporting_facts = training_data[i]['supporting_facts']
#     # print(supporting_facts)
#     # print(type(supporting_facts))
#     rl_margin_generator.train(
#         segments = context,
#         query = query,
#         extractive_summary_prompt = template_extractive_summary,
#         num_episodes = 1,
#         max_new_tokens = 50,
#         supporting_facts = supporting_facts,
#         min_tokens_segment=256,
#     )

# rl_margin_generator.save_model_output(dir)

# INCREASE THE EPOCHS PER MARGIN GENERATION


--- Episode 1/5 ---


Generating margins: 100%|██████████| 7/7 [00:06<00:00,  1.05it/s]


Query: Which animated film was released first, The Country Bears or The Wild?
Segment: Donkey Kong Country Returns is a side-scrolling platformer video game developed by Retro Studios and published by Nintendo for the Wii console. The game was released first in North America in November 2010, and in PAL regions and Japan the following month. A stereoscopic port of the game, titled Donkey Kong Country Returns 3D, was released for the Nintendo 3DS in May 2013, and in Japan the following month. The 34th Los Angeles Film Critics Association Awards, given by the Los Angeles Film Critics Association (LAFCA), honored the best in film for 2008. Pixar's animated film "WALL-E" won the Best Film award and became the first ever animated film to do so, however, the film lost the Best Animated Film award to "Waltz with Bashir". "Man! I Feel Like a Woman!" is a song recorded by Canadian singer-songwriter Shania Twain taken from her third studio album, "Come On Over" (1997).
Margin: 
Answer: The relev

Generating margins: 100%|██████████| 7/7 [00:08<00:00,  1.15s/it]


Query: Which animated film was released first, The Country Bears or The Wild?
Segment: Donkey Kong Country Returns is a side-scrolling platformer video game developed by Retro Studios and published by Nintendo for the Wii console. The game was released first in North America in November 2010, and in PAL regions and Japan the following month. A stereoscopic port of the game, titled Donkey Kong Country Returns 3D, was released for the Nintendo 3DS in May 2013, and in Japan the following month. The 34th Los Angeles Film Critics Association Awards, given by the Los Angeles Film Critics Association (LAFCA), honored the best in film for 2008. Pixar's animated film "WALL-E" won the Best Film award and became the first ever animated film to do so, however, the film lost the Best Animated Film award to "Waltz with Bashir". "Man! I Feel Like a Woman!" is a song recorded by Canadian singer-songwriter Shania Twain taken from her third studio album, "Come On Over" (1997).
Margin: 
Answer: The relev

Generating margins: 100%|██████████| 7/7 [00:07<00:00,  1.12s/it]


Query: Which animated film was released first, The Country Bears or The Wild?
Segment: Donkey Kong Country Returns is a side-scrolling platformer video game developed by Retro Studios and published by Nintendo for the Wii console. The game was released first in North America in November 2010, and in PAL regions and Japan the following month. A stereoscopic port of the game, titled Donkey Kong Country Returns 3D, was released for the Nintendo 3DS in May 2013, and in Japan the following month. The 34th Los Angeles Film Critics Association Awards, given by the Los Angeles Film Critics Association (LAFCA), honored the best in film for 2008. Pixar's animated film "WALL-E" won the Best Film award and became the first ever animated film to do so, however, the film lost the Best Animated Film award to "Waltz with Bashir". "Man! I Feel Like a Woman!" is a song recorded by Canadian singer-songwriter Shania Twain taken from her third studio album, "Come On Over" (1997).
Margin: 
Answer: The relev

Generating margins: 100%|██████████| 7/7 [00:07<00:00,  1.13s/it]


Query: Which animated film was released first, The Country Bears or The Wild?
Segment: Donkey Kong Country Returns is a side-scrolling platformer video game developed by Retro Studios and published by Nintendo for the Wii console. The game was released first in North America in November 2010, and in PAL regions and Japan the following month. A stereoscopic port of the game, titled Donkey Kong Country Returns 3D, was released for the Nintendo 3DS in May 2013, and in Japan the following month. The 34th Los Angeles Film Critics Association Awards, given by the Los Angeles Film Critics Association (LAFCA), honored the best in film for 2008. Pixar's animated film "WALL-E" won the Best Film award and became the first ever animated film to do so, however, the film lost the Best Animated Film award to "Waltz with Bashir". "Man! I Feel Like a Woman!" is a song recorded by Canadian singer-songwriter Shania Twain taken from her third studio album, "Come On Over" (1997).
Margin: 
Answer: The relev

Generating margins: 100%|██████████| 7/7 [00:07<00:00,  1.13s/it]


Query: Which animated film was released first, The Country Bears or The Wild?
Segment: Donkey Kong Country Returns is a side-scrolling platformer video game developed by Retro Studios and published by Nintendo for the Wii console. The game was released first in North America in November 2010, and in PAL regions and Japan the following month. A stereoscopic port of the game, titled Donkey Kong Country Returns 3D, was released for the Nintendo 3DS in May 2013, and in Japan the following month. The 34th Los Angeles Film Critics Association Awards, given by the Los Angeles Film Critics Association (LAFCA), honored the best in film for 2008. Pixar's animated film "WALL-E" won the Best Film award and became the first ever animated film to do so, however, the film lost the Best Animated Film award to "Waltz with Bashir". "Man! I Feel Like a Woman!" is a song recorded by Canadian singer-songwriter Shania Twain taken from her third studio album, "Come On Over" (1997).
Margin: 
Answer: The relev

Generating margins: 100%|██████████| 7/7 [00:07<00:00,  1.10s/it]


Query: Who was the captain of the only battleship to provide gunfire support during the Vietnam War?
Segment: USS "New Jersey" (BB-62) ("Big J" or "Black Dragon") is an "Iowa"-class battleship , and was the second ship of the United States Navy to be named after the US state of New Jersey. "New Jersey" earned more battle stars for combat actions than the other three completed "Iowa"-class battleships, and was the only US battleship providing gunfire support during the Vietnam War. Operation Market Time was the United States Navy and South Vietnam’s successful effort begun in 1965 to stop the flow of troops, war material, and supplies by sea, coast, and rivers, from North Vietnam into parts of South Vietnam during the Vietnam War. Also participating in Operation Market Time were United States Coast Guard Squadron One and Squadron Three. The Coast Guard provided heavily armed 82 ft patrol boats and large cutters that included 5" cannons used in battle and gunfire support. USS "Trippe" (F

Generating margins: 100%|██████████| 7/7 [00:07<00:00,  1.14s/it]


Query: Who was the captain of the only battleship to provide gunfire support during the Vietnam War?
Segment: USS "New Jersey" (BB-62) ("Big J" or "Black Dragon") is an "Iowa"-class battleship , and was the second ship of the United States Navy to be named after the US state of New Jersey. "New Jersey" earned more battle stars for combat actions than the other three completed "Iowa"-class battleships, and was the only US battleship providing gunfire support during the Vietnam War. Operation Market Time was the United States Navy and South Vietnam’s successful effort begun in 1965 to stop the flow of troops, war material, and supplies by sea, coast, and rivers, from North Vietnam into parts of South Vietnam during the Vietnam War. Also participating in Operation Market Time were United States Coast Guard Squadron One and Squadron Three. The Coast Guard provided heavily armed 82 ft patrol boats and large cutters that included 5" cannons used in battle and gunfire support. USS "Trippe" (F

Generating margins: 100%|██████████| 7/7 [00:07<00:00,  1.13s/it]


Query: Who was the captain of the only battleship to provide gunfire support during the Vietnam War?
Segment: USS "New Jersey" (BB-62) ("Big J" or "Black Dragon") is an "Iowa"-class battleship , and was the second ship of the United States Navy to be named after the US state of New Jersey. "New Jersey" earned more battle stars for combat actions than the other three completed "Iowa"-class battleships, and was the only US battleship providing gunfire support during the Vietnam War. Operation Market Time was the United States Navy and South Vietnam’s successful effort begun in 1965 to stop the flow of troops, war material, and supplies by sea, coast, and rivers, from North Vietnam into parts of South Vietnam during the Vietnam War. Also participating in Operation Market Time were United States Coast Guard Squadron One and Squadron Three. The Coast Guard provided heavily armed 82 ft patrol boats and large cutters that included 5" cannons used in battle and gunfire support. USS "Trippe" (F

Generating margins: 100%|██████████| 7/7 [00:07<00:00,  1.12s/it]


Query: Who was the captain of the only battleship to provide gunfire support during the Vietnam War?
Segment: USS "New Jersey" (BB-62) ("Big J" or "Black Dragon") is an "Iowa"-class battleship , and was the second ship of the United States Navy to be named after the US state of New Jersey. "New Jersey" earned more battle stars for combat actions than the other three completed "Iowa"-class battleships, and was the only US battleship providing gunfire support during the Vietnam War. Operation Market Time was the United States Navy and South Vietnam’s successful effort begun in 1965 to stop the flow of troops, war material, and supplies by sea, coast, and rivers, from North Vietnam into parts of South Vietnam during the Vietnam War. Also participating in Operation Market Time were United States Coast Guard Squadron One and Squadron Three. The Coast Guard provided heavily armed 82 ft patrol boats and large cutters that included 5" cannons used in battle and gunfire support. USS "Trippe" (F

Generating margins: 100%|██████████| 7/7 [00:07<00:00,  1.12s/it]


Query: Who was the captain of the only battleship to provide gunfire support during the Vietnam War?
Segment: USS "New Jersey" (BB-62) ("Big J" or "Black Dragon") is an "Iowa"-class battleship , and was the second ship of the United States Navy to be named after the US state of New Jersey. "New Jersey" earned more battle stars for combat actions than the other three completed "Iowa"-class battleships, and was the only US battleship providing gunfire support during the Vietnam War. Operation Market Time was the United States Navy and South Vietnam’s successful effort begun in 1965 to stop the flow of troops, war material, and supplies by sea, coast, and rivers, from North Vietnam into parts of South Vietnam during the Vietnam War. Also participating in Operation Market Time were United States Coast Guard Squadron One and Squadron Three. The Coast Guard provided heavily armed 82 ft patrol boats and large cutters that included 5" cannons used in battle and gunfire support. USS "Trippe" (F

0.0

In [None]:
# from transformers import AutoModel

# model = AutoModel.from_pretrained('h2oai/h2o-danube3-500m-base')

In [16]:
# model

In [30]:
template_final_answer = """
```
{margins}
{query}
""".strip()

# special_tokens = {
#     "{user_header}": "<|start_header_id|>user<|end_header_id|>\n\n",
#     "{generation_header}": "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n",
# }

# for token, replacement in special_tokens.items():
#     template_extractive_summary = template_extractive_summary.replace(token, replacement)
#     template_system_message = template_system_message.replace(token, replacement)
#     template_final_answer = template_final_answer.replace(token, replacement)



In [31]:
model_id = "h2oai/h2o-danube3-500m-base"  # or any compatible model identifier
model_id_rl = 'h2oai/h2o-danube3-500m-base'
dir = '/kaggle/working/'

model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto",attn_implementation='eager', torch_dtype='bfloat16')
tokenizer = AutoTokenizer.from_pretrained(model_id, device_map="auto",attn_implementation='eager', torch_dtype='bfloat16')

rl_margin_generator = RLMargin_Generator(
            model_id=model,
            tokenizer = tokenizer,
            rl_config=RLConfig(),
            device="cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available else "cpu",
        )

wim_inference = WIMInference(
    model = model,
    tokenizer = tokenizer,
)
final_output = []
count = 0
for batch in val_dataloader:
    for segment, query, supporting_facts, answer in zip(batch['segment'], batch['query'], batch['supporting_facts'], batch['answer']):
        count += 1
        if(count >100):
            break
        query = query
        context = segment
        # supporting_facts = test_data[i]['supporting_facts']
        print('context:', context)
        print('query:', query)
        # chunk segment
        # chunked prefill
        # margins .append with certain threshold
        # 
        margins = []
        segments = chunk_text_to_segments(text = context, min_tokens_segment=1024, tokenizer=tokenizer)
        for segment in segments:
            margin = rl_margin_generator.generate_rl_margin(
                segment = segment,
                query = query,
                extractive_summary = template_extractive_summary,
                supporting_facts = supporting_facts,
                max_new_tokens = 50,
                do_sample=True,
                top_p=0.9,
                temperature=0.7,
                early_stopping=True,
                remove_segment=False
            )

            print(f'margin: {margin}')
            print(f'segment: {segment}')
            margins.append(margin)

        # rewards = compute_reward(margins, query, supporting_facts, device="cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available else "cpu")
        # print('rewards:', rewards)

        formatted_final_answer = template_final_answer.format(
            margins =margins, query = query
        )

        _, _, final_answer_prefill_outputs = wim_inference.prefill_text_with_kv_cache(
            formatted_final_answer, wim_inference.wim_kv_cache
        )

        final_answer = wim_inference.generate_text_with_kv_cache(
            max_new_tokens=20,
            previous_logits=final_answer_prefill_outputs["logits"],
            do_sample=True,
            top_p=0.9,
            temperature=0.7,
            early_stopping=True,
            kv_cache=wim_inference.wim_kv_cache,
        )

        print('final_answer', final_answer)
        final_output.append({
            "query": query,
            "answer": answer,
            "final_answer": final_answer
        })
        print('----------------------------------------------------------------------------------------------------')

context: Constantin Medien AG (formerly EM.Entertainment and EM.TV & Merchandising AG, then EM.TV AG, and finally em.sport media ag) is a German media group, based in Ismaning near Munich, active in the area of sports, film and event marketing to medium-sized media companies. VIVA Polska (earlier "VIVApolska!")  is a Polish 24h music and entertainment channel from Viacom International Media Networks Polska.  The channel was officially launched on June 10, 2000 by the German VIVA Media AG. Viva (stylised as VIVA) is a music television channel in the United Kingdom and Ireland, owned by VIVA Media and thereby Viacom International Media Networks Europe.  The channel launched on 26 October 2009, replacing TMF. Blic (Cyrillic: Блиц, ] ) is a daily middle-market tabloid newspaper in Serbia.  Founded in 1996, "Blic" is owned by Ringier Axel Springer Media AG, a joint venture between Ringier media corporation from Switzerland and Axel Springer AG from Germany. Qontis is a Switzerland based onl

OutOfMemoryError: CUDA out of memory. Tried to allocate 3.04 GiB. GPU 0 has a total capacity of 15.89 GiB of which 2.89 GiB is free. Process 19960 has 13.00 GiB memory in use. Of the allocated memory 9.47 GiB is allocated by PyTorch, and 3.23 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [33]:
len(final_output)

33

In [34]:
torch.cuda.empty_cache()

In [35]:
import json

# … your existing loop that populates final_output …

# after you've built final_output:
output_path = "final_output.json"
with open(output_path, "w", encoding="utf-8") as f:
    json.dump(final_output, f, ensure_ascii=False, indent=2)

print(f"Wrote {len(final_output)} records to {output_path}")

Wrote 33 records to final_output.json


In [39]:
import json
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

# 1) Load your model & tokenizer
model_id = "h2oai/h2o-danube3-500m-base"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",
    torch_dtype=torch.float16
)
device = model.device

# 2) Prepare to collect outputs
final_output = []
max_items = 33  # optional: cap how many you process
count = 0

# 3) Iterate your validation set
with torch.no_grad():
    for batch in val_dataloader:
        # Assumes your dataloader yields `batch['query']` and `batch['answer']`
        for segment, query, answer in zip(batch['segment'], batch['query'],  batch['answer']):            
            count += 1
            if count > max_items:
                break
            query = segment+query
            # tokenize + move to device
            input_ids = tokenizer(query, return_tensors="pt").input_ids.to(device)

            # generate
            output_ids = model.generate(
                input_ids,
                # max_length=100,
                do_sample=True,
                top_p=0.9,
                temperature=0.7
            )
            generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)

            # record
            final_output.append({
                "query": query,
                "answer": answer,
                "generated_text": generated_text
            })

        if count > max_items:
            break

# 4) Write to disk
output_path = "final_output_base_model.json"
with open(output_path, "w", encoding="utf-8") as f:
    json.dump(final_output, f, ensure_ascii=False, indent=2)

print(f"Wrote {len(final_output)} records to {output_path}")


Wrote 33 records to final_output_base_model.json


In [40]:
! pip install bert-score

Collecting bert-score
  Downloading bert_score-0.3.13-py3-none-any.whl.metadata (15 kB)
Downloading bert_score-0.3.13-py3-none-any.whl (61 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.1/61.1 kB[0m [31m2.7 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: bert-score
Successfully installed bert-score-0.3.13


In [43]:
# First, install bert-score if you haven't already:
# pip install bert-score

import json
from bert_score import score

# 1) Load your data
with open("final_output_base_model.json", "r", encoding="utf-8") as f:
    data = json.load(f)

# 2) Extract lists of references and candidates
#    adjust the keys if yours are named differently
refs  = [item["answer"]             for item in data]
cands = [item.get("final_answer",  item.get("generated_text")) for item in data]

# 3) Compute BERTScore
#    lang="en" works for English; you can pick model_type="roberta-large" etc.
P, R, F1 = score(
    cands,
    refs,
    lang="en",
    model_type="bert-base-uncased",
    batch_size=32,
    verbose=True,
    rescale_with_baseline=True
)

# 4) Attach scores back to each record
for i, item in enumerate(data):
    item["bert_score_precision"] = P[i].item()
    item["bert_score_recall"]    = R[i].item()
    item["bert_score_f1"]        = F1[i].item()

# 5) Dump augmented data to a new file
out_path = "final_output_with_bertscore_base_model.json"
with open(out_path, "w", encoding="utf-8") as f:
    json.dump(data, f, ensure_ascii=False, indent=2)

print(f"Computed BERTScore for {len(data)} examples and wrote to {out_path}")


calculating scores...
computing bert embedding.


  0%|          | 0/3 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/2 [00:00<?, ?it/s]

done in 1.65 seconds, 19.95 sentences/sec
Computed BERTScore for 33 examples and wrote to final_output_with_bertscore_base_model.json


In [44]:
import json
import statistics

# 1) Load your augmented JSON
with open("final_output_with_bertscore.json", "r", encoding="utf-8") as f:
    data = json.load(f)

# 2) Gather all scores
precisions = [item["bert_score_precision"] for item in data]
recalls    = [item["bert_score_recall"]    for item in data]
f1s        = [item["bert_score_f1"]        for item in data]

# 3) Compute and print averages
print(f"Average Precision: {statistics.mean(precisions):.4f}")
print(f"Average Recall   : {statistics.mean(recalls):.4f}")
print(f"Average F1       : {statistics.mean(f1s):.4f}")


Average Precision: -0.1566
Average Recall   : -0.0192
Average F1       : -0.0986


In [45]:
import json
import statistics

# 1) Load your augmented JSON
with open("/kaggle/working/final_output_with_bertscore_base_model.json", "r", encoding="utf-8") as f:
    data = json.load(f)

# 2) Gather all scores
precisions = [item["bert_score_precision"] for item in data]
recalls    = [item["bert_score_recall"]    for item in data]
f1s        = [item["bert_score_f1"]        for item in data]

# 3) Compute and print averages
print(f"Average Precision: {statistics.mean(precisions):.4f}")
print(f"Average Recall   : {statistics.mean(recalls):.4f}")
print(f"Average F1       : {statistics.mean(f1s):.4f}")

Average Precision: -0.2754
Average Recall   : 0.2099
Average F1       : -0.1487
