# PA4 Notebook
COSI-134A: StatNLP

Wanyue Xiao

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
%cd /content/drive/My Drive/PA4/

/content/drive/My Drive/PA4


In [4]:
!ls

 consts.py	   ptb			       requirements_with_versions.txt
 data_loader.py   'ptb.zip (Unzipped Files)'   seq2seq.py
 EVALB		   __pycache__		       train.py
 inference.py	   README.ipynb		       utils.py
 outputs	   README.md		       vocabs.py
 prepare_data.py   requirements.txt	       Wanyue_Xiao_PA4.ipynb


## Contents
0. Introduction
2. Prepare Data
3. Seq2Seq with Attention
4. Training
5. Inference

## 0. Introduction

* this notebook is designed primarily for those using Google Colab Notebooks
    * you don't need to use this if you'd rather work with .py files
    * it may help nevertheless to look at .py files because we import and use them here
* make sure to read [README](./README.ipynb) before you begin
* entire notebook trains Vanilla Seq2Seq out-of-the-box
* for questions: [jchun@brandeis.edu](mailto:jchun@brandeis.edu)

### Using GCP

* make sure to add all .py files in Colab's **Files** tab
* instead of importing from `seq2seq.py`, we will use the models defined in this notebook
    * write your implementations for attentional decoders here

In [5]:
# # uncomment to download required packages 
# # consider `requirements_with_versions.txt` if you'd like version constraints
!pip install -r requirements.txt

Collecting sacrebleu
  Downloading sacrebleu-2.0.0-py3-none-any.whl (90 kB)
