<a href="https://colab.research.google.com/github/saitejakarre/Saiteja/blob/main/NER_end_to_end_code.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np

In [None]:
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForTokenClassification, DataCollatorForTokenClassification, TrainingArguments, Trainer
import numpy as np

In [None]:
pip install "datasets<4.0.0"

In [None]:
dataset = load_dataset("conll2003",trust_remote_code=True)

In [None]:
dataset

In [None]:
label_list = dataset["train"].features["ner_tags"].feature.names
num_labels = len(label_list)

In [None]:
model_checkpoint = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

In [None]:
model = AutoModelForTokenClassification.from_pretrained(
    model_checkpoint,
    num_labels=num_labels
)

In [None]:
#to assign the splitted tokens to their original words
def tokenize_and_align_labels(examples):
    tokenized_inputs = tokenizer(
        examples["tokens"], truncation=True, is_split_into_words=True
    )
    labels = []
    for i, label in enumerate(examples["ner_tags"]):
        word_ids = tokenized_inputs.word_ids(batch_index=i)
        label_ids = []
        previous_word_id = None
        for word_id in word_ids:
            if word_id is None:

                label_ids.append(-100)
            elif word_id != previous_word_id:
                label_ids.append(label[word_id])
            else:
                label_ids.append(-100)
            previous_word_id = word_id
        labels.append(label_ids)
    tokenized_inputs["labels"] = labels
    return tokenized_inputs

In [None]:
tokenized_datasets = dataset.map(tokenize_and_align_labels, batched=True)


In [None]:
data_collator = DataCollatorForTokenClassification(tokenizer)

In [None]:
!pip install "evaluate==0.4.0"

import evaluate

In [None]:
!pip install seqeval

In [None]:
metric = evaluate.load("seqeval")

In [None]:
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=2)

    true_labels = [[label_list[l] for l in label if l != -100] for label in labels]
    true_predictions = [
        [label_list[p] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    results = metric.compute(predictions=true_predictions, references=true_labels)
    return {
        "precision": results["overall_precision"],
        "recall": results["overall_recall"],
        "f1": results["overall_f1"],
        "accuracy": results["overall_accuracy"],
    }

In [None]:
training_args = TrainingArguments(
    output_dir="./ner-distilbert",
    eval_strategy="epoch",
    save_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=3,
    weight_decay=0.01,
    logging_dir="./logs",
    report_to="none",
    logging_steps=10
)

In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics
)

In [None]:
trainer.train()

In [None]:
trainer.save_model("./ner-distilbert")

In [None]:
from transformers import pipeline

In [None]:
ner_pipeline = pipeline(
    "ner",
    model="./ner-distilbert",
    tokenizer="./ner-distilbert",
    aggregation_strategy="simple"
)


In [None]:
def decode_labels(results):
    decoded = []
    for ent in results:
        label_id = int(ent["entity_group"].split("_")[-1])
        ent["entity_group"] = label_list[label_id]
        decoded.append(ent)
    return decoded

In [None]:
sentence = "SaiTeja living in Tirumala with lord venkateshwara."
raw_results = ner_pipeline(sentence)

In [None]:
final_results = decode_labels(raw_results)

In [None]:
for entity in final_results:
    print(f"Entity: {entity['word']}, Label: {entity['entity_group']}")
