# Connections

Brief description of the problem, link to the website, etc...

In [6]:
# load a jsonl file
import json

def load_jsonl(file_path):
    data = []
    with open(file_path, 'r') as file:
        for line in file:
            data.append(json.loads(line))
    return data

ds = load_jsonl('connections_prompts.jsonl')

In [9]:
print(ds[0]["words"])



['schmaltz', 'knuckles', 'corn', 'sap', 'loose', 'smile', 'chump', 'egg', 'duct', 'pipe', 'climate', 'sea', 'cheese', 'window', 'drain', 'sewer']


## Naive approach

In [25]:
from openai import OpenAI

client = OpenAI()

In [91]:

system_prompt = (
    """The game "Connections" is a word game where you start with 16 words and need to group """
    """them into 4 groups of 4. Each grouping has a category that unambiguously groups the four words together."""
    """Each puzzle has exactly one solution. Watch out for words that seem to belong to multiple categories."""
    """You will be given 16 words. Output 4 groups of 4 words and the categories to which they belong"""
    """The results should be in JSON format as following:
    {"category1": ["word1", "word2", "word3", "word4"], "category2": ["word1", "word2", "word3", "word4"]}]}
    """
)

user_prompt = "Here are the 16 words: {words}"

In [133]:
def call_openai(system_prompt, user_prompt, model="gpt-4o"):
    response = client.chat.completions.create(
        model=model,
        messages=[
            {
                "role": "system",
                "content": system_prompt
            },
            {
                "role": "user",
                "content": user_prompt
            }
            ],
            temperature=0.7,
            response_format={ "type": "json_object" }
        )
    extracted = response.choices[0].message.content
    return extracted

In [93]:
res = call_openai(system_prompt, user_prompt.format(words=ds[0]["words"]))
generation = json.loads(res)

In [99]:
for group in generation.items():
    print(group)

('Body Parts', ['knuckles', 'smile', 'corn', 'sap'])
('Food', ['egg', 'cheese', 'corn', 'sap'])
('Plumbing', ['duct', 'pipe', 'drain', 'sewer'])
('Miscellaneous', ['schmaltz', 'loose', 'chump', 'climate'])


Let's create a function to check if the groups are valid

In [105]:
flat_generation = list(generation.values())
flat_generation

[['knuckles', 'smile', 'corn', 'sap'],
 ['egg', 'cheese', 'corn', 'sap'],
 ['duct', 'pipe', 'drain', 'sewer'],
 ['schmaltz', 'loose', 'chump', 'climate']]

In [106]:
flat_solution = list(ds[0]["categories"].values())
flat_solution

[['drain', 'duct', 'pipe', 'sewer'],
 ['cheese', 'corn', 'sap', 'schmaltz'],
 ['egg', 'knuckles', 'smile', 'window'],
 ['chump', 'climate', 'loose', 'sea']]

In [128]:
def check_solution(generation, solution):
    "Check that all group of words match the solution"    
    accuracy = 0.
    
    for sol_cat, sol_group in solution.items():
        for gen_cat, gen_group in generation.items():
            if set(gen_group) == set(sol_group):
                print(f"{gen_cat} ~ {sol_cat}: {gen_group} == {sol_group}")
                accuracy += 1
    return accuracy / len(flat_solution)

In [129]:
check_solution(generation, ds[0]["categories"])

Plumbing ~ conduits for water removal: ['duct', 'pipe', 'drain', 'sewer'] == ['drain', 'duct', 'pipe', 'sewer']


0.25

## Refactor

In [134]:
def generate_solution(sample):
    res = call_openai(system_prompt, user_prompt.format(words=sample["words"]))
    generation = json.loads(res)
    return generation

In [135]:
generations = [generate_solution(sample) for sample in ds[0:10]]

In [140]:
def check_solutions(generations, ds):
    total_acc = 0.
    perfect_match = 0.
    for gen, sol in zip(generations, ds):
        acc = check_solution(gen, sol["categories"])
        if acc == 1.0:
            perfect_match += 1
        total_acc += acc
    print(f"\nTotal accuracy: {total_acc}")
    print(f"Perfect match: {perfect_match}")

check_solutions(generations, ds[0:10])



category2 ~ conduits for water removal: ['duct', 'pipe', 'drain', 'sewer'] == ['drain', 'duct', 'pipe', 'sewer']
category3 ~ food products associated with sentimentality: ['schmaltz', 'corn', 'sap', 'cheese'] == ['cheese', 'corn', 'sap', 'schmaltz']
bands ~ classic rock bands: ['rush', 'genesis', 'kansas', 'yes'] == ['genesis', 'kansas', 'rush', 'yes']
Fruit/Vegetable ~ vegetables that are also fruits: ['eggplant', 'cucumber', 'tomato', 'pepper'] == ['cucumber', 'eggplant', 'pepper', 'tomato']
Geometric Shape ~ 3-d shapes: ['cube', 'cone', 'pyramid', 'sphere'] == ['cone', 'cube', 'pyramid', 'sphere']
Apple Products ~ words with “i”: ['pad', 'pod', 'mac', 'phone'] == ['mac', 'pad', 'phone', 'pod']
Verbs ~ words with two pronunciations: ['mobile', 'lima', 'job', 'polish'] == ['job', 'lima', 'mobile', 'polish']
Musical Terms ~ musical sections: ['brass', 'string', 'wind', 'rhythm'] == ['brass', 'rhythm', 'string', 'wind']
awards ~ awards: ['cup', 'trophy', 'ribbon', 'medal'] == ['cup', 'm

In [141]:
extra_system_prompt = """
Check your solution before submitting it. Be sure about:
- that you have 4 groups of 4 words each
- that the words are not in the same category
- that the words are not in the same group
- that the words are not in the same category
"""

In [142]:
def generate_solution(sample):
    res = call_openai(system_prompt+extra_system_prompt, user_prompt.format(words=sample["words"]))
    generation = json.loads(res)
    return generation

generations_2 = [generate_solution(sample) for sample in ds[0:10]]
check_solutions(generations_2, ds[0:10])

category3 ~ conduits for water removal: ['duct', 'pipe', 'drain', 'sewer'] == ['drain', 'duct', 'pipe', 'sewer']
category3 ~ breadth: ['extent', 'scope', 'reach', 'range'] == ['extent', 'range', 'reach', 'scope']
US States ~ u.s. mountain states: ['utah', 'arizona', 'nevada', 'colorado'] == ['arizona', 'colorado', 'nevada', 'utah']
Carbonated Drinks ~ soda brands: ['sprite', 'mug', 'crush', 'squirt'] == ['crush', 'mug', 'sprite', 'squirt']
Bands ~ classic rock bands: ['rush', 'genesis', 'kansas', 'yes'] == ['genesis', 'kansas', 'rush', 'yes']
Fictional Characters ~ tony ___: ['hawk', 'soprano', 'stark', 'montana'] == ['hawk', 'montana', 'soprano', 'stark']
Fruits/Vegetables ~ vegetables that are also fruits: ['eggplant', 'cucumber', 'tomato', 'pepper'] == ['cucumber', 'eggplant', 'pepper', 'tomato']
Shapes ~ 3-d shapes: ['cube', 'cone', 'pyramid', 'sphere'] == ['cone', 'cube', 'pyramid', 'sphere']
Apple Products ~ words with “i”: ['pad', 'pod', 'mac', 'phone'] == ['mac', 'pad', 'phone'