[K     |████████████████████████████████| 90 kB 5.0 MB/s 
Collecting portalocker
  Downloading portalocker-2.3.2-py2.py3-none-any.whl (15 kB)
Collecting colorama
  Downloading colorama-0.4.4-py2.py3-none-any.whl (16 kB)
Installing collected packages: portalocker, colorama, sacrebleu
Successfully installed colorama-0.4.4 portalocker-2.3.2 sacrebleu-2.0.0


In [6]:
import nltk
nltk.__version__

'3.2.5'

In [7]:
import torch
torch.__version__

'1.10.0+cu111'

## 1. Prepare data

In [8]:
import os
import random

In [9]:
import vocabs
from prepare_data import linearize_parse_tree, prepare_data, process_sent

In [10]:
# prepare_data path hyperparams
PTB_DIR = './ptb'
DATA_DIR = './outputs/ptb'

In [11]:
os.path.abspath(PTB_DIR)

'/content/drive/MyDrive/PA4/ptb'

In [12]:
os.path.exists(PTB_DIR)

True

In [13]:
os.path.abspath(DATA_DIR)

'/content/drive/MyDrive/PA4/outputs/ptb'

In [14]:
os.makedirs(DATA_DIR, exist_ok=True)

In [15]:
# prepare_data hyperparams
LOWER = False
REVERSE_SENT = False
PRUNE = False
XX_NORM = False
CLOSING_TAG = False
KEEP_INDEX = False

### Prepare Data Playground

In [16]:
reader = nltk.corpus.BracketParseCorpusReader(f'{PTB_DIR}/dev', r'.*/wsj_.*\.mrg')

#### Original Sentence

In [17]:
sample_sent = reader.sents()[0]
" ".join(sample_sent)

"Influential members of the House Ways and Means Committee introduced legislation that *T*-1 would restrict how the new savings-and-loan bailout agency can raise capital *T*-2 , *-3 creating another potential obstacle to the government 's sale of sick thrifts ."

#### Preprocessed Sentence

In [18]:
# current setting
process_sent(sample_sent, lower=LOWER, reverse=REVERSE_SENT, keep_index=KEEP_INDEX)

"Influential members of the House Ways and Means Committee introduced legislation that *T* would restrict how the new savings-and-loan bailout agency can raise capital *T* , * creating another potential obstacle to the government 's sale of sick thrifts ."

In [19]:
process_sent(sample_sent, lower=False, reverse=False, keep_index=True)

"Influential members of the House Ways and Means Committee introduced legislation that *T*-1 would restrict how the new savings-and-loan bailout agency can raise capital *T*-2 , *-3 creating another potential obstacle to the government 's sale of sick thrifts ."

In [20]:
process_sent(sample_sent, lower=False, reverse=True, keep_index=False)

". thrifts sick of sale 's government the to obstacle potential another creating * , *T* capital raise can agency bailout savings-and-loan new the how restrict would *T* that legislation introduced Committee Means and Ways House the of members Influential"

In [21]:
process_sent(sample_sent, lower=True, reverse=False, keep_index=False)

"influential members of the house ways and means committee introduced legislation that *T* would restrict how the new savings-and-loan bailout agency can raise capital *T* , * creating another potential obstacle to the government 's sale of sick thrifts ."

In [22]:
process_sent(sample_sent, lower=True, reverse=True, keep_index=False)

". thrifts sick of sale 's government the to obstacle potential another creating * , *T* capital raise can agency bailout savings-and-loan new the how restrict would *T* that legislation introduced committee means and ways house the of members influential"

#### Original Parse Tree

In [23]:
sample_tree = reader.parsed_sents()[0]

In [24]:
" ".join(str(sample_tree).strip().split())

"(S (NP-SBJ (NP (JJ Influential) (NNS members)) (PP (IN of) (NP (DT the) (NNP House) (NNP Ways) (CC and) (NNP Means) (NNP Committee)))) (VP (VBD introduced) (NP (NP (NN legislation)) (SBAR (WHNP-1 (WDT that)) (S (NP-SBJ-3 (-NONE- *T*-1)) (VP (MD would) (VP (VB restrict) (SBAR (WHADVP-2 (WRB how)) (S (NP-SBJ (DT the) (JJ new) (NN savings-and-loan) (NN bailout) (NN agency)) (VP (MD can) (VP (VB raise) (NP (NN capital)) (ADVP-MNR (-NONE- *T*-2)))))) (, ,) (S-ADV (NP-SBJ (-NONE- *-3)) (VP (VBG creating) (NP (NP (DT another) (JJ potential) (NN obstacle)) (PP (TO to) (NP (NP (NP (DT the) (NN government) (POS 's)) (NN sale)) (PP (IN of) (NP (JJ sick) (NNS thrifts)))))))))))))) (. .))"

#### Preprocessed Parse Tree

In [25]:
# current setting
linearize_parse_tree(sample_tree, prune_leaf_brackets=PRUNE, XX_norm=XX_NORM, closing_tag=CLOSING_TAG, keep_index=KEEP_INDEX)

'(S (NP-SBJ (NP (JJ ) (NNS ) ) (PP (IN ) (NP (DT ) (NNP ) (NNP ) (CC ) (NNP ) (NNP ) ) ) ) (VP (VBD ) (NP (NP (NN ) ) (SBAR (WHNP (WDT ) ) (S (NP-SBJ (-NONE- ) ) (VP (MD ) (VP (VB ) (SBAR (WHADVP (WRB ) ) (S (NP-SBJ (DT ) (JJ ) (NN ) (NN ) (NN ) ) (VP (MD ) (VP (VB ) (NP (NN ) ) (ADVP-MNR (-NONE- ) ) ) ) ) ) (, ) (S-ADV (NP-SBJ (-NONE- ) ) (VP (VBG ) (NP (NP (DT ) (JJ ) (NN ) ) (PP (TO ) (NP (NP (NP (DT ) (NN ) (POS ) ) (NN ) ) (PP (IN ) (NP (JJ ) (NNS ) ) ) ) ) ) ) ) ) ) ) ) ) ) (. ) )'

In [26]:
linearize_parse_tree(
    sample_tree, prune_leaf_brackets=False, XX_norm=False, closing_tag=False, keep_index=True)

'(S (NP-SBJ (NP (JJ ) (NNS ) ) (PP (IN ) (NP (DT ) (NNP ) (NNP ) (CC ) (NNP ) (NNP ) ) ) ) (VP (VBD ) (NP (NP (NN ) ) (SBAR (WHNP-1 (WDT ) ) (S (NP-SBJ-3 (-NONE- ) ) (VP (MD ) (VP (VB ) (SBAR (WHADVP-2 (WRB ) ) (S (NP-SBJ (DT ) (JJ ) (NN ) (NN ) (NN ) ) (VP (MD ) (VP (VB ) (NP (NN ) ) (ADVP-MNR (-NONE- ) ) ) ) ) ) (, ) (S-ADV (NP-SBJ (-NONE- ) ) (VP (VBG ) (NP (NP (DT ) (JJ ) (NN ) ) (PP (TO ) (NP (NP (NP (DT ) (NN ) (POS ) ) (NN ) ) (PP (IN ) (NP (JJ ) (NNS ) ) ) ) ) ) ) ) ) ) ) ) ) ) (. ) )'

In [27]:
linearize_parse_tree(
    sample_tree, prune_leaf_brackets=False, XX_norm=False, closing_tag=True, keep_index=False)

'(S (NP-SBJ (NP (JJ )JJ (NNS )NNS )NP (PP (IN )IN (NP (DT )DT (NNP )NNP (NNP )NNP (CC )CC (NNP )NNP (NNP )NNP )NP )PP )NP-SBJ (VP (VBD )VBD (NP (NP (NN )NN )NP (SBAR (WHNP (WDT )WDT )WHNP (S (NP-SBJ (-NONE- )-NONE- )NP-SBJ (VP (MD )MD (VP (VB )VB (SBAR (WHADVP (WRB )WRB )WHADVP (S (NP-SBJ (DT )DT (JJ )JJ (NN )NN (NN )NN (NN )NN )NP-SBJ (VP (MD )MD (VP (VB )VB (NP (NN )NN )NP (ADVP-MNR (-NONE- )-NONE- )ADVP-MNR )VP )VP )S )SBAR (, ), (S-ADV (NP-SBJ (-NONE- )-NONE- )NP-SBJ (VP (VBG )VBG (NP (NP (DT )DT (JJ )JJ (NN )NN )NP (PP (TO )TO (NP (NP (NP (DT )DT (NN )NN (POS )POS )NP (NN )NN )NP (PP (IN )IN (NP (JJ )JJ (NNS )NNS )NP )PP )NP )PP )NP )VP )S-ADV )VP )VP )S )SBAR )NP )VP (. ). )S'

In [28]:
linearize_parse_tree(
    sample_tree, prune_leaf_brackets=False, XX_norm=True, closing_tag=False, keep_index=False)

'(XX (XX (XX (XX ) (XX ) ) (XX (XX ) (XX (XX ) (XX ) (XX ) (XX ) (XX ) (XX ) ) ) ) (XX (XX ) (XX (XX (XX ) ) (XX (XX (XX ) ) (XX (XX (XX ) ) (XX (XX ) (XX (XX ) (XX (XX (XX ) ) (XX (XX (XX ) (XX ) (XX ) (XX ) (XX ) ) (XX (XX ) (XX (XX ) (XX (XX ) ) (XX (XX ) ) ) ) ) ) (XX ) (XX (XX (XX ) ) (XX (XX ) (XX (XX (XX ) (XX ) (XX ) ) (XX (XX ) (XX (XX (XX (XX ) (XX ) (XX ) ) (XX ) ) (XX (XX ) (XX (XX ) (XX ) ) ) ) ) ) ) ) ) ) ) ) ) ) (XX ) )'

In [29]:
linearize_parse_tree(
    sample_tree, prune_leaf_brackets=True, XX_norm=False, closing_tag=False, keep_index=False)

'(S (NP-SBJ (NP JJ NNS ) (PP IN (NP DT NNP NNP CC NNP NNP ) ) ) (VP VBD (NP (NP NN ) (SBAR (WHNP WDT ) (S (NP-SBJ -NONE- ) (VP MD (VP VB (SBAR (WHADVP WRB ) (S (NP-SBJ DT JJ NN NN NN ) (VP MD (VP VB (NP NN ) (ADVP-MNR -NONE- ) ) ) ) ) , (S-ADV (NP-SBJ -NONE- ) (VP VBG (NP (NP DT JJ NN ) (PP TO (NP (NP (NP DT NN POS ) NN ) (PP IN (NP JJ NNS ) ) ) ) ) ) ) ) ) ) ) ) ) . )'

In [30]:
linearize_parse_tree(
    sample_tree, prune_leaf_brackets=True, XX_norm=True, closing_tag=False, keep_index=False)

'(XX (XX (XX XX XX ) (XX XX (XX XX XX XX XX XX XX ) ) ) (XX XX (XX (XX XX ) (XX (XX XX ) (XX (XX XX ) (XX XX (XX XX (XX (XX XX ) (XX (XX XX XX XX XX XX ) (XX XX (XX XX (XX XX ) (XX XX ) ) ) ) ) XX (XX (XX XX ) (XX XX (XX (XX XX XX XX ) (XX XX (XX (XX (XX XX XX XX ) XX ) (XX XX (XX XX XX ) ) ) ) ) ) ) ) ) ) ) ) ) XX )'

### Run Prepare Data
* preprocess both sentence and parse tree data
* compile vocab counters

In [31]:
datasets, vocab_counters = prepare_data(ptb_dir=PTB_DIR, out_dir=DATA_DIR, lower=LOWER, reverse_sent=REVERSE_SENT, prune_leaf_brackets=PRUNE, XX_norm=XX_NORM, closing_tag=CLOSING_TAG, keep_index=KEEP_INDEX)


Begin loading and processing PTB..

Loading dev..


[Processing Sents]: 100%|██████████| 1700/1700 [00:00<00:00, 2215.62it/s]
[Processing Trees]: 100%|██████████| 1700/1700 [00:02<00:00, 632.45it/s]


Sample ptb from dev
  Sent: Influential members of the House Ways and Means Committee introduced legislation that *T* would restrict how the new savings-and-loan bailout agency can raise capital *T* , * creating another potential obstacle to the government 's sale of sick thrifts .
  Tree: (S (NP-SBJ (NP (JJ ) (NNS ) ) (PP (IN ) (NP (DT ) (NNP ) (NNP ) (CC ) (NNP ) (NNP ) ) ) ) (VP (VBD ) (NP (NP (NN ) ) (SBAR (WHNP (WDT ) ) (S (NP-SBJ (-NONE- ) ) (VP (MD ) (VP (VB ) (SBAR (WHADVP (WRB ) ) (S (NP-SBJ (DT ) (JJ ) (NN ) (NN ) (NN ) ) (VP (MD ) (VP (VB ) (NP (NN ) ) (ADVP-MNR (-NONE- ) ) ) ) ) ) (, ) (S-ADV (NP-SBJ (-NONE- ) ) (VP (VBG ) (NP (NP (DT ) (JJ ) (NN ) ) (PP (TO ) (NP (NP (NP (DT ) (NN ) (POS ) ) (NN ) ) (PP (IN ) (NP (JJ ) (NNS ) ) ) ) ) ) ) ) ) ) ) ) ) ) (. ) )

Loading train..


[Processing Sents]: 100%|██████████| 39832/39832 [00:16<00:00, 2344.65it/s]
[Processing Trees]: 100%|██████████| 39832/39832 [00:44<00:00, 899.31it/s]


Sample ptb from train
  Sent: In an Oct. !DIGITS review of `` The Misanthrope '' at Chicago 's Goodman Theatre -LRB- `` Revitalized Classics Take the Stage in Windy City , '' Leisure & Arts -RRB- , the role of Celimene , played * by Kim Cattrall , was mistakenly attributed * to Christina Haag .
  Tree: (S (PP-LOC (IN ) (NP (NP (DT ) (NNP ) (CD ) (NN ) ) (PP (IN ) (NP (`` ) (NP-TTL (DT ) (NN ) ) ('' ) (PP-LOC (IN ) (NP (NP (NNP ) (POS ) ) (NNP ) (NNP ) ) ) ) ) (PRN (-LRB- ) (`` ) (S-HLN (NP-SBJ (VBN ) (NNS ) ) (VP (VBP ) (NP (DT ) (NN ) ) (PP-LOC (IN ) (NP (NNP ) (NNP ) ) ) ) ) (, ) ('' ) (NP-TMP (NN ) (CC ) (NNS ) ) (-RRB- ) ) ) ) (, ) (NP-SBJ (NP (NP (DT ) (NN ) ) (PP (IN ) (NP (NNP ) ) ) ) (, ) (VP (VBN ) (NP (-NONE- ) ) (PP (IN ) (NP-LGS (NNP ) (NNP ) ) ) ) (, ) ) (VP (VBD ) (VP (ADVP-MNR (RB ) ) (VBN ) (NP (-NONE- ) ) (PP-CLR (TO ) (NP (NNP ) (NNP ) ) ) ) ) (. ) )

Compiling Vocab..

Vocab Info:
  Sent (39175) => ,, the, ., !DIGITS, *, of, to, a, and, *T*, in, 's, that, for, *U*, $

[Processing Sents]: 100%|██████████| 2416/2416 [00:00<00:00, 2417.46it/s]
[Processing Trees]: 100%|██████████| 2416/2416 [00:02<00:00, 912.57it/s]


Sample ptb from test
  Sent: No , it was n't Black Monday .
  Tree: (S (INTJ (RB ) ) (, ) (NP-SBJ (PRP ) ) (VP (VBD ) (RB ) (NP-PRD (NNP ) (NNP ) ) ) (. ) )


In [32]:
dev_raw, train_raw, test_raw = datasets

## 2. Seq2Seq with Attention

In [33]:
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
import consts as C

### A. Encoder

In [34]:
class Encoder(nn.Module):
    """Recurrent Encoder"""
    def __init__(self, sent_stoi, embed_dim, hidden_dim, num_layers, dropout, rnn_type):
        """configs and layers for Encoder

        Args:
          tree_stoi: tree str-to-int vocab
          embed_dim: embedding feature dimension
          hidden_dim: RNN hidden dimension
          num_layers: number of RNN layers
          dropout: dropout probability
          rnn_type: RNN, GRU or LSTM
        """
        super().__init__()
        ### configs
        self.rnn_type = rnn_type

        ### layers
        self.embedding = nn.Embedding(len(sent_stoi), embed_dim, padding_idx=sent_stoi[C.PAD])
        self.dropout = nn.Dropout(dropout)
        if rnn_type == C.GRU:
            self.rnn = nn.GRU(embed_dim, hidden_dim, num_layers, batch_first=True)
        elif rnn_type == C.LSTM:
            self.rnn = nn.LSTM(embed_dim, hidden_dim, num_layers, batch_first=True)
        else:
            self.rnn = nn.RNN(embed_dim, hidden_dim, num_layers, batch_first=True)

        self.reset_parameters()

    def reset_parameters(self):
        for name, param in self.rnn.named_parameters():
            if 'weight' in name:
                nn.init.xavier_uniform_(param)
            elif 'bias' in name:
                nn.init.constant_(param, 0.)

    def init_pretrained_embedding(self, weights, finetune=False):
        """initializes nn.Embedding layer with pre-trained embedding weights

        Args:
          weights: GloVe embedding vector
          finetune: whether to finetune the embedding matrix during training
        """
        self.embedding = nn.Embedding.from_pretrained(weights, freeze=not finetune)

    def forward(self, x, lengths):
        """encodes source sentences

        Args:
          x: source tensor, (batch_size, sent_seq_len)
          lengths: valid source length list, (batch_size)

        Returns:
          outputs: RNN hidden states for all time-steps, (batch_size, sent_seq_len, hidden_dim)
          state: last RNN hidden state
            if RNN or GRU:  (num_layers, batch_size, hidden_dim)
            if LSTM: Tuple((num_layers, batch_size, hidden_dim),
                           (num_layers, batch_size, hidden_dim))
        """
        # (batch_size, src_seq_len, embed_size)
        x = self.dropout(self.embedding(x))
        x = pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False)
        outputs, state = self.rnn(x)
        # output: (batch_size, src_seq_len, hidden_dim)
        outputs, _ = pad_packed_sequence(outputs, batch_first=True)
        return outputs, state

