In [1]:

import torch
import os 
import json
json_path = 'env_config.json'
with open(json_path, 'r') as file:
    env_config = json.load(file)

hf_home = env_config['HF_HOME']
# Set the HF_HOME environment variable
os.environ['HF_HOME'] = hf_home
# Set the access token to huggingface hub
access_token = env_config['access_token']
os.environ['HUGGINGFACE_HUB_TOKEN'] = access_token
os.environ['CUDA_VISIBLE_DEVICES'] = '2'


In [2]:
from llmexp.llm.smollm import LLMWrapper, Template
from accelerate import Accelerator

accelerator = Accelerator()
device = accelerator.device

# checkpoint = "meta-llama/Llama-3.2-1B-Instruct"
# checkpoint = "meta-llama/Llama-3.2-3B-Instruct"
checkpoint = "meta-llama/Meta-Llama-3-8B-Instruct"
# checkpoint = "HuggingFaceTB/SmolLM-1.7B-Instruct"

llm = LLMWrapper(checkpoint, device=device, access_token=access_token)
tokenizer = llm.tokenizer
template = Template(tokenizer, task='qa')

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [3]:
# from llmexp.explainer.mab_explainer import MABExplainer
# from llmexp.explainer.mab_explainer_advantage import MABExplainerAdvantage
from llmexp.explainer.mab_explainer import MABExplainer

mab_explainer = MABExplainer(llm, tokenizer, template)
# mab_explainer = MABExplainerAdvantage(llm, tokenizer)



In [4]:
from llmexp.utils.data_utils import LLMDataset
import numpy as np

dataset = LLMDataset("hotpot_qa", split="train")

from llmexp.utils.hotpot_helper import HotpotHelper, HotpotSample
# helper = HotpotHelper()


In [5]:
idx = 0
hpsample = HotpotSample(dataset[idx])
sentences = hpsample.flattened_contexts
question = hpsample.question
query = sentences + [question]
query

response = mab_explainer.get_response(sentences, question)
print(response)

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


According to the context, Arthur's Magazine was started in 1844, while First for Women was started in 1989. Therefore, Arthur's Magazine was started first.<|eot_id|>


In [6]:
print("Context:\n", "\n".join([f"{idx} {x}" for idx, x in enumerate(hpsample.flattened_contexts[:5])]), "\n...")
print("Question:", hpsample.question)
print("supporting facts indices:", hpsample.flattened_supporting_facts_indices)

Context:
 0 Radio City is India's first private FM radio station and was started on 3 July 2001.
1  It broadcasts on 91.1 (earlier 91.0 in most cities) megahertz from Mumbai (where it was started in 2004), Bengaluru (started first in 2001), Lucknow and New Delhi (since 2003).
2  It plays Hindi, English and regional songs.
3  It was launched in Hyderabad in March 2006, in Chennai on 7 July 2006 and in Visakhapatnam October 2007.
4  Radio City recently forayed into New Media in May 2008 with the launch of a music portal - PlanetRadiocity.com that offers music related news, videos, songs, and other music-related features. 
...
Question: Which magazine was started first Arthur's Magazine or First for Women?
supporting facts indices: [25, 33]


In [7]:
theta = mab_explainer.thompson_sampling(sentences, question, response, n_iter=20)
# theta = mab_explainer.thompson_sampling_gaussian(sentences, response, n_iter=100)
# theta = mab_explainer.policy_gradient_with_advantage(sentences, response, n_iter=30, lr=0.1)
theta

We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)


array([-0.30644426, -0.2688241 , -0.39993867, -0.39629516, -0.24347259,
       -0.46292529, -0.107263  , -0.48387897, -0.34414169, -0.20506722,
       -0.06440838, -0.49595651, -0.52092165, -0.36010492, -0.0055213 ,
       -0.28138241, -0.39494956, -0.43563432, -0.20169032, -0.27577958,
       -0.48041481, -0.36646467, -0.43796539, -0.39905867, -0.37742791,
       -0.02177167, -0.02649677, -0.23753779, -0.38597018, -0.32334548,
       -0.29833668, -0.64995503, -0.26415506, -0.40566549,  0.05760736,
       -0.18212007, -0.50258273, -0.41072154, -0.43376026, -0.28246805,
       -0.63809258, -0.396759  , -0.36103016, -0.61072016, -0.49916637,
       -0.31854135])

