Prerequisite: 
- `./scl-2021-ds/test.csv`
- Run `02_DataPrep.ipynb` to get `./scl-2021-ds/parsed_train.csv`. This `parsed_train.csv` is a tokenized version with labels for every token.

In [1]:
import math
import time
import numpy as np
import pandas as pd
import ast
import torch
import torch.nn as nn
import torch.nn.functional as F
import io
import torch
import torch.optim as optim
from torchtext.datasets import WikiText2
from torchtext.data.utils import get_tokenizer
from collections import Counter
import torchtext
from torchtext.vocab import Vocab, Vectors, FastText
from torch.utils.tensorboard import SummaryWriter
from tqdm.notebook import tqdm

## Read Input

In [4]:
df = pd.read_csv("./scl-2021-ds/parsed_train.csv")
df = df.dropna()
df['parsed'] = df['parsed'].apply(ast.literal_eval)

## Split Train/Validation

In [5]:
mask = np.random.rand(len(df)) < 0.8
train_df = df[mask].reset_index(drop=True)
valid_df = df[~mask].reset_index(drop=True)

## Define Dataset

In [6]:
TEXT = torchtext.legacy.data.Field()
LABEL_TAG = torchtext.legacy.data.Field(unk_token=None)
TRANSFORMED = torchtext.legacy.data.Field()
fields = (("input", TEXT), ("label", LABEL_TAG), ("transformed", TRANSFORMED))

In [7]:
class DataObject(object):
    def __init__(self, inputs, labels, transformed):
        self.input = inputs
        self.label = labels
        self.transformed = transformed
        
class Dataset(torchtext.legacy.data.Dataset):
    
    def __init__(self, df, fields):
        self.df = df 
    
        examples = []
        for i, row in self.df.iterrows():
            inputs, labels, transformed = zip(*self.df.at[i, 'parsed'])
            inputs = list(inputs)
            labels = list(labels)
            transformed = list(transformed)
            examples.append(torchtext.legacy.data.Example.fromlist([inputs, labels, transformed], fields))
        super().__init__(examples, fields)
    
    @staticmethod
    def sort_key(ex):
        return len(ex.input)

In [8]:
train_dataset = Dataset(train_df, fields)
valid_dataset = Dataset(valid_df, fields)

In [9]:
# sample data format
vars(train_dataset[0])

{'input': ['jl',
  'kapuk',
  'timur',
  'delta',
  'sili',
  'iii',
  'lippo',
  'cika',
  '11',
  'a',
  'cicau',
  'cikarang',
  'pusat'],
 'label': [1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2],
 'transformed': ['jl',
  'kapuk',
  'timur',
  'delta',
  'sili',
  'iii',
  'lippo',
  'cika',
  '11',
  'a',
  'cicau',
  'cikarang',
  'pusat']}

In [10]:
# load pretrained fasttext-indonesia vector
fasttext_id_vectors = FastText(language='id')

In [11]:
# build vocabulary set
TEXT.build_vocab(train_dataset, min_freq=2, vectors=fasttext_id_vectors)
LABEL_TAG.build_vocab(train_dataset)
TRANSFORMED.build_vocab(train_dataset, min_freq=2, vectors=fasttext_id_vectors)

In [12]:
print(f"{len(TEXT.vocab)=}")
print(f"{len(LABEL_TAG.vocab)=}")
print(f"{len(TRANSFORMED.vocab)=}")

len(TEXT.vocab)=39896
len(LABEL_TAG.vocab)=4
len(TRANSFORMED.vocab)=41023


In [13]:
print(TEXT.vocab.freqs.most_common(20))

[(',', 164438), ('no', 39282), ('rt', 34294), ('raya', 28435), ('1', 19036), ('2', 17394), ('rw', 16224), ('3', 14546), ('4', 11858), ('timur', 11807), ('barat', 11745), ('5', 10862), ('utara', 10732), ('kel.', 10397), ('jaya', 9924), ('gg.', 9585), ('6', 9345), ('selatan', 8963), ('jl.', 8858), ('baru', 8579)]


In [14]:
print(LABEL_TAG.vocab.itos)

['<pad>', 2, 1, 0]


