In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

import matplotlib.pyplot as plt

from ale_py import ALEInterface
from ale_py.roms import LostLuggage
import gymnasium as gym
from tqdm import tqdm




In [2]:
ale = ALEInterface()
ale.loadROM(LostLuggage)

In [3]:
env = gym.make('ALE/LostLuggage-v5', render_mode='human')

In [4]:
class PolicyNetwork(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(PolicyNetwork, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, output_size)
        self.softmax = nn.Softmax(dim=1)
    
    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return self.softmax(x)
    

In [7]:
def reinforce():
    input_size = env.action_space.n
    hidden_size = input_size * 4
    output_size = input_size
    learning_rate = 0.01
    num_episodes = 50

    poilicy_network = PolicyNetwork(input_size, hidden_size, output_size)
    optimizer = optim.Adam(poilicy_network.parameters(), lr=learning_rate)

    for _ in range(num_episodes):
        env.reset()
        actions = np.zeros((env.action_space.n,))
        log_probs = []
        rewards = []
        is_done = False
        while not is_done:
            action_tensor = torch.from_numpy(actions).float().unsqueeze(0)
            action_probs = poilicy_network(action_tensor)
            action = torch.multinomial(action_probs, 1).item()
            log_probs.append(torch.log(action_probs[0][action]))
            _, reward, is_done, _, _ = env.step(action)
            rewards.append(reward)

        returns = []
        G = 0
        for reward in rewards[::-1]:
            G = reward + 0.9 * G
            returns.insert(0, G)

        policy_loss = []

        for log_prob, G in zip(log_probs, returns):
            policy_loss.append(-log_prob * G)
        
        optimizer.zero_grad()
        policy_loss = torch.cat(policy_loss).sum() # <--
        policy_loss.backward()
        optimizer.step()

        print(f'{_ + 1}, Loss {policy_loss.item()}')

In [8]:
reinforce()

RuntimeError: zero-dimensional tensor (at position 0) cannot be concatenated