In [6]:
import numpy as np
import torch
import torch.nn.functional as F

In [None]:
class SignallingGame:

    def __init__(self, states: int, messages: int, actions: int, seed: int = 42):
        self.states = states
        self.messages = messages
        self.actions = actions
        self.message_weights = np.full((states, messages), 1e-6, dtype=float)
        self.action_weights = np.full((messages, actions), 1e-6, dtype=float)
        self.rng = np.random.RandomState(seed)
        self.stats = []


    def world_state(self):
        return self.rng.randint(self.states)


    def emit_message(self, state):
        w = self.message_weights[state, :]
        probs = w / np.sum(w)
        message = self.rng.choice(self.messages, p=probs)
        return message

    def perform_action(self, message):
        w = self.action_weights[message, :]
        probs = w / np.sum(w)
        action = self.rng.choice(self.actions, p=probs)
        return action


    def payoff(self, state, action):
        return 1 if action == state else 0


    def update_weights(self, state, message, action, payoff):
        self.message_weights[state, message] += payoff
        self.action_weights[message, action] += payoff


    def snapshot(self, state, message, action, payoff):
        self.stats.append({
            "s": state,
            "m": message,
            "a": action,
            "p": payoff,
            "mw": self.message_weights.copy(),
            "aw": self.action_weights.copy(),
        })


    def play(self, N: int):
        for _ in range(N):
            state = self.world_state()
            message = self.emit_message(state)
            action = self.perform_action(message)
            payoff = self.payoff(state, action)
            self.update_weights(state, message, action, payoff)
            self.snapshot(state, message, action, payoff)

In [48]:
class ReinforcedRothErev(SignallingGame):
    def __init__(self, states: int, messages: int, actions: int, l: float, seed: int = 42):
        super().__init__(states, messages, actions, seed)

        # Convert tabular weights into trainable torch parameters
        self.message_weights = torch.nn.Parameter(
            torch.tensor(self.message_weights, dtype=torch.float32)
        )
        self.action_weights = torch.nn.Parameter(
            torch.tensor(self.action_weights, dtype=torch.float32)
        )

        # SGD optimizer on both sender and receiver policies
        self.optimizer = torch.optim.SGD(
            [self.message_weights, self.action_weights],
            lr=l,
        )


    def emit_message(self, state):
        with torch.no_grad():
            logits = self.message_weights[state] # shape: (messages,)
            probs = torch.softmax(logits, dim=-1).cpu().numpy()
        return np.random.choice(self.messages, p=probs)


    def perform_action(self, message):
        with torch.no_grad():
            logits = self.action_weights[message] # shape: (actions,)
            probs = torch.softmax(logits, dim=-1).cpu().numpy()
        return np.random.choice(self.actions, p=probs)


    def update_weights(self, state, message, action, payoff):
        payoff = torch.tensor(float(payoff), dtype=torch.float32) # scalar

        # Sender loss
        msg_logits = self.message_weights[state] # (messages,)
        log_probs_msg = F.log_softmax(msg_logits, dim=-1)
        loss_msg = - payoff * log_probs_msg[message]

        # Receiver loss
        act_logits = self.action_weights[message] # (actions,)
        log_probs_act = F.log_softmax(act_logits, dim=-1)
        loss_act = - payoff * log_probs_act[action]

        # Combined loss
        loss = loss_msg + loss_act

        # PyTorch training step
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()


    def snapshot(self, state, message, action, payoff):
        """Snapshot that works with torch parameters."""
        self.stats.append({
            "s": state,
            "m": message,
            "a": action,
            "p": payoff,
            "mw": self.message_weights.detach().cpu().numpy().copy(),
            "aw": self.action_weights.detach().cpu().numpy().copy(),
        })

In [88]:
game = SignallingGame(states=3, messages=3, actions=3)
game.play(100)

# Payoffs over time
payoffs = [x["p"] for x in game.stats]
print("Number of rounds:", len(payoffs))
print("First 20 payoffs:", payoffs[:20])
print("Last 20 payoffs:", payoffs[-20:])
print("Total reward:", sum(payoffs))

Number of rounds: 100
First 20 payoffs: [0, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0]
Last 20 payoffs: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
Total reward: 82


In [None]:
game = ReinforcedRothErev(states=3, messages=3, actions=3, l=1)
game.play(100)