In [15]:
print(LABEL_TAG.vocab.freqs.most_common())

[(2, 1121710), (1, 411847), (0, 276062)]


In [16]:
print(TRANSFORMED.vocab.freqs.most_common())



## Create Data Iterator 

This iterator will attempt to pack inputs with similar length together (i.e., the number of tokens). See `sort_key` for the sorting scheme.

In [17]:
batch_size = 32
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_iter, valid_iter = torchtext.legacy.data.BucketIterator.splits(
                                                  datasets=(train_dataset, valid_dataset), 
                                                  batch_sizes=(16, 16),
                                                  sort_key=lambda x: len(x.input),
                                                  device=device,
                                                  sort=False, 
                                                  shuffle=True,
                                                  sort_within_batch=True,
)

In [18]:
# Test if iterator is working
for x in train_iter:
    print(x)
    break


[torchtext.legacy.data.batch.Batch of size 16]
	[.input]:[torch.cuda.LongTensor of size 12x16 (GPU 0)]
	[.label]:[torch.cuda.LongTensor of size 12x16 (GPU 0)]
	[.transformed]:[torch.cuda.LongTensor of size 12x16 (GPU 0)]


## Define Model

In [19]:
class BiLSTMTagger(nn.Module):
    def __init__(self, 
                 input_dim, 
                 embedding_dim, 
                 hidden_dim, 
                 output_dim, 
                 n_layers, 
                 bidirectional, 
                 dropout, 
                 pad_idx):
        
        super().__init__()
        
        self.embedding = nn.Embedding(input_dim, embedding_dim, padding_idx = pad_idx)
        
        self.lstm = nn.LSTM(embedding_dim, 
                            hidden_dim, 
                            num_layers = n_layers, 
                            bidirectional = bidirectional,
                            dropout = dropout if n_layers > 1 else 0)
        
        self.fc = nn.Linear(hidden_dim * 2 if bidirectional else hidden_dim, output_dim)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, text):

        #text = [sent len, batch size]
        
        #pass text through embedding layer
        embedded = self.dropout(self.embedding(text))
        
        #embedded = [sent len, batch size, emb dim]
        
        #pass embeddings into LSTM
        outputs, (hidden, cell) = self.lstm(embedded)
        
        #outputs holds the backward and forward hidden states in the final layer
        #hidden and cell are the backward and forward hidden and cell states at the final time-step
        
        #output = [sent len, batch size, hid dim * n directions]
        #hidden/cell = [n layers * n directions, batch size, hid dim]
        
        #we use our outputs to make a prediction of what the tag should be
        predictions = self.fc(self.dropout(outputs))
        
        #predictions = [sent len, batch size, output dim]
        
        return predictions

In [20]:
INPUT_DIM = len(TEXT.vocab)
EMBEDDING_DIM = 300
HIDDEN_DIM = 1024
OUTPUT_DIM = len(LABEL_TAG.vocab)
N_LAYERS = 2
BIDIRECTIONAL = True
DROPOUT = 0.25
PAD_IDX = TEXT.vocab.stoi[TEXT.pad_token]

model = BiLSTMTagger(INPUT_DIM,
                     EMBEDDING_DIM,
                     HIDDEN_DIM,
                     OUTPUT_DIM,
                     N_LAYERS,
                     BIDIRECTIONAL,
                     DROPOUT,
                     PAD_IDX)

In [21]:
def init_weights(m):
    for name, param in m.named_parameters():
        nn.init.normal_(param.data, mean = 0, std = 0.1)
        
model.apply(init_weights)

BiLSTMTagger(
  (embedding): Embedding(39896, 300, padding_idx=1)
  (lstm): LSTM(300, 1024, num_layers=2, dropout=0.25, bidirectional=True)
  (fc): Linear(in_features=2048, out_features=4, bias=True)
  (dropout): Dropout(p=0.25, inplace=False)
)

In [22]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(model):,} trainable parameters')

The model has 48,021,796 trainable parameters


## Use pre-trained fasttext vectors as embedding layer's initialization

In [23]:
pretrained_embeddings = TEXT.vocab.vectors

print(pretrained_embeddings.shape)

torch.Size([39896, 300])


