In [5]:
import nltk
from sklearn.model_selection import train_test_split
import torch

from xtagger import xtagger_dataset_to_df
from xtagger import df_to_torchtext_data
from transformers import AutoTokenizer

nltk_data = list(nltk.corpus.treebank.tagged_sents(tagset='universal'))
train_set,test_set =train_test_split(nltk_data,train_size=0.8,test_size=0.2,random_state = 2112)

df_train = xtagger_dataset_to_df(train_set)
df_test = xtagger_dataset_to_df(test_set)

device = torch.device("cuda")

tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")

train_iterator, valid_iterator, test_iterator, TEXT, TAGS = df_to_torchtext_data(
    df_train, 
    df_test, 
    device, 
    transformers = True,
    tokenizer = tokenizer,
    batch_size=32
)

Number of training examples: 3131
Number of testing examples: 783
Unique tokens in TEXT vocabulary: 5444
Unique tokens in TAGS vocabulary: 13


In [7]:
from xtagger import BERTForTagging

model = BERTForTagging(
    output_dim = 13,
    TEXT = TEXT,
    TAGS = TAGS,
    dropout = 0.2,
    device = device,
    cuda = True
  )

The model has 9997 trainable parameters
The model has 108310272 non-trainable parameters


In [9]:
model.fit(
    train_iterator, 
    valid_iterator, 
    eval_metrics = ["acc", "avg_f1"], 
    epochs = 10
)

  0%|          | 0/980 [00:00<?, ?it/s]

Evaluating...


  0%|          | 0/123 [00:00<?, ?it/s]

{'eval': {'acc': 82.09734069242349, 'avg_f1': {'weighted': 81.34879152884871, 'micro': 82.09734069242349, 'macro': 71.51932441708736}}, 'train': {'acc': 82.69140267010131, 'avg_f1': {'weighted': 82.10337663320242, 'micro': 82.69140267010131, 'macro': 72.33960207955383}}, 'eval_loss': 0.7153935766220093, 'train_loss': 0.7072247771584258}
Evaluating...


  0%|          | 0/123 [00:00<?, ?it/s]

{'eval': {'acc': 85.24335173105871, 'avg_f1': {'weighted': 84.90494929342174, 'micro': 85.24335173105871, 'macro': 76.20841853923078}}, 'train': {'acc': 85.79124662521983, 'avg_f1': {'weighted': 85.53556922675834, 'micro': 85.79124662521983, 'macro': 76.77499661798352}}, 'eval_loss': 0.5569747805595398, 'train_loss': 0.5482063551946562}
Evaluating...


  0%|          | 0/123 [00:00<?, ?it/s]

{'eval': {'acc': 86.58805820371299, 'avg_f1': {'weighted': 86.26767159407876, 'micro': 86.588058203713, 'macro': 77.87141364702664}}, 'train': {'acc': 87.22413494166894, 'avg_f1': {'weighted': 86.95909147129622, 'micro': 87.22413494166894, 'macro': 78.44841489628045}}, 'eval_loss': 0.4801012313365936, 'train_loss': 0.47021751531532835}
Evaluating...


  0%|          | 0/123 [00:00<?, ?it/s]

{'eval': {'acc': 87.6668339187155, 'avg_f1': {'weighted': 87.46329466096165, 'micro': 87.6668339187155, 'macro': 79.23183083936223}}, 'train': {'acc': 88.1393505560647, 'avg_f1': {'weighted': 87.97434511893405, 'micro': 88.1393505560647, 'macro': 79.63319339183685}}, 'eval_loss': 0.4359757947921753, 'train_loss': 0.42670554196348}
Evaluating...


  0%|          | 0/123 [00:00<?, ?it/s]