### B. Vanilla Decoder

In [35]:
class Decoder(nn.Module):
    """Vanilla Recurrent Decoder"""
    def __init__(self, tree_stoi, embed_dim, hidden_dim, num_layers, dropout, rnn_type):
        """configs and layers for Vanilla Decoder

        Args:
          tree_stoi: tree str-to-int vocab
          embed_dim: embedding feature dimension
          hidden_dim: RNN hidden dimension
          num_layers: number of RNN layers
          dropout: dropout probability
          rnn_type: RNN, GRU or LSTM
        """
        super().__init__()
        ### configs
        vocab_size = len(tree_stoi)

        ### layers
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=tree_stoi[C.PAD])
        self.dropout = nn.Dropout(dropout)
        if rnn_type == C.GRU:
            self.rnn = nn.GRU(embed_dim, hidden_dim, num_layers, batch_first=True)
        elif rnn_type == C.LSTM:
            self.rnn = nn.LSTM(embed_dim, hidden_dim, num_layers, batch_first=True)
        else:
            self.rnn = nn.RNN(embed_dim, hidden_dim, num_layers, batch_first=True)
        self.dense = nn.Linear(hidden_dim, vocab_size)

        self.reset_parameters()

    def reset_parameters(self):
        for name, param in self.rnn.named_parameters():
            if 'weight' in name:
                nn.init.xavier_uniform_(param)
            elif 'bias' in name:
                nn.init.constant_(param, 0.)
        nn.init.xavier_uniform_(self.dense.weight)
        nn.init.constant_(self.dense.bias, 0.)

    def forward(self, x, state, *args):
        """decodes for `tgt_seq_len` number of steps

        when `tgt_seq_len` == 1, we decode for one time-step. This is useful
          during inference, when not using teacher forcing

        when `tgt_seq_len` > 1, we decode for more than one-step. This can happen
          only when using teacher forcing, i.e. during training

        Args:
          x: target tensor, (batch_size, tgt_seq_len)
          state: previous hidden state
            if RNN or GRU:  (num_layers, batch_size, hidden_dim)
            if LSTM: Tuple((num_layers, batch_size, hidden_dim),
                           (num_layers, batch_size, hidden_dim))
          *args: captures other unused arguments

        Returns:
          output: token-level logits, (batch_size, tgt_seq_len, vocab_size)
          state: decoder's last RNN hidden state. Similar shape as `state`
        """
        # (batch_size, src_seq_len, embed_size)
        x = self.dropout(self.embedding(x))

        # output: (batch_size, tgt_seq_len, hidden_dim)
        output, state = self.rnn(x, state)

        # (batch_size, tgt_seq_len, vocab_size)
        output = self.dense(output)
        return output, state