In [24]:
model.embedding.weight.data.copy_(pretrained_embeddings)

tensor([[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0996,  0.1834, -0.1307,  ..., -0.2033, -0.0849, -0.0017],
        ...,
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [-0.1304,  0.5440, -1.1450,  ..., -0.0696, -0.1374, -0.0943]])

## Define optimizer, learning rate scheduler, and loss function

In [25]:
optimizer = optim.Adam(model.parameters())
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.6)

In [26]:
TAG_PAD_IDX = LABEL_TAG.vocab.stoi[LABEL_TAG.pad_token]
criterion = nn.CrossEntropyLoss(ignore_index = TAG_PAD_IDX)

In [27]:
model = model.to(device)
criterion = criterion.to(device)

In [28]:
def categorical_accuracy(preds, y, tag_pad_idx):
    """
    Returns accuracy per batch, i.e. if you get 8/10 right, this returns 0.8, NOT 8
    """
    max_preds = preds.argmax(dim = 1, keepdim = True) # get the index of the max probability
    non_pad_elements = (y != tag_pad_idx).nonzero()
    correct = max_preds[non_pad_elements].squeeze(1).eq(y[non_pad_elements])
    denom = torch.FloatTensor([y[non_pad_elements].shape[0]]).to(y.device)
    return correct.sum() / denom


## Train/Evaluate

In [29]:
def train(model, iterator, optimizer, criterion, tag_pad_idx, writer=None, log_interval=100):
    
    epoch_loss = 0
    epoch_acc = 0
    
    model.train()
    counter = 0
    
    pbar = tqdm(iterator, leave=False)
    for batch in pbar:
        
        text = batch.input
        tags = batch.label
        
        optimizer.zero_grad()
        
        #text = [sent len, batch size]
        
        predictions = model(text)
        
        #predictions = [sent len, batch size, output dim]
        #tags = [sent len, batch size]
        
        predictions = predictions.view(-1, predictions.shape[-1])
        tags = tags.view(-1)
        
        #predictions = [sent len * batch size, output dim]
        #tags = [sent len * batch size]
        
        loss = criterion(predictions, tags)
                
        acc = categorical_accuracy(predictions, tags, tag_pad_idx)
        
        loss.backward()
        
        optimizer.step()
        
        if writer is not None:
            writer.add_scalar("step_train", loss.item())
        
        epoch_loss += loss.item()
        epoch_acc += acc.item()
        counter += 1
        
        description = f"[Train] Batch {counter}: Loss {epoch_loss / counter : .3f} | Acc {epoch_acc / counter : .3f}"
        if counter % log_interval == 0:
            pbar.set_description(description)
        
    return epoch_loss / len(iterator), epoch_acc / len(iterator)

In [30]:
def evaluate(model, iterator, criterion, tag_pad_idx, writer=None, log_interval=100):
    
    epoch_loss = 0
    epoch_acc = 0
    counter = 0
    
    model.eval()
    
    with torch.no_grad():
    
        pbar = tqdm(iterator, leave=False)
        for batch in pbar:

            text = batch.input
            tags = batch.label
            
            predictions = model(text)
            
            predictions = predictions.view(-1, predictions.shape[-1])
            tags = tags.view(-1)
            
            loss = criterion(predictions, tags)
            
            acc = categorical_accuracy(predictions, tags, tag_pad_idx)

            epoch_loss += loss.item()
            epoch_acc += acc.item()
            counter += 1

            description = f"[Valid] Batch {counter}: Loss {epoch_loss / counter : .3f} | Acc {epoch_acc / counter : .3f}"
            if counter % log_interval == 0:
                pbar.set_description(description)
            
        
    return epoch_loss / len(iterator), epoch_acc / len(iterator)

In [31]:
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

In [32]:
N_EPOCHS = 10
writer = SummaryWriter()
best_valid_loss = float('inf')

