In [None]:
import matplotlib.pyplot as plt
import random
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

In [None]:
class IteratedPrisonersDilemma:
    def __init__(self, max_steps):
        """
        max_steps: number of steps (rounds) per episode
        """
        self.max_steps = max_steps
        self.reset()

    def reset(self):
        # History of (actionA, actionB). Each action is 0=C, 1=D
        self.history = []
        self.t = 0
        return self.history

    def step(self, actionA, actionB):
        """
        actionA, actionB in {0, 1} where:
          0 = Cooperate, 1 = Defect
        """
        self.history.append((actionA, actionB))

        # Calculate reward
        if actionA == 0 and actionB == 0:
            # Both cooperate
            rewardA, rewardB = 3, 3
        elif actionA == 1 and actionB == 1:
            # Both defect
            rewardA, rewardB = 1, 1
        elif actionA == 0 and actionB == 1:
            # A cooperates, B defects
            rewardA, rewardB = 0, 5
        else:
            # A defects, B cooperates
            rewardA, rewardB = 5, 0

        self.t += 1
        done = (self.t >= self.max_steps)
        return self.history, rewardA, rewardB, done

In [None]:
# Simple mapping for pairs -> int
pair_to_int = {
    (0, 0): 0,
    (0, 1): 1,
    (1, 0): 2,
    (1, 1): 3
}

In [None]:
def generate_square_subsequent_mask(sz):
    """
    Generate a causal mask for a sequence of length `sz`.
    """
    # Upper triangular matrix of 1s, shifted by 1 so the diagonal is all allowed
    # positions. This means position i can attend up to i (including itself),
    # but not beyond.
    mask = torch.triu(torch.ones(sz, sz), diagonal=1)
    # Convert 1 -> -inf and 0 -> 0. This tells attention to ignore future tokens.
    mask = mask.masked_fill(mask == 1, float('-inf'))
    return mask

In [None]:
class TransformerPolicy(nn.Module):
    def __init__(self,
                 max_length,
                 d_model=32,
                 nhead=4,
                 num_layers=2,
                 device='cpu'):
        super().__init__()
        self.max_length = max_length
        self.device = device

        # We have 4 possible tokens for the pair-history
        self.vocab_size = 4
        self.embedding = nn.Embedding(self.vocab_size, d_model)

        # Positional encoding
        self.pos_encoding = nn.Embedding(self.max_length, d_model)

        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model,
                                                   nhead=nhead,
                                                   dim_feedforward=64)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer,
                                                         num_layers=num_layers)

        # Output layer: 2 actions (Cooperate or Defect)
        self.action_head = nn.Linear(d_model, 2)

    def forward(self, history_tokens):
        """
        history_tokens: [batch_size, seq_len] (integers in [0..3])
        """
        batch_size, seq_len = history_tokens.shape

        # Token embedding
        x = self.embedding(history_tokens)  # shape: [batch_size, seq_len, d_model]

        # Add positional encodings
        positions = torch.arange(seq_len, device=self.device).unsqueeze(0)  # [1, seq_len]
        pos_emb = self.pos_encoding(positions)  # [1, seq_len, d_model]
        x = x + pos_emb

        # The Transformer in PyTorch expects [seq_len, batch_size, d_model]
        x = x.permute(1, 0, 2)  # -> [seq_len, batch_size, d_model]

        causal_mask = generate_square_subsequent_mask(seq_len).to(self.device)

        # Forward through the Transformer
        encoded = self.transformer_encoder(x, mask=causal_mask)  # [seq_len, batch_size, d_model]

        # Take the last hidden state (the final token’s representation)
        last_hidden = encoded[-1, :, :]  # shape: [batch_size, d_model]

        # Output logits for next action
        logits = self.action_head(last_hidden)  # [batch_size, 2]
        return logits


In [None]:
# Player B strategies

import random

def always_cooperate(historyA, historyB):
    """
    Always Cooperate Strategy:
    Player B always cooperates, regardless of history.
    """
    return 0  # Always cooperate


def always_defect(historyA, historyB):
    """
    Always Defect Strategy:
    Player B always defects, regardless of history.
    """
    return 1  # Always defect


