# HW 3: Neural Machine Translation

In this homework you will build a full neural machine translation system using an attention-based encoder-decoder network to translate from German to English. The encoder-decoder network with attention forms the backbone of many current text generation systems. See [Neural Machine Translation and Sequence-to-sequence Models: A Tutorial](https://arxiv.org/pdf/1703.01619.pdf) for an excellent tutorial that also contains many modern advances.

## Goals


1. Build a non-attentional baseline model (pure seq2seq as in [ref](https://papers.nips.cc/paper/5346-sequence-to-sequence-learning-with-neural-networks.pdf)). 
2. Incorporate attention into the baseline model ([ref](https://arxiv.org/abs/1409.0473) but with dot-product attention as in class notes).
3. Implement beam search: review/tutorial [here](http://www.phontron.com/slides/nlp-programming-en-13-search.pdf)
4. Visualize the attention distribution for a few examples. 

Consult the papers provided for hyperparameters, and the course notes for formal definitions.

This will be the most time-consuming assignment in terms of difficulty/training time, so we recommend that you get started early!

## Setup

This notebook provides a working definition of the setup of the problem itself. Feel free to construct your models inline, or use an external setup (preferred) to build your system.

In [1]:
# Text text processing library and methods for pretrained word embeddings
from torchtext import data
from torchtext import datasets
import torch
from torchtext.vocab import Vectors, GloVe
import torch.autograd as autograd
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import math
from torch.optim.lr_scheduler import MultiStepLR
from matplotlib import pyplot as plt
from torch.nn.utils import clip_grad_norm
import numpy as np
import torch.nn.init as weight_init
import time
import os
import heapq as hq

We first need to process the raw data using a tokenizer. We are going to be using spacy, which can be installed via:  
  `[sudo] pip install spacy`  
  
Tokenizers for English/German can be installed via:  
  `[sudo] python -m spacy download en`  
  `[sudo] python -m spacy download de`
  
This isn't *strictly* necessary, and you can use your own tokenization rules if you prefer (e.g. a simple `split()` in addition to some rules to acccount for punctuation), but we recommend sticking to the above.

In [2]:
import spacy
spacy_de = spacy.load('de')
spacy_en = spacy.load('en')

def tokenize_de(text):
    return [tok.text for tok in spacy_de.tokenizer(text)]

def tokenize_en(text):
    return [tok.text for tok in spacy_en.tokenizer(text)]


Note that we need to add the beginning-of-sentence token `<s>` and the end-of-sentence token `</s>` to the 
target so we know when to begin/end translating. We do not need to do this on the source side.

In [3]:
BOS_WORD = '<s>'
EOS_WORD = '</s>'
DE = data.Field(tokenize=tokenize_de)
EN = data.Field(tokenize=tokenize_en, init_token = BOS_WORD, eos_token = EOS_WORD) # only target needs BOS/EOS

Let's download the data. This may take a few minutes.

**While this dataset of 200K sentence pairs is relatively small compared to others, it will still take some time to train. So we are going to be only working with sentences of length at most 20 for this homework. Please train only on this reduced dataset for this homework.**

In [4]:
MAX_LEN = 20
train, val, test = datasets.IWSLT.splits(exts=('.de', '.en'), fields=(DE, EN), 
                                         filter_pred=lambda x: len(vars(x)['src']) <= MAX_LEN and 
                                         len(vars(x)['trg']) <= MAX_LEN)
print(train.fields)
print(len(train))
print(vars(train[0]))

{'src': <torchtext.data.field.Field object at 0x7f2ab40689e8>, 'trg': <torchtext.data.field.Field object at 0x7f2ab4068a90>}
119076
{'src': ['David', 'Gallo', ':', 'Das', 'ist', 'Bill', 'Lange', '.', 'Ich', 'bin', 'Dave', 'Gallo', '.'], 'trg': ['David', 'Gallo', ':', 'This', 'is', 'Bill', 'Lange', '.', 'I', "'m", 'Dave', 'Gallo', '.']}


Now we build the vocabulary and convert the text corpus into indices. We are going to be replacing tokens that occurred less than 5 times with `<unk>` tokens, and take the rest as our vocab.

In [5]:
MIN_FREQ = 5
DE.build_vocab(train.src, min_freq=MIN_FREQ)
EN.build_vocab(train.trg, min_freq=MIN_FREQ)
print(DE.vocab.freqs.most_common(10))
print("Size of German vocab", len(DE.vocab))
print(EN.vocab.freqs.most_common(10))
print("Size of English vocab", len(EN.vocab))
print(EN.vocab.stoi["<s>"], EN.vocab.stoi["</s>"]) #vocab index for <s>, </s>

[('.', 113253), (',', 67237), ('ist', 24189), ('die', 23778), ('das', 17102), ('der', 15727), ('und', 15622), ('Sie', 15085), ('es', 13197), ('ich', 12946)]
Size of German vocab 13353
[('.', 113433), (',', 59512), ('the', 46029), ('to', 29177), ('a', 27548), ('of', 26794), ('I', 24887), ('is', 21775), ("'s", 20630), ('that', 19814)]
Size of English vocab 11560
2 3


Now we split our data into batches as usual. Batching for MT is slightly tricky because source/target will be of different lengths. Fortunately, `torchtext` lets you do this by allowing you to pass in a `sort_key` function. This will minimizing the amount of padding on the source side, but since there is still some padding you will inadvertendly "attend" to these padding tokens. 

One way to get rid of padding is to pass a binary `mask` vector to your attention module so its attention score (before the softmax) is minus infinity for the padding token. Another way (which is how we do it for our projects, e.g. opennmt) is to manually sort data into batches so that each batch has exactly the same source length (this means that some batches will be less than the desired batch size, though).

However, for this homework padding won't matter too much, so it's fine to ignore it.

In [6]:
BATCH_SIZE = 32
train_iter, val_iter, test_iter = data.BucketIterator.splits((train, val, test), batch_size=BATCH_SIZE, device=-1,
                                                  repeat=False, sort_key=lambda x: len(x.src))

Let's check to see that the BOS/EOS token is indeed appended to the target (English) sentence.

In [7]:
batch = next(iter(train_iter))
print("Source")
print(batch.src)
print("Target")
print(batch.trg)


Source
Variable containing:

Columns 0 to 10 
    87     20   2242   2424   3020     26     26    790     26     40   1477
    75     67     13   4985     24     60    116     25      4    210   1580
     9   1361     34    334     23     22     18     82   2758     17    235
   205   6783    421     24      4   4520    389     33    895   4349      7
   304      2      0     17     22   1254    249    202     37    824   7849
  5564     20      5   1642    328      7      3      0      6   1484    195
  1449    867    125    427      0   3218     10      0     29     17    127
     3     80  10964     11    128   6083      4    189     30      0      0
   163     95      3      6    493   2283     36      3    551      4   1538
     9      3      5    119     52     65     46    118      0      3    128
    34     11     13     16      5     22    249    116      3    272   1794
   334    154     29     20   2027    304      3     10     43      6     90
   157    193   1178   1172   

Success! Now that we've processed the data, we are ready to begin modeling.

## Assignment

Now it is your turn to build the models described at the top of the assignment. 

When a model is trained, use the following test function to produce predictions, and then upload to the kaggle competition: https://www.kaggle.com/c/cs287-hw3-s18/

For the final Kaggle test, we will provide the source sentence, and you are to predict the **first three words of the target sentence**. The source sentence can be found under `source_test.txt`

In [8]:
!head source_test.txt

Als ich in meinen 20ern war , hatte ich meine erste Psychotherapie-Patientin .
Ich war Doktorandin und studierte Klinische Psychologie in Berkeley .
Sie war eine 26-jährige Frau namens Alex .
Und als ich das hörte , war ich erleichtert .
Meine Kommilitonin bekam nämlich einen Brandstifter als ersten Patienten .
Und ich bekam eine Frau in den 20ern , die über Jungs reden wollte .
Das kriege ich hin , dachte ich mir .
Aber ich habe es nicht hingekriegt .
Arbeit kam später , Heiraten kam später , Kinder kamen später , selbst der Tod kam später .
Leute in den 20ern wie Alex und ich hatten nichts als Zeit .


Similar to HW1, you are to predict the 100 most probable 3-gram that will begin the target sentence. The submission format will be as follows, where each word in the 3-gram will be separated by "|", and each 3-gram will be separated by space. For example, here is what an example submission might look like with 5 most-likely 3-grams (instead of 100).

```
id,word
1,Newspapers|talk|about When|I|was Researchers|call|the Twentysomethings|like|Alex But|before|long
2,That|'s|what Newspapers|talk|about You|have|robbed It|'s|realizing My|parents|wanted
3,We|forget|how We|think|about Proust|actually|links Does|any|other This|is|something
4,But|what|do And|it|'s They|'re|on My|name|is It|only|happens
```

When you print out your data, you will need to escape quotes and commas with the following command so that Kaggle does not complain. 

In [9]:
use_cuda = False #torch.cuda.is_available()

In [73]:
class Seq2SeqAttn(nn.Module):
    def __init__(self, encoder, decoder, use_true = True):
        super(Seq2SeqAttn, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.last_hidden_enc = (self.encoder.num_layers*(self.encoder.bidirectional*2)) - 1
        self.use_true = use_true # want to feed true last words in when training
        
    def forward(self, input_seqs, target_seqs):
        batch_size = input_seqs.size(1)
        target_length = target_seqs.size(0)
        vocab_size = self.decoder.output_size
        
        if use_cuda:
            outputs = Variable(torch.zeros(target_length, batch_size, vocab_size)).cuda()
        else:
            outputs = Variable(torch.zeros(target_length, batch_size, vocab_size))
        
        encoder_output, hidden = self.encoder(input_seqs, None)
        
        output = Variable(target_seqs.data[0, :])
        
        for t in range(1, target_length):
            
            
            #context = Variable(torch.zeros(batch_size,self.encoder.hidden_size)).cuda()
            
            #print(context.size())
            
            encoder_output_attn = encoder_output.transpose(0,1)
            hidden_attn = hidden[0][self.last_hidden_enc,:,:]
            
            hidden_attn=hidden_attn.unsqueeze(2)
            #print(hidden_attn.size(), encoder_output_attn.size())
            att_probs = torch.bmm(encoder_output_attn, hidden_attn)
            context = torch.bmm(att_probs.transpose(1,2), encoder_output_attn)
            
#             #loop over examples in batch
#             for i in range(batch_size):
                
#                 #print(encoder_output[:,i,:].size())
#                 #print(hidden[0][1,i,:].size())
                
#                 #calculate attention probabilities
                
#                 att_probs = torch.matmul(encoder_output[:,i,:],hidden[0][self.last_hidden_enc,i,:])
                
#                 #print(att_probs.size())
                
#                 #now get "expected" for each element in batch
#                 context[i,:] = torch.matmul(att_probs,encoder_output[:,i,:])
            
            
            output, hidden = self.decoder(output, hidden, context)
            
            outputs[t] = output
            
            #use true values only if training/validating model
            if use_cuda:
                output = Variable(target_seqs.data[t]).cuda()
            else:
                output = Variable(target_seqs.data[t])

        return outputs
    
    def batch_train(self, optimizer, train_iter, vocab_size, grad_clip=10):
        self.train()
        total_loss = 0
        pad = EN.vocab.stoi['<pad>']
        curr_time = time.time()
        loss_f = nn.NLLLoss()
        for b, batch in enumerate(train_iter):
            source = batch.src
            target = batch.trg
            if use_cuda:
                source, target = source.cuda(), target.cuda()
            optimizer.zero_grad()
            output = self.forward(source, target)
            loss = F.cross_entropy(output[1:].view(-1, vocab_size),
                                   target[1:].contiguous().view(-1),
                                   ignore_index=pad)
            loss.backward()
            clip_grad_norm(self.parameters(), grad_clip)
            optimizer.step()
            total_loss += loss.data[0]

            if b % 1000 == 0 and b != 0:
                total_loss = total_loss / 1000
                print("[%d][loss:%5.2f][pp:%5.2f][time:%5.2f]" %
                      (b, total_loss, math.exp(total_loss), time.time() - curr_time))
                total_loss = 0
                curr_time = time.time()
    
    def beam_search(self, input_seq, beam_size, search_time, target_vocab, n = 10):
        
        self.eval()
        
        encoder_output, hidden = self.encoder(input_seq, None)
        
        #here we track the strings still in our beam
        track = [None]*beam_size
        
        # in worst case, take beam_size candidates from one former state
        poss_next = min([beam_size, len(target_vocab)])
        
        # next states that we might want to keep around
        possible_next = [None]*(beam_size*poss_next)
        
        # list of complete strings to keep around
        final_candidates = []
        
        # first state we have in our beam
        track[0] = ([target_vocab.stoi["<s>"]], hidden, 0)
        num_tracked = 1
        
        #start time
        start = time.time()
        
        #track no. of steps we've taken
        steps = 0
        
        #continue looping until we've used search time
        while time.time() < start + search_time:
            
            steps += 1
            
            #track where we are in the list of possible next values
            poss_counter = 0
            
            for i in range(num_tracked):
                
                att_probs = torch.matmul(encoder_output[:,0,:],track[i][1][0][self.last_hidden_enc,0,:])
            
                #print(att_probs)
                context = torch.matmul(att_probs,encoder_output[:,0,:])
                
                
                #print(track[i][0][-1:])
                output, hidden = self.decoder(Variable(torch.LongTensor(track[i][0][-1:])), hidden, context.view(1,1,-1))
                
                #print(output)
                #log_output = torch.log(output)
                
                #print(log_output)
                top_next, idx = torch.topk(output.squeeze(), poss_next)
                
                #print(top_next)
                
                
                for j in range(len(top_next)):
                    
                    #print(idx[j])
                    if idx.data[j] != target_vocab.stoi["</s>"]:
                        
                        possible_next[poss_counter] = (top_next.data[j] + track[i][2],top_next[j], idx.data[j], i, steps)
                    
                        poss_counter += 1
                        
                    else:
                        
                        #we have a complete string
                        final_candidates += [(top_next.data[j] + track[i][2],track[i][0]+[idx.data[j]])]
                    
                    
            if poss_counter < beam_size:
                
                print("less possible next states than beam size. Seems unlikely this would ever happen.")
                
                break
            
            
            #sort our possible next states and choose
            #hq.heapify(possible_next[:poss_counter])
            largest = hq.nlargest(beam_size, possible_next[:poss_counter])
            
            num_tracked = 0
            
            #now loop over and put everything we want to keep in our track array
            for k in range(poss_next):
                parent = track[possible_next[k][-2]]
                #print(parent[0] + [possible_next[k][2]])
                track[k] = (parent[0] + [possible_next[k][2]], hidden, possible_next[k][0])
                num_tracked += 1
           
        print(len(final_candidates))
        
        #finally, sort our final candidates
#         hq.heapify(final_candidates)
#         final = hq.nlargest(n, final_candidates)
        
        #lets just try getting the best directly from tracked
        candidates = [(score, string) for string, _, score in track] + final_candidates
        final = hq.nlargest(n, candidates)
        self.train()
        
        return final
                
            
        
    def predict(self, val_iter, vocab_size):
        self.eval()
        pad = EN.vocab.stoi['<pad>']
        total_loss = 0
        for batch in val_iter:
            source = batch.src
            target = batch.trg
            if use_cuda:
                source = Variable(source.data.cuda(), volatile=True)
                target = Variable(target.data.cuda(), volatile=True)
            output = self.forward(source, target)
            loss = F.cross_entropy(output[1:].view(-1, vocab_size),
                                   target[1:].contiguous().view(-1),
                                   ignore_index=pad)
            total_loss += loss.data[0]
        return total_loss / len(val_iter)

In [74]:
class AttnDecoderLSTM(nn.Module):
    def __init__(self, hidden_size=200, context_size=200, output_size=11560, n_layers = 2, dropout= 0.3, bidirectional=False):
        super(AttnDecoderLSTM, self).__init__()
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.n_layers = n_layers
        self.context_size = context_size
        self.bidirectional = bidirectional
        self.dropout = nn.Dropout(dropout, inplace=True)
        self.embedding = nn.Embedding(output_size, hidden_size)
        self.lstm = nn.LSTM(self.hidden_size+self.context_size, self.hidden_size, self.n_layers, bidirectional=bidirectional)
        for param in self.lstm.parameters():
            nn.init.uniform(param, -0.08, 0.08)
        self.out = nn.Linear(self.hidden_size+self.context_size, output_size)
        self.softmax = nn.LogSoftmax(dim=1)
        
    def forward(self, input_seq, hidden, context):
        #print(input_seq.size())
        embedded = self.embedding(input_seq).unsqueeze(0)
        embedded = self.dropout(embedded)
        #print(embedded.size(),context.size())
        combined = torch.cat([embedded, context.permute(1,0,2)],2)
        output, hidden = self.lstm(combined, hidden)
        output = output.squeeze(0)
        output = self.softmax(self.out(torch.cat([output,context.squeeze(1)],1)))
        return output, hidden

In [75]:
class EncoderLSTM(nn.Module):
    def __init__(self, input_size = 13352, hidden_size = 200, n_layers=2, dropout=0.3, bidirectional=False):
        super(EncoderLSTM, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = n_layers
        self.bidirectional = bidirectional
        self.embedding = nn.Embedding(input_size, hidden_size)
        self.lstm = nn.LSTM(hidden_size, hidden_size, n_layers, bidirectional=bidirectional)
        for param in self.lstm.parameters():
            nn.init.uniform(param, -0.08, 0.08)

    def forward(self, input, hidden):
        embedded = self.embedding(input)
        output, hidden = self.lstm(embedded, hidden)
        return output, hidden
        
class DecoderLSTM(nn.Module):
    def __init__(self, hidden_size=200, output_size=11560, n_layers = 2, dropout= 0.3, bidirectional=None):
        super(DecoderLSTM, self).__init__()
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.n_layers = n_layers
        self.dropout = nn.Dropout(dropout, inplace=True)
        self.embedding = nn.Embedding(output_size, hidden_size)
        self.lstm = nn.LSTM(self.hidden_size, self.hidden_size, self.n_layers, bidirectional=bidirectional)
        self.out = nn.Linear(hidden_size, output_size)
        self.softmax = nn.LogSoftmax(dim=1)

    def forward(self, input_seq, hidden):
        embedded = self.embedding(input_seq).unsqueeze(0)
        embedded = self.dropout(embedded)
        output, hidden = self.lstm(embedded, hidden)
        output = F.tanh(output)
        output = self.softmax(self.out(output))
        return output, hidden

In [76]:
encoder = EncoderLSTM(input_size = len(DE.vocab), hidden_size = 200, n_layers=2, dropout=0.3, bidirectional=False)
decoder = AttnDecoderLSTM(hidden_size=200, context_size = 200, output_size = len(EN.vocab), n_layers = 2, dropout= 0.3, bidirectional=None)
if use_cuda:
    seq2seq = Seq2SeqAttn(encoder, decoder).cuda()
else:
    seq2seq = Seq2SeqAttn(encoder, decoder)
epoch_num = 13
optimizer = optim.SGD(seq2seq.parameters(), lr=1)
scheduler = MultiStepLR(optimizer, milestones=range(9, epoch_num), gamma=0.5)

In [119]:
def escape(l):
    return l.replace("\"", "<quote>").replace(",", "<comma>")

In [109]:
best_val_loss = None
#seq2seq.load_state_dict(torch.load("./.save/seq2seq_4.pt"))
for i in range(epoch_num):
    seq2seq.batch_train(optimizer, train_iter,len(EN.vocab), grad_clip = 10)
    val_loss = seq2seq.predict(val_iter, len(EN.vocab))
    print("[Epoch:%d] val_loss:%5.3f | val_pp:%5.2fS"
          % (i, val_loss, math.exp(val_loss)))

    # Save the model if the validation loss is the best we've seen so far.
    if not best_val_loss or val_loss < best_val_loss:
        print("[!] saving model...")
        if not os.path.isdir(".save"):
            os.makedirs(".save")
        torch.save(seq2seq.state_dict(), './.save/seq2seq_a_%d.pt' % (i))
        best_val_loss = val_loss
test_loss = seq2seq.predict(test_iter, len(EN.vocab))
print("[TEST] loss:%5.2f" % test_loss)

torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 2

torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 2

torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 2

torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 2

torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 2

torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 2

torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 2

torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 2

torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 2

torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 2

torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 2

torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 2

torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 2

torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 2

torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 2

torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 2

torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 2

torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 2

torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 2

torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 2

torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 2

torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 2

torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 2

torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 2

torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 2

torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 2

torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 2

torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 2

torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 2

torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 2

torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 2

torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 2

torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 2

torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 2

torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 2

torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 2

torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 2

torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 200])
torch.Size([32])
torch.Size([1, 32, 200]) torch.Size([32, 1, 2

KeyboardInterrupt: 

In [84]:
batch = next(iter(train_iter))
if use_cuda:
    source = batch.src.cuda()
    if hasattr(batch, "trg"):
        target = batch.trg.cuda()
else:
    source = batch.src
    if hasattr(batch, "trg"):
        target = batch.trg
        
seq2seq.load_state_dict(torch.load("./.save/seq2seq_a_11.pt"))
sentences = seq2seq.beam_search(source[:,:1], 5, 100, EN.vocab, n=5)
print(len(sentences))
if hasattr(batch, "trg"):
        print(" ".join([EN.vocab.itos[word] for word in target[:,:1].squeeze().data]))
for sentence in sentences:
    print(" ".join([EN.vocab.itos[word] for word in sentence[1]]))

48456
5
<s> And it 's at the scale of the night sky . </s> <pad> <pad> <pad> <pad> <pad>
<s> It 's <unk> . </s>
<s> It 's <unk> . </s>
<s> It 's <unk> </s>
<s> It 's <unk> the </s>
<s> It 's <unk> . ? </s>


In [32]:
%set_env CUDA_LAUNCH_BLOCKING=1

torch.cuda.empty_cache()

env: CUDA_LAUNCH_BLOCKING=1


RuntimeError: cuda runtime error (59) : device-side assert triggered at torch/csrc/cuda/Module.cpp:321

In [26]:
A=next(iter(train_iter))
A.src

Variable containing:

Columns 0 to 10 
  4226     12      9     20    139     12    729     12     28     12     12
    33     42    897    143    631    184      5      7   4821      5     72
    13    283     32     82     80      3    264    357      3   3236    199
   324     46    887    699     22    210      3    215     31      0   6523
     0      3      3      3   1536      6     31      4      6     27      4
     0     13     37    224      0   1271     13      3   2202      3      6
     3    542     10     11      8     18     22     31   1374     17   2769
     5  12969    705      5   9038      0    496    128     91   1445     17
    32    336     14   3161      0      3    109    175      5      3      0
   161     81    446   2402     68     27      3      3   7060     21   5627
   815      0     35     67      0     13     35     11     72    314      0
   165     69      2     52  11869     14     36    197      0    216   4924
    14     81     26      5     65   

In [27]:
A.trg

Variable containing:

Columns 0 to 10 
     2      2      2      2      2      2      2      2      2      2      2
  5855     34     52    688     48     14    716     14     42     14     14
     5     86    461     10    706    141      6      6   4547     55     47
    19    368     85     90      7     82    203    428     11    144      9
    28     13    433      7     59    949      9     29     13      5     50
   339     19     23    181    405     11     94     20      6     22   3911
  2511    126      4     15      8     30     56     11    904    129     11
   758      7     97    162   1802    257      8    130    912      5    152
   349    547     23     10    569      5    402     63     73     31      6
  2436   1643    103   2218     18     15     23    169    566    177      0
  1363     99   2701      6      8     43   2217      5      6     54      7
     7     40      4   2502   8622   5735      7     44    455    684      0
     0     78      3      7   2193   

You should perform your hyperparameter search/early stopping/write-up based on perplexity, not the above metric. (In practice, people use a metric called [BLEU](https://www.aclweb.org/anthology/P02-1040.pdf), which is roughly a geometric average of 1-gram, 2-gram, 3-gram, 4-gram precision, with a brevity penalty for producing translations that are too short.)

Finally, as always please put up a (short) write-up following the template provided in the repository:  https://github.com/harvard-ml-courses/cs287-s18/blob/master/template/
