# LSTM

Lets import required modules and set up our logging module

In [19]:
import logging
import random

import nltk
import torch
import xtagger
from sklearn.model_selection import train_test_split
from xtagger import LabelEncoder, WhiteSpaceTokenizer, RNNTagger, Accuracy, F1, ClasswiseF1
from xtagger.utils.logging_helpers import LoggingHandler


logging.basicConfig(
    format='%(asctime)s - %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S',
    level=logging.INFO,
    handlers=[LoggingHandler()]
)

We import our dataset from NLTK. x-tagger only accepts xtagger.DATASET_TYPE input, but we have very rich collections of data converters. Then, set up our processors such as pretokenizer, tokenizer, and label encoder

In [3]:
def pretokenizer(text):
    return text.lower().split()

print(xtagger.DATASET_TYPE)


nltk_data = list(nltk.corpus.treebank.tagged_sents(tagset='universal'))
train_set, test_set = train_test_split(nltk_data,train_size=0.8,test_size=0.2)

label_encoder = LabelEncoder(train_set)

tokenizer = WhiteSpaceTokenizer()
tokenizer.fit(train_set, pretokenizer=pretokenizer)

typing.List[typing.List[typing.Tuple[str, str]]]
2023-10-12 14:27:04 - Vocab size: 10151


In [4]:
model = RNNTagger(
    rnn="LSTM",
    vocab_size=tokenizer.vocab_size,
    embedding_dim=100,
    padding_idx=tokenizer.pad_token_id,
    hidden_size=128,
    num_layers=1,
    bidirectional=True,
    n_classes=len(label_encoder.maps) + 1,
    dropout=0.1
)

In [6]:
optimizer = torch.optim.Adam(model.parameters())
criterion = torch.nn.CrossEntropyLoss(ignore_index=label_encoder.pad_tag_id)
device = torch.device("mps")

In [7]:
with model._autocast(dtype=torch.float32):
    results = model.fit(
        train_set=train_set,
        dev_set=test_set,
        tokenizer=tokenizer,
        label_encoder=label_encoder,
        optimizer=optimizer,
        criterion=criterion,
        num_epochs=15,
        max_length=128,
        batch_size=32,
        device=device,
        eval_metrics=[Accuracy, F1, ClasswiseF1],
        use_amp=True
    )

2023-10-12 14:28:08 - Output path is set to ./out.
2023-10-12 14:28:08 - Evaluation results will be saved to ./out.
2023-10-12 14:28:08 - Checkpoints will be saved to ./out.




Training:   0%|          | 0/1470 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/25 [00:00<?, ?it/s]

2023-10-12 14:28:18 - Epoch 1/15 - Train Loss: 1.3004552326640304
2023-10-12 14:28:18 - Epoch 1/15 - Evaluation Loss: 0.6933601450920105


Evaluating:   0%|          | 0/25 [00:00<?, ?it/s]

2023-10-12 14:28:21 - Epoch 2/15 - Train Loss: 0.5522085102845211
2023-10-12 14:28:21 - Epoch 2/15 - Evaluation Loss: 0.4504885971546173


Evaluating:   0%|          | 0/25 [00:00<?, ?it/s]

2023-10-12 14:28:24 - Epoch 3/15 - Train Loss: 0.39233980525513085
2023-10-12 14:28:24 - Epoch 3/15 - Evaluation Loss: 0.3570176923274994


Evaluating:   0%|          | 0/25 [00:00<?, ?it/s]

2023-10-12 14:28:27 - Epoch 4/15 - Train Loss: 0.3037042769850517
2023-10-12 14:28:27 - Epoch 4/15 - Evaluation Loss: 0.29874316036701204


Evaluating:   0%|          | 0/25 [00:00<?, ?it/s]

2023-10-12 14:28:30 - Epoch 5/15 - Train Loss: 0.2448711503221064
2023-10-12 14:28:30 - Epoch 5/15 - Evaluation Loss: 0.2620729994773865


Evaluating:   0%|          | 0/25 [00:00<?, ?it/s]

2023-10-12 14:28:33 - Epoch 6/15 - Train Loss: 0.19935344357271584
2023-10-12 14:28:33 - Epoch 6/15 - Evaluation Loss: 0.23694030523300172


Evaluating:   0%|          | 0/25 [00:00<?, ?it/s]

2023-10-12 14:28:36 - Epoch 7/15 - Train Loss: 0.16230678193423212
2023-10-12 14:28:36 - Epoch 7/15 - Evaluation Loss: 0.2235153341293335


Evaluating:   0%|          | 0/25 [00:00<?, ?it/s]

2023-10-12 14:28:39 - Epoch 8/15 - Train Loss: 0.13465633152090772
2023-10-12 14:28:39 - Epoch 8/15 - Evaluation Loss: 0.22235272645950319


Evaluating:   0%|          | 0/25 [00:00<?, ?it/s]

2023-10-12 14:28:42 - Epoch 9/15 - Train Loss: 0.11109754177076477
2023-10-12 14:28:42 - Epoch 9/15 - Evaluation Loss: 0.20667631506919862