def tit_for_tat(historyA, historyB):
    """
    Tit-for-Tat Strategy:
    - On the first move, cooperate.
    - On subsequent moves, do whatever Player A did in the previous round.
    """
    if not historyB:
        return 0  # First move: cooperate
    else:
        return historyA[-1]  # Mimic the opponent's last move


def suspicious_tit_for_tat(historyA, historyB):
    """
    Suspicious Tit-for-Tat:
    - On the first move, defect (be suspicious).
    - On subsequent moves, do whatever Player A did in the previous round.
    """
    if not historyB:
        return 1  # First move: defect
    else:
        return historyA[-1]


def grudger(historyA, historyB):
    """
    Grudger (a.k.a. Grim Trigger):
    - Cooperate unless the opponent has ever defected in the past.
    - If the opponent defects once, defect forever.
    """
    if 1 in historyA:
        # If the opponent has ever defected, defect forever
        return 1
    else:
        return 0


def pavlov(historyA, historyB):
    """
    Pavlov (a.k.a. Win-Stay, Lose-Shift):
    - If both players made the same move in the previous round (both cooperated or both defected), cooperate this round.
    - Otherwise (if one cooperated and the other defected), defect this round.
    - First move: cooperate by default.
    """
    if not historyB:
        return 0  # No history yet, default to cooperate
    else:
        last_move_A = historyA[-1]
        last_move_B = historyB[-1]
        if last_move_A == last_move_B:
            # If last round was CC or DD, cooperate
            return 0
        else:
            # If last round was CD or DC, defect
            return 1


def random_strategy(historyA, historyB):
    """
    Random Strategy:
    - Player B chooses randomly between cooperate (0) and defect (1) each round.
    """
    return random.randint(0, 1)


def bully(historyA, historyB):
    """
    Bully Strategy:
    - First round: Defect.
    - If the opponent cooperated last round, continue defecting.
    - If the opponent defected last round, switch to cooperating.
    """
    if not historyB:
        return 1  # First move: defect
    else:
        if historyA[-1] == 0:
            # Opponent cooperated last time -> keep defecting
            return 1
        else:
            # Opponent defected last time -> try cooperating
            return 0


def random_tft(historyA, historyB):
    """
    Randomized Tit-for-Tat:
    - Like Tit-for-Tat, but occasionally (with small probability) defect anyway,
      to inject some unpredictability.
    """
    if not historyB:
        return 0  # First move: cooperate

    # Typically, do what the opponent did last time
    move = historyA[-1]

    # With a small probability (say 10%), do the opposite for randomness
    if random.random() < 0.1:
        move = 1 - move
    return move


def detective(historyA, historyB):
    """
    Detective Strategy (a playful example):
    - Start with a sequence of moves to 'test' the opponent: Cooperate, Defect, Cooperate, Cooperate.
    - After the initial testing phase, follow a Tit-for-Tat approach,
      but if ever the opponent defects in response to your testing, grudges can form.
    """
    # Predefined opening sequence for the first 4 moves:
    opening_sequence = [0, 1, 0, 0]  # C, D, C, C

    # If still in the opening phase, follow the script:
    if len(historyB) < 4:
        return opening_sequence[len(historyB)]

    # Post-opening: revert to a more typical approach (here, Tit-for-Tat)
    return tit_for_tat(historyA, historyB)


def tit_for_two_tats(historyA, historyB):
    """
    Tit-for-Two-Tats:
    - Cooperate by default until the opponent defects twice in a row.
    - Once you see two consecutive defects, defect next round.
    - Then go back to cooperating, but keep watching.
    """
    # If there's less than 2 rounds of history, cooperate
    if len(historyB) < 2:
        return 0

    # If opponent defected in the last two rounds, defect
    if historyA[-1] == 1 and historyA[-2] == 1:
        return 1
    else:
        return 0


def generous_tit_for_tat(historyA, historyB):
    """
    Generous Tit-for-Tat:
    - Normally, copy the opponent's last move (like TFT).
    - But if the opponent defected last round, cooperate anyway with a certain probability
      (e.g., 30% chance to forgive).
    """
    forgiveness_probability = 0.3

    if not historyB:
        return 0  # first move: cooperate

    if historyA[-1] == 1:
        # Opponent defected
        if random.random() < forgiveness_probability:
            return 0  # Forgive and cooperate
        else:
            return 1  # Defect as retaliation
    else:
        # Opponent cooperated last round
        return 0


