In [None]:
# combat_rnn_game.py
# Requirements: torch, numpy
# Run: python combat_rnn_game.py

import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import deque, Counter
import os
import json
import time

# ---------------- Config ----------------
SEQ_LEN = 6            # how many previous player moves to feed the RNN
HIDDEN_SIZE = 64
NUM_LAYERS = 1
LR = 0.01
PRETRAIN_SAMPLES = 1500  # synthetic pretrain size
ONLINE_LR = 0.005        # smaller LR for online updates
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
MODEL_PATH = "combat_rnn_model.pt"
MOVES = ['Attack', 'Block', 'Dodge']
move_to_idx = {m: i for i, m in enumerate(MOVES)}
idx_to_move = {i: m for m, i in move_to_idx.items()}


# Utility: which move beats which
# We'll define: Attack beats Dodge (attack hits dodger),
# Block beats Attack (block defends), Dodge beats Block (dodge around block).
WIN_MAP = {
    'Attack': 'Dodge',   # Attack beats Dodge
    'Block':  'Attack',  # Block beats Attack
    'Dodge':  'Block'    # Dodge beats Block
}
# To choose bot move that counters predicted player move, pick the move that beats the player's predicted move:
# e.g. if player predicted to do Attack -> bot should do Block (because Block beats Attack)
def counter_move(pred_player_move):
    # find move m such that WIN_MAP[m] == pred_player_move
    for m, beats in WIN_MAP.items():
        if beats == pred_player_move:
            return m
    return random.choice(MOVES)


# ---------------- Synthetic opponent generators ----------------
def aggressive_player():
    # favors Attack heavily, but with occasional switches
    while True:
        r = random.random()
        if r < 0.7:
            yield 'Attack'
        elif r < 0.85:
            yield 'Dodge'
        else:
            yield 'Block'

def defensive_player():
    # favors Block
    while True:
        r = random.random()
        if r < 0.7:
            yield 'Block'
        elif r < 0.85:
            yield 'Attack'
        else:
            yield 'Dodge'

def random_player():
    while True:
        yield random.choice(MOVES)

def patterned_player(pattern=['Attack', 'Attack', 'Block', 'Dodge']):
    i = 0
    while True:
        yield pattern[i % len(pattern)]
        i += 1

def mixed_player():
    # switch strategy every 50 moves
    gens = [aggressive_player(), defensive_player(), patterned_player(), random_player()]
    i = 0
    gen = random.choice(gens)
    while True:
        if i % 50 == 0:
            gen = random.choice(gens)
        yield next(gen)
        i += 1

SYNTH_OPS = {
    'aggressive': aggressive_player,
    'defensive': defensive_player,
    'random': random_player,
    'patterned': lambda: patterned_player(['Attack','Dodge','Block']),
    'mixed': mixed_player
}

# ---------------- Dataset helpers ----------------
def moves_to_onehot(seq):
    # seq: list of move strings length SEQ_LEN
    arr = np.zeros((len(seq), len(MOVES)), dtype=np.float32)
    for i, m in enumerate(seq):
        idx = move_to_idx[m]
        arr[i, idx] = 1.0
    return arr  # shape (seq_len, 3)

def seq_to_tensor(seq):
    # return shape (1, SEQ_LEN, 3) float tensor
    return torch.tensor(moves_to_onehot(seq), dtype=torch.float32, device=DEVICE).unsqueeze(0)


