In [478]:
import pandas as pd
import re
from google import genai
import ast
from itertools import permutations, chain

In [479]:
def load_data():
    data = pd.read_csv("../data/connections.csv")
    return data

df = load_data()
df["groups"] = df["groups"].apply(lambda x: ast.literal_eval(x) if isinstance(x, str) else x)
df["answers"] = df["answers"].apply(lambda x: ast.literal_eval(x) if isinstance(x, str) else x)
    

In [480]:
def prompt(question):
    prompt = f"You are playing the NY Times Connections game. I will give you a set of 16 words, and I want you to provide 4 sets of exactly 4 words that are connected in some way. \
                I want you to group the words in such a way that each group has a common theme. Think about your answers carefully, as you will only have one chance to submit your answer. \
                Here is an example: If the words are: 'BUCKS, HAIL, JAZZ, SHIFT, LEVEL, MOM, SNOW, RACECAR, SLEET, TAB, KAYAK, RETURN, OPTION, NETS, RAIN, HEAT', \
                a possible answer could be: 'answer: [['HAIL', 'RAIN', 'SLEET', 'SNOW'], ['BUCKS', 'HEAT', 'JAZZ', 'NETS'], ['OPTION', 'RETURN', 'SHIFT', 'TAB'], ['KAYAK', 'LEVEL', 'MOM', 'RACECAR']] and groups: ['WET WEATHER', 'NBA TEAMS', 'KEYBOARD KEYS', 'PALINDROMES']. \
                Give your answer strictly in the format (no other words): \
                    'Answer: [[4 words of group1], [4 words of group2], [4 words of group3], [4 words of group4]]  \
                    Group: [group1, group2, group3, group4].' \
                {question}"

    client = genai.Client(api_key="AIzaSyDJ5qNxk-UUdivgVE_A-FYIHlWYL4MV7Ko")
    response = client.models.generate_content(
        model="gemini-2.0-flash", contents=[prompt]
    )

    return response


In [481]:
def generate_answers(question):
    response = prompt(question)
    answer = response.candidates[0].content.parts[0].text
    answer_match = re.search(r"Answer:\s*(\[\[.*?\]\])", answer)
    group_match = re.search(r"Group:\s*(\['.*?'\])", answer)

    answer = eval(answer_match.group(1)) if answer_match else None
    group = eval(group_match.group(1)) if group_match else None

    return answer, group
    

In [482]:
def evaluate_answers(gen_answers, answers):
    # mistakes = [
    #     sum(1 for word in generated_group if word not in correct_group)
    #     for correct_group, generated_group in zip(answers, gen_answers)
    # ]
    # return mistakes

    best_match_pairs = []
    min_total_mistakes = float('inf')
    best_mistakes = []

    # Try all possible permutations of answers to find the best pairing
    for perm in permutations(answers):  
        total_mistakes = 0
        temp_mistakes = []
        temp_pairs = []

        for generated_group, correct_group in zip(gen_answers, perm):
            # Count misplaced words (words in generated_group not in correct_group)
            mistake_count = sum(1 for word in generated_group if word not in correct_group)
            temp_mistakes.append(mistake_count)
            total_mistakes += mistake_count

            # Store the pair
            temp_pairs.append((generated_group, correct_group))

        # Keep track of the best match with the least mistakes
        if total_mistakes < min_total_mistakes:
            min_total_mistakes = total_mistakes
            best_match_pairs = temp_pairs
            best_mistakes = temp_mistakes

    return best_match_pairs, best_mistakes


In [483]:
def reinforcement_prompt(row):
    answers, groups = row["answers"], row["groups"]
    max_attempts = 5
    attempt = 0
    feedback = ""
    reveal_correct_groups_after = 1

    while attempt < max_attempts:
        question = "Here are the words" + " {" + row["question"] + "} " + feedback

        gen_answers, gen_groups = generate_answers(question)
        pairs, mistakes = evaluate_answers(gen_answers, answers)

        if sum(mistakes) == 0:
            print("Correct answer found!")
            return gen_answers

        feedback = "\nYour last attempt resulted in the following mistakes:\n"
        
        if attempt + 1 >= reveal_correct_groups_after:
            correct_group_names = [f"Group {i + 1}: {group}" for i, group in enumerate(row["groups"])]
            feedback += "\nHint: The correct group names are:\n" + "\n".join(correct_group_names) + "\n"
        else:
            for i, (gen_group, mistake_count) in enumerate(zip(pairs, mistakes), 1):
                feedback += f"Group {i}: {gen_group[0]} → Mistakes: {mistake_count}\n"

        feedback += "\nReview each set carefully and try to reduce the mistake to 0. The mistake count is the number of words in your group that do not belong in the correct group. \
                    You have a maximum of 10 attempts to reach the correct solution."
        attempt += 1

    print("Failed to reach correct solution within max attempts.")
    return None

In [None]:
for index, row in df[1:15].iterrows():
    reinforcement_prompt(row)