{'eval': {'acc': 88.3492222779729, 'avg_f1': {'weighted': 88.21786241760054, 'micro': 88.3492222779729, 'macro': 80.12611430328977}}, 'train': {'acc': 88.9009981918609, 'avg_f1': {'weighted': 88.8130114590012, 'micro': 88.90099819186091, 'macro': 80.62307653769545}}, 'eval_loss': 0.4064346432685852, 'train_loss': 0.39726588920671113}
Evaluating...


  0%|          | 0/123 [00:00<?, ?it/s]

{'eval': {'acc': 88.66532865027597, 'avg_f1': {'weighted': 88.49774540363401, 'micro': 88.66532865027597, 'macro': 80.38526747121644}}, 'train': {'acc': 89.31464097292745, 'avg_f1': {'weighted': 89.20331263693815, 'micro': 89.31464097292745, 'macro': 81.05375049783525}}, 'eval_loss': 0.38456093668937685, 'train_loss': 0.37361516849118837}
Evaluating...


  0%|          | 0/123 [00:00<?, ?it/s]

{'eval': {'acc': 88.96136477671851, 'avg_f1': {'weighted': 88.84846805573173, 'micro': 88.96136477671851, 'macro': 80.7721702756681}}, 'train': {'acc': 89.5970079013202, 'avg_f1': {'weighted': 89.53644383963858, 'micro': 89.5970079013202, 'macro': 81.46808716754805}}, 'eval_loss': 0.36782260656356813, 'train_loss': 0.35705909984452383}
Evaluating...


  0%|          | 0/123 [00:00<?, ?it/s]

{'eval': {'acc': 89.18715504264927, 'avg_f1': {'weighted': 89.07293548483352, 'micro': 89.18715504264927, 'macro': 81.06129476005373}}, 'train': {'acc': 89.8124984519357, 'avg_f1': {'weighted': 89.7536116943709, 'micro': 89.8124984519357, 'macro': 81.79154058685216}}, 'eval_loss': 0.35759310722351073, 'train_loss': 0.3455424868330664}
Evaluating...


  0%|          | 0/123 [00:00<?, ?it/s]

{'eval': {'acc': 89.33266432513798, 'avg_f1': {'weighted': 89.17648653578746, 'micro': 89.33266432513798, 'macro': 81.155844533212}}, 'train': {'acc': 90.00693532806578, 'avg_f1': {'weighted': 89.91018870940354, 'micro': 90.00693532806578, 'macro': 81.92647391325453}}, 'eval_loss': 0.34750911593437195, 'train_loss': 0.33492926401751383}
Evaluating...


  0%|          | 0/123 [00:00<?, ?it/s]

{'eval': {'acc': 89.50827897641747, 'avg_f1': {'weighted': 89.34650556168636, 'micro': 89.50827897641747, 'macro': 81.3793538412704}}, 'train': {'acc': 90.22118742724098, 'avg_f1': {'weighted': 90.11575875347711, 'micro': 90.22118742724096, 'macro': 82.22700783168506}}, 'eval_loss': 0.3388308537006378, 'train_loss': 0.326385920905337}


In [10]:
model.evaluate(test_iterator, eval_metrics = ["acc", "avg_f1"])

  0%|          | 0/25 [00:00<?, ?it/s]

{'acc': 89.50775252145115,
 'avg_f1': {'macro': 81.37881674895928,
  'micro': 89.50775252145115,
  'weighted': 89.345972741001}}

In [11]:
s = ["there", "are", "no", "two", "words", "in", "the", "english", 
     "language", "more", "harmful", "than", "good", "job"]
model.predict(s, tokenizer)

([('there', 'DET'),
  ('are', 'VERB'),
  ('no', 'ADV'),
  ('two', 'NUM'),
  ('words', 'NOUN'),
  ('in', 'ADP'),
  ('the', 'DET'),
  ('english', 'NOUN'),
  ('language', 'NOUN'),
  ('more', 'ADJ'),
  ('harmful', 'ADJ'),
  ('than', 'ADP'),
  ('good', 'ADJ'),
  ('job', 'NOUN')],
 ['language'])