Step 1: Define the Environment

In [1]:
import numpy as np

In [2]:
# define the environment
n_states = 16  # Number of states in the grid world
n_actions = 4  # Number of possible actions (up, down, left, right)
goal_state = 15  # Goal state

# Initialize Q-table with zeros
Q_table = np.zeros((n_states, n_actions))

Step 2: Set Hyperparameters

In [3]:
# define parameters
learning_rate = 0.85
discount_factor = 0.96
exploration_prob = 0.2
epochs = 1000

Step 3: Implement the Q-Learning Algorithm

In [4]:
# Q-learning algorithm
for epoch in range(epochs):
    current_state = np.random.randint(0, n_states)  # Start from a random state

    while current_state != goal_state:
        # Choose action with epsilon-greedy strategy
        if np.random.rand() < exploration_prob:
            action = np.random.randint(0, n_actions)  # Explore
        else:
            action = np.argmax(Q_table[current_state])  # Exploit

        # Simulate the environment (move to the next state)
        # For simplicity, move to the next state
        next_state = (current_state + 1) % n_states

        # Define a simple reward function (1 if the goal state is reached, 0 otherwise)
        reward = 1 if next_state == goal_state else 0

        # Update Q-value using the Q-learning update rule
        Q_table[current_state, action] += learning_rate * \
            (reward + discount_factor *
             np.max(Q_table[next_state]) - Q_table[current_state, action])

        current_state = next_state  # Move to the next state

Step 4: Output the Learned Q-Table

In [5]:
# After training, the Q-table represents the learned Q-values
print("Learned Q-table:")
print(Q_table)

Learned Q-table:
[[0.56467331 0.56276754 0.56463043 0.56276754]
 [0.58820095 0.58820036 0.58790359 0.58820137]
 [0.61270976 0.61270976 0.61270976 0.61270976]
 [0.63823933 0.63823933 0.63823933 0.63823933]
 [0.66483264 0.66483264 0.66483263 0.66483264]
 [0.692534   0.692534   0.692534   0.692534  ]
 [0.72138958 0.72138958 0.72138958 0.72138958]
 [0.75144748 0.75144748 0.75144748 0.75144748]
 [0.78275779 0.78275779 0.78275779 0.78275779]
 [0.8153727  0.8153727  0.8153727  0.8153727 ]
 [0.84934656 0.84934656 0.84934656 0.84934656]
 [0.884736   0.884736   0.884736   0.884736  ]
 [0.9216     0.9216     0.9216     0.9216    ]
 [0.96       0.96       0.96       0.96      ]
 [1.         1.         1.         1.        ]
 [0.         0.         0.         0.        ]]
