# All imports

In [3]:
import numpy as np
import random
from tqdm import tqdm
import gymnasium as gym
from collections import deque, namedtuple
from collections import defaultdict
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
from gymnasium.envs.toy_text.frozen_lake import generate_random_map

# w2v required
from collections import Counter
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics.pairwise import cosine_similarity

# DQN required
import argparse
import numpy as np
import logging
from matplotlib import animation # will be needed for rendering

# All global variables needed
Since some or all of these variables are needed for each cell below, it's difficult to put these inside the main function.\
Perhaps, using some args method might work?

## Common Functions needed

### Function to generate custom map

In [None]:
# ============================================================================================
# Function to generate custom map
# ============================================================================================
def make_env(env_name, env_dim = 4, seed = 42, stochastic = False):
    env = gym.make(env_name, desc=generate_random_map(size=env_dim, seed=seed), 
                   is_slippery = stochastic, render_mode = 'rgb_array')
    return env

## Global Variables

In [4]:
'''
env_name: str
env_dim: int --> Dimension of the game: 4x4 or 8x8
seed
stochastic = boolean --> Whether we use is_slippery = True or False
 '''
env_name = "FrozenLake-v1"
env_dim = 4
stochastic = False
seed = 42
gamma = 0.99 # discount factor in Q computation
alpha = 0.1 # learning rate in the table
num_episodes_q_table = 100_000
convergence_threshold = 1e-4
epsilon_start = 1
epsilon_decay_q_table = 0.99995
epsilon_end = 0.01
check_env_details = True

# Creating the environment
env = make_env(env_name=env_name, env_dim=env_dim, seed = seed, stochastic=stochastic)
state_dim = env.observation_space.n
action_dim = env.action_space.n
print("State space: ", env.observation_space.n)
print("Action space: ", env.action_space.n)

# state related variables
num_episodes_trajectories = 10_000
num_states = state_dim
num_actions = action_dim

max_esp_len = 100

# Tabular Q learning

### Q-Learning Agent class

In [None]:
# ============================================================================================
# Q-Learning Agent class
# ============================================================================================
class QLearningAgent:
    """Q-learning agent."""
    def __init__(self, num_states, num_actions, gamma=0.99, epsilon=0.1, alpha=0.1):
        self.num_states = num_states
        self.num_actions = num_actions
        self.alpha = alpha  # Learning rate
        self.gamma = gamma  # Discount factor
        self.epsilon = epsilon  # Exploration rate
        self.q_table = np.zeros((num_states, num_actions))  # Initialize Q-table

    def choose_action(self, state, epsilon):
        """Epsilon-greedy action selection."""
        if random.uniform(0, 1) < epsilon:
            return random.randint(0, self.num_actions - 1)  # Explore
        else:
            return np.argmax(self.q_table[state, :])  # Exploit

    def update_q_value(self, state, action, reward, next_state):
        """Q-learning update rule."""
        best_next_action = np.argmax(self.q_table[next_state, :])  # Greedy action for next state
        td_target = reward + self.gamma * self.q_table[next_state, best_next_action]
        td_error = td_target - self.q_table[state, action]
        self.q_table[state, action] += self.alpha * td_error  # Update Q-table

    def get_optimal_policy(self):
        """Extracts the optimal policy after training."""
        return np.argmax(self.q_table, axis=1)

### Function to train the model using Q-learning for Frozen Lake