In [8]:
# combine theta and messages, create a pandas dataframe
data_dict = {
    "theta": theta,
    "sentences": sentences
}
import pandas as pd 
df = pd.DataFrame(data_dict)

# Set pandas display options to show full text
pd.set_option('display.max_colwidth', None)  # No limit on column width
pd.set_option('display.width', None)         # No limit on display width
pd.set_option('display.max_rows', None)      # Show all rows

# Sort the dataframe
df_sorted = df.sort_values(by='theta', ascending=False)

import textwrap
from IPython.display import display, Markdown

width = 100
# Format question and answer with proper wrapping
wrapped_question = textwrap.fill(question, width=width)
wrapped_response = textwrap.fill(response, width=width)

# Print with f-strings
print(f"Q: {wrapped_question}")
print(f"A: {wrapped_response}")
print(f"Supporting indices: {hpsample.flattened_supporting_facts_indices}")
print(f"Ground truth answer: {hpsample.answer}")

# # Alternatively, for better formatting in notebooks:
# display(Markdown(f"**Q:** {wrapped_question}"))
# display(Markdown(f"**A:** {wrapped_response}"))
# Display the sorted dataframe
df_sorted

Q: Which magazine was started first Arthur's Magazine or First for Women?
A: According to the context, Arthur's Magazine was started in 1844, while First for Women was started
in 1989. Therefore, Arthur's Magazine was started first.<|eot_id|>
Supporting indices: [25, 33]
Ground truth answer: Arthur's Magazine


Unnamed: 0,theta,sentences
34,0.057607,The magazine was started in 1989.
14,-0.005521,until they signed to Warner Bros.
25,-0.021772,Arthur's Magazine (1844–1846) was an American literary periodical published in Philadelphia in the 19th century.
26,-0.026497,"Edited by T.S. Arthur, it featured work by Edgar A. Poe, J.H. Ingraham, Sarah Josepha Hale, Thomas G. Spear, and others."
10,-0.064408,"In 1932, Albania joined FIFA (during the 12–16 June convention ) And in 1954 she was one of the founding members of UEFA."
6,-0.107263,Abraham Thomas is the CEO of the company.
35,-0.18212,"It is based in Englewood Cliffs, New Jersey."
18,-0.20169,"Records' fifth-biggest-selling-digital song of 2014, with 1.3 million downloads sold."
9,-0.205067,"Albanian National Team was founded on June 6, 1930, but Albania had to wait 16 years to play its first international match and then defeated Yugoslavia in 1946."
27,-0.237538,"In May 1846 it was merged into ""Godey's Lady's Book""."


In [9]:
from bert_score import score

# Example: original and perturbed responses
response1 = "The Oberoi family is part of a hotel company that has a head office in Delhi."
response2 = "The Oberoi Group is a hotel company with its head office in Delhi"

# BERTScore expects lists of strings
P, R, F1 = score([response2], [response1], lang="en", model_type="bert-base-uncased")

# Print the F1 similarity score
print(f"BERTScore F1: {F1.item():.4f}")




BERTScore F1: 0.7901


In [10]:
output = mab_explainer.get_baseline_tokens(query)
output.shape
response = tokenizer.decode(output[0])
response

AttributeError: 'MABExplainer' object has no attribute 'get_baseline_tokens'

In [None]:
masked_output = mab_explainer.get_baseline_tokens(masked_query)
masked_output.shape
masked_response = tokenizer.decode(masked_output[0])
masked_response

In [None]:
response_logits = mab_explainer.get_response_logits(query, response)
response_logits.shape
reward = mab_explainer.get_reward(response_logits, baseline_tokens=output, reward_type="cross_entropy")
reward

In [None]:
masked_response_logits = mab_explainer.get_response_logits(masked_query, masked_response)
masked_response_logits.shape
masked_reward = mab_explainer.get_reward(masked_response_logits, baseline_tokens=masked_output, reward_type="cross_entropy")
masked_reward

In [None]:
# find the argmax of the response logits
response_logits_argmax = torch.argmax(response_logits, dim=-1)
print(response_logits_argmax)

In [None]:
output

In [None]:
tokenizer.decode(response_logits_argmax[0])

In [None]:
tokenizer.decode(output[0])

In [None]:
reward = mab_explainer.get_reward(response_logits, baseline_tokens=output, reward_type="cross_entropy")
reward

