In [5]:
# !pip install transformers datasets torch flask flask-cors  scikit-learn seqeval
# !pip install accelerate -U

Installing collected packages: accelerate
Successfully installed accelerate-0.30.1


In [1]:
import torch
from transformers import BertTokenizerFast, BertForTokenClassification, Trainer, TrainingArguments, DataCollatorForTokenClassification
from datasets import load_dataset, load_metric
import numpy as np
from flask import Flask, request, jsonify
from flask_cors import CORS

# Load dataset
dataset = load_dataset("conll2003")
label_list = dataset['train'].features['ner_tags'].feature.names

# Initialize tokenizer and model
tokenizer = BertTokenizerFast.from_pretrained("bert-base-cased")
model = BertForTokenClassification.from_pretrained("bert-base-cased", num_labels=len(label_list))

# Define label_all_tokens
label_all_tokens = True

def tokenize_and_align_labels(examples):
    tokenized_inputs = tokenizer(examples['tokens'], truncation=True, is_split_into_words=True, padding='max_length', max_length=128)
    labels = []
    for i, label in enumerate(examples[f'ner_tags']):
        word_ids = tokenized_inputs.word_ids(batch_index=i)
        previous_word_idx = None
        label_ids = []
        for word_idx in word_ids:
            if word_idx is None:
                label_ids.append(-100)
            elif word_idx != previous_word_idx:
                label_ids.append(label[word_idx])
            else:
                label_ids.append(label[word_idx] if label_all_tokens else -100)
            previous_word_idx = word_idx
        labels.append(label_ids)
    tokenized_inputs["labels"] = labels
    return tokenized_inputs

# Tokenize and align labels for train and validation sets
tokenized_datasets = dataset.map(tokenize_and_align_labels, batched=True)

# Load metric
metric = load_metric("seqeval")

def compute_metrics(p):
    predictions, labels = p
    predictions = np.argmax(predictions, axis=2)

    true_predictions = [
        [label_list[p] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    true_labels = [
        [label_list[l] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]

    results = metric.compute(predictions=true_predictions, references=true_labels)
    return {
        "precision": results["overall_precision"],
        "recall": results["overall_recall"],
        "f1": results["overall_f1"],
        "accuracy": results["overall_accuracy"],
    }

# Training arguments
training_args = TrainingArguments(
    output_dir="./results",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=3,
    weight_decay=0.01,
)

# Data collator for padding
data_collator = DataCollatorForTokenClassification(tokenizer)

# Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

# Train model
trainer.train()

# Save model
trainer.save_model("ner_model")

# Initialize Flask app
app = Flask(__name__)
CORS(app)

@app.route('/extract_entities', methods=['POST'])
def extract_entities():
    data = request.get_json()
    text = data['text']
    tokens = tokenizer.tokenize(tokenizer.decode(tokenizer.encode(text)))
    inputs = tokenizer.encode(text, return_tensors="pt")
    outputs = model(inputs).logits
    predictions = torch.argmax(outputs, dim=2)
    entities = []
    for token, label_id in zip(tokens, predictions[0].tolist()):
        entities.append({"word": token, "entity": label_list[label_id]})
    return jsonify(entities)

if __name__ == "__main__":
    app.run(host='0.0.0.0', port=5000)


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.
Some weights of BertForTokenClassification were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  metric = load_metric("seqeval")
You can avoid this message in future by passing the argument `tr

Epoch,Training Loss,Validation Loss,Precision,Recall,F1,Accuracy
1,0.244,0.07127,0.906956,0.919091,0.912983,0.978981
2,0.051,0.063794,0.929798,0.933639,0.931714,0.983768
3,0.028,0.059403,0.936998,0.942888,0.939934,0.985212


 * Serving Flask app '__main__'
 * Debug mode: off


 * Running on all addresses (0.0.0.0)
 * Running on http://127.0.0.1:5000
 * Running on http://172.28.0.12:5000
INFO:werkzeug:[33mPress CTRL+C to quit[0m


In [5]:
import requests 
response = requests.post(url = 'http://127.0.0.1:5000/extract_entities', json = {"text":"Which country is highly speceializing in cognitive science and mind building?"})
response.status_code

200

In [6]:
response.json()

[{'entity': 'O', 'word': '[CLS]'},
 {'entity': 'O', 'word': 'Which'},
 {'entity': 'O', 'word': 'country'},
 {'entity': 'O', 'word': 'is'},
 {'entity': 'O', 'word': 'highly'},
 {'entity': 'O', 'word': 's'},
 {'entity': 'O', 'word': '##pec'},
 {'entity': 'O', 'word': '##ei'},
 {'entity': 'O', 'word': '##ali'},
 {'entity': 'O', 'word': '##zing'},
 {'entity': 'O', 'word': 'in'},
 {'entity': 'O', 'word': 'cognitive'},
 {'entity': 'O', 'word': 'science'},
 {'entity': 'O', 'word': 'and'},
 {'entity': 'O', 'word': 'mind'},
 {'entity': 'O', 'word': 'building'},
 {'entity': 'O', 'word': '?'},
 {'entity': 'O', 'word': '[SEP]'}]