# MODEL EVALUATION WITH CATALAN HELLASWAG

In [None]:
from tqdm import tqdm
import torch
import json
from transformers import AutoTokenizer, AutoModelForCausalLM
import datasets

# Load the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("baiges/CatGPT")
model = AutoModelForCausalLM.from_pretrained("baiges/CatGPT")

# Define the device (GPU if available, otherwise CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Load the Hellaswag dataset translated to Catalan
# The dataset is in "pauhidalgoo/hellaswag-val-ca"
dataset = datasets.load_dataset("pauhidalgoo/hellaswag-val-ca")['train']

# Define the evaluation function
def evaluate_model(model, tokenizer, dataset, device):
    total_correct = 0
    total_items = 0

    # Iterate over each entry in the dataset
    for i, item in tqdm(enumerate(dataset)):
        try:
            context = item['ctx']  # Get the context
            endings = item['endings']  # Get the possible endings
            correct_ending_idx = int(item['label'])  # Index of the correct ending

            # Tokenize the context and move it to the device
            context_input = tokenizer(context, return_tensors="pt").input_ids.to(device).long()

            avg_loss_values = []
            for ending in endings:
                try:
                    # Tokenize the ending and move it to the device
                    ending_input = tokenizer(ending, return_tensors="pt").input_ids.to(device).long()
                    
                    # Concatenate the context with each ending
                    input_ids = torch.cat([context_input, ending_input[:, 1:]], dim=-1).to(device).long()

                    # Calculate the loss only over the ending part
                    with torch.no_grad():
                        outputs = model(input_ids)
                        logits = outputs.logits[:, -ending_input.size(1):, :]  # Get logits corresponding to the ending

                        shift_logits = logits[:, :-1, :].contiguous()  # Shift logits to the right
                        shift_labels = ending_input[:, 1:].contiguous()  # Shift labels to the left

                        # Ensure the logits and labels have the same shape
                        shift_logits = shift_logits.view(-1, shift_logits.size(-1))
                        shift_labels = shift_labels.view(-1)

                        # Calculate loss
                        loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
                        loss = loss_fct(shift_logits, shift_labels)
                        avg_loss = loss.mean().item()  # Calculate the average loss per token
                        avg_loss_values.append(avg_loss)
                except RuntimeError as e:
                    print(f"Error processing ending: {e}")
                    avg_loss_values.append(float('inf'))  # Consider high loss if there's an error
                    continue

            # Find the ending with the lowest average loss per token
            predicted_idx = torch.argmin(torch.tensor(avg_loss_values)).item()

            # Check if the prediction is correct and update the counters
            is_correct = predicted_idx == correct_ending_idx
            if is_correct:
                total_correct += 1
            total_items += 1

            # Calculate the accuracy for this specific item
            item_accuracy = total_correct / total_items

            # Print context, prediction, correctness, and accuracy for this item
            print(f"Context: {context}")
            print(f"Predicted Ending: {endings[predicted_idx]}")
            print(f"Correct Ending: {endings[correct_ending_idx]}")
            print(f"Is Correct: {'Yes' if is_correct else 'No'}")
            print(f"Accuracy so far: {item_accuracy:.4f}\n")

        except Exception as e:
            print(f"Error processing item {i}: {e}")
            continue

    # Calculate the overall accuracy
    accuracy = total_correct / total_items
    return accuracy

# Perform the evaluation
accuracy = evaluate_model(model, tokenizer, dataset, device)
print(f"Final Accuracy: {accuracy:.4f}")