In [1]:
import random
import os
import sys
import torch
import fire
from tqdm import tqdm

import warnings
warnings.filterwarnings("ignore")

%load_ext autoreload
%autoreload 2

In [2]:
sys.path.append("../")
from utils import (
    load_model_tokenzier,
    compute_role_description,
    NewLineStoppingCriteria,
    compute_final_roles,
    ask_mental_state_questions,
)

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
player_names = [
    "Alice",
    "Bob",
    "Charlie",
    "David",
    "Eve",
    "Frank",
    "Grace",
    "Heidi",
    "Ivan",
    "Judy",
]
player_roles = ["Villager", "Werewolf", "Troublemaker"]

with open("/data/nikhil_prakash/mind/werewolf/game_description.txt", "r") as f:
    game_description = f.read()

In [4]:
# Delete all .txt and .csv files in the conversations folder
def delete_exists():
    for file in os.listdir("/data/nikhil_prakash/mind/werewolf/conversations"):
        if file.endswith(".txt") or file.endswith(".csv"):
            os.remove(os.path.join("/data/nikhil_prakash/mind/werewolf/conversations", file))

In [5]:
def play_werewolf(
    model_name: str = "daryl149/llama-2-70b-chat-hf",
    precision: str = "int4",
    n_games: int = 10,
    n_players: int = 3,
    n_rounds: int = 5,
):
    model, tokenizer = load_model_tokenzier(model_name, precision, device)
    delete_exists()

    for game_idx in range(n_games):
        players = [
            {"name": name, "role": role}
            for name, role in zip(random.sample(player_names, n_players), player_roles * n_players)
        ]
        players = compute_final_roles(players)
        players = compute_role_description(players)

        stopping_criteria = NewLineStoppingCriteria(tokenizer)

        with torch.no_grad():
            for round_idx in tqdm(range(n_rounds), total=n_rounds):
                for player_idx in range(n_players):
                    player = players[player_idx]
                    prompt = f"{game_description}\n\n{player['role_description']}\n\n"

                    if os.path.exists(
                        f"/data/nikhil_prakash/mind/werewolf/conversations/{game_idx}.txt"
                    ):
                        with open(
                            f"/data/nikhil_prakash/mind/werewolf/conversations/{game_idx}.txt", "r"
                        ) as f:
                            conversation = f.read()
                    else:
                        conversation = ""

                    prompt += f"DAY PHASE:\n{conversation}\n{player['name']} talks:"
#                     print(prompt)

                    inputs = tokenizer(prompt, return_tensors="pt").to(device)
                    outputs = model.generate(
                        **inputs,
                        max_new_tokens=100,
                        num_return_sequences=1,
                        temperature=0.00,
                        do_sample=False,
                        stopping_criteria=[stopping_criteria],
                        early_stopping=True,
                    )
                    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
#                     print(f"{player['name']} talks: {response[len(prompt) :].strip()}")

                    with open(
                        f"/data/nikhil_prakash/mind/werewolf/conversations/{game_idx}.txt", "a+"
                    ) as f:
                        f.write(f"{player['name']} talks: {response[len(prompt) :].strip()}\n")

            own_vote, other_votes = ask_mental_state_questions(
                model, tokenizer, device, players, game_description, game_idx, conversation
            )

            with open(f"./conversations/{game_idx}.csv", "w") as f:
                f.write("Name,Role,Final Role,Own Vote,Other Votes\n")
                for player in players:
                    f.write(f"{player['name']},{player['role']},{player['final_role']},{own_vote[player['name']]},{other_votes[player['name']]}\n")

In [None]:
play_werewolf(
    model_name = "daryl149/llama-2-70b-chat-hf",
    precision = "int4",
    n_games = 10,
    n_players = 3,
    n_rounds = 5,
)

Loading checkpoint shards:   0%|          | 0/15 [00:00<?, ?it/s]

100%|██████████| 5/5 [04:09<00:00, 49.84s/it]


QUESTION: Who would you vote for?
Options: A: Alice, B: Bob, C: Eve, D: Not sure
Answer:
QUESTION: Who would Bob vote for?
Options: A: Alice, B: Bob, C: Eve, D: Not sure
Answer:
QUESTION: Who would Alice vote for?
Options: A: Alice, B: Bob, C: Eve, D: Not sure
Answer:
QUESTION: Who would you vote for?
Options: A: Bob, B: Eve, C: Alice, D: Not sure
Answer:
QUESTION: Who would Eve vote for?
Options: A: Bob, B: Eve, C: Alice, D: Not sure
Answer:
QUESTION: Who would Alice vote for?
Options: A: Bob, B: Eve, C: Alice, D: Not sure
Answer:
QUESTION: Who would you vote for?
Options: A: Eve, B: Bob, C: Alice, D: Not sure
Answer:
QUESTION: Who would Eve vote for?
Options: A: Eve, B: Bob, C: Alice, D: Not sure
Answer:
QUESTION: Who would Bob vote for?
Options: A: Eve, B: Bob, C: Alice, D: Not sure
Answer:


100%|██████████| 5/5 [03:23<00:00, 40.66s/it]


QUESTION: Who would you vote for?
Options: A: Ivan, B: Alice, C: Frank, D: Not sure
Answer:
QUESTION: Who would Frank vote for?
Options: A: Ivan, B: Alice, C: Frank, D: Not sure
Answer:
QUESTION: Who would Ivan vote for?
Options: A: Ivan, B: Alice, C: Frank, D: Not sure
Answer:
QUESTION: Who would you vote for?
Options: A: Frank, B: Alice, C: Ivan, D: Not sure
Answer:
QUESTION: Who would Alice vote for?
Options: A: Frank, B: Alice, C: Ivan, D: Not sure
Answer:
