In [1]:
! pip install gym[atari]==0.23 shimmy
! pip install autorom
! AutoROM --accept-license

Collecting gym==0.23 (from gym[atari]==0.23)
  Downloading gym-0.23.0.tar.gz (624 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m624.4/624.4 kB[0m [31m9.4 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25h  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone
Collecting ale-py~=0.7.4 (from gym[atari]==0.23)
  Downloading ale_py-0.7.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (8.1 kB)
Downloading ale_py-0.7.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.6 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.6/1.6 MB[0m [31m49.6 MB/s[0m eta [36m0:00:00[0m
[?25hBuilding wheels for collected packages: gym
  Building wheel for gym (pyproject.toml) ... [?25ldone
[?25h  Created wheel for gym: filename=gym-0.23.0-py3-none-any.whl size=697635 sha256=27c371cd14cc076d87c55c630c9dab02490d15784c5cbb6

In [None]:
import gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from collections import deque
import random
from gym.spaces import Box
from gym.wrappers import FrameStack
import cv2
import matplotlib.pyplot as plt
from torch.distributions import Categorical

# Preprocessing wrapper (unchanged)
class PreprocessAtari(gym.Wrapper):
    def __init__(self, env):
        super().__init__(env)
        self.observation_space = Box(low=0, high=255, shape=(84, 84, 1), dtype=np.uint8)

    def preprocess(self, frame):
        frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
        frame = cv2.resize(frame, (84, 84), interpolation=cv2.INTER_AREA)
        return frame[:, :, None]  # Add channel dimension

    def reset(self):
        return self.preprocess(self.env.reset())

    def step(self, action):
        next_state, reward, done, info = self.env.step(action)
        return self.preprocess(next_state), reward, done, info

# Improved Adaptive Vision Expert
class AdaptiveVisionExpert(nn.Module):
    def __init__(self, h, w, outputs):
        super(AdaptiveVisionExpert, self).__init__()
        self.conv1 = nn.Conv2d(4, 32, kernel_size=8, stride=4)
        self.bn1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
        self.bn2 = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)
        self.bn3 = nn.BatchNorm2d(64)

        def conv2d_size_out(size, kernel_size=3, stride=1):
            return (size - (kernel_size - 1) - 1) // stride + 1

        convw = conv2d_size_out(conv2d_size_out(conv2d_size_out(w, 8, 4), 4, 2), 3, 1)
        convh = conv2d_size_out(conv2d_size_out(conv2d_size_out(h, 8, 4), 4, 2), 3, 1)
        linear_input_size = convw * convh * 64

        self.fc = nn.Linear(linear_input_size, outputs)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        return self.fc(x.view(x.size(0), -1))

# Improved DQN Expert
class DQNExpert(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(DQNExpert, self).__init__()
        self.fc1 = nn.Linear(state_dim, 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, action_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)

# Improved PPO Expert
class PPOExpert(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(PPOExpert, self).__init__()
        self.actor = nn.Sequential(
            nn.Linear(state_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, action_dim)
        )
        self.critic = nn.Sequential(
            nn.Linear(state_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 1)
        )

    def forward(self, state):
        action_logits = self.actor(state)
        action_probs = F.softmax(action_logits, dim=-1)
        state_values = self.critic(state)
        return action_probs, state_values

# Improved Gating Network
class GatingNetwork(nn.Module):
    def __init__(self, state_dim, num_experts):
        super(GatingNetwork, self).__init__()
        self.fc1 = nn.Linear(state_dim, 128)
        self.fc2 = nn.Linear(128, num_experts)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        return F.softmax(self.fc2(x), dim=-1)

class MoE:
    def __init__(self, env, state_dim, action_dim):
        self.env = env
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.vision_expert = AdaptiveVisionExpert(84, 84, state_dim).to(self.device)
        self.dqn_expert = DQNExpert(state_dim, action_dim).to(self.device)
        self.ppo_expert = PPOExpert(state_dim, action_dim).to(self.device)
        self.gating_network = GatingNetwork(state_dim, 2).to(self.device)

        self.vision_optimizer = optim.Adam(self.vision_expert.parameters(), lr=0.0001)
        self.dqn_optimizer = optim.Adam(self.dqn_expert.parameters(), lr=0.0001)
        self.ppo_optimizer = optim.Adam(self.ppo_expert.parameters(), lr=0.0001)
        self.gating_optimizer = optim.Adam(self.gating_network.parameters(), lr=0.0001)

        self.memory = deque(maxlen=100000)
        self.batch_size = 32
        self.gamma = 0.99
        self.epsilon = 1.0
        self.epsilon_min = 0.01
        self.epsilon_decay = 0.995
        self.tau = 0.001

        self.target_vision_expert = AdaptiveVisionExpert(84, 84, state_dim).to(self.device)
        self.target_dqn_expert = DQNExpert(state_dim, action_dim).to(self.device)
        self.target_ppo_expert = PPOExpert(state_dim, action_dim).to(self.device)
        self.update_target_networks(tau=1.0)

        self.global_step = 0
        self.eval_scores = []
        self.eval_episodes = []

        # PPO specific parameters
        self.ppo_clip_epsilon = 0.2
        self.ppo_epochs = 4
        self.ppo_entropy_coef = 0.01

    def update_target_networks(self, tau=None):
        if tau is None:
            tau = self.tau

        for target_param, param in zip(self.target_vision_expert.parameters(), self.vision_expert.parameters()):
            target_param.data.copy_(tau * param.data + (1.0 - tau) * target_param.data)

        for target_param, param in zip(self.target_dqn_expert.parameters(), self.dqn_expert.parameters()):
            target_param.data.copy_(tau * param.data + (1.0 - tau) * target_param.data)

        for target_param, param in zip(self.target_ppo_expert.parameters(), self.ppo_expert.parameters()):
            target_param.data.copy_(tau * param.data + (1.0 - tau) * target_param.data)

    def preprocess_state(self, state):
        state = np.array(state)

        if len(state.shape) == 5 and state.shape[-1] == 1:
            state = state.squeeze(-1)
            state = state.transpose(0, 2, 3, 1)
        elif len(state.shape) == 4 and state.shape[-1] == 1:
            state = state.squeeze(-1)
            state = state.transpose(1, 2, 0)
            state = np.expand_dims(state, axis=0)
        elif len(state.shape) == 3 and state.shape[0] == 4:
            state = state.transpose((1, 2, 0))
            state = np.expand_dims(state, axis=0)
        elif len(state.shape) == 3 and state.shape[-1] == 4:
            state = np.expand_dims(state, axis=0)
        else:
            raise ValueError(f"Unexpected state shape: {state.shape}")

        state = torch.FloatTensor(state).to(self.device) / 255.0
        state = state.permute(0, 3, 1, 2)

        return state

    def select_action(self, state):
        if random.random() < self.epsilon:
            return random.randrange(self.action_dim)

        with torch.no_grad():
            state = self.preprocess_state(state)
            structured_state = self.vision_expert(state)
            dqn_action = self.dqn_expert(structured_state)
            ppo_action, _ = self.ppo_expert(structured_state)
            expert_outputs = torch.stack([dqn_action, ppo_action], dim=1)
            expert_probs = self.gating_network(structured_state).unsqueeze(-1)
            final_action = torch.sum(expert_outputs * expert_probs, dim=1)

        return final_action.max(1)[1].item()

    def update(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))
        if len(self.memory) < self.batch_size:
            return

        batch = random.sample(self.memory, self.batch_size)
        state, action, reward, next_state, done = map(np.stack, zip(*batch))

        state = self.preprocess_state(state)
        next_state = self.preprocess_state(next_state)
        action = torch.LongTensor(action).to(self.device)
        reward = torch.FloatTensor(reward).to(self.device)
        done = torch.FloatTensor(done).to(self.device)

        structured_state = self.vision_expert(state)
        structured_next_state = self.target_vision_expert(next_state)

        # DQN update
        current_q_values = self.dqn_expert(structured_state)
        next_q_values = self.target_dqn_expert(structured_next_state).max(1)[0].detach()
        target_q_values = reward + (1 - done) * self.gamma * next_q_values

        dqn_loss = F.smooth_l1_loss(current_q_values.gather(1, action.unsqueeze(1)), target_q_values.unsqueeze(1))

        # PPO update
        ppo_action_probs, state_values = self.ppo_expert(structured_state)
        old_action_probs = ppo_action_probs.detach()
        old_state_values = state_values.detach()

        advantages = target_q_values - old_state_values.squeeze()
        
        ppo_loss = 0
        for _ in range(self.ppo_epochs):
            new_action_probs, new_state_values = self.ppo_expert(structured_state)
            
            # Add a small epsilon to prevent division by zero
            ratio = (new_action_probs.gather(1, action.unsqueeze(1)) + 1e-8) / (old_action_probs.gather(1, action.unsqueeze(1)) + 1e-8)
            surr1 = ratio * advantages.unsqueeze(1)
            surr2 = torch.clamp(ratio, 1 - self.ppo_clip_epsilon, 1 + self.ppo_clip_epsilon) * advantages.unsqueeze(1)
            
            actor_loss = -torch.min(surr1, surr2).mean()
            critic_loss = F.mse_loss(new_state_values.squeeze(), target_q_values)
            
            # Clip action probabilities to prevent log(0)
            clipped_probs = torch.clamp(new_action_probs, 1e-10, 1.0)
            entropy = -(clipped_probs * torch.log(clipped_probs)).sum(dim=-1).mean()
            
            ppo_loss += actor_loss + 0.5 * critic_loss - self.ppo_entropy_coef * entropy

        ppo_loss /= self.ppo_epochs

        # Gating network update
        expert_outputs = torch.stack([current_q_values, ppo_action_probs], dim=1)
        expert_probs = self.gating_network(structured_state).unsqueeze(-1)
        gating_loss = F.mse_loss(torch.sum(expert_outputs * expert_probs, dim=1).gather(1, action.unsqueeze(1)), target_q_values.unsqueeze(1))

        # Vision expert update
        vision_loss = F.mse_loss(structured_state, structured_next_state)

        # Combine all losses
        total_loss = vision_loss + dqn_loss + ppo_loss + gating_loss

        # Optimize
        self.vision_optimizer.zero_grad()
        self.dqn_optimizer.zero_grad()
        self.ppo_optimizer.zero_grad()
        self.gating_optimizer.zero_grad()

        total_loss.backward()

        # Clip gradients to prevent exploding gradients
        torch.nn.utils.clip_grad_norm_(self.vision_expert.parameters(), max_norm=1.0)
        torch.nn.utils.clip_grad_norm_(self.dqn_expert.parameters(), max_norm=1.0)
        torch.nn.utils.clip_grad_norm_(self.ppo_expert.parameters(), max_norm=1.0)
        torch.nn.utils.clip_grad_norm_(self.gating_network.parameters(), max_norm=1.0)

        self.vision_optimizer.step()
        self.dqn_optimizer.step()
        self.ppo_optimizer.step()
        self.gating_optimizer.step()

        # Update target networks
        self.update_target_networks()

        # Decay epsilon
        self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)

    def evaluate(self, num_episodes=10):
        total_rewards = []
        for _ in range(num_episodes):
            state = self.env.reset()
            total_reward = 0
            done = False
            while not done:
                action = self.select_action(state)
                next_state, reward, done, _ = self.env.step(action)
                total_reward += reward
                state = next_state
            total_rewards.append(total_reward)
        return np.mean(total_rewards), np.std(total_rewards)
    
    def train(self, num_episodes, max_steps_per_episode=1000, eval_frequency=100):
        for episode in range(num_episodes):
            state = self.env.reset()
            total_reward = 0
            highscore = 0
            counter = 0
            done = False
            steps = 0

            while not done and counter < max_steps_per_episode:
                action = self.select_action(state)
                next_state, reward, done, _ = self.env.step(action)
                total_reward += reward
                if total_reward > highscore:
                    highscore = total_reward
                    counter = 0
                else:
                    counter += 1
                self.update(state, action, reward, next_state, done)
                state = next_state
                steps += 1
                self.global_step += 1

            print(f"Episode {episode + 1}, Total Reward: {total_reward}, Steps: {steps}, Global Steps: {self.global_step}, Epsilon: {self.epsilon:.2f}")

            if (episode + 1) % eval_frequency == 0:
                eval_mean, eval_std = self.evaluate()
                self.eval_scores.append(eval_mean)
                self.eval_episodes.append(episode + 1)
                print(f"Evaluation at episode {episode + 1}: Mean reward: {eval_mean:.2f} (+/- {eval_std:.2f})")
                self.save_model(f"moe_model_episode_{episode + 1}.pth")
                self.plot_evaluation()

    def plot_evaluation(self):
        plt.figure(figsize=(10, 5))
        plt.plot(self.eval_episodes, self.eval_scores)
        plt.title("Evaluation Scores During Training")
        plt.xlabel("Episode")
        plt.ylabel("Mean Evaluation Score")
        plt.savefig("evaluation_plot.png")
        plt.close()

    def save_model(self, path):
        torch.save({
            'vision_expert': self.vision_expert.state_dict(),
            'dqn_expert': self.dqn_expert.state_dict(),
            'ppo_expert': self.ppo_expert.state_dict(),
            'gating_network': self.gating_network.state_dict(),
            'epsilon': self.epsilon,
            'global_step': self.global_step,
            'eval_scores': self.eval_scores,
            'eval_episodes': self.eval_episodes
        }, path)

    def load_model(self, path):
        checkpoint = torch.load(path)
        self.vision_expert.load_state_dict(checkpoint['vision_expert'])
        self.dqn_expert.load_state_dict(checkpoint['dqn_expert'])
        self.ppo_expert.load_state_dict(checkpoint['ppo_expert'])
        self.gating_network.load_state_dict(checkpoint['gating_network'])
        self.epsilon = checkpoint['epsilon']
        self.global_step = checkpoint['global_step']
        self.eval_scores = checkpoint['eval_scores']
        self.eval_episodes = checkpoint['eval_episodes']

# Main training loop
if __name__ == "__main__":
    env = gym.make("PongNoFrameskip-v4")
    env = PreprocessAtari(env)
    env = FrameStack(env, 4)

    state_dim = 84 * 84 * 4  # 4 stacked frames, each 84x84
    action_dim = env.action_space.n

    moe = MoE(env, state_dim, action_dim)
    moe.train(5000, max_steps_per_episode=1000, eval_frequency=100)  # Train for 5000 episodes, evaluate every 100 episodes

    env.close()

Episode 1, Total Reward: -6.0, Steps: 1000, Global Steps: 1000, Epsilon: 0.01
Episode 2, Total Reward: -5.0, Steps: 1000, Global Steps: 2000, Epsilon: 0.01
Episode 3, Total Reward: -6.0, Steps: 1000, Global Steps: 3000, Epsilon: 0.01
Episode 4, Total Reward: -4.0, Steps: 1000, Global Steps: 4000, Epsilon: 0.01
Episode 5, Total Reward: 0.0, Steps: 1000, Global Steps: 5000, Epsilon: 0.01