# Payoffs over time
payoffs = [x["p"] for x in game.stats]
print("Number of rounds:", len(payoffs))
print("First 20 payoffs:", payoffs[:20])
print("Last 20 payoffs:", payoffs[-20:])
print("Total reward:", sum(payoffs))

Number of rounds: 200
First 20 payoffs: [1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 1, 1]
Last 20 payoffs: [1, 0, 1, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1]
Total reward: 111


In [99]:
class ReinforcedRothErev(SignallingGame):
    def __init__(self, states: int, messages: int, actions: int, l: float, seed: int = 42):
        super().__init__(states, messages, actions, seed)

        # Convert tabular weights into trainable torch parameters
        self.message_weights = torch.nn.Parameter(
            torch.tensor(self.message_weights, dtype=torch.float32)
        )
        self.action_weights = torch.nn.Parameter(
            torch.tensor(self.action_weights, dtype=torch.float32)
        )

        # SGD optimizer on both sender and receiver policies
        self.optimizer = torch.optim.SGD(
            [self.message_weights, self.action_weights],
            lr=l,
        )


    def emit_message(self, state):
        with torch.no_grad():
            logits = self.message_weights[state] # shape: (messages,)
            probs = torch.softmax(logits, dim=-1).cpu().numpy()
        return np.random.choice(self.messages, p=probs)


    def perform_action(self, message):
        with torch.no_grad():
            logits = self.action_weights[message] # shape: (actions,)
            probs = torch.softmax(logits, dim=-1).cpu().numpy()
        return np.random.choice(self.actions, p=probs)


    def update_weights(self, state, message, action, payoff):
        payoff = torch.tensor(float(payoff), dtype=torch.float32) # scalar

        # Sender loss
        msg_logits = self.message_weights[state] # (messages,)
        log_probs_msg = F.log_softmax(msg_logits, dim=-1)
        loss_msg = - payoff * log_probs_msg[message]

        # Receiver loss
        act_logits = self.action_weights[message] # (actions,)
        log_probs_act = F.log_softmax(act_logits, dim=-1)
        loss_act = - payoff * log_probs_act[action]

        # Combined loss
        loss = loss_msg + loss_act

        # PyTorch training step
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()


    def snapshot(self, state, message, action, payoff):
        """Snapshot that works with torch parameters."""
        self.stats.append({
            "s": state,
            "m": message,
            "a": action,
            "p": payoff,
            "mw": self.message_weights.detach().cpu().numpy().copy(),
            "aw": self.action_weights.detach().cpu().numpy().copy(),
        })

    def play_batch(self, batch_size):
        logprobs_msg = []
        logprobs_act = []
        rewards = []

        for _ in range(batch_size):
            state = self.world_state()
            message = self.emit_message(state)
            action = self.perform_action(message)
            payoff = self.payoff(state, action)

            # store reward
            rewards.append(payoff)

            # get log probs for reinforce
            msg_logits = self.message_weights[state]
            log_probs_msg = F.log_softmax(msg_logits, dim=-1)
            logprobs_msg.append(log_probs_msg[message])

            act_logits = self.action_weights[message]
            log_probs_act = F.log_softmax(act_logits, dim=-1)
            logprobs_act.append(log_probs_act[action])

            # store stats for user visibility
            self.snapshot(state, message, action, payoff)

        # Convert lists to tensors
        rewards = torch.tensor(rewards, dtype=torch.float32)
        logprobs_msg = torch.stack(logprobs_msg)
        logprobs_act = torch.stack(logprobs_act)

        # REINFORCE loss for whole batch
        loss = - torch.mean(rewards * (logprobs_msg + logprobs_act))

        # update parameters
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        return float(rewards.mean())

In [100]:
game = ReinforcedRothErev(3, 3, 3, l=1)

for t in range(100):     # 100 gradient updates
    avg_r = game.play_batch(batch_size=100)   # 20 samples per update
    # print(f"Update {t}, average reward = {avg_r:.3f}")

payoffs = [x["p"] for x in game.stats]
print("Total snapshots:", len(payoffs))
print("First 20 payoffs:", payoffs[:20])
print("Last 20 payoffs:", payoffs[-20:])
print("Total reward:", sum(payoffs))

Total snapshots: 10000
First 20 payoffs: [1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0]
Last 20 payoffs: [1, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1]
Total reward: 3893
