In [1]:
import json
import random
import os
import sys
import torch
import fire
from tqdm import tqdm
from collections import defaultdict

import warnings
warnings.filterwarnings("ignore")

%load_ext autoreload
%autoreload 2

In [2]:
sys.path.append("../")
from utils import (
    load_model_tokenzier,
    compute_role_description,
    NewLineStoppingCriteria,
    compute_final_roles,
    ask_mental_state_questions,
)

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
player_names = [
    "Alice",
    "Bob",
    "Charlie",
    "David",
    "Eve",
    "Frank",
    "Grace",
    "Heidi",
    "Ivan",
    "Judy",
]
player_roles = ["Villager", "Werewolf", "Troublemaker"]

with open("/data/nikhil_prakash/mind/werewolf/game_description.txt", "r") as f:
    game_description = f.read()

In [4]:
# Delete all .txt and .csv files in the conversations folder
def delete_exists():
    for file in os.listdir("/data/nikhil_prakash/mind/werewolf/conversations"):
        if file.endswith(".txt") or file.endswith(".csv"):
            os.remove(os.path.join("/data/nikhil_prakash/mind/werewolf/conversations", file))

In [5]:
def play_werewolf(
    model_name: str = "daryl149/llama-2-70b-chat-hf",
    precision: str = "int4",
    n_games: int = 10,
    n_players: int = 3,
    n_rounds: int = 5,
):
    model, tokenizer = load_model_tokenzier(model_name, precision, device)
    delete_exists()

    for game_idx in range(n_games):
        random.shuffle(player_roles)
        players = [
            {"name": name, "role": role}
            for name, role in zip(random.sample(player_names, n_players), player_roles)
        ]
        players = compute_final_roles(players)
        players = compute_role_description(players)

        stopping_criteria = NewLineStoppingCriteria(tokenizer)

        with torch.no_grad():
            for round_idx in tqdm(range(n_rounds), total=n_rounds):
                for player_idx in range(n_players):
                    player = players[player_idx]
                    prompt = f"{game_description}\n\n{player['role_description']}\n\n"

                    if os.path.exists(
                        f"/data/nikhil_prakash/mind/werewolf/conversations/{game_idx}.txt"
                    ):
                        with open(
                            f"/data/nikhil_prakash/mind/werewolf/conversations/{game_idx}.txt", "r"
                        ) as f:
                            conversation = f.read()
                    else:
                        conversation = ""

                    prompt += f"DAY PHASE:\n{conversation}\n{player['name']} talks:"
#                     print(prompt)

                    inputs = tokenizer(prompt, return_tensors="pt").to(device)
                    outputs = model.generate(
                        **inputs,
                        max_new_tokens=100,
                        num_return_sequences=1,
                        temperature=0.01,
                        do_sample=True,
                        stopping_criteria=[stopping_criteria],
                        early_stopping=True,
                    )
                    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
#                     print(f"{player['name']} talks: {response[len(prompt) :].strip()}")

                    with open(
                        f"/data/nikhil_prakash/mind/werewolf/conversations/{game_idx}.txt", "a+"
                    ) as f:
                        f.write(f"{player['name']} talks: {response[len(prompt) :].strip()}\n")

            own_vote, other_votes = ask_mental_state_questions(
                model, tokenizer, device, players, game_description, game_idx, conversation
            )

            with open(f"./conversations/{game_idx}.csv", "w") as f:
                f.write("Name,Role,Final Role,Own Vote,Other Votes\n")
                for player in players:
                    f.write(f"{player['name']},{player['role']},{player['final_role']},{own_vote[player['name']]},{other_votes[player['name']]}\n")

In [None]:
play_werewolf(
    model_name = "daryl149/llama-2-70b-chat-hf",
    precision = "int4",
    n_games = 50,
    n_players = 3,
    n_rounds = 5,
)

Loading checkpoint shards:   0%|          | 0/15 [00:00<?, ?it/s]

