In [213]:
import torch
from torch import nn
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
import gymnasium as gym
from tqdm import tqdm
import numpy as np

## Model

In [214]:
# Get cpu, gpu or mps device for training.
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

# Define model
class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(8, 30),
            nn.ReLU(),
            nn.Linear(30, 30),
            nn.ReLU(),
            nn.Linear(30, 4)
        )

    def forward(self, x):
        logits = self.linear_relu_stack(x)
        return logits

model = NeuralNetwork().to(device)
print(model)

Using cpu device
NeuralNetwork(
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=8, out_features=30, bias=True)
    (1): ReLU()
    (2): Linear(in_features=30, out_features=30, bias=True)
    (3): ReLU()
    (4): Linear(in_features=30, out_features=4, bias=True)
  )
)


### Define Loss Function and Optimizer

In [215]:
loss_fn = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

### Create Gymnasium Environment Wrapper Class

In [216]:
class LunarLander(gym.Wrapper):
   def __init__(self, env):
      super(LunarLander, self).__init__(env)
      self.observation, info = env.reset()

   def step_env(self,action:int):
      reset = False
      next_observation, reward, terminated, truncated, info = self.step(action)
      observation = self.observation
      self.observation = next_observation

      if terminated or truncated:
         self.observation, info = self.reset()
         reset = True

      return reward, observation, next_observation, reset
   
   @property
   def terminated(self):
      return self.terminated or self.truncated
   
env = gym.make("LunarLander-v2")
lander = LunarLander(env)