### C. Bahdanau Attentional Decoder

In [36]:
class BahdanauAttentionDecoder(nn.Module):
    """Bahdanau (Additive) Attentional Decoder
    score = v^T \cdot \tanh(W_h \cdot H_h + W_e \cdot H_e)
    where H_e: encoder outputs, and H_h: previous decoder hidden state
    """
    def __init__(self, tree_stoi, embed_dim, hidden_dim, num_layers, dropout, rnn_type):
        """configs and layers for Decoder with Bahdanau Attention

        Args:
          tree_stoi: tree str-to-int vocab
          embed_dim: embedding feature dimension
          hidden_dim: RNN hidden dimension
          num_layers: number of RNN layers
          dropout: dropout probability
          rnn_type: RNN, GRU or LSTM
        """
        super().__init__()
        ### configs
        vocab_size = len(tree_stoi)

        ### layers
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=tree_stoi[C.PAD])
        self.dropout = nn.Dropout(dropout)
        if rnn_type == C.GRU:
            self.rnn = nn.GRU(embed_dim, hidden_dim, num_layers, batch_first=True)
        elif rnn_type == C.LSTM:
            self.rnn = nn.LSTM(embed_dim, hidden_dim, num_layers, batch_first=True)
        else:
            self.rnn = nn.RNN(embed_dim, hidden_dim, num_layers, batch_first=True)
        self.dense = nn.Linear(hidden_dim, vocab_size)
        
        self.query_compute = nn.Linear(hidden_dim, hidden_dim)
        self.key_compute = nn.Linear(hidden_dim, hidden_dim)
        self.score_compute = nn.Linear(hidden_dim, 1)

        self.reset_parameters()

    def reset_parameters(self):
        for name, param in self.rnn.named_parameters():
            if 'weight' in name:
                nn.init.xavier_uniform_(param)
            elif 'bias' in name:
                nn.init.constant_(param, 0.)
        nn.init.xavier_uniform_(self.dense.weight)
        nn.init.constant_(self.dense.bias, 0.)

    def forward(self, x, state, encoder_outputs, mask=None):
        """single decoder step with Bahdanau attention. Additive attention takes place when the outptus of `fc_hidden` and `fc_encoder` are summed. There are other interpretations that prefer concat over addition.
        `x` may be a gold decoder input (with teacher forcing) or a previous decoder's prediction (without teacher forcing).

        Args:
          x: decoder input, (batch_size, 1)
          state: decoder's previous RNN hidden state
            if RNN or GRU:  (num_layers, batch_size, hidden_dim)
            if LSTM: Tuple(
              (num_layers, batch_size, hidden_dim),
              (num_layers, batch_size, hidden_dim))
          encoder_outputs: RNN hidden states for all time-steps, (batch_size, src_seq_len, hidden_dim)
          mask: boolean tensor, (batch_size, 1)

        Returns:
          output: logits, (batch_size, 1, vocab_size)
          state: decoder's last RNN hidden state. Similar shape as `state`
        """
        x = self.embedding(x);
        output, state = self.rnn(x, state)
        score = self.score_compute(torch.tanh(self.query_compute(output) + self.key_compute(encoder_outputs))).squeeze(-1)
        attn = F.softmax(score, dim=-1)
        output = torch.bmm(attn.unsqueeze(1), encoder_outputs)
        output = self.dense(output)
        return output, state

### D. Luong Attentional Decoder

