In [1]:
!pip install openai

from openai import OpenAI  # Replace Groq with OpenAI
import networkx as nx

import numpy as np
from scipy import stats as sps
import seaborn as sns
from matplotlib import pyplot as plt
import pandas as pd
np.random.seed(42)

Defaulting to user installation because normal site-packages is not writeable


In [2]:
N = 23
top_n = 3

In [3]:
import pickle
with open("data/Moviegroups.pkl", "rb") as f:
    movie_group = pickle.load(f)
with open("data/Movietags.pkl", "rb") as f:
    movietags = pickle.load(f)

In [4]:
import pickle
with open("data/sampled_split_tags.pkl", "rb") as f:
    sampled_split_tags = pickle.load(f)

In [5]:
import pickle
with open("data/extended_tags.pkl", "rb") as f:
    extended_tags = pickle.load(f)

In [6]:
import time
import numpy as np

class Agent:
    def __init__(self, api_key, model, description=None, api_type="Groq", max_retries=5, retry_delay=50):
        self.client = OpenAI(api_key=api_key, base_url="https://openrouter.ai/api/v1")
        self.model = model
        self.description = description
        self.max_retries = max_retries
        self.retry_delay = retry_delay
        self.history = [
            {"role": "system", "content": f"{self.description}"}
        ]

    def _generate_response(self, messages):
        print(f"# call with {messages}")
        response = None
        try:
          response = self.client.chat.completions.create(
              messages=messages,
              model=self.model,
              temperature=0,
              top_p=0.5
          )
        except Exception as e:
          print(f"# Error: {e}")
        print(f"# raw: {response}")

        # Check for API error
        if getattr(response, "error", None):
            print(f"# API error {response.error.get('code')}: {response.error.get('message')}")
            return None

        # Check if choices exist
        if not getattr(response, "choices", None):
            print("# No choices returned")
            return None

        content = getattr(response.choices[0].message, "content", None)
        if not content:
            print("# Empty content returned")
            return None

        return content

    def stateful_chat(self, user_message, history_message = ""):
        self.history.append({"role": "user", "content": user_message})
        response = None
        for attempt in range(self.max_retries):
            response = self._generate_response(self.history)
            if response is not None:
                try:
                    tmp = [float(x) for x in response.split()]
                    if len(tmp) != 3:
                        raise ValueError("Response does not contain exactly 3 numbers")
                except ValueError:
                    # Parsing failed or wrong number of floats -> retry
                    print(f"# Retry {attempt+1}/{self.max_retries} after {self.retry_delay}s...")
                    time.sleep(self.retry_delay)
                    continue  # go to next retry attempt

                # If we got here, parsing succeeded
                print(f"#{response}")
                self.history.pop()
                if history_message != "":
                    self.history.append({"role": "user", "content": history_message})
                return response
            print(f"# Retry {attempt+1}/{self.max_retries} after {self.retry_delay}s...")
            time.sleep(self.retry_delay)

        self.history.pop()
        return response


    def reset_history(self):
        self.history = [self.history[0]]

### Network

In [7]:
N=23
top_n=3

In [8]:
# Example
network = np.array([[0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, ],
[0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, ],
[0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, ],
[0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, ],
[1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, ],
[0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 1, 0, ],
[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, ],
[1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, ],
[0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, ],
[0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ],
[0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, ],
[1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, ],
[0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, ],
[0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, ],
[0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, ],
[0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, ],
[0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, ],
[0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, ],
[0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, ],
[0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, ],
[0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, ],
[1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ],
]
)

weights = [[0.333, 0.333, 0.334] for i in range(N)]

edges = [[] for i in range(N)]
for i in range(N):
    for j in range(i + 1, N):
        if (network[i][j] == 1):
            edges[i].append(j)
            edges[j].append(i)
    edges[i].append(i)
edges

[[4, 6, 7, 11, 22, 0],
 [3, 7, 10, 16, 20, 22, 1],
 [5, 7, 11, 12, 20, 2],
 [1, 9, 10, 14, 19, 21, 3],
 [0, 5, 12, 13, 15, 19, 4],
 [2, 4, 8, 14, 15, 17, 21, 5],
 [0, 10, 11, 16, 17, 6],
 [0, 1, 2, 16, 7],
 [5, 17, 21, 8],
 [3, 11, 9],
 [1, 3, 6, 12, 13, 19, 22, 10],
 [0, 2, 6, 9, 14, 17, 11],
 [2, 4, 10, 17, 12],
 [4, 10, 14, 19, 13],
 [3, 5, 11, 13, 18, 21, 14],
 [4, 5, 18, 20, 15],
 [1, 6, 7, 20, 21, 16],
 [5, 6, 8, 11, 12, 17],
 [14, 15, 18],
 [3, 4, 10, 13, 19],
 [1, 2, 15, 16, 20],
 [3, 5, 8, 14, 16, 21],
 [0, 1, 10, 22]]

### Run LLM

