In [1]:
from transformers import AutoConfig
from datasets import load_from_disk

import configs
from controller.memory_manager import MemoryManager
from data_processor.data_loader import GSM8KDataset
from generator.crv_generator import CRVGenerator
from generator.text_generator import TextGenerator

from retrieve.cosine_similarity import CRVRetriever
from retrieve.dnc import DNMemory
from utils import set_seed, logger
from utils.loading_model import CustomTransformerLoader

# from rich import print
from rich.console import Console


In [2]:
# Set up logging and console
console = Console()
logger = logger()

In [3]:
console = Console()
seed = 42
set_seed(seed)

model_urls = {
    "llama31": "meta-llama/Meta-Llama-3.1-8B-Instruct",
    "llama3": "meta-llama/Meta-Llama-3-8B-Instruct",
}
model_path = model_urls["llama31"]
tokenizer_path = model_path
hf_token = "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"

In [4]:
config = AutoConfig.from_pretrained(model_path, use_auth_token=hf_token)

console.rule("[bold red]Loading the Model")

loader = CustomTransformerLoader()



In [5]:
model, tokenizer = loader.load_model(
    model_path=model_path, tokenizer_path=tokenizer_path, hf_token=hf_token
)

crv_layers = configs.CRV_LAYERS

print(":warning: model type: ", type(model))
print("config.hidden_size: ", config.num_hidden_layers)
print("config._attn_implementation: ", config._attn_implementation)



Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

config.hidden_size:  32
config._attn_implementation:  eager


In [6]:
import re

def extract_context_expansion(text):
    pattern = r'<context_generation>(.*?)</context_generation>'
    match = re.search(pattern, text, re.DOTALL)
    if match:
        return match.group(1).strip()
    else:
        return f"Context expansion section not found. The original text: {text}"

In [7]:
def extract_test_cases(text):
    # Pattern to match assert statements
    pattern = r'assert\s+[\w_]+\(.*?\).*?(?=[\n<]|$)'
    
    # Find all matches
    test_cases = re.findall(pattern, text)
    
    # Group test cases by task
    grouped_tests = []
    current_group = []
    
    for test in test_cases:
        if current_group and not test.startswith(current_group[-1].split('(')[0]):
            grouped_tests.append(current_group)
            current_group = []
        current_group.append(test)
    
    if current_group:
        grouped_tests.append(current_group)
        print("test cases len: ", len(grouped_tests))
    
    return grouped_tests

text = '''<|start_header_id|>user<|end_header_id|>\n\nYou are an expert Python programmer, and here is your task:\nWrite a function to find the similar elements from the given two tuple lists.\nYour code should pass the following tests:\nassert similar_elements((3, 4, 5, 6),(5, 7, 4, 10)) == (4, 5)\nassert similar_elements((1, 2, 3, 4),(5, 4, 3, 7)) == (3, 4)\nassert similar_elements((11, 12, 14, 13),(17, 15, 14, 13)) == (13, 14)<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n```python\ndef similar_elements(test_tup1, test_tup2):\n res = tuple(set(test_tup1) & set(test_tup2))\n return (res) \n```<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nYou are an expert Python programmer, and here is your task:\nWrite a python function to identify non-prime numbers.\nYour code should pass the following tests:\nassert is_not_prime(2) == False\nassert is_not_prime(10) == True\nassert is_not_prime(35) == True<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n```python\nimport math\ndef is_not_prime(n):\n result = False\n for i in range(2,int(math.sqrt(n)) + 1):\n if n % i == 0:\n result = True\n return result\n```<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nYou are an expert Python programmer, and here is your task:\nWrite a function to find the largest integers from a given list of numbers using heap queue algorithm.\nYour code should pass the following tests:\nassert heap_queue_largest( [25, 35, 22, 85, 14, 65, 75, 22, 58],3)==[85, 75, 65] \nassert heap_queue_largest( [25, 35, 22, 85, 14, 65, 75, 22, 58],2)==[85, 75] \nassert heap_queue_largest( [25, 35, 22, 85, 14, 65, 75, 22, 58],5)==[85, 75, 65, 58, 35]<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n```python\nimport heapq as hq\ndef heap_queue_largest(nums,n):\n largest_nums = hq.nlargest(n, nums)\n return largest_nums\n```<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nYou are an expert Python programmer, and here is your task:\nWrite a function to create the next bigger number by rearranging the digits of a given number.\nYour code should pass the following tests:\nassert rearrange_bigger(12)==21\nassert rearrange_bigger(10)==False\nassert rearrange_bigger(102)==120<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n```python"'''
out = extract_test_cases(text)
print(out[-1])

