In [16]:
from langchain_openai import ChatOpenAI
from langgraph.graph import (
    StateGraph,
    START,
    END
)
from typing_extensions import TypedDict
from IPython.display import Image, display
from pydantic import BaseModel, Field
from typing import List 
import random
import json 
from template2x2 import (
    get_temp0,
    get_temp1
)

In [None]:
class State(TypedDict):
    initial_prompt: str
    category: str 
    story: str
    clue: str 
    solution: str 
    final_prompt: str 

llm = ChatOpenAI(
    api_key="",
    model="gpt-4o"
)

In [18]:
def extract_json(msg: str):
    start_indx, end_indx = None, None 
    for idx, ch in enumerate(msg):
        if ch == '{':
            start_indx = idx 
            break
    for idx, ch in enumerate(msg):
        if ch == '}':
            end_indx = idx 
    if start_indx == None or end_indx == None:
        return "{}"
    return msg[start_indx: end_indx+1]

In [19]:
def category_prompt(state: State):
    x, y = state['initial_prompt']['row'], state['initial_prompt']['column']
    msg = llm.invoke(f"""
    You are an AI assistant who is an expert at designing grid logic puzzles. 
    
    I want to create a logic grid puzzle of size {x}*{y}. 
    This means the puzzle should have {x} categories, and each category should contain {y} distinct values.
    One of the category could be a numerical category like money, price, depth or date time. 
    Make the first category in the list of categories as numerical. 
    Rules:
    1. Dont pick weird values for categories. For example orange fruit and orange color, in such cases clues for grid puzzles will become ambiguous.
    2. Some examples of category triplets.
    (Depth of pool, Diver Name, Competition Year),
    (Dog graduation year, Police officer, Dog name),
    (Price, Person, Pet),
    (Price, Person, Fruits),
    ('Distance', 'Exo-Planet', 'Star')
    etc.
    
    Your task:
    1. Generate {x} category names that are intuitive and distinct.
    2. For each category, provide {y} values that are realistic, non-overlapping, and suitable for a logic grid puzzle.
    3. Return the result strictly in **valid JSON** format, with the structure:
    
    {{
        "categories": ["c1", "c2", ..., "c{x}"],
        "c1": ["c1_1", "c1_2", ..., "c1_{y}"],
        "c2": ["c2_1", "c2_2", ..., "c2_{y}"],
        ...
    }}
    No extra commentary or explanation. Only output valid JSON.
    """)
    return {'category': extract_json(msg.content)}


In [20]:
class Story(BaseModel):
    story: str = Field(
        description = "The story surrounding the logic grid puzzle."
    )
    clue: List[str] = Field(
        description = "The clues to solve the grid puzzle, each entry in a list is an independent clue."
    )

def story_clue_prompt(state: State):
    '''(2, 2) grid puzzles''' 
    def beautify(grid):
        result = ""
        for row in grid:
            temp_row = row.copy()
            temp_row = ' | '.join(row) 
            result += temp_row + "\n"
        return result
        
    '''Build Solution'''
    elements, dim = state['initial_prompt']['row'], state['initial_prompt']['column']
    st = json.loads(state['category'])
    keys = st['categories']
    cat = []
    for key in keys:
        row = st[key].copy()
        random.shuffle(row)
        cat.append(row)
    solution = [['' for _ in range(elements)] for __ in range(dim)]
    for i in range(elements):
        for j in range(dim):
            solution[j][i] = cat[i][j]
    soi = list(range(dim))
    random.shuffle(soi)
    ''' X '''
    
    problem_templates = [
        get_temp0(),
        get_temp1()
    ]
    index = random.randint(0, len(problem_templates)-1)
    llm.with_structured_output(Story)
    if index == 0:
        fin_prompt = problem_templates[index].format(keys, \
                keys[0], st[keys[0]], \
                keys[1], st[keys[1]], \
                solution[soi[0]][0], solution[soi[0]][1]
            )
    elif index == 1:
        fin_prompt = problem_templates[index].format(keys, \
                keys[0], st[keys[0]], \
                keys[1], st[keys[1]], \
                solution[soi[0]][0], solution[soi[1]][1]
            )
    resp = llm.invoke(fin_prompt)
    resp = json.loads(extract_json(resp.content))
    print(f'INDEX[{index}]; prompt> {fin_prompt}]')
    return {
        "solution": beautify(solution), 
        "story": resp['story'],
        "clue": resp['clues']
    }