In [37]:
class LuongAttentionDecoder(nn.Module):
    """Luong (Multiplicative) Attention

    Dot:
        score = H_e \cdot H_h
    General:
        score = H_e \cdot W \cdot H_h

    where H_e: encoder outputs, and H_h: previous decoder hidden state

    There also exists Luong Concat, but you are not asked to implement it.
    """
    def __init__(self, tree_stoi, embed_dim, hidden_dim, num_layers, dropout, rnn_type, mode):
        """configs and layers for Decoder with Luong Attention

        Args:
          tree_stoi: tree str-to-int vocab
          embed_dim: embedding feature dimension
          hidden_dim: RNN hidden dimension
          num_layers: number of RNN layers
          dropout: dropout probability
          rnn_type: RNN, GRU or LSTM
          mode: `dot` or `general`
        """
        super().__init__()
        ### configs
        vocab_size = len(tree_stoi)
        self.mode = mode;

        ### layers
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=tree_stoi[C.PAD])
        self.dropout = nn.Dropout(dropout)
        if rnn_type == C.GRU:
            self.rnn = nn.GRU(embed_dim, hidden_dim, num_layers, batch_first=True)
        elif rnn_type == C.LSTM:
            self.rnn = nn.LSTM(embed_dim, hidden_dim, num_layers, batch_first=True)
        else:
            self.rnn = nn.RNN(embed_dim, hidden_dim, num_layers, batch_first=True)
        self.dense = nn.Linear(hidden_dim, vocab_size)
        # general mode
        self.attention = nn.Linear(hidden_dim, hidden_dim)

        self.reset_parameters()

    def reset_parameters(self):
        for name, param in self.rnn.named_parameters():
            if 'weight' in name:
                nn.init.xavier_uniform_(param)
            elif 'bias' in name:
                nn.init.constant_(param, 0.)
        nn.init.xavier_uniform_(self.dense.weight)
        nn.init.constant_(self.dense.bias, 0.)

    def forward(self, x, state, encoder_outputs, mask=None):
        """single decoder step with Luong attention

        `x` may be a gold decoder input (with teacher forcing) or a previous decoder's
        prediction (without teacher forcing).

        Args:
          x: decoder input, (batch_size, 1)
          state: decoder's previous RNN hidden state
            if RNN or GRU:  (num_layers, batch_size, hidden_dim)
            if LSTM: Tuple(
              (num_layers, batch_size, hidden_dim),
              (num_layers, batch_size, hidden_dim)
            )
          encoder_outputs: RNN hidden states for all time-steps, (batch_size, src_seq_len, hidden_dim)
          mask: boolean tensor, (batch_size, 1)

        Returns:
          output: logits, (batch_size, 1, vocab_size)
          state: decoder's last RNN hidden state. Similar shape as `state`
        """

        if self.mode == 'general':
            x = self.embedding(x);
            output, state = self.rnn(x, state)
            attn_weights = self.attention(output)
            score = torch.bmm(attn_weights, encoder_outputs.transpose(-2, -1))

            attn = F.softmax(score, dim=-1)
            output = torch.bmm(attn, encoder_outputs)
            output = self.dense(output)
            return output, state

        elif self.mode == 'dot':
            x = self.embedding(x);
            output, state = self.rnn(x, state)
            score = torch.bmm(output, encoder_outputs.transpose(-2, -1))
            attn = F.softmax(score, dim=-1)
            output = torch.bmm(attn, encoder_outputs)
            output = self.dense(output)
            return output, state

### E. Seq2Seq

In [38]:
class Seq2Seq(nn.Module):
    """Seq2Seq"""
    def __init__(self, *, model, sent_stoi, tree_stoi, embed_dim, hidden_dim,
                 num_layers, dropout, rnn_type, glove, finetune_glove, device):
        super().__init__()
        print(f"\n{model.capitalize()} Seq2Seq init")

        ### configs
        self.model = model

        self.tree_vocab_size = len(tree_stoi)
        self.sent_pad_idx = sent_stoi[C.PAD]
        self.tree_bos_idx = tree_stoi[C.BOS]
        self.tree_eos_idx = tree_stoi[C.EOS]

        self.dropout = dropout
        self.rnn_type = rnn_type

        ### modules
        self.encoder = Encoder(sent_stoi=sent_stoi, embed_dim=embed_dim, hidden_dim=hidden_dim, num_layers=num_layers, dropout=dropout, rnn_type=rnn_type)
        if glove is not None:
            self.encoder.init_pretrained_embedding(glove, finetune=finetune_glove)

        decoder_kwargs = {'tree_stoi': tree_stoi, 'embed_dim': embed_dim, 'hidden_dim': hidden_dim, 'num_layers': num_layers, 'dropout': dropout, 'rnn_type': rnn_type}
        if model == C.BAHDANAU:
            self.decoder = BahdanauAttentionDecoder(**decoder_kwargs)
        elif model in [C.LUONG_DOT, C.LUONG_GENERAL]:
            mode = model.split('_')[-1]
            print(mode)
            self.decoder = LuongAttentionDecoder(**decoder_kwargs, mode=mode)
        else:
            self.decoder = Decoder(**decoder_kwargs)

        self.device = device

    def parse(self, x, x_lens):
        """forward computation for inference

        during inference we keep it simple by assuming `batch_size` == 1

        Args:
          x: (1, src_seq_len)
          x_lens: (1)

        Returns:
          predictions as list of ints
        """
        # encode step
        encoder_outputs, state = self.encoder(x, x_lens)

        # setup for decoding loop
        max_decode_step = x.size(1) * 3

        # padding mask necessary for attentional decoders; not used by vanilla decoder
        # True if valid token, False otherwise (i.e. padding)
        padding_mask = x==self.sent_pad_idx

        # (1, 1)
        yt = torch.tensor([[self.tree_bos_idx]], dtype=torch.long, device=self.device)

        # decoding loop
        preds = []
        for i in range(max_decode_step):
            # decode step
            output, state = self.decoder(yt, state, encoder_outputs, padding_mask)

            # current time-step prediction
            yt = output.argmax(-1)
            yt_int = yt.item()

            # when parsing, no need to keep EOS in our predictions but simply terminate
            if yt_int == self.tree_eos_idx:
                break
            preds.append(yt_int)

        return preds

    def forward(self, x, y, x_lens, teacher_forcing_ratio=0.0):
        """forward computation for training

        `x` and `y` are padded tensors where:
        `x`: each row has valid tokens + EOS + possibly PADs
        `y`: except for row with longest valid length, has BOS + valid tokens + EOS + possibly PADs
        row with longest valid length has BOS + valid tokens, because `y` == `trees[:,:-1]`
        from the training loop in `train.py`

        See Recitation Week 12 & 13 slides for details

        Args:
          x: (batch_size, src_seq_len)
          y: (batch_size, tgt_seq_len)
          x_lens: (batch_size)
          teacher_forcing_ratio: float to determine whether to use teacher forcing
            at each decoding step

        Returns:
          token-level logits, (batch_size, tgt_seq_len, tree_vocab_size)
        """
        # encode step
        encoder_outputs, state = self.encoder(x, x_lens)

        if teacher_forcing_ratio == 1.:
            assert self.model == C.VANILLA, \
                f'full teacher forcing only supports Vanilla Seq2Seq, but your model is {self.model}'

            # decode step with teacher forcing; let PyTorch iterate through `tgt_seq_len` dim internally
            outputs, _ = self.decoder(y, state)
        else:
            batch_size, tgt_seq_len = y.shape

            # padding mask necessary for attentional decoders; not used by vanilla decoder
            # True if valid token, False otherwise (i.e. padding)
            padding_mask = x==self.sent_pad_idx

            # decoding initial inputs as BOS, (batch_size, 1)
            yt = y[:, 0].unsqueeze(-1)

            # the first two dimensions are swapped to make storing easier
            outputs = torch.zeros([tgt_seq_len, batch_size, self.tree_vocab_size], device=self.device)

            # manual iteration through `tgt_seq_len` dimension
            for i in range(tgt_seq_len):
                # output: (batch_size, 1, tree_vocab_size)
                output, state = self.decoder(yt, state, encoder_outputs, padding_mask)

                # save the model output: (1, batch_size, tree_vocab_size)
                outputs[i] = output.transpose(0, 1)

                # decoding input: (batch_size, 1)
                if random.random() < teacher_forcing_ratio:
                    # without teacher forcing, use current prediction
                    yt = output.argmax(-1)
                else:
                    # with teacher forcing, fetch the next step's gold input
                    try:
                        yt = y[:,i+1].reshape([batch_size, 1])
                    except IndexError:
                        pass # last step, will terminate

            # (batch_size, tgt_seq_len, tree_vocab_size)
            outputs = outputs.transpose(0, 1)

        return outputs

