In [1]:
import numpy as np
from matplotlib import pyplot as plt 
import torch 
import torch.nn as nn
from torchinfo import summary
from torch import optim
from matplotlib.patches import FancyArrow
import matplotlib.cm as cm
from tqdm import tqdm
from datetime import datetime
import time
import keyboard
from torch.utils.tensorboard import SummaryWriter
import torch.multiprocessing as mp
from data_loader_script import create_data_loader

mp.set_start_method('spawn', force=True)

In [2]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
device

device(type='cuda')

In [3]:
class Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers):
        super(Encoder,self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.lstm = nn.LSTM(hidden_dim, hidden_dim, num_layers, batch_first=True, dropout= 0.0)
        
    def forward(self, x):
        y = self.fc1(x)
        out , _ = self.lstm(y)
        return out

class Decoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers, num_heads):
        super(Decoder,self).__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers, batch_first= True, dropout= 0.0)
        self.attn = nn.MultiheadAttention(hidden_dim, num_heads, dropout= 0.0, batch_first= True)

    def forward(self, x, enc_outs, h0, c0, indices_to_ignore):
        # LSTM output
        y, (hn, cn) = self.lstm(x, (h0, c0))  # y: (N, 1, d)

        # Create a mask for attention
        # enc_outs: (N, L, d), indices_to_ignore: (N, s)
        N, L, _ = enc_outs.shape
        mask = torch.zeros((N, L), dtype=torch.bool, device=enc_outs.device)  # Initialize mask to False

        # Set True for indices to ignore
        if not indices_to_ignore is None:
            for i in range(N):
                mask[i, indices_to_ignore[i]] = True

        # Apply attention with mask
        _, attn_weights = self.attn(query=y, key=enc_outs, value=enc_outs, key_padding_mask=mask)  # Masked attention
        attn_weights = attn_weights.squeeze(1)  # (N, L)

        return (hn, cn), attn_weights


In [4]:
class TSPNet(nn.Module):
    def __init__(self, input_dim, hidden_dim, Num_L_enc =3, 
                Num_L_dec = 3, num_heads = 2):
        super(TSPNet, self).__init__()
        self.Num_L_dec = Num_L_dec
        self.hidden_dim = hidden_dim
        self.encoder = Encoder(input_dim, hidden_dim, Num_L_enc)
        self.decoder = Decoder(hidden_dim, hidden_dim, Num_L_dec, num_heads)
    def forward(self, X, mod = 'train'):

        batch_size, seq_length, _ = X.size()

        encoded_cities = self.encoder(X) # output shape: (batch_size, num_cities, hidden_dim)

        h0,c0 = torch.zeros(self.Num_L_dec, batch_size, self.hidden_dim).to(device), torch.zeros(self.Num_L_dec, batch_size, self.hidden_dim).to(device)
        #indices_to_ignore = torch.cat((torch.zeros(batch_size,1),torch.zeros(batch_size,1)-1),dim=-1).long()
        
        start_token = nn.Parameter(torch.zeros(1, 1, self.hidden_dim))
        x_star = start_token.expand(batch_size, -1, -1).to(device)# the first input to the decoder is a vector we have to learn

        outs = torch.zeros(batch_size, seq_length+1, seq_length).to(device)
        action_indices = torch.zeros(batch_size, seq_length+1, 1).to(device)

        indices_to_ignore = None # for the first input, we can visit all the cities.
        
        for t in range(seq_length+1):
            if t == seq_length:
                indices_to_ignore = indices_to_ignore[:,1:]
            (hn,cn), attn_weights = self.decoder(x_star, encoded_cities, h0,c0, indices_to_ignore)
            attn_weights = torch.clamp(attn_weights, min=1e-9)
            attn_weights = attn_weights / attn_weights.sum(dim=-1, keepdim=True)
            if mod == 'train':
                idx = torch.multinomial(attn_weights, num_samples=1).squeeze(-1)
            elif mod == 'eval':
                idx = torch.argmax(attn_weights, dim=-1)
            else:
                raise('wrong mode')
            x_star = encoded_cities[torch.arange(batch_size), idx, :].unsqueeze(1)
            outs[:,t,:] = attn_weights
            action_indices[:,t,0] = idx
            h0,c0 = hn,cn
            if t==0:
                indices_to_ignore = idx.unsqueeze(-1)
            else:
                indices_to_ignore = torch.cat((indices_to_ignore, idx.unsqueeze(-1)),dim=-1).long()
            
        return outs, action_indices
            