In [9]:
description1 ="You are participating in a movie guessing game with three possible movies.  Your previous belief is given; as well as a clue related to the correct movie. Your task is to output your new belief probabilites. Return ONLY three space-separated decimal numbers. NO explanations, NO text, NO formatting.",

In [10]:
description2 = "You are participating in a movie guessing game with three possible movies. Your previous belief and your neighbors' beliefs are given. Your neighbors have seen different clues than you. Your task is to output your new belief probabilites. Return ONLY three space-separated decimal numbers. NO explanations, NO text, NO formatting.",

In [11]:
OPENAI_API_KEY="..."

In [13]:
with open("data/interguess", "rb") as f:
    intermediary_guess = np.load(f)
with open("data/guess", "rb") as f:
    guess = np.load(f)

In [14]:
agents1 = []
agents2 = []
for desc_id in range(N):
    agent1 = Agent(
        api_key=OPENAI_API_KEY,
        model="openai/gpt-oss-120b",
        api_type="OpenRouter",
        description=description1
    )
    agent2 = Agent(
        api_key=OPENAI_API_KEY,
        model="openai/gpt-oss-120b",
        api_type="OpenRouter",
        description=description2
    )
    agents1.append(agent1)
    agents2.append(agent2)

In [None]:
guess = np.zeros((1000, 100, N, top_n))
intermediary_guess = np.zeros((1000, 100, N, top_n))

In [15]:
num_samples = 14
T_max1 = 15

In [16]:
# how clueids are generated in RunNBSL
# sim.reset_randomness()
# shuffled_times, clueids= sim.get_clueids(0, num_samples) # the second shouldnt matter
# shuffled_times.shape, clueids.shape

In [18]:
with open("data/clueids", "rb") as f:
    clueids = np.load(f)

In [20]:
for sample in range(num_samples):
    for k in range(N):
        agents1[k].reset_history()
        agents2[k].reset_history()
    prev_guess = [[0.333, 0.333, 0.334] for i in range(N)]

    for id_t in range(T_max1):
        # t = shuffled_times[id_t]

        for k in range(N):
            clue = extended_tags[clueids[id_t][k]]
            if (intermediary_guess[sample][id_t][k][0] == 0):
                head_prompt = f"""New clue: \"{clue}\". Your own belief from the previous round over [Pulp Fiction, Forrest Gump, Fight Club] was: {prev_guess[k]}.
Based on the new clue and your previous belief, what is your belief over the classes [Pulp Fiction, Forrest Gump,  Fight Club]? Return ONLY EXACTLY three space-separated decimal numbers. NO explanations, NO text, NO formatting"""
                response = agents1[k].stateful_chat(head_prompt)
                intermediary_guess[sample][id_t][k] = normalize(np.array([float(x) for x in response.split()]))
            print(f"Sample {sample}, time {id_t}, agent {k}.  clue_id={clueids[id_t][k]};  {clue}: {intermediary_guess[sample][id_t][k]}")
        for k in range(N):
            if (guess[sample][id_t][k][0] == 0):
                head_prompt = f"Your own belief from the previous round over [Pulp Fiction, Forrest Gump,  Fight Club] was: {intermediary_guess[sample][id_t][k]}.\n Your neighbors' beliefs for [Pulp Fiction, Forrest Gump, Fight Club] are\n"
                for neighbor in edges[k]:
                    if (neighbor != k):
                        head_prompt += f"Neighbor {neighbor}: {intermediary_guess[sample][id_t][neighbor]}\n"
                head_prompt += f"Based on the provided information, what is your current probability belief over [Pulp Fiction, Forrest Gump,  Fight Club]? Return ONLY EXACTLY three space-separated decimal numbers. NO explanations, NO text, NO formatting."

                response = agents2[k].stateful_chat(head_prompt)
                guess[sample][id_t][k] = normalize(np.array([float(x) for x in response.split()]))
            print(f"\t agent {k}. Update: {guess[sample][id_t][k]}")
        prev_guess = guess[sample][id_t]
#         plot_network(network, intermediary_guess[sample][id_t], f"Sample {sample}, time {id_t}", 0)