In [None]:
from llmexp.utils.data_utils import LLMDataset
import numpy as np

dataset = LLMDataset("hotpot_qa", split="train")

from llmexp.utils.hotpot_helper import HotpotHelper, HotpotSample
# helper = HotpotHelper()
hpsample = HotpotSample(dataset[0])

print(hpsample.flattened_supporting_facts_indices)
print(hpsample.sentence_mask)
print(hpsample.flattened_contexts)

bernoulli_mask = np.random.binomial(1, 0.5, size=hpsample.sentence_mask.shape)

print(hpsample.get_contexts_from_mask(bernoulli_mask))

In [None]:
print(hpsample.flattened_contexts[25])
print(hpsample.flattened_contexts[34])

In [10]:
example = dataset[2]
question = example["question"]
titles = example["context"]["title"]
contexts = example["context"]["sentences"]
supporting_facts = example["supporting_facts"]
# top_res = [23, 33] # {'title': ["Arthur's Magazine", 'First for Women'], 'sent_id': [0, 0]}

In [None]:
supporting_facts

In [None]:
for idx, (title, context) in enumerate(zip(titles, contexts)):
    for j, sent in enumerate(context):
        print(f"({idx}, {j}) {title}: {sent}")
    

In [7]:
def get_supporting_facts_indices(supporting_facts, titles):
    supporting_facts_title_indices = [titles.index(t) for t in supporting_facts["title"]]
    
    return list(zip(supporting_facts_title_indices, supporting_facts["sent_id"]))

def get_iter_indices(contexts):
    ind_tuples = []
    iter_indices = []
    cur_idx = 0
    for a, context_list in enumerate(contexts):
        for b, _ in enumerate(context_list):
            ind_tuples.append((a, b))
            iter_indices.append(cur_idx)
            cur_idx += 1
    ind_map = dict(zip(ind_tuples, iter_indices))
    return ind_map


class QASample:
    def __init__(self, question, contexts, supporting_facts, titles):
        self.question = question
        self.contexts = contexts
        self.supporting_facts = supporting_facts
        self.titles = titles
        


In [None]:
supporting_facts_indices = get_supporting_facts_indices(supporting_facts, titles)

ind_map = get_iter_indices(contexts)

supporting_facts_indices_iter = [ind_map[tuple(fact)] for fact in supporting_facts_indices]
supporting_facts_indices_iter



In [None]:
supporting_facts

In [None]:
# Build the Context Sentences section.
contexts_str = ""
for title, sentences in zip(titles, contexts):
    # Join the list of sentences into one string.
    sentence_str = "\n".join(sentences)
    contexts_str += f"{sentence_str}\n"

context_str = f"Context Sentences:\n{contexts_str}"
question_str = f"Question: {question}"

query = context_str + "\n" + question_str + "\n"
print(query)



In [5]:
from llmexp.explainer.mab_explainer import MABExplainer, MABTemplate, MABRecorder
instruction = "Answer the question based on the context provided."

template = MABTemplate(tokenizer, instruction)
explainer = MABExplainer(llm, tokenizer, template)

input_mapper = explainer.input_mapper 

In [6]:
input_mapper.process_text(query)
# may need to sort the keys by (first, second, third) elements of the tuple
all_sentences = [input_mapper.sentences[key] for key in input_mapper.sentences.keys()]


In [None]:
for idx, sentence in enumerate(all_sentences):
    print(idx, sentence)
    print("-"*100)

In [None]:
query = " ".join(all_sentences)
print(query)

In [None]:
template_query = template.format(query)
print(template_query)

tokenized_template_query = tokenizer(template_query, return_tensors="pt").to(device)
input_ids = tokenized_template_query.input_ids 
attention_mask = tokenized_template_query.attention_mask 

template_query_with_response = llm.generate(input_ids, attention_mask, max_new_tokens=100)
template_query_with_response_ids = template_query_with_response['input_ids']
template_query_with_response_attention_mask = template_query_with_response['attention_mask']

response_mask = explainer.get_response_mask(attention_mask, template_query_with_response_attention_mask)

# print the response texts 
response_tokens = template_query_with_response_ids[response_mask == 1]
response_texts = tokenizer.decode(response_tokens, skip_special_tokens=False)
print(response_texts)

