In [13]:
!pip install gymnasium
!pip install torchrl
!pip install TensorDict



In [14]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torchrl.modules import DecisionTransformer
from torchrl.objectives import DTLoss
import gymnasium as gym
from gymnasium import spaces
from tensordict import TensorDict
from torch.nn import CrossEntropyLoss
import torch.nn.functional as F

Env

In [21]:
class ATSPEnv(gym.Env):
    def __init__(self, num_nodes, seed=None):
        super(ATSPEnv, self).__init__()
        self.num_nodes = num_nodes
        self.seed = seed
        self.action_space = spaces.Discrete(self.num_nodes)
        self.observation_space = spaces.Dict({
            "adjacency_matrix": spaces.Box(low=0, high=np.inf, shape=(self.num_nodes, self.num_nodes), dtype=np.float32),
            "current_node": spaces.Discrete(self.num_nodes),
            "visited": spaces.MultiBinary(self.num_nodes)
        })
        self.graph = self._generate_random_graph()
        self.reset()

    def _generate_random_graph(self):
        # if self.seed is not None:
        #     np.random.seed(self.seed)

        locations = np.random.rand(self.num_nodes, 2)
        graph = np.zeros((self.num_nodes, self.num_nodes))
        for i in range(self.num_nodes):
            for j in range(i + 1, self.num_nodes):
                distance = np.linalg.norm(locations[i] - locations[j])
                graph[i, j] = distance
                graph[j, i] = distance

        # np.fill_diagonal(graph, np.inf)
        np.fill_diagonal(graph, 1000)
        return graph

    def reset(self):
        self.graph = self._generate_random_graph()
        self.current_node = np.random.choice(self.num_nodes)
        self.visited = np.zeros(self.num_nodes, dtype=bool)
        # self.visited[self.current_node] = True
        self.path = self.optimal_solution()
        print(self.path)
        return self._get_obs(), {}

    def _get_obs(self):
        obs = {
            "adjacency_matrix": self.graph,
            "current_node": self.current_node,
            "visited": self.visited
        }
        return obs

    def step(self, action):
        done = False
        if self.visited[action]:
            reward = -20.0
            # done = True
        else:
            reward = -self._compute_distance(self.current_node, action)
            self.current_node = action
            self.visited[action] = True
            done = self.visited.all()

        reward = np.clip(reward, -20.0, 10.0)
        return self._get_obs(), reward, done, {}

    def _compute_distance(self, node1, node2):
        return self.graph[node1, node2]

    def render(self, mode='human'):
        print(f'Visited: {np.where(self.visited)[0]}')

    def optimal_solution(self):
        current_node = 0
        visited = np.zeros(self.num_nodes, dtype=bool)
        path = [current_node]
        visited[current_node] = True

        # Loop until all nodes are visited
        while not visited.all():
            # Find the nearest unvisited neighbor
            next_node = None
            min_distance = np.inf
            for i in range(self.num_nodes):
                if not visited[i] and self.graph[current_node, i] < min_distance:
                    min_distance = self.graph[current_node, i]
                    next_node = i

            # Move to the next node
            current_node = next_node
            visited[current_node] = True
            path.append(current_node)

        # Close the tour by returning to the starting node
        # path.append(path[0])

        # one hot encoding--idk if needed
        path_tensor = torch.zeros((self.num_nodes, self.num_nodes))
        for i in range(self.num_nodes):
          curr = path[i]
          path_tensor[i][curr] = 1
        return path

  and should_run_async(code)


Train

In [22]:
# make it so that it AR predcits the full path one step at a time
# this path is then compared to optimal path
# error is MSE between the two
def compute_tour_length(adj_matrix, tour):
    tour_length = 0
    for i in range(len(tour) - 1):
        tour_length += adj_matrix[tour[i], tour[i + 1]]
    tour_length += adj_matrix[tour[-1], tour[0]]  # Return to the starting point
    return tour_length

