In [3]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import numpy as np
import random
from tqdm import tqdm
import math

# --- 1. TicTacToe 게임 환경 ---
class TicTacToe:
    def __init__(self):
        self.board = np.zeros((3, 3), dtype=int)
        self.current_player = 1  # 1: 'X', -1: 'O'

    def reset(self):
        self.board = np.zeros((3, 3), dtype=int)
        self.current_player = 1
        return self.get_state()

    def get_state(self):
        return self.board.flatten()

    def get_available_actions(self):
        return [i for i, val in enumerate(self.get_state()) if val == 0]

    def make_move(self, action):
        if self.get_state()[action] != 0:
            raise ValueError("Invalid move")
        
        row, col = action // 3, action % 3
        self.board[row, col] = self.current_player
        
        winner = self.check_winner()
        done = winner is not None
        reward = 0
        if done:
            if winner == 1: reward = 1
            elif winner == -1: reward = -1
            else: reward = 0 # Draw

        self.current_player *= -1
        return self.get_state(), reward, done

    def check_winner(self):
        # Rows and Columns
        for i in range(3):
            if abs(self.board[i, :].sum()) == 3: return self.board[i, 0]
            if abs(self.board[:, i].sum()) == 3: return self.board[0, i]
        
        # Diagonals
        if abs(np.diag(self.board).sum()) == 3: return self.board[0, 0]
        if abs(np.diag(np.fliplr(self.board)).sum()) == 3: return self.board[0, 2]
        
        # Draw
        if not np.any(self.board == 0):
            return 0
            
        return None # Game not over

# --- 2. 전문가 데이터 생성 (Minimax 알고리즘 사용) ---
def minimax(board, player):
    game = TicTacToe()
    game.board = board
    winner = game.check_winner()
    if winner is not None:
        return winner * player

    best_score = -math.inf
    for action in game.get_available_actions():
        new_board = board.copy()
        row, col = action // 3, action % 3
        new_board[row, col] = player
        score = -minimax(new_board, -player)
        if score > best_score:
            best_score = score
    return best_score if best_score != -math.inf else 0

def get_best_move(board, player):
    best_score = -math.inf
    best_move = -1
    game = TicTacToe()
    game.board = board
    
    available_actions = game.get_available_actions()
    if not available_actions:
        return -1

    for action in available_actions:
        new_board = board.copy()
        row, col = action // 3, action % 3
        new_board[row, col] = player
        score = -minimax(new_board, -player)
        if score > best_score:
            best_score = score
            best_move = action
    return best_move

def generate_trajectories(num_trajectories):
    print("Generating expert trajectories using Minimax...")
    trajectories = []
    env = TicTacToe()
    
    for _ in tqdm(range(num_trajectories)):
        states, actions, rewards = [], [], []
        state = env.reset()
        done = False
        
        while not done:
            current_player = env.current_player
            
            # Minimax 'X' (player 1) vs slightly random 'O' (player -1)
            if current_player == 1:
                move = get_best_move(env.board, current_player)
            else:
                # To add variety, player 'O' sometimes makes a random move
                if random.random() < 0.3:
                    move = random.choice(env.get_available_actions())
                else:
                    move = get_best_move(env.board, current_player)

            if move == -1: break # No more moves

            states.append(state)
            actions.append(move)
            
            state, reward, done = env.make_move(move)
            rewards.append(reward)
        
        # Calculate rewards-to-go
        rewards_to_go = np.zeros_like(rewards, dtype=float)
        running_rtg = 0
        for t in reversed(range(len(rewards))):
            # The final reward is the outcome of the game.
            # We assign it to the player who made the move at that step.
            if rewards[t] != 0: running_rtg = rewards[t]
            rewards_to_go[t] = running_rtg

        trajectories.append({
            'states': np.array(states),
            'actions': np.array(actions),
            'rewards_to_go': rewards_to_go
        })
    return trajectories