# print the resources 
for top_res_idx in top_res:
    print(top_res_idx)
    print(all_sentences[top_res_idx])
    print("-"*100)




In [10]:
mab_pull_iter = explainer.random_clip_query_words(query, response_texts, num_trials=5)


In [21]:
import numpy as np
from scipy.stats import beta
import torch
import random
from typing import List, Tuple, Dict, Any
from llmexp.explainer.mab_explainer import MultiLevelInputMapper

class ThompsonMABExplainer:
    def __init__(self, explainer, input_mapper, initial_alpha=1.0, initial_beta=1.0):
        """
        Initialize the Thompson Sampling Multi-Armed Bandit explainer.
        
        Args:
            explainer: The base explainer object
            input_mapper: The MultiLevelInputMapper for the text
            initial_alpha: Initial alpha parameter for Beta distributions
            initial_beta: Initial beta parameter for Beta distributions
        """
        self.explainer = explainer
        self.input_mapper = input_mapper
        self.recorder = MABRecorder(input_mapper)
        
        # Thompson sampling parameters for each content unit (Beta distribution)
        self.alphas = {}
        self.betas = {}
        
        # Initialize Beta distribution parameters for all content units
        for k in input_mapper.sentences.keys():
            self.alphas[k] = initial_alpha
            self.betas[k] = initial_beta
            
        # for k in input_mapper.phrases.keys():
        #     self.alphas[k] = initial_alpha
        #     self.betas[k] = initial_beta
            
        # for k in input_mapper.words.keys():
        #     self.alphas[k] = initial_alpha
        #     self.betas[k] = initial_beta
        
        # Track best observed clips and their rewards
        self.best_clip = None
        self.best_reward = float('-inf')
        self.all_rewards = []
        
    def sample_content_units(self, p_threshold=0.5) -> List[Tuple]:
        """
        Sample content units based on current Beta distributions.
        
        Args:
            p_threshold: Probability threshold for including a unit
            
        Returns:
            List of sampled content unit indices
        """
        sampled_indices = []
        
        # Sample based on Thompson sampling probabilities
        for key in self.alphas.keys():
            # Sample from Beta distribution
            p = beta.rvs(self.alphas[key], self.betas[key], size=1)[0]
            
            # Include this unit if its probability exceeds threshold
            if p > p_threshold:
                sampled_indices.append(key)
                
        return sampled_indices
    
    def construct_clip_from_indices(self, indices: List[Tuple]) -> str:
        """
        Construct a clip of text from the sampled indices.
        
        Args:
            indices: List of content unit indices to include
            
        Returns:
            The constructed text clip
        """
        # Sort indices to maintain the original text order
        sorted_indices = sorted(indices, key=lambda x: (x[0] if len(x) > 0 else 0, 
                                                      x[1] if len(x) > 1 else 0,
                                                      x[2] if len(x) > 2 else 0))
        
        # Extract content for each index
        contents = []
        for idx in sorted_indices:
            if len(idx) == 1:  # Sentence
                contents.append(self.input_mapper.get_content(idx[0]))
            elif len(idx) == 2:  # Phrase
                contents.append(self.input_mapper.get_content(idx[0], idx[1]))
            elif len(idx) == 3:  # Word
                contents.append(self.input_mapper.get_content(idx[0], idx[1], idx[2]))
        
        # Join with spaces
        return " ".join(contents)
    
    def update_parameters(self, sampled_indices: List[Tuple], reward: float, reward_threshold=28.0):
        """
        Update Beta distribution parameters based on observed reward.
        
        Args:
            sampled_indices: The indices of content units that were sampled
            reward: The observed reward
            reward_threshold: Threshold to consider a reward as positive
        """
        # Determine success/failure based on reward
        is_success = reward > reward_threshold
        # thresholds are usually fixed.
        # adversarial bandits (literature)
        # is_success = reward > self.best_reward
        # value running average of reward / advantage / -- could cause change of distributions.
        # Max reward as threshold

        
        # Update parameters for sampled indices
        for idx in sampled_indices:
            if idx in self.alphas:
                if is_success:
                    self.alphas[idx] += 1  # Increase alpha on success
                else:
                    self.betas[idx] += 1   # Increase beta on failure
                    
        # Record this pull in the recorder
        self.recorder.record(sampled_indices, reward)
        
        # Update best clip if this is the best so far
        if reward > self.best_reward:
            self.best_reward = reward
            self.best_clip = sampled_indices
            
        self.all_rewards.append(reward)
    
    def run_iteration(self, query: str, response_text: str, p_threshold=0.5):
        """
        Run a single iteration of the Thompson sampling algorithm.
        
        Args:
            query: The original query text
            response_text: The model's response text
            p_threshold: Probability threshold for sampling
            
        Returns:
            The reward obtained in this iteration
        """
        # Sample content units based on current distributions
        sampled_indices = self.sample_content_units(p_threshold)
        
        if not sampled_indices:
            # If nothing was sampled, pick a random sentence to avoid empty clips
            sentence_keys = list(self.input_mapper.sentences.keys())
            sampled_indices = [random.choice(sentence_keys)]
        
        # Construct the clip
        clip_text = self.construct_clip_from_indices(sampled_indices)
        
        # Format the clip with response
        template_clip_text = self.explainer.template.format(clip_text)
        clip_with_response = template_clip_text + response_text
        
        # Get reward (logits mean)
        reward = self.explainer.get_response_logits_mean_from_clips(template_clip_text, clip_with_response)
        
        # Update parameters
        self.update_parameters(sampled_indices, reward.item())
        
        return reward.item()
    
    def run(self, query: str, response_text: str, num_iterations=50, p_threshold=0.5):
        """
        Run the Thompson sampling algorithm for multiple iterations.
        
        Args:
            query: The original query text
            response_text: The model's response text
            num_iterations: Number of iterations to run
            p_threshold: Probability threshold for sampling
            
        Returns:
            The best clip found and its reward
        """
        for i in range(num_iterations):
            reward = self.run_iteration(query, response_text, p_threshold)
            
            # Optional: Print progress
            if (i+1) % 10 == 0:
                print(f"Iteration {i+1}/{num_iterations}, Reward: {reward:.4f}, Best: {self.best_reward:.4f}")
        
        # Return the best clip and its content
        best_clip_text = self.construct_clip_from_indices(self.best_clip)
        return self.best_clip, best_clip_text, self.best_reward
    
    def analyze_importance(self, normalize=True):
        """
        Analyze the importance of different content units based on Thompson sampling parameters.
        
        Args:
            normalize: Whether to normalize importance scores to [0,1] range
            
        Returns:
            Dictionary mapping content units to their importance scores
        """
        importance_scores = {}
        
        # Calculate importance as probability of being selected (alpha/(alpha+beta))
        for key in self.alphas.keys():
            importance_scores[key] = self.alphas[key] / (self.alphas[key] + self.betas[key])
        
        # Normalize if requested
        if normalize and importance_scores:
            min_score = min(importance_scores.values())
            max_score = max(importance_scores.values())
            if max_score > min_score:  # Avoid division by zero
                for key in importance_scores:
                    importance_scores[key] = (importance_scores[key] - min_score) / (max_score - min_score)
        
        return importance_scores
    
    def get_top_content_units(self, n=5):
        """
        Get the top N most important content units.
        
        Args:
            n: Number of top units to return
            
        Returns:
            List of (content_unit, importance_score) tuples
        """
        importance_scores = self.analyze_importance()
        
        # Sort by importance score in descending order
        sorted_units = sorted(importance_scores.items(), key=lambda x: x[1], reverse=True)
        
        # Return top N units with their content
        top_units = []
        for idx, score in sorted_units[:n]:
            if len(idx) == 1:
                content = self.input_mapper.get_content(idx[0])
                unit_type = "sentence"
            elif len(idx) == 2:
                content = self.input_mapper.get_content(idx[0], idx[1])
                unit_type = "phrase"
            elif len(idx) == 3:
                content = self.input_mapper.get_content(idx[0], idx[1], idx[2])
                unit_type = "word"
            
            top_units.append((idx, unit_type, content, score))
        
        return top_units