def adaptive_strategy(historyA, historyB):
    """
    Adaptive Strategy:
    - Keep track of the opponent's cooperation ratio in the past few rounds (e.g., last 5).
    - If the ratio of opponent's cooperation is above a threshold, cooperate; otherwise defect.
    """
    window_size = 5
    threshold = 0.5

    if not historyB:
        return 0  # first move: cooperate

    # Look at opponent's last `window_size` moves
    recent_moves = historyA[-window_size:]
    cooperation_ratio = sum(1 for move in recent_moves if move == 0) / len(recent_moves)

    if cooperation_ratio > threshold:
        return 0  # cooperate
    else:
        return 1  # defect


def lookback_defect_if_opponent_defected_x_rounds_ago(historyA, historyB, x=2):
    """
    Lookback Strategy (Parameterizable):
    - Cooperate unless the opponent defected exactly x rounds ago.
    - For example, if x=2, always 'retaliate' 2 rounds after the opponent defects.
    - This function is written in a more general form.
    """
    if len(historyA) < x:
        return 0  # not enough history, cooperate
    if historyA[-x] == 1:
        return 1  # defect if the opponent defected x rounds ago
    return 0

manual_strategies = [
    always_cooperate,
    always_defect,
    tit_for_tat,
    suspicious_tit_for_tat,
    grudger,
    pavlov,
    random_strategy,
    bully,
    random_tft,
    detective,
    tit_for_two_tats,
    generous_tit_for_tat,
    adaptive_strategy,
    lookback_defect_if_opponent_defected_x_rounds_ago
]


In [None]:
def snapshot_policy(policy):
    # returns a new policy with the same weights
    new_policy = TransformerPolicy(
        max_length=policy.max_length,
        d_model=policy.embedding.embedding_dim,
        nhead=policy.transformer_encoder.layers[0].self_attn.num_heads,
        num_layers=len(policy.transformer_encoder.layers),
        device=policy.device,
    )
    new_policy.load_state_dict(policy.state_dict())
    new_policy.to(new_policy.device)
    new_policy.eval()  # freeze it
    return new_policy

def pick_opponent(policy_pool, manual_strategies):
    """
    Randomly pick an opponent from a combined pool:
      - Some fraction from the policy_pool
      - Some fraction from the manual_strategies
    You can tune the ratio of selection.
    """
    if policy_pool != [] and random.random() < 0.7:
        return random.choice(policy_pool)
    return random.choice(manual_strategies)

def get_action_b(history, opponent, device='cpu'):
    """
    If opponent is a TransformerPolicy, we forward-pass it.
    If opponent is a manual strategy function, we just call it.
    """
    # history is a list of (a, b) pairs so far
    # Let's gather separate lists of A and B moves for manual strategies
    historyA = [ab[0] for ab in history]
    historyB = [ab[1] for ab in history]

    if isinstance(opponent, nn.Module):
        # Opponent is a frozen policy
        # We do the same tokenization as for A
        pair_indices = [pair_to_int[p] for p in history]
        history_tokens = torch.tensor([pair_indices], dtype=torch.long, device=device)
        if len(pair_indices) == 0:
            # always cooperate on first move
            actionB = torch.tensor([0], device=device)  # cooperate
        else:
            with torch.no_grad():
                logitsB = opponent(history_tokens)
                probsB = F.softmax(logitsB, dim=-1)  # shape [1, 2]
                distB = torch.distributions.Categorical(probsB)
                actionB = distB.sample()  # shape [1]
        return actionB.item()
    else:
        # Opponent is a manual strategy function
        return opponent(historyA, historyB)

