In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import math
import random

class DiffusionModel(nn.Module):
    def __init__(self, state_dim, hidden_dim):
        super().__init__()
        self.fc1 = nn.Linear(state_dim + 1, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, state_dim)
    
    def forward(self, noisy_state, noise_level):
        # Dimension unification handling
        if noisy_state.dim() == 1:
            noisy_state = noisy_state.unsqueeze(0)
        if noise_level.dim() == 0:
            noise_level = noise_level.unsqueeze(0).unsqueeze(1)
        elif noise_level.dim() == 1:
            noise_level = noise_level.unsqueeze(1)
            
        x = torch.cat([noisy_state, noise_level], dim=-1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x).squeeze(0)

class SyntheticTrajectoryDataset(Dataset):
    def __init__(self, num_samples, state_dim):
        self.num_samples = num_samples
        self.state_dim = state_dim

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        clean_state = torch.rand(self.state_dim)
        noise_level = random.choice([0.5, 1.0])
        return (clean_state + torch.randn(self.state_dim) * noise_level,
                torch.tensor(noise_level, dtype=torch.float32),
                clean_state)

class Node:
    def __init__(self, trajectory, noise_level, parent=None, meta_action=0):
        self.trajectory = trajectory
        self.noise_level = noise_level
        self.parent = parent
        self.children = []
        self.visit_count = 0
        self.value = 0.0
        self.meta_action_weights = torch.tensor([1.0, 1.0])  # Initial weights for 2 meta-actions
        self.action_history = []

class MCTDPlanner:
    def __init__(self, diffusion_model, goal, planning_horizon, state_dim, device):
        self.diffusion_model = diffusion_model.to(device)
        self.goal = goal.to(device)
        self.planning_horizon = planning_horizon
        self.state_dim = state_dim
        self.device = device
        self.root = Node([], torch.tensor(1.0, device=device))

    def uct_score(self, parent, child, c_param=1.41):
        if child.visit_count == 0:
            return float('inf')
        return (child.value / child.visit_count) + \
               c_param * math.sqrt(math.log(parent.visit_count + 1) / child.visit_count)

    def best_uct_child(self, node):
        return node.children[max(enumerate(
            [self.uct_score(node, child) for child in node.children]),
            key=lambda x: x[1])[0]]

    def is_fully_expanded(self, node):
        return len(node.children) >= 2  # Two meta-actions available

    def is_leaf(self, node):
        return len(node.trajectory) >= self.planning_horizon

    def select_meta_action(self, node):
        # Softmax selection based on learned weights
        probs = F.softmax(node.meta_action_weights, dim=0)
        return torch.multinomial(probs, 1).item()

    def denoise_subplan(self, node, meta_action):
        # Generate subplan using meta-action guidance
        with torch.no_grad():
            noise_level = node.noise_level * 0.8  # Noise reduction
            return [self.diffusion_model(
                node.trajectory[-1].to(self.device) if node.trajectory 
                else torch.randn(self.state_dim, device=self.device),
                noise_level.to(self.device)
            )]

    def create_node(self, subplan):
        return Node(
            trajectory=subplan,
            noise_level=subplan[-1].mean(),
            parent=None
        )

    def update_meta_schedule(self, node, reward):
        # Update meta-action weights using reward signal
        learning_rate = 0.1
        if node.action_history:
            last_action = node.action_history[-1]
            node.meta_action_weights[last_action] += learning_rate * reward
            node.meta_action_weights = torch.clamp(node.meta_action_weights, 0.1, 10.0)

    def fast_denoising(self, partial_trajectory):
        if len(partial_trajectory) >= self.planning_horizon:
            return []
            
        num_steps = self.planning_horizon - len(partial_trajectory)
        current_state = partial_trajectory[-1].clone() if partial_trajectory \
            else torch.randn(self.state_dim, device=self.device)
            
        noise_schedule = torch.linspace(0.5, 0.01, 5, device=self.device)
        denoised = []
        
        for _ in range(num_steps):
            for t in noise_schedule:
                current_state = self.diffusion_model(
                    current_state.unsqueeze(0),
                    t.unsqueeze(0)
                ).squeeze(0)
            denoised.append(current_state.clone())
            
        return denoised

    def evaluate_plan(self, full_plan):
        final_state = full_plan[-1].to(self.device)
        return -torch.norm(final_state - self.goal).item()

    def search(self, iterations):
        for _ in range(iterations):
            node = self.root
            # Selection phase
            while self.is_fully_expanded(node) and not self.is_leaf(node):
                node = self.best_uct_child(node)

            # Expansion phase
            if not self.is_leaf(node):
                meta_action = self.select_meta_action(node)
                node.action_history.append(meta_action)
                subplan = self.denoise_subplan(node, meta_action)
                child = self.create_node(subplan)
                node.children.append(child)
                node = child

            # Simulation phase
            full_plan = node.trajectory + self.fast_denoising(node.trajectory)
            reward = self.evaluate_plan(full_plan)

            # Backpropagation
            current = node
            while current is not None:
                current.visit_count += 1
                current.value += reward
                self.update_meta_schedule(current, reward)
                current = current.parent

        # Return best trajectory
        best_child = max(self.root.children, key=lambda x: x.value/x.visit_count)
        return best_child.trajectory

