In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import numpy as np
import random
from tqdm import tqdm

import time


def get_canonical_form(board):
    """보드의 8가지 대칭(회전, 대칭) 중 정규형(canonical form)을 찾습니다."""
    symmetries = []
    current_board = board.copy()
    for _ in range(4): # 4 rotations
        symmetries.append(current_board)
        symmetries.append(np.fliplr(current_board))
        current_board = np.rot90(current_board)

    # 튜플로 변환하여 정렬 가능하게 만듦
    symmetries_as_tuples = [tuple(b.flatten()) for b in symmetries]
    canonical_tuple = min(symmetries_as_tuples)
    return np.array(canonical_tuple).reshape(3, 3)

memoized_policy = {}
def build_optimal_policy_map(board, player):
    """모든 고유 상태에 대한 최적의 수를 미리 계산하여 맵에 저장 (재귀)"""
    canonical_board = get_canonical_form(board)
    canonical_tuple = tuple(canonical_board.flatten())

    if canonical_tuple in memoized_policy or TicTacToe()._check_winner_on_board(board) is not None:
        return

    game = TicTacToe()
    game.board = board
    available_actions = game.get_available_actions()
    
    best_score = -math.inf
    best_moves = []
    
    for action in available_actions:
        new_board = board.copy()
        row, col = action // 3, action % 3
        new_board[row, col] = player
        score = -minimax(new_board, -player)
        
        if score > best_score:
            best_score = score
            best_moves = [action]
        elif score == best_score:
            best_moves.append(action)

    memoized_policy[tuple(board.flatten())] = best_moves
    
    for move in best_moves:
        next_board = board.copy()
        row, col = move // 3, move % 3
        next_board[row, col] = player
        build_optimal_policy_map(next_board, -player)

# TicTacToe 클래스 내부에 헬퍼 함수 추가
TicTacToe._check_winner_on_board = lambda self, board: TicTacToe.check_winner(type('obj', (object,), {'board': board})())


def generate_all_optimal_trajectories():
    """사전 계산된 정책을 사용하여 모든 최적의 경로를 생성합니다."""
    print("Pre-calculating optimal policy for all unique states...")
    start_board = np.zeros((3, 3), dtype=int)
    build_optimal_policy_map(start_board, 1)
    
    print("Generating all unique optimal trajectories...")
    all_trajectories = []
    
    def find_paths(board, player, path):
        winner = TicTacToe()._check_winner_on_board(board)
        if winner is not None:
            states, actions = zip(*path) if path else ([], [])
            rewards = np.zeros(len(actions))
            
            # 승패에 따라 Reward-to-go 계산
            rewards_to_go = np.zeros_like(rewards, dtype=float)
            final_reward = winner # 1 for win, -1 for loss, 0 for draw
            for t in range(len(rewards)):
                rewards_to_go[t] = final_reward
                
            all_trajectories.append({
                'states': np.array(states),
                'actions': np.array(actions),
                'rewards_to_go': rewards_to_go
            })
            return

        optimal_moves = memoized_policy.get(tuple(board.flatten()))
        if not optimal_moves: return

        for move in optimal_moves:
            next_board = board.copy()
            row, col = move // 3, move % 3
            next_board[row, col] = player
            find_paths(next_board, -player, path + [(board.flatten(), move)])

    find_paths(start_board, 1, [])
    return all_trajectories


# --- 3. PyTorch Dataset & Decision Transformer 모델 (이전과 동일) ---
class TicTacToeDataset(Dataset):
    def __init__(self, trajectories, context_length):
        self.trajectories = trajectories
        self.context_length = context_length

    def __len__(self):
        return len(self.trajectories)

    def __getitem__(self, idx):
        traj = self.trajectories[idx]
        traj_len = len(traj['states'])
        start_idx = random.randint(0, traj_len - 1)
        states = traj['states'][start_idx : start_idx + self.context_length]
        actions = traj['actions'][start_idx : start_idx + self.context_length]
        rtgs = traj['rewards_to_go'][start_idx : start_idx + self.context_length]
        T = len(states)
        padding_len = self.context_length - T
        states = torch.tensor(np.pad(states, ((0, padding_len), (0, 0)), 'constant'), dtype=torch.float32)
        actions = torch.tensor(np.pad(actions, (0, padding_len), 'constant', constant_values=-1), dtype=torch.long)
        rtgs = torch.tensor(np.pad(rtgs, (0, padding_len), 'constant'), dtype=torch.float32).unsqueeze(1)
        mask = torch.cat([torch.ones(T), torch.zeros(padding_len)], dim=0)
        return states, actions, rtgs, mask

class DecisionTransformer(nn.Module):
    def __init__(self, state_dim, action_dim, n_head, n_layer, d_model, context_length):
        super().__init__()
        self.d_model = d_model
        self.context_length = context_length
        self.embed_state = nn.Linear(state_dim, d_model)
        self.embed_action = nn.Embedding(action_dim + 1, d_model)
        self.embed_rtg = nn.Linear(1, d_model)
        self.embed_timestep = nn.Embedding(context_length, d_model)
        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=n_head, batch_first=True)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=n_layer)
        self.predict_action = nn.Linear(d_model, action_dim)

    def forward(self, states, actions, rtgs, timesteps):
        action_embeddings = self.embed_action(actions + 1)
        state_embeddings = self.embed_state(states)
        rtg_embeddings = self.embed_rtg(rtgs)
        time_embeddings = self.embed_timestep(timesteps)
        state_embeddings += time_embeddings
        action_embeddings += time_embeddings
        rtg_embeddings += time_embeddings
        stacked_inputs = torch.stack(
            (rtg_embeddings, state_embeddings, action_embeddings), dim=1
        ).permute(0, 2, 1, 3).reshape(states.shape[0], 3 * self.context_length, self.d_model)
        causal_mask = nn.Transformer.generate_square_subsequent_mask(3 * self.context_length).to(states.device)
        encoder_output = self.transformer_encoder(stacked_inputs, mask=causal_mask)
        x = encoder_output[:, 1::3, :]
        action_preds = self.predict_action(x)
        return action_preds