# Enhanced MABRecorder class
class MABRecorder:
    def __init__(self, input_mapper: MultiLevelInputMapper):
        # Create a single combined recorder dictionary
        self.content_recorder = {}
        
        # Add all content unit keys
        for k in input_mapper.sentences.keys():
            self.content_recorder[k] = {'pulls': 0, 'rewards': [], 'avg_reward': 0}
            
        for k in input_mapper.phrases.keys():
            self.content_recorder[k] = {'pulls': 0, 'rewards': [], 'avg_reward': 0}
            
        for k in input_mapper.words.keys():
            self.content_recorder[k] = {'pulls': 0, 'rewards': [], 'avg_reward': 0}
        
        self.input_mapper = input_mapper

    def record(self, clip_indices, reward):
        """
        Record the usage and reward of content elements.
        
        Args:
            clip_indices: List of indices tuples to record
            reward: The reward obtained for this clip
        """
        # Update statistics for each index in the clip
        for idx in clip_indices:
            if idx in self.content_recorder:
                self.content_recorder[idx]['pulls'] += 1
                self.content_recorder[idx]['rewards'].append(reward)
                self.content_recorder[idx]['avg_reward'] = sum(self.content_recorder[idx]['rewards']) / len(self.content_recorder[idx]['rewards'])
    
    def get_top_units_by_reward(self, n=5):
        """
        Get the top N content units by average reward.
        
        Args:
            n: Number of top units to return
            
        Returns:
            List of (content_unit, avg_reward, pull_count) tuples
        """
        # Filter to only units that have been pulled
        pulled_units = {k: v for k, v in self.content_recorder.items() if v['pulls'] > 0}
        
        # Sort by average reward
        sorted_units = sorted(pulled_units.items(), key=lambda x: x[1]['avg_reward'], reverse=True)
        
        # Return top N
        return [(idx, stats['avg_reward'], stats['pulls']) for idx, stats in sorted_units[:n]]