100%|██████████| 5/5 [03:31<00:00, 42.32s/it]
100%|██████████| 5/5 [03:29<00:00, 41.84s/it]
100%|██████████| 5/5 [02:54<00:00, 34.84s/it]
100%|██████████| 5/5 [04:19<00:00, 51.88s/it]
100%|██████████| 5/5 [03:30<00:00, 42.18s/it]
100%|██████████| 5/5 [03:57<00:00, 47.41s/it]
100%|██████████| 5/5 [03:23<00:00, 40.66s/it]
100%|██████████| 5/5 [04:47<00:00, 57.50s/it]
100%|██████████| 5/5 [03:58<00:00, 47.62s/it]
100%|██████████| 5/5 [04:13<00:00, 50.67s/it]
100%|██████████| 5/5 [03:33<00:00, 42.79s/it]
100%|██████████| 5/5 [04:02<00:00, 48.45s/it]
100%|██████████| 5/5 [04:00<00:00, 48.03s/it]
100%|██████████| 5/5 [04:20<00:00, 52.08s/it]
100%|██████████| 5/5 [04:47<00:00, 57.45s/it]
100%|██████████| 5/5 [02:45<00:00, 33.14s/it]
100%|██████████| 5/5 [04:16<00:00, 51.33s/it]
100%|██████████| 5/5 [04:30<00:00, 54.13s/it]
 40%|████      | 2/5 [01:48<02:45, 55.05s/it]

## Analysing Voting Results

In [3]:
path = "./conversations"

correct, won, total, lost = 0, 0, 0, 0
for i in range(50):
    with open(f"{path}/{i}.csv", "r") as f:
        results = f.read()

    results = results.split("\n")
    del results[0], results[-1]

    votes = {}
    for result in results:
        result = result.split(",")
        if result[2] == "Werewolf":
            werewolf = result[0]
            break

    for i in range(len(results)):
        results[i] = results[i].split(",")
        results[i][4] = ",".join(results[i][4:])
        del results[i][5]

        votes[results[i][0]] = {
                "own_votes": results[i][3],
                "other_votes": eval(results[i][4])
            }

    for player in votes:
        for other_player, their_vote in votes[player]["other_votes"].items():
            if votes[other_player]['own_votes'] == their_vote:
                correct += 1
            total += 1

    vote_count = defaultdict(int)
    for player in votes:
        vote_count[votes[player]["own_votes"]] += 1

    # Find the player who has the most votes against them
    max_votes = max(vote_count.values())
    max_voted = [player for player, votes in vote_count.items() if votes == max_votes][0]

    if max_voted == werewolf:
        won += 1
    else:
        print(f"Werewolf: {werewolf}; Voted: {max_voted}")
        lost += 1

Werewolf: Judy; Voted: Alice
Werewolf: Judy; Voted: Bob
Werewolf: Frank; Voted: Heidi
Werewolf: Frank; Voted: Eve
Werewolf: Frank; Voted: Charlie
Werewolf: Grace; Voted: Eve
Werewolf: Charlie; Voted: Alice
Werewolf: Grace; Voted: Eve
Werewolf: Frank; Voted: Eve
Werewolf: Grace; Voted: Charlie
Werewolf: Judy; Voted: Bob
Werewolf: Eve; Voted: Judy
Werewolf: Charlie; Voted: Alice
Werewolf: Eve; Voted: Judy
Werewolf: Grace; Voted: Judy
Werewolf: Heidi; Voted: Grace
Werewolf: Ivan; Voted: Heidi
Werewolf: Heidi; Voted: Eve
Werewolf: Ivan; Voted: Charlie
Werewolf: Frank; Voted: Grace
Werewolf: Alice; Voted: Charlie
Werewolf: Eve; Voted: Charlie
Werewolf: David; Voted: Grace
Werewolf: Ivan; Voted: Frank
Werewolf: Charlie; Voted: Grace
Werewolf: Ivan; Voted: Frank
Werewolf: Judy; Voted: Charlie
Werewolf: Eve; Voted: Alice
Werewolf: Charlie; Voted: Ivan
Werewolf: Heidi; Voted: Ivan
Werewolf: Frank; Voted: Grace
Werewolf: David; Voted: Heidi
Werewolf: Ivan; Voted: Judy
Werewolf: Grace; Voted: Fra

In [4]:
round(correct/total, 2), correct, total

(0.36, 108, 300)

In [5]:
round(won/(won+lost), 2), won, lost

(0.28, 14, 36)