In [5]:
def route_cost(cities, routes):
    B, N, _ = cities.shape
    routes = routes.squeeze(-1).long()  # Convert to long for indexing
    ordered_cities = cities[torch.arange(B).unsqueeze(1), routes]  # Reorder cities based on routes
    diffs = ordered_cities[:, :-1] - ordered_cities[:, 1:]  # Compute differences between consecutive cities
    distances = torch.norm(diffs, p=2, dim=2)  # Euclidean distances
    total_distances = distances.sum(dim=1)  # Sum distances for each batch
    return total_distances

In [6]:
import math
def generate_unit_circle_cities(B, N, d):
    """
    Generates a PyTorch tensor of size (B, N, d), representing B batches
    of N cities in d-dimensional space, where cities are randomly placed on the unit circle.
    
    Args:
        B (int): Number of batches.
        N (int): Number of cities in each batch.
        d (int): Number of dimensions (must be at least 2, higher dimensions will have zeros).
        
    Returns:
        torch.Tensor: A tensor of shape (B, N, d) with cities on the unit circle.
    """
    if d < 2:
        raise ValueError("Dimension 'd' must be at least 2.")

    # Generate random angles for each city
    angles = torch.rand(B, N) * 2 * math.pi  # Random angles in radians

    # Coordinates on the unit circle
    x_coords = torch.cos(angles)
    y_coords = torch.sin(angles)

    # Create a tensor of zeros for higher dimensions if d > 2
    higher_dims = torch.zeros(B, N, d - 2)

    # Combine x, y, and higher dimensions
    unit_circle_coords = torch.stack((x_coords, y_coords), dim=-1)
    result = torch.cat((unit_circle_coords, higher_dims), dim=-1)
    result[:,0,:] = result[:,-1,:]
    return result
cities = generate_unit_circle_cities(10,10,2)

In [7]:

input_dim = 2
hidden_dim = 128
num_layers = 2
num_heads = 1
model = TSPNet(input_dim, hidden_dim, num_layers, num_layers, num_heads).to(device)
summary(model)

Layer (type:depth-idx)                                  Param #
TSPNet                                                  --
├─Encoder: 1-1                                          --
│    └─Linear: 2-1                                      384
│    └─LSTM: 2-2                                        264,192
├─Decoder: 1-2                                          --
│    └─LSTM: 2-3                                        264,192
│    └─MultiheadAttention: 2-4                          49,536
│    │    └─NonDynamicallyQuantizableLinear: 3-1        16,512
Total params: 594,816
Trainable params: 594,816
Non-trainable params: 0

In [8]:

lr = 0.001
batch_size = 512
num_samples = 1000000
num_cities = 50
input_dim = 2
num_workers = 8  # Start with 1 worker and test scaling up
from torch.amp import GradScaler, autocast
data_loader = create_data_loader(batch_size, num_samples, num_cities, input_dim, num_workers=num_workers)

scaler = GradScaler('cuda')
optimizer = optim.Adam(model.parameters(), lr=lr)


alpha = 0.1
run_name = 'runs/TSP/' + str(batch_size) + '_' + str(num_cities) + '_' + str(num_samples) + '_' + '/ANN/'+datetime.now().strftime(("%Y_%m_%d %H_%M_%S"))
writer = SummaryWriter(log_dir=run_name)

