In [None]:
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torch.optim import lr_scheduler
import matplotlib.pyplot as plt
from tqdm import tqdm_notebook

# Model

In [None]:
class Attention(nn.Module):
    def __init__(self, n_hidden):
        super(Attention, self).__init__()
        self.size = 0
        self.batch_size = 0
        self.dim = n_hidden
        
        v  = torch.FloatTensor(n_hidden).cuda()
        self.v  = nn.Parameter(v)
        self.v.data.uniform_(-1/math.sqrt(n_hidden), 1/math.sqrt(n_hidden))
        
        # parameters for pointer attention
        self.Wref = nn.Linear(n_hidden, n_hidden)
        self.Wq = nn.Linear(n_hidden, n_hidden)
    
    
    def forward(self, q, ref):       # query and reference
        self.batch_size = q.size(0)
        self.size = int(ref.size(0) / self.batch_size)
        q = self.Wq(q)     # (B, dim)
        ref = self.Wref(ref)
        ref = ref.view(self.batch_size, self.size, self.dim)  # (B, size, dim)
        
        q_ex = q.unsqueeze(1).repeat(1, self.size, 1) # (B, size, dim)
        # v_view: (B, dim, 1)
        v_view = self.v.unsqueeze(0).expand(self.batch_size, self.dim).unsqueeze(2)
        
        # (B, size, dim) * (B, dim, 1)
        u = torch.bmm(torch.tanh(q_ex + ref), v_view).squeeze(2)
        
        return u, ref

In [None]:
class LSTM(nn.Module):
    def __init__(self, n_hidden):
        super(LSTM, self).__init__()
        
        # parameters for input gate
        self.Wxi = nn.Linear(n_hidden, n_hidden)    # W(xt)
        self.Whi = nn.Linear(n_hidden, n_hidden)    # W(ht)
        self.wci = nn.Linear(n_hidden, n_hidden)    # w(ct)
        
        # parameters for forget gate
        self.Wxf = nn.Linear(n_hidden, n_hidden)    # W(xt)
        self.Whf = nn.Linear(n_hidden, n_hidden)    # W(ht)
        self.wcf = nn.Linear(n_hidden, n_hidden)    # w(ct)
        
        # parameters for cell gate
        self.Wxc = nn.Linear(n_hidden, n_hidden)    # W(xt)
        self.Whc = nn.Linear(n_hidden, n_hidden)    # W(ht)
        
        # parameters for forget gate
        self.Wxo = nn.Linear(n_hidden, n_hidden)    # W(xt)
        self.Who = nn.Linear(n_hidden, n_hidden)    # W(ht)
        self.wco = nn.Linear(n_hidden, n_hidden)    # w(ct)
    
    
    def forward(self, x, h, c):       # query and reference
        
        # input gate
        i = torch.sigmoid(self.Wxi(x) + self.Whi(h) + self.wci(c))
        # forget gate
        f = torch.sigmoid(self.Wxf(x) + self.Whf(h) + self.wcf(c))
        # cell gate
        c = f * c + i * torch.tanh(self.Wxc(x) + self.Whc(h))
        # output gate
        o = torch.sigmoid(self.Wxo(x) + self.Who(h) + self.wco(c))
        
        h = o * torch.tanh(c)
        
        return h, c

