In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split, TensorDataset

In [2]:
# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# ----------------------------
# Environment (Synthetic Data)
# ----------------------------
class SimpleEnv:
    def __init__(self):
        self.state_space = 4
        self.action_space = 2

    def reset(self):
        return np.random.rand(self.state_space)

    def step(self, action):
        next_state = np.random.rand(self.state_space)
        reward = np.random.rand()
        done = np.random.choice([False, True], p=[0.9, 0.1])
        return next_state, reward, done

env = SimpleEnv()

# -----------------
# Preprocessing
# -----------------
def preprocess_state(state):
    return torch.tensor(state, dtype=torch.float32)

# Generate synthetic dataset for demonstration
states = [preprocess_state(env.reset()) for _ in range(500)]
actions = [torch.tensor(np.random.randint(env.action_space), dtype=torch.long) for _ in range(500)]
rewards = [torch.tensor(np.random.rand(), dtype=torch.float32) for _ in range(500)]

dataset = TensorDataset(torch.stack(states), torch.stack(actions), torch.stack(rewards))

# -----------------
# Train-Test Split
# -----------------
train_size = int(0.8 * len(dataset))
val_size = int(0.1 * len(dataset))
test_size = len(dataset) - train_size - val_size
train_set, val_set, test_set = random_split(dataset, [train_size, val_size, test_size])

train_loader = DataLoader(train_set, batch_size=32, shuffle=True)
val_loader = DataLoader(val_set, batch_size=32)
test_loader = DataLoader(test_set, batch_size=32)

# -------------------------------
# Base Model Networks Definition
# -------------------------------
class RepresentationNetwork(nn.Module):
    def __init__(self, state_dim, hidden_dim=64):
        super().__init__()
        self.fc = nn.Linear(state_dim, hidden_dim)

    def forward(self, state):
        return torch.relu(self.fc(state))

class DynamicsNetwork(nn.Module):
    def __init__(self, hidden_dim=64, action_dim=2):
        super().__init__()
        self.fc_state = nn.Linear(hidden_dim + action_dim, hidden_dim)
        self.fc_reward = nn.Linear(hidden_dim, 1)

    def forward(self, latent_state, action_onehot):
        x = torch.cat([latent_state, action_onehot], dim=-1)
        next_state = torch.relu(self.fc_state(x))
        reward = self.fc_reward(next_state)
        return next_state, reward

class PredictionNetwork(nn.Module):
    def __init__(self, hidden_dim=64, action_dim=2):
        super().__init__()
        self.fc_policy = nn.Linear(hidden_dim, action_dim)
        self.fc_value = nn.Linear(hidden_dim, 1)

    def forward(self, latent_state):
        policy = torch.softmax(self.fc_policy(latent_state), dim=-1)
        value = self.fc_value(latent_state)
        return policy, value

# Initialize networks
state_dim = env.state_space
action_dim = env.action_space
hidden_dim = 64

representation_net = RepresentationNetwork(state_dim, hidden_dim)
dynamics_net = DynamicsNetwork(hidden_dim, action_dim)
prediction_net = PredictionNetwork(hidden_dim, action_dim)

# Optimizers
optimizer = optim.Adam(list(representation_net.parameters()) + 
                       list(dynamics_net.parameters()) + 
                       list(prediction_net.parameters()), lr=0.001)

In [3]:
# Training Base Model
# -------------------------------
def one_hot_encode(action, action_dim):
    onehot = torch.zeros(action_dim)
    onehot[action] = 1.0
    return onehot

for epoch in range(3):  # Lightweight setup with fewer epochs
    for states, actions, rewards in train_loader:
        optimizer.zero_grad()
        
        # Representation step
        latent_states = representation_net(states)

        # Dynamics step
        actions_onehot = torch.stack([one_hot_encode(a.item(), action_dim) for a in actions])
        predicted_next_state, predicted_reward = dynamics_net(latent_states, actions_onehot)

        # Prediction step
        predicted_policy, predicted_value = prediction_net(latent_states)

        # Losses
        reward_loss = nn.MSELoss()(predicted_reward.squeeze(), rewards)
        value_loss = nn.MSELoss()(predicted_value.squeeze(), rewards)
        policy_loss = nn.CrossEntropyLoss()(predicted_policy, actions)

        total_loss = reward_loss + value_loss + policy_loss
        total_loss.backward()
        optimizer.step()
    
    print(f"Epoch {epoch+1}, Loss: {total_loss.item():.4f}")

# -------------------------------
# Planning (Simulated Experience)
# -------------------------------
def simulate_experience(latent_state, steps=3):
    total_reward = 0
    for _ in range(steps):
        policy, _ = prediction_net(latent_state)
        action = torch.multinomial(policy, 1).item()
        action_onehot = one_hot_encode(action, action_dim)
        latent_state, reward = dynamics_net(latent_state, action_onehot)
        total_reward += reward.item()
    return total_reward

# Simulate from random states
sample_state = preprocess_state(env.reset())
latent_state = representation_net(sample_state)
print("Simulated Experience Reward:", simulate_experience(latent_state))

Epoch 1, Loss: 1.0565
Epoch 2, Loss: 0.8920
Epoch 3, Loss: 0.8868
Simulated Experience Reward: 1.107329398393631


In [4]:
# Fine-Tuning Networks
# -------------------------------
for epoch in range(2):  # Quick fine-tuning
    for states, actions, rewards in val_loader:
        optimizer.zero_grad()
        latent_states = representation_net(states)
        actions_onehot = torch.stack([one_hot_encode(a.item(), action_dim) for a in actions])
        _, predicted_reward = dynamics_net(latent_states, actions_onehot)
        predicted_policy, predicted_value = prediction_net(latent_states)

        loss = nn.MSELoss()(predicted_reward.squeeze(), rewards)
        loss.backward()
        optimizer.step()

    print(f"Fine-Tuning Epoch {epoch+1}, Loss: {loss.item():.4f}")

# -------------------------------
# Evaluation
# -------------------------------
def evaluate(model_loader):
    total_loss = 0
    with torch.no_grad():
        for states, actions, rewards in model_loader:
            latent_states = representation_net(states)
            actions_onehot = torch.stack([one_hot_encode(a.item(), action_dim) for a in actions])
            _, predicted_reward = dynamics_net(latent_states, actions_onehot)
            total_loss += nn.MSELoss()(predicted_reward.squeeze(), rewards).item()
    return total_loss / len(model_loader)

test_loss = evaluate(test_loader)
print(f"Test Loss: {test_loss:.4f}")

# -------------------------------
# Deploy Policy
# -------------------------------
def deploy_policy(state):
    state = preprocess_state(state)
    latent_state = representation_net(state)
    policy, _ = prediction_net(latent_state)
    action = torch.argmax(policy).item()
    return action

# Deploy on new environment states
new_state = env.reset()
action = deploy_policy(new_state)
print(f"Deployed Action: {action}")

Fine-Tuning Epoch 1, Loss: 0.0838
Fine-Tuning Epoch 2, Loss: 0.0850
Test Loss: 0.0776
Deployed Action: 0