Evaluating:   0%|          | 0/25 [00:00<?, ?it/s]

2023-10-12 14:28:45 - Epoch 10/15 - Train Loss: 0.09154008826887121
2023-10-12 14:28:45 - Epoch 10/15 - Evaluation Loss: 0.2017851948738098


Evaluating:   0%|          | 0/25 [00:00<?, ?it/s]

2023-10-12 14:28:48 - Epoch 11/15 - Train Loss: 0.07439479277450212
2023-10-12 14:28:48 - Epoch 11/15 - Evaluation Loss: 0.20238860785961152


Evaluating:   0%|          | 0/25 [00:00<?, ?it/s]

2023-10-12 14:28:51 - Epoch 12/15 - Train Loss: 0.062384373337334514
2023-10-12 14:28:51 - Epoch 12/15 - Evaluation Loss: 0.19868223667144774


Evaluating:   0%|          | 0/25 [00:00<?, ?it/s]

2023-10-12 14:28:54 - Epoch 13/15 - Train Loss: 0.052897779575111915
2023-10-12 14:28:54 - Epoch 13/15 - Evaluation Loss: 0.19688178718090057


Evaluating:   0%|          | 0/25 [00:00<?, ?it/s]

2023-10-12 14:28:57 - Epoch 14/15 - Train Loss: 0.044348734925139924
2023-10-12 14:28:57 - Epoch 14/15 - Evaluation Loss: 0.2053576412796974


Evaluating:   0%|          | 0/25 [00:00<?, ?it/s]

2023-10-12 14:29:00 - Epoch 15/15 - Train Loss: 0.0367054227853612
2023-10-12 14:29:00 - Epoch 15/15 - Evaluation Loss: 0.20481922388076781


In [9]:
model.evaluate(
    test_set=test_set,
    tokenizer=tokenizer,
    device=device,
    label_encoder=label_encoder,
    batch_size=32,
    max_length=128,
    criterion=criterion,
    eval_metrics=[Accuracy, ClasswiseF1]
)

Testing:   0%|          | 0/25 [00:00<?, ?it/s]

defaultdict(<function xtagger.models.rnn.rnn.RNNTagger.evaluate.<locals>.<lambda>()>,
            {'test': {'accuracy': 0.9417411621292158,
              'classwise_f1': {'.': 0.999780268072951,
               'NOUN': 0.7610921501706484,
               'DET': 0.9789311408016443,
               'ADP': 0.837696335078534,
               'X': 0.9964994165694282,
               'PRON': 0.9893617021276595,
               'PRT': 0.9328894340283862,
               'ADJ': 0.9223300970873786,
               'VERB': 0.9909747292418772,
               'NUM': 0.9811616954474097,
               'CONJ': 0.9160333642261352,
               'ADV': 0.9762553522771507},
              'loss': 0.19694156142381522}})

In [20]:
idx = random.randint(0, len(train_set)-1)
sentence = " ".join([pair[0] for pair in train_set[idx]])

preds = model.predict(
    sentence=sentence,
    tokenizer=tokenizer,
    label_encoder=label_encoder,
    device=device,
)

Prediction:   0%|          | 0/1 [00:00<?, ?it/s]

In [21]:
preds

[[('[START]', 'VERB'),
  ('the', 'DET'),
  ('patent', 'NOUN'),
  ('for', 'ADP'),
  ('interleukin-3', 'NOUN'),
  ('covers', 'VERB'),
  ('materials', 'NOUN'),
  ('and', 'CONJ'),
  ('methods', 'NOUN'),
  ('used', 'VERB'),
  ('*', 'X'),
  ('*', 'X'),
  ('to', 'PRT'),
  ('make', 'VERB'),
  ('the', 'DET'),
  ('human', 'ADJ'),
  ('blood', 'NOUN'),
  ('cell', 'NOUN'),
  ('growth', 'NOUN'),
  ('factor', 'NOUN'),
  ('via', 'ADP'),
  ('recombinant', 'ADJ'),
  ('dna', 'NOUN'),
  ('technology', 'NOUN'),
  ('.', '.'),
  ('[END]', 'NOUN')]]

In [28]:
!tail -n 30 "out/eval/results.json" 

    },
    "15": {
        "train": {
            "loss": 0.0367054227853612
        },
        "eval": {
            "accuracy": 0.9417411621292158,
            "f1": {
                "weighted": 0.9412477503544329,
                "micro": 0.9417411621292159,
                "macro": 0.940250473760767
            },
            "classwise_f1": {
                ".": 0.999780268072951,
                "NOUN": 0.7610921501706484,
                "DET": 0.9789311408016443,
                "ADP": 0.837696335078534,
                "X": 0.9964994165694282,
                "PRON": 0.9893617021276595,
                "PRT": 0.9328894340283862,
                "ADJ": 0.9223300970873786,
                "VERB": 0.9909747292418772,
                "NUM": 0.9811616954474097,
                "CONJ": 0.9160333642261352,
                "ADV": 0.9762553522771507
            },
            "loss": 0.20481922388076781
        }
    }
}