In [None]:
class GPN(nn.Module):
    
    def __init__(self, n_feature, n_hidden):
        super(GPN, self).__init__()
        self.city_size = 0
        self.batch_size = 0
        self.dim = n_hidden
        
        # lstm for first turn
        self.lstm0 = nn.LSTM(n_hidden, n_hidden)
        
        # pointer layer
        self.pointer = Attention(n_hidden)
        
        # lstm encoder
        self.encoder = LSTM(n_hidden)
        
        # trainable first hidden input
        h0 = torch.FloatTensor(n_hidden).cuda()
        c0 = torch.FloatTensor(n_hidden).cuda()
        
        # trainable latent variable coefficient
        alpha = torch.ones(1).cuda()
        
        self.h0 = nn.Parameter(h0)
        self.c0 = nn.Parameter(c0)
        
        self.alpha = nn.Parameter(alpha)
        self.h0.data.uniform_(-1/math.sqrt(n_hidden), 1/math.sqrt(n_hidden))
        self.c0.data.uniform_(-1/math.sqrt(n_hidden), 1/math.sqrt(n_hidden))
        
        r1 = torch.ones(1).cuda()
        r2 = torch.ones(1).cuda()
        r3 = torch.ones(1).cuda()
        self.r1 = nn.Parameter(r1)
        self.r2 = nn.Parameter(r2)
        self.r3 = nn.Parameter(r3)
        
        # embedding
        self.embedding_x = nn.Linear(n_feature, n_hidden)
        self.embedding_all = nn.Linear(n_feature, n_hidden)
        
        
        # weights for GNN
        self.W1 = nn.Linear(n_hidden, n_hidden)
        self.W2 = nn.Linear(n_hidden, n_hidden)
        self.W3 = nn.Linear(n_hidden, n_hidden)
        
        # aggregation function for GNN
        self.agg_1 = nn.Linear(n_hidden, n_hidden)
        self.agg_2 = nn.Linear(n_hidden, n_hidden)
        self.agg_3 = nn.Linear(n_hidden, n_hidden)
    
    
    def forward(self, x, X_all, mask, h=None, c=None, latent=None):
        '''
        Inputs (B: batch size, size: city size, dim: hidden dimension)
        
        x: current city coordinate (B, 2)
        X_all: all cities' cooridnates (B, size, 2)
        mask: mask visited cities
        h: hidden variable (B, dim)
        c: cell gate (B, dim)
        latent: latent pointer vector from previous layer (B, size, dim)
        
        Outputs
        
        softmax: probability distribution of next city (B, size)
        h: hidden variable (B, dim)
        c: cell gate (B, dim)
        latent_u: latent pointer vector for next layer
        '''
        
        self.batch_size = X_all.size(0)
        self.city_size = X_all.size(1)
        
        
        # =============================
        # vector context
        # =============================
        
        x_expand = x.unsqueeze(1).repeat(1, self.city_size, 1)   # (B, size)
        X_all = X_all - x_expand
        
        # the weights share across all the cities
        x = self.embedding_x(x)
        context = self.embedding_all(X_all)
        
        # =============================
        # process hidden variable
        # =============================
        
        first_turn = False
        if h is None or c is None:
            first_turn = True
        
        if first_turn:
            # (dim) -> (B, dim)
            
            h0 = self.h0.unsqueeze(0).expand(self.batch_size, self.dim)
            c0 = self.c0.unsqueeze(0).expand(self.batch_size, self.dim)

            h0 = h0.unsqueeze(0).contiguous()
            c0 = c0.unsqueeze(0).contiguous()
            
            input_context = context.permute(1,0,2).contiguous()
            _, (h_enc, c_enc) = self.lstm0(input_context, (h0, c0))
            
            # let h0, c0 be the hidden variable of first turn
            h = h_enc.squeeze(0)
            c = c_enc.squeeze(0)
        
        
        # =============================
        # graph neural network encoder
        # =============================
        
        # (B, size, dim)
        context = context.view(-1, self.dim)
        
        context = self.r1 * self.W1(context)\
            + (1-self.r1) * F.relu(self.agg_1(context/(self.city_size-1)))

        context = self.r2 * self.W2(context)\
            + (1-self.r2) * F.relu(self.agg_2(context/(self.city_size-1)))
        
        context = self.r3 * self.W3(context)\
            + (1-self.r3) * F.relu(self.agg_3(context/(self.city_size-1)))
        
        
        # LSTM encoder
        h, c = self.encoder(x, h, c)
        
        # query vector
        q = h
        
        # pointer
        u, _ = self.pointer(q, context)
        
        latent_u = u.clone()
        
        u = 100 * torch.tanh(u) + mask
        
        if latent is not None:
            u += self.alpha * latent
    
        return F.softmax(u, dim=1), h, c, latent_u


# Training

In [None]:
size = 50
learn_rate = 1e-3    # learning rate
B = 128    # batch_size
B_val = 32     # validation size
size_val = 500
steps = 2500    # training steps
n_epoch = 10    # epochs
save_root = './model/gpn_tsp500.pt'

print('=========================')
print('prepare to train')
print('=========================')
print('Hyperparameters:')
print('size', size)
print('learning rate', learn_rate)
print('batch size', B)
print('validation size', B_val)
print('steps', steps)
print('epoch', n_epoch)
print('save root:', save_root)
print('=========================')

model = GPN(n_feature=2, n_hidden=128).cuda()

# load model
# model = torch.load(save_root).cuda()

optimizer = optim.Adam(model.parameters(), lr=learn_rate)

lr_decay_step = 2500
lr_decay_rate = 0.96
opt_scheduler = lr_scheduler.MultiStepLR(optimizer, range(lr_decay_step, lr_decay_step*1000,
                                     lr_decay_step), gamma=lr_decay_rate)