def train_online(model, env, num_episodes=100):
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    # criterion = DTLoss(actor_network=model, loss_function="cross_entropy")
    criterion = CrossEntropyLoss()

    for episode in range(num_episodes):
        state, _ = env.reset()
        done = False
        total_loss = 0
        total_reward = 0
        counter = 0
        episode_rewards = []

        while not done:
            curr_node = state["current_node"]
            visited = state["visited"]
            adj_matrix = state["adjacency_matrix"]
            num_nodes = adj_matrix.shape[0]

            # obs for each node
            node_observations = []
            for i in range(num_nodes):
                distance_to_current_node = adj_matrix[curr_node, i]
                visited_status = float(visited[i])
                node_observations.append([distance_to_current_node, visited_status])
            node_observations = torch.tensor(node_observations).float().unsqueeze(0)  # [1, num_nodes, 2]

            # action for each node
            actions = []
            for i in range(num_nodes):
                next_node_index = i
                travel_cost = adj_matrix[curr_node, next_node_index]
                actions.append([next_node_index, travel_cost])
            actions = torch.tensor(actions).float().unsqueeze(0) # [1, num_nodes, 2]

            # return_to_go = torch.tensor(np.random.rand(1, num_nodes, 1), dtype=torch.float)

            # action probabilities
            action_probs = model(node_observations, actions, return_to_go)  # [1, num_nodes, 128]
            action_probs_no_batch = action_probs.squeeze(0) # [num_nodes, 128]
            score_layer = nn.Linear(128, 1)

            # convert features to scores
            scores = score_layer(action_probs_no_batch)
            probabilities = F.softmax(scores.squeeze(1), dim=-1)

            mask = torch.tensor(visited, dtype=torch.bool)
            print(mask)
            masked_probabilities = probabilities.clone()
            masked_probabilities[mask] = 0
            distribution = torch.distributions.Categorical(probs=masked_probabilities)

            # Sample an action
            selected_action = distribution.sample().item()
            print(masked_probabilities)
            print(selected_action)

            optimal_path = env.path
            optimal_action = optimal_path[counter]
            optimal_action = torch.tensor([optimal_action])

            # Need to get scores in shape batches, classes
            scores = scores.transpose(0, 1)

            loss = criterion(scores, optimal_action.long())

            total_loss += loss.item()

            # backprop
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # next state
            next_state, reward, done, _ = env.step(selected_action)
            total_reward += reward
            episode_rewards.append(reward)
            state = next_state
            counter += 1
        print("finished!")
    print(f"Episode {episode + 1}/{num_episodes}, Total Reward: {total_reward}, Episode Loss: {total_loss}")

if __name__ == "__main__":
    env = ATSPEnv(num_nodes=5, seed=42)

    state_dim = 2
    action_dim = 2

    config = DecisionTransformer.DTConfig(n_embd=128)
    model = DecisionTransformer(state_dim=state_dim, action_dim=action_dim, config=config)

    train_online(model, env)
    torch.save(model.state_dict(), "decision_transformer_model.pth")