## 3. Training

Here we do not import from `seq2seq.py`
* instead we use our models defined above

In [39]:
import utils
from data_loader import PTB, init_data_loader, load_ptb_dataset
from inference import predict
import torch.optim as optim
import tqdm

In [40]:
# model path hyperparams
MODEL_DIR = './outputs/model'
GLOVE_DIR = None # Change this to your GloVe dir if you wish to use GloVe embedding

In [48]:
# model data hyperparams
GLOVE_NAME = C.GLOVE_6B
GLOVE_STRATEGY = C.KEEP_OVERLAP
WITH_TORCHTEXT = False
FINETUNE_GLOVE = False
SENT_THRESHOLD = 5
TREE_THRESHOLD = 5

In [49]:
# model hyperparams
EMBED_DIM = 300   # 100, 200, 300
RNN = C.LSTM      # C.RNN, C.GRU, C.LSTM
NUM_LAYERS = 2    # 1, 2, 3
HIDDEN_DIM = 256  # 64, 128, 256
DROPOUT = 0.2     # 0.2, 0.5, 0.7

# model experiment hyperparams
EPOCHS = 10
EVAL_EVERY = 10
BATCH_SIZE = 68              # 68, 128, 200
LEARNING_RATE = 0.001        # 0.001, 0.005, 0.01
TEACHER_FORCING_RATIO = 0.75 # 0.5, 0.75
CHECKPOINT = None
SEED = 1334

### A. Setup

In [50]:
os.path.abspath(MODEL_DIR)
os.makedirs(MODEL_DIR, exist_ok=True)
utils.set_seed(SEED)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# init model state dict filename
model_path_template = os.path.join(MODEL_DIR, C.MODEL_PT_TEMPLATE)
device, model_path_template

(device(type='cuda'), './outputs/model/epoch_{}.pt')

### B. Load Vocab
initialize vocab itos and stoi from counter
###### Optionally load glove, which may modify `sent_vocab`

In [60]:
sent_itos, sent_stoi = vocabs.init_vocab(vocab_counters[0], SENT_THRESHOLD, special_symbols=[C.PAD, C.UNK, C.EOS])
tree_itos, tree_stoi = vocabs.init_vocab(vocab_counters[1], TREE_THRESHOLD, special_symbols=[C.PAD, C.UNK, C.BOS, C.EOS])

vocabs.display_vocabs(sent_itos, tree_itos)

glove = None
if GLOVE_DIR is not None:
    glove, sent_itos, sent_stoi = vocabs.init_glove(
        glove_dir=GLOVE_DIR, name=GLOVE_NAME, embed_dim=EMBED_DIM, 
        sent_itos=sent_itos, strategy=GLOVE_STRATEGY, with_torchtext=WITH_TORCHTEXT)