for epoch in tqdm(range(N_EPOCHS)):

    start_time = time.time()
    
    train_loss, train_acc = train(model, train_iter, optimizer, criterion, TAG_PAD_IDX, writer)
    writer.add_scalar("loss/train", train_loss, epoch)
    writer.add_scalar("acc/train", train_acc, epoch)
    valid_loss, valid_acc = evaluate(model, valid_iter, criterion, TAG_PAD_IDX)
    writer.add_scalar("loss/valid", valid_loss, epoch)
    writer.add_scalar("acc/valid", valid_acc, epoch)
    lr_scheduler.step()
    
    end_time = time.time()

    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), 'tut1-model.pt')
    
    print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')
    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. Acc: {valid_acc*100:.2f}%')


  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/15009 [00:00<?, ?it/s]

  0%|          | 0/3741 [00:00<?, ?it/s]

Epoch: 01 | Epoch Time: 3m 53s
	Train Loss: 0.278 | Train Acc: 90.09%
	 Val. Loss: 0.239 |  Val. Acc: 91.77%


  0%|          | 0/15009 [00:00<?, ?it/s]

  0%|          | 0/3741 [00:00<?, ?it/s]

Epoch: 02 | Epoch Time: 3m 54s
	Train Loss: 0.210 | Train Acc: 92.82%
	 Val. Loss: 0.225 |  Val. Acc: 92.23%


  0%|          | 0/15009 [00:00<?, ?it/s]

  0%|          | 0/3741 [00:00<?, ?it/s]

Epoch: 03 | Epoch Time: 3m 53s
	Train Loss: 0.181 | Train Acc: 93.74%
	 Val. Loss: 0.233 |  Val. Acc: 92.24%


  0%|          | 0/15009 [00:00<?, ?it/s]

  0%|          | 0/3741 [00:00<?, ?it/s]

Epoch: 04 | Epoch Time: 3m 54s
	Train Loss: 0.153 | Train Acc: 94.58%
	 Val. Loss: 0.253 |  Val. Acc: 91.94%


  0%|          | 0/15009 [00:00<?, ?it/s]

  0%|          | 0/3741 [00:00<?, ?it/s]

Epoch: 05 | Epoch Time: 3m 53s
	Train Loss: 0.126 | Train Acc: 95.38%
	 Val. Loss: 0.294 |  Val. Acc: 91.64%


  0%|          | 0/15009 [00:00<?, ?it/s]

  0%|          | 0/3741 [00:00<?, ?it/s]

Epoch: 06 | Epoch Time: 3m 54s
	Train Loss: 0.106 | Train Acc: 96.06%
	 Val. Loss: 0.325 |  Val. Acc: 91.18%


  0%|          | 0/15009 [00:00<?, ?it/s]

  0%|          | 0/3741 [00:00<?, ?it/s]

Epoch: 07 | Epoch Time: 3m 54s
	Train Loss: 0.092 | Train Acc: 96.53%
	 Val. Loss: 0.357 |  Val. Acc: 90.76%


  0%|          | 0/15009 [00:00<?, ?it/s]

  0%|          | 0/3741 [00:00<?, ?it/s]

Epoch: 08 | Epoch Time: 3m 53s
	Train Loss: 0.083 | Train Acc: 96.83%
	 Val. Loss: 0.388 |  Val. Acc: 90.79%


  0%|          | 0/15009 [00:00<?, ?it/s]

  0%|          | 0/3741 [00:00<?, ?it/s]

Epoch: 09 | Epoch Time: 3m 54s
	Train Loss: 0.079 | Train Acc: 97.02%
	 Val. Loss: 0.403 |  Val. Acc: 90.56%


  0%|          | 0/15009 [00:00<?, ?it/s]

  0%|          | 0/3741 [00:00<?, ?it/s]

Epoch: 10 | Epoch Time: 3m 54s
	Train Loss: 0.075 | Train Acc: 97.13%
	 Val. Loss: 0.411 |  Val. Acc: 90.51%


In [33]:
model.load_state_dict(torch.load('tut1-model.pt'))

<All keys matched successfully>

In [34]:
def tag_sentence(model, device, tokens, text_field, tag_field):
    
    model.eval()
    numericalized_tokens = [text_field.vocab.stoi[t] for t in tokens]
    unk_idx = text_field.vocab.stoi[text_field.unk_token]
    unks = [t for t, n in zip(tokens, numericalized_tokens) if n == unk_idx]
    
    token_tensor = torch.LongTensor(numericalized_tokens)
    token_tensor = token_tensor.unsqueeze(-1).to(device)
         
    predictions = model(token_tensor) 
    top_predictions = predictions.argmax(-1)  
    predicted_tags = [tag_field.vocab.itos[t.item()] for t in top_predictions]
    
    return tokens, predicted_tags, unks