# --- 3. PyTorch Dataset & Decision Transformer 모델 ---
class TicTacToeDataset(Dataset):
    def __init__(self, trajectories, context_length):
        self.trajectories = trajectories
        self.context_length = context_length

    def __len__(self):
        return len(self.trajectories)

    def __getitem__(self, idx):
        traj = self.trajectories[idx]
        traj_len = len(traj['states'])
        
        # Pick a random starting point in the trajectory
        start_idx = random.randint(0, traj_len - 1)
        
        states = traj['states'][start_idx : start_idx + self.context_length]
        actions = traj['actions'][start_idx : start_idx + self.context_length]
        rtgs = traj['rewards_to_go'][start_idx : start_idx + self.context_length]

        # Padding
        T = len(states)
        padding_len = self.context_length - T
        
        states = torch.tensor(np.pad(states, ((0, padding_len), (0, 0)), 'constant'), dtype=torch.float32)
        actions = torch.tensor(np.pad(actions, (0, padding_len), 'constant', constant_values=-1), dtype=torch.long)
        rtgs = torch.tensor(np.pad(rtgs, (0, padding_len), 'constant'), dtype=torch.float32).unsqueeze(1)
        
        mask = torch.cat([torch.ones(T), torch.zeros(padding_len)], dim=0)
        
        return states, actions, rtgs, mask

class DecisionTransformer(nn.Module):
    def __init__(self, state_dim, action_dim, n_head, n_layer, d_model, context_length):
        super().__init__()
        self.d_model = d_model
        self.context_length = context_length
        
        self.embed_state = nn.Linear(state_dim, d_model)
        self.embed_action = nn.Embedding(action_dim + 1, d_model) # +1 for padding action
        self.embed_rtg = nn.Linear(1, d_model)
        self.embed_timestep = nn.Embedding(context_length, d_model)

        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=n_head, batch_first=True)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=n_layer)

        self.predict_action = nn.Linear(d_model, action_dim)

    def forward(self, states, actions, rtgs, timesteps):
        # `actions` need to be handled for padding (-1)
        # We shift actions to be non-negative for embedding lookup
        action_embeddings = self.embed_action(actions + 1)
        state_embeddings = self.embed_state(states)
        rtg_embeddings = self.embed_rtg(rtgs)
        time_embeddings = self.embed_timestep(timesteps)

        # Interleave sequence: (RTG_1, s_1, a_1, RTG_2, s_2, a_2, ...)
        # A simpler approach is to sum embeddings with time embeddings
        # Here we sum them, which works well for simpler tasks
        
        state_embeddings += time_embeddings
        action_embeddings += time_embeddings
        rtg_embeddings += time_embeddings

        # This creates the sequence of length 3*K
        stacked_inputs = torch.stack(
            (rtg_embeddings, state_embeddings, action_embeddings), dim=1
        ).permute(0, 2, 1, 3).reshape(states.shape[0], 3 * self.context_length, self.d_model)

        # Causal mask to ensure predictions at t only depend on inputs before t
        causal_mask = nn.Transformer.generate_square_subsequent_mask(3 * self.context_length).to(states.device)
        
        encoder_output = self.transformer_encoder(stacked_inputs, mask=causal_mask)
        
        # We only want to predict actions, which are at indices 1, 4, 7, ...
        # These correspond to state embeddings
        x = encoder_output[:, 1::3, :] # Get embeddings for state positions
        
        action_preds = self.predict_action(x)
        return action_preds

