<a href="https://colab.research.google.com/github/poppingary/name-entity-recognition/blob/main/ExtendedCode/flair.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Import package and load model

In [None]:
!pip install --upgrade git+https://github.com/flairNLP/flair.git

Collecting git+https://github.com/flairNLP/flair.git
  Cloning https://github.com/flairNLP/flair.git to /tmp/pip-req-build-i0315i3j
  Running command git clone -q https://github.com/flairNLP/flair.git /tmp/pip-req-build-i0315i3j
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
    Preparing wheel metadata ... [?25l[?25hdone
Collecting gdown==3.12.2
  Downloading gdown-3.12.2.tar.gz (8.2 kB)
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
    Preparing wheel metadata ... [?25l[?25hdone
Collecting transformers>=4.0.0
  Downloading transformers-4.13.0-py3-none-any.whl (3.3 MB)
[K     |████████████████████████████████| 3.3 MB 8.8 MB/s 
[?25hCollecting huggingface-hub
  Downloading huggingface_hub-0.2.1-py3-none-any.whl (61 kB)
[K     |████████████████████████████████| 61 kB 642 kB/s 
Collecting pptree
  Downloading pptree-3.1.tar.gz (3.0 kB)
Collecting bpemb

In [None]:
from tqdm import tqdm

In [None]:
from flair.data import Sentence
from flair.models import SequenceTagger
tagger = SequenceTagger.load("ner-pooled")

2021-12-12 20:43:14,129 https://nlp.informatik.hu-berlin.de/resources/models/ner-pooled/en-ner-conll03-pooled-v0.5.pt not found in cache, downloading to /tmp/tmp054igmst


100%|██████████| 1125470069/1125470069 [01:04<00:00, 17487334.73B/s]

2021-12-12 20:44:19,006 copying /tmp/tmp054igmst to cache at /root/.flair/models/en-ner-conll03-pooled-v0.5.pt





2021-12-12 20:44:22,955 removing temp file /tmp/tmp054igmst
2021-12-12 20:44:23,116 loading file /root/.flair/models/en-ner-conll03-pooled-v0.5.pt


## Load templates and names

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
easy_sentence_templates_file = '/content/drive/MyDrive/IntroToMachineLearning/SentenceSample/SentenceTemplate/easy_sentence_templates.txt'
hard_sentence_templates_file = '/content/drive/MyDrive/IntroToMachineLearning/SentenceSample/SentenceTemplate/hard_sentence_templates.txt'
common_names_file = '/content/drive/MyDrive/IntroToMachineLearning/SentenceSample/NameSample/common_names.txt'
rare_names_file = '/content/drive/MyDrive/IntroToMachineLearning/SentenceSample/NameSample/rare_names.txt'

In [None]:
def read_template(file):
  all_string = ''
  with open(file, 'r', encoding='utf-8-sig') as f:
    for l in f:
      all_string += l
  sentences = all_string.split('.')
  return [s.strip() + '.' for s in sentences if s != '']

In [None]:
def read_name(file):
  names = []
  with open(file, 'r', encoding='utf-8-sig') as f:
    for l in f:
      names.append(l.strip())
  return names

In [None]:
easy_sentence_templates = read_template(easy_sentence_templates_file)
hard_sentence_templates = read_template(hard_sentence_templates_file)
common_names = read_name(common_names_file)
rare_names = read_name(rare_names_file)

## Predict

In [None]:
def is_correct(entities, name):
    for entity in entities:
      if entity.text == name:
        if entity.tag == 'PER':
          return True
        else:
          return False

def predict_sentences(templates, names):
  correct = []
  wrong = []

  for template in tqdm(templates):
    for name in names:
      sentence = template.replace('*', name)
      sentence_dictionary = Sentence(sentence)
      tagger.predict(sentence_dictionary)
      entities = sentence_dictionary.get_spans('ner')
      if is_correct(entities, name):
        correct.append((sentence_dictionary, name))
      else:
        wrong.append((sentence_dictionary, name))

  print('')
  print('correct: ' + str(len(correct)))
  print('wrong: ' + str(len(wrong)))
  print('precision: ' + str(len(correct) / (len(correct) + len(wrong))))

  return correct, wrong

In [None]:
correct, wrong = predict_sentences(easy_sentence_templates, common_names)

100%|██████████| 15/15 [04:22<00:00, 17.52s/it]


correct: 297
wrong: 3
precision: 0.99





In [None]:
correct, wrong = predict_sentences(easy_sentence_templates, rare_names)

100%|██████████| 15/15 [04:15<00:00, 17.06s/it]


correct: 293
wrong: 7
precision: 0.9766666666666667





In [None]:
correct, wrong = predict_sentences(hard_sentence_templates, common_names)

100%|██████████| 15/15 [07:11<00:00, 28.80s/it]


correct: 291
wrong: 9
precision: 0.97





In [None]:
correct, wrong = predict_sentences(hard_sentence_templates, rare_names)

100%|██████████| 15/15 [07:01<00:00, 28.12s/it]


correct: 252
wrong: 48
precision: 0.84





In [None]:
def get_per_prob(sent, name):
  for t in sent.tokens:
    if t.text == name:
      s_per = t.get_tags_proba_dist('ner')[10]
      assert s_per.value == 'S-PER'
      return s_per.score

In [None]:
def before_memory(temp, names, correct, wrong):
  tagger = SequenceTagger.load("ner-pooled")
  for name in names:
    sent_str = temp.replace('*', name)
    sent = Sentence(sent_str)
    tagger.predict(sent, all_tag_prob=True)
    entities = sent.get_spans('ner')
    sPER_score = get_per_prob(sent, name)
    if is_correct(entities, name):
      correct.append((sent, name, sPER_score))
    else:
      wrong.append((sent, name, sPER_score))
  # print("")
  # print("correct: " + str(len(correct)))
  # print("wrong: " + str(len(wrong)))
  # print("precision: " + str(len(correct)/(len(correct) + len(wrong))))
  return correct, wrong 

In [None]:
def evaluate(correct, wrong):
  def mean_prob(outs):
    total_score = 0
    for sent, name, score in outs:
      total_score += score
    return total_score/len(outs)

  print("correct: " + str(len(correct)))
  print("wrong: " + str(len(wrong)))
  print("precision: " + str(len(correct)/(len(correct) + len(wrong))))
  print("---------------------")
  print("correct prob: " + str(mean_prob(correct)))
  print("wrong prob: " + str(mean_prob(wrong)))
  print("mean prob: " + str(mean_prob(correct+wrong)))

In [None]:
print('Easy sentence with common name\n')
correct = []
wrong = []
for template in easy_sentence_templates:
  correct, wrong = before_memory(template, common_names, correct, wrong)

evaluate(correct, wrong)

Easy sentence with common name

2021-12-12 21:14:12,534 loading file /root/.flair/models/en-ner-conll03-pooled-v0.5.pt
2021-12-12 21:14:30,335 loading file /root/.flair/models/en-ner-conll03-pooled-v0.5.pt
2021-12-12 21:14:54,779 loading file /root/.flair/models/en-ner-conll03-pooled-v0.5.pt
2021-12-12 21:15:35,175 loading file /root/.flair/models/en-ner-conll03-pooled-v0.5.pt
2021-12-12 21:15:55,935 loading file /root/.flair/models/en-ner-conll03-pooled-v0.5.pt
2021-12-12 21:16:18,147 loading file /root/.flair/models/en-ner-conll03-pooled-v0.5.pt
2021-12-12 21:16:40,702 loading file /root/.flair/models/en-ner-conll03-pooled-v0.5.pt
2021-12-12 21:16:59,292 loading file /root/.flair/models/en-ner-conll03-pooled-v0.5.pt
2021-12-12 21:17:27,017 loading file /root/.flair/models/en-ner-conll03-pooled-v0.5.pt
2021-12-12 21:17:46,897 loading file /root/.flair/models/en-ner-conll03-pooled-v0.5.pt
2021-12-12 21:18:05,075 loading file /root/.flair/models/en-ner-conll03-pooled-v0.5.pt
2021-12-12 

In [None]:
print('Easy sentence with rare name\n')
correct = []
wrong = []
for template in easy_sentence_templates:
  correct, wrong = before_memory(template, rare_names, correct, wrong)

evaluate(correct, wrong)

Easy sentence with rare name

2021-12-12 21:19:54,097 loading file /root/.flair/models/en-ner-conll03-pooled-v0.5.pt
2021-12-12 21:20:11,126 loading file /root/.flair/models/en-ner-conll03-pooled-v0.5.pt
2021-12-12 21:20:35,437 loading file /root/.flair/models/en-ner-conll03-pooled-v0.5.pt
2021-12-12 21:21:15,797 loading file /root/.flair/models/en-ner-conll03-pooled-v0.5.pt
2021-12-12 21:21:37,146 loading file /root/.flair/models/en-ner-conll03-pooled-v0.5.pt
2021-12-12 21:22:00,031 loading file /root/.flair/models/en-ner-conll03-pooled-v0.5.pt
2021-12-12 21:22:22,620 loading file /root/.flair/models/en-ner-conll03-pooled-v0.5.pt
2021-12-12 21:22:41,719 loading file /root/.flair/models/en-ner-conll03-pooled-v0.5.pt
2021-12-12 21:23:08,922 loading file /root/.flair/models/en-ner-conll03-pooled-v0.5.pt
2021-12-12 21:23:29,410 loading file /root/.flair/models/en-ner-conll03-pooled-v0.5.pt
2021-12-12 21:23:46,751 loading file /root/.flair/models/en-ner-conll03-pooled-v0.5.pt
2021-12-12 21

In [None]:
print('Hard sentence with common name\n')
correct = []
wrong = []
for template in hard_sentence_templates:
  correct, wrong = before_memory(template, common_names, correct, wrong)

evaluate(correct, wrong)

Hard sentence with common name

2021-12-12 21:25:36,461 loading file /root/.flair/models/en-ner-conll03-pooled-v0.5.pt
2021-12-12 21:25:57,159 loading file /root/.flair/models/en-ner-conll03-pooled-v0.5.pt
2021-12-12 21:26:21,049 loading file /root/.flair/models/en-ner-conll03-pooled-v0.5.pt
2021-12-12 21:26:42,441 loading file /root/.flair/models/en-ner-conll03-pooled-v0.5.pt
2021-12-12 21:27:01,186 loading file /root/.flair/models/en-ner-conll03-pooled-v0.5.pt
2021-12-12 21:27:19,821 loading file /root/.flair/models/en-ner-conll03-pooled-v0.5.pt
2021-12-12 21:27:43,120 loading file /root/.flair/models/en-ner-conll03-pooled-v0.5.pt
2021-12-12 21:28:30,957 loading file /root/.flair/models/en-ner-conll03-pooled-v0.5.pt
2021-12-12 21:29:05,251 loading file /root/.flair/models/en-ner-conll03-pooled-v0.5.pt
2021-12-12 21:29:33,726 loading file /root/.flair/models/en-ner-conll03-pooled-v0.5.pt
2021-12-12 21:30:14,781 loading file /root/.flair/models/en-ner-conll03-pooled-v0.5.pt
2021-12-12 

In [None]:
print('Hard sentence with rare name\n')
correct = []
wrong = []
for template in hard_sentence_templates:
  correct, wrong = before_memory(template, rare_names, correct, wrong)

evaluate(correct, wrong)

Hard sentence with rare name

2021-12-12 21:34:19,872 loading file /root/.flair/models/en-ner-conll03-pooled-v0.5.pt
2021-12-12 21:34:40,882 loading file /root/.flair/models/en-ner-conll03-pooled-v0.5.pt
2021-12-12 21:35:04,538 loading file /root/.flair/models/en-ner-conll03-pooled-v0.5.pt
2021-12-12 21:35:25,534 loading file /root/.flair/models/en-ner-conll03-pooled-v0.5.pt
2021-12-12 21:35:43,873 loading file /root/.flair/models/en-ner-conll03-pooled-v0.5.pt
2021-12-12 21:36:02,127 loading file /root/.flair/models/en-ner-conll03-pooled-v0.5.pt
2021-12-12 21:36:25,408 loading file /root/.flair/models/en-ner-conll03-pooled-v0.5.pt
2021-12-12 21:37:13,079 loading file /root/.flair/models/en-ner-conll03-pooled-v0.5.pt
2021-12-12 21:37:47,186 loading file /root/.flair/models/en-ner-conll03-pooled-v0.5.pt
2021-12-12 21:38:15,621 loading file /root/.flair/models/en-ner-conll03-pooled-v0.5.pt
2021-12-12 21:38:56,128 loading file /root/.flair/models/en-ner-conll03-pooled-v0.5.pt
2021-12-12 21