https://github.com/ccjjxx99/PointerNetworks-pytorch-tsp

https://github.com/aurelienbibaut/Actor_CriticPointer_Network-TSP

https://github.com/zifeiyu0531/PointerNetwork-RL-TSP

https://colab.research.google.com/drive/1lobspU9b7dTO_HuoX-3nibZspTwfa5aX?usp=sharing#scrollTo=UMntU4jUu_v1

## 指针网络

`RNN`的输入和输出长度是被限制成一样的了，`seq2seq`的方式通过`RNN`编码和另外一个`RNN`解码的方式能够解除掉这个限制，典型的就是注意力机制`Transformer`网络。但是还是存在一个问题：不能够解决输出字典长度不固定的问题，比如英语单词字典就定`n=8000`个单词，那么`RNN`一个时间步的输出就是一个长度为`8000`的向量，是定死的，这类问题对于处理文本类的问题是没有问题的，但是对于一些特定问题，比如组合优化问题

如果这个长度也是改变的那怎么办呢？指针网络提出了一种能够解决上述问题的网络结构：

<img src="../images/pointer_network.png" width="50%">

https://github.com/ast0414/pointer-networks-pytorch：解决整数排序问题
 I expand upon earlier work by solving the planar convex hull problem and replace the LSTM-based encoder and decoder with transformers.

`pointer network`的核心公示如下：

$$
u_{j}^{i} = v^{T}tanh(W_{1}e_{j} + W_{2}d_{i}) j \in (1, \cdots, n)
$$

$$
p(C_{i}|C_{1},\cdots,C_{i-1}, P) = softmax(u^{i})
$$

它由两个可学习的权重$W_{1}$和$W_{2}$与编码、解码网络的输出相乘。之后再经过一个非线性变换与另一个学习权重$v$相乘。

## 导入相关的库

In [1]:
import torch.nn as nn
import torch.nn.functional as F

import torch.optim as optim
import torch.backends.cudnn as cudnn
from torch.autograd import Variable
from torch.utils.data import DataLoader
import torch
from torch.utils.data import Dataset
import itertools
import numpy as np
import argparse
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


## Encoder

在指针网络求解`TSP`问题中，我们需要对给定的数据进行编码，之后传给`Decoder`去解码得到输出。可以看到下面代码的`encoder`就只采用了一个`LSTM`结构。

In [2]:
class Encoder(nn.Module):
    def __init__(self, feature_size, embedding_dim, hidden_dim, n_layers, dropout, bidir):

        super(Encoder, self).__init__()
        
        self.embedding_layer = nn.Linear(feature_size, embedding_dim)
        
        self.hidden_dim = hidden_dim // 2 if bidir else hidden_dim
        self.n_layers = n_layers * 2 if bidir else n_layers
        self.bidir = bidir
        self.lstm = nn.LSTM(embedding_dim, self.hidden_dim, n_layers, dropout=dropout, bidirectional=bidir, batch_first=True)
        
        self.h0 = nn.Parameter(torch.zeros(1), requires_grad=False)
        self.c0 = nn.Parameter(torch.zeros(1), requires_grad=False)

    def forward(self, input_datas, hidden):
        embedded_inputs = self.embedding_layer(input_datas)
        outputs, hidden = self.lstm(embedded_inputs, hidden)  # [8, 5, 512], [4, 8, 256], [4, 8, 256]

        return outputs, hidden

    def init_hidden(self, batch_size):

        h0 = self.h0.unsqueeze(0).unsqueeze(0).repeat(self.n_layers, batch_size, self.hidden_dim)  # [4, 8, 256]
        c0 = self.h0.unsqueeze(0).unsqueeze(0).repeat(self.n_layers, batch_size, self.hidden_dim)  # [4, 8, 256]

        return h0, c0

In [3]:
encoder = Encoder(feature_size = 2, embedding_dim = 128, hidden_dim = 512, n_layers = 2, dropout = 0, bidir=True)

