<a href="https://colab.research.google.com/github/viniciusrpb/116319_estruturasdedados/blob/main/pytorch_ner_conll.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install datasets

In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import gensim
import gensim.downloader
from datasets import load_dataset
import pandas as pd

In [7]:
word_vectors = gensim.downloader.load('word2vec-google-news-300')
#word_vectors = gensim.downloader.load('glove-twitter-25')



In [19]:
class LSTM_NER(nn.Module):

    def __init__(self, embedding_dim, hidden_dim, vocab_size, tagset_size, embedding_matrix):
        super(LSTM_NER, self).__init__()
        self.hidden_dim = hidden_dim
        self.word_embeddings = nn.Embedding.from_pretrained(torch.FloatTensor(embedding_matrix), freeze=True)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim)
        self.hidden2tag = nn.Linear(hidden_dim, tagset_size)

    def forward(self, sentence):
        embeds = self.word_embeddings(sentence)
        lstm_out, _ = self.lstm(embeds.view(len(sentence), 1, -1))
        tag_space = self.hidden2tag(lstm_out.view(len(sentence), -1))
        tag_scores = nn.functional.log_softmax(tag_space, dim=1)
        return tag_scores

In [20]:
train_dataset = load_dataset("conll2003", split="train")
valid_dataset = load_dataset("conll2003", split="validation")
test_dataset  = load_dataset("conll2003", split="test")

In [11]:
df_train = pd.DataFrame(train_dataset)

In [12]:
df_train

Unnamed: 0,id,tokens,pos_tags,chunk_tags,ner_tags
0,0,"[EU, rejects, German, call, to, boycott, Briti...","[22, 42, 16, 21, 35, 37, 16, 21, 7]","[11, 21, 11, 12, 21, 22, 11, 12, 0]","[3, 0, 7, 0, 0, 0, 7, 0, 0]"
1,1,"[Peter, Blackburn]","[22, 22]","[11, 12]","[1, 2]"
2,2,"[BRUSSELS, 1996-08-22]","[22, 11]","[11, 12]","[5, 0]"
3,3,"[The, European, Commission, said, on, Thursday...","[12, 22, 22, 38, 15, 22, 28, 38, 15, 16, 21, 3...","[11, 12, 12, 21, 13, 11, 11, 21, 13, 11, 12, 1...","[0, 3, 4, 0, 0, 0, 0, 0, 0, 7, 0, 0, 0, 0, 0, ..."
4,4,"[Germany, 's, representative, to, the, Europea...","[22, 27, 21, 35, 12, 22, 22, 27, 16, 21, 22, 2...","[11, 11, 12, 13, 11, 12, 12, 11, 12, 12, 12, 1...","[5, 0, 0, 0, 0, 3, 4, 0, 0, 0, 1, 2, 0, 0, 0, ..."
...,...,...,...,...,...
14036,14036,"[on, Friday, :]","[15, 22, 8]","[13, 11, 0]","[0, 0, 0]"
14037,14037,"[Division, two]","[21, 11]","[11, 12]","[0, 0]"
14038,14038,"[Plymouth, 2, Preston, 1]","[21, 11, 22, 11]","[11, 12, 12, 12]","[3, 0, 3, 0]"
14039,14039,"[Division, three]","[21, 11]","[11, 12]","[0, 0]"


In [13]:
def label2int():
    iob_labels = ["B", "I"]
    ner_labels = ["PER", "ORG", "LOC", "MISC"]
    all_labels = [(label1, label2) for label2 in ner_labels for label1 in iob_labels]
    all_labels = ["-".join([a, b]) for a, b in all_labels]
    dic = dict(zip(range(1, len(all_labels) + 1), all_labels))
    dic[0] = 'O'
    return dic

In [14]:
int2tag = label2int()

int2tag

{1: 'B-PER',
 2: 'I-PER',
 3: 'B-ORG',
 4: 'I-ORG',
 5: 'B-LOC',
 6: 'I-LOC',
 7: 'B-MISC',
 8: 'I-MISC',
 0: 'O'}

In [15]:
tag2int = {}
for key in int2tag:
    value = int2tag[key]
    tag2int[value] = key
print(tag2int)

num_labels = len(tag2int)

{'B-PER': 1, 'I-PER': 2, 'B-ORG': 3, 'I-ORG': 4, 'B-LOC': 5, 'I-LOC': 6, 'B-MISC': 7, 'I-MISC': 8, 'O': 0}


In [16]:
words = train_dataset['tokens']
ner_tags = train_dataset['ner_tags']

for i in range(0,len(words[0])):
    print(f'Word: {words[0][i]} | text label: {int2tag[ner_tags[0][i]]}')

Word: EU | text label: B-ORG
Word: rejects | text label: O
Word: German | text label: B-MISC
Word: call | text label: O
Word: to | text label: O
Word: boycott | text label: O
Word: British | text label: B-MISC
Word: lamb | text label: O
Word: . | text label: O


In [17]:
word_to_ix = {'OOV' : 0}
for i in range(200):
    for j in range(len(train_dataset['tokens'][i])):
        word = train_dataset['tokens'][i][j]
        if word not in word_to_ix:
            word_to_ix[word] = len(word_to_ix)

In [22]:
len(word_to_ix)

1279

In [29]:
pretrained_embeddings = []
for word, index in word_to_ix.items():
    if word in word_vectors:
        pretrained_embeddings.append(word_vectors[word])
    else:
        pretrained_embeddings.append(torch.zeros(300))

pretrained_embeddings = np.array(pretrained_embeddings)

embedding_matrix = torch.from_numpy(pretrained_embeddings)

In [30]:
# Hiperparâmetros
HIDDEN_DIM = 50
OUTPUT_DIM = len(tag2int)
EMBEDDING_DIM = len(word_vectors[0])
VOCAB_SIZE = len(word_to_ix)

In [31]:
model = LSTM_NER(EMBEDDING_DIM, HIDDEN_DIM, VOCAB_SIZE, OUTPUT_DIM, embedding_matrix)

In [32]:
loss_function = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)

In [33]:
for epoch in range(15):
    total_loss = 0
    for i in range(200):#len(train_dataset)):

        sentence = train_dataset['tokens'][i]
        tags = train_dataset['ner_tags'][i]

        model.zero_grad()

        sentence_in = torch.tensor([word_to_ix[word] for word in sentence], dtype=torch.long)
        targets = torch.tensor([tag for tag in tags], dtype=torch.long)

        tag_scores = model(sentence_in)

        loss = loss_function(tag_scores, targets)
        loss.backward()

        optimizer.step()

        total_loss += loss.item()
    print(f'Epoch {epoch} ====== Loss: {total_loss:.4f}')



Exception ignored in: <function _xla_gc_callback at 0x7a0693ef39a0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/lib/__init__.py", line 98, in _xla_gc_callback
    def _xla_gc_callback(*args):
KeyboardInterrupt: 




In [34]:
torch.no_grad()

test_sentence = "The United States started to struggle with John Tiles"

test_input = []
for word in test_sentence.split():
    if word in word_to_ix:
        word_int = word_to_ix[word]
    else:
        word_int = word_to_ix['OOV']

    test_input.append(word_int)

In [35]:
test_input

[14, 868, 869, 0, 5, 0, 22, 223, 0]

In [36]:
inputs = torch.tensor(test_input, dtype=torch.long)
tag_scores = model(inputs)
_, predicted_tags = torch.max(tag_scores, 1)
predicted_tags = [list(int2tag.keys())[idx] for idx in predicted_tags.numpy()]
print(predicted_tags)

[1, 6, 7, 1, 1, 1, 1, 2, 1]
