In [2]:
from gliner import GLiNER

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# available models: https://huggingface.co/urchade

model = GLiNER.from_pretrained("urchade/gliner_medium")
model.eval()
print("ok")



ok


In [4]:
text = """
Libretto by Marius Petipa, based on the 1822 novella ``Trilby, ou Le Lutin d'Argail`` by Charles Nodier, first presented by the Ballet of the Moscow Imperial Bolshoi Theatre on January 25/February 6 (Julian/Gregorian calendar dates), 1870, in Moscow with Polina Karpakova as Trilby and Ludiia Geiten as Miranda and restaged by Petipa for the Imperial Ballet at the Imperial Bolshoi Kamenny Theatre on January 17–29, 1871 in St. Petersburg with Adèle Grantzow as Trilby and Lev Ivanov as Count Leopold.
"""

labels = ["person", "book", "location", "date", "actor", "character"]

entities = model.predict_entities(text, labels, threshold=0.4)

for entity in entities:
    print(entity["text"], "=>", entity["label"])

Marius Petipa => actor
1822 => date
Trilby => character
Charles Nodier => person
Moscow => location
January 25/February 6 => date
Julian/Gregorian => date
1870 => date
Moscow => location
Polina Karpakova => actor
Trilby => character
Ludiia Geiten => actor
Miranda => character
Imperial Bolshoi Kamenny Theatre => location
January 17–29, 1871 => date
St. Petersburg => location
Adèle Grantzow => actor
Trilby => character
Lev Ivanov => actor
Count Leopold => character


In [16]:
import re

def predict_entities(self, text, labels, flat_ner=True, threshold=0.5):
    tokens = []
    start_token_idx_to_text_idx = []
    end_token_idx_to_text_idx = []
    for match in re.finditer(r'\w+(?:[-_]\w+)*|\S', text):
        tokens.append(match.group())
        start_token_idx_to_text_idx.append(match.start())
        end_token_idx_to_text_idx.append(match.end())

    input_x = {"tokenized_text": tokens, "ner": None}
    x = self.collate_fn([input_x], labels)
    output = self.predict(x, flat_ner=flat_ner, threshold=threshold)

    entities = []
    for start_token_idx, end_token_idx, ent_type in output[0]:
        start_text_idx = start_token_idx_to_text_idx[start_token_idx]
        end_text_idx = end_token_idx_to_text_idx[end_token_idx]
        entities.append({
            "start": start_token_idx_to_text_idx[start_token_idx],
            "end": end_token_idx_to_text_idx[end_token_idx],
            "text": text[start_text_idx:end_text_idx],
            "label": ent_type,
        })
    return entities


def batch_predict_entities(self, texts, labels, flat_ner=True, threshold=0.5):

    all_tokens = []

    for text in texts:
        tokens = []
        start_token_idx_to_text_idx = []
        end_token_idx_to_text_idx = []
        for match in re.finditer(r'\w+(?:[-_]\w+)*|\S', text):
            tokens.append(match.group())
            start_token_idx_to_text_idx.append(match.start())
            end_token_idx_to_text_idx.append(match.end())
        all_tokens.append(tokens)

    input_x = [{"tokenized_text": tk, "ner": None} for tk in all_tokens]
    x = self.collate_fn(input_x, labels)
    outputs = self.predict(x, flat_ner=flat_ner, threshold=threshold)

    all_entities = []
    for i, output in enumerate(outputs):
        entities = []
        for start_token_idx, end_token_idx, ent_type in output:
            start_text_idx = start_token_idx_to_text_idx[start_token_idx]
            end_text_idx = end_token_idx_to_text_idx[end_token_idx]
            entities.append({
                "start": start_token_idx_to_text_idx[start_token_idx],
                "end": end_token_idx_to_text_idx[end_token_idx],
                "text": texts[i][start_text_idx:end_text_idx],
                "label": ent_type,
            })
        all_entities.append(entities)
        
    return all_entities

In [17]:
batch_predict_entities(model, [text, text], labels)

