In [27]:
import torch
from transformers import BertTokenizerFast, BertForTokenClassification


# Define the label mappings manually
id2label = {
    0: 'O',        # Outside of a named entity
    1: 'B-PER',    # Beginning of a person's name right after another person's name
    2: 'I-PER',    # Person's name
    3: 'B-LOC',    # Beginning of a location right after another location
    4: 'I-LOC',    # Location
    5: 'B-ORG',    # Beginning of an organization right after another organization
    6: 'I-ORG',    # Organization
    7: 'B-MISC',   # Beginning of a miscellaneous entity right after another miscellaneous entity
    8: 'I-MISC',   # Miscellaneous entity
}

def predict_entities(sentence, model, tokenizer, id2label):
    # Tokenize the input sentence and convert to tensor
    sentence = sentence.lower()
    inputs = tokenizer(sentence, return_tensors="pt", truncation=True, padding=True, is_split_into_words=False)

    # Move input tensors to the same device as model
    inputs = {k: v.to(model.device) for k, v in inputs.items()}

    # Get model predictions
    with torch.no_grad():
        outputs = model(**inputs)

    predictions = outputs.logits.argmax(dim=-1).squeeze().tolist()
    tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"].squeeze().tolist())
    predicted_labels = [id2label[pred] for pred in predictions]

    token_label_pairs = list(zip(tokens, predicted_labels))
    token_label_pairs = [(token, label) for token, label in token_label_pairs if token not in (tokenizer.cls_token, tokenizer.sep_token, tokenizer.pad_token)]

    return token_label_pairs

model_save_path = 'model_loc'
tokenizer_save_path = 'tokenizer_loc'
# Load the model
model = BertForTokenClassification.from_pretrained(model_save_path)

# Load the tokenizer
tokenizer = BertTokenizerFast.from_pretrained(tokenizer_save_path)


# Assuming the model and tokenizer are loaded and available




In [30]:
sentence = "I have to attend a meeting with Mohit at Hotel Tuli Imperial on 22nd Aprin, 4PM"
token_label_pairs = predict_entities(sentence, model, tokenizer, id2label)

# for token, label in token_label_pairs:
#     print(f"{token}: {label}")

person = []
location = []

for token, label in token_label_pairs:
    if label == "B-PER":
        if("##" in token):
          person.append(token.replace("##", ""))
        else:
          person.append(token)
    elif label == "I-PER":
        person.append(token.replace("##", ""))
    elif label == "B-ORG":
        location = [token.replace("##", "")]
    elif label == "I-ORG":
        location.append(token.replace("##", ""))

output = {
    "person": "".join(person),
    "location": " ".join(location)
}

print(output)


{'person': 'mohit', 'location': 'hotel tu li imperial'}