## Visualize output prediction

In [35]:
example_index = 0

tokens = vars(valid_dataset.examples[example_index])['input']
actual_tags = vars(valid_dataset.examples[example_index])['label']

print(tokens)
print(actual_tags)

['toko', 'dita', ',', 'kertosono']
[0, 0, 2, 2]


In [36]:
tokens, pred_tags, unks = tag_sentence(model, 
                                       device, 
                                       tokens, 
                                       TEXT, 
                                       LABEL_TAG)

print(unks)

[]


In [37]:
print("Pred. Tag\tActual Tag\tCorrect?\tToken\n")

for token, pred_tag, actual_tag in zip(tokens, pred_tags, actual_tags):
    correct = '✔' if pred_tag == actual_tag else '✘'
    print(f"{pred_tag}\t\t{actual_tag}\t\t{correct}\t\t{token}")


Pred. Tag	Actual Tag	Correct?	Token

0		0		✔		toko
0		0		✔		dita
2		2		✔		,
2		2		✔		kertosono


In [38]:
def tokens_to_output(tokens):
    output = ""
    first = True
    for token in tokens:
        if (not first) and (token != ","):
            output += " "
        output += token 
        first = False
    return output

## Actual Evaluation on Validation Split

In [39]:
valid_df

Unnamed: 0,id,raw_address,POI/street,parsed
0,3,"toko dita, kertosono",toko dita/,"[(toko, 0, toko), (dita, 0, dita), (,, 2, ,), ..."
1,6,"kem mel raya, no 4 bojong rawalumbu rt 1 36 ra...",/kem mel raya,"[(kem, 1, kem), (mel, 1, mel), (raya, 1, raya)..."
2,8,gg. i wates magersari,/gg. i,"[(gg., 1, gg.), (i, 1, i), (wates, 2, wates), ..."
3,15,"kampung.gudang areng,desa:anyer, kecamatan:any...",gudang areng/,"[(kampung.gudang, 0, gudang), (areng, 0, areng..."
4,19,"tam tama barat v, banyumanik",/tam tama barat v,"[(tam, 1, tam), (tama, 1, tama), (barat, 1, ba..."
...,...,...,...,...
59847,299975,"taman kanak kanak dha wan karanggayam sren,",/,"[(taman, 2, taman), (kanak, 2, kanak), (kanak,..."
59848,299982,kan desa tanjung beringin,/kan desa,"[(kan, 1, kan), (desa, 1, desa), (tanjung, 2, ..."
59849,299983,"la banda minima, cile raya, pesanggrahan",la banda minimarket/cile raya,"[(la, 0, la), (banda, 0, banda), (minima, 0, m..."
59850,299984,jend sudi 2 larangan kel. candi,/,"[(jend, 2, jend), (sudi, 2, sudi), (2, 2, 2), ..."


In [40]:
for i, row, in tqdm(valid_df.iterrows(), total=len(valid_df)):
    sentence = row['raw_address']
    tokens = sentence.replace(',', ' , ').split()
    
    tokens, tags, unks = tag_sentence(model, 
                                      device, 
                                      tokens, 
                                      TEXT, 
                                      LABEL_TAG)
    
    POI_tokens = []
    street_tokens = []
    for token, tag in zip(tokens, tags):
        if tag == 0:
            POI_tokens.append(token)
        elif tag == 1:
            street_tokens.append(token)
            
    POI = tokens_to_output(POI_tokens)
    street = tokens_to_output(street_tokens)
    output = f"{POI}/{street}"
    valid_df.at[i, "pred"] = output

  0%|          | 0/59852 [00:00<?, ?it/s]

In [41]:
valid_df