假设`batch size = 8`， 序列长度为`5`， `TSP`问题的两个特征点为横纵坐标，是一个两维的变量。需要先编码到指定维度

In [4]:
input_data = torch.randn(8, 5, 2)
h0, c0 = encoder.init_hidden(batch_size = 8)
outputs, hidden = encoder(input_data, (h0, c0))
print("outputs shape {}".format(outputs.size()))

outputs shape torch.Size([8, 5, 512])


可以看到，对于`Encoder`部分主要处理的就是将给定的输入序列进行一个编码，这里若采用`Transformer`进行编码的话也是可以的。

## 注意力机制

`Attention`的输入是由`decoder`上一个时刻的输出和`decoder`上一时刻的隐藏状态变换得到的`input`。和`Encoder`编码之后的结果共同作用得到：

In [5]:
class Attention(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(Attention, self).__init__()

        self.input_dim = input_dim
        self.hidden_dim = hidden_dim

        self.input_linear = nn.Linear(input_dim, hidden_dim)
        self.context_linear = nn.Conv1d(input_dim, hidden_dim, 1, 1)
        self.V = nn.Parameter(torch.FloatTensor(hidden_dim), requires_grad=True)
        self._inf = nn.Parameter(torch.FloatTensor([float('-inf')]), requires_grad=False)
        self.tanh = nn.Tanh()
        self.softmax = nn.Softmax()

        # Initialize vector V
        nn.init.uniform(self.V, -1, 1)

    def forward(self, input, context, mask):
        
        # (batch, hidden_dim, seq_len)
        inp = self.input_linear(input).unsqueeze(2).expand(-1, -1, context.size(1))

        # (batch, hidden_dim, seq_len)
        context = context.permute(0, 2, 1)
        ctx = self.context_linear(context)

        # (batch, 1, hidden_dim)
        V = self.V.unsqueeze(0).expand(context.size(0), -1).unsqueeze(1)

        # (batch, seq_len)
        att = torch.bmm(V, self.tanh(inp + ctx)).squeeze(1)
        if len(att[mask]) > 0:
            att[mask] = self.inf[mask]
        alpha = self.softmax(att)

        hidden_state = torch.bmm(ctx, alpha.unsqueeze(2)).squeeze(2)

        return hidden_state, alpha

    def init_inf(self, mask_size):
        self.inf = self._inf.unsqueeze(1).expand(*mask_size)

In [6]:
input_length, batch_size = 5, 8
attention_layer = Attention(input_dim = 512, hidden_dim = 512)
mask = nn.Parameter(torch.ones(1), requires_grad=False).repeat(input_length).unsqueeze(0).repeat(batch_size, 1)
attention_layer.init_inf(mask.size())
print("mask size {}".format(mask.size()))

mask size torch.Size([8, 5])


  app.launch_new_instance()


In [7]:
print(mask)

tensor([[1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.]])


In [8]:
input_data = torch.randn(8, 512)
context = torch.randn(8, 5, 512)
mask_input = torch.eq(mask, 0)
print(mask_input)

tensor([[False, False, False, False, False],
        [False, False, False, False, False],
        [False, False, False, False, False],
        [False, False, False, False, False],
        [False, False, False, False, False],
        [False, False, False, False, False],
        [False, False, False, False, False],
        [False, False, False, False, False]])


In [9]:
hidden_state, output = attention_layer(input_data, context, mask_input)
print("hidden_state size {}".format(hidden_state.size()))
print("output size {}".format(output.size()))

hidden_state size torch.Size([8, 512])
output size torch.Size([8, 5])




## Decoder

In [10]:
class Decoder(nn.Module):
    def __init__(self, embedding_dim, hidden_dim):

        super(Decoder, self).__init__()
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim

        self.input_to_hidden = nn.Linear(embedding_dim, 4 * hidden_dim)
        self.hidden_to_hidden = nn.Linear(hidden_dim, 4 * hidden_dim)
        self.hidden_out = nn.Linear(hidden_dim * 2, hidden_dim)
        self.att = Attention(hidden_dim, hidden_dim)

        # Used for propagating .cuda() command
        self.mask = nn.Parameter(torch.ones(1), requires_grad=False)
        self.runner = nn.Parameter(torch.zeros(1), requires_grad=False)

    def forward(self, embedded_inputs, decoder_input, decoder_hidden, context):

        batch_size, input_length, _ = embedded_inputs.size()  # [8, 5, 128]

        # (batch, seq_len)
        mask = self.mask.repeat(input_length).unsqueeze(0).repeat(batch_size, 1)  # [8, 5]
        self.att.init_inf(mask.size())

        # Generating arang(input_length), broadcasted across batch_size
        runner = self.runner.repeat(input_length)
        for i in range(input_length):
            runner.data[i] = i
        runner = runner.unsqueeze(0).expand(batch_size, -1).long()

        outputs = []
        pointers = []

        def step(decoder_input, decoder_hidden):

            # Regular LSTM
            h, c = decoder_hidden  # []

            gates = self.input_to_hidden(decoder_input) + self.hidden_to_hidden(h)  # [8, 2048]
            input, forget, cell, out = gates.chunk(4, 1)

            input = F.sigmoid(input)  # [8, 512]
            forget = F.sigmoid(forget)  # [8, 512]
            cell = F.tanh(cell)  # [8, 512]
            out = F.sigmoid(out)  # [8, 512]

            c_t = (forget * c) + (input * cell)  # [8, 512]
            h_t = out * F.tanh(c_t)  # [8, 512]

            # Attention section
            hidden_t, output = self.att(h_t, context, torch.eq(mask, 0))
            hidden_t = F.tanh(self.hidden_out(torch.cat((hidden_t, h_t), 1)))

            return hidden_t, c_t, output

        # Recurrence loop
        for _ in range(input_length):
            h_t, c_t, outs = step(decoder_input, decoder_hidden)
            decoder_hidden = (h_t, c_t)  # 更新隐藏状态。

            # Masking selected inputs
            masked_outs = outs * mask

            # Get maximum probabilities and indices
            max_probs, indices = masked_outs.max(1)
            one_hot_pointers = (runner == indices.unsqueeze(1).expand(-1, outs.size()[1])).float()  # 被选中的编码

            # Update mask to ignore seen indices
            mask = mask * (1 - one_hot_pointers)  #被选中的经过mask之后下一轮就选不到了。

            # Get embedded inputs by max indices
            embedding_mask = one_hot_pointers.unsqueeze(2).expand(-1, -1, self.embedding_dim).byte()
            decoder_input = embedded_inputs[embedding_mask.data].view(batch_size, self.embedding_dim)

            outputs.append(outs.unsqueeze(0))
            pointers.append(indices.unsqueeze(1))

        outputs = torch.cat(outputs).permute(1, 0, 2)
        pointers = torch.cat(pointers, 1)

        return (outputs, pointers), decoder_hidden

## Pointer Network

In [11]:
class PointerNet(nn.Module):
    def __init__(self, feature_size, embedding_dim, hidden_dim, lstm_layers, dropout, bidir=False):

        super(PointerNet, self).__init__()
        self.embedding_dim = embedding_dim
        self.bidir = bidir
        self.embedding = nn.Linear(2, embedding_dim)
        self.encoder = Encoder(feature_size, embedding_dim, hidden_dim, lstm_layers, dropout, bidir)
        self.decoder = Decoder(embedding_dim, hidden_dim)
        self.decoder_input0 = nn.Parameter(torch.FloatTensor(embedding_dim), requires_grad=False)

        # Initialize decoder_input0
        nn.init.uniform(self.decoder_input0, -1, 1)

    def forward(self, inputs):

        batch_size, input_length, _ = inputs.size()  # [8, 5, 2]

        embedded_inputs = self.embedding(inputs)  # [8, 5, 128]
        encoder_hidden0 = self.encoder.init_hidden(batch_size)

        encoder_outputs, encoder_hidden = self.encoder(inputs, encoder_hidden0)  # [8, 5, 512] ([4, 8, 256], [4, 8, 256])

        if self.bidir:
            decoder_hidden_init = (torch.cat((encoder_hidden[0][-2:][0], encoder_hidden[0][-2:][1]), dim=-1),
                                   torch.cat((encoder_hidden[1][-2:][0], encoder_hidden[1][-2:][1]), dim=-1))
            # decoder_hidden0 = (torch.cat(encoder_hidden[0][-2:], dim=-1), torch.cat(encoder_hidden[1][-2:], dim=-1))
        else:
            decoder_hidden_init = (encoder_hidden[0][-1], encoder_hidden[1][-1])
        decoder_input_init = self.decoder_input0.unsqueeze(0).expand(batch_size, -1)  # [8, 128]
        (outputs, pointers), decoder_hidden = self.decoder(embedded_inputs, decoder_input_init, decoder_hidden_init, encoder_outputs)

        return outputs, pointers


## 创建数据集

In [12]:
def tsp_opt(points):
    """
    Dynamic programing solution for TSP - O(2^n*n^2)
    https://gist.github.com/mlalevic/6222750
    :param points: List of (x, y) points
    :return: Optimal solution
    """

    def length(x_coord, y_coord):
        return np.linalg.norm(np.asarray(x_coord) - np.asarray(y_coord))

    # Calculate all lengths
    all_distances = [[length(x, y) for y in points] for x in points]
    # Initial value - just distance from 0 to every other point + keep the track of edges
    A = {(frozenset([0, idx+1]), idx+1): (dist, [0, idx+1]) for idx, dist in enumerate(all_distances[0][1:])}
    cnt = len(points)
    for m in range(2, cnt):
        B = {}
        for S in [frozenset(C) | {0} for C in itertools.combinations(range(1, cnt), m)]:
            for j in S - {0}:
                # This will use 0th index of tuple for ordering, the same as if key=itemgetter(0) used
                B[(S, j)] = min([(A[(S-{j}, k)][0] + all_distances[k][j], A[(S-{j}, k)][1] + [j])
                                 for k in S if k != 0 and k != j])
        A = B
    res = min([(A[d][0] + all_distances[0][d[1]], A[d][1]) for d in iter(A)])
    return np.asarray(res[1])


class TSPDataset(Dataset):
    """
    Random TSP dataset
    """

    def __init__(self, data_size, seq_len, solver=tsp_opt, solve=True):
        self.data_size = data_size
        self.seq_len = seq_len
        self.solve = solve
        self.solver = solver
        self.data = self._generate_data()

    def __len__(self):
        return self.data_size

    def __getitem__(self, idx):
        tensor = torch.from_numpy(self.data['Points_List'][idx]).float()
        solution = torch.from_numpy(self.data['Solutions'][idx]).long() if self.solve else None

        sample = {'Points':tensor, 'Solution':solution}

        return sample

    def _generate_data(self):
        """
        :return: Set of points_list ans their One-Hot vector solutions
        """
        points_list = []
        solutions = []
        data_iter = tqdm(range(self.data_size), unit='data')
        for i, _ in enumerate(data_iter):
            data_iter.set_description('Data points %i/%i' % (i+1, self.data_size))
            points_list.append(np.random.random((self.seq_len, 2)))
        solutions_iter = tqdm(points_list, unit='solve')
        if self.solve:
            for i, points in enumerate(solutions_iter):
                solutions_iter.set_description('Solved %i/%i' % (i+1, len(points_list)))
                solutions.append(self.solver(points))
        else:
            solutions = None

        return {'Points_List':points_list, 'Solutions':solutions}

    def _to1hotvec(self, points):
        """
        :param points: List of integers representing the points indexes
        :return: Matrix of One-Hot vectors
        """
        vec = np.zeros((len(points), self.seq_len))
        for i, v in enumerate(vec):
            v[points[i]] = 1

        return

In [13]:
if __name__ == "__main__":
#     pass
    parser = argparse.ArgumentParser(description="Pytorch implementation of Pointer-Net")

#     Data
    parser.add_argument('--train_size', default=1000000, type=int, help='Training data size')
    parser.add_argument('--val_size', default=10000, type=int, help='Validation data size')
    parser.add_argument('--test_size', default=10000, type=int, help='Test data size')
    parser.add_argument('--batch_size', default=256, type=int, help='Batch size')
    # Train
    parser.add_argument('--epoch_nums', default=50000, type=int, help='Number of epochs')
    parser.add_argument('--lr', type=float, default=0.0001, help='Learning rate')
    # GPU
    parser.add_argument('--gpu', default=True, action='store_true', help='Enable gpu')
    # TSP
    parser.add_argument('--point_nums', type=int, default=5, help='Number of points in TSP')
    # Network
    parser.add_argument('--embedding_size', type=int, default=128, help='Embedding size')
    parser.add_argument('--hidden_size', type=int, default=512, help='Number of hidden units')
    parser.add_argument('--num_layers', type=int, default=2, help='Number of LSTM layers')
    parser.add_argument('--dropout', type=float, default=0., help='Dropout value')
    parser.add_argument('--bidir', default=True, action='store_true', help='Bidirectional')
    parser.add_argument('--feature_size', type=int, default=2, help='feature size of input')

    args = parser.parse_args(args=[])

    if args.gpu and torch.cuda.is_available():
        USE_CUDA = True
#         print('Using GPU, %i devices.' % torch.cuda.device_count())
    else:
        USE_CUDA = False

    model = PointerNet(args.feature_size, args.embedding_size, args.hidden_size, args.num_layers, args.dropout, args.bidir)

    dataset = TSPDataset(args.train_size, args.point_nums)

    dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=0)

    if USE_CUDA:
        model.cuda()
        net = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count()))
        cudnn.benchmark = True

    CCE = torch.nn.CrossEntropyLoss()
    model_optim = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr)

    losses = []

    for epoch in range(args.epoch_nums):
        batch_loss = []
        iterator = tqdm(dataloader, unit='Batch')

        for i_batch, sample_batched in enumerate(iterator):
