In [6]:
import time
import torch
from torch.distributions import Categorical, kl
# from d2l.torch import Animator

from net import Net
from aco import ACO
from utils import gen_pyg_data, load_val_dataset

torch.manual_seed(1234)

lr = 3e-4
EPS = 1e-10
T=5
device = 'cuda:0'

In [7]:
def train_instance(model, optimizer, pyg_data, distances, n_ants):
    model.train()
    heu_vec = model(pyg_data)
    heu_mat = model.reshape(pyg_data, heu_vec) + EPS
    
    aco = ACO(
        n_ants=n_ants,
        heuristic=heu_mat,
        distances=distances,
        device=device
        )
    
    costs, log_probs, _ = aco.sample()
    baseline = costs.mean()
    reinforce_loss = torch.sum((costs - baseline) * log_probs.sum(dim=0)) / aco.n_ants
    optimizer.zero_grad()
    reinforce_loss.backward()
    optimizer.step()

def infer_instance(model, pyg_data, distances, n_ants):
    model.eval()
    heu_vec = model(pyg_data)
    heu_mat = model.reshape(pyg_data, heu_vec) + EPS
    aco = ACO(
        n_ants=n_ants,
        heuristic=heu_mat,
        distances=distances,
        device=device
        )
    costs, log_probs, _ = aco.sample()
    aco.run(n_iterations=T)
    baseline = costs.mean().item()
    best_sample_cost = torch.min(costs).item()
    best_aco_cost = aco.lowest_cost
    return baseline, best_sample_cost, best_aco_cost

In [8]:
def train_epoch(n_node,
                n_ants, 
                k_sparse, 
                epoch, 
                steps_per_epoch, 
                net, 
                optimizer
                ):
    for _ in range(steps_per_epoch):
        instance = torch.rand(size=(n_node, 2), device=device)
        data, distances = gen_pyg_data(instance, k_sparse=k_sparse)
        train_instance(net, optimizer, data, distances, n_ants)


@torch.no_grad()
def validation(n_ants, epoch, net, val_dataset, animator=None):
    sum_bl, sum_sample_best, sum_aco_best = 0, 0, 0
    
    for data, distances in val_dataset:
        bl, sample_best, aco_best = infer_instance(net, data, distances, n_ants)
        sum_bl += bl; sum_sample_best += sample_best; sum_aco_best += aco_best
    
    n_val = len(val_dataset)
    avg_bl, avg_sample_best, avg_aco_best = sum_bl/n_val, sum_sample_best/n_val, sum_aco_best/n_val
    # if animator:
    #     animator.add(epoch+1, (avg_bl, avg_sample_best, avg_aco_best))
    
    return avg_bl, avg_sample_best, avg_aco_best

In [9]:
def train(n_node, k_sparse, n_ants, steps_per_epoch, epochs):
    net = net = Net(gfn=False).to(device)
    optimizer = torch.optim.AdamW(net.parameters(), lr=lr)
    val_list = load_val_dataset(n_node, k_sparse, device)
    animator = None # Animator(xlabel='epoch', xlim=[0, epochs], legend=["Avg. sample obj.", "Best sample obj.", "Best ACO obj."])
    
    avg_bl, avg_best, avg_aco_best = validation(n_ants, -1, net, val_list, animator)
    val_results = [(avg_bl, avg_best, avg_aco_best)]
    
    sum_time = 0
    for epoch in range(0, epochs):
        start = time.time()
        train_epoch(n_node, n_ants, k_sparse, epoch, steps_per_epoch, net, optimizer)
        sum_time += time.time() - start
        avg_bl, avg_sample_best, avg_aco_best = validation(n_ants, epoch, net, val_list, animator)
        val_results.append((avg_bl, avg_sample_best, avg_aco_best))
        
    print('total training duration:', sum_time)
    
    for epoch in range(-1, epochs):
        print(f'epoch {epoch}:', val_results[epoch+1])
        
    # torch.save(net.state_dict(), f'../pretrained/tsp/tsp{n_node}.pt')

Learn heuristic for TSP20: 

In [10]:
n_node, n_ants = 20, 20
k_sparse = 10
steps_per_epoch = 128
epochs = 5
train(n_node, k_sparse, n_ants, steps_per_epoch, epochs)

total training duration: 15.017478466033936
epoch -1: (7.513315134048462, 6.231752462387085, 3.787857837677002)
epoch 0: (4.83945333480835, 4.096140575408936, 3.7878578519821167)
epoch 1: (4.378843965530396, 3.8765585231781006, 3.787857856750488)
epoch 2: (4.27893536567688, 3.8807146167755127, 3.787857871055603)
epoch 3: (4.321380271911621, 3.8742052698135376, 3.78785786151886)
epoch 4: (4.3237948513031, 3.8907668447494506, 3.78785786151886)


Learn heuristic for TSP100: 

In [11]:
n_node = 100
n_ants = 20
k_sparse = 20
steps_per_epoch = 128
epochs = 5
train(n_node, k_sparse, n_ants, steps_per_epoch, epochs)

total training duration: 43.585044622421265
epoch -1: (21.311012954711913, 19.63759094238281, 7.8251377201080325)
epoch 0: (11.352936344146729, 10.460236339569091, 7.7966246414184575)
epoch 1: (9.84587589263916, 9.101685390472412, 7.78905686378479)
epoch 2: (9.694465713500977, 9.044547328948974, 7.796471109390259)
epoch 3: (9.640990810394287, 8.939566841125488, 7.793552856445313)
epoch 4: (9.665411338806152, 8.979426040649415, 7.786292972564698)


Learn heuristic for TSP500: 

In [None]:
n_node = 500
n_ants = 50
k_sparse = 50
steps_per_epoch = 128
epochs = 5
train(n_node, k_sparse, n_ants, steps_per_epoch, epochs)