test cases len:  4
['assert rearrange_bigger(12)==21', 'assert rearrange_bigger(10)==False', 'assert rearrange_bigger(102)==120']


In [8]:
def extract_functions(text):
    # Extract imports
    import_pattern = r'^(?:from\s+[\w.]+\s+import\s+(?:[\w.]+(?:\s*,\s*[\w.]+)*|\*)|import\s+(?:[\w.]+(?:\s*,\s*[\w.]+)*))(?:\s+as\s+[\w.]+)?'
    imports = re.findall(import_pattern, text, re.MULTILINE)
    
    # Extract functions
    function_pattern = r"(def\s+\w+\s*\(.*?\):(?:\s*['\"][\s\S]*?['\"])?\s*(?:(?!def\s)[\s\S])*?(?=\ndef|\Z))"
    functions = re.findall(function_pattern, text, re.MULTILINE | re.DOTALL)
    
    def clean_code(code):
        # Remove docstrings
        code = re.sub(r'"""[\s\S]*?"""|\'\'\'[\s\S]*?\'\'\'', '', code)
        # Remove comments
        code = re.sub(r'#.*', '', code)
        # Remove empty lines and trailing whitespace
        code = '\n'.join(line for line in code.splitlines() if line.strip())
        return code
    
    cleaned_imports = [clean_code(imp) for imp in imports]
    cleaned_functions = [clean_code(func) for func in functions]
    
    # Combine imports and functions
    cleaned_code = '\n'.join(cleaned_imports)
    if cleaned_imports and cleaned_functions:
        cleaned_code += '\n\n'
    cleaned_code += '\n\n'.join(cleaned_functions)
    
    return cleaned_code