Unnamed: 0,id,raw_address,POI/street,parsed,pred
0,3,"toko dita, kertosono",toko dita/,"[(toko, 0, toko), (dita, 0, dita), (,, 2, ,), ...",toko dita/
1,6,"kem mel raya, no 4 bojong rawalumbu rt 1 36 ra...",/kem mel raya,"[(kem, 1, kem), (mel, 1, mel), (raya, 1, raya)...",/kem mel raya
2,8,gg. i wates magersari,/gg. i,"[(gg., 1, gg.), (i, 1, i), (wates, 2, wates), ...",/gg. i
3,15,"kampung.gudang areng,desa:anyer, kecamatan:any...",gudang areng/,"[(kampung.gudang, 0, gudang), (areng, 0, areng...",kampung.gudang areng/
4,19,"tam tama barat v, banyumanik",/tam tama barat v,"[(tam, 1, tam), (tama, 1, tama), (barat, 1, ba...",/tam tama barat v
...,...,...,...,...,...
59847,299975,"taman kanak kanak dha wan karanggayam sren,",/,"[(taman, 2, taman), (kanak, 2, kanak), (kanak,...",taman kanak kanak dha wan karanggayam sren/
59848,299982,kan desa tanjung beringin,/kan desa,"[(kan, 1, kan), (desa, 1, desa), (tanjung, 2, ...",kan desa tanjung/
59849,299983,"la banda minima, cile raya, pesanggrahan",la banda minimarket/cile raya,"[(la, 0, la), (banda, 0, banda), (minima, 0, m...",la banda minima/cile raya
59850,299984,jend sudi 2 larangan kel. candi,/,"[(jend, 2, jend), (sudi, 2, sudi), (2, 2, 2), ...",/jend sudi


In [42]:
n_correct = sum(valid_df['POI/street'] == valid_df['pred'])
acc = n_correct / len(valid_df)
print(acc)

0.6186426518746241


## Actual Run on Test Split

In [43]:
test_df = pd.read_csv("./scl-2021-ds/test.csv")

In [44]:
for i, row, in tqdm(test_df.iterrows(), total=len(test_df)):
    sentence = row['raw_address']
    tokens = sentence.replace(',', ' , ').split()
    
    tokens, tags, unks = tag_sentence(model, 
                                  device, 
                                  tokens, 
                                  TEXT, 
                                  LABEL_TAG)
    
    POI_tokens = []
    street_tokens = []
    for token, tag in zip(tokens, tags):
        if tag == 0:
            POI_tokens.append(token)
        elif tag == 1:
            street_tokens.append(token)
            
    POI = tokens_to_output(POI_tokens)
    street = tokens_to_output(street_tokens)
    output = f"{POI}/{street}"
    test_df.at[i, "POI/street"] = output

  0%|          | 0/50000 [00:00<?, ?it/s]

In [45]:
del test_df['raw_address']

In [46]:
test_df.to_csv("output_bilstm.csv", index=False, index_label=False)

## TODO
- [ ] Compare SGD and Adam optimizer
- [ ] Try using vocabulary set with `min_freq=1`
- [ ] Try adding a layer to correct the token in the decoder section
- [ ] Try Transformer Model
- [ ] Try augmenting training Data by replacing POI and street

In [2]:
# class TransformerModel(nn.Module):

#     def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5):
#         super(TransformerModel, self).__init__()
#         from torch.nn import TransformerEncoder, TransformerEncoderLayer
#         self.model_type = 'Transformer'
#         self.pos_encoder = PositionalEncoding(ninp, dropout)
#         encoder_layers = TransformerEncoderLayer(ninp, nhead, nhid, dropout)
#         self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
#         self.encoder = nn.Embedding(ntoken, ninp)
#         self.ninp = ninp
#         self.decoder = nn.Linear(ninp, ntoken)

#         self.init_weights()

#     def generate_square_subsequent_mask(self, sz):
#         mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
#         mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
#         return mask

#     def init_weights(self):
#         initrange = 0.1
#         self.encoder.weight.data.uniform_(-initrange, initrange)
#         self.decoder.bias.data.zero_()
#         self.decoder.weight.data.uniform_(-initrange, initrange)

#     def forward(self, src, src_mask):
#         src = self.encoder(src) * math.sqrt(self.ninp)
#         src = self.pos_encoder(src)
#         output = self.transformer_encoder(src, src_mask)
#         output = self.decoder(output)
#         return output