for episode, data_batch in enumerate(tqdm(data_loader, colour='green')):
    data_batch = data_batch.to(device, non_blocking=True)

    with autocast('cuda'):  # Mixed precision for speed-up
        outs, actions = model(data_batch)
        sum_log_prob = torch.sum(torch.log(torch.cat([
            outs[i][torch.arange(len(outs[i])), actions[i].cpu().numpy().astype(int).flatten()].unsqueeze(0)
            for i in range(len(outs))
        ], axis=0)), axis=1).to(device)
        costs = route_cost(data_batch, actions).to(device)
        
        policy_loss = torch.sum(sum_log_prob * costs) / batch_size

    optimizer.zero_grad()
    scaler.scale(policy_loss).backward()  # Scale loss for mixed precision
    scaler.step(optimizer)  # Use scaler to handle optimizer step
    scaler.update()  # Update scaler for next iteration

    # Logging and monitoring
    if episode % 100 == 0:
        mean_cost = costs.mean().item()
        print(f"Episode: {episode} Mean cost: {mean_cost:.2f}")
        writer.add_scalar('Mean cost', mean_cost, episode)

writer.close()

  0%|[32m          [0m| 1/1954 [00:03<1:47:44,  3.31s/it]

Episode: 0 Mean cost: 26.12


  5%|[32m▌         [0m| 101/1954 [02:21<52:00,  1.68s/it]

Episode: 100 Mean cost: 26.03


 10%|[32m█         [0m| 201/1954 [04:47<32:40,  1.12s/it]

Episode: 200 Mean cost: 25.89


 15%|[32m█▌        [0m| 301/1954 [07:09<46:35,  1.69s/it]

Episode: 300 Mean cost: 26.11


 17%|[32m█▋        [0m| 334/1954 [07:53<38:18,  1.42s/it]


KeyboardInterrupt: 

In [None]:
class SimpleDataset(Dataset):
    def __init__(self, size):
        self.size = size

    def __len__(self):
        return self.size

    def __getitem__(self, index):
        return torch.tensor(index)  # Minimal implementation

dataset = SimpleDataset(1000)
data_loader = DataLoader(dataset, batch_size=32, num_workers=1)
for data in data_loader:
    print(data)

In [None]:
torch.save(model.state_dict(), 'model.pth')

In [None]:
model.load_state_dict(torch.load('model.pth'))

In [None]:

def plot_routes(cities, routes, batch_index=0):
    """
    Plots the route for a given batch of cities and routes.
    
    Args:
        cities (torch.Tensor): Tensor of shape (B, N, 2) representing city coordinates.
        routes (torch.Tensor): Tensor of shape (B, N) representing routes.
        batch_index (int): Index of the batch to plot.
    """
    cities = cities[batch_index].numpy()
    route = routes[batch_index].long().squeeze().numpy()
    print(route)
    # Get coordinates of cities in the order of the route
    ordered_cities = cities[route]
    
    # Plot cities
    plt.figure(figsize=(8, 6))
    plt.scatter(cities[:, 0], cities[:, 1], color='blue', zorder=2, label='Cities')
    for i, (x, y) in enumerate(cities):
        plt.text(x, y, f'{i}', fontsize=12, ha='right', color='black')
    
    # Plot the route
    plt.plot(ordered_cities[:, 0], ordered_cities[:, 1], color='red', linestyle='--', zorder=1, label='Route')
    
    # Highlight start and end points
    plt.scatter(ordered_cities[0, 0], ordered_cities[0, 1], color='green', s=100, label='Start', zorder=3)
    plt.scatter(ordered_cities[-1, 0], ordered_cities[-1, 1], color='purple', s=100, label='End', zorder=3)
    
    plt.title(f"Route for Batch {batch_index}")
    plt.xlabel("X Coordinate")
    plt.ylabel("Y Coordinate")

    plt.axis('off')
    plt.show()

In [None]:
lr = 0.001
batch_size = 128
num_cities = 10
N_episodes = 25000
data = generate_unit_circle_cities(batch_size, num_cities, input_dim).to(device)
_, actions = model(data,mod='eval')
plot_routes(data.cpu(),actions.cpu(),11)

In [None]:
test_city = torch.rand(1, num_cities, input_dim).to(device)
_ , actions = model(test_city,mod='eval')
plot_routes(test_city.cpu(),actions.cpu(),0)