In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import math
import random

class MetaPolicy(nn.Module):
    def __init__(self, state_dim):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(state_dim + 1, 64),
            nn.ReLU(),
            nn.Linear(64, 2)
        )
    
    def forward(self, state, noise_level):
        x = torch.cat([state, noise_level.unsqueeze(-1)], dim=-1)
        return self.fc(x)

class DiffusionModel(nn.Module):
    def __init__(self, state_dim, hidden_dim):
        super().__init__()
        self.fc1 = nn.Linear(state_dim + 1, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, state_dim)
    
    def forward(self, noisy_state, noise_level):
        if noisy_state.dim() == 1:
            noisy_state = noisy_state.unsqueeze(0)
        noise_level = noise_level.view(-1, 1)
            
        x = torch.cat([noisy_state, noise_level], dim=-1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x).squeeze(0)

class SyntheticTrajectoryDataset(Dataset):
    def __init__(self, num_samples, state_dim):
        self.num_samples = num_samples
        self.state_dim = state_dim

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        clean_state = torch.rand(self.state_dim)
        noise_level = random.choice([0.5, 1.0])
        return (clean_state + torch.randn(self.state_dim) * noise_level,
                torch.tensor(noise_level, dtype=torch.float32),
                clean_state)

class Node:
    def __init__(self, trajectory, noise_level, parent=None, device='cuda'):
        self.device = device
        self.trajectory = [t.detach().requires_grad_(False) for t in trajectory] if parent else trajectory
        self.noise_level = noise_level.detach().to(self.device)
        self.parent = parent
        self.children = []
        self.visit_count = 0
        self.value = 0.0
        self.meta_policy = MetaPolicy(len(self.trajectory[-1])).to(device)
        self.optimizer = optim.Adam(self.meta_policy.parameters(), lr=0.001)

    def update_policy(self, reward):
        if not self.trajectory:
            return
            
        self.optimizer.zero_grad()
        policy_logits = self.meta_policy(self.trajectory[-1], self.noise_level)
        loss = -reward * F.log_softmax(policy_logits, dim=-1)[0]
        loss.backward(retain_graph=False)
        self.optimizer.step()

class MCTDPlanner:
    def __init__(self, diffusion_model, goal, planning_horizon, state_dim, device):
        self.diffusion_model = diffusion_model.to(device)
        self.goal = goal.to(device)
        self.planning_horizon = planning_horizon
        self.state_dim = state_dim
        self.device = device
        self.root = Node([torch.randn(state_dim, device=device)], 
                        torch.tensor(1.0, device=device),
                        device=device)

    def cosine_schedule(self, t, T, s=0.008):
        return torch.cos((t / T + s) / (1 + s) * math.pi / 2) ** 2

    def uct_score(self, parent, child, c_param=1.41):
        exploration = c_param * math.sqrt(math.log(parent.visit_count + 1) / (child.visit_count + 1e-8))
        noise_penalty = child.noise_level * math.log(self.planning_horizon + 1)
        return (child.value / (child.visit_count + 1e-8)) + exploration - noise_penalty

    def best_uct_child(self, node):
        scores = [self.uct_score(node, child) for child in node.children]
        return node.children[torch.argmax(torch.tensor(scores)).item()]

    def is_fully_expanded(self, node):
        return len(node.children) >= 2

    def is_leaf(self, node):
        return len(node.trajectory) >= self.planning_horizon

    def select_meta_action(self, node):
        with torch.no_grad():
            logits = node.meta_policy(node.trajectory[-1], node.noise_level)
            return torch.distributions.Categorical(logits=logits).sample().item()

    def denoise_subplan(self, node, meta_action):
        device = node.device
        subplan = []
        num_steps = meta_action + 1
        noise_decay = 0.5 + 0.3 * meta_action
        
        current_state = node.trajectory[-1].clone()
        current_noise = node.noise_level * noise_decay
        
        for step in range(num_steps):
            t = torch.tensor(step / num_steps, device=device)
            noise_level = self.cosine_schedule(t, num_steps)
            current_state = self.diffusion_model(
                current_state.unsqueeze(0),
                noise_level.unsqueeze(0)
            ).squeeze(0)
            subplan.append(current_state.detach())
        
        return subplan

    def create_node(self, subplan, parent):
        device = parent.device if parent else self.device
        subplan = [s.to(device) for s in subplan]
        return Node(
            trajectory=subplan,
            noise_level=subplan[-1].mean().detach(),
            parent=parent,
            device=device
        )

    def fast_denoising(self, partial_trajectory):
        if len(partial_trajectory) >= self.planning_horizon:
            return []
            
        remaining_steps = self.planning_horizon - len(partial_trajectory)
        current_state = partial_trajectory[-1].clone() if partial_trajectory else torch.randn(self.state_dim, device=self.device)
        denoised = []
        
        T = 5 * remaining_steps
        for step in range(remaining_steps):
            t = torch.tensor(step / remaining_steps, device=self.device)
            noise_level = self.cosine_schedule(t, T)
            current_state = self.diffusion_model(
                current_state.unsqueeze(0),
                noise_level.unsqueeze(0)
            ).squeeze(0)
            denoised.append(current_state.detach())
            
        return denoised

    def evaluate_plan(self, full_plan):
        final_state = full_plan[-1].to(self.device)
        return -torch.norm(final_state - self.goal).item()

    def search(self, iterations):
        for _ in range(iterations):
            node = self.root
            
            # Selection phase
            while self.is_fully_expanded(node) and not self.is_leaf(node):
                node = self.best_uct_child(node)
            
            # Expansion phase
            if not self.is_leaf(node):
                meta_action = self.select_meta_action(node)
                subplan = self.denoise_subplan(node, meta_action)
                child = self.create_node(subplan, parent=node)
                node.children.append(child)
                node = child
            
            # Simulation phase
            full_plan = node.trajectory + self.fast_denoising(node.trajectory)
            reward = self.evaluate_plan(full_plan)
            
            # Backpropagation
            current_node = node
            while current_node is not None:
                current_node.visit_count += 1
                current_node.value += reward
                current_node.update_policy(reward)
                current_node = current_node.parent
        
        best_child = max(self.root.children, key=lambda x: x.value/(x.visit_count + 1e-8))
        return best_child.trajectory

