In [0]:
from google.colab import drive
drive.mount('/content/drive/')

Drive already mounted at /content/drive/; to attempt to forcibly remount, call drive.mount("/content/drive/", force_remount=True).


In [2]:
import torch
print(torch.__version__)
! nvcc --version

1.0.1.post2
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2018 NVIDIA Corporation
Built on Sat_Aug_25_21:08:01_CDT_2018
Cuda compilation tools, release 10.0, V10.0.130


In [0]:
! pip install --verbose --no-cache-dir torch-scatter
! pip install --verbose --no-cache-dir torch-sparse
! pip install --verbose --no-cache-dir torch-cluster
# ! pip install --verbose --no-cache-dir torch-spline-conv
! pip install torch-geometric

# Neural Net Model

In [0]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from torch_geometric.nn import GCNConv

class GCN(torch.nn.Module):
    def __init__(self, n_feature, n_hidden, n_class):
        super(GCN, self).__init__()
        self.size = 0
        self.batch_size = 0
        self.dim = n_hidden
        
        v  = torch.FloatTensor(n_hidden).cuda()
        h0 = torch.FloatTensor(n_hidden).cuda()
        c0 = torch.FloatTensor(n_hidden).cuda()

        self.v  = nn.Parameter(v)
        self.h0 = nn.Parameter(h0)
        self.c0 = nn.Parameter(c0)
        self.v.data.uniform_(-1/math.sqrt(n_hidden), 1/math.sqrt(n_hidden))
        self.h0.data.uniform_(-1/math.sqrt(n_hidden), 1/math.sqrt(n_hidden))
        self.c0.data.uniform_(-1/math.sqrt(n_hidden), 1/math.sqrt(n_hidden))
        
        # embedding
        self.embedding_x = nn.Linear(n_feature, n_hidden)
        self.embedding_all = nn.Linear(n_feature, n_hidden)
        self.encoder_ori = GCNConv(n_hidden, n_hidden)    # use GCN as encoder
        self.encoder = GCNConv(n_hidden, n_hidden)    # use GCN as encoder
        
        
        # parameters for input gate
        self.Wxi = GCNConv(n_hidden, n_hidden)    # W(xt)
        self.Whi = GCNConv(n_hidden, n_hidden)    # W(ht)
        self.wci = nn.Linear(n_hidden, n_hidden)    # w(ct)
        
        # parameters for forget gate
        self.Wxf = GCNConv(n_hidden, n_hidden)    # W(xt)
        self.Whf = GCNConv(n_hidden, n_hidden)    # W(ht)
        self.wcf = nn.Linear(n_hidden, n_hidden)    # w(ct)
        
        # parameters for cell gate
        self.Wxc = GCNConv(n_hidden, n_hidden)    # W(xt)
        self.Whc = GCNConv(n_hidden, n_hidden)    # W(ht)
        
        # parameters for forget gate
        self.Wxo = GCNConv(n_hidden, n_hidden)    # W(xt)
        self.Who = GCNConv(n_hidden, n_hidden)    # W(ht)
        self.wco = nn.Linear(n_hidden, n_hidden)    # w(ct)
        
        # parameters for pointer attention
        self.Wref = GCNConv(n_hidden, n_hidden)
        self.Wq = nn.Linear(n_hidden, n_hidden)
        
        
    def forward(self, x, edge_index, input, mask, h=None, c=None):
        ''' args
        x: current city (B, 2)
        input: all cities (B*size, 2)
        h: hidden variable (B, dim)
        c: cell gate (B, dim)
        context: encoded context, encoder(input) --> (B*size, dim)
        '''
        self.batch_size = x.size(0)
        self.size = int(input.size(0) / self.batch_size)
        # print(self.size)
        
        if h is None:
            h = self.h0.unsqueeze(0).expand(self.batch_size, self.dim)
        if c is None:
            c = self.c0.unsqueeze(0).expand(self.batch_size, self.dim)

        self_connect = edge_index[:,:0].clone()
        alpha = 0.8
        
        # embedding
        x = self.embedding_x(x)
        context = self.embedding_all(input)
        
        context = (1-alpha) * self.encoder(context, edge_index)\
                + alpha * self.encoder_ori(context, self_connect)   # encode
        # alpha*I*H*W0 + (1-alpha)*DAD*H*W1
        
        
        # hidden variable does not have graph structure
        # input gate
        i = torch.sigmoid(self.Wxi(x,self_connect) + self.Whi(h,self_connect) + self.wci(c))
        # forget gate
        f = torch.sigmoid(self.Wxf(x,self_connect) + self.Whf(h,self_connect) + self.wcf(c))
        # cell gate
        c = f*c + i*torch.tanh(self.Wxc(x,self_connect) + self.Whc(h,self_connect))
        # output gate
        o = torch.sigmoid(self.Wxo(x,self_connect) + self.Who(h,self_connect) + self.wco(c))
        h = o*torch.tanh(c)
        
        

        # query and reference
        q = h
        ref = context
        q = self.Wq(q)     # (B, dim)
        ref = self.Wref(ref, self_connect)
        ref = ref.view(self.batch_size, self.size, self.dim)  # (B, size, dim)
        # print(ref.size())
        
        q_ex = q.unsqueeze(1).repeat(1, self.size, 1) # (B, size, dim)
        # print(q.size())
        # v_view: (B, dim, 1)
        v_view = self.v.unsqueeze(0).expand(self.batch_size, self.dim).unsqueeze(2)
        
        # (B, size, dim) * (B, dim, 1)
        u = torch.bmm(torch.tanh(q_ex + ref), v_view).squeeze(2)

        u = 10*torch.tanh(u) + mask

        return F.softmax(u, dim=1), h, c