In [None]:
def play_one_episode(env, policyA, opponent, device='cpu'):
    """
    Play one IPD episode from A's perspective against a chosen 'opponent'.
    Returns:
      - actionA_logprobs: list of log-probs for A's actions
      - rewardsA: list of reward(A) for each step
    """
    history = env.reset()  # empty at start
    done = False

    actionA_logprobs = []
    rewardsA = []

    while not done:
        # Convert A’s observation to tokens
        pair_indices = [pair_to_int[p] for p in history]  # e.g. [0,1,3,...]
        history_tokens = torch.tensor([pair_indices], dtype=torch.long, device=device)

        # Forward pass for A
        if len(pair_indices) == 0:
            # always cooperate on first move
            actionA = torch.tensor([0], device=device)  # cooperate
            logprobA = torch.tensor([0.0], device=device)  # dummy log-prob
        else:
            logitsA = policyA(history_tokens)
            probsA = F.softmax(logitsA, dim=-1)  # [1, 2]
            distA = torch.distributions.Categorical(probsA)
            actionA = distA.sample()  # shape [1]
            logprobA = distA.log_prob(actionA)

        # Opponent's move
        actionB = get_action_b(history, opponent, device=device)

        # Step environment
        history, rA, _, done = env.step(actionA.item(), actionB)

        # Store data
        actionA_logprobs.append(logprobA)
        rewardsA.append(rA)

    return actionA_logprobs, rewardsA

In [None]:
def compute_returns(rewards, gamma=1.0):
    """
    Compute discounted returns for REINFORCE.
    If gamma=1.0 => no discount, just sum of future rewards.
    """
    returns = []
    R = 0
    for r in reversed(rewards):
        R = r + gamma * R
        returns.insert(0, R)
    return returns

def train_against_mixture(policyA,
                          policy_pool,
                          manual_strategies,
                          episodes=5000,
                          batch_size=10,       # Number of episodes per update
                          gamma=1.0,
                          lr=1e-3,
                          max_steps=200,
                          snapshot_interval=100,
                          device='cpu'):
    """
    Trains A only (purely for A's rewards),
    while B is chosen from a mixture of older policies & manual strategies.
    """
    env = IteratedPrisonersDilemma(max_steps=max_steps)
    optimizer = optim.Adam(policyA.parameters(), lr=lr)
    policyA.to(device)

    episode_returns = []

    episode_count = 0
    while episode_count < episodes:
        print(f"Episode: {episode_count}")

        # Lists to store data across multiple episodes
        all_logprobs = []
        all_returns = []

        # Gather 'batch_size' episodes
        for _ in range(batch_size):
            # Randomly pick an opponent from the pool for this episode
            opponent = pick_opponent(policy_pool, manual_strategies)

            # Play one episode
            actionA_logprobs, rewardsA = play_one_episode(env, policyA, opponent, device=device)

            episode_return_A = sum(rewardsA)
            episode_returns.append(episode_return_A)

            # Compute returns for A
            returnsA = compute_returns(rewardsA, gamma=gamma)

            all_logprobs.extend(actionA_logprobs)
            all_returns.extend(returnsA)

            episode_count += 1
            if episode_count >= episodes:
                break

        # REINFORCE loss for A
        logprobs_t = torch.cat(all_logprobs)  # shape [N] where N=number of steps
        returns_t = torch.tensor(all_returns, dtype=torch.float32, device=device)

        advantages_t = returns_t - returns_t.mean()

        loss = -(logprobs_t * advantages_t).mean()

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(policyA.parameters(), max_norm=1.0)
        optimizer.step()

        # Occasionally snapshot A's policy to put into policy_pool
        if episode_count % snapshot_interval == 0:
            if len(policy_pool) == 10:
                policy_pool.pop(0)
            policy_pool.append(snapshot_policy(policyA))

        if episode_count % 100 == 0:
            avg_return = sum(episode_returns[-100:]) / 100  # undiscounted total average
            print(f"Episode {episode_count}: total A reward = {avg_return}, loss={loss.item():.3f}")

    plt.plot(episode_returns)
    plt.xlabel("Episode")
    plt.ylabel("Return")
    plt.title("Training Progress")
    plt.show()

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

num_rounds = 200

# Create A's policy
policyA = TransformerPolicy(max_length=num_rounds, d_model=32, nhead=4, num_layers=2, device=device)

# Create an empty policy pool (or maybe with an initial snapshot)
policy_pool = []  # later we'll add snapshots of A

# Train A
train_against_mixture(policyA,
                      policy_pool,
                      manual_strategies,
                      episodes=5000,
                      batch_size=10,
                      gamma=1.0,
                      lr=1e-3,
                      max_steps=num_rounds,
                      snapshot_interval=100,
                      device=device)

# After training, you can run play_one_episode(...)
# and inspect the learned behavior.