# --- 4. 학습 및 평가 (하이퍼파라미터 조정) ---
def train():
    # Hyperparameters (조정됨)
    CONTEXT_LENGTH = 9  # 최대 9수면 게임이 끝나므로
    N_EPOCHS = 30       # 데이터가 고품질이므로 에포크 감소
    BATCH_SIZE = 64     # 데이터셋 크기에 맞춰 배치 사이즈 조정
    LR = 1e-4
    D_MODEL = 128
    N_HEAD = 4
    N_LAYER = 3

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    start_time = time.time()
    trajectories = generate_all_optimal_trajectories()
    end_time = time.time()
    print(f"Generated {len(trajectories)} unique optimal trajectories in {end_time - start_time:.2f} seconds.")
    
    dataset = TicTacToeDataset(trajectories, context_length=CONTEXT_LENGTH)
    dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

    model = DecisionTransformer(
        state_dim=9, action_dim=9, n_head=N_HEAD, n_layer=N_LAYER,
        d_model=D_MODEL, context_length=CONTEXT_LENGTH
    ).to(device)
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
    loss_fn = nn.CrossEntropyLoss()

    print("Starting training...")
    for epoch in range(N_EPOCHS):
        model.train()
        total_loss = 0
        for states, actions, rtgs, mask in dataloader:
            states, actions, rtgs, mask = states.to(device), actions.to(device), rtgs.to(device), mask.to(device)
            timesteps = torch.arange(CONTEXT_LENGTH, device=device).repeat(states.shape[0], 1)
            action_preds = model(states, actions, rtgs, timesteps)
            action_preds = action_preds.reshape(-1, 9)
            actions_target = actions.reshape(-1)
            mask = mask.reshape(-1).bool()
            action_preds = action_preds[mask]
            actions_target = actions_target[mask]
            valid_targets = actions_target != -1
            loss = loss_fn(action_preds[valid_targets], actions_target[valid_targets])
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        
        avg_loss = total_loss / len(dataloader)
        print(f"Epoch {epoch+1}/{N_EPOCHS}, Loss: {avg_loss:.4f}")
        
    print("Training finished.")
    return model

def play_game_with_model(model, context_length):
    model.eval()
    
    # 모델이 현재 사용 중인 device를 가져옵니다.
    device = next(model.parameters()).device
    
    env = TicTacToe()
    
    # 모든 텐서를 생성할 때 .to(device)를 붙여줍니다.
    states = torch.zeros((1, context_length, 9), dtype=torch.float32, device=device)
    actions = torch.full((1, context_length), -1, dtype=torch.long, device=device)
    rtgs = torch.zeros((1, context_length, 1), dtype=torch.float32, device=device)
    rtgs[0, :, 0] = 1.0 # 목표는 승리(RTG=1)
    timesteps = torch.arange(0, context_length, device=device).unsqueeze(0)

    print("\n--- New Game: You are 'O' ---")
    turn = 0
    done = False
    
    while not done and turn < 9:
        board_str = ""
        for i, cell in enumerate(env.get_state()):
            mark = 'X' if cell == 1 else 'O' if cell == -1 else str(i)
            board_str += f" {mark} "
            if (i+1) % 3 == 0:
                board_str += "\n" if i < 8 else ""
                if i < 8: board_str += "---+---+---\n"
        print(board_str)

        if env.current_player == 1:
            print("Model's turn ('X')...")
            with torch.no_grad():
                # 이제 모델과 입력 텐서가 모두 같은 device에 있습니다.
                pred_actions = model(states, actions, rtgs, timesteps)
            logits = pred_actions[0, turn, :]
            available_actions = env.get_available_actions()
            mask = torch.full_like(logits, -float('inf'))
            mask[available_actions] = 0
            move = (logits + mask).argmax().item()
            print(f"Model chooses action: {move}")
        else:
            try:
                move = int(input("Your turn ('O'). Enter move (0-8): "))
                if move not in env.get_available_actions():
                    print("Invalid move. Try again.")
                    continue
            except (ValueError, IndexError):
                print("Invalid input. Enter a number between 0 and 8.")
                continue

        # 상태 업데이트 시에도 .to(device)가 필요합니다.
        if turn < context_length:
            current_state_tensor = torch.tensor(env.get_state(), dtype=torch.float32).to(device)
            states[0, turn] = current_state_tensor
            actions[0, turn] = move
        
        _, reward, done = env.make_move(move)
        
        # RTG 업데이트
        if turn + 1 < context_length:
            rtgs[0, turn+1:] = rtgs[0, turn:-1].clone() # clone()을 사용하여 인플레이스 수정 방지
            rtgs[0, turn, 0] -= reward
        
        turn += 1
        print("-" * 20)

    winner = env.check_winner()
    print("--- Game Over ---")
    if winner == 1: print("Model (X) wins!")
    elif winner == -1: print("You (O) win!")
    else: print("It's a draw!")


