In [8]:
import torch.nn as nn
import math
import torch
import torch.nn.functional as F
from torch.autograd import Variable
from torch import optim
from torch.utils.data import DataLoader
import tqdm
from Datasets import Datasets
import math


dataset = Datasets("C:\Attention\data\\train.txt")

dataset.bulid_vocab(dataset.en_data,dataset.ch_data)

dataloader = DataLoader(dataset, batch_size=16, num_workers=0,collate_fn=dataset.collate_fn)


maxlen = 128
d_model = 512
units = 512
dropout_rate = 0.2
numofblock = 4
numofhead = 4
# encoder_vocab = len(dataset.ch_vocab)
vocab_size = len(dataset.en_vocab)
epochs = 20
latent_dim = 512
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def get_padding_mask(seq_q,seq_k):
    # print(seq_k.shape)
    # print(seq_q.shape)
    batch_size, len_q = seq_q.size()
    batch_size, len_k = seq_k.size()
    padding_mask = seq_k.data.eq(1).unsqueeze(1)
    return padding_mask.expand(batch_size,len_q,len_k)


class TokenEmbedding(nn.Module):
    def __init__(self,vocab_size,emb_size):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size,emb_size)

    def forward(self,x):
        # print(x.shape)

        return self.embedding(x).to(DEVICE) # shape = (batch_size,input_seq_len,emb_dim)


class MultiHeadAttention(nn.Module):
    def __init__(self, num_units, num_heads, dropout_rate, mask=False):
        super().__init__()
        self.num_units = num_units
        self.num_head = num_heads
        self.dropout_rate = dropout_rate
        self.mask = mask
        self.linearQ = nn.Linear(self.num_units,self.num_units)
        self.linearK = nn.Linear(self.num_units,self.num_units)
        self.linearV = nn.Linear(self.num_units,self.num_units)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(self.dropout_rate)
        self.LayerNormalization = nn.LayerNorm(d_model)
        self.Q = nn.Sequential(self.linearQ,self.relu)
        self.K = nn.Sequential(self.linearK,self.relu)
        self.V = nn.Sequential(self.linearV,self.relu)


    def forward(self, queries, keys, values, self_padding_mask, enc_dec_padding_mask):
        '''
        :param queries: shape:(batch_size,input_seq_len,d_model)
        :param keys: shape:(batch_size,input_seq_len,d_model)
        :param values: shape:(batch_size,input_seq_len,d_model)
        :return: None
        '''
        q, k, v = self.Q(queries), self.K(keys), self.V(values)

        q_split, k_split, v_split = torch.chunk(q,self.num_head,dim=-1), torch.chunk(k,self.num_head,dim=-1), torch.chunk(v,self.num_head,dim=-1)
        q_, k_, v_ = torch.stack(q_split,dim=1), torch.stack(k_split,dim=1), torch.stack(v_split,dim=1)
        # shape : (batch_size, num_head, input_seq_len, depth = d_model/num_head)
        a = torch.matmul(q_,k_.permute(0,1,3,2)) # a = q * k^T(后两个维度)
        a = a / (k_.size()[-1] ** 0.5) # shape:(batch_size,num_head,seq_len,seq_len)
        batch_size_shape = a.shape[0]
        seq_len_shape = a.shape[2]
        if self.mask:
            self_padding_mask = self_padding_mask.unsqueeze(1).repeat(1, self.num_head, 1, 1)
            masked = torch.ones((batch_size_shape,1,seq_len_shape,seq_len_shape))
            masked = Variable((1 - torch.tril(masked, diagonal=0)) * (-2 ** 32 + 1)).to(DEVICE)

            assert masked.shape[-1] == self_padding_mask.shape[-1]
            a = a + masked
            a.masked_fill_(self_padding_mask,-1e9)
        else:
            enc_dec_padding_mask = enc_dec_padding_mask.unsqueeze(1).repeat(1, self.num_head, 1, 1)
            a.masked_fill_(enc_dec_padding_mask,-1e9)

        a = F.softmax(a,dim=-1)

        a = torch.matmul(a,v_)
        a = torch.reshape(a.permute(0, 2, 1, 3), shape=(q.shape[0],q.shape[1],q.shape[2]))
        a = self.dropout(a)
        a += queries
        a = self.LayerNormalization(a).to(DEVICE)
        return a


class FC(nn.Module):
    def __init__(self,input_channels,units=(2048,512)):
        super().__init__()
        self.input_channels = input_channels
        self.units = units
        self.layer1 = nn.Linear(self.input_channels,units[0])
        self.layer2 = nn.Linear(self.units[0],self.units[1])
        self.relu = nn.ReLU()
        self.LayerNormalization = nn.LayerNorm(d_model)


    def forward(self,x):
        outputs = self.layer1(x)
        outputs = self.relu(outputs)
        outputs = self.layer2(outputs)
        outputs += x
        outputs = self.LayerNormalization(outputs)
        return outputs.to(DEVICE)


class PositionalEncoding(nn.Module):
    def __init__(self, d_model=d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        """
        x: [seq_len, batch_size, d_model]
        """
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)


class EncoderLayer(nn.Module):
    def __init__(self):
        super(EncoderLayer, self).__init__()
        self.mask_self_attention = MultiHeadAttention(units,numofhead,dropout_rate,True)
        self.fc = FC(d_model)


    def forward(self,inputs,padding_mask):
        outputs = self.mask_self_attention(inputs,inputs,inputs,padding_mask,None)
        outputs = self.fc(outputs)
        return outputs.to(DEVICE)



class DecoderLayer(nn.Module):
    def __init__(self):
        super(DecoderLayer, self).__init__()
        self.self_attention = MultiHeadAttention(units,numofhead,dropout_rate,mask=False)
        self.fc = FC(d_model)



    def forward(self,enc_outputs):
        outputs = self.fc(enc_outputs)
        return outputs.to(DEVICE)



class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.layers = nn.ModuleList([DecoderLayer() for _ in range(numofblock)])


    def forward(self,x):
        for layer in self.layers:
            x = layer(x)
        return x


class Encoder(nn.Module):
    def __init__(self,vocab_size):
        super(Encoder, self).__init__()
        self.pe = PositionalEncoding()
        self.embedding = TokenEmbedding(vocab_size,units)
        self.layers = nn.ModuleList([EncoderLayer() for _ in range(numofblock)])


    def forward(self,inputs):
        outputs = self.embedding(inputs)
        outputs = self.pe(outputs.transpose(0, 1)).transpose(0, 1)

        padding_mask = get_padding_mask(inputs,inputs)
        for layer in self.layers:
            outputs = layer(outputs,padding_mask)
        return outputs



class CTG(nn.Module):
    def __init__(self,vocab_size):
        super(CTG, self).__init__()
        self.Encoder = Encoder(vocab_size)
        self.Decoder = Decoder()
        self.linear = nn.Linear(d_model,vocab_size)
        self.mean = nn.Linear(d_model,latent_dim)
        self.log_var = nn.Linear(d_model,latent_dim)

    def reparameterize(self,z_mean,z_log_var):
        std = torch.exp(0.5 * z_log_var)
        eps = torch.randn_like(std)
        return eps * std + z_mean


    def forward(self,x):
        enc_outputs= self.Encoder(x)

        z_mean = self.mean(enc_outputs)
        z_log_var = self.log_var(enc_outputs)

        z = self.reparameterize(z_mean,z_log_var)
        enc_outputs = self.Decoder(z)
        logits = self.linear(enc_outputs)

        logits = logits.view(-1, logits.size(-1))
        return logits,z_mean,z_log_var


model = CTG(vocab_size).to(DEVICE)
criterion = nn.CrossEntropyLoss(ignore_index=1)
optimizer = optim.SGD(model.parameters(), lr=1e-3, momentum=0.99)


for epoch in tqdm.tqdm(range(epochs)):
    total = []
    for _,dec_inputs,dec_outputs in dataloader:

        dec_inputs,dec_outputs= dec_inputs.to(DEVICE),dec_outputs.to(DEVICE)

        outputs,z_mean,z_log_var = model(dec_inputs)

        normal_loss = criterion(outputs,dec_outputs.contiguous().view(-1))

        reconstruction_loss = F.cross_entropy(outputs,dec_outputs.contiguous().view(-1))
        kl_loss = torch.mean(0.5 * torch.sum(torch.exp(z_log_var) + z_mean ** 2 - 1. - z_log_var, 1))

        loss = normal_loss + reconstruction_loss + kl_loss
        optimizer.zero_grad()
        loss.backward()
        total.append(loss)
        optimizer.step()
    cur_loss = sum(total)/len(total)
    print(cur_loss)
    print(math.exp(cur_loss))


  5%|▌         | 1/20 [00:42<13:23, 42.31s/it]

tensor(10.6126, device='cuda:0', grad_fn=<DivBackward0>)
40644.60172941534


 10%|█         | 2/20 [01:24<12:35, 41.99s/it]

tensor(9.4371, device='cuda:0', grad_fn=<DivBackward0>)
12545.012760173635


 15%|█▌        | 3/20 [02:05<11:50, 41.79s/it]

tensor(9.0147, device='cuda:0', grad_fn=<DivBackward0>)
8223.415756655515


 20%|██        | 4/20 [02:45<10:58, 41.18s/it]

tensor(8.6417, device='cuda:0', grad_fn=<DivBackward0>)
5662.8448644029195


 25%|██▌       | 5/20 [03:29<10:29, 41.95s/it]

tensor(8.3797, device='cuda:0', grad_fn=<DivBackward0>)
4357.550535796576


 30%|███       | 6/20 [04:12<09:53, 42.43s/it]

tensor(8.2274, device='cuda:0', grad_fn=<DivBackward0>)
3741.9375511069543


 35%|███▌      | 7/20 [04:52<09:00, 41.61s/it]

tensor(8.0681, device='cuda:0', grad_fn=<DivBackward0>)
3191.1182357241064


 40%|████      | 8/20 [05:32<08:13, 41.15s/it]

tensor(7.9492, device='cuda:0', grad_fn=<DivBackward0>)
2833.2240697442367


 45%|████▌     | 9/20 [06:13<07:30, 40.96s/it]

tensor(7.8584, device='cuda:0', grad_fn=<DivBackward0>)
2587.425033673666


 50%|█████     | 10/20 [06:53<06:48, 40.82s/it]

tensor(7.7860, device='cuda:0', grad_fn=<DivBackward0>)
2406.7261599557783


 55%|█████▌    | 11/20 [07:33<06:05, 40.66s/it]

tensor(7.7132, device='cuda:0', grad_fn=<DivBackward0>)
2237.5998711865595


 60%|██████    | 12/20 [08:13<05:22, 40.37s/it]

tensor(7.6146, device='cuda:0', grad_fn=<DivBackward0>)
2027.5897322684073


 65%|██████▌   | 13/20 [08:54<04:43, 40.53s/it]

tensor(7.6103, device='cuda:0', grad_fn=<DivBackward0>)
2018.7981247652624


 70%|███████   | 14/20 [09:35<04:04, 40.74s/it]

tensor(7.5357, device='cuda:0', grad_fn=<DivBackward0>)
1873.814207384793


 75%|███████▌  | 15/20 [10:16<03:24, 40.81s/it]

tensor(7.5048, device='cuda:0', grad_fn=<DivBackward0>)
1816.6835540346658


 80%|████████  | 16/20 [10:57<02:43, 40.88s/it]

tensor(7.4524, device='cuda:0', grad_fn=<DivBackward0>)
1724.0503890296986


 85%|████████▌ | 17/20 [11:38<02:02, 40.73s/it]

tensor(7.4053, device='cuda:0', grad_fn=<DivBackward0>)
1644.6203633895284


 90%|█████████ | 18/20 [12:18<01:21, 40.65s/it]

tensor(7.3619, device='cuda:0', grad_fn=<DivBackward0>)
1574.8827316251268


 95%|█████████▌| 19/20 [13:03<00:41, 41.80s/it]

tensor(7.3122, device='cuda:0', grad_fn=<DivBackward0>)
1498.4302910866338


100%|██████████| 20/20 [13:43<00:00, 41.20s/it]

tensor(7.3045, device='cuda:0', grad_fn=<DivBackward0>)
1486.9537516699113





In [None]:

def greedy_decoder(model, start_symbol):
    """贪心编码
    For simplicity, a Greedy Decoder is Beam search when K=1. This is necessary for inference as we don't know the
    target sequence input. Therefore we try to generate the target input word by word, then feed it into the transformer.
    Starting Reference: http://nlp.seas.harvard.edu/2018/04/03/attention.html#greedy-decoding
    :param model: Transformer Model
    :param enc_input: The encoder input
    :param start_symbol: The start symbol. In this example it is 'S' which corresponds to index 4
    :return: The target input
    """
    inputs = torch.zeros(1, 0).long()
    terminal = False
    next_symbol = start_symbol
    while not terminal:
        # 预测阶段：inputs序列会一点点变长（每次添加一个新预测出来的单词）
        inputs = torch.cat([inputs.to(DEVICE), torch.tensor([[next_symbol]], dtype=inputs.dtype).to(DEVICE)],
                              -1)
        # print("inputs:")
        # print(inputs)
        dec_outputs = model.Encoder(inputs)
        dec_outputs = model.Decoder(dec_outputs)
        dec_outputs = model.linear(dec_outputs)

        prob = dec_outputs.squeeze(0).max(dim=-1, keepdim=False)[1]
        # print("prob:")
        # print(dataset.idx2enwords(prob))
        # 增量更新（我们希望重复单词预测结果是一样的）
        # 我们在预测是会选择性忽略重复的预测的词，只摘取最新预测的单词拼接到输入序列中
        next_word = prob.data[-1]  # 拿出当前预测的单词(数字)。我们用x'_t对应的输出z_t去预测下一个单词的概率，不用z_1,z_2..z_{t-1}
        # print(next_word)

        next_symbol = next_word
        # print(dataset.idx2en(next_word))
        if next_symbol == dataset.en_vocab["<eos>"]:
            terminal = True
        # print(next_word)

    # greedy_dec_predict = torch.cat(
    #     [inputs.to(device), torch.tensor([[next_symbol]], dtype=enc_input.dtype).to(device)],
    #     -1)
    greedy_dec_predict = inputs[:, 1:]
    return greedy_dec_predict

for i in range(20):
    greedy_dec_predict = greedy_decoder(model, start_symbol=dataset.en_vocab["<bos>"])
    # print(input[i], '->', greedy_dec_predict.squeeze())
    print(" ".join([dataset.idx2en(n.item()) for n in greedy_dec_predict.squeeze()]))


In [10]:

def get_sequence():
    s = dataset.words2idx("<bos>".split(),"en")
    s = s.unsqueeze(0).to(DEVICE)

    # print(s.shape)
    flag = True
    data = torch.tensor([]).long().to(DEVICE).unsqueeze(0)
    count = 0
    while flag:
        # print(s,data)
        s = torch.cat((s,data),dim=-1)
        dec_outputs,z_mean,z_log_var = model(s.to(DEVICE))
        prob = F.softmax(dec_outputs, dim=-1)
        # print(prob)
        prob = torch.multinomial(prob, num_samples=1)
        # print(prob)
        data = prob[-1].unsqueeze(0)
        # print(prob)
        # data = prob
#         print(data)
        count += 1
        if data == 3:
            flag = False
        if count == 20:
            flag = False
    # print()

    # print(prob)
#     print(s)
    # for i in prob:
    print(dataset.idx2enwords(s[-1]))
for i in range(10):
    get_sequence()

<bos> i'm very greene after him with the leg.
<bos> she doesn't have to stay with me here?
<bos> you're good yesterday.
<bos> can was that?
<bos> the list in tennis decisions
<bos> they work more and the dress.
<bos> he's roll in this light.
<bos> see your way for sleep is a skirt.
<bos> the book was often all first visit two to the secret.
<bos> we'd just long novel.


In [20]:

def get_sequence():
    s = dataset.words2idx("<bos> he".split(),"en")
    s = s.unsqueeze(0).to(DEVICE)

    # print(s.shape)
    flag = True
    data = torch.tensor([]).long().to(DEVICE).unsqueeze(0)
    inputs = torch.zeros(1, 0).long()
    count = 0
    while flag:

        s = torch.cat((s,data),dim=-1)
        dec_outputs,_,_ = model(s.to(DEVICE))

        prob = dec_outputs.squeeze(0).max(dim=-1, keepdim=False)[1]

        data = prob[-1].unsqueeze(0).unsqueeze(0)


#         print(data)
        count += 1
        if data == 3:
            flag = False
        if count == 20:
            flag = False
    # print()

    # print(prob)
    # for i in prob:
    print(dataset.idx2enwords(s[-1]))
for i in range(10):
    get_sequence()

<bos> he was a good books of the morning.
<bos> he was a good time.
<bos> he was a good books of the party.
<bos> he was a good time.
<bos> he was a good time.
<bos> he was a good time of the house.
<bos> he was a good time.
<bos> he was a good time for the house.
<bos> he was a good more and in the morning.
<bos> he was a good more years than the house.
