In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
model_dir = './drive/MyDrive/ml_class_group_project/Vijay/bert-fine-tuned-ner/checkpoint-5268'

In [3]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(model_dir)


In [4]:
from transformers import AutoModelForTokenClassification

model = AutoModelForTokenClassification.from_pretrained(model_dir)

In [5]:
import torch

def predict(text:str):

  inputs = tokenizer(text, return_tensors="pt")

  ids_list = inputs['input_ids']
  tokens = inputs.tokens()

  with torch.no_grad():

    logits = model(**inputs).logits


    predictions = torch.argmax(logits, dim=2)


    predicted_token_class = [model.config.id2label[t.item()] for t in predictions[0]]

  return (tokens, ids_list, predicted_token_class)

In [11]:
def get_entities(text, entity):

  result_from_predict_function = predict(text)

  tokens_list = result_from_predict_function[0]
  ids_list = result_from_predict_function[1].tolist()[0]
  labels_list = result_from_predict_function[2]

  length_of_labels_list = len(labels_list)

  index = 0

  all = []

  while index < length_of_labels_list:

    current_index = index
    group = []

    while current_index < length_of_labels_list and entity in labels_list[current_index]:

      group.append(ids_list[current_index])
      current_index += 1


    if len(group) > 0:

      all.append(tokenizer.decode(group))

    if current_index == index:

      index += 1

    else:

      index = current_index


  if len(all) == 0:

    entity_to_full_name_of_entity_dict = {}

    entity_to_full_name_of_entity_dict['ORG'] = 'organization'
    entity_to_full_name_of_entity_dict['PER'] = 'person'
    entity_to_full_name_of_entity_dict['LOC'] = 'location'
    entity_to_full_name_of_entity_dict['MISC'] = 'miscellaneous entity'


    all.append(f'No {entity_to_full_name_of_entity_dict[entity]} inside text')

  return all




In [21]:

# Single Word Organization Tests

entities = get_entities('Microsoft and Google are large tech companies', 'ORG')
print(entities)

entities = get_entities('Apple and Amazon are competing with each other', 'ORG')
print(entities)

entities = get_entities('Spring has sprung at Amazon', 'ORG')
print(entities)

entities = get_entities('George is a student at university', 'PER')
print(entities)

entities = get_entities('Johnny delivers pizza in New York.', 'PER')
print(entities)

entities = get_entities('German is a language spoken in Germany', 'MISC')
print(entities)

entities = get_entities('Amazon and Tesla are currently the best picks out there', 'ORG')
print(entities)

# Multiple Word Organization Tests

entities = get_entities('Hewlett Packard Enterprise is a company.', 'ORG')
print(entities)

entities = get_entities('Mukesh Ambani is the chairman of Reliance Industries Limited', 'ORG')
print(entities)

entities = get_entities('British Airways is better than American Airlines.', 'ORG')
print(entities)

entities = get_entities('The United Nations and the World Bank are well known organizations.', 'ORG')
print(entities)

entities = get_entities('The World Trade Organization is an intergovernmental organization headquartered in Geneva, Switzerland.', 'ORG')
print(entities)

entities = get_entities('Johnny delivers pizza in New York.', 'LOC')
print(entities)


['Microsoft', 'Google']
['Apple', 'Amazon']
['No organization inside text']
['George']
['Johnny']
['German']
['Amazon', 'Tesla']
['Hewlett Packard Enterprise']
['Reliance Industries Limited']
['British Airways', 'American Airlines']
['United Nations', 'World Bank']
['World Trade Organization']
['New York']
