In [None]:
!pip install gymnasium

In [None]:
import random
import numpy as np
import gymnasium as gym
from collections import deque
import matplotlib.pyplot as plt
from tensorflow.keras.layers import Dense
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.models import Sequential

In [None]:
env = gym.make('FrozenLake-v1', desc=None, map_name="4x4", is_slippery=False, render_mode='rgb_array')
env.reset()
env.render()

In [None]:
# Test this code
# number_of_par = env.observation_space.n
# # if type(number_of_par) == list:
# #   number_of_par = len(number_of_par)
# # else:
# #   number_of_par = 1

# number_of_par = len(number_of_par) if type(number_of_par) == list else 1
# number_of_par

In [None]:
class DqnAgent:

  def __init__(self, env):
    self.env=env
    self.memory=deque(maxlen=2000)
    self.gamma=0.3
    self.epsilon=1.0
    self.epsilon_min=0.01
    self.epsilon_decay=0.995
    self.learning_rate=0.005
    self.tau=0.125
    self.model=self.create_model()
    self.target_model=self.create_model()

  def action(self, state):
    self.epsilon *= self.epsilon_decay
    self.epsilon = max(self.epsilon , self.epsilon_min)

    if np.random.random()<self.epsilon:
      return self.env.action_space.sample()
    else:
      return np.argmax(self.model.predict(state))

  def create_model(self):
    number_of_par = self.env.observation_space.n
    number_of_par = len(number_of_par) if type(number_of_par) == list else 1
    model=Sequential()
    model.add(Dense(24,input_dim=number_of_par,activation="relu"))
    model.add(Dense(48,activation="relu"))
    model.add(Dense(24,activation="relu"))
    model.add(Dense(self.env.action_space.n))
    model.compile(loss="mean_squared_error",optimizer=Adam(learning_rate=self.learning_rate))
    return model

  def remembr(self,state,action,reward,new_state,done):
    self.memory.append([state,action,reward,new_state,done])

  def target_train(self):
    weights=self.model.get_weights()
    target_weights=self.target_model.get_weights()

    for i in range(len(target_weights)):
      target_weights[i]= weights[i] * self.tau + target_weights[i] * (1-self.tau)

    self.target_model.set_weights(target_weights)

  def save_model(self,path):
    self.model.save(path)

  def reply(self):
    batch_size = 32
    if len(self.memory) < batch_size:
        return

    samples = random.sample(self.memory, batch_size)
    for sample in samples:
        state, action, reward, new_state, done = sample

        # Reshape state and new_state to match the model input shape
        state = np.array(state).reshape(1, -1)  # Ensure it is (1, input_dim)
        new_state = np.array(new_state).reshape(1, -1)

        # Get the target predictions
        target = self.target_model.predict(state)

        if done:
            target[0][action] = reward
        else:
            Q_future = max(self.target_model.predict(new_state)[0])
            target[0][action] = reward + self.gamma * Q_future

        # Fit the model (reshape state and target as required)
        self.model.fit(state, target, epochs=1, verbose=1)


In [None]:
trails=500
trail_len=500
dqn_agent=DqnAgent(env)

In [None]:
from IPython.display import clear_output

In [None]:
for trail in range(trails):
  current_state=np.array([[env.reset()[0]]])
  for step in range(trail_len):
    print("#",step)
    action=dqn_agent.action(current_state)
    new_state,reward,done,_,_=env.step(action)
    new_state=np.array([[new_state]])
    dqn_agent.remembr(current_state,action,reward,new_state,done)
    dqn_agent.reply()
    dqn_agent.target_train()
    current_state=new_state
    clear_output(wait=True)
    if done:
      break
  if step>199:
    print("failed")
  else:
    print("success")
    dqn_agent.save_model('result.h5')
    break

In [None]:
done = False
state = np.array([[env.reset()[0]]])
frames=[]
epoches=0

while not done:
  predict_action = dqn_agent.target_model.predict(state)
  action = np.argmax(predict_action[0])
  new_state,reward,done,_,_=env.step(action)
  print(done)
  state =  np.array([[new_state]])
  frames.append({'frame':env.render(),
                 'state':state,
                 'action':action,
                 'reward':reward})
  epoches+=1
  print(f"epoche : {epoches}")
  if epoches == 199:
    break

In [None]:
for i ,frame in enumerate(frames):
    plt.title(f"Frame : {i}")
    plt.imshow(frame['frame'])
    plt.pause(0.2)
    plt.clf()