In [38]:
class AdvancedLLaMACRVFramework:
    def __init__(self, model, tokenizer, layer_idx = 10):
        self.model = model
        self.tokenizer = tokenizer
        self.text_generator = TextGenerator(model, tokenizer)
        self.crv_generator = CRVGenerator(model, tokenizer, max_length=configs.MAX_LENGTH)
        self.memory_manager = MemoryManager(model, max_memories=5)
        self.layer_idx = layer_idx


    def generate_thought_trajectories(self, input_query, test_cases=None, max_new_tokens=1000):
        prompt_template = f"""
        <|begin_of_text|><|start_header_id|>system<|end_header_id|>
        
        \n\nYou are an expert Python programmer designed to provide standard, accurate,and fully working codes, and here is your task:\n
        \nWrite a function to find the similar elements from the given two tuple lists.\nYour code should pass the following tests:\nassert similar_elements((3, 4, 5, 6),(5, 7, 4, 10)) == (4, 5)\nassert similar_elements((1, 2, 3, 4),(5, 4, 3, 7)) == (3, 4)\nassert similar_elements((11, 12, 14, 13),(17, 15, 14, 13)) == (13, 14)<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n```python\ndef similar_elements(test_tup1, test_tup2):\n res = tuple(set(test_tup1) & set(test_tup2))\n return (res) \n```<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nYou are an expert Python programmer, and here is your task:\nWrite a python function to identify non-prime numbers.\nYour code should pass the following tests:\nassert is_not_prime(2) == False\nassert is_not_prime(10) == True\nassert is_not_prime(35) == True<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n```python\nimport math\ndef is_not_prime(n):\n result = False\n for i in range(2,int(math.sqrt(n)) + 1):\n if n % i == 0:\n result = True\n return result\n```<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nYou are an expert Python programmer, and here is your task:\nWrite a function to find the largest integers from a given list of numbers using heap queue algorithm.\nYour code should pass the following tests:\nassert heap_queue_largest( [25, 35, 22, 85, 14, 65, 75, 22, 58],3)==[85, 75, 65] \nassert heap_queue_largest( [25, 35, 22, 85, 14, 65, 75, 22, 58],2)==[85, 75] \nassert heap_queue_largest( [25, 35, 22, 85, 14, 65, 75, 22, 58],5)==[85, 75, 65, 58, 35]<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n```python\nimport heapq as hq\ndef heap_queue_largest(nums,n):\n largest_nums = hq.nlargest(n, nums)\n return largest_nums\n```<|eot_id|>
        Enable code_interpreter tool.
        <|eot_id|><|start_header_id|>user<|end_header_id|>
        
        \n\nYou are an expert Python programmer, and here is your task:\n{input_query}.\nYour code must pass these test cases:{test_cases}
        
        \n\nYour outputs must follow this structure:

        Identify the core components of this problem.
        1. Identify potential edge cases and tricky parts.
        2. Write 2 short test cases for the edge cases and tricky parts.
        
        <chain_of_thoughts>
        1. you must consider the edge cases according to the problem statement.
        2. Begin with a <thinking> section.
        3. Inside the thinking section:
           a. Write the topic name of the query, the name of the algorithm if necessary.
           b. Draft an answer as an expert.
           b. Briefly analyze the question and outline your approach.
           c. Present a clear plan of steps to solve the problem.
           d. Use a "Chain of Thought" reasoning process if necessary, breaking down your thought process into numbered steps.
        4. Include a <reflection> section for each idea where you:
           a. Review your reasoning.
           b. Check for potential errors or oversights.
           c. Confirm or adjust your conclusion if necessary.
        5. Be sure to close all reflection sections.
        6. Close the thinking section with </thinking>.
        7. Provide your final answer in an <output> section.        
        </chain_of_thoughts>

        <chain_of_thought_selection>
        you must consider the edge cases according to the problem statement and select the most promising chain of thought that solves the edge cases (not necessarily the simplest nor the standard approach).
        </chain_of_thought_selection>

        <solution>
        1. As a Python expert, generate the Python code and make sure it solves the edge cases while keeping it efficient.
        2. the internal steps must produce the required output.
        </solution>

        Include a <reflection> section for the selected solution where if it is not correct, modify or if necessary, rewrite the solution and pay attention to the input problem.
           a. Review your reasoning.
           b. Check for potential errors or oversights according to the problem. you must consider the edge cases according to the problem. Make sure it is not overcomplicated.
           c. Confirm or adjust your conclusion if necessary.
        4. Be sure to close all reflection sections.
        
        <context_generation>
        1. Rewrite the problem.
        2. Rewrite the edge cases and tricky parts in one short sentence
        2. Generate a very accurate and minimal Python code/pseudocode for the final solution. Ensure that the final solution is minimal and accurate.
        </context_generation>
        <|eot_id|>
        <|start_header_id|>assistant<|end_header_id|>\n\n"
        """
        # <|eot_id|><|start_header_id|>user<|end_header_id|>\n\n
        # <|start_header_id|>system<|end_header_id|>
        
        # <|eot_id|>
        # ```python
        # print("prompt: ", prompt)
        generated_text = self.text_generator.generate_text(
            prompt_template,
            max_new_tokens=max_new_tokens,
            num_return_sequences = 1,
            output_file="data/results.csv",
            # stop_sequences=["The end", ".\n\n"],
        )
        return generated_text
    
    def extract_hidden_states(self, context):
        best_crv, seq_length = self.crv_generator.generate_crvs(
            context, crv_layers=crv_layers, max_length=configs.MAX_LENGTH
        )
        return best_crv, seq_length  # Return the hidden state and its len

    def generate_crv(self, hidden_states, seq_length):
        # return torch.mean(hidden_states, dim=1)
        return hidden_states, seq_length
        
    def final_generation(self, original_query, test_cases, crv, seq_length, max_new_tokens=250):

        query=f"""<|start_header_id|>user<|end_header_id|>\n\nYou are an expert Python programmer, and here is your task:\n{original_query}.\nYour code should pass the following tests:{test_cases}"<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n```python"""
        # Combine original query and CRV
        self.memory_manager.add_memory(
        crv, seq_length, layer_idx=self.layer_idx, crv_layers=crv_layers
    )

        # model.model.set_post_concat_crv(True)
        self.memory_manager.set_concat_positions(0, start_pos=0, end_pos=seq_length)
        self.memory_manager.apply_memory_to_model(0)
        generated_text = self.text_generator.generate_text(
            query,
            max_new_tokens=max_new_tokens,
            num_return_sequences = 1,
            output_file="data/results.csv",
            # stop_sequences=["The end", ".\n\n"],
        )
        # print(generated_text)
        print('==' * 50)
        return generated_text