# Training

In [0]:
import numpy as np

import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torch.optim import lr_scheduler
from scipy.spatial.distance import squareform, pdist
from sklearn.neighbors import NearestNeighbors

In [0]:
size = 50
learn_rate = 1e-3
beta = 0.8
B = 128

In [0]:
model = GCN(n_feature=2,
            n_hidden=128,
            n_class=1)

In [0]:
model.cuda()

learn_rate = 1e-3
lr_decay_step = 3000
lr_decay_rate = 0.96


optimizer = optim.Adam(model.parameters(), lr=learn_rate)

opt_scheduler = lr_scheduler.MultiStepLR(optimizer, range(lr_decay_step, 10000, 1000),
                                         gamma=lr_decay_rate)


In [0]:
# using knn since if B=128, size=50, the size of E is 128*128*50*50
# so use k-nearest neighboor to shrink the graph size

def knn(X, B, size, N):
    Y = X.reshape(B,size,2)
    E = []
    
    for k in range(B):
        nbrs = NearestNeighbors(n_neighbors=N, algorithm='ball_tree').fit(Y[k])
        # G = nbrs.kneighbors_graph(Y[k]).toarray()
        _, indices = nbrs.kneighbors(Y[k])

        for i in range(size):
            for n in range(N):
                E.append([i+k*size, indices[i,n]+k*size])
        
    return E

In [0]:
B = 128    # batch_size
C = 0     # baseline
R = 0     # reward

reward = 0
for i in range(3000):
    optimizer.zero_grad()

    X = np.random.rand(B*size, 2)
    K = knn(X, B, size, 5)    
        
    X = torch.Tensor(X).cuda()
    mask = torch.zeros(B,size).cuda()

    # k-nn edge
    E = np.array(K).T
    edge_index = torch.LongTensor(E).cuda()

    R = 0
    logprobs = 0
    
    Y = X.view(B,size,2)           # to the same batch size
    x = Y[:,0,:]
    h = None
    c = None

    for k in range(size):
        
        output, h, c = model(x=x, edge_index=edge_index, input=X, h=h, c=c, mask=mask)
        
        sampler = torch.distributions.Categorical(output)
        idx = sampler.sample()         # now the idx has B elements

        # print(output)
        Y1 = Y[[i for i in range(B)], idx.data].clone()
        if k == 0:
            Y_ini = Y1.clone()
        if k > 0:
            reward = torch.norm(Y1-Y0, dim=1)
            
        Y0 = Y1.clone()
        x = Y[[i for i in range(B)], idx.data].clone()
        
        R += reward
            
        TINY = 1e-15
        logprobs += torch.log(output[[i for i in range(B)], idx.data]+TINY) 
        
        mask[[i for i in range(B)], idx.data] += -np.inf 
        
        
    R += torch.norm(Y1-Y_ini, dim=1)
    
    if i == 0:
        C = R.mean()
    else:
        C = (C * beta) + ((1. - beta) * R.mean())
    
    loss = ((R-C)*logprobs).mean()

    loss.backward()
    
    max_grad_norm = 4.0
    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                       max_grad_norm, norm_type=2)
    optimizer.step()
    opt_scheduler.step()
    
    if i % 50 == 0:
        print("epoch:{}, loss:{}, reward:{}"
            .format(i,loss.item(),R.mean().item()))

epoch:0, loss:-3.925157308578491, reward:22.78050994873047
epoch:50, loss:1.8348097801208496, reward:10.116782188415527
epoch:100, loss:-0.9211317300796509, reward:8.541533470153809
epoch:150, loss:-4.340392589569092, reward:8.422378540039062
epoch:200, loss:0.31372666358947754, reward:8.087350845336914
epoch:250, loss:-4.57519006729126, reward:8.372838020324707
epoch:300, loss:-8.675603866577148, reward:8.368376731872559
epoch:350, loss:-7.94425106048584, reward:8.186714172363281
epoch:400, loss:0.43562081456184387, reward:7.808413028717041
epoch:450, loss:-1.272653579711914, reward:8.072643280029297
epoch:500, loss:-1.4694222211837769, reward:7.883309841156006