In [None]:
# Process the text with the input mapper
input_mapper.process_text(query)

# Create the Thompson MAB explainer
thompson_explainer = ThompsonMABExplainer(explainer, input_mapper)

# Run the algorithm
best_indices, best_clip, best_reward = thompson_explainer.run(
    query=query, 
    response_text=response_texts, 
    num_iterations=50,
    p_threshold=0.5
)

# Analyze results
print(f"Best reward: {best_reward}")
print(f"Best clip: {best_clip}")

# Get the most important content units
top_units = thompson_explainer.get_top_content_units(n=5)
for idx, unit_type, content, score in top_units:
    print(f"{unit_type} {idx}: '{content}' (importance: {score:.4f})")

In [None]:
print(response_texts)

In [None]:
# user_input = "Hello world! This is a test, or not a test. How are you?"
user_input = "The service at this restaurant was fantastic, and the staff were so friendly."
template_input = template.format(user_input) 
tokenized_template_input = tokenizer(template_input, return_tensors="pt").to(device)
# get input_ids and attention_mask
input_ids = tokenized_template_input.input_ids 
attention_mask = tokenized_template_input.attention_mask 

# get the output_ids and output_attention_mask
output = llm.generate(input_ids, attention_mask, max_new_tokens=100)
output_ids = output['input_ids']
output_attention_mask = output['attention_mask']

# get the response mask
response_mask = explainer.get_response_mask(attention_mask, output_attention_mask)

response_ids = output_ids[response_mask == 1]
response_texts = tokenizer.decode(response_ids[0], skip_special_tokens=False)



In [None]:
response_texts

In [None]:
from llmexp.explainer.mab_explainer import MABRecorder
random_clips = explainer.random_clip_query_words(user_input, 'positive')

recorder = MABRecorder(explainer.input_mapper)

for clip, clip_with_response, clipped_query_indices in random_clips:
    # print(clip)
    print(clipped_query_indices)
    # print(clip)
    # print(clip_with_response)
    logits_mean = explainer.get_response_logits_mean_from_clips(clip, clip_with_response)
    print(logits_mean)
    # logits = explainer.get_response_logits_from_clips(clip, clip_with_response)
    recorder.record(clipped_query_indices, logits_mean)

    
    # explainer.get_response_logits_mean(input_ids, attention_mask, response_mask)
    print("-"*100)

recorder

In [None]:
inputs = tokenizer(template)
input_ids = torch.tensor(inputs["input_ids"]).unsqueeze(0).to(device)
attention_mask = torch.tensor(inputs["attention_mask"]).unsqueeze(0).to(device)
output = llm.generate(input_ids, attention_mask, max_new_tokens=100)
print(output)

In [None]:
output_ids = output['input_ids']
output_attention_mask = output['attention_mask']

output_texts = tokenizer.decode(output_ids[0], skip_special_tokens=False)
print(output_texts)