[[{'start': 13, 'end': 26, 'text': 'Marius Petipa', 'label': 'actor'},
  {'start': 41, 'end': 45, 'text': '1822', 'label': 'date'},
  {'start': 56, 'end': 62, 'text': 'Trilby', 'label': 'character'},
  {'start': 90, 'end': 104, 'text': 'Charles Nodier', 'label': 'person'},
  {'start': 143, 'end': 149, 'text': 'Moscow', 'label': 'location'},
  {'start': 178, 'end': 199, 'text': 'January 25/February 6', 'label': 'date'},
  {'start': 235, 'end': 239, 'text': '1870', 'label': 'date'},
  {'start': 244, 'end': 250, 'text': 'Moscow', 'label': 'location'},
  {'start': 256, 'end': 272, 'text': 'Polina Karpakova', 'label': 'actor'},
  {'start': 276, 'end': 282, 'text': 'Trilby', 'label': 'character'},
  {'start': 287, 'end': 300, 'text': 'Ludiia Geiten', 'label': 'actor'},
  {'start': 304, 'end': 311, 'text': 'Miranda', 'label': 'character'},
  {'start': 366,
   'end': 398,
   'text': 'Imperial Bolshoi Kamenny Theatre',
   'label': 'location'},
  {'start': 402, 'end': 421, 'text': 'January 17–29

In [8]:
predict_entities(model, text, labels)

[{'start': 13, 'end': 26, 'text': 'Marius Petipa', 'label': 'actor'},
 {'start': 41, 'end': 45, 'text': '1822', 'label': 'date'},
 {'start': 56, 'end': 62, 'text': 'Trilby', 'label': 'character'},
 {'start': 90, 'end': 104, 'text': 'Charles Nodier', 'label': 'person'},
 {'start': 143, 'end': 149, 'text': 'Moscow', 'label': 'location'},
 {'start': 178, 'end': 199, 'text': 'January 25/February 6', 'label': 'date'},
 {'start': 235, 'end': 239, 'text': '1870', 'label': 'date'},
 {'start': 244, 'end': 250, 'text': 'Moscow', 'label': 'location'},
 {'start': 256, 'end': 272, 'text': 'Polina Karpakova', 'label': 'actor'},
 {'start': 276, 'end': 282, 'text': 'Trilby', 'label': 'character'},
 {'start': 287, 'end': 300, 'text': 'Ludiia Geiten', 'label': 'actor'},
 {'start': 304, 'end': 311, 'text': 'Miranda', 'label': 'character'},
 {'start': 366,
  'end': 398,
  'text': 'Imperial Bolshoi Kamenny Theatre',
  'label': 'location'},
 {'start': 402, 'end': 421, 'text': 'January 17–29, 1871', 'label':

In [None]:
    def predict_entities(self, text, labels, flat_ner=True, threshold=0.5):
        tokens = []
        start_token_idx_to_text_idx = []
        end_token_idx_to_text_idx = []
        for match in re.finditer(r'\w+(?:[-_]\w+)*|\S', text):
            tokens.append(match.group())
            start_token_idx_to_text_idx.append(match.start())
            end_token_idx_to_text_idx.append(match.end())

        input_x = {"tokenized_text": tokens, "ner": None}
        x = self.collate_fn([input_x], labels)
        output = self.predict(x, flat_ner=flat_ner, threshold=threshold)

        entities = []
        for start_token_idx, end_token_idx, ent_type in output[0]:
            start_text_idx = start_token_idx_to_text_idx[start_token_idx]
            end_text_idx = end_token_idx_to_text_idx[end_token_idx]
            entities.append({
                "start": start_token_idx_to_text_idx[start_token_idx],
                "end": end_token_idx_to_text_idx[end_token_idx],
                "text": text[start_text_idx:end_text_idx],
                "label": ent_type,
            })
        return entities

    def evaluate(self, test_data, flat_ner=False, threshold=0.5, batch_size=12, entity_types=None):
        self.eval()
        data_loader = self.create_dataloader(test_data, batch_size=batch_size, entity_types=entity_types, shuffle=False)
        device = next(self.parameters()).device
        all_preds = []
        all_trues = []
        for x in data_loader:
            for k, v in x.items():
                if isinstance(v, torch.Tensor):
                    x[k] = v.to(device)
            batch_predictions = self.predict(x, flat_ner, threshold)
            all_preds.extend(batch_predictions)
            all_trues.extend(x["entities"])
        evaluator = Evaluator(all_trues, all_preds)
        out, f1 = evaluator.evaluate()
        return out, f1