[0, 1, 4, 2, 3]
[0, 3, 1, 2, 4]
tensor([False, False, False, False, False])
tensor([0.2129, 0.2089, 0.1776, 0.1897, 0.2109], grad_fn=<IndexPutBackward0>)
4
tensor([False, False, False, False,  True])
tensor([0.1597, 0.1352, 0.1432, 0.1862, 0.0000], grad_fn=<IndexPutBackward0>)
1
tensor([False,  True, False, False,  True])
tensor([0.1467, 0.0000, 0.2510, 0.1968, 0.0000], grad_fn=<IndexPutBackward0>)
2
tensor([False,  True,  True, False,  True])
tensor([0.2232, 0.0000, 0.0000, 0.2053, 0.0000], grad_fn=<IndexPutBackward0>)
3
tensor([False,  True,  True,  True,  True])
tensor([0.1908, 0.0000, 0.0000, 0.0000, 0.0000], grad_fn=<IndexPutBackward0>)
0
finished!
[0, 3, 1, 2, 4]
tensor([False, False, False, False, False])
tensor([0.1961, 0.1900, 0.2467, 0.1718, 0.1954], grad_fn=<IndexPutBackward0>)
1
tensor([False,  True, False, False, False])
tensor([0.1336, 0.0000, 0.1776, 0.1426, 0.2322], grad_fn=<IndexPutBackward0>)
4
tensor([False,  True, False, False,  True])
tensor([0.1863, 0.0000, 0.1828

In [1]:
# other loss function:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np

def compute_tour_length(adj_matrix, tour):
    tour_length = 0
    for i in range(len(tour) - 1):
        tour_length += adj_matrix[tour[i], tour[i + 1]]
    tour_length += adj_matrix[tour[-1], tour[0]]  # Return to the starting point
    return tour_length

def train_online(model, env, num_episodes=100):
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()

    for episode in range(num_episodes):
        state, _ = env.reset()
        done = False
        total_loss = 0
        total_reward = 0
        counter = 0

        selected_actions = []
        log_probs = []

        while not done:
            curr_node = state["current_node"]
            visited = state["visited"]
            adj_matrix = state["adjacency_matrix"]
            num_nodes = adj_matrix.shape[0]

            # obs for each node
            node_observations = []
            for i in range(num_nodes):
                distance_to_current_node = adj_matrix[curr_node, i]
                visited_status = float(visited[i])
                node_observations.append([distance_to_current_node, visited_status])
            node_observations = torch.tensor(node_observations).float().unsqueeze(0)  # [1, num_nodes, 2]

            # action for each node
            actions = []
            for i in range(num_nodes):
                next_node_index = i
                travel_cost = adj_matrix[curr_node, next_node_index]
                actions.append([next_node_index, travel_cost])
            actions = torch.tensor(actions).float().unsqueeze(0) # [1, num_nodes, 2]

            return_to_go = torch.tensor(np.random.rand(1, num_nodes, 1), dtype=torch.float)

            # action probabilities
            action_probs = model(node_observations, actions, return_to_go)  # [1, num_nodes, 128]
            action_probs_no_batch = action_probs.squeeze(0) # [num_nodes, 128]
            score_layer = nn.Linear(128, 1)

            # convert features to scores
            scores = score_layer(action_probs_no_batch)
            probabilities = F.softmax(scores.squeeze(1), dim=-1)

            distribution = torch.distributions.Categorical(probs=probabilities)

            # sample an action
            selected_action = distribution.sample().item()
            log_prob = distribution.log_prob(torch.tensor(selected_action))
            log_probs.append(log_prob)

            selected_actions.append(selected_action)

            # next state
            next_state, reward, done, _ = env.step(selected_action)
            total_reward += reward
            state = next_state
            counter += 1

        # Compute tour lengths
        model_tour_length = compute_tour_length(state["adjacency_matrix"], selected_actions)
        optimal_path = env.optimal_solution()
        optimal_tour_length = compute_tour_length(state["adjacency_matrix"], optimal_path)

        # Compute loss
        tour_length_diff = model_tour_length - optimal_tour_length
        log_probs_sum = torch.stack(log_probs).sum()
        loss = tour_length_diff * log_probs_sum
        print(loss)

        total_loss += loss.item()

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        print(f"Episode {episode + 1}/{num_episodes}, Total Reward: {total_reward}, Episode Loss: {total_loss}")

if __name__ == "__main__":
    env = ATSPEnv(num_nodes=5, seed=42)

    state_dim = 2
    action_dim = 2

    config = DecisionTransformer.DTConfig(n_embd=128)
    model = DecisionTransformer(state_dim=state_dim, action_dim=action_dim, config=config)

    train_online(model, env)
    torch.save(model.state_dict(), "decision_transformer_model.pth")


Eval