In [1]:
!pip install swig
!pip install gymnasium[box2d]
!pip install moviepy

Collecting swig
  Downloading swig-4.2.1-py2.py3-none-manylinux_2_5_x86_64.manylinux1_x86_64.whl.metadata (3.6 kB)
Downloading swig-4.2.1-py2.py3-none-manylinux_2_5_x86_64.manylinux1_x86_64.whl (1.9 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.9/1.9 MB[0m [31m17.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: swig
Successfully installed swig-4.2.1
Collecting gymnasium[box2d]
  Downloading gymnasium-0.29.1-py3-none-any.whl.metadata (10 kB)
Collecting farama-notifications>=0.0.1 (from gymnasium[box2d])
  Downloading Farama_Notifications-0.0.4-py3-none-any.whl.metadata (558 bytes)
Collecting box2d-py==2.3.5 (from gymnasium[box2d])
  Downloading box2d-py-2.3.5.tar.gz (374 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m374.4/374.4 kB[0m [31m8.5 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Downloading Farama_Notifications-0.0.4-py3-none-any.whl (2.5 kB)
Downloading gymnasium-0.29.

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
import gymnasium as gym
from moviepy.editor import ImageSequenceClip

In [3]:
class StateEncoder(nn.Module):
    def __init__(self, latent_dim):
        super(StateEncoder, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=8, stride=4)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)


        self._initialize_weights()
        with torch.no_grad():
            self.flattened_size = self._compute_flattened_size((3, 96, 96))

        self.fc = nn.Linear(self.flattened_size, latent_dim)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

    def _compute_flattened_size(self, input_shape):
        x = torch.zeros(1, *input_shape)
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        return x.numel()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

In [4]:
class TransitionModel(nn.Module):
    def __init__(self, latent_dim, action_dim):
        super(TransitionModel, self).__init__()
        self.fc1 = nn.Linear(latent_dim + action_dim, 128)
        self.fc2 = nn.Linear(128, latent_dim)

    def forward(self, state, action):
        x = torch.cat([state, action], dim=1)
        x = F.relu(self.fc1(x))
        next_state = self.fc2(x)
        return next_state

In [5]:
class RewardModel(nn.Module):
    def __init__(self, latent_dim, action_dim):
        super(RewardModel, self).__init__()
        self.fc1 = nn.Linear(latent_dim + action_dim, 128)
        self.fc2 = nn.Linear(128, 1)

    def forward(self, state, action):
        x = torch.cat([state, action], dim=1)
        x = F.relu(self.fc1(x))
        reward = self.fc2(x)
        return reward

In [6]:
class WorldModel(nn.Module):
    def __init__(self, latent_dim, action_dim):
        super(WorldModel, self).__init__()
        self.encoder = StateEncoder(latent_dim)
        self.transition = TransitionModel(latent_dim, action_dim)
        self.reward = RewardModel(latent_dim, action_dim)

    def forward(self, state, action):
        latent_state = self.encoder(state)
        next_latent_state = self.transition(latent_state, action)
        reward = self.reward(latent_state, action)
        return next_latent_state, reward

    def predict(self, latent_state, action):
        next_latent_state = self.transition(latent_state, action)
        reward = self.reward(latent_state, action)
        return next_latent_state, reward

In [7]:
def compute_loss(world_model, state, action, next_state, reward):
    latent_state = world_model.encoder(state)
    next_latent_state_pred, reward_pred = world_model(state, action)

    latent_next_state = world_model.encoder(next_state)
    state_loss = F.mse_loss(next_latent_state_pred, latent_next_state)
    reward_loss = F.mse_loss(reward_pred, reward)

    return state_loss + reward_loss

In [8]:

env = gym.make('CarRacing-v2', render_mode='rgb_array')
world_model = WorldModel(latent_dim=64, action_dim=3)


In [9]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
world_model.to(device)

optimizer = Adam(world_model.parameters(), lr=1e-3)

num_episodes = 100
max_steps_per_episode = 1000

for episode in range(num_episodes):
    state, _ = env.reset()
    done = False
    total_reward = 0
    step_count = 0

    while not done and step_count < max_steps_per_episode:
        action = env.action_space.sample()
        next_state, reward, done, _, _ = env.step(action)


        state_tensor = torch.tensor(state, dtype=torch.float32).permute(2, 0, 1).unsqueeze(0).to(device)
        action_tensor = torch.tensor(action, dtype=torch.float32).unsqueeze(0).to(device)
        next_state_tensor = torch.tensor(next_state, dtype=torch.float32).permute(2, 0, 1).unsqueeze(0).to(device)
        reward_tensor = torch.tensor([reward], dtype=torch.float32).unsqueeze(0).to(device)


        loss = compute_loss(world_model, state_tensor, action_tensor, next_state_tensor, reward_tensor)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        state = next_state
        total_reward += reward
        step_count += 1

        print(f'Episode {episode + 1}, Step {step_count}, Reward: {reward}, Total Reward: {total_reward}')

    print(f'Episode {episode + 1} completed, Loss: {loss.item()}, Total Reward: {total_reward}, Steps: {step_count}')

    state, _ = env.reset()



[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Episode 96, Step 6, Reward: -0.09999999999999964, Total Reward: 6.620216606498197
Episode 96, Step 7, Reward: -0.09999999999999964, Total Reward: 6.520216606498197
Episode 96, Step 8, Reward: -0.09999999999999964, Total Reward: 6.420216606498197
Episode 96, Step 9, Reward: -0.09999999999999964, Total Reward: 6.320216606498198
Episode 96, Step 10, Reward: -0.09999999999999964, Total Reward: 6.220216606498198
Episode 96, Step 11, Reward: -0.09999999999999964, Total Reward: 6.120216606498198
Episode 96, Step 12, Reward: -0.09999999999999964, Total Reward: 6.020216606498199
Episode 96, Step 13, Reward: -0.09999999999999964, Total Reward: 5.920216606498199
Episode 96, Step 14, Reward: -0.09999999999999964, Total Reward: 5.8202166064981995
Episode 96, Step 15, Reward: -0.09999999999999964, Total Reward: 5.7202166064982
Episode 96, Step 16, Reward: -0.09999999999999964, Total Reward: 5.6202166064982
Episode 96, Step 17, Reward: 

In [10]:
def inference(env, world_model, device, frame_skip=2):
    state, _ = env.reset()
    done = False
    frames = []
    step_count = 0

    while not done:

        state_tensor = torch.tensor(state, dtype=torch.float32).permute(2, 0, 1).unsqueeze(0).to(device)


        latent_state = world_model.encoder(state_tensor)


        action = env.action_space.sample()
        action_tensor = torch.tensor(action, dtype=torch.float32).unsqueeze(0).to(device)


        with torch.no_grad():
            next_latent_state, reward = world_model.predict(latent_state, action_tensor)


        next_state, reward, done, _, _ = env.step(action)


        if step_count % frame_skip == 0:
            frame = env.render()
            if frame is not None:
                frames.append(frame)

        state = next_state
        step_count += 1

    return frames

In [11]:

frames = inference(env, world_model, device, frame_skip=2)
clip = ImageSequenceClip(frames, fps=30)
clip.write_videofile("car_racing_inference.mp4", codec="libx264")


Moviepy - Building video car_racing_inference.mp4.
Moviepy - Writing video car_racing_inference.mp4





Moviepy - Done !
Moviepy - video ready car_racing_inference.mp4