In [None]:
# ============================================================================================
# Function to train the model using Q-learning for Frozen Lake
# ============================================================================================
def run_tabular_q_frozen(env, agent, num_episodes=10, convergence_threshold=1e-4,
                         epsilon_start = 1, epsilon_decay = 0.995, epsilon_end = 0.01, seed=42):
    reward_curve = [] # this will store the moving avg of rewards
    moving_window = deque(maxlen=100)
    epsilon = epsilon_start
    prev_q_table = np.copy(agent.q_table)  # Store old Q-table

    for episode in tqdm(range(num_episodes)):
        state,_ = env.reset(seed=seed)
        # print(f"\nIn episode {episode}, After reset initial state = {state} and epsilon = {epsilon}")
        curr_reward = 0
        max_eps_len = 100
        flag = False

        for _ in range(max_eps_len):
            action = agent.choose_action(state, epsilon)
            next_state, reward, done, _, _ = env.step(action)
            agent.update_q_value(state, action, reward, next_state)
            state = next_state
            curr_reward += reward

            # # Compute max Q-value change
            # q_change = np.max(np.abs(agent.q_table - prev_q_table))
            # prev_q_table = np.copy(agent.q_table)

            # # Check Q-value convergence
            # if q_change < convergence_threshold:
            #     print(f"Q-values converged at Episode {episode+1} with max Q-change: {q_change}")
            #     flag = True
            #     break

            if done:
                break
        # end while inside an episode
        
        # Epsilon decay performed at the end of each episode
        epsilon *= epsilon_decay
        epsilon = max(epsilon, epsilon_end)

        # Appending the smoothened reward
        moving_window.append(curr_reward)
        reward_curve.append(np.mean(moving_window))

        if episode % 1000 == 0:
            print(f"Tabular Q: Episode {episode}: epsilon = {epsilon}, avg reward = {np.mean(moving_window)}")
        # end if

        # if flag:
        #     break
    # end for num_episode

    return agent.q_table, agent.get_optimal_policy(), reward_curve

### Running the Tabular Q learning

In [None]:
# env = make_env(env_name=env_name, env_dim=env_dim, seed = seed, stochastic=stochastic)

# Setting seeds
np.random.seed(seed)
random.seed(seed)

if check_env_details:
    # Extract the environment description (grid layout)
    lake_grid = env.unwrapped.desc  # Gets the grid representation

    # Print state-to-symbol mapping
    print("Frozen Lake Grid Layout:")
    for row in lake_grid:
        print(" ".join(row.astype(str)))

    goal_state = None
    rows, cols = lake_grid.shape
    for i in range(rows):
        for j in range(cols):
            if lake_grid[i, j] == b'G':  # 'G' is stored as a byte-string
                goal_state = i * cols + j  # Convert (row, col) to state number
                break
        # end for j
    # end for i
    print(f"Goal State: {goal_state}")
# end if check_env

state_dim = env.observation_space.n
action_dim = env.action_space.n
print("State space: ", env.observation_space.n)
print("Action space: ", env.action_space.n)

learner = QLearningAgent(num_states=state_dim, num_actions=action_dim, gamma=gamma
                            , epsilon=epsilon_start, alpha=alpha) # Creating the learning Agent

final_q_table, final_policy, reward_curve = run_tabular_q_frozen(
                env, learner, num_episodes=num_episodes_q_table, convergence_threshold=convergence_threshold,
                epsilon_start = epsilon_start, epsilon_decay = epsilon_decay_q_table, epsilon_end = epsilon_end, seed=seed)

Val_f = np.max(final_q_table, axis=1)

state = 0
# Define action map
action_map = {
0: "Left",
1: "Down",
2: "Right",
3: "Up"
}
print("State: Type -    V(s),    action taken")
lake_grid = env.unwrapped.desc  # Gets the grid representation
for row in lake_grid:
    for cell in row:
        print(f"     {state}:   {cell.decode('utf-8')} - {Val_f[state]:.2f}, {final_policy[state]}-->{action_map[final_policy[state]]}")  # Convert byte to string
        state += 1
# assert False, "c1"

# Print the final table and policy
print("Final Q function: ", final_q_table)
# print("Final Policy: ", final_policy)
# print("Final Value function: ", Val_f)

# # Plot heatmap of the Value function
# plt.figure(figsize=(5,5))
# plt.imshow(Val_f.reshape(4,4), cmap="coolwarm", interpolation="nearest")
# for i in range(4):
#     for j in range(4):
#         plt.text(j, i, f"{Val_f[i*4+j]:.2f}", ha='center', va='center', color='black')


# Plot the reward curve



# Save the current Q-function
file_name = f"Q_table_{env_name}_map_size_{env_dim}_stochastic_{stochastic}_seed_{seed}.npy"
np.save(file_name, final_q_table)

# Collecting Trajectories
Here we may need to ensure that the trajectories sufficiently explore each state.\
Thus, we may want to start from specific states when reseting the environment during trajectory collection.

# Performing w2v

# Now we need to ensure that the same states are sampled to compare the performance of w2v imbued DQN and Vanilla DQN. 
Thus we reset np, random and torch.manual_seed before starting both the following cells

# DQN with w2v

# Vanilla DQN (without w2v)