In [2]:
import os
import sys
import json
import numpy as np
import pandas as pd
from datasets import Dataset
from tqdm import tqdm
import random
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

In [3]:
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True

In [4]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
MODEL_PATH = "google/gemma-2-9b-it"
model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, torch_dtype = torch.float16).to(DEVICE)
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)

OSError: You are trying to access a gated repo.
Make sure to have access to it at https://huggingface.co/google/gemma-2-9b.
401 Client Error. (Request ID: Root=1-6752bd31-631fb6c74c1c57ac69ac8a06;cb1f47f0-90d2-40a3-bcdf-ad0104b5c219)

Cannot access gated repo for url https://huggingface.co/google/gemma-2-9b/resolve/main/config.json.
Access to model google/gemma-2-9b is restricted. You must have access to it and be authenticated to access it. Please log in.

In [None]:
data = pd.read_csv("data/sample_submission.csv")
few_shots = pd.read_csv("data/few_shots.csv")

In [None]:
PROMPT ="""
Scrambled Passage:
{scrambled_passage}

Reordered Passage:
"""

In [None]:
processed = []

for i, row in tqdm(data.iterrows(), total=len(data)):
    scrambled_passage = row["text"]
    prompt = PROMPT.format(scrambled_passage=scrambled_passage)
    few_shots_unscrambled = '\n'.join([f"Unscrambled Passage:\n{unscrambled}\n" for unscrambled in few_shots.sample(5, random_state=SEED)["sentence"]])
    few_shots_scrambled = '\n'.join([f"Scrambled Passage:\n{scrambled}\n" for scrambled in few_shots.sample(5, random_state=SEED)["sentence_shuffled"]])
    processed.append(
        {
            "messages": [
                {
                    "role": "system",
                    "content" : """
                    Someone has scrambled the words in classic tales! Your task is to put those words back in order, minimizing the perplexity of each passage.
                    
                    You will be given a passage with words scrambled. You need to unscramble the words in the passage to make it coherent. You can use the context of the passage to help you unscramble the words.

                    Here are the rules:
                    - You can only unscramble the words in the passage.
                    - You cannot add or remove words from the passage.
                    - You cannot use any external resources.
                    - You cannot use any punctuation marks.
                    - You cannot use any special characters.
                    
                    Here are the examples of scrambled and unscrambled passages:
                    """
                },
                {
                    "role": "system",
                    "content" : few_shots_scrambled
                },
                {
                    "role": "system",
                    "content" : few_shots_unscrambled
                },
                {
                    "role": "user",
                    "content" : prompt
                }
            ]
        }
    )

In [None]:
processed[0]

In [None]:
processed = Dataset.from_pandas(pd.DataFrame(processed))
processed

In [None]:
def formatting_prompts(examples):
    output_text = []
    for i in range(len(examples["messages"])):
        output_text.append(
            tokenizer.apply_chat_template(
                examples["messages"],
                tokenize=False
            )
        )
    return output_text

def tokenize(element):
    outputs = tokenizer(
        formatting_prompts(element),
        truncation=False,
        padding=False,
        return_overflowing_tokens=False,
        return_length=False,
    )
    return {
        "input_ids": outputs["input_ids"],
        "attention_mask": outputs["attention_mask"],
    }

tokenized_data = processed.map(
    tokenize,
    remove_columns=list(processed.column_names),
    batched=True,
    num_proc=4,
    load_from_cache_file=True,
    desc="Running tokenizer on prompts",
)

In [None]:
model.eval()

generation_config = {
    "max_new_tokens": 256,  
    "do_sample": False,     
    "temperature": 0.7,     
    "top_p": 0.9,          
    "repetition_penalty": 1.1,
    "no_repeat_ngram_size": 2
}

def generate_unscrambled_passage(model, tokenizer, input_text, device):
    inputs = tokenizer(input_text, return_tensors="pt").to(device)
    
    outputs = model.generate(
        **inputs, 
        **generation_config
    )
    
    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    try:
        unscrambled_passage = generated_text.split("Reordered Passage:")[1].strip()
    except IndexError:
        unscrambled_passage = generated_text.strip()
    
    return unscrambled_passage


In [None]:
results = []
for i, row in tqdm(data.iterrows(), total=len(data)):
    input_text = PROMPT.format(scrambled_passage=row["text"])
    
    try:
        unscrambled = generate_unscrambled_passage(model, tokenizer, input_text, DEVICE)
        results.append({
            "id": row["id"],
            "text": unscrambled
        })
    except Exception as e:
        print(f"Error processing row {i}: {e}")
        results.append({
            "id": row["id"],
            "text": row["text"] 
        })

In [None]:
output_df = pd.DataFrame(results)
output_df.to_csv("unscrambled_passages.csv", index=False)

print("Unscrambling complete. Results saved to unscrambled_passages.csv")