'''TEST CODE'''
state = {
    'initial_prompt': { 'row': 2, 'column': 2 },
    'category': json.dumps({
        "categories": ["Month", "Dog"],
        "Month": ["March", "April"],
        "Dog": ["Pluto", "Donald"]
    }, indent=4)
}
story_clue_prompt(state)

INDEX[1]; prompt> Form a grid puzzle using the following template. 
A logic grid puzzle should have a story and clues.

For example a 3x4 grid puzzle should have 3 categories, and each category should contain 4 distinct values.
The solution to a grid puzzle 

Categories:
['Month', 'Dog']

Month: ['March', 'April']
Dog: ['Pluto', 'Donald']

Clues:
March is not related to Donald.

Donot provide extra clues, if there is one clue under the sub topic Clues, provide only one clue.
Your job is to fill in the following values in the following JSON.
{
    "story": "",
    "clues": "",
}
No extra commentary or explanation. Only output valid JSON.]


{'solution': 'March | Pluto\nApril | Donald\n',
 'story': 'Four friends each adopted a dog in a different month. Based on the clues, determine which friend adopted which dog and in which month.',
 'clue': 'March is not related to Donald.'}

In [21]:
# This doesn't go into the LLM.
def final_prompt(state: State):
    st = json.loads(state['category'])
    keys = st['categories']
    clues = state['clue']
    if isinstance(clues, list):
        clues = '\n'.join(clues)
    post_fix = f'''{state['story']}
Categories:
{keys}

{keys[0]}: {st[keys[0]]}
{keys[1]}: {st[keys[1]]}

Clues:
{clues}


While answering use the following format:
Step-by-step solution:
Your steps showing how you are solving the puzzle
Final Answer:
Create a table like this
c1_1 | c2_1 | c3_1
c1_2 | c2_2 | c3_2
c1_3 | c2_3 | c3_3
c1_4 | c2_4 | c3_4'''
    return {
        "final_prompt": post_fix
        }

In [14]:
workflow = StateGraph(State)
workflow.add_node('c_prompt', category_prompt)
workflow.add_node('sc_prompt', story_clue_prompt)
workflow.add_node('f_prompt', final_prompt)

workflow.add_edge(START, 'c_prompt')
workflow.add_edge('c_prompt', 'sc_prompt')
workflow.add_edge('sc_prompt', 'f_prompt')
workflow.add_edge('f_prompt', END)

chain = workflow.compile()

In [15]:
'''
row: how many elements are there 
column: how many values each category can have
'''
import pandas as pd 
EPOCHS = 10
for epoch in range(EPOCHS):
    state = chain.invoke({"initial_prompt":{
            "row": 2,
            "column": 2
            }
    })
    print(state['final_prompt'], state['solution'], \
        sep="\n-------------\n")
    df = pd.read_csv('data/grid_puzzle_easy_x.csv')
    old_shape = df.shape[0]
    new_row = pd.DataFrame({'question': [state['final_prompt']], 'answer': [state['solution']]})
    df = pd.concat([df, new_row], ignore_index=True)
    df.to_csv('data/grid_puzzle_easy_x.csv', index=False)
    new_shape = df.shape[0]
    assert new_shape > old_shape

INDEX[0]; prompt> Form a grid puzzle using the following template. 
A logic grid puzzle should have a story and clues.

For example a 3x4 grid puzzle should have 3 categories, and each category should contain 4 distinct values.
The solution to a grid puzzle 

Categories:
['Height in meters', 'Building Name']

Height in meters: ['100', '200']
Building Name: ['Skyscraper One', 'Tower Two']

Clues:
200 and Tower Two are directly related.

Donot provide extra clues, if there is one clue under the sub topic Clues, provide only one clue.
Your job is to fill in the following values in the following JSON.
{
    "story": "",
    "clues": "",
}
No extra commentary or explanation. Only output valid JSON.]
In a small city, there are two newly constructed buildings that have become the talk of the town. Each building has a distinct height, and the challenge is to match the buildings with their respective heights.
Categories:
['Height in meters', 'Building Name']

Height in meters: ['100', '200']
Bu