In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import random
from collections import deque

# --- Hyperparameters ---
EPISODES = 15000
GAMMA = 0.95
EPSILON = 1.0
EPSILON_DECAY = 0.9995
EPSILON_MIN = 0.05
LEARNING_RATE = 0.001
BATCH_SIZE = 64
MEMORY_SIZE = 10000

# Track training stats
losses = []
win_rates = []
draw_rates = []
loss_rates = []
epsilons = []
avg_q_values = []
episode_rewards = []
rolling_win = []
rolling_draw = []
rolling_loss = []

# --- Tic Tac Toe Environment ---
class TicTacToeEnv:
    def __init__(self):
        self.reset()

    def reset(self):
        self.board = np.zeros(9, dtype=int)
        self.done = False
        self.winner = None
        return self.board.copy()

    def get_valid_actions(self):
        return [i for i in range(9) if self.board[i] == 0]

    def check_winner(self, board):
        wins = [(0,1,2), (3,4,5), (6,7,8),
                (0,3,6), (1,4,7), (2,5,8),
                (0,4,8), (2,4,6)]
        for i, j, k in wins:
            total = board[i] + board[j] + board[k]
            if total == 3:
                return 1
            elif total == -3:
                return -1
        return 0

    def step(self, action, player=1):
        if self.board[action] != 0 or self.done:
            return self.board.copy(), -10, True
        self.board[action] = player
        winner = self.check_winner(self.board)
        if winner != 0:
            self.done = True
            self.winner = winner
            return self.board.copy(), 1 if winner == 1 else -1, True
        if 0 not in self.board:
            self.done = True
            self.winner = 0
            return self.board.copy(), 0.5, True
        return self.board.copy(), 0, False

# --- DQN Model ---
class DQN(nn.Module):
    def __init__(self):
        super(DQN, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(9, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, 9)
        )

    def forward(self, x):
        return self.model(x)

# --- Replay Memory ---
class ReplayMemory:
    def __init__(self, size):
        self.memory = deque(maxlen=size)

    def add(self, experience):
        self.memory.append(experience)

    def sample(self, batch_size):
        return random.sample(self.memory, min(len(self.memory), batch_size))

# --- Training Function ---
def train_dqn():
    global EPSILON
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    env = TicTacToeEnv()
    model = DQN().to(device)
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    criterion = nn.MSELoss()
    memory = ReplayMemory(MEMORY_SIZE)
    stats = {"win": 0, "draw": 0, "loss": 0}

    for episode in range(1, EPISODES + 1):
        state = env.reset()
        done = False
        total_reward = 0
        q_total = 0
        q_count = 0

        while not done:
            state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
            if random.random() < EPSILON:
                action = random.choice(env.get_valid_actions())
            else:
                with torch.no_grad():
                    q_values = model(state_tensor)
                    valid_q = q_values[0][env.get_valid_actions()]
                    action = env.get_valid_actions()[valid_q.argmax().item()]
                    q_total += valid_q.max().item()
                    q_count += 1

            next_state, reward, done = env.step(action, player=1)
            total_reward += reward

            if not done:
                opp_action = random.choice(env.get_valid_actions())
                next_state, opp_reward, done = env.step(opp_action, player=-1)
                if opp_reward == -1:
                    reward = -1
                total_reward += opp_reward

            memory.add((state, action, reward, next_state, done))
            state = next_state

            # Training step
            batch = memory.sample(BATCH_SIZE)
            if len(batch) > 1:
                states, actions, rewards, next_states, dones = zip(*batch)
                states = torch.FloatTensor(states).to(device)
                next_states = torch.FloatTensor(next_states).to(device)
                actions = torch.LongTensor(actions).unsqueeze(1).to(device)
                rewards = torch.FloatTensor(rewards).unsqueeze(1).to(device)
                dones = torch.BoolTensor(dones).unsqueeze(1).to(device)

                q_values = model(states).gather(1, actions)
                with torch.no_grad():
                    max_next_q = model(next_states).max(1)[0].unsqueeze(1)
                    target_q = rewards + GAMMA * max_next_q * (~dones)

                loss = criterion(q_values, target_q)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                losses.append(loss.item())

        if EPSILON > EPSILON_MIN:
            EPSILON *= EPSILON_DECAY

        epsilons.append(EPSILON)
        avg_q_values.append(q_total / q_count if q_count > 0 else 0)
        episode_rewards.append(total_reward)

        if env.winner == 1:
            stats["win"] += 1
        elif env.winner == -1:
            stats["loss"] += 1
        else:
            stats["draw"] += 1

        if episode % 500 == 0:
            total = stats["win"] + stats["draw"] + stats["loss"]
            win_rates.append(stats["win"] / total)
            draw_rates.append(stats["draw"] / total)
            loss_rates.append(stats["loss"] / total)
            rolling_win.append(stats["win"])
            rolling_draw.append(stats["draw"])
            rolling_loss.append(stats["loss"])
            print(f"Episode {episode} | W: {stats['win']} | L: {stats['loss']} | D: {stats['draw']}")
            stats = {"win": 0, "draw": 0, "loss": 0}

    return model

