In [3]:
import random
import numpy as np
import gymnasium as gym
import torch
import torch.nn as nn
import torch.optim as optim

In [11]:
def policy(env, observation):
    return env.action_space.sample()  # agent policy that uses the observation and info

In [19]:
def run(policy, N=1000):
    env = gym.make("LunarLander-v2") #, render_mode="human")
    
    episodes = []
    current = []    
    observation, info = env.reset()
    for _ in range(N):
        action = policy(env, observation)
        next_observation, reward, terminated, truncated, next_info = env.step(action)
        if terminated or truncated:
            current.append((observation, action, reward, next_observation, True))
            episodes.append(current)
            current = []
            observation, info = env.reset()
        else:
            current.append((observation, action, reward, next_observation, False))
    if len(current)>0:
        episodes.append(current)
    env.close()
    return episodes

In [20]:
episodes = run(policy)

In [30]:
class DQN:
    BATCH_SIZE = 10
    
    def __init__(self):
        self.env = gym.make("LunarLander-v2") #, render_mode="human")
        self.model = nn.Sequential(
            nn.Linear(8, 4))
        self.loss = nn.MSELoss()
        self.optimizer = optim.SGD(self.model.parameters())
        
    def policy(self, observation):
        return self.env.action_space.sample()

    def run(self, Ne=10):
        self.pool = []
        for i in range(Ne):
            observation, info = self.env.reset()
            while True:
                action = self.policy(observation)
                next_observation, reward, terminated, truncated, next_info = self.env.step(action)
                if terminated or truncated:
                    self.pool.append((observation, action, reward, next_observation, True))
                    break
                else:
                    self.pool.append((observation, action, reward, next_observation, False))
                self.train()

    def train(self):
        if len(self.pool)>=self.BATCH_SIZE:
            minibatch = random.sample(self.pool, self.BATCH_SIZE)
            inputs = torch.tensor(minibatch)
            
            
        self.optimizer.zero_grad()
        # ...
        outputs = self.model(inputs)


In [None]:
agent = DQN()
agent.run()

In [26]:
len(agent.pool)

831