In [3]:
# class PositionalEncoding(nn.Module):

#     def __init__(self, d_model, dropout=0.1, max_len=5000):
#         super(PositionalEncoding, self).__init__()
#         self.dropout = nn.Dropout(p=dropout)

#         pe = torch.zeros(max_len, d_model)
#         position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
#         div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
#         pe[:, 0::2] = torch.sin(position * div_term)
#         pe[:, 1::2] = torch.cos(position * div_term)
#         pe = pe.unsqueeze(0).transpose(0, 1)
#         self.register_buffer('pe', pe)

#     def forward(self, x):
#         x = x + self.pe[:x.size(0), :]
#         return self.dropout(x)

In [47]:
# ntokens = len(vocab.stoi) # the size of vocabulary
# emsize = 200 # embedding dimension
# nhid = 200 # the dimension of the feedforward network model in nn.TransformerEncoder
# nlayers = 2 # the number of nn.TransformerEncoderLayer in nn.TransformerEncoder
# nhead = 2 # the number of heads in the multiheadattention models
# dropout = 0.2 # the dropout value
# model = TransformerModel(ntokens, emsize, nhead, nhid, nlayers, dropout).to(device)

In [48]:
# criterion = nn.CrossEntropyLoss()
# lr = 1.0 # learning rate
# optimizer = torch.optim.SGD(model.parameters(), lr=lr)
# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)

In [49]:
# import time
# def train():
#     model.train() # Turn on the train mode
#     total_loss = 0.
#     start_time = time.time()
#     src_mask = model.generate_square_subsequent_mask(bptt).to(device)
#     for batch, i in enumerate(range(0, train_data.size(0) - 1, bptt)):
#         data, targets = get_batch(train_data, i)
#         optimizer.zero_grad()
#         if data.size(0) != bptt:
#             src_mask = model.generate_square_subsequent_mask(data.size(0)).to(device)
#         output = model(data, src_mask)
#         loss = criterion(output.view(-1, ntokens), targets)
#         loss.backward()
#         torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
#         optimizer.step()

#         total_loss += loss.item()
#         log_interval = 200
#         if batch % log_interval == 0 and batch > 0:
#             cur_loss = total_loss / log_interval
#             elapsed = time.time() - start_time
#             print('| epoch {:3d} | {:5d}/{:5d} batches | '
#                   'lr {:02.2f} | ms/batch {:5.2f} | '
#                   'loss {:5.2f} | ppl {:8.2f}'.format(
#                     epoch, batch, len(train_data) // bptt, scheduler.get_lr()[0],
#                     elapsed * 1000 / log_interval,
#                     cur_loss, math.exp(cur_loss)))
#             total_loss = 0
#             start_time = time.time()

# def evaluate(eval_model, data_source):
#     eval_model.eval() # Turn on the evaluation mode
#     total_loss = 0.
#     src_mask = model.generate_square_subsequent_mask(bptt).to(device)
#     with torch.no_grad():
#         for i in range(0, data_source.size(0) - 1, bptt):
#             data, targets = get_batch(data_source, i)
#             if data.size(0) != bptt:
#                 src_mask = model.generate_square_subsequent_mask(data.size(0)).to(device)
#             output = eval_model(data, src_mask)
#             output_flat = output.view(-1, ntokens)
#             total_loss += len(data) * criterion(output_flat, targets).item()
#     return total_loss / (len(data_source) - 1)

In [50]:
# best_val_loss = float("inf")
# epochs = 3 # The number of epochs
# best_model = None

# for epoch in range(1, epochs + 1):
#     epoch_start_time = time.time()
#     train()
#     val_loss = evaluate(model, val_data)
#     print('-' * 89)
#     print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | '
#           'valid ppl {:8.2f}'.format(epoch, (time.time() - epoch_start_time),
#                                      val_loss, math.exp(val_loss)))
#     print('-' * 89)

#     if val_loss < best_val_loss:
#         best_val_loss = val_loss
#         best_model = model

#     scheduler.step()

In [51]:
# get_batch(train_data, 2)

In [52]:
# model.generate_square_subsequent_mask(bptt)

In [53]:
# test_data