In [21]:
import os
DATA_DIR = os.path.join(os.path.abspath(os.path.join(os.path.curdir, "..")), "data")

model_checkpoint = os.path.join(DATA_DIR, "models","checkpoints","checkpoint-2500/")


In [22]:
from transformers import AutoTokenizer
    
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

In [23]:
tokenizer

PreTrainedTokenizerFast(name_or_path='/Users/max/Documents/projects/ai/sequence-labeling/data/models/checkpoints/checkpoint-2500/', vocab_size=30522, model_max_len=512, is_fast=True, padding_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'})

In [20]:
import transformers
assert isinstance(tokenizer, transformers.PreTrainedTokenizerFast)

In [5]:
tokenized_input = tokenizer("Hello, this is one sentence! Good night!")
tokenized_input

{'input_ids': [101, 7592, 1010, 2023, 2003, 2028, 6251, 999, 2204, 2305, 999, 102], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}

In [6]:

tokens = tokenizer.convert_ids_to_tokens(tokenized_input["input_ids"])
print(tokens)

['[CLS]', 'hello', ',', 'this', 'is', 'one', 'sentence', '!', 'good', 'night', '!', '[SEP]']


In [7]:
from transformers import AutoModelForTokenClassification

model = AutoModelForTokenClassification.from_pretrained(model_checkpoint)

In [8]:
model

DistilBertForTokenClassification(
  (distilbert): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0): TransformerBlock(
          (attention): MultiHeadSelfAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)
          

In [9]:
label_list = ['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC', 'B-MISC', 'I-MISC']

In [10]:
' '.join(['Germany', "'s", 'representative', 'to', 'the', 'European', 'Union', "'s", 'veterinary', 'committee', 'Werner', 'Zwingmann', 'said', 'on', 'Wednesday', 'consumers', 'should', 'buy', 'sheepmeat', 'from', 'countries', 'other', 'than', 'Britain', 'until', 'the', 'scientific', 'advice', 'was', 'clearer', '.'])

"Germany 's representative to the European Union 's veterinary committee Werner Zwingmann said on Wednesday consumers should buy sheepmeat from countries other than Britain until the scientific advice was clearer ."

In [11]:
import torch

s = "Germany's representative to the European Union's veterinary committee Werner Zwingmann said on Wednesday consumers should buy sheepmeat from countries other than Britain until the scientific advice was clearer."
ti = tokenizer(s)
print(ti)
ti = tokenizer.encode(s)
print(ti)
input_ids = torch.tensor([ti])
model(input_ids)

{'input_ids': [101, 2762, 1005, 1055, 4387, 2000, 1996, 2647, 2586, 1005, 1055, 15651, 2837, 14121, 1062, 9328, 5804, 2056, 2006, 9317, 10390, 2323, 4965, 8351, 4168, 4017, 2013, 3032, 2060, 2084, 3725, 2127, 1996, 4045, 6040, 2001, 24509, 1012, 102], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}
[101, 2762, 1005, 1055, 4387, 2000, 1996, 2647, 2586, 1005, 1055, 15651, 2837, 14121, 1062, 9328, 5804, 2056, 2006, 9317, 10390, 2323, 4965, 8351, 4168, 4017, 2013, 3032, 2060, 2084, 3725, 2127, 1996, 4045, 6040, 2001, 24509, 1012, 102]


TokenClassifierOutput(loss=None, logits=tensor([[[ 8.9136, -0.9371, -1.5237, -1.0358, -1.6040, -1.4502, -1.6999,
          -0.7906, -1.5180],
         [-1.4842, -0.5072, -1.3988, -0.0443, -1.8029,  7.9319, -0.2244,
           0.3503, -2.2295],
         [ 8.8620, -1.2049, -1.7902, -0.8684, -1.3035, -1.5098, -2.0356,
          -0.6383, -1.3549],
         [ 8.8042, -1.1878, -1.8321, -0.7357, -1.4156, -1.3873, -2.1791,
          -0.4464, -1.4726],
         [ 8.6329, -1.0893, -1.9809, -0.3864, -1.3250, -1.6509, -2.4413,
          -0.2689, -1.5393],
         [ 8.8453, -1.1129, -1.5121, -0.8787, -1.3618, -1.5475, -1.9312,
          -0.8868, -1.6168],
         [ 8.7882, -1.1661, -1.7364, -0.3707, -1.4714, -0.9157, -2.1007,
          -0.6204, -1.8341],
         [-0.3362, -1.2733, -2.2390,  6.4149,  1.3414,  0.2284, -2.5698,
           0.5114, -1.4814],
         [ 1.1663, -1.6495, -1.1086,  0.6043,  6.1599, -2.7057,  0.0407,
          -2.1928,  0.7981],
         [ 8.7525, -1.4281, -1.5460, -1.02

In [12]:
import numpy as np
preds = model(input_ids)
labels = np.argmax(preds.logits[0].detach().numpy(), axis=1)
print(f'{len(labels)}\n{labels}')

39
[0 5 0 0 0 0 0 3 4 0 0 0 0 1 2 2 2 0 0 0 0 0 0 0 0 0 0 0 0 0 5 0 0 0 0 0 0
 0 7]


In [17]:
tokens = tokenizer.convert_ids_to_tokens(ti)
print(len(tokens))
print(tokens)

39
['[CLS]', 'germany', "'", 's', 'representative', 'to', 'the', 'european', 'union', "'", 's', 'veterinary', 'committee', 'werner', 'z', '##wing', '##mann', 'said', 'on', 'wednesday', 'consumers', 'should', 'buy', 'sheep', '##me', '##at', 'from', 'countries', 'other', 'than', 'britain', 'until', 'the', 'scientific', 'advice', 'was', 'clearer', '.', '[SEP]']


In [18]:
for t, l in zip(tokens, labels):
    print(f'{t}\t\t{label_list[l]}')

[CLS]		O
germany		B-LOC
'		O
s		O
representative		O
to		O
the		O
european		B-ORG
union		I-ORG
'		O
s		O
veterinary		O
committee		O
werner		B-PER
z		I-PER
##wing		I-PER
##mann		I-PER
said		O
on		O
wednesday		O
consumers		O
should		O
buy		O
sheep		O
##me		O
##at		O
from		O
countries		O
other		O
than		O
britain		B-LOC
until		O
the		O
scientific		O
advice		O
was		O
clearer		O
.		O
[SEP]		B-MISC
