In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
from fastai.imports import *
from fastai.plots import *
from torchtext.data import Field
from fastai.lm_rnn import seq2seq_reg
from quicknlp import SpacyTokenizer, print_batch, S2SModelData
from pathlib import Path

In [3]:
import matplotlib.pyplot as plt
import numpy as np

In [4]:
DATAPATH = "dataset/translation"


In [5]:
!cd $DATAPATH; ls

data  models  tmp  train  validation


In [6]:
INIT_TOKEN = "<sos>"
EOS_TOKEN = "<eos>"
fields = [
    ("english", Field(init_token=INIT_TOKEN, eos_token=EOS_TOKEN, tokenize=SpacyTokenizer('en'), lower=True)),
    ("french", Field(init_token=INIT_TOKEN, eos_token=EOS_TOKEN, tokenize=SpacyTokenizer('fr'), lower=True))

]
batch_size = 64
data = S2SModelData.from_text_files(path=DATAPATH, fields=fields,
                                    train="train",
                                    validation="validation",
                                    source_names=["english", "french"],
                                    target_names=["french"],
                                    bs= batch_size
                                   )




In [7]:
print(f'num tr batches: {len(data.trn_dl)}, num tr samples: {len(data.trn_ds)}')
print(f'num val batches: {len(data.val_dl)},num val samples: {len(data.val_ds)}')

num tr batches: 4, num tr samples: 201
num val batches: 4,num val samples: 201


In [8]:
emb_size = 300
nh = 1024
nl = 3
tnh = 512
learner = data.get_model(emb_sz=emb_size,
                         nhid=nh,
                         nlayers=nl,
                         bidir=True,
                         max_iterations=30,
                         att_nhid=tnh,
                         attention=True
                         )
reg_fn = partial(seq2seq_reg, alpha=2, beta=1)
clip = 0.3
learner.reg_fn = reg_fn
learner.clip = clip

In [9]:
learner.summary()

Seq2SeqAttention(
  (encoder): EmbeddingRNNEncoder(
    (rnns): ModuleList(
      (0): Cell(
        (cell): WeightDrop(
          (module): LSTM(300, 512, dropout=0.3, bidirectional=True)
        )
      )
      (1): Cell(
        (cell): WeightDrop(
          (module): LSTM(1024, 512, dropout=0.3, bidirectional=True)
        )
      )
      (2): Cell(
        (cell): WeightDrop(
          (module): LSTM(1024, 150, dropout=0.3, bidirectional=True)
        )
      )
    )
    (dropouths): ModuleList(
      (0): LockedDropout(
      )
      (1): LockedDropout(
      )
      (2): LockedDropout(
      )
    )
    (encoder): Embedding(102, 300, padding_idx=1)
    (encoder_with_dropout): EmbeddingDropout(
      (embed): Embedding(102, 300, padding_idx=1)
    )
    (dropouti): LockedDropout(
    )
  )
  (decoder): RNNAttentionDecoder(
    (rnns): ModuleList(
      (0): Cell(
        (cell): WeightDrop(
          (module): LSTM(600, 1024, dropout=0.3)
        )
      )
      (1): Cell(
      

In [10]:
print_batch(lr=learner,dt=data, input_field="english", output_field="french",num_sentences=4)

batch: 0 sample : 0
input: am i fat ?
target: - je grosse ?
prediction: ['on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on']


batch: 0 sample : 1
input: am i fat ?
target: - je gros ?
prediction: ['on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on']


batch: 0 sample : 2
input: ask tom .
target: à tom .
prediction: ['on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on on

In [12]:
learner.lr_find()