# Model and Replay Buffer definition

In [4]:
from collections import deque, namedtuple
import torch.nn as nn
import random

Transition = namedtuple('Transition',
                        ('state', 'action', 'next_state', 'reward'))
class ReplayBuffer:
    def __init__(self,max_capacity):
        self.memory = deque([],max_capacity)
        
    # Takes a named tuple of Transition
    def push(self,transition_):
        self.memory.append(transition_)
        
    def can_sample(self,batch_size):
        return len(self.memory) >= batch_size
        
    def sample(self,batch_size):
        return random.sample(self.memory,batch_size)
        
    def __len__(self):
        return len(self.memory)        

class DQN(nn.Module):
    def __init__(self,n_observations,actions):
        self.input = nn.Linear(n_observations,128)
        self.middle = nn.Linear(128,256)
        self.fc = nn.Linear(256,actions)
        
    def forward(self,x):
        x = nn.ReLU(self.input(x))
        x = nn.ReLU(self.middle(x))        
        return self.fc(x)
    


# Training

In [None]:
import torch
from torch.functional import F
from game.game_launch import GameLauncher
epsilon = 1.0       # Start epsilon at 1.0 for exploration
epsilon_min = 0.01  # Minimum epsilon for a reasonable amount of exploitation
epsilon_decay = 0.995
action_list = ['up', 'down', 'left', 'right']

def policy(state,action_list,inference_model):
    global epsilon  # Ensure epsilon is tracked across calls
    if torch.rand(1) < epsilon:
        epsilon = max(epsilon_min,epsilon * epsilon_decay)
        return random.randrange(len(action_list))
    else:
        return inference_model(state).detach().argmax().item()
    
    
    
def training_model(policy_net:DQN,target_q_model:DQN,game_instance:GameLauncher,lr,batch_size,episodes = 20,gamma=0.99,):
    optimizer = torch.optim.AdamW(policy_net.parameters(),lr=lr,)
    replay_buffer = ReplayBuffer(max_capacity=300)
    for episode in range(1,episodes + 1):
        done = False
        total_loss = 0
        while not done:
            current_state = game_instance.state_to_array()
            action = policy(state=current_state,action_list=action_list,inference_model=target_q_model)
            new_state,reward,done =game_instance.step(action)
            print(f"New state is {new_state} | Reward is {reward} | Done {done} ")
            replay_buffer.push(Transition(state=current_state,action=action,next_state=new_state,reward=reward))
            
            if replay_buffer.can_sample(batch_size):
                state_b,action_b,next_state_b,reward_b = replay_buffer.sample()
                action_pred_b = policy_net(state_b).gather(1, action_b)
                
                with torch.no_grad():
                    maximum_next_q_value = torch.max(target_q_model(next_state_b),dim=1,keepdim=True)[0]
                    target_q_value =  reward_b + (gamma * maximum_next_q_value * reward_b)
                
                criterion = nn.SmoothL1Loss()
                
                loss = criterion(action_b,target_q_value)
                loss.backward() # compute gradients for all parameters
                total_loss += loss.item()
                
                torch.nn.utils.clip_grad_value_(policy_net.parameters(), 100)
                optimizer.zero_grad()
                optimizer.step()
        print(f"Episode {episode} done with average error | {total_loss / len(replay_buffer)}")
        target_q_model.load_state_dict(policy_net.state_dict())
        torch.save(f"policy_episode({episode}).pth")
                
                
                
                
            