In [25]:
import gym
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions import Categorical
from gym_pikachu_volleyball.envs.pikachu_volleyball import PikachuVolleyballMultiEnv

# Define the A2C actor and critic networks
class ActorCritic(nn.Module):
    def __init__(self, num_inputs, num_actions, hidden_size):
        super(ActorCritic, self).__init__()
        self.fc1 = nn.Linear(num_inputs, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.actor = nn.Linear(hidden_size, num_actions)
        self.critic = nn.Linear(hidden_size, 1)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.actor(x), self.critic(x)

# Set hyperparameters
lr = 0.0001
gamma = 0.99
num_steps = 5
max_episodes = 10000

# Create the environment
env = PikachuVolleyballMultiEnv(render_mode='human')

# Initialize the actor-critic network and optimizer
num_inputs = env.observation_space.shape[0]
num_actions = sum(env.action_space.shape)
hidden_size = 256
model = ActorCritic(num_inputs, num_actions, hidden_size)
optimizer = optim.Adam(model.parameters(), lr=lr)

# Define the A2C update function
def update(model, optimizer, states, actions, rewards, next_states, masks):
    # Preprocess the observation tensor
    states = torch.FloatTensor(states).view(-1, 304 * 432 * 3)
    next_states = torch.FloatTensor(next_states).view(-1, 304 * 432 * 3)

    _, critic_values = model(states)
    _, next_critic_values = model(next_states)

    td_targets = rewards + gamma * next_critic_values * masks
    td_errors = td_targets - critic_values

    actor_logits, _ = model(states)
    dist = Categorical(logits=actor_logits)
    log_probs = dist.log_prob(actions)

    actor_loss = -(log_probs * td_errors.detach()).mean()
    critic_loss = F.smooth_l1_loss(critic_values, td_targets.detach())

    loss = actor_loss + critic_loss
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()


     
# Train the A2C agent
episode_rewards = []
for episode in range(max_episodes):
    state = env.reset(option)
    done = False
    episode_reward = 0

    while not done:
        # Collect experience by taking num_steps in the environment
        states = []
        actions = []
        rewards = []
        masks = []
        for _ in range(num_steps):
            states.append(state)
            actor_logits, critic_value = model(torch.FloatTensor(state))
            dist = Categorical(logits=actor_logits)
            action = dist.sample().numpy()
            next_state, reward, done, _ = env.step(action)
            actions.append(action)
            rewards.append(reward)
            masks.append([1.0] * sum(env.action_space.shape))
            state = next_state
            episode_reward += reward
            if done:
                break
        next_states = state

        # Update the actor-critic network using the collected experience
        update(model, optimizer, torch.FloatTensor(states), torch.FloatTensor(actions),
               torch.FloatTensor(rewards), torch.FloatTensor(next_states), torch.FloatTensor(masks))

    episode_rewards.append(episode_reward)
    print('Episode %d | Reward: %d' % (episode, episode_reward))

    # Save the model weights every 100 episodes
    if episode % 100 == 0:
        torch.save(model.state_dict(), 'a2c_weights.pth')

RuntimeError: mat1 and mat2 shapes cannot be multiplied (131328x3 and 304x256)