# ---------------- Model definition ----------------
class GRUPredictor(nn.Module):
    def __init__(self, input_size=3, hidden_size=HIDDEN_SIZE, num_layers=NUM_LAYERS, output_size=3):
        super().__init__()
        self.gru = nn.GRU(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        # x: (batch, seq_len, input_size)
        out, h = self.gru(x)           # out: (batch, seq_len, hidden)
        last = out[:, -1, :]           # (batch, hidden)
        logits = self.fc(last)         # (batch, output_size)
        return logits

# ---------------- Train on synthetic data for a warm start ----------------
def generate_pretrain_data(gen_fn, n_samples=1000):
    gen = gen_fn()
    history = deque([next(gen) for _ in range(SEQ_LEN)], maxlen=SEQ_LEN)
    X = []
    Y = []
    for _ in range(n_samples):
        X.append(list(history))
        nxt = next(gen)
        Y.append(move_to_idx[nxt])
        history.append(nxt)
    return X, Y

def pretrain_model(model, epochs=6, batch_size=64):
    model.train()
    optimizer = torch.optim.Adam(model.parameters(), lr=LR)
    loss_fn = nn.CrossEntropyLoss()
    # build mixed synthetic data
    X_all, Y_all = [], []
    per_gen = PRETRAIN_SAMPLES // len(SYNTH_OPS)
    for name, gen in SYNTH_OPS.items():
        X, Y = generate_pretrain_data(gen, per_gen)
        X_all.extend(X)
        Y_all.extend(Y)
    # shuffle
    perm = list(range(len(Y_all)))
    random.shuffle(perm)
    X_all = [X_all[i] for i in perm]
    Y_all = [Y_all[i] for i in perm]
    # training loop
    for ep in range(epochs):
        total_loss = 0.0
        for i in range(0, len(Y_all), batch_size):
            xb = X_all[i:i+batch_size]
            yb = Y_all[i:i+batch_size]
            x_tensor = torch.stack([torch.tensor(moves_to_onehot(seq), device=DEVICE) for seq in xb], dim=0)
            y_tensor = torch.tensor(yb, dtype=torch.long, device=DEVICE)
            optimizer.zero_grad()
            logits = model(x_tensor)
            loss = loss_fn(logits, y_tensor)
            loss.backward()
            optimizer.step()
            total_loss += loss.item() * x_tensor.size(0)
        avg = total_loss / len(Y_all)
        print(f"[Pretrain] Epoch {ep+1}/{epochs}  loss={avg:.4f}")
    return model

# ---------------- Online update (one step) ----------------
def online_update(model, optimizer, criterion, seq_history, true_move):
    model.train()
    x = seq_to_tensor(list(seq_history))  # (1, seq_len, 3)
    y = torch.tensor([move_to_idx[true_move]], dtype=torch.long, device=DEVICE)
    optimizer.zero_grad()
    logits = model(x)
    loss = criterion(logits, y)
    loss.backward()
    optimizer.step()
    return loss.item()

# ---------------- Save / Load helpers ----------------
def save_model(model, path=MODEL_PATH, meta=None):
    data = {
        'state_dict': model.state_dict(),
        'meta': meta or {}
    }
    torch.save(data, path)
    print("Saved model to", path)

def load_model(path=MODEL_PATH):
    if not os.path.exists(path):
        return None
    data = torch.load(path, map_location=DEVICE)
    model = GRUPredictor().to(DEVICE)
    model.load_state_dict(data['state_dict'])
    print("Loaded model from", path)
    return model

# ---------------- Gameplay ----------------
def play_game(model=None, opponent_mode='human'):
    # model: GRUPredictor instance (can be None -> random bot)
    # opponent_mode: if not 'human', uses synthetic generator for simulated human moves
    if model is None:
        model = GRUPredictor().to(DEVICE)
    model.eval()
    opt_online = torch.optim.SGD(model.parameters(), lr=ONLINE_LR)
    loss_fn = nn.CrossEntropyLoss()

    # initialize history with random moves
    history = deque([random.choice(MOVES) for _ in range(SEQ_LEN)], maxlen=SEQ_LEN)

    # stats
    total_rounds = 0
    correct_preds = 0
    bot_wins = 0
    player_wins = 0
    ties = 0
    pred_counts = Counter()

    # if opponent_mode is synthetic, create generator
    opp_gen = None
    if opponent_mode != 'human':
        gen_fn = SYNTH_OPS.get(opponent_mode, random_player)
        opp_gen = gen_fn()

    print("\n--- Simple Combat Game (Attack, Block, Dodge) ---")
    print("Type: 'a' for Attack, 'b' for Block, 'd' for Dodge, 'q' to quit.")
    print("Bot predicts your next move and chooses a counter.\n")

    try:
        while True:
            # Display last few moves
            print("\nLast moves:", list(history))
            # Model prediction
            x = seq_to_tensor(list(history))
            with torch.no_grad():
                logits = model(x)
                probs = F.softmax(logits, dim=1).squeeze(0).cpu().numpy()
                pred_idx = int(logits.argmax(dim=1).item())
                pred_move = idx_to_move[pred_idx]

            # Bot chooses counter
            bot_move = counter_move(pred_move)

            # Ask for player move if human, otherwise sample from opp_gen
            if opponent_mode == 'human':
                ans = input("Your move (a/b/d or q): ").strip().lower()
                if ans == 'q':
                    break
                if ans not in ['a', 'b', 'd']:
                    print("Invalid. Use a/b/d or q.")
                    continue
                player_move = {'a':'Attack','b':'Block','d':'Dodge'}[ans]
            else:
                player_move = next(opp_gen)
                print(f"[Simulated player] {player_move}")

            # Evaluate result
            total_rounds += 1
            pred_counts[pred_move] += 1
            if pred_move == player_move:
                correct_preds += 1
            # Decide winner
            if bot_move == player_move:
                ties += 1
                outcome = "Tie"
            elif WIN_MAP[bot_move] == player_move:
                bot_wins += 1
                outcome = "Bot wins"
            elif WIN_MAP[player_move] == bot_move:
                player_wins += 1
                outcome = "Player wins"
            else:
                outcome = "Undetermined"

            print(f"Bot predicted you would do: {pred_move}  (conf {probs[pred_idx]:.2f})")
            print(f"Bot move: {bot_move} | Your move: {player_move} -> {outcome}")

            # Online learning: update model with the new observed sample (history -> true next)
            # Use one gradient step
            loss_val = online_update(model, opt_online, loss_fn, history, player_move)

            # Update history (append player's move)
            history.append(player_move)

            # Print stats
            acc = correct_preds / total_rounds if total_rounds else 0.0
            print(f"Rounds: {total_rounds} | Pred Acc: {acc:.3f} | Bot wins: {bot_wins} | Player wins: {player_wins} | Ties: {ties} | Last loss: {loss_val:.4f}")

    except KeyboardInterrupt:
        print("\nInterrupted by user.")

    # Summary
    print("\n--- Game Summary ---")
    print("Total rounds:", total_rounds)
    print("Prediction accuracy:", correct_preds, "/", total_rounds, f"({(correct_preds/total_rounds if total_rounds else 0):.3f})")
    print("Bot wins:", bot_wins, "Player wins:", player_wins, "Ties:", ties)
    print("Prediction distribution (top 3):", pred_counts.most_common(3))

    return model

# ---------------- CLI entrypoint ----------------
def main():
    print("Simple Combat Game with RNN predictor\n")
    # Offer to load existing model
    model = None
    if os.path.exists(MODEL_PATH):
        yn = input(f"Load saved model from {MODEL_PATH}? (y/n): ").strip().lower()
        if yn == 'y':
            model = load_model(MODEL_PATH)
    if model is None:
        model = GRUPredictor().to(DEVICE)
        # Offer to pretrain on synthetic opponents for a warm start
        yn = input("Pretrain model on synthetic opponents for a better start? (y/n): ").strip().lower()
        if yn == 'y':
            pretrain_model(model, epochs=6)

    # Choose playing mode
    print("\nChoose opponent mode:")
    print("1) human (you play)")
    print("2) aggressive")
    print("3) defensive")
    print("4) random")
    print("5) patterned")
    print("6) mixed")
    choice = input("Select 1-6: ").strip()
    modes = {'1':'human','2':'aggressive','3':'defensive','4':'random','5':'patterned','6':'mixed'}
    mode = modes.get(choice, 'human')

    model = play_game(model=model, opponent_mode=mode)

    # After game, offer to save model
    yn = input("Save trained model to disk? (y/n): ").strip().lower()
    if yn == 'y':
        meta = {'saved_at': time.time()}
        save_model(model, meta=meta)

if __name__ == "__main__":
    main()

Simple Combat Game with RNN predictor

Pretrain model on synthetic opponents for a better start? (y/n): y
[Pretrain] Epoch 1/6  loss=1.0578
[Pretrain] Epoch 2/6  loss=0.9563
[Pretrain] Epoch 3/6  loss=0.8755
[Pretrain] Epoch 4/6  loss=0.8342
[Pretrain] Epoch 5/6  loss=0.8179
[Pretrain] Epoch 6/6  loss=0.7874

Choose opponent mode:
1) human (you play)
2) aggressive
3) defensive
4) random
5) patterned
6) mixed
Select 1-6: 1

--- Simple Combat Game (Attack, Block, Dodge) ---
Type: 'a' for Attack, 'b' for Block, 'd' for Dodge, 'q' to quit.
Bot predicts your next move and chooses a counter.


Last moves: ['Dodge', 'Dodge', 'Dodge', 'Dodge', 'Attack', 'Dodge']
Your move (a/b/d or q): a
Bot predicted you would do: Dodge  (conf 0.51)
Bot move: Attack | Your move: Attack -> Tie
Rounds: 1 | Pred Acc: 0.000 | Bot wins: 0 | Player wins: 0 | Ties: 1 | Last loss: 1.3675

Last moves: ['Dodge', 'Dodge', 'Dodge', 'Attack', 'Dodge', 'Attack']
Your move (a/b/d or q): d
Bot predicted you would do: Dodge  