Sample 0, time 0, agent 0.  clue_id=372;  sexy: [0.59912289 0.20043855 0.20043855]
Sample 0, time 0, agent 1.  clue_id=1463;  airport: [0.3330011  0.3330011  0.33399781]
Sample 0, time 0, agent 2.  clue_id=870;  car accident: [0.39978072 0.30010964 0.30010964]
Sample 0, time 0, agent 3.  clue_id=369;  too long: [0.69879398 0.15060301 0.15060301]
Sample 0, time 0, agent 4.  clue_id=2079;  piercing: [0.3330011  0.3330011  0.33399781]
Sample 0, time 0, agent 5.  clue_id=2132;  AFI: [0.09179707 0.81640586 0.09179707]
Sample 0, time 0, agent 6.  clue_id=2133;  planning: [0.23133659 0.15458985 0.61407356]
Sample 0, time 0, agent 7.  clue_id=1725;  unreliable narration: [0.05093192 0.05093192 0.89813615]
Sample 0, time 0, agent 8.  clue_id=1009;  katana sword: [0.3330011  0.3330011  0.33399781]
Sample 0, time 0, agent 9.  clue_id=1135;  sex scene: [0.69879398 0.05093192 0.2502741 ]
Sample 0, time 0, agent 10.  clue_id=2154;  flashback: [0.30010964 0.59912289 0.10076747]
Sample 0, time 0, agen

	 agent 2. Update: [0.69398493 0.03757471 0.26844036]
	 agent 3. Update: [0.70986624 0.21460288 0.07553088]
	 agent 4. Update: [0.64129473 0.17874514 0.17996013]
	 agent 5. Update: [0.61912389 0.14427988 0.23659623]
	 agent 6. Update: [0.62448203 0.18396873 0.19154924]
	 agent 7. Update: [0.50361814 0.22436319 0.27201868]
	 agent 8. Update: [0.65947042 0.19927495 0.14125463]
	 agent 9. Update: [0.99780662 0.0010966  0.00109678]
	 agent 10. Update: [0.66426891 0.15480116 0.18092993]
	 agent 11. Update: [0.73670129 0.0471307  0.21616801]
	 agent 12. Update: [0.73157255 0.0466093  0.22181815]
	 agent 13. Update: [0.61865559 0.27548611 0.10585831]
	 agent 14. Update: [0.99780724 0.00109638 0.00109638]
	 agent 15. Update: [0.72202133 0.04933719 0.22864148]
	 agent 16. Update: [0.51309681 0.21955386 0.26734933]
	 agent 17. Update: [0.99779707 0.00109638 0.00110655]
	 agent 18. Update: [0.99586764 0.00109838 0.00303399]
	 agent 19. Update: [0.6648267  0.25811112 0.07706217]
	 agent 20. Update

Sample 7, time 5, agent 17.  clue_id=2107;  crime: [0.78803947 0.01653743 0.1954231 ]
Sample 7, time 5, agent 18.  clue_id=524;  overated: [0.99538623 0.00153294 0.00308083]
Sample 7, time 5, agent 19.  clue_id=766;  american idiocy: [0.98804445 0.0021808  0.00977474]
Sample 7, time 5, agent 20.  clue_id=1031;  male with long hair: [0.98589156 0.00326124 0.0108472 ]
Sample 7, time 5, agent 21.  clue_id=899;  crime spree: [0.9765145 0.0063042 0.0171813]
Sample 7, time 5, agent 22.  clue_id=130;  bloody: [0.88667398 0.01116316 0.10216286]
	 agent 0. Update: [0.71182848 0.13225514 0.15591639]
	 agent 1. Update: [0.99780724 0.00109638 0.00109638]
	 agent 2. Update: [0.99780524 0.00109638 0.00109838]
	 agent 3. Update: [0.8093322  0.05115319 0.1395146 ]
	 agent 4. Update: [0.99780723 0.00109638 0.00109638]
	 agent 5. Update: [0.8746231  0.02463121 0.1007457 ]
	 agent 6. Update: [0.99779941 0.00109683 0.00110376]
	 agent 7. Update: [0.99780425 0.00109648 0.00109927]
	 agent 8. Update: [0.937

	 agent 10. Update: [0.68768533 0.1380799  0.17423478]
	 agent 11. Update: [0.99780724 0.00109638 0.00109638]
	 agent 12. Update: [0.80494781 0.13192071 0.06313148]
	 agent 13. Update: [0.99768863 0.00120502 0.00110635]
	 agent 14. Update: [0.72524936 0.16872069 0.10602995]
	 agent 15. Update: [0.99780718 0.00109644 0.00109638]
	 agent 16. Update: [0.91446204 0.02858948 0.05694848]
	 agent 17. Update: [0.99610884 0.00279478 0.00109638]
	 agent 18. Update: [0.52920887 0.43092221 0.03986892]
	 agent 19. Update: [0.99778728 0.00110636 0.00110636]
	 agent 20. Update: [0.77678451 0.19111965 0.03209584]
	 agent 21. Update: [0.7020452  0.22919357 0.06876123]
	 agent 22. Update: [0.70162731 0.10917138 0.1892013 ]
Sample 11, time 0, agent 0.  clue_id=372;  sexy: [0.69879398 0.15060301 0.15060301]
Sample 11, time 0, agent 1.  clue_id=1463;  airport: [0.3330011  0.3330011  0.33399781]
Sample 11, time 0, agent 2.  clue_id=870;  car accident: [0.3330011  0.3330011  0.33399781]
Sample 11, time 0, ag

In [None]:
with open("data/interguess", "rb") as f:
    intermediary_guess = np.load(f)
with open("data/guess", "rb") as f:
    guess = np.load(f)