# --- 4. 학습 및 평가 ---
def train():
    # Hyperparameters
    CONTEXT_LENGTH = 5  # Max sequence length for the model
    N_EPOCHS = 50
    BATCH_SIZE = 128
    LR = 1e-4
    D_MODEL = 128
    N_HEAD = 4
    N_LAYER = 3

    # Setup
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    trajectories = generate_trajectories(num_trajectories=5000)
    dataset = TicTacToeDataset(trajectories, context_length=CONTEXT_LENGTH)
    dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

    model = DecisionTransformer(
        state_dim=9, 
        action_dim=9, 
        n_head=N_HEAD,
        n_layer=N_LAYER,
        d_model=D_MODEL,
        context_length=CONTEXT_LENGTH
    ).to(device)
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
    loss_fn = nn.CrossEntropyLoss()

    print("Starting training...")
    for epoch in range(N_EPOCHS):
        model.train()
        total_loss = 0
        for states, actions, rtgs, mask in dataloader:
            states, actions, rtgs, mask = states.to(device), actions.to(device), rtgs.to(device), mask.to(device)
            
            timesteps = torch.arange(CONTEXT_LENGTH, device=device).repeat(states.shape[0], 1)
            
            action_preds = model(states, actions, rtgs, timesteps)
            
            # We only calculate loss on the actions that were actually taken (not padded)
            # Reshape for loss function
            action_preds = action_preds.reshape(-1, 9)
            actions_target = actions.reshape(-1)
            mask = mask.reshape(-1).bool()
            
            # Filter out padded parts
            action_preds = action_preds[mask]
            actions_target = actions_target[mask]

            # We don't want to compute loss on the padding action value (-1)
            valid_targets = actions_target != -1
            loss = loss_fn(action_preds[valid_targets], actions_target[valid_targets])
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
        
        avg_loss = total_loss / len(dataloader)
        print(f"Epoch {epoch+1}/{N_EPOCHS}, Loss: {avg_loss:.4f}")
        
    print("Training finished.")
    return model

def play_game_with_model(model, context_length):
    model.eval()
    env = TicTacToe()
    
    states = torch.zeros((1, context_length, 9), dtype=torch.float32)
    actions = torch.full((1, context_length), -1, dtype=torch.long)
    
    # Target: Win the game! So, initial RTG is 1.0
    rtgs = torch.zeros((1, context_length, 1), dtype=torch.float32)
    rtgs[0, -1, 0] = 1.0
    
    timesteps = torch.arange(0, context_length).unsqueeze(0)

    print("\n--- New Game: You are 'O' ---")
    turn = 0
    done = False
    
    while not done:
        # Print board
        board_str = ""
        for i, cell in enumerate(env.get_state()):
            mark = 'X' if cell == 1 else 'O' if cell == -1 else str(i)
            board_str += f" {mark} "
            if (i+1) % 3 == 0:
                board_str += "\n" if i < 8 else ""
                if i < 8: board_str += "---+---+---\n"
        print(board_str)

        # Model's turn ('X')
        if env.current_player == 1:
            print("Model's turn ('X')...")
            with torch.no_grad():
                pred_actions = model(states, actions, rtgs, timesteps)
            
            # Get the prediction for the current step
            logits = pred_actions[0, turn, :]
            
            # Mask illegal moves
            available_actions = env.get_available_actions()
            mask = torch.full_like(logits, -float('inf'))
            mask[available_actions] = 0
            
            # Choose best legal move
            move = (logits + mask).argmax().item()
            
            print(f"Model chooses action: {move}")

        # Human's turn ('O')
        else:
            try:
                move = int(input("Your turn ('O'). Enter move (0-8): "))
                if move not in env.get_available_actions():
                    print("Invalid move. Try again.")
                    continue
            except ValueError:
                print("Invalid input. Enter a number between 0 and 8.")
                continue

        # Update sequence for the model
        if turn < context_length:
            states[0, turn] = torch.tensor(env.get_state(), dtype=torch.float32)
            actions[0, turn] = move
        
        # Make move in environment
        _, reward, done = env.make_move(move)

        # Update RTG for next prediction
        rtgs[0, turn, 0] = rtgs[0, turn, 0] - reward
        
        turn += 1
        print("-" * 20)

    winner = env.check_winner()
    print("--- Game Over ---")
    if winner == 1: print("Model (X) wins!")
    elif winner == -1: print("You (O) win!")
    else: print("It's a draw!")


if __name__ == '__main__':
    trained_model = train()
    
    # Save the model if you want
    # torch.save(trained_model.state_dict(), "tictactoe_dt_model.pth")
    
    # Load the model
    # model = DecisionTransformer(...)
    # model.load_state_dict(torch.load("tictactoe_dt_model.pth"))
    
    while True:
        play_game_with_model(trained_model, CONTEXT_LENGTH=5)
        if input("Play again? (y/n): ").lower() != 'y':
            break

Using device: cuda
Generating expert trajectories using Minimax...


  6%|▌         | 284/5000 [30:05<8:19:36,  6.36s/it]


KeyboardInterrupt: 

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import numpy as np
import random
from tqdm import tqdm
import math
import time

# --- 1. TicTacToe 게임 환경 (이전과 동일) ---
class TicTacToe:
    def __init__(self):
        self.board = np.zeros((3, 3), dtype=int)
        self.current_player = 1

    def reset(self):
        self.board = np.zeros((3, 3), dtype=int)
        self.current_player = 1
        return self.get_state()

    def get_state(self):
        return self.board.flatten()

    def get_available_actions(self):
        return [i for i, val in enumerate(self.get_state()) if val == 0]

    def make_move(self, action):
        if self.get_state()[action] != 0:
            raise ValueError("Invalid move")
        row, col = action // 3, action % 3
        self.board[row, col] = self.current_player
        winner = self.check_winner()
        done = winner is not None
        reward = 0
        if done:
            if winner == 1: reward = 1
            elif winner == -1: reward = -1
            else: reward = 0
        self.current_player *= -1
        return self.get_state(), reward, done

    def check_winner(self):
        for i in range(3):
            if abs(self.board[i, :].sum()) == 3: return self.board[i, 0]
            if abs(self.board[:, i].sum()) == 3: return self.board[0, i]
        if abs(np.diag(self.board).sum()) == 3: return self.board[0, 0]
        if abs(np.diag(np.fliplr(self.board)).sum()) == 3: return self.board[0, 2]
        if not np.any(self.board == 0): return 0
        return None

# --- 2. 최적화된 데이터 생성 ---

# Minimax 알고리즘 (이전과 동일)
def minimax(board, player):
    game = TicTacToe()
    game.board = board
    winner = game.check_winner()
    if winner is not None:
        return winner * player
    best_score = -math.inf
    for action in game.get_available_actions():
        new_board = board.copy()
        row, col = action // 3, action % 3
        new_board[row, col] = player
        score = -minimax(new_board, -player)
        if score > best_score: best_score = score
    return best_score if best_score != -math.inf else 0

def get_canonical_form(board):
    """보드의 8가지 대칭(회전, 대칭) 중 정규형(canonical form)을 찾습니다."""
    symmetries = []
    current_board = board.copy()
    for _ in range(4): # 4 rotations
        symmetries.append(current_board)
        symmetries.append(np.fliplr(current_board))
        current_board = np.rot90(current_board)

    # 튜플로 변환하여 정렬 가능하게 만듦
    symmetries_as_tuples = [tuple(b.flatten()) for b in symmetries]
    canonical_tuple = min(symmetries_as_tuples)
    return np.array(canonical_tuple).reshape(3, 3)

memoized_policy = {}
def build_optimal_policy_map(board, player):
    """모든 고유 상태에 대한 최적의 수를 미리 계산하여 맵에 저장 (재귀)"""
    canonical_board = get_canonical_form(board)
    canonical_tuple = tuple(canonical_board.flatten())

    if canonical_tuple in memoized_policy or TicTacToe()._check_winner_on_board(board) is not None:
        return

    game = TicTacToe()
    game.board = board
    available_actions = game.get_available_actions()
    
    best_score = -math.inf
    best_moves = []
    
    for action in available_actions:
        new_board = board.copy()
        row, col = action // 3, action % 3
        new_board[row, col] = player
        score = -minimax(new_board, -player)
        
        if score > best_score:
            best_score = score
            best_moves = [action]
        elif score == best_score:
            best_moves.append(action)

    memoized_policy[tuple(board.flatten())] = best_moves
    
    for move in best_moves:
        next_board = board.copy()
        row, col = move // 3, move % 3
        next_board[row, col] = player
        build_optimal_policy_map(next_board, -player)

# TicTacToe 클래스 내부에 헬퍼 함수 추가
TicTacToe._check_winner_on_board = lambda self, board: TicTacToe.check_winner(type('obj', (object,), {'board': board})())


def generate_all_optimal_trajectories():
    """사전 계산된 정책을 사용하여 모든 최적의 경로를 생성합니다."""
    print("Pre-calculating optimal policy for all unique states...")
    start_board = np.zeros((3, 3), dtype=int)
    build_optimal_policy_map(start_board, 1)
    
    print("Generating all unique optimal trajectories...")
    all_trajectories = []
    
    def find_paths(board, player, path):
        winner = TicTacToe()._check_winner_on_board(board)
        if winner is not None:
            states, actions = zip(*path) if path else ([], [])
            rewards = np.zeros(len(actions))
            
            # 승패에 따라 Reward-to-go 계산
            rewards_to_go = np.zeros_like(rewards, dtype=float)
            final_reward = winner # 1 for win, -1 for loss, 0 for draw
            for t in range(len(rewards)):
                rewards_to_go[t] = final_reward
                
            all_trajectories.append({
                'states': np.array(states),
                'actions': np.array(actions),
                'rewards_to_go': rewards_to_go
            })
            return

        optimal_moves = memoized_policy.get(tuple(board.flatten()))
        if not optimal_moves: return

        for move in optimal_moves:
            next_board = board.copy()
            row, col = move // 3, move % 3
            next_board[row, col] = player
            find_paths(next_board, -player, path + [(board.flatten(), move)])

    find_paths(start_board, 1, [])
    return all_trajectories


# --- 3. PyTorch Dataset & Decision Transformer 모델 (이전과 동일) ---
class TicTacToeDataset(Dataset):
    def __init__(self, trajectories, context_length):
        self.trajectories = trajectories
        self.context_length = context_length

    def __len__(self):
        return len(self.trajectories)

    def __getitem__(self, idx):
        traj = self.trajectories[idx]
        traj_len = len(traj['states'])
        start_idx = random.randint(0, traj_len - 1)
        states = traj['states'][start_idx : start_idx + self.context_length]
        actions = traj['actions'][start_idx : start_idx + self.context_length]
        rtgs = traj['rewards_to_go'][start_idx : start_idx + self.context_length]
        T = len(states)
        padding_len = self.context_length - T
        states = torch.tensor(np.pad(states, ((0, padding_len), (0, 0)), 'constant'), dtype=torch.float32)
        actions = torch.tensor(np.pad(actions, (0, padding_len), 'constant', constant_values=-1), dtype=torch.long)
        rtgs = torch.tensor(np.pad(rtgs, (0, padding_len), 'constant'), dtype=torch.float32).unsqueeze(1)
        mask = torch.cat([torch.ones(T), torch.zeros(padding_len)], dim=0)
        return states, actions, rtgs, mask

class DecisionTransformer(nn.Module):
    def __init__(self, state_dim, action_dim, n_head, n_layer, d_model, context_length):
        super().__init__()
        self.d_model = d_model
        self.context_length = context_length
        self.embed_state = nn.Linear(state_dim, d_model)
        self.embed_action = nn.Embedding(action_dim + 1, d_model)
        self.embed_rtg = nn.Linear(1, d_model)
        self.embed_timestep = nn.Embedding(context_length, d_model)
        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=n_head, batch_first=True)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=n_layer)
        self.predict_action = nn.Linear(d_model, action_dim)

    def forward(self, states, actions, rtgs, timesteps):
        action_embeddings = self.embed_action(actions + 1)
        state_embeddings = self.embed_state(states)
        rtg_embeddings = self.embed_rtg(rtgs)
        time_embeddings = self.embed_timestep(timesteps)
        state_embeddings += time_embeddings
        action_embeddings += time_embeddings
        rtg_embeddings += time_embeddings
        stacked_inputs = torch.stack(
            (rtg_embeddings, state_embeddings, action_embeddings), dim=1
        ).permute(0, 2, 1, 3).reshape(states.shape[0], 3 * self.context_length, self.d_model)
        causal_mask = nn.Transformer.generate_square_subsequent_mask(3 * self.context_length).to(states.device)
        encoder_output = self.transformer_encoder(stacked_inputs, mask=causal_mask)
        x = encoder_output[:, 1::3, :]
        action_preds = self.predict_action(x)
        return action_preds

# --- 4. 학습 및 평가 (하이퍼파라미터 조정) ---
def train():
    # Hyperparameters (조정됨)
    CONTEXT_LENGTH = 9  # 최대 9수면 게임이 끝나므로
    N_EPOCHS = 30       # 데이터가 고품질이므로 에포크 감소
    BATCH_SIZE = 64     # 데이터셋 크기에 맞춰 배치 사이즈 조정
    LR = 1e-4
    D_MODEL = 128
    N_HEAD = 4
    N_LAYER = 3

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    start_time = time.time()
    trajectories = generate_all_optimal_trajectories()
    end_time = time.time()
    print(f"Generated {len(trajectories)} unique optimal trajectories in {end_time - start_time:.2f} seconds.")
    
    dataset = TicTacToeDataset(trajectories, context_length=CONTEXT_LENGTH)
    dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

    model = DecisionTransformer(
        state_dim=9, action_dim=9, n_head=N_HEAD, n_layer=N_LAYER,
        d_model=D_MODEL, context_length=CONTEXT_LENGTH
    ).to(device)
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
    loss_fn = nn.CrossEntropyLoss()

    print("Starting training...")
    for epoch in range(N_EPOCHS):
        model.train()
        total_loss = 0
        for states, actions, rtgs, mask in dataloader:
            states, actions, rtgs, mask = states.to(device), actions.to(device), rtgs.to(device), mask.to(device)
            timesteps = torch.arange(CONTEXT_LENGTH, device=device).repeat(states.shape[0], 1)
            action_preds = model(states, actions, rtgs, timesteps)
            action_preds = action_preds.reshape(-1, 9)
            actions_target = actions.reshape(-1)
            mask = mask.reshape(-1).bool()
            action_preds = action_preds[mask]
            actions_target = actions_target[mask]
            valid_targets = actions_target != -1
            loss = loss_fn(action_preds[valid_targets], actions_target[valid_targets])
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        
        avg_loss = total_loss / len(dataloader)
        print(f"Epoch {epoch+1}/{N_EPOCHS}, Loss: {avg_loss:.4f}")
        
    print("Training finished.")
    return model

def play_game_with_model(model, context_length):
    model.eval()
    
    # 모델이 현재 사용 중인 device를 가져옵니다.
    device = next(model.parameters()).device
    
    env = TicTacToe()
    
    # 모든 텐서를 생성할 때 .to(device)를 붙여줍니다.
    states = torch.zeros((1, context_length, 9), dtype=torch.float32, device=device)
    actions = torch.full((1, context_length), -1, dtype=torch.long, device=device)
    rtgs = torch.zeros((1, context_length, 1), dtype=torch.float32, device=device)
    rtgs[0, :, 0] = 1.0 # 목표는 승리(RTG=1)
    timesteps = torch.arange(0, context_length, device=device).unsqueeze(0)

    print("\n--- New Game: You are 'O' ---")
    turn = 0
    done = False
    
    while not done and turn < 9:
        board_str = ""
        for i, cell in enumerate(env.get_state()):
            mark = 'X' if cell == 1 else 'O' if cell == -1 else str(i)
            board_str += f" {mark} "
            if (i+1) % 3 == 0:
                board_str += "\n" if i < 8 else ""
                if i < 8: board_str += "---+---+---\n"
        print(board_str)

        if env.current_player == 1:
            print("Model's turn ('X')...")
            with torch.no_grad():
                # 이제 모델과 입력 텐서가 모두 같은 device에 있습니다.
                pred_actions = model(states, actions, rtgs, timesteps)
            logits = pred_actions[0, turn, :]
            available_actions = env.get_available_actions()
            mask = torch.full_like(logits, -float('inf'))
            mask[available_actions] = 0
            move = (logits + mask).argmax().item()
            print(f"Model chooses action: {move}")
        else:
            try:
                move = int(input("Your turn ('O'). Enter move (0-8): "))
                if move not in env.get_available_actions():
                    print("Invalid move. Try again.")
                    continue
            except (ValueError, IndexError):
                print("Invalid input. Enter a number between 0 and 8.")
                continue

        # 상태 업데이트 시에도 .to(device)가 필요합니다.
        if turn < context_length:
            current_state_tensor = torch.tensor(env.get_state(), dtype=torch.float32).to(device)
            states[0, turn] = current_state_tensor
            actions[0, turn] = move
        
        _, reward, done = env.make_move(move)
        
        # RTG 업데이트
        if turn + 1 < context_length:
            rtgs[0, turn+1:] = rtgs[0, turn:-1].clone() # clone()을 사용하여 인플레이스 수정 방지
            rtgs[0, turn, 0] -= reward
        
        turn += 1
        print("-" * 20)

    winner = env.check_winner()
    print("--- Game Over ---")
    if winner == 1: print("Model (X) wins!")
    elif winner == -1: print("You (O) win!")
    else: print("It's a draw!")




Using device: cuda
Pre-calculating optimal policy for all unique states...
Generating all unique optimal trajectories...
Generated 1023 unique optimal trajectories in 14.43 seconds.
Starting training...
Epoch 1/30, Loss: 2.1757
Epoch 2/30, Loss: 2.0764
Epoch 3/30, Loss: 1.9939
Epoch 4/30, Loss: 1.8944
Epoch 5/30, Loss: 1.7810
Epoch 6/30, Loss: 1.6728
Epoch 7/30, Loss: 1.5787
Epoch 8/30, Loss: 1.5085
Epoch 9/30, Loss: 1.4357
Epoch 10/30, Loss: 1.3777
Epoch 11/30, Loss: 1.3128
Epoch 12/30, Loss: 1.2689
Epoch 13/30, Loss: 1.2266
Epoch 14/30, Loss: 1.1834
Epoch 15/30, Loss: 1.1446
Epoch 16/30, Loss: 1.1137
Epoch 17/30, Loss: 1.0857
Epoch 18/30, Loss: 1.0471
Epoch 19/30, Loss: 1.0220
Epoch 20/30, Loss: 1.0131
Epoch 21/30, Loss: 0.9743
Epoch 22/30, Loss: 0.9599
Epoch 23/30, Loss: 0.9460
Epoch 24/30, Loss: 0.9195
Epoch 25/30, Loss: 0.9271
Epoch 26/30, Loss: 0.8983
Epoch 27/30, Loss: 0.8835
Epoch 28/30, Loss: 0.8796
Epoch 29/30, Loss: 0.8797
Epoch 30/30, Loss: 0.8659
Training finished.

--- Ne

In [None]:
if __name__ == '__main__':
    CONTEXT_LENGTH = 9
    trained_model = train()
    while True:
        play_game_with_model(trained_model, CONTEXT_LENGTH)
        if input("Play again? (y/n): ").lower() != 'y':
            break

In [6]:
trained_model

DecisionTransformer(
  (embed_state): Linear(in_features=9, out_features=128, bias=True)
  (embed_action): Embedding(10, 128)
  (embed_rtg): Linear(in_features=1, out_features=128, bias=True)
  (embed_timestep): Embedding(9, 128)
  (transformer_encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-2): 3 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
        )
        (linear1): Linear(in_features=128, out_features=2048, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=2048, out_features=128, bias=True)
        (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (predict_action): Linear(in_features=128, out_fea

In [7]:
import torch.nn.functional as F

def analyze_model_decision(model, board_state, context_length=9):
    """
    주어진 보드 상태에서 모델이 각 행동에 대해 어떻게 평가하는지 분석합니다.
    """
    model.eval()
    device = next(model.parameters()).device
    
    # 모델 입력을 위한 빈 시퀀스 준비
    states = torch.zeros((1, context_length, 9), dtype=torch.float32, device=device)
    actions = torch.full((1, context_length), -1, dtype=torch.long, device=device)
    rtgs = torch.zeros((1, context_length, 1), dtype=torch.float32, device=device)
    rtgs[0, :, 0] = 1.0  # 항상 '승리'를 목표로 결정 분석
    timesteps = torch.arange(0, context_length, device=device).unsqueeze(0)

    # 현재 상태를 시퀀스의 첫 번째(turn=0)에 배치
    states[0, 0] = torch.tensor(board_state, dtype=torch.float32).to(device)
    
    with torch.no_grad():
        pred_actions_logits = model(states, actions, rtgs, timesteps)
        
    # 첫 번째 턴에 대한 결정(logits)을 가져옴
    logits = pred_actions_logits[0, 0, :]
    
    # Softmax를 이용해 확률로 변환
    probabilities = F.softmax(logits, dim=0)
    
    print("===== Model Decision Analysis =====")
    print("Board State (1: Model, -1: Opponent):", board_state)
    print("\nAction Probabilities:")
    
    # 가능한 수와 불가능한 수 구분하여 출력
    available_actions = [i for i, val in enumerate(board_state) if val == 0]
    for i, prob in enumerate(probabilities):
        move_type = "Legal" if i in available_actions else "Illegal"
        print(f"  Move {i}: {prob.item():.4f} ({prob.item()*100:.2f}%) [{move_type}]")
        
    best_move = probabilities.argmax().item()
    print(f"\n=> Model's Best Choice: Move {best_move} ({probabilities[best_move].item()*100:.2f}%)")
    print("===================================\n")
    return probabilities

# trained_model이 학습 완료된 모델이라고 가정
# 예제 분석 실행
# 1. 빈 보드 (첫 수 분석)
empty_board = [0, 0, 0, 0, 0, 0, 0, 0, 0]
analyze_model_decision(trained_model, empty_board)

# 2. 상대방(O)이 2개를 놓아 반드시 막아야 하는 상황
must_block_board = [ -1, -1, 0,  # O, O, _
                      0, 1, 0, 
                      0, 0, 0]
analyze_model_decision(trained_model, must_block_board)

# 3. 내가(X) 2개를 놓아 바로 이길 수 있는 상황
can_win_board = [ 1, 1, 0,  # X, X, _
                 -1, 0, 0, 
                 -1, 0, 0]
analyze_model_decision(trained_model, can_win_board)

===== Model Decision Analysis =====
Board State (1: Model, -1: Opponent): [0, 0, 0, 0, 0, 0, 0, 0, 0]

Action Probabilities:
  Move 0: 0.0799 (7.99%) [Legal]
  Move 1: 0.6264 (62.64%) [Legal]
  Move 2: 0.0046 (0.46%) [Legal]
  Move 3: 0.1470 (14.70%) [Legal]
  Move 4: 0.0363 (3.63%) [Legal]
  Move 5: 0.0314 (3.14%) [Legal]
  Move 6: 0.0161 (1.61%) [Legal]
  Move 7: 0.0049 (0.49%) [Legal]
  Move 8: 0.0533 (5.33%) [Legal]

=> Model's Best Choice: Move 1 (62.64%)

===== Model Decision Analysis =====
Board State (1: Model, -1: Opponent): [-1, -1, 0, 0, 1, 0, 0, 0, 0]

Action Probabilities:
  Move 0: 0.0204 (2.04%) [Illegal]
  Move 1: 0.4443 (44.43%) [Illegal]
  Move 2: 0.0428 (4.28%) [Legal]
  Move 3: 0.1606 (16.06%) [Legal]
  Move 4: 0.0007 (0.07%) [Illegal]
  Move 5: 0.0475 (4.75%) [Legal]
  Move 6: 0.0505 (5.05%) [Legal]
  Move 7: 0.0026 (0.26%) [Legal]
  Move 8: 0.2307 (23.07%) [Legal]

=> Model's Best Choice: Move 1 (44.43%)

===== Model Decision Analysis =====
Board State (1: Model, 

tensor([0.0175, 0.1842, 0.0290, 0.0449, 0.0891, 0.1057, 0.1907, 0.2156, 0.1234],
       device='cuda:0')