def train_diffusion_model(model, device, dataloader, epochs=5, lr=1e-3):
    model.train()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    loss_fn = nn.MSELoss()
    for epoch in range(1, epochs + 1):
        total_loss = 0.0
        for noisy_state, noise_level, clean_state in dataloader:
            noisy_state = noisy_state.to(device)
            noise_level = noise_level.to(device)
            clean_state = clean_state.to(device)
            
            optimizer.zero_grad()
            output = model(noisy_state, noise_level) 
            loss = loss_fn(output, clean_state)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            
        print(f"Epoch {epoch}/{epochs} - Loss: {total_loss/len(dataloader):.4f}")

def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    state_dim = 4
    hidden_dim = 32
    num_train_samples = 10000
    batch_size = 128
    train_epochs = 10
    planning_horizon = 10
    
    dataset = SyntheticTrajectoryDataset(num_train_samples, state_dim)
    dataloader = DataLoader(dataset, batch_size, shuffle=True)
    
    diffusion_model = DiffusionModel(state_dim, hidden_dim).to(device)
    print("Training diffusion model...")
    train_diffusion_model(diffusion_model, device, dataloader, train_epochs)
    print("Diffusion model training completed.\n")
    
    goal = torch.ones(state_dim, device=device)
    planner = MCTDPlanner(diffusion_model, goal, planning_horizon, state_dim, device)
    
    print("Running planning search...")
    best_traj = planner.search(100)
    
    if not best_traj:
        print("No valid trajectory found")
    else:
        print(f"\nPlanned trajectory length: {len(best_traj)}")
        for t, state in enumerate(best_traj):
            print(f"Step {t}: {state.cpu().detach().numpy().round(3)}")
        final_reward = -torch.norm(best_traj[-1] - goal, p=2).item()
        print(f"Final reward (negative distance): {final_reward:.3f}")

if __name__ == "__main__":
    main()

Training diffusion model...
Epoch 1/10 - Loss: 0.1501
Epoch 2/10 - Loss: 0.0769
Epoch 3/10 - Loss: 0.0734
Epoch 4/10 - Loss: 0.0729
Epoch 5/10 - Loss: 0.0733
Epoch 6/10 - Loss: 0.0736
Epoch 7/10 - Loss: 0.0729
Epoch 8/10 - Loss: 0.0731
Epoch 9/10 - Loss: 0.0718
Epoch 10/10 - Loss: 0.0731
Diffusion model training completed.

Running planning search...

Planned trajectory length: 2
Step 0: [0.427 0.466 0.284 0.551]
Step 1: [0.479 0.487 0.47  0.497]
Final reward (negative distance): -1.034