def train_diffusion_model(model, device, dataloader, epochs=5, lr=1e-3):
    model.train()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    loss_fn = nn.MSELoss()
    for epoch in range(1, epochs + 1):
        total_loss = 0.0
        for noisy_state, noise_level, clean_state in dataloader:
            noisy_state = noisy_state.to(device)
            noise_level = noise_level.to(device)
            clean_state = clean_state.to(device)
            
            optimizer.zero_grad()
            output = model(noisy_state, noise_level)
            loss = loss_fn(output, clean_state)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            
        print(f"Epoch {epoch}/{epochs} - Loss: {total_loss/len(dataloader):.4f}")

def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    state_dim = 4
    hidden_dim = 32
    num_train_samples = 10000
    batch_size = 128
    train_epochs = 10
    planning_horizon = 10
    
    # Dataset and model setup
    dataset = SyntheticTrajectoryDataset(num_train_samples, state_dim)
    dataloader = DataLoader(dataset, batch_size, shuffle=True)
    
    diffusion_model = DiffusionModel(state_dim, hidden_dim).to(device)
    print("Training diffusion model...")
    train_diffusion_model(diffusion_model, device, dataloader, train_epochs)
    print("Diffusion model training completed.\n")

    # Planning setup
    goal = torch.ones(state_dim, device=device)
    planner = MCTDPlanner(diffusion_model, goal, planning_horizon, state_dim, device)
    
    print("Running planning search...")
    best_traj = planner.search(100)
    
    # Results display
    if not best_traj:
        print("No valid trajectory found")
    else:
        print(f"\nPlanned trajectory length: {len(best_traj)}")
        for t, state in enumerate(best_traj):
            print(f"Step {t}: {state.cpu().detach().numpy().round(3)}")
        final_reward = -torch.norm(best_traj[-1] - goal, p=2).item()
        print(f"Final reward (negative distance): {final_reward:.3f}")

if __name__ == "__main__":
    main()

Training diffusion model...
Epoch 1/10 - Loss: 0.1047
Epoch 2/10 - Loss: 0.0763
Epoch 3/10 - Loss: 0.0743
Epoch 4/10 - Loss: 0.0735
Epoch 5/10 - Loss: 0.0718
Epoch 6/10 - Loss: 0.0729
Epoch 7/10 - Loss: 0.0722
Epoch 8/10 - Loss: 0.0722
Epoch 9/10 - Loss: 0.0718
Epoch 10/10 - Loss: 0.0706
Diffusion model training completed.

Running planning search...

Planned trajectory length: 1
Step 0: [0. 0. 0. 0.]
Final reward (negative distance): -2.000
