In [None]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt
import random
from collections import deque
import time

# Set Random seeds for reproducibility
np.random.seed(42)
tf.random.set_seed(42)
random.seed(42)

class SnakeGame:
    """Snake game environment for Decision Transformer"""

    def __init__(self, width=12, height=12):
        self.width = width
        self.height = height
        self.reset()

    def reset(self):
        """Reset the game to initial state"""
        center_x, center_y = self.width // 2, self.height // 2
        self.snake = [(center_x, center_y), (center_x - 1, center_y), (center_x - 2, center_y)]
        self.direction = 1  # 0=up, 1=right, 2=down, 3=left
        self.food = self._place_food()
        self.score = 0
        self.done = False
        self.steps_without_food = 0
        self.max_steps_without_food = 100
        return self.get_state()

    def _place_food(self):
        """Place food randomly, avoiding snake body"""
        while True:
            food = (random.randint(0, self.width - 1), random.randint(0, self.height - 1))
            if food not in self.snake:
                return food

    def get_state(self):
        """Get current game state as a feature vector"""
        state = np.zeros((self.height, self.width, 3))  # 3 channels: snake, food, head

        # Channel 0: Snake Body
        for i, (x, y) in enumerate(self.snake):
            if 0 <= x < self.width and 0 <= y < self.height:
                state[y, x, 0] = 1.0

        # Channel 1: Food
        if 0 <= self.food[0] < self.width and 0 <= self.food[1] < self.height:
            state[self.food[1], self.food[0], 1] = 1.0

        # Channel 2: Snake head
        head_x, head_y = self.snake[0]
        if 0 <= head_x < self.width and 0 <= head_y < self.height:
            state[head_y, head_x, 2] = 1.0

        flat_state = state.flatten()

        # Additional normalized features
        food_x, food_y = self.food
        additional_features = [
            self.direction / 3.0,
            (food_x - head_x) / self.width,
            (food_y - head_y) / self.height,
            len(self.snake) / (self.width * self.height),
            self.steps_without_food / self.max_steps_without_food
        ]
        return np.concatenate([flat_state, additional_features])

    def step(self, action):
        """
        Execute one game step.
        Actions: 0 = up, 1 = right, 2 = down, 3 = left
        Returns: next_state, reward, done
        """
        if self.done:
            return self.get_state(), 0, True

        # Prevent direct reversal
        if action != (self.direction + 2) % 4:
            self.direction = action

        # Calculate new head position
        head_x, head_y = self.snake[0]
        if self.direction == 0:  # up
            new_head = (head_x, head_y - 1)
        elif self.direction == 1:  # right
            new_head = (head_x + 1, head_y)
        elif self.direction == 2:  # down
            new_head = (head_x, head_y + 1)
        else:  # left
            new_head = (head_x - 1, head_y)

        # Check wall collision
        if (new_head[0] < 0 or new_head[0] >= self.width or
                new_head[1] < 0 or new_head[1] >= self.height):
            self.done = True
            return self.get_state(), -10, self.done

        # Check self collision
        if new_head in self.snake:
            self.done = True
            return self.get_state(), -10, self.done

        # Move snake
        self.snake.insert(0, new_head)

        # Check if food is eaten
        if new_head == self.food:
            self.score += 1
            reward = 10
            self.food = self._place_food()
            self.steps_without_food = 0
        else:
            self.snake.pop()
            reward = -0.1
            self.steps_without_food += 1

        # Starvation condition
        if self.steps_without_food >= self.max_steps_without_food:
            self.done = True
            reward = -5

        # Bonus for moving closer to food
        food_x, food_y = self.food
        distance_to_food = abs(head_x - food_x) + abs(head_y - food_y)
        reward += 0.1 / (distance_to_food + 1)

        return self.get_state(), reward, self.done

    def render_console(self):
        """Print game state to console"""
        grid = [['.' for _ in range(self.width)] for _ in range(self.height)]

        # Place snake
        for i, (x, y) in enumerate(self.snake):
            if 0 <= x < self.width and 0 <= y < self.height:
                grid[y][x] = 'H' if i == 0 else 'S'

        # Place food
        food_x, food_y = self.food
        if 0 <= food_x < self.width and 0 <= food_y < self.height:
            grid[food_y][food_x] = 'F'

        # Print grid
        print(f"Score: {self.score}, Steps without food: {self.steps_without_food}")
        for row in grid:
            print(' '.join(row))
        print()