# --- Plotting Function ---
def plot_metrics():
    fig, axs = plt.subplots(3, 2, figsize=(15, 10))

    axs[0, 0].plot(losses)
    axs[0, 0].set_title("Training Loss")

    axs[0, 1].plot(win_rates, label="Win")
    axs[0, 1].plot(draw_rates, label="Draw")
    axs[0, 1].plot(loss_rates, label="Loss")
    axs[0, 1].set_title("Win/Draw/Loss Rates")
    axs[0, 1].legend()

    axs[1, 0].plot(epsilons)
    axs[1, 0].set_title("Epsilon Decay")

    axs[1, 1].plot(avg_q_values)
    axs[1, 1].set_title("Average Q-values")

    axs[2, 0].plot(episode_rewards)
    axs[2, 0].set_title("Reward per Episode")

    axs[2, 1].plot(rolling_win, label="Win")
    axs[2, 1].plot(rolling_draw, label="Draw")
    axs[2, 1].plot(rolling_loss, label="Loss")
    axs[2, 1].set_title("Rolling Win/Draw/Loss (500 ep)")
    axs[2, 1].legend()

    plt.tight_layout()
    plt.show()

# --- Pretty Console Game ---
def play_against_ai(model):
    env = TicTacToeEnv()
    state = env.reset()
    done = False
    print("\nYou are 'O' and AI is 'X'. Enter your move as: row col (0-2)\n")

    def display_board(board):
        symbols = {0: ' ', 1: 'X', -1: 'O'}
        for i in range(3):
            row = ' | '.join(symbols[board[j]] for j in range(i*3, i*3+3))
            print(f" {row} ")
            if i < 2:
                print("-----------")

    def parse_move(inp):
        try:
            if len(inp.strip()) == 2 and inp[0].isdigit() and inp[1].isdigit():
                row, col = int(inp[0]), int(inp[1])
            else:
                row, col = map(int, inp.strip().split())
            if 0 <= row <= 2 and 0 <= col <= 2:
                return row * 3 + col
        except:
            pass
        return None

    while not done:
        display_board(env.board)
        move_str = input("\nYour move (row col): ")
        move = parse_move(move_str)
        if move is None or move not in env.get_valid_actions():
            print("Invalid move. Try again.")
            continue

        _, _, done = env.step(move, player=-1)
        if done:
            break

        state_tensor = torch.FloatTensor(env.board).unsqueeze(0)
        with torch.no_grad():
            q_values = model(state_tensor)
            valid_q = q_values[0][env.get_valid_actions()]
            ai_move = env.get_valid_actions()[valid_q.argmax().item()]
        _, _, done = env.step(ai_move, player=1)

    display_board(env.board)
    print()
    if env.winner == 1:
        print("AI wins!")
    elif env.winner == -1:
        print("You win!")
    else:
        print("It's a draw!")

# --- Run Everything ---
model = train_dqn()

play_against_ai(model)


  states = torch.FloatTensor(states).to(device)


Episode 500 | W: 287 | L: 137 | D: 76
Episode 1000 | W: 276 | L: 145 | D: 79
Episode 1500 | W: 278 | L: 122 | D: 100
Episode 2000 | W: 279 | L: 120 | D: 101
Episode 2500 | W: 259 | L: 140 | D: 101
Episode 3000 | W: 246 | L: 121 | D: 133
Episode 3500 | W: 255 | L: 140 | D: 105
Episode 4000 | W: 242 | L: 118 | D: 140
Episode 4500 | W: 232 | L: 127 | D: 141
Episode 5000 | W: 285 | L: 106 | D: 109
Episode 5500 | W: 260 | L: 102 | D: 138
Episode 6000 | W: 271 | L: 114 | D: 115
Episode 6500 | W: 276 | L: 85 | D: 139


In [None]:
plot_metrics()