In [39]:
from datasets import load_from_disk, Dataset
from tqdm import tqdm
import pandas as pd
from typing import List, Dict, Any

def evaluate_model(model, tokenizer, dataset: Dataset, layer_indices: List[int], num_examples: int = -1) -> pd.DataFrame:
    results = []
    
    for layer_idx in tqdm(layer_indices, desc="Processing layer indices"):
        framework = AdvancedLLaMACRVFramework(model, tokenizer, layer_idx=layer_idx)
        
        for i, instance in enumerate(tqdm(dataset, desc=f"Processing instances for layer {layer_idx}")):
            if num_examples != -1 and i >= num_examples:
                break
            
            query = instance['query'][0]
            context = instance['context'][0]
            test_cases = '\n'.join(extract_test_cases(instance['input_final_prompts'][0])[-1])
            
            trajectories_and_context = framework.generate_thought_trajectories(query, test_cases, max_new_tokens=1000)
            context_expansion = extract_context_expansion(trajectories_and_context)
            
            hidden_states, seq_len = framework.extract_hidden_states(context_expansion)
            crv, seq_len = framework.generate_crv(hidden_states, seq_len)
            
            final_output = framework.final_generation(query, test_cases, crv, seq_len, max_new_tokens=250)
            extracted_functions = extract_functions(final_output)
            
            result = {
                'layer_idx': layer_idx,
                'instance_id': i,  # This should be the index in the dataset
                'query': query,
                'context': context,
                'test_cases': test_cases,
                'final_output': final_output,
                'extracted_functions': extracted_functions
            }
            results.append(result)
    
    return pd.DataFrame(results)



In [40]:
from datasets import Dataset, concatenate_datasets
def add_parsed_functions_to_dataset(dataset: Dataset, results_df: pd.DataFrame, layer_indices: List[int]) -> Dataset:
    # Convert results DataFrame to a dictionary
    results_dict = results_df.to_dict('records')
    
    # Create a dictionary to store new columns
    new_columns = {
        f'final_output_layer_{layer}': [None] * len(dataset) for layer in layer_indices
    }
    new_columns.update({
        f'extracted_functions_layer_{layer}': [None] * len(dataset) for layer in layer_indices
    })
    
    # Populate new columns
    for result in results_dict:
        instance_id = result['instance_id']
        layer = result['layer_idx']
        if 0 <= instance_id < len(dataset):
            new_columns[f'final_output_layer_{layer}'][instance_id] = result['final_output']
            new_columns[f'extracted_functions_layer_{layer}'][instance_id] = result['extracted_functions']
    
    # Create a new dataset with only the new columns
    new_dataset = Dataset.from_dict(new_columns)
    
    # Combine the original dataset with the new dataset
    updated_dataset = concatenate_datasets([dataset, new_dataset], axis=1)
    
    return updated_dataset

In [42]:
loaded_dataset = load_from_disk("data/processed_meta_llama_dataset")

# Define layer indices to evaluate
layer_indices = [15]
num_examples=1
# Evaluate model
results_df = evaluate_model(model, tokenizer, loaded_dataset, layer_indices, num_examples=num_examples)

# Add parsed functions to the dataset
updated_dataset = add_parsed_functions_to_dataset(loaded_dataset, results_df, layer_indices)

print(f"Type of updated_dataset: {type(updated_dataset)}")
print(f"Number of rows in updated_dataset: {len(updated_dataset)}")


# Save the updated dataset
updated_dataset.save_to_disk("data/processed_meta_llama_dataset_with_results")

Processing layer indices:   0%|                           | 0/1 [00:00<?, ?it/s]
Processing instances for layer 15:   0%|                | 0/500 [00:00<?, ?it/s][AThe attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.
Both `max_new_tokens` (=1000) and `max_length`(=50) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)