In [220]:
def train(model, loss_fn, optimizer):
    model.train()
    size = 100_000
    gamma = 0.4
    epsillon = 0.4
    for step in tqdm(range(size)):
        # Compute prediction error
        pred = model(torch.from_numpy(lander.observation))
        pred = pred/pred.sum()
        
        #run environment
        action = torch.argmax(pred).item()
        action = action if np.random.rand() > epsillon else np.random.randint(0,4)
        reward, observation, next_observation, reset = lander.step_env(action)
        next_pred = model(torch.from_numpy(lander.observation))
        next_pred = next_pred/next_pred.sum()
        update_pred = pred.clone()
        update_pred[action] = reward + gamma*next_pred[action]
        update_pred = update_pred/update_pred.sum()

        loss = loss_fn(update_pred, pred)

        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if step % (size // 10) == 0:
            loss, current = loss.item(), step
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

In [221]:
epochs = 5
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train( model, loss_fn, optimizer)
print("Done!")

Epoch 1
-------------------------------


  0%|          | 60/100000 [00:00<03:01, 551.87it/s]

loss: 0.045858  [    0/100000]


 10%|█         | 10093/100000 [00:13<02:15, 662.56it/s]

loss: 0.336023  [10000/100000]


 20%|██        | 20091/100000 [00:33<02:51, 465.07it/s]

loss: 0.033938  [20000/100000]


 30%|███       | 30016/100000 [00:56<03:05, 378.28it/s]

loss: 0.037189  [30000/100000]


 40%|████      | 40035/100000 [01:18<02:18, 433.85it/s]

loss: 0.044390  [40000/100000]


 50%|█████     | 50047/100000 [01:41<01:45, 471.60it/s]

loss: 0.046345  [50000/100000]


 60%|██████    | 60068/100000 [02:04<01:26, 459.61it/s]

loss: 0.038948  [60000/100000]


 70%|███████   | 70047/100000 [02:27<01:18, 382.20it/s]

loss: 0.334749  [70000/100000]


 80%|████████  | 80063/100000 [02:50<00:40, 490.61it/s]

loss: 0.032067  [80000/100000]


 90%|█████████ | 90001/100000 [03:12<00:21, 469.24it/s]

loss: 0.327578  [90000/100000]


100%|██████████| 100000/100000 [03:34<00:00, 466.76it/s]


Epoch 2
-------------------------------


  0%|          | 81/100000 [00:00<02:17, 728.60it/s]

loss: 0.031174  [    0/100000]


 10%|█         | 10122/100000 [00:12<01:47, 835.88it/s]

loss: 0.038481  [10000/100000]


 20%|██        | 20074/100000 [00:25<01:43, 769.89it/s]

loss: 0.330814  [20000/100000]


 30%|███       | 30096/100000 [00:37<01:27, 801.19it/s]

loss: 0.033285  [30000/100000]


 40%|████      | 40100/100000 [00:50<01:14, 799.01it/s]

loss: 0.045559  [40000/100000]


 50%|█████     | 50100/100000 [01:02<01:02, 800.27it/s]

loss: 0.327339  [50000/100000]


 60%|██████    | 60102/100000 [01:15<00:51, 769.37it/s]

loss: 0.043239  [60000/100000]


 70%|███████   | 70123/100000 [01:28<00:38, 786.03it/s]

loss: 0.330781  [70000/100000]


 80%|████████  | 80066/100000 [01:41<00:25, 783.23it/s]

loss: 0.034922  [80000/100000]


 90%|█████████ | 90098/100000 [01:54<00:13, 743.13it/s]

loss: 0.043901  [90000/100000]


100%|██████████| 100000/100000 [02:07<00:00, 785.45it/s]


Epoch 3
-------------------------------


  0%|          | 77/100000 [00:00<02:18, 723.47it/s]

loss: 0.316395  [    0/100000]


 10%|█         | 10145/100000 [00:13<01:55, 778.95it/s]

loss: 0.040787  [10000/100000]


 20%|██        | 20097/100000 [00:26<01:43, 772.46it/s]

loss: 0.045585  [20000/100000]


 30%|███       | 30071/100000 [00:39<01:31, 765.63it/s]

loss: 0.032491  [30000/100000]


 40%|████      | 40119/100000 [00:52<01:16, 780.01it/s]

loss: 0.333065  [40000/100000]


 50%|█████     | 50159/100000 [01:06<01:04, 777.78it/s]

loss: 0.045575  [50000/100000]


 60%|██████    | 60063/100000 [01:18<00:52, 761.81it/s]

loss: 0.318828  [60000/100000]


 70%|███████   | 70128/100000 [01:32<00:38, 767.09it/s]

loss: 0.035471  [70000/100000]


 80%|████████  | 80114/100000 [01:45<00:26, 757.11it/s]

loss: 0.043638  [80000/100000]


 90%|█████████ | 90083/100000 [01:58<00:12, 770.31it/s]

loss: 0.040505  [90000/100000]


100%|██████████| 100000/100000 [02:11<00:00, 759.63it/s]


Epoch 4
-------------------------------


  0%|          | 43/100000 [00:00<03:53, 428.08it/s]

loss: 0.037365  [    0/100000]


 10%|█         | 10144/100000 [00:13<01:55, 779.06it/s]

loss: 0.032534  [10000/100000]


 20%|██        | 20081/100000 [00:26<01:58, 675.73it/s]

loss: 0.031345  [20000/100000]


 30%|███       | 30089/100000 [00:40<01:31, 767.68it/s]

loss: 0.043148  [30000/100000]


 40%|████      | 40084/100000 [00:53<01:17, 773.91it/s]

loss: 0.033228  [40000/100000]


 50%|█████     | 50123/100000 [01:06<01:04, 769.33it/s]

loss: 0.031223  [50000/100000]


 60%|██████    | 60088/100000 [01:19<00:52, 759.07it/s]

loss: 0.045313  [60000/100000]


 70%|███████   | 70075/100000 [01:32<00:38, 775.64it/s]

loss: 0.026735  [70000/100000]


 80%|████████  | 80100/100000 [01:45<00:26, 759.89it/s]

loss: 0.034863  [80000/100000]


 90%|█████████ | 90107/100000 [01:58<00:12, 783.13it/s]

loss: 0.043139  [90000/100000]


100%|██████████| 100000/100000 [02:12<00:00, 757.51it/s]


Epoch 5
-------------------------------


  0%|          | 71/100000 [00:00<02:21, 706.44it/s]

loss: 0.342900  [    0/100000]


 10%|█         | 10089/100000 [00:14<01:57, 764.80it/s]

loss: 0.325983  [10000/100000]


 20%|██        | 20126/100000 [00:27<01:42, 775.59it/s]

loss: 0.031523  [20000/100000]


 30%|███       | 30134/100000 [00:41<01:30, 775.11it/s]

loss: 0.043591  [30000/100000]


 40%|████      | 40135/100000 [00:54<01:19, 753.17it/s]

loss: 0.036888  [40000/100000]


 50%|█████     | 50060/100000 [01:13<01:52, 444.46it/s]

loss: 0.035248  [50000/100000]


 60%|██████    | 60069/100000 [01:34<01:25, 465.19it/s]

loss: 0.041193  [60000/100000]


 70%|███████   | 70045/100000 [01:56<01:05, 456.64it/s]

loss: 0.345478  [70000/100000]


 80%|████████  | 80063/100000 [02:18<00:44, 444.25it/s]

loss: 0.339441  [80000/100000]


 90%|█████████ | 90051/100000 [02:39<00:22, 437.83it/s]

loss: 0.043012  [90000/100000]


100%|██████████| 100000/100000 [03:01<00:00, 550.18it/s]

Done!





In [438]:
gamma = 0.55
epsillon = 0.1
pred = model(torch.from_numpy(lander.observation))

#run environment
action = torch.argmax(pred).item()
action = action if np.random.rand() > epsillon else np.random.randint(0,4)
reward, observation, next_observation, reset = lander.step_env(0)
next_pred = model(torch.from_numpy(lander.observation))
update_pred = pred.clone()
update_pred[action] = reward + gamma*next_pred[action]

loss = loss_fn(pred, update_pred)
print(action, reward)

2 -0.9421270524817942


In [222]:

env = gym.make("LunarLander-v2", render_mode="human")
observation, info = env.reset(seed=42)
for _ in range(1000):
   pred = model(torch.from_numpy(lander.observation))
   action = torch.argmax(pred).item()
   observation, reward, terminated, truncated, info = env.step(action)

   if terminated or truncated:
      observation, info = env.reset()

env.close()

KeyboardInterrupt: 

In [None]:
torch.save(model.state_dict(), "model.pth")
print("Saved PyTorch Model State to model.pth")

In [None]:
model = NeuralNetwork().to(device)
model.load_state_dict(torch.load("model.pth"))

In [None]:
model.eval()
x, y = test_data[0][0], test_data[0][1]
with torch.no_grad():
    x = x.to(device)
    pred = model(x)
    predicted, actual = classes[pred[0].argmax(0)], classes[y]
    print(f'Predicted: "{predicted}", Actual: "{actual}"')