Vocab Info:
  Sent (11377) => <pad>, <unk>, <eos>, ,, the, ., !DIGITS, *, of, to, a, and, *T*, in, 's, that, for, *U*, $, ``, is, The, '', said, on, %, it, by, from, million, at, as, with, Mr., was, be, are, its, has, n't, an, will, have, !YEAR, he, or, company, year, which, would, about, --, says, they, were, this, market, more, billion, had, But, In, his, up, their, but, than, U.S., been, who, share, also, new, one, other, :, not, some, Corp., stock, I, years, New, shares, -RRB-, It, -LRB-, ;, could, all, Inc., last, two, out, &, trading, *ICH*, because, when, sales, ...
  Tree (195) => <pad>, <unk>, <bos>, <eos>, ), (NP, (VP, (NN, (IN, (NP-SBJ, (NNP, (S, (DT, (-NONE-, (JJ, (NNS, (PP, (,, (., (CD, (RB, (VBD, (VB, (CC, (SBAR, (TO, (VBZ, (VBN, (PRP, (VBG, (PP-LOC, (VBP, (PP-CLR, (ADVP, (MD, (QP, (WHNP, (POS, (PRP$, (PP-TMP, (ADJP, ($, (``, ('', (ADJP-PRD, (ADVP-TMP, (NP-PRD, (:, (PP-DIR, (NP-TMP, (WDT, (S-TPC, (JJR, (S-NOM, (SBAR-ADV, (NNPS, (RP, (PRT, (WHADVP, (NP-LGS, (ADVP-MNR, (PR

### C. Load PTB

In [61]:
dev_raw = load_ptb_dataset(DATA_DIR, C.DEV)
dev = PTB(C.DEV, dev_raw, sent_stoi, tree_stoi)


Dev PTB init
Sample vector from Dev
  Sent: Mrs. Hills said !DIGITS the U.S. wo n't accept any delays after Nov. !DIGITS because U.S. fish-processing firms enter into contracts *ICH* in the fall * to purchase the next season 's catch . <eos>
  Sent Vector: [983, 3529, 23, 6, 4, 67, 432, 39, 1534, 125, 2693, 102, 443, 6, 97, 67, 1, 414, 3002, 104, 663, 96, 13, 4, 676, 7, 9, 533, 4, 151, 1606, 14, 3081, 5, 2]
  Tree: <bos> (S (NP-SBJ (NNP ) (NNP ) ) (VP (VBD ) (SBAR (-NONE- ) (S (NP-SBJ (DT ) (NNP ) ) (VP (MD ) (RB ) (VP (VB ) (NP (DT ) (NNS ) ) (PP-TMP (IN ) (NP (NNP ) (CD ) ) ) (SBAR-PRP (IN ) (S (NP-SBJ (NNP ) (JJ ) (NNS ) ) (VP (VBP ) (PP-CLR (IN ) (NP (NNS ) (S (-NONE- ) ) ) ) (PP-TMP (IN ) (NP (DT ) (NN ) ) ) (S (NP-SBJ (-NONE- ) ) (VP (TO ) (VP (VB ) (NP (NP (DT ) (JJ ) (NN ) (POS ) ) (NN ) ) ) ) ) ) ) ) ) ) ) ) ) (. ) ) <eos>
  Tree Vector: [2, 11, 9, 10, 4, 10, 4, 4, 6, 21, 4, 24, 13, 4, 11, 9, 12, 4, 10, 4, 4, 6, 34, 4, 20, 4, 6, 22, 4, 5, 12, 4, 15, 4, 4, 39, 8, 4, 5, 10, 4, 

In [62]:
training_raw = load_ptb_dataset(DATA_DIR, C.TRAIN)
training = PTB(C.TRAIN, training_raw, sent_stoi, tree_stoi)


Train PTB init
Sample vector from Train
  Sent: Morgan Stanley is expected * to price another junk bond deal , $ !DIGITS million *U* of senior subordinated debentures by Continental Cablevision Inc. , next Tuesday . <eos>
  Sent Vector: [915, 1415, 20, 174, 7, 9, 131, 265, 500, 408, 505, 3, 18, 6, 29, 17, 8, 462, 1510, 1392, 27, 1828, 10165, 90, 3, 151, 583, 5, 2]
  Tree: <bos> (S (NP-SBJ (NNP ) (NNP ) ) (VP (VBZ ) (VP (VBN ) (S (NP-SBJ (-NONE- ) ) (VP (TO ) (VP (VB ) (NP (NP (DT ) (NN ) (NN ) (NN ) ) (, ) (NP (NP (QP ($ ) (CD ) (CD ) ) (-NONE- ) ) (PP (IN ) (NP (JJ ) (JJ ) (NNS ) ) ) (PP (IN ) (NP (NNP ) (NNP ) (NNP ) ) ) ) (, ) ) (NP-TMP (JJ ) (NNP ) ) ) ) ) ) ) (. ) ) <eos>
  Tree Vector: [2, 11, 9, 10, 4, 10, 4, 4, 6, 26, 4, 6, 27, 4, 11, 9, 13, 4, 4, 6, 25, 4, 6, 22, 4, 5, 5, 12, 4, 7, 4, 7, 4, 7, 4, 4, 17, 4, 5, 5, 35, 41, 4, 19, 4, 19, 4, 4, 13, 4, 4, 16, 8, 4, 5, 14, 4, 14, 4, 15, 4, 4, 4, 16, 8, 4, 5, 10, 4, 10, 4, 10, 4, 4, 4, 4, 17, 4, 4, 49, 14, 4, 10, 4, 4, 4, 4, 4, 4, 4,

### D. Training Data Loader init

In [63]:
train_dataloader = init_data_loader(training, sent_stoi, tree_stoi, BATCH_SIZE)

### E. Model init

In [65]:
model = Seq2Seq(
    model=C.BAHDANAU, sent_stoi=sent_stoi, tree_stoi=tree_stoi, embed_dim=EMBED_DIM, hidden_dim=HIDDEN_DIM, 
    num_layers=NUM_LAYERS, dropout=DROPOUT, rnn_type=RNN, glove=glove, finetune_glove=FINETUNE_GLOVE, device=device)

num_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
f'{num_trainable_params:,}'


Bahdanau Seq2Seq init


'5,849,012'

load from checkpoint and resume training if `CHECKPOINT` provided

In [66]:
epoch = 0
if CHECKPOINT is not None:
    ckpt = torch.load(CHECKPOINT, map_location='cpu') # always load initially to RAM
    model.load_state_dict(ckpt['model'])
    epoch = ckpt['epoch']
    print("Resume training with Token-Level ACC {:.3f} | BLEU {:.2f} at epoch {}".format(
      ckpt['acc'], ckpt['bleu'], epoch))

# move all model parameters to `device`
model.to(device)

Seq2Seq(
  (encoder): Encoder(
    (embedding): Embedding(11377, 300, padding_idx=0)
    (dropout): Dropout(p=0.2, inplace=False)
    (rnn): LSTM(300, 256, num_layers=2, batch_first=True)
  )
  (decoder): BahdanauAttentionDecoder(
    (embedding): Embedding(195, 300, padding_idx=0)
    (dropout): Dropout(p=0.2, inplace=False)
    (rnn): LSTM(300, 256, num_layers=2, batch_first=True)
    (dense): Linear(in_features=256, out_features=195, bias=True)
    (query_compute): Linear(in_features=256, out_features=256, bias=True)
    (key_compute): Linear(in_features=256, out_features=256, bias=True)
    (score_compute): Linear(in_features=256, out_features=1, bias=True)
  )
)

### F. Loss and Optimizer init

In [67]:
criterion = nn.CrossEntropyLoss(ignore_index=tree_stoi[C.PAD])
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

### G. Training Loop

first, define eval function to be used inside training loop

In [74]:
def evaluate(model, dev, dev_raw, epoch, tree_itos, device):
    print("Begin Inference on Dev..")
    preds, dev_acc = predict(model, dev, tree_itos, device)

    # bleu
    bleu = utils.raw_corpus_bleu(preds, dev_raw[1])
    print('  Dev Token-level Accuracy: {:.3f} | BLEU: {:.2f}'.format(dev_acc, bleu))

    # export model params and other misc info
    model_fpath = model_path_template.format(epoch+1)
    print("Exporting model params at", model_fpath)
    torch.save({'model': model.state_dict(), 'epoch': epoch+1, 'acc': dev_acc, 'bleu': bleu},  model_fpath)

    # export dev predictions
    dev_preds_path = os.path.join(MODEL_DIR, C.TREE_PRED_TEMPLATE.format('dev', f'_{epoch+1}'))
    print("Exporting dev predictions at", dev_preds_path)
    utils.export_txt(preds, dev_preds_path)

    # sample prediction display
    print("Sample Dev Prediction")
    sample_idx = random.randint(0, len(dev)-1)
    print("  [SENT]", dev_raw[0][sample_idx])
    sample_gold = dev_raw[1][sample_idx]
    print(f"  [GOLD: {len(sample_gold.split())} toks]", sample_gold)
    sample_pred = preds[sample_idx]
    print(f"  [PRED: {len(sample_pred.split())} toks]", sample_pred)

# adjust target epoch value
target_epochs = 10
epoch, target_epochs

(9, 10)

In [75]:
for epoch in range(epoch, target_epochs):
    epoch_loss = 0
    num_correct = num_tokens = 0

    for batch in tqdm.tqdm(train_dataloader, desc=f'[Training {epoch+1}/{target_epochs}]'):
        optimizer.zero_grad()
        sents, trees, sent_lens = utils.to_device(batch, device)
        # since last index of `trees` is always PAD or EOS, there is no need to predict a token with PAD or EOS as inputs when decoding; hence we omit it output logits: (batch_size, tgt_seq_len-1, vocab_size)
        outputs = model(sents, trees[:,:-1], sent_lens, teacher_forcing_ratio=TEACHER_FORCING_RATIO)
        # decoder output (gold target) that drops BOS at the beginning since BOS is never part of our predictions
        trees_target = trees[:,1:]
        # for the reason why `outputs` are transposed, see `Shape:` in https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html
        loss = criterion(outputs.transpose(1,2), trees_target)
        # token-level accuracy: (batch_size, tgt_seq_len-1) model predictions as ints: (batch_size, tgt_seq_len-1)
        preds = outputs.argmax(-1)
        trees_mask = trees_target == tree_stoi[C.PAD]
        num_correct += (preds==trees_target).masked_fill_(trees_mask, False).sum()
        num_tokens += (~trees_mask).sum() # includes <eos> which needs to be predicted correctly
        # optimizer step
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
        optimizer.step()
        epoch_loss += loss.item()

    acc = (num_correct.item() / num_tokens.item())
    print(' Training Loss: {:.2f} Token-level Accuracy: {:.3f}'.format(epoch_loss, acc))
    ### 7. eval on dev
    if (epoch+1) % EVAL_EVERY == 0:
        evaluate(model, dev, dev_raw, epoch, tree_itos, device)

[Training 10/10]: 100%|██████████| 586/586 [03:45<00:00,  2.60it/s]


 Training Loss: 399.27 Token-level Accuracy: 0.796
Begin Inference on Dev..


100%|██████████| 1700/1700 [01:19<00:00, 21.36it/s]


  Dev Token-level Accuracy: 0.498 | BLEU: 67.14
Exporting model params at ./outputs/model/epoch_10.pt
Exporting dev predictions at ./outputs/model/dev_pred_10.txt
Sample Dev Prediction
  [SENT] Clients `` are all staying out '' of the market , one Merrill trader says !DIGITS *T* .
  [GOLD: 62 toks] (S (S-TPC (NP-SBJ (NP (NNS ) ) ) (`` ) (VP (VBP ) (DT ) (VP (VBG ) (PP-LOC (IN ) ('' ) (PP (IN ) (NP (DT ) (NN ) ) ) ) ) ) ) (, ) (NP-SBJ (CD ) (NNP ) (NN ) ) (VP (VBZ ) (SBAR (-NONE- ) (S (-NONE- ) ) ) ) (. ) )
  [PRED: 57 toks] (S (S-TPC (NP-SBJ ) ) (`` ) (VP (VBP ) (NP-PRD (DT ) ) ) ) (PRT (RP ) ) ('' ) (PP (IN ) (NP (DT ) (NN ) ) ) ) ) ) (NP-SBJ (NP-SBJ (NP-SBJ ) ) (NNP ) (NN ) ) (VP (VBZ ) (SBAR (-NONE- ) (S (-NONE- ) ) ) )


## 4. Inference

In [76]:
# we already loaded `test_raw` when preparing data
test = PTB(C.TEST, test_raw, sent_stoi, tree_stoi)
INFERENCE_CKPT = None # if None, will use current `model` params


Test PTB init
Sample vector from Test
  Sent: This was an October massacre '' like those that *T* occurred in !YEAR and !YEAR . <eos>
  Sent Vector: [231, 34, 40, 688, 6183, 22, 187, 175, 15, 12, 3041, 13, 43, 11, 43, 5, 2]
  Tree: <bos> (S (NP-SBJ (DT ) ) (VP (VBD ) (NP-PRD (NP (DT ) (NNP ) (NN ) ) ('' ) (PP (IN ) (NP (NP (DT ) ) (SBAR (WHNP (WDT ) ) (S (NP-SBJ (-NONE- ) ) (VP (VBD ) (PP-TMP (IN ) (NP (CD ) (CC ) (CD ) ) ) ) ) ) ) ) ) ) (. ) ) <eos>
  Tree Vector: [2, 11, 9, 12, 4, 4, 6, 21, 4, 46, 5, 12, 4, 10, 4, 7, 4, 4, 43, 4, 16, 8, 4, 5, 5, 12, 4, 4, 24, 36, 50, 4, 4, 11, 9, 13, 4, 4, 6, 21, 4, 39, 8, 4, 5, 19, 4, 23, 4, 19, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 18, 4, 4, 3]


In [77]:
if INFERENCE_CKPT is not None:
    ckpt = torch.load(args.checkpoint, map_location='cpu')  # always load initially to RAM
    model.load_state_dict(ckpt['model'])
    print("Successfully loaded checkpoint from {}: Dev Token-Level ACC {:.3f} | BLEU {:.2f}".format(
        args.checkpoint, ckpt['acc'], ckpt['bleu']))

In [78]:
# run prediction step
preds, acc = predict(model, test, tree_itos, device)
# export predictions (this will overwrite existing prediction file)
preds_path = os.path.join(MODEL_DIR, C.TREE_PRED_TEMPLATE.format('test', ''))
utils.export_txt(preds, preds_path)

100%|██████████| 2416/2416 [01:52<00:00, 21.45it/s]


In [79]:
# bleu
bleu = utils.raw_corpus_bleu(preds, test_raw[1])
"Test Token-level Accuracy: {:.3f} | BLEU: {:.2f}".format(acc, bleu)

'Test Token-level Accuracy: 0.491 | BLEU: 66.68'