# 数据的预处理与导入

In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import math, copy
from torch.autograd import Variable

import matplotlib.pyplot as plt

In [2]:
from torchtext import data,datasets
!python -m spacy download en
!python -m spacy download de

为 D:\Anaconda3\lib\site-packages\spacy\data\en <<===>> D:\Anaconda3\lib\site-packages\en_core_web_sm 创建的符号链接

    Linking successful
    D:\Anaconda3\lib\site-packages\en_core_web_sm -->
    D:\Anaconda3\lib\site-packages\spacy\data\en

    You can now load the model via spacy.load('en')

为 D:\Anaconda3\lib\site-packages\spacy\data\de <<===>> D:\Anaconda3\lib\site-packages\de_core_news_sm 创建的符号链接

    Linking successful
    D:\Anaconda3\lib\site-packages\de_core_news_sm -->
    D:\Anaconda3\lib\site-packages\spacy\data\de

    You can now load the model via spacy.load('de')



In [3]:
import spacy
spacy_de = spacy.load('de')
spacy_en = spacy.load('en')

In [4]:

def tokenize_de(text):
    return [tok.text for tok in spacy_de.tokenizer(text)]

def tokenize_en(text):
    return [tok.text for tok in spacy_en.tokenizer(text)]

BOS_WORD = '<s>'
EOS_WORD = '</s>'
BLANK_WORD = "<blank>"
SRC = data.Field(tokenize=tokenize_de, pad_token=BLANK_WORD)
#torchtext.data.Field : 用来定义字段的处理方法（文本字段，标签字段） 
TGT = data.Field(tokenize=tokenize_en, init_token = BOS_WORD, 
                 eos_token = EOS_WORD, pad_token=BLANK_WORD)

MAX_LEN = 100
train, val, test = datasets.IWSLT.splits(exts=('.de', '.en'), fields=(SRC, TGT), 
                                         filter_pred=lambda x: len(vars(x)['src']) <= MAX_LEN and 
                                         len(vars(x)['trg']) <= MAX_LEN)
#filter_pred（callable或None）仅使用filter_pred（example）为True的示例，或使用所有示例（如果为None）
MIN_FREQ = 1
SRC.build_vocab(train.src, min_freq=MIN_FREQ)
TGT.build_vocab(train.trg, min_freq=MIN_FREQ)

In [5]:
print(next(train.trg))
#train.src 是生成器
print(b for b in train)

['David', 'Gallo', ':', 'This', 'is', 'Bill', 'Lange', '.', 'I', "'m", 'Dave', 'Gallo', '.']
<generator object <genexpr> at 0x0000022A23977408>


In [6]:
SRC.vocab.freqs[',']

273475

###以上，实现了分词，建立词表

In [7]:
BATCH_SIZE = 4096
global max_src_in_batch, max_tgt_in_batch
def batch_size_fn(new, count, sofar):
    "Keep augmenting batch and calculate total number of tokens + padding."
    global max_src_in_batch, max_tgt_in_batch
    if count == 1:
        max_src_in_batch = 0
        max_tgt_in_batch = 0
    max_src_in_batch = max(max_src_in_batch,  len(new.src))
    max_tgt_in_batch = max(max_tgt_in_batch,  len(new.trg) + 2)
    src_elements = count * max_src_in_batch
    tgt_elements = count * max_tgt_in_batch
    return max(src_elements, tgt_elements)

In [8]:
#重写了Iterator的函数

class MyIterator(data.Iterator):
    def create_batches(self):
        if self.train:
            def pool(d, random_shuffler):
                for p in data.batch(d, self.batch_size * 100):
                    p_batch = data.batch(
                        sorted(p, key=self.sort_key),
                        self.batch_size, self.batch_size_fn)
                    for b in random_shuffler(list(p_batch)):
                        yield b
            self.batches = pool(self.data(), self.random_shuffler)
            
        else:
            self.batches = []
            for b in data.batch(self.data(), self.batch_size,
                                          self.batch_size_fn):
                self.batches.append(sorted(b, key=self.sort_key))

In [9]:
def rebatch(pad_idx, batch):
    "Fix order in torchtext to match ours"
    src, trg = batch.src.transpose(0, 1), batch.trg.transpose(0, 1)
    src_mask, trg_mask = make_std_mask(src, trg, pad_idx)
    return Batch(src, trg, src_mask, trg_mask, (trg[1:] != pad_idx).data.sum())

train_iter = MyIterator(train, batch_size=BATCH_SIZE, device=0,
                        repeat=False, sort_key=lambda x: (len(x.src), len(x.trg)),
                        batch_size_fn=batch_size_fn, train=True)
valid_iter = MyIterator(val, batch_size=BATCH_SIZE, device=0,
                        repeat=False, sort_key=lambda x: (len(x.src), len(x.trg)),
                        batch_size_fn=batch_size_fn, train=False)

The `device` argument should be set by using `torch.device` or passing a string as an argument. This behavior will be deprecated soon and currently defaults to cpu.
The `device` argument should be set by using `torch.device` or passing a string as an argument. This behavior will be deprecated soon and currently defaults to cpu.


In [10]:
#看一下train_iter里面都是什么
#for i, batch in enumerate(train_iter):
    #print(i,'-',batch)

0 - 
[torchtext.data.batch.Batch of size 341 from IWSLT]
	[.src]:[torch.LongTensor of size 10x341]
	[.trg]:[torch.LongTensor of size 12x341]
1 - 
[torchtext.data.batch.Batch of size 195 from IWSLT]
	[.src]:[torch.LongTensor of size 20x195]
	[.trg]:[torch.LongTensor of size 21x195]
2 - 
[torchtext.data.batch.Batch of size 163 from IWSLT]
	[.src]:[torch.LongTensor of size 16x163]
	[.trg]:[torch.LongTensor of size 25x163]
3 - 
[torchtext.data.batch.Batch of size 273 from IWSLT]
	[.src]:[torch.LongTensor of size 12x273]
	[.trg]:[torch.LongTensor of size 15x273]
4 - 
[torchtext.data.batch.Batch of size 141 from IWSLT]
	[.src]:[torch.LongTensor of size 12x141]
	[.trg]:[torch.LongTensor of size 29x141]
5 - 
[torchtext.data.batch.Batch of size 204 from IWSLT]
	[.src]:[torch.LongTensor of size 18x204]
	[.trg]:[torch.LongTensor of size 20x204]
6 - 
[torchtext.data.batch.Batch of size 163 from IWSLT]
	[.src]:[torch.LongTensor of size 23x163]
	[.trg]:[torch.LongTensor of size 25x163]
7 - 
[torchte

	[.trg]:[torch.LongTensor of size 24x170]
59 - 
[torchtext.data.batch.Batch of size 117 from IWSLT]
	[.src]:[torch.LongTensor of size 33x117]
	[.trg]:[torch.LongTensor of size 35x117]
60 - 
[torchtext.data.batch.Batch of size 80 from IWSLT]
	[.src]:[torch.LongTensor of size 41x80]
	[.trg]:[torch.LongTensor of size 51x80]
61 - 
[torchtext.data.batch.Batch of size 52 from IWSLT]
	[.src]:[torch.LongTensor of size 52x52]
	[.trg]:[torch.LongTensor of size 78x52]
62 - 
[torchtext.data.batch.Batch of size 409 from IWSLT]
	[.src]:[torch.LongTensor of size 8x409]
	[.trg]:[torch.LongTensor of size 10x409]
63 - 
[torchtext.data.batch.Batch of size 240 from IWSLT]
	[.src]:[torch.LongTensor of size 8x240]
	[.trg]:[torch.LongTensor of size 17x240]
64 - 
[torchtext.data.batch.Batch of size 157 from IWSLT]
	[.src]:[torch.LongTensor of size 21x157]
	[.trg]:[torch.LongTensor of size 26x157]
65 - 
[torchtext.data.batch.Batch of size 110 from IWSLT]
	[.src]:[torch.LongTensor of size 25x110]
	[.trg]:[torch

	[.trg]:[torch.LongTensor of size 15x273]
117 - 
[torchtext.data.batch.Batch of size 68 from IWSLT]
	[.src]:[torch.LongTensor of size 53x68]
	[.trg]:[torch.LongTensor of size 60x68]
118 - 
[torchtext.data.batch.Batch of size 83 from IWSLT]
	[.src]:[torch.LongTensor of size 39x83]
	[.trg]:[torch.LongTensor of size 49x83]
119 - 
[torchtext.data.batch.Batch of size 315 from IWSLT]
	[.src]:[torch.LongTensor of size 9x315]
	[.trg]:[torch.LongTensor of size 13x315]
120 - 
[torchtext.data.batch.Batch of size 204 from IWSLT]
	[.src]:[torch.LongTensor of size 15x204]
	[.trg]:[torch.LongTensor of size 20x204]
121 - 
[torchtext.data.batch.Batch of size 186 from IWSLT]
	[.src]:[torch.LongTensor of size 14x186]
	[.trg]:[torch.LongTensor of size 22x186]
122 - 
[torchtext.data.batch.Batch of size 409 from IWSLT]
	[.src]:[torch.LongTensor of size 9x409]
	[.trg]:[torch.LongTensor of size 10x409]
123 - 
[torchtext.data.batch.Batch of size 70 from IWSLT]
	[.src]:[torch.LongTensor of size 48x70]
	[.trg]:[

	[.trg]:[torch.LongTensor of size 55x74]
175 - 
[torchtext.data.batch.Batch of size 178 from IWSLT]
	[.src]:[torch.LongTensor of size 18x178]
	[.trg]:[torch.LongTensor of size 23x178]
176 - 
[torchtext.data.batch.Batch of size 170 from IWSLT]
	[.src]:[torch.LongTensor of size 18x170]
	[.trg]:[torch.LongTensor of size 24x170]
177 - 
[torchtext.data.batch.Batch of size 81 from IWSLT]
	[.src]:[torch.LongTensor of size 48x81]
	[.trg]:[torch.LongTensor of size 50x81]
178 - 
[torchtext.data.batch.Batch of size 63 from IWSLT]
	[.src]:[torch.LongTensor of size 41x63]
	[.trg]:[torch.LongTensor of size 65x63]
179 - 
[torchtext.data.batch.Batch of size 240 from IWSLT]
	[.src]:[torch.LongTensor of size 14x240]
	[.trg]:[torch.LongTensor of size 17x240]
180 - 
[torchtext.data.batch.Batch of size 105 from IWSLT]
	[.src]:[torch.LongTensor of size 35x105]
	[.trg]:[torch.LongTensor of size 39x105]
181 - 
[torchtext.data.batch.Batch of size 157 from IWSLT]
	[.src]:[torch.LongTensor of size 18x157]
	[.trg

	[.trg]:[torch.LongTensor of size 42x97]
233 - 
[torchtext.data.batch.Batch of size 422 from IWSLT]
	[.src]:[torch.LongTensor of size 4x422]
	[.trg]:[torch.LongTensor of size 9x422]
234 - 
[torchtext.data.batch.Batch of size 409 from IWSLT]
	[.src]:[torch.LongTensor of size 7x409]
	[.trg]:[torch.LongTensor of size 10x409]
235 - 
[torchtext.data.batch.Batch of size 107 from IWSLT]
	[.src]:[torch.LongTensor of size 27x107]
	[.trg]:[torch.LongTensor of size 38x107]
236 - 
[torchtext.data.batch.Batch of size 69 from IWSLT]
	[.src]:[torch.LongTensor of size 56x69]
	[.trg]:[torch.LongTensor of size 59x69]
237 - 
[torchtext.data.batch.Batch of size 78 from IWSLT]
	[.src]:[torch.LongTensor of size 46x78]
	[.trg]:[torch.LongTensor of size 52x78]
238 - 
[torchtext.data.batch.Batch of size 186 from IWSLT]
	[.src]:[torch.LongTensor of size 22x186]
	[.trg]:[torch.LongTensor of size 20x186]
239 - 
[torchtext.data.batch.Batch of size 102 from IWSLT]
	[.src]:[torch.LongTensor of size 31x102]
	[.trg]:[

	[.trg]:[torch.LongTensor of size 42x97]
291 - 
[torchtext.data.batch.Batch of size 128 from IWSLT]
	[.src]:[torch.LongTensor of size 32x128]
	[.trg]:[torch.LongTensor of size 31x128]
292 - 
[torchtext.data.batch.Batch of size 107 from IWSLT]
	[.src]:[torch.LongTensor of size 37x107]
	[.trg]:[torch.LongTensor of size 38x107]
293 - 
[torchtext.data.batch.Batch of size 78 from IWSLT]
	[.src]:[torch.LongTensor of size 37x78]
	[.trg]:[torch.LongTensor of size 52x78]
294 - 
[torchtext.data.batch.Batch of size 83 from IWSLT]
	[.src]:[torch.LongTensor of size 37x83]
	[.trg]:[torch.LongTensor of size 49x83]
295 - 
[torchtext.data.batch.Batch of size 292 from IWSLT]
	[.src]:[torch.LongTensor of size 12x292]
	[.trg]:[torch.LongTensor of size 14x292]
296 - 
[torchtext.data.batch.Batch of size 204 from IWSLT]
	[.src]:[torch.LongTensor of size 18x204]
	[.trg]:[torch.LongTensor of size 20x204]
297 - 
[torchtext.data.batch.Batch of size 60 from IWSLT]
	[.src]:[torch.LongTensor of size 57x60]
	[.trg]:

	[.trg]:[torch.LongTensor of size 16x256]
349 - 
[torchtext.data.batch.Batch of size 315 from IWSLT]
	[.src]:[torch.LongTensor of size 11x315]
	[.trg]:[torch.LongTensor of size 13x315]
350 - 
[torchtext.data.batch.Batch of size 157 from IWSLT]
	[.src]:[torch.LongTensor of size 22x157]
	[.trg]:[torch.LongTensor of size 26x157]
351 - 
[torchtext.data.batch.Batch of size 292 from IWSLT]
	[.src]:[torch.LongTensor of size 11x292]
	[.trg]:[torch.LongTensor of size 14x292]
352 - 
[torchtext.data.batch.Batch of size 136 from IWSLT]
	[.src]:[torch.LongTensor of size 23x136]
	[.trg]:[torch.LongTensor of size 30x136]
353 - 
[torchtext.data.batch.Batch of size 59 from IWSLT]
	[.src]:[torch.LongTensor of size 62x59]
	[.trg]:[torch.LongTensor of size 69x59]
354 - 
[torchtext.data.batch.Batch of size 132 from IWSLT]
	[.src]:[torch.LongTensor of size 28x132]
	[.trg]:[torch.LongTensor of size 31x132]
355 - 
[torchtext.data.batch.Batch of size 204 from IWSLT]
	[.src]:[torch.LongTensor of size 17x204]
	[

	[.trg]:[torch.LongTensor of size 101x40]
407 - 
[torchtext.data.batch.Batch of size 186 from IWSLT]
	[.src]:[torch.LongTensor of size 8x186]
	[.trg]:[torch.LongTensor of size 22x186]
408 - 
[torchtext.data.batch.Batch of size 40 from IWSLT]
	[.src]:[torch.LongTensor of size 63x40]
	[.trg]:[torch.LongTensor of size 101x40]
409 - 
[torchtext.data.batch.Batch of size 91 from IWSLT]
	[.src]:[torch.LongTensor of size 45x91]
	[.trg]:[torch.LongTensor of size 43x91]
410 - 
[torchtext.data.batch.Batch of size 107 from IWSLT]
	[.src]:[torch.LongTensor of size 17x107]
	[.trg]:[torch.LongTensor of size 38x107]
411 - 
[torchtext.data.batch.Batch of size 83 from IWSLT]
	[.src]:[torch.LongTensor of size 46x83]
	[.trg]:[torch.LongTensor of size 49x83]
412 - 
[torchtext.data.batch.Batch of size 87 from IWSLT]
	[.src]:[torch.LongTensor of size 41x87]
	[.trg]:[torch.LongTensor of size 47x87]
413 - 
[torchtext.data.batch.Batch of size 372 from IWSLT]
	[.src]:[torch.LongTensor of size 8x372]
	[.trg]:[tor

	[.trg]:[torch.LongTensor of size 42x97]
465 - 
[torchtext.data.batch.Batch of size 85 from IWSLT]
	[.src]:[torch.LongTensor of size 36x85]
	[.trg]:[torch.LongTensor of size 48x85]
466 - 
[torchtext.data.batch.Batch of size 178 from IWSLT]
	[.src]:[torch.LongTensor of size 22x178]
	[.trg]:[torch.LongTensor of size 23x178]
467 - 
[torchtext.data.batch.Batch of size 178 from IWSLT]
	[.src]:[torch.LongTensor of size 17x178]
	[.trg]:[torch.LongTensor of size 23x178]
468 - 
[torchtext.data.batch.Batch of size 105 from IWSLT]
	[.src]:[torch.LongTensor of size 36x105]
	[.trg]:[torch.LongTensor of size 39x105]
469 - 
[torchtext.data.batch.Batch of size 163 from IWSLT]
	[.src]:[torch.LongTensor of size 19x163]
	[.trg]:[torch.LongTensor of size 25x163]
470 - 
[torchtext.data.batch.Batch of size 341 from IWSLT]
	[.src]:[torch.LongTensor of size 11x341]
	[.trg]:[torch.LongTensor of size 12x341]
471 - 
[torchtext.data.batch.Batch of size 97 from IWSLT]
	[.src]:[torch.LongTensor of size 33x97]
	[.tr

	[.trg]:[torch.LongTensor of size 8x512]
523 - 
[torchtext.data.batch.Batch of size 141 from IWSLT]
	[.src]:[torch.LongTensor of size 27x141]
	[.trg]:[torch.LongTensor of size 29x141]
524 - 
[torchtext.data.batch.Batch of size 157 from IWSLT]
	[.src]:[torch.LongTensor of size 22x157]
	[.trg]:[torch.LongTensor of size 26x157]
525 - 
[torchtext.data.batch.Batch of size 186 from IWSLT]
	[.src]:[torch.LongTensor of size 17x186]
	[.trg]:[torch.LongTensor of size 22x186]
526 - 
[torchtext.data.batch.Batch of size 91 from IWSLT]
	[.src]:[torch.LongTensor of size 42x91]
	[.trg]:[torch.LongTensor of size 45x91]
527 - 
[torchtext.data.batch.Batch of size 93 from IWSLT]
	[.src]:[torch.LongTensor of size 44x93]
	[.trg]:[torch.LongTensor of size 43x93]
528 - 
[torchtext.data.batch.Batch of size 341 from IWSLT]
	[.src]:[torch.LongTensor of size 8x341]
	[.trg]:[torch.LongTensor of size 12x341]
529 - 
[torchtext.data.batch.Batch of size 120 from IWSLT]
	[.src]:[torch.LongTensor of size 34x120]
	[.trg]

	[.trg]:[torch.LongTensor of size 44x93]
581 - 
[torchtext.data.batch.Batch of size 120 from IWSLT]
	[.src]:[torch.LongTensor of size 27x120]
	[.trg]:[torch.LongTensor of size 34x120]
582 - 
[torchtext.data.batch.Batch of size 47 from IWSLT]
	[.src]:[torch.LongTensor of size 57x47]
	[.trg]:[torch.LongTensor of size 86x47]
583 - 
[torchtext.data.batch.Batch of size 128 from IWSLT]
	[.src]:[torch.LongTensor of size 29x128]
	[.trg]:[torch.LongTensor of size 32x128]
584 - 
[torchtext.data.batch.Batch of size 113 from IWSLT]
	[.src]:[torch.LongTensor of size 36x113]
	[.trg]:[torch.LongTensor of size 35x113]
585 - 
[torchtext.data.batch.Batch of size 157 from IWSLT]
	[.src]:[torch.LongTensor of size 22x157]
	[.trg]:[torch.LongTensor of size 26x157]
586 - 
[torchtext.data.batch.Batch of size 372 from IWSLT]
	[.src]:[torch.LongTensor of size 9x372]
	[.trg]:[torch.LongTensor of size 11x372]
587 - 
[torchtext.data.batch.Batch of size 62 from IWSLT]
	[.src]:[torch.LongTensor of size 51x62]
	[.trg

	[.trg]:[torch.LongTensor of size 21x195]
639 - 
[torchtext.data.batch.Batch of size 124 from IWSLT]
	[.src]:[torch.LongTensor of size 24x124]
	[.trg]:[torch.LongTensor of size 33x124]
640 - 
[torchtext.data.batch.Batch of size 136 from IWSLT]
	[.src]:[torch.LongTensor of size 26x136]
	[.trg]:[torch.LongTensor of size 30x136]
641 - 
[torchtext.data.batch.Batch of size 341 from IWSLT]
	[.src]:[torch.LongTensor of size 6x341]
	[.trg]:[torch.LongTensor of size 12x341]
642 - 
[torchtext.data.batch.Batch of size 124 from IWSLT]
	[.src]:[torch.LongTensor of size 23x124]
	[.trg]:[torch.LongTensor of size 33x124]
643 - 
[torchtext.data.batch.Batch of size 170 from IWSLT]
	[.src]:[torch.LongTensor of size 17x170]
	[.trg]:[torch.LongTensor of size 24x170]
644 - 
[torchtext.data.batch.Batch of size 46 from IWSLT]
	[.src]:[torch.LongTensor of size 54x46]
	[.trg]:[torch.LongTensor of size 88x46]
645 - 
[torchtext.data.batch.Batch of size 120 from IWSLT]
	[.src]:[torch.LongTensor of size 33x120]
	[.

	[.trg]:[torch.LongTensor of size 12x341]
697 - 
[torchtext.data.batch.Batch of size 585 from IWSLT]
	[.src]:[torch.LongTensor of size 4x585]
	[.trg]:[torch.LongTensor of size 7x585]
698 - 
[torchtext.data.batch.Batch of size 163 from IWSLT]
	[.src]:[torch.LongTensor of size 15x163]
	[.trg]:[torch.LongTensor of size 24x163]
699 - 
[torchtext.data.batch.Batch of size 102 from IWSLT]
	[.src]:[torch.LongTensor of size 30x102]
	[.trg]:[torch.LongTensor of size 40x102]
700 - 
[torchtext.data.batch.Batch of size 273 from IWSLT]
	[.src]:[torch.LongTensor of size 14x273]
	[.trg]:[torch.LongTensor of size 15x273]
701 - 
[torchtext.data.batch.Batch of size 215 from IWSLT]
	[.src]:[torch.LongTensor of size 13x215]
	[.trg]:[torch.LongTensor of size 19x215]
702 - 
[torchtext.data.batch.Batch of size 372 from IWSLT]
	[.src]:[torch.LongTensor of size 10x372]
	[.trg]:[torch.LongTensor of size 11x372]
703 - 
[torchtext.data.batch.Batch of size 124 from IWSLT]
	[.src]:[torch.LongTensor of size 20x124]
	

	[.trg]:[torch.LongTensor of size 99x41]
755 - 
[torchtext.data.batch.Batch of size 256 from IWSLT]
	[.src]:[torch.LongTensor of size 15x256]
	[.trg]:[torch.LongTensor of size 16x256]
756 - 
[torchtext.data.batch.Batch of size 136 from IWSLT]
	[.src]:[torch.LongTensor of size 30x136]
	[.trg]:[torch.LongTensor of size 28x136]
757 - 
[torchtext.data.batch.Batch of size 110 from IWSLT]
	[.src]:[torch.LongTensor of size 35x110]
	[.trg]:[torch.LongTensor of size 37x110]
758 - 
[torchtext.data.batch.Batch of size 157 from IWSLT]
	[.src]:[torch.LongTensor of size 23x157]
	[.trg]:[torch.LongTensor of size 26x157]
759 - 
[torchtext.data.batch.Batch of size 227 from IWSLT]
	[.src]:[torch.LongTensor of size 11x227]
	[.trg]:[torch.LongTensor of size 18x227]
760 - 
[torchtext.data.batch.Batch of size 75 from IWSLT]
	[.src]:[torch.LongTensor of size 42x75]
	[.trg]:[torch.LongTensor of size 54x75]
761 - 
[torchtext.data.batch.Batch of size 409 from IWSLT]
	[.src]:[torch.LongTensor of size 10x409]
	[.

	[.trg]:[torch.LongTensor of size 76x53]
813 - 
[torchtext.data.batch.Batch of size 315 from IWSLT]
	[.src]:[torch.LongTensor of size 12x315]
	[.trg]:[torch.LongTensor of size 13x315]
814 - 
[torchtext.data.batch.Batch of size 124 from IWSLT]
	[.src]:[torch.LongTensor of size 32x124]
	[.trg]:[torch.LongTensor of size 33x124]
815 - 
[torchtext.data.batch.Batch of size 157 from IWSLT]
	[.src]:[torch.LongTensor of size 14x157]
	[.trg]:[torch.LongTensor of size 26x157]
816 - 
[torchtext.data.batch.Batch of size 183 from IWSLT]
	[.src]:[torch.LongTensor of size 13x183]
	[.trg]:[torch.LongTensor of size 22x183]
817 - 
[torchtext.data.batch.Batch of size 204 from IWSLT]
	[.src]:[torch.LongTensor of size 16x204]
	[.trg]:[torch.LongTensor of size 20x204]
818 - 
[torchtext.data.batch.Batch of size 124 from IWSLT]
	[.src]:[torch.LongTensor of size 33x124]
	[.trg]:[torch.LongTensor of size 31x124]
819 - 
[torchtext.data.batch.Batch of size 146 from IWSLT]
	[.src]:[torch.LongTensor of size 24x146]


	[.trg]:[torch.LongTensor of size 28x146]
871 - 
[torchtext.data.batch.Batch of size 341 from IWSLT]
	[.src]:[torch.LongTensor of size 10x341]
	[.trg]:[torch.LongTensor of size 12x341]
872 - 
[torchtext.data.batch.Batch of size 58 from IWSLT]
	[.src]:[torch.LongTensor of size 46x58]
	[.trg]:[torch.LongTensor of size 69x58]
873 - 
[torchtext.data.batch.Batch of size 133 from IWSLT]
	[.src]:[torch.LongTensor of size 19x133]
	[.trg]:[torch.LongTensor of size 30x133]
874 - 
[torchtext.data.batch.Batch of size 178 from IWSLT]
	[.src]:[torch.LongTensor of size 20x178]
	[.trg]:[torch.LongTensor of size 23x178]
875 - 
[torchtext.data.batch.Batch of size 99 from IWSLT]
	[.src]:[torch.LongTensor of size 39x99]
	[.trg]:[torch.LongTensor of size 41x99]
876 - 
[torchtext.data.batch.Batch of size 240 from IWSLT]
	[.src]:[torch.LongTensor of size 17x240]
	[.trg]:[torch.LongTensor of size 16x240]
877 - 
[torchtext.data.batch.Batch of size 227 from IWSLT]
	[.src]:[torch.LongTensor of size 17x227]
	[.tr

	[.trg]:[torch.LongTensor of size 24x170]
929 - 
[torchtext.data.batch.Batch of size 105 from IWSLT]
	[.src]:[torch.LongTensor of size 33x105]
	[.trg]:[torch.LongTensor of size 39x105]
930 - 
[torchtext.data.batch.Batch of size 178 from IWSLT]
	[.src]:[torch.LongTensor of size 21x178]
	[.trg]:[torch.LongTensor of size 23x178]
931 - 
[torchtext.data.batch.Batch of size 455 from IWSLT]
	[.src]:[torch.LongTensor of size 5x455]
	[.trg]:[torch.LongTensor of size 9x455]
932 - 
[torchtext.data.batch.Batch of size 120 from IWSLT]
	[.src]:[torch.LongTensor of size 30x120]
	[.trg]:[torch.LongTensor of size 34x120]
933 - 
[torchtext.data.batch.Batch of size 146 from IWSLT]
	[.src]:[torch.LongTensor of size 22x146]
	[.trg]:[torch.LongTensor of size 28x146]
934 - 
[torchtext.data.batch.Batch of size 141 from IWSLT]
	[.src]:[torch.LongTensor of size 18x141]
	[.trg]:[torch.LongTensor of size 29x141]
935 - 
[torchtext.data.batch.Batch of size 341 from IWSLT]
	[.src]:[torch.LongTensor of size 8x341]
	[

	[.trg]:[torch.LongTensor of size 30x136]
987 - 
[torchtext.data.batch.Batch of size 170 from IWSLT]
	[.src]:[torch.LongTensor of size 17x170]
	[.trg]:[torch.LongTensor of size 24x170]
988 - 
[torchtext.data.batch.Batch of size 315 from IWSLT]
	[.src]:[torch.LongTensor of size 11x315]
	[.trg]:[torch.LongTensor of size 13x315]
989 - 
[torchtext.data.batch.Batch of size 124 from IWSLT]
	[.src]:[torch.LongTensor of size 24x124]
	[.trg]:[torch.LongTensor of size 33x124]
990 - 
[torchtext.data.batch.Batch of size 292 from IWSLT]
	[.src]:[torch.LongTensor of size 11x292]
	[.trg]:[torch.LongTensor of size 14x292]
991 - 
[torchtext.data.batch.Batch of size 170 from IWSLT]
	[.src]:[torch.LongTensor of size 21x170]
	[.trg]:[torch.LongTensor of size 24x170]
992 - 
[torchtext.data.batch.Batch of size 93 from IWSLT]
	[.src]:[torch.LongTensor of size 32x93]
	[.trg]:[torch.LongTensor of size 44x93]
993 - 
[torchtext.data.batch.Batch of size 132 from IWSLT]
	[.src]:[torch.LongTensor of size 27x132]
	[

	[.trg]:[torch.LongTensor of size 17x240]
1045 - 
[torchtext.data.batch.Batch of size 227 from IWSLT]
	[.src]:[torch.LongTensor of size 14x227]
	[.trg]:[torch.LongTensor of size 18x227]
1046 - 
[torchtext.data.batch.Batch of size 163 from IWSLT]
	[.src]:[torch.LongTensor of size 24x163]
	[.trg]:[torch.LongTensor of size 25x163]
1047 - 
[torchtext.data.batch.Batch of size 178 from IWSLT]
	[.src]:[torch.LongTensor of size 23x178]
	[.trg]:[torch.LongTensor of size 23x178]
1048 - 
[torchtext.data.batch.Batch of size 117 from IWSLT]
	[.src]:[torch.LongTensor of size 29x117]
	[.trg]:[torch.LongTensor of size 35x117]
1049 - 
[torchtext.data.batch.Batch of size 292 from IWSLT]
	[.src]:[torch.LongTensor of size 8x292]
	[.trg]:[torch.LongTensor of size 14x292]
1050 - 
[torchtext.data.batch.Batch of size 292 from IWSLT]
	[.src]:[torch.LongTensor of size 9x292]
	[.trg]:[torch.LongTensor of size 14x292]
1051 - 
[torchtext.data.batch.Batch of size 186 from IWSLT]
	[.src]:[torch.LongTensor of size 17

In [11]:
class Batch:
    def __init__(self, src, trg, src_mask, trg_mask, ntokens):
        self.src = src
        self.trg = trg
        self.src_mask = src_mask
        self.trg_mask = trg_mask
        self.ntokens = ntokens
        
#print(next(b for b in train_iter)[0])

In [12]:
pad_idx = TGT.vocab.stoi["<blank>"]
def make_std_mask(src, tgt, pad):
    src_mask = (src != pad).unsqueeze(-2)
    tgt_mask = (tgt != pad).unsqueeze(-2)
    tgt_mask = tgt_mask & Variable(subsequent_mask(tgt.size(-1)).type_as(tgt_mask.data))
    #type_as 将张量转换为给定类型的张量
    return src_mask, tgt_mask
def subsequent_mask(size):
    attn_shape = (1, size, size)
    subsequent_mask = np.triu(np.ones(attn_shape), k = 1).astype('uint8')
    #Return a copy of a matrix with the elements below the k-th diagonal zeroed.
    return torch.from_numpy(subsequent_mask) == 0

In [13]:
#探究掩码
for i, batch in enumerate(train_iter):
    src, trg = batch.src.transpose(0, 1), batch.trg.transpose(0, 1)
    src_mask = (src != pad_idx).unsqueeze(-2)
    
    #unsqueeze 增加一维
    tgt_mask = (trg != pad_idx).unsqueeze(-2)
   #print(Variable(subsequent_mask(trg.size(-1)).type_as(tgt_mask.data)).size())
    
    tgt_mask = tgt_mask & Variable(subsequent_mask(trg.size(-1)).type_as(tgt_mask.data))
    #print(tgt_mask)
    #print((trg[1:] != pad_idx).data.sum())
    #统计trg中除bos，以及blank外词的个数
    print(trg[1:].size())
    print(trg.size())
    print('--')
    #print(src_mask)

torch.Size([39, 102])
torch.Size([40, 102])
--
torch.Size([109, 37])
torch.Size([110, 37])
--
torch.Size([203, 20])
torch.Size([204, 20])
--
torch.Size([203, 20])
torch.Size([204, 20])
--
torch.Size([123, 33])
torch.Size([124, 33])
--
torch.Size([109, 37])
torch.Size([110, 37])
--
torch.Size([371, 11])
torch.Size([372, 11])
--
torch.Size([584, 7])
torch.Size([585, 7])
--
torch.Size([185, 22])
torch.Size([186, 22])
--
torch.Size([113, 35])
torch.Size([114, 35])
--
torch.Size([255, 16])
torch.Size([256, 16])
--
torch.Size([74, 54])
torch.Size([75, 54])
--
torch.Size([371, 11])
torch.Size([372, 11])
--
torch.Size([156, 26])
torch.Size([157, 26])
--
torch.Size([194, 21])
torch.Size([195, 21])
--
torch.Size([194, 21])
torch.Size([195, 21])
--
torch.Size([314, 13])
torch.Size([315, 13])
--
torch.Size([214, 19])
torch.Size([215, 19])
--
torch.Size([214, 18])
torch.Size([215, 18])
--
torch.Size([48, 81])
torch.Size([49, 81])
--
torch.Size([88, 46])
torch.Size([89, 46])
--
torch.Size([123, 33])

--
torch.Size([123, 33])
torch.Size([124, 33])
--
torch.Size([109, 37])
torch.Size([110, 37])
--
torch.Size([74, 54])
torch.Size([75, 54])
--
torch.Size([194, 21])
torch.Size([195, 21])
--
torch.Size([511, 8])
torch.Size([512, 8])
--
torch.Size([116, 35])
torch.Size([117, 35])
--
torch.Size([140, 29])
torch.Size([141, 29])
--
torch.Size([50, 80])
torch.Size([51, 80])
--
torch.Size([408, 10])
torch.Size([409, 10])
--
torch.Size([203, 20])
torch.Size([204, 20])
--
torch.Size([194, 21])
torch.Size([195, 21])
--
torch.Size([79, 51])
torch.Size([80, 51])
--
torch.Size([116, 35])
torch.Size([117, 35])
--
torch.Size([255, 16])
torch.Size([256, 16])
--
torch.Size([272, 15])
torch.Size([273, 15])
--
torch.Size([162, 25])
torch.Size([163, 25])
--
torch.Size([131, 31])
torch.Size([132, 31])
--
torch.Size([226, 18])
torch.Size([227, 18])
--
torch.Size([314, 12])
torch.Size([315, 12])
--
torch.Size([255, 16])
torch.Size([256, 16])
--
torch.Size([203, 20])
torch.Size([204, 20])
--
torch.Size([84, 48

torch.Size([340, 12])
torch.Size([341, 12])
--
torch.Size([169, 24])
torch.Size([170, 24])
--
torch.Size([42, 95])
torch.Size([43, 95])
--
torch.Size([169, 24])
torch.Size([170, 24])
--
torch.Size([408, 10])
torch.Size([409, 10])
--
torch.Size([214, 19])
torch.Size([215, 19])
--
torch.Size([86, 47])
torch.Size([87, 47])
--
torch.Size([119, 34])
torch.Size([120, 34])
--
torch.Size([272, 15])
torch.Size([273, 15])
--
torch.Size([340, 12])
torch.Size([341, 12])
--
torch.Size([127, 32])
torch.Size([128, 32])
--
torch.Size([272, 15])
torch.Size([273, 15])
--
torch.Size([80, 50])
torch.Size([81, 50])
--
torch.Size([156, 26])
torch.Size([157, 26])
--
torch.Size([94, 43])
torch.Size([95, 43])
--
torch.Size([454, 9])
torch.Size([455, 9])
--
torch.Size([226, 18])
torch.Size([227, 18])
--
torch.Size([194, 21])
torch.Size([195, 21])
--
torch.Size([255, 16])
torch.Size([256, 16])
--
torch.Size([150, 27])
torch.Size([151, 27])
--
torch.Size([127, 32])
torch.Size([128, 32])
--
torch.Size([214, 19])
t

torch.Size([104, 39])
torch.Size([105, 39])
--
torch.Size([185, 22])
torch.Size([186, 22])
--
torch.Size([96, 42])
torch.Size([97, 42])
--
torch.Size([90, 45])
torch.Size([91, 45])
--
torch.Size([311, 13])
torch.Size([312, 13])
--
torch.Size([203, 20])
torch.Size([204, 20])
--
torch.Size([291, 14])
torch.Size([292, 14])
--
torch.Size([169, 24])
torch.Size([170, 24])
--
torch.Size([255, 16])
torch.Size([256, 16])
--
torch.Size([109, 37])
torch.Size([110, 37])
--
torch.Size([185, 21])
torch.Size([186, 21])
--
torch.Size([84, 48])
torch.Size([85, 48])
--
torch.Size([76, 53])
torch.Size([77, 53])
--
torch.Size([135, 30])
torch.Size([136, 30])
--
torch.Size([106, 38])
torch.Size([107, 38])
--
torch.Size([371, 11])
torch.Size([372, 11])
--
torch.Size([255, 16])
torch.Size([256, 16])
--
torch.Size([314, 13])
torch.Size([315, 13])
--
torch.Size([140, 29])
torch.Size([141, 29])
--
torch.Size([255, 16])
torch.Size([256, 16])
--
torch.Size([131, 31])
torch.Size([132, 31])
--
torch.Size([112, 36])

In [36]:
print(trg[:, :-1])
print(trg)
#为什么model.forward中,带入trg的句子都把最后一个词给去掉
#删除的这个词必定是ENS 或者是 pad

tensor([[   2,  111,  112,  ..., 5116,    5,    3],
        [   2,   58,   35,  ...,  465,    5,    3],
        [   2,   12,   25,  ...,  102,    4,    3],
        ...,
        [   2,   19,   13,  ...,    6,  891,    5],
        [   2,  968,   42,  ...,  379,  215,    5],
        [   2,   19,   20,  ...,    8,  125,    5]])
tensor([[  2, 111, 112,  ...,   5,   3,   1],
        [  2,  58,  35,  ...,   5,   3,   1],
        [  2,  12,  25,  ...,   4,   3,   1],
        ...,
        [  2,  19,  13,  ..., 891,   5,   3],
        [  2, 968,  42,  ..., 215,   5,   3],
        [  2,  19,  20,  ..., 125,   5,   3]])


In [37]:
print(tgt_mask.size())
print(tgt_mask[:, :-1, :-1].size())

torch.Size([141, 29, 29])
torch.Size([141, 28, 28])


## 建立整个模型

In [14]:
#定义标准的编码器-解码器框架
class EncoderDecoder(nn.Module):
    def __init__(self, encoder, decoder, src_embed, tgt_embed, generator):
        super(EncoderDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_embed = src_embed
        self.tgt_embed = tgt_embed
        self.generator = generator
        
    def forward(self, src, tgt, src_mask, tgt_mask):
        "take in and process masked src and tgt sequences"
        memory = self.encoder(self.src_embed(src), src_mask)
        output = self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask)
        return output

In [15]:
def clones(module, N):
    "Produce N identical layers"
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])

In [16]:
class Encoder(nn.Module):
    "core encoder is a stack of N layers"
    def __init__(self, layer, N):
        super(Encoder, self).__init__()
        self.layers = clones(layer, N)
        self.norm = LayerNorm(layer.size)
    
    def forward(self, x, mask):
        for layer in self.layers:
            x = layer(x, mask)
        return self.norm(x)

In [17]:
class LayerNorm(nn.Module):
    def __init__(self, features, eps=1e-6):
        super(LayerNorm, self).__init__()
        self.a_2 = nn.Parameter(torch.ones(features))
        self.b_2 = nn.Parameter(torch.zeros(features))
        self.eps = eps
        
    def forward(self, x):
        mean = x.mean(-1, keepdim=True)#求mean最后一个维度的均值，并保持维度不变
        std = x.std(-1, keepdim=True)
        return self.a_2*(x - mean)/(std + self.eps) + self.b_2
#归一化层

In [18]:
class SublayerConnection(nn.Module):
    """
    A residual connection followed by a layer norm.
    """
    def __init__(self, size, dropout):
        super(SublayerConnection, self).__init__()
        self.norm = LayerNorm(size)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, sublayer):
        return x + self.dropout(sublayer(self.norm(x)))

In [19]:
class EncoderLayer(nn.Module):
    "Encoder is made up of two sublayers, self-attn and feed forward (defined below)"
    def __init__(self, size, self_attn, feed_forward, dropout):
        super(EncoderLayer, self).__init__()
        self.self_attn = self_attn
        self.feed_forward = feed_forward
        self.sublayer = clones(SublayerConnection(size, dropout), 2)
        self.size = size

    def forward(self, x, mask):
        "Follow Figure 1 (left) for connections."
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask))
        return self.sublayer[1](x, self.feed_forward)

In [20]:
class Decoder(nn.Module):
    def __init__(self, layer, N):
        super(Decoder, self).__init__()
        self.layers = clones(layer, N)
        self.norm = LayerNorm(layer.size)
        
    def forward(self, x, memory, src_mask, tgt_mask):
        for layer in self.layers:
            x = layer(x, memory, src_mask, tgt_mask)
        return self.norm(x)

In [21]:
class DecoderLayer(nn.Module):
    def __init__(self, size,self_attn, src_attn, feed_forward, dropout):
        super(DecoderLayer, self).__init__()
        self.size = size
        self.self_attn = self_attn
        self.src_attn = src_attn
        self.feed_forward = feed_forward
        self.sublayer = clones(SublayerConnection(size, dropout), 3)
        
    def forward(self, x,memory, src_mask, tgt_mask):
        m = memory
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask))
        x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask))
        return self.sublayer[2](x, self.feed_forward)

In [22]:
def subsequent_mask(size):
    attn_shape = (1, size, size)
    subsequent_mask = np.triu(np.ones(attn_shape), k = 1).astype('uint8')
    #Return a copy of a matrix with the elements below the k-th diagonal zeroed.
    return torch.from_numpy(subsequent_mask) == 0

In [23]:
def attention(query, key, value, mask = None, dropout = 0.0):
    "scaled dot product attention"
    d_k = query.size(-1)
    scores = torch.matmul(query, key.transpose(-2,-1))/math.sqrt(d_k)
    if mask is not None:
        scores = scores.masked_fill(mask == 0,-1e9)
    p_attn = F.softmax(scores, dim = -1)
   #print('--')
    #print(p_attn)
    p_attn = F.dropout(p_attn, p = dropout)
    return torch.matmul(p_attn, value), p_attn


In [24]:
class MultiHeadedAttention(nn.Module):
    def __init__(self, h, d_model, dropout = 0.1):
        super(MultiHeadedAttention, self).__init__()
        assert d_model % h == 0
        self.d_k = d_model//h
        self.h = h
        self.p = dropout
        self.linears = clones(nn.Linear(d_model, d_model), 4)
        self.attn = None
        
    def forward(self, query, key, value, mask = None):
        if mask is not None:
            mask = mask.unsqueeze(1)
        nbatches = query.size(0)
        #1）do all the linear projections in batch from d_model => h x d_k
        query, key, value = [l(x).view(nbatches, -1, self.h,self.d_k).transpose(1, 2) for l, x in 
                             zip(self.linears, (query, key, value))]
        # 依次取出每一个数组的元素然后进行组合
        #2) Apply attention on all the projected vectors in batch
        x, self.attn = attention (query, key, value, mask = mask, dropout = self.p)
        #3)"Concat" using a view and apply a final linear
        x = x.transpose(1,2).contiguous().view(nbatches, -1, self.h * self.d_k)
        return self.linears[-1](x)

In [25]:
#逐位置的前馈网络  编码器和解码器模块最后都包含一个全连接的前馈网络，独立相同的应用于每一个位置

class PositionwiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout = 0.1):
        super(PositionwiseFeedForward, self).__init__()
        self.w_1 = nn.Linear(d_model,d_ff)
        self.w_2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)
    def forward(self, x):
        return self.w_2(self.dropout(F.relu(self.w_1(x))))

In [26]:
class Embeddings(nn.Module):
    def __init__(self, d_model, vocab):
        super(Embeddings,self).__init__()
        self.lut = nn.Embedding(vocab, d_model)
        self.d_model = d_model
        
    def forward(self, x):
        return self.lut(x) * math.sqrt(self.d_model)

In [27]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout, max_len = 5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p = dropout)
        
        pe = torch.zeros(max_len, d_model)
        
        position = torch.arange(0.,max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0., d_model, 2) * 
                            -(math.log(10000.0)/d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)
    def forward(self, x):
        x = x + Variable(self.pe[:, x.size(1)], requires_grad = False)
        return self.dropout(x)

In [28]:
class Generator(nn.Module):
    def __init__(self, d_model, vocab):
        super(Generator,self).__init__()
        self.proj = nn.Linear(d_model, vocab)
    def forward(self, x):
        return F.log_softmax(self.proj(x), dim = -1)

In [29]:
#定义模型整体,将以上模块组合
def make_model(src_vocab, tgt_vocab, N = 6, d_model = 512, d_ff = 2048, h = 8, dropout = 0.1):
    c = copy.deepcopy
    attn = MultiHeadedAttention(h, d_model, dropout)
    ff = PositionwiseFeedForward(d_model, d_ff, dropout)
    position = PositionalEncoding(d_model, dropout)
    model = EncoderDecoder(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N),
        Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ff), dropout), N),
        nn.Sequential(Embeddings(d_model, src_vocab), c(position)),
        nn.Sequential(Embeddings(d_model, tgt_vocab), c(position)),
        Generator(d_model, tgt_vocab)
    )
    for p in model.parameters():
        if p.dim()>1:
            nn.init.xavier_uniform(p)
    return model

In [30]:
tmp_model = make_model(10, 10, 1)

  app.launch_new_instance()


## 训练

In [31]:
def train_epoch(train_iter, model, criterion, opt, transpose=False):
    model.train()
    for i, batch in enumerate(train_iter):
        src, trg, src_mask, trg_mask = \
            batch.src, batch.trg, batch.src_mask, batch.trg_mask
        out = model.forward(src, trg[:, :-1], src_mask, trg_mask[:, :-1, :-1])
        loss = loss_backprop(model.generator, criterion, out, trg[:, 1:], batch.ntokens) 
                        
        model_opt.step()
        model_opt.optimizer.zero_grad()
        if i % 10 == 1:
            print(i, loss, model_opt._rate)

In [32]:
def valid_epoch(valid_iter, model, criterion, transpose=False):
    model.test()
    total = 0
    for batch in valid_iter:
        src, trg, src_mask, trg_mask = \
            batch.src, batch.trg, batch.src_mask, batch.trg_mask
        out = model.forward(src, trg[:, :-1], src_mask, trg_mask[:, :-1, :-1])
        loss = loss_backprop(model.generator, criterion, out, trg[:, 1:], batch.ntokens)