In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import math
import random

class DiffusionModel(nn.Module):
    def __init__(self, state_dim, hidden_dim):
        super(DiffusionModel, self).__init__()
        self.fc1 = nn.Linear(state_dim + 1, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, state_dim)
    
    def forward(self, noisy_state, noise_level):
        # Dimension unification handling
        if noisy_state.dim() == 1:
            noisy_state = noisy_state.unsqueeze(0)
        if noise_level.dim() == 0:
            noise_level = noise_level.unsqueeze(0).unsqueeze(1)
        elif noise_level.dim() == 1:
            noise_level = noise_level.unsqueeze(1)
            
        x = torch.cat([noisy_state, noise_level], dim=-1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        denoised = self.fc3(x)
        return denoised.squeeze(0)  # Remove batch dimension

class SyntheticTrajectoryDataset(Dataset):
    def __init__(self, num_samples, state_dim):
        self.num_samples = num_samples
        self.state_dim = state_dim

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        # Generate synthetic data with random noise
        clean_state = torch.rand(self.state_dim)
        noise_level = random.choice([0.5, 1.0])
        noise = torch.randn(self.state_dim) * noise_level
        noisy_state = clean_state + noise
        return noisy_state, torch.tensor(noise_level, dtype=torch.float32), clean_state

class Node:
    def __init__(self, trajectory, noise_level, parent=None, meta_action=0):
        self.trajectory = trajectory  
        self.noise_level = noise_level
        self.parent = parent
        self.meta_action = meta_action
        self.children = []
        self.visit_count = 0
        self.value = 0.0
        self.is_terminal = False

    def add_child(self, child_node):
        self.children.append(child_node)

class MCTDPlanner:
    def __init__(self, diffusion_model, goal, planning_horizon, state_dim, device):
        self.diffusion_model = diffusion_model.to(device)
        self.goal = goal.to(device)
        self.planning_horizon = planning_horizon
        self.state_dim = state_dim
        self.device = device
        self.root = Node(trajectory=[], noise_level=torch.tensor(1.0, device=device), parent=None)

    def uct_score(self, parent, child, c_param=1.41):
        if child.visit_count == 0:
            return float('inf')
        exploitation = child.value / child.visit_count
        exploration = c_param * math.sqrt(math.log(parent.visit_count + 1) / child.visit_count)
        return exploitation + exploration

    def best_uct_child(self, node):
        scores = [self.uct_score(node, child) for child in node.children]
        return node.children[scores.index(max(scores))]

    def is_fully_expanded(self, node):
        return len(node.children) == 2

    def is_leaf(self, node):
        return len(node.children) == 0

    def is_expandable(self, node):
        return not self.is_fully_expanded(node)

    def select_meta_action(self, node):
        # Smarter meta-action selection
        if node.trajectory:
            last_state = node.trajectory[-1]
            return 0 if last_state.mean() > 0.5 else 1
        return 0

    def denoise_subplan(self, node, meta_action):
        return [torch.zeros(self.state_dim, device=self.device)]

    def create_node(self, subplan):
        return Node(trajectory=subplan, 
                   noise_level=subplan[-1].mean(),
                   parent=None)

    def fast_denoising(self, partial_trajectory):
        """Complete denoising with dimension fixes"""
        if not partial_trajectory:
            return []
        
        num_denoise_steps = 5
        device = self.device
        current_state = partial_trajectory[-1].detach().clone().to(device)
        remaining_steps = self.planning_horizon - len(partial_trajectory)
        
        # Add batch dimension
        current_state = current_state.unsqueeze(0)  # [1, state_dim]
        
        # Noise scheduling (exponential decay)
        initial_noise_level = 0.5
        noise_schedule = torch.exp(
            torch.linspace(
                math.log(initial_noise_level), 
                math.log(1e-3), 
                num_denoise_steps,
                device=device
            )
        )
        
        denoised_states = []
        for step in range(remaining_steps):
            noisy_state = current_state
            for t in range(num_denoise_steps):
                # Dynamic noise adjustment
                adjusted_noise = noise_schedule[t] * (1 - step/remaining_steps)
                adjusted_noise = adjusted_noise.unsqueeze(0).unsqueeze(1)  # [1,1]
                
                # Denoising step
                with torch.no_grad():
                    denoised = self.diffusion_model(
                        noisy_state, 
                        adjusted_noise.to(device)
                    ).unsqueeze(0)  # Maintain batch dimension
                
                # Momentum update
                if t < num_denoise_steps - 1:
                    noisy_state = 0.8 * denoised + 0.2 * noisy_state
                else:
                    current_state = denoised
                    
            # Remove batch dimension and store
            denoised_states.append(current_state.squeeze(0).clone())
            
        return denoised_states

    def evaluate_plan(self, full_plan):
        final_state = full_plan[-1].to(self.device)
        return -torch.norm(final_state - self.goal, p=2).item()

    def search(self, iterations):
        for _ in range(iterations):
            node = self.root
            while self.is_fully_expanded(node) and not self.is_leaf(node):
                node = self.best_uct_child(node)

            if self.is_expandable(node):
                gs = self.select_meta_action(node)
                child = self.create_node(self.denoise_subplan(node, gs))
                node.add_child(child)
                node = child

            full_plan = node.trajectory + self.fast_denoising(node.trajectory)
            reward = self.evaluate_plan(full_plan)
            
            # Backpropagation
            current_node = node
            while current_node is not None:
                current_node.visit_count += 1
                current_node.value += reward
                current_node = current_node.parent

        return self.best_uct_child(self.root).trajectory

def train_diffusion_model(model, device, dataloader, epochs=5, lr=1e-3):
    model.train()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    loss_fn = nn.MSELoss()
    for epoch in range(1, epochs + 1):
        total_loss = 0.0
        for noisy_state, noise_level, clean_state in dataloader:
            noisy_state = noisy_state.to(device)
            noise_level = noise_level.to(device)
            clean_state = clean_state.to(device)
            
            optimizer.zero_grad()
            output = model(noisy_state, noise_level)
            loss = loss_fn(output, clean_state)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            
        print(f"Epoch {epoch}/{epochs} - Loss: {total_loss/len(dataloader):.4f}")

def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    state_dim = 4
    hidden_dim = 32
    num_train_samples = 10000
    batch_size = 128
    train_epochs = 10
    planning_horizon = 10
    
    # Dataset and model setup
    dataset = SyntheticTrajectoryDataset(num_train_samples, state_dim)
    dataloader = DataLoader(dataset, batch_size, shuffle=True)
    
    diffusion_model = DiffusionModel(state_dim, hidden_dim).to(device)
    print("Training diffusion model...")
    train_diffusion_model(diffusion_model, device, dataloader, train_epochs)
    print("Diffusion model training completed.\n")

    # Planning setup
    goal = torch.ones(state_dim, device=device)
    planner = MCTDPlanner(diffusion_model, goal, planning_horizon, state_dim, device)
    
    print("Running planning search...")
    best_traj = planner.search(100)
    
    # Results display
    if not best_traj:
        print("No valid trajectory found")
    else:
        print(f"\nPlanned trajectory length: {len(best_traj)}")
        for t, state in enumerate(best_traj):
            print(f"Step {t}: {state.cpu().detach().numpy().round(3)}")
        final_reward = -torch.norm(best_traj[-1] - goal, p=2).item()
        print(f"Final reward (negative distance): {final_reward:.3f}")

if __name__ == "__main__":
    main()

Training diffusion model...
Epoch 1/10 - Loss: 0.1047
Epoch 2/10 - Loss: 0.0763
Epoch 3/10 - Loss: 0.0743
Epoch 4/10 - Loss: 0.0735
Epoch 5/10 - Loss: 0.0718
Epoch 6/10 - Loss: 0.0729
Epoch 7/10 - Loss: 0.0722
Epoch 8/10 - Loss: 0.0722
Epoch 9/10 - Loss: 0.0718
Epoch 10/10 - Loss: 0.0706
Diffusion model training completed.

Running planning search...

Planned trajectory length: 1
Step 0: [0. 0. 0. 0.]
Final reward (negative distance): -2.000