In [None]:
C = 0     # baseline
R = 0     # reward

val_mean = []
val_std = []

for epoch in range(n_epoch):
    for i in range(steps):
        optimizer.zero_grad()
    
        X = np.random.rand(B, size, 2)        
    
        X = torch.Tensor(X).cuda()
        
        mask = torch.zeros(B,size).cuda()
    
        R = 0
        logprobs = 0
        reward = 0
        
        Y = X.view(B,size,2)
        x = Y[:,0,:]
        h = None
        c = None
    
        for k in range(size):
            
            output, h, c, _ = model(x=x, X_all=X, h=h, c=c, mask=mask)
            
            sampler = torch.distributions.Categorical(output)
            idx = sampler.sample()         # now the idx has B elements
    
            Y1 = Y[[i for i in range(B)], idx.data].clone()
            if k == 0:
                Y_ini = Y1.clone()
            if k > 0:
                reward = torch.norm(Y1-Y0, dim=1)
                
            Y0 = Y1.clone()
            x = Y[[i for i in range(B)], idx.data].clone()
            
            R += reward
                
            TINY = 1e-15
            logprobs += torch.log(output[[i for i in range(B)], idx.data]+TINY) 
            
            mask[[i for i in range(B)], idx.data] += -np.inf 
            
        R += torch.norm(Y1-Y_ini, dim=1)
        
        
        # self-critic base line
        mask = torch.zeros(B,size).cuda()
        
        C = 0
        baseline = 0
        
        Y = X.view(B,size,2)
        x = Y[:,0,:]
        h = None
        c = None
        
        for k in range(size):
        
            output, h, c, _ = model(x=x, X_all=X, h=h, c=c, mask=mask)
        
            # sampler = torch.distributions.Categorical(output)
            # idx = sampler.sample()         # now the idx has B elements
            idx = torch.argmax(output, dim=1)    # greedy baseline
        
            Y1 = Y[[i for i in range(B)], idx.data].clone()
            if k == 0:
                Y_ini = Y1.clone()
            if k > 0:
                baseline = torch.norm(Y1-Y0, dim=1)
        
            Y0 = Y1.clone()
            x = Y[[i for i in range(B)], idx.data].clone()
        
            C += baseline
            mask[[i for i in range(B)], idx.data] += -np.inf
    
        C += torch.norm(Y1-Y_ini, dim=1)
    
        gap = (R-C).mean()
        loss = ((R-C-gap)*logprobs).mean()
    
        loss.backward()
        
        max_grad_norm = 1.0
        torch.nn.utils.clip_grad_norm_(model.parameters(),
                                           max_grad_norm, norm_type=2)
        optimizer.step()
        opt_scheduler.step()

        if i % 50 == 0:
            print("epoch:{}, batch:{}/{}, reward:{}"
                .format(epoch, i, steps, R.mean().item()))
            # R_mean.append(R.mean().item())
            # R_std.append(R.std().item())

            # greedy validation
            
            tour_len = 0

            X_val = np.random.rand(B_val, size_val, 2)
            X = X_val
            X = torch.Tensor(X).cuda()
            
            mask = torch.zeros(B_val,size_val).cuda()
            
            R = 0
            logprobs = 0
            Idx = []
            reward = 0
            
            Y = X.view(B_val, size_val, 2)    # to the same batch size
            x = Y[:,0,:]
            h = None
            c = None
            
            for k in range(size_val):
                
                output, h, c, hidden_u = model(x=x, X_all=X, h=h, c=c, mask=mask)
                
                sampler = torch.distributions.Categorical(output)
                # idx = sampler.sample()
                idx = torch.argmax(output, dim=1)
                Idx.append(idx.data)
            
                Y1 = Y[[i for i in range(B_val)], idx.data]
                
                if k == 0:
                    Y_ini = Y1.clone()
                if k > 0:
                    reward = torch.norm(Y1-Y0, dim=1)
        
                Y0 = Y1.clone()
                x = Y[[i for i in range(B_val)], idx.data]
                
                R += reward
                
                mask[[i for i in range(B_val)], idx.data] += -np.inf
        
            R += torch.norm(Y1-Y_ini, dim=1)
            
            val_mean.append(R.mean().item())
            val_std.append(R.std().item())
            tour_len += R.mean().item()
            print('validation tour length:', tour_len)

    print('save model to: ', save_root)
    torch.save(model, save_root)