#             iterator.set_description('Batch %i/%i' % (epoch+1, args.point_nums))

            train_batch = Variable(sample_batched['Points'])
            target_batch = Variable(sample_batched['Solution'])

            if USE_CUDA:
                train_batch = train_batch.cuda()
                target_batch = target_batch.cuda()

            o, p = model(train_batch)
            o = o.contiguous().view(-1, o.size()[-1])

            target_batch = target_batch.view(-1)

            loss = CCE(o, target_batch)

            losses.append(loss.item())
            batch_loss.append(loss.item())

            model_optim.zero_grad()
            loss.backward()
            model_optim.step()

            iterator.set_postfix(loss='{}'.format(loss.item()))

        iterator.set_postfix(loss=np.average(batch_loss))

  app.launch_new_instance()
  del sys.path[0]
Data points 9726/1000000:   1%|          | 9717/1000000 [00:18<33:04, 498.90data/s] IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

Data points 14351/1000000:   1%|▏         | 14329/1000000 [00:22<09:28, 1732.74data/s]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

Data points 19148/1000000:   2%|▏         | 19115/1000000 [00:24<10:05, 1619.16data/s]IOPub message rate exceeded.
The notebook server

  allow_unreachable=True, accumulate_grad=True)  # allow_unreachable flag
100%|█████████▉| 3901/3907 [36:23<00:03,  1.79Batch/s, loss=1.1597148180007935]  


KeyboardInterrupt: 