test cases len:  4
 # This function checks if a given tuple is distinct or not.
def check_distinct(tup):
    return len(tup) == len(set(tup))
"

    assert check_distinct((1, 4, 5, 6, 1, 4)) == False
assert check_distinct((1, 4, 5, 6)) == True
assert check_distinct((2, 3, 4, 5, 6)) == True


INFO:utils.utils:The input received is a query
INFO:controller.memory_manager:Added new memory. Current number of memories: 1
INFO:controller.memory_manager:Set concat positions for memory 0: start=0, end=tensor([111], device='cuda:0')


INFO:controller.memory_manager:Applied memory 0 to model
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.
Both `max_new_tokens` (=250) and `max_length`(=50) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)



def check_distinct(tup):
    """
    This function checks if a given tuple is distinct or not.
    
    Args:
        tup (tuple): The input tuple to be checked.
    
    Returns:
        bool: True if the tuple is distinct, False otherwise.
    """
    # Convert the tuple to a set. A set in Python is an unordered collection of unique elements.
    # If the set has the same number of elements as the original tuple, it means all elements were unique.
    return len(tup) == len(set(tup))
```



Processing instances for layer 15:   0%|      | 1/500 [00:16<2:13:11, 16.01s/it][A
Processing layer indices: 100%|███████████████████| 1/1 [00:16<00:00, 16.02s/it]

Type of updated_dataset: <class 'datasets.arrow_dataset.Dataset'>
Number of rows in updated_dataset: 500





Saving the dataset (0/1 shards):   0%|          | 0/500 [00:00<?, ? examples/s]

In [50]:
from datasets import load_dataset

def main():    
    # Print column names
    print("\nColumns in the updated dataset:")
    print(updated_dataset.column_names)
    # loaded_dataset2 = load_from_disk("data/processed_meta_llama_dataset_with_results")
    loaded_dataset2 = load_dataset(
    "meta-llama/Meta-Llama-3.1-8B-Instruct-evals",
    name="Meta-Llama-3.1-8B-Instruct-evals__mbpp__details",
    split="latest"
    )



    # Print selected columns from the first 5 rows of the updated dataset
    print("First 5 rows of the updated dataset (selected columns):")
    for i, example in enumerate(loaded_dataset2.select(range(num_examples))):
        print(f"\nExample {i + 1}:")
        # Print original columns
        print(f"input_correct_responses: {example['input_correct_responses']}...")
        # print(f"extracted_functions_layer_15: {example['extracted_functions_layer_15']}...")
        
        # Print new columns for each layer
        for layer in layer_indices:
            print(f"\nLayer {layer}:")
            final_output = example.get(f'final_output_layer_{layer}')
            extracted_functions = example.get(f'extracted_functions_layer_{layer}')
            
            if final_output:
                print(f"Final Output: {final_output[:100]}...")
            else:
                print("Final Output: None")
            
            if extracted_functions:
                print(f"Extracted Functions: {extracted_functions[:100]}...")
            else:
                print("Extracted Functions: None")


if __name__ == "__main__":
    main()



Columns in the updated dataset:
['task_type', 'task_name', 'subtask_name', 'input_question', 'input_choice_list', 'input_final_prompts', 'input_correct_responses', 'output_prediction_text', 'output_parsed_answer', 'output_choice_completions', 'output_choice_negative_log_likelihoods', 'output_metrics', 'is_correct', 'input_question_hash', 'input_final_prompts_hash', 'benchmark_label', 'eval_config', 'context', 'query', 'final_output_layer_15', 'extracted_functions_layer_15']


Using the latest cached version of the dataset since meta-llama/Meta-Llama-3.1-8B-Instruct-evals couldn't be found on the Hugging Face Hub
Found the latest cached dataset configuration 'Meta-Llama-3.1-8B-Instruct-evals__mbpp__details' at /home/sg23454/.cache/huggingface/datasets/meta-llama___meta-llama-3.1-8_b-instruct-evals/Meta-Llama-3.1-8B-Instruct-evals__mbpp__details/0.0.0/0f783b11d6240fc4f669dd95a842173b036e6799 (last modified on Sat Sep  7 05:25:57 2024).


First 5 rows of the updated dataset (selected columns):

Example 1:
input_correct_responses: None...

Layer 15:
Final Output: None
Extracted Functions: None


In [34]:
import pandas as pd
from tqdm import tqdm
from typing import List, Dict, Any

def load_from_disk(path: str) -> pd.DataFrame:
    return pd.read_csv(path)  # Adjust if your data is in a different format

def evaluate_model(model, tokenizer, dataset: pd.DataFrame, layer_indices: List[int], num_examples: int = -1) -> pd.DataFrame:
    results = []
    
    for layer_idx in tqdm(layer_indices, desc="Processing layer indices"):
        framework = AdvancedLLaMACRVFramework(model, tokenizer, layer_idx=layer_idx)
        
        for i, row in dataset.iterrows():
            if num_examples != -1 and i >= num_examples:
                break
            
            query = row['query']
            context = row['context']
            test_cases = '\n'.join(extract_test_cases(row['input_final_prompts'])[-1])
            
            trajectories_and_context = framework.generate_thought_trajectories(query, test_cases, max_new_tokens=1000)
            context_expansion = extract_context_expansion(trajectories_and_context)
            
            hidden_states, seq_len = framework.extract_hidden_states(context_expansion)
            crv, seq_len = framework.generate_crv(hidden_states, seq_len)
            
            final_output = framework.final_generation(query, test_cases, crv, seq_len, max_new_tokens=250)
            extracted_functions = extract_functions(final_output)
            
            results.append({
                'layer_idx': layer_idx,
                'instance_id': i,
                'final_output': final_output,
                'extracted_functions': extracted_functions
            })
    
    return pd.DataFrame(results)

def add_parsed_functions_to_dataset(dataset: pd.DataFrame, results_df: pd.DataFrame, layer_indices: List[int]) -> pd.DataFrame:
    for layer in layer_indices:
        dataset[f'final_output_layer_{layer}'] = None
        dataset[f'extracted_functions_layer_{layer}'] = None
    
    for _, row in results_df.iterrows():
        instance_id = row['instance_id']
        layer = row['layer_idx']
        dataset.loc[instance_id, f'final_output_layer_{layer}'] = row['final_output']
        dataset.loc[instance_id, f'extracted_functions_layer_{layer}'] = row['extracted_functions']
    
    return dataset



In [51]:
def main():
    # Load dataset
    loaded_dataset = load_from_disk("data/processed_meta_llama_dataset")
    
    # Define layer indices to evaluate
    layer_indices = [15]
    
    # Evaluate model on a subset
    num_examples = 2  # Adjust as needed
    subset_dataset = loaded_dataset.head(num_examples)
    results_df = evaluate_model(model, tokenizer, subset_dataset, layer_indices, num_examples=num_examples)
    
    # Add parsed functions to the dataset
    updated_dataset = add_parsed_functions_to_dataset(subset_dataset, results_df, layer_indices)
    
    print(f"Type of updated_dataset: {type(updated_dataset)}")
    print(f"Number of rows in updated_dataset: {len(updated_dataset)}")
    
    # Save the updated dataset
    updated_dataset.to_csv("data/processed_meta_llama_dataset_with_results.csv", index=False)
    
    # Print column names
    print("\nColumns in the updated dataset:")
    print(updated_dataset.columns.tolist())
    
    # Print selected columns from the first 5 rows of the updated dataset
    print("First 5 rows of the updated dataset (selected columns):")
    for i, row in updated_dataset.iterrows():
        print(f"\nExample {i + 1}:")
        print(f"input_correct_responses: {row['input_correct_responses']}")
        print(f"extracted_functions_layer_15: {row['extracted_functions_layer_15']}")
        
        for layer in layer_indices:
            print(f"\nLayer {layer}:")
            final_output = row[f'final_output_layer_{layer}']
            extracted_functions = row[f'extracted_functions_layer_{layer}']
            
            print(f"Final Output: {final_output[:100]}..." if final_output else "Final Output: None")
            print(f"Extracted Functions: {extracted_functions[:100]}..." if extracted_functions else "Extracted Functions: None")

if __name__ == "__main__":
    main()


AttributeError: 'Dataset' object has no attribute 'head'