# ======================= Decision Transformer Model =======================
class DecisionTransformer(keras.Model):
    def __init__(self, state_dim, action_dim=4, max_length=50, embed_dim=256, num_heads=8, num_layers=4):
        super().__init__()
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.max_length = max_length
        self.embed_dim = embed_dim

        # Embedding layers
        self.state_embed = keras.Sequential([
            layers.Dense(embed_dim, activation='relu'),
            layers.LayerNormalization(),
            layers.Dense(embed_dim)
        ])

        self.action_embed = layers.Embedding(action_dim, embed_dim)

        self.return_embed = keras.Sequential([
            layers.Dense(embed_dim, activation='relu'),
            layers.Dense(embed_dim)
        ])

        self.pos_embed = layers.Embedding(3 * max_length, embed_dim)

        # Transformer blocks
        self.transformer_blocks = []
        for _ in range(num_layers):
            self.transformer_blocks.append({
                'attention': layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim // num_heads),
                'dropout1': layers.Dropout(0.1),
                'norm1': layers.LayerNormalization(),
                'ff1': layers.Dense(embed_dim * 2, activation='gelu'),
                'ff2': layers.Dense(embed_dim),
                'dropout2': layers.Dropout(0.1),
                'norm2': layers.LayerNormalization()
            })

        # Output head
        self.action_head = keras.Sequential([
            layers.Dense(embed_dim, activation='relu'),
            layers.Dropout(0.1),
            layers.Dense(action_dim)
        ])

    def call(self, inputs, training=False):
        states = inputs['states']
        actions = inputs['actions']
        return_to_go = inputs['return_to_go']
        
        batch_size = tf.shape(states)[0]
        seq_len = tf.shape(states)[1]

        # Embeddings
        state_embeddings = self.state_embed(states)
        # Convert actions to int32 for embedding
        actions_int = tf.cast(actions, tf.int32)
        action_embeddings = self.action_embed(actions_int)
        return_embeddings = self.return_embed(tf.expand_dims(return_to_go, -1))

        # Build sequence (interleaved: return, state, action)
        sequence = tf.stack([return_embeddings, state_embeddings, action_embeddings], axis=3)
        sequence = tf.reshape(sequence, (batch_size, seq_len * 3, self.embed_dim))

        # Add positional embeddings
        positions = tf.range(seq_len * 3)
        pos_embeddings = self.pos_embed(positions)
        sequence += pos_embeddings

        # Apply transformer blocks
        x = sequence
        for block in self.transformer_blocks:
            # Multi-head attention
            attn_output = block['attention'](x, x, training=training)
            attn_output = block['dropout1'](attn_output, training=training)
            x = block['norm1'](x + attn_output)

            # Feed forward
            ff_output = block['ff1'](x)
            ff_output = block['ff2'](ff_output)
            ff_output = block['dropout2'](ff_output, training=training)
            x = block['norm2'](x + ff_output)

        # Extract action tokens (every 3rd token starting from index 2)
        action_tokens = x[:, 2::3, :]
        action_logits = self.action_head(action_tokens)
        return action_logits


# ======================= Data Collection =======================
def collect_random_data(env, num_episodes=500, max_steps=200):
    trajectories = []
    print("Collecting training data...")

    for episode in range(num_episodes):
        if episode % 100 == 0:
            print(f"Episode {episode}/{num_episodes}")

        env.reset()
        states, actions, rewards = [], [], []

        for step in range(max_steps):
            state = env.get_state()

            head_x, head_y = env.snake[0]
            food_x, food_y = env.food

            # Smart action selection 70% of the time
            if random.random() < 0.7:
                dx = food_x - head_x
                dy = food_y - head_y

                if abs(dx) > abs(dy):
                    action = 1 if dx > 0 else 3  # right or left
                else:
                    action = 2 if dy > 0 else 0  # down or up

                # Avoid walls
                if action == 0 and head_y == 0:
                    action = random.choice([1, 3])
                elif action == 1 and head_x == env.width - 1:
                    action = random.choice([0, 2])
                elif action == 2 and head_y == env.height - 1:
                    action = random.choice([1, 3])
                elif action == 3 and head_x == 0:
                    action = random.choice([0, 2])
            else:
                action = random.randint(0, 3)

            next_state, reward, done = env.step(action)
            states.append(state)
            actions.append(action)
            rewards.append(reward)

            if done:
                break

        # Calculate returns-to-go
        returns = []
        rtg = 0
        for r in reversed(rewards):
            rtg += r
            returns.insert(0, rtg)

        if len(states) > 1:
            trajectories.append({
                'states': np.array(states),
                'actions': np.array(actions),
                'rewards': np.array(rewards),
                'returns': np.array(returns)
            })

    print(f"Collected {len(trajectories)} trajectories")
    return trajectories


# ======================= Preprocess Training Data =======================
def create_training_data(trajectories, max_length=20):
    states_batch, actions_batch, rtg_batch, targets_batch = [], [], [], []

    for traj in trajectories:
        states = traj['states']
        actions = traj['actions']
        returns = traj['returns']

        for i in range(len(states) - 1):
            seq_len = min(max_length, len(states) - i)

            s_seq = states[i:i+seq_len]
            a_seq = actions[i:i+seq_len]
            r_seq = returns[i:i+seq_len]

            # Pad sequences to max_length
            pad_len = max_length - seq_len
            if pad_len > 0:
                s_seq = np.concatenate([np.zeros((pad_len, s_seq.shape[1])), s_seq])
                a_seq = np.concatenate([np.zeros(pad_len), a_seq])
                r_seq = np.concatenate([np.zeros(pad_len), r_seq])

            # Create targets (next actions)
            if i + seq_len < len(actions):
                target = actions[i+1:i+seq_len+1]
                if len(target) < max_length:
                    target = np.concatenate([np.zeros(max_length - len(target)), target])

                states_batch.append(s_seq)
                actions_batch.append(a_seq)
                rtg_batch.append(r_seq)
                targets_batch.append(target)

    return (np.array(states_batch), np.array(actions_batch),
            np.array(rtg_batch), np.array(targets_batch))


# ======================= Train the Model =======================
def train_decision_transformer():
    env = SnakeGame(width=10, height=10)
    trajectories = collect_random_data(env, num_episodes=1000, max_steps=150)
    states, actions, returns_to_go, targets = create_training_data(trajectories)

    print(f"Training data shape - States: {states.shape}, Actions: {actions.shape}, RTG: {returns_to_go.shape}, Targets: {targets.shape}")

    model = DecisionTransformer(
        state_dim=states.shape[-1],
        action_dim=4,
        max_length=20,
        embed_dim=128,
        num_heads=4,
        num_layers=4
    )

    model.compile(
        optimizer=keras.optimizers.Adam(learning_rate=1e-4),
        loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=["accuracy"]
    )

    print("Training Decision Transformer...")
    history = model.fit(
        x={"states": states, "actions": actions, "return_to_go": returns_to_go},
        y=targets,
        batch_size=32,
        epochs=50,
        validation_split=0.2,
        verbose=1
    )

    return model, env, history


# ======================= Evaluate Model =======================
def play_with_dt(model, env, target_return=15, render=True):
    env.reset()
    states = [env.get_state()]
    actions = [0]
    returns_to_go = [target_return]
    total_reward = 0
    steps = 0

    while not env.done and steps < 200:
        seq_len = min(20, len(states))
        s_seq = np.array(states[-seq_len:])
        a_seq = np.array(actions[-seq_len:])
        r_seq = np.array(returns_to_go[-seq_len:])

        # Pad sequences
        pad_len = 20 - seq_len
        if pad_len > 0:
            s_seq = np.concatenate([np.zeros((pad_len, s_seq.shape[1])), s_seq])
            a_seq = np.concatenate([np.zeros(pad_len), a_seq])
            r_seq = np.concatenate([np.zeros(pad_len), r_seq])

        # Create batch
        s_batch = s_seq[np.newaxis, ...]
        a_batch = a_seq[np.newaxis, ...]
        r_batch = r_seq[np.newaxis, ...]

        # Get action prediction
        action_logits = model({"states": s_batch, "actions": a_batch, "return_to_go": r_batch}, training=False)
        action_probs = tf.nn.softmax(action_logits[0, -1])
        action = tf.argmax(action_probs).numpy()

        # Take action
        next_state, reward, done = env.step(action)
        total_reward += reward
        steps += 1

        # Update sequences
        states.append(next_state)
        actions.append(action)
        returns_to_go.append(returns_to_go[-1] - reward)

        if render and steps % 5 == 0:
            print(f"Step {steps}, Score: {env.score}, Total Reward: {total_reward:.2f}")
            env.render_console()
            time.sleep(0.2)

    print(f"Game Over! Final Score: {env.score}, Total Reward: {total_reward:.2f}")
    return env.score, total_reward


# ======================= Main =======================
if __name__ == "__main__":
    print("===== Decision Transformer Training =====")
    model, env, history = train_decision_transformer()

    print("\n=== Testing Trained Model ===")
    for target_return in [10, 20, 30]:
        print(f"\nTarget return: {target_return}")
        score, reward = play_with_dt(model, env, target_return=target_return, render=True)
        print(f"Score: {score}, Reward: {reward:.2f}")

    # Plot training curves
    plt.figure(figsize=(12, 4))

    plt.subplot(1, 2, 1)
    plt.plot(history.history['loss'], label='Training loss')
    plt.plot(history.history['val_loss'], label='Validation loss')
    plt.title('Loss Curve')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()

    plt.subplot(1, 2, 2)
    plt.plot(history.history['accuracy'], label='Train Acc')
    plt.plot(history.history['val_accuracy'], label='Val Acc')
    plt.title('Accuracy Curve')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()

    plt.tight_layout()
    plt.show()

    print("\n=== Demo Complete ===")

2025-07-30 01:18:08.175187: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-07-30 01:18:08.187312: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1753831088.201375  128352 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1753831088.204656  128352 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1753831088.213183  128352 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking 

===== Decision Transformer Training =====
Collecting training data...
Episode 0/1000
Episode 100/1000
Episode 200/1000
Episode 300/1000
Episode 400/1000
Episode 500/1000
Episode 600/1000
Episode 700/1000
Episode 800/1000
Episode 900/1000
Collected 1000 trajectories
Training data shape - States: (1860, 20, 305), Actions: (1860, 20), RTG: (1860, 20), Targets: (1860, 20)
Training Decision Transformer...
Epoch 1/50


I0000 00:00:1753831089.912755  128352 gpu_device.cc:2019] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 1148 MB memory:  -> device: 0, name: NVIDIA GeForce RTX 3060, pci bus id: 0000:01:00.0, compute capability: 8.6
I0000 00:00:1753831095.682609  129490 service.cc:152] XLA service 0x77eb5000d4a0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1753831095.682624  129490 service.cc:160]   StreamExecutor device (0): NVIDIA GeForce RTX 3060, Compute Capability 8.6
2025-07-30 01:18:15.821255: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:269] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
2025-07-30 01:18:16.217435: W tensorflow/compiler/tf2xla/kernels/assert_op.cc:39] Ignoring Assert operator compile_loss/sparse_categorical_crossentropy/SparseSoftmaxCrossEntropyWithLogits/assert_equal_1/Assert/Assert
I0000 00:00:1753831096.707779  129490 cuda_dnn.cc:529] Loaded c

[1m23/47[0m [32m━━━━━━━━━[0m[37m━━━━━━━━━━━[0m [1m0s[0m 7ms/step - accuracy: 0.2844 - loss: 1.6664   

I0000 00:00:1753831109.237667  129490 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.


[1m44/47[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 8ms/step - accuracy: 0.3039 - loss: 1.5653

2025-07-30 01:18:30.027108: W tensorflow/compiler/tf2xla/kernels/assert_op.cc:39] Ignoring Assert operator compile_loss/sparse_categorical_crossentropy/SparseSoftmaxCrossEntropyWithLogits/assert_equal_1/Assert/Assert

























[1m47/47[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 297ms/step - accuracy: 0.3066 - loss: 1.5551

2025-07-30 01:18:43.355356: W tensorflow/compiler/tf2xla/kernels/assert_op.cc:39] Ignoring Assert operator compile_loss/sparse_categorical_crossentropy/SparseSoftmaxCrossEntropyWithLogits/assert_equal_1/Assert/Assert
2025-07-30 01:18:44.822448: W tensorflow/compiler/tf2xla/kernels/assert_op.cc:39] Ignoring Assert operator compile_loss/sparse_categorical_crossentropy/SparseSoftmaxCrossEntropyWithLogits/assert_equal_1/Assert/Assert








[1m47/47[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m38s[0m 403ms/step - accuracy: 0.3074 - loss: 1.5520 - val_accuracy: 0.3629 - val_loss: 1.3177
Epoch 2/50
[1m47/47[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 11ms/step - accuracy: 0.4705 - loss: 1.1919 - val_accuracy: 0.4192 - val_loss: 1.2900
Epoch 3/50
[1m47/47[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 12ms/step - accuracy: 0.5594 - loss: 1.0485 - val_accuracy: 0.4528 - val_loss: 1.2818
Epoch 4/50
[1m47/47[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 10ms/step - accuracy: 0.6223 - loss: 0.9348 - val_accuracy: 0.4710 - val_loss: 1.2895
Epoch 5/50
[1m47/47[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 10ms/step - accuracy: 0.6724 - loss: 0.8359 - val_accuracy: 0.4746 - val_loss: 1.3241
Epoch 6/50
[1m47/47[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 10ms/step - accuracy: 0.7251 - loss: 0.7314 - val_accuracy: 0.4847 - val_loss: 1.3826
Epoch 7/50
[1m47/47[0m [32m━━━━━━━━━━━━━