From aedb60420709fab9cf88cf5b1e074efa354e3dd5 Mon Sep 17 00:00:00 2001 From: Ingvar Date: Tue, 6 Jan 2026 15:40:32 +0200 Subject: [PATCH 1/8] implement token-level relex model --- gliner/config.py | 27 ++- gliner/data_processing/__init__.py | 1 + gliner/data_processing/collator.py | 10 + gliner/data_processing/processor.py | 350 +++++++++++++++++++++------- gliner/decoding/__init__.py | 2 +- gliner/decoding/decoder.py | 212 +++++++++++++++++ gliner/model.py | 102 +++++++- gliner/modeling/base.py | 186 +++++++++++++-- gliner/onnx/model.py | 48 ++++ 9 files changed, 823 insertions(+), 115 deletions(-) diff --git a/gliner/config.py b/gliner/config.py index 8cdd817..d368d68 100644 --- a/gliner/config.py +++ b/gliner/config.py @@ -172,9 +172,7 @@ def __init__( raise ValueError("UniEncoderSpanDecoderConfig requires span_mode != 'token_level'") -class UniEncoderSpanRelexConfig(UniEncoderConfig): - """Configuration for uni-encoder span model with relation extraction.""" - +class UniEncoderRelexConfig(UniEncoderConfig): def __init__( self, relations_layer: Optional[str] = None, @@ -187,7 +185,7 @@ def __init__( relation_loss_coef=1.0, **kwargs, ): - """Initialize UniEncoderSpanRelexConfig. + """Initialize UniEncoderRelexConfig. Args: relations_layer (str, optional): Name of relations layer, @@ -215,11 +213,26 @@ def __init__( self.span_loss_coef = span_loss_coef self.adjacency_loss_coef = adjacency_loss_coef self.relation_loss_coef = relation_loss_coef + +class UniEncoderSpanRelexConfig(UniEncoderRelexConfig): + """Configuration for uni-encoder span model with relation extraction.""" + + def __init__(self, **kwargs): + super().__init__(**kwargs) self.model_type = "gliner_uni_encoder_span_relex" if self.span_mode == "token_level": raise ValueError("UniEncoderSpanRelexConfig requires span_mode != 'token_level'") +class UniEncoderTokenRelexConfig(UniEncoderRelexConfig): + """Configuration for uni-encoder token-level model with relation extraction.""" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.model_type = "gliner_uni_encoder_token_relex" + self.span_mode = "token_level" + + class BiEncoderConfig(BaseGLiNERConfig): """Base configuration for bi-encoder GLiNER models.""" @@ -302,7 +315,10 @@ def model_type(self): elif self.labels_encoder: return "gliner_bi_encoder_span" if self.span_mode != "token-level" else "gliner_bi_encoder_token" elif self.relations_layer is not None: - return "gliner_uni_encoder_span_relex" + if self.span_mode == 'token-level': + return "gliner_uni_encoder_token_relex" + else: + return "gliner_uni_encoder_span_relex" elif self.span_mode == "token-level": return "gliner_uni_encoder_token" else: @@ -319,6 +335,7 @@ def model_type(self): "gliner_uni_encoder_token": UniEncoderTokenConfig, "gliner_uni_encoder_span_decoder": UniEncoderSpanDecoderConfig, "gliner_uni_encoder_span_relex": UniEncoderSpanRelexConfig, + "gliner_uni_encoder_token_relex": UniEncoderTokenRelexConfig, "gliner_bi_encoder": BiEncoderConfig, "gliner_bi_encoder_span": BiEncoderSpanConfig, "gliner_bi_encoder_token": BiEncoderTokenConfig, diff --git a/gliner/data_processing/__init__.py b/gliner/data_processing/__init__.py index 95be360..102ecd4 100644 --- a/gliner/data_processing/__init__.py +++ b/gliner/data_processing/__init__.py @@ -15,5 +15,6 @@ UniEncoderTokenProcessor, UniEncoderSpanDecoderProcessor, RelationExtractionSpanProcessor, + RelationExtractionTokenProcessor ) from .tokenizer import WordsSplitter diff --git a/gliner/data_processing/collator.py b/gliner/data_processing/collator.py index 95875f7..3cffc3e 100644 --- a/gliner/data_processing/collator.py +++ b/gliner/data_processing/collator.py @@ -475,6 +475,16 @@ def _filter_none_values(self, batch_dict: Dict[str, Any]) -> Dict[str, Any]: return {k: v for k, v in batch_dict.items() if v is not None} +class RelationExtractionTokenDataCollator(RelationExtractionSpanDataCollator): + """Data collator for RelationExtractionTokenProcessor. + + Handles joint entity and relation extraction at token level. + Produces both entity labels and relation adjacency matrices. + + Required Processor: RelationExtractionTokenProcessor + """ + pass + class UniEncoderSpanDataCollator(SpanDataCollator): """ Backward compatibility alias for SpanDataCollator with UniEncoderSpanProcessor. diff --git a/gliner/data_processing/processor.py b/gliner/data_processing/processor.py index f2ad777..9ce4d12 100644 --- a/gliner/data_processing/processor.py +++ b/gliner/data_processing/processor.py @@ -429,6 +429,14 @@ class UniEncoderSpanProcessor(BaseProcessor): predict entity types for all possible spans up to a maximum width. """ + def prepare_span_labels(self, ner, classes_to_id, num_tokens): + dict_lab = self.get_dict(ner, classes_to_id) if ner else defaultdict(int) + span_label = torch.LongTensor([dict_lab[i] for i in spans_idx]) + spans_idx = torch.LongTensor(spans_idx) + valid_span_mask = spans_idx[:, 1] > num_tokens - 1 + span_label = span_label.masked_fill(valid_span_mask, -1) + return span_label, spans_idx + def preprocess_example(self, tokens, ner, classes_to_id): """Preprocess a single example for span-based prediction. @@ -461,11 +469,8 @@ def preprocess_example(self, tokens, ner, classes_to_id): tokens = tokens[:max_len] num_tokens = len(tokens) spans_idx = prepare_span_idx(num_tokens, max_width) - dict_lab = self.get_dict(ner, classes_to_id) if ner else defaultdict(int) - span_label = torch.LongTensor([dict_lab[i] for i in spans_idx]) - spans_idx = torch.LongTensor(spans_idx) - valid_span_mask = spans_idx[:, 1] > num_tokens - 1 - span_label = span_label.masked_fill(valid_span_mask, -1) + + span_label, spans_idx = self.prepare_span_labels(ner, classes_to_id, num_tokens) return { "tokens": tokens, @@ -1224,6 +1229,29 @@ def collate_raw_batch( return self.create_batch_dict(batch, class_to_ids, id_to_classes, rel_class_to_ids, rel_id_to_classes) + def sort_entities_and_relations(self, ner, relations): + if ner is not None and len(ner) > 0: + indexed_ner = list(enumerate(ner)) + indexed_ner_sorted = sorted(indexed_ner, key=lambda x: (x[1][0], x[1][1])) + + ner_sorted = [entity for _, entity in indexed_ner_sorted] + + # Create mapping from old entity indices to new sorted indices + old_to_new_idx = {old_idx: new_idx for new_idx, (old_idx, _) in enumerate(indexed_ner_sorted)} + + # Update relation indices to match new entity ordering + if relations is not None and len(relations) > 0: + updated_relations = [] + for head_idx, tail_idx, rel_type in relations: + if head_idx in old_to_new_idx and tail_idx in old_to_new_idx: + new_head_idx = old_to_new_idx[head_idx] + new_tail_idx = old_to_new_idx[tail_idx] + updated_relations.append((new_head_idx, new_tail_idx, rel_type)) + relations = sorted(updated_relations, key=lambda x: (x[0], x[1])) + + ner = ner_sorted + return ner, relations + def preprocess_example(self, tokens, ner, classes_to_id, relations, rel_classes_to_id): """Preprocess a single example for joint entity and relation extraction. @@ -1263,31 +1291,10 @@ def preprocess_example(self, tokens, ner, classes_to_id, relations, rel_classes_ num_tokens = len(tokens) spans_idx = prepare_span_idx(num_tokens, max_width) - if ner is not None and len(ner) > 0: - indexed_ner = list(enumerate(ner)) - indexed_ner_sorted = sorted(indexed_ner, key=lambda x: (x[1][0], x[1][1])) - - ner_sorted = [entity for _, entity in indexed_ner_sorted] - - old_to_new_idx = {old_idx: new_idx for new_idx, (old_idx, _) in enumerate(indexed_ner_sorted)} - - if relations is not None and len(relations) > 0: - updated_relations = [] - for head_idx, tail_idx, rel_type in relations: - if head_idx in old_to_new_idx and tail_idx in old_to_new_idx: - new_head_idx = old_to_new_idx[head_idx] - new_tail_idx = old_to_new_idx[tail_idx] - updated_relations.append((new_head_idx, new_tail_idx, rel_type)) - relations = sorted(updated_relations, key=lambda x: (x[0], x[1])) - - ner = ner_sorted + ner, relations = self.sort_entities_and_relations(ner, relations) # Process entity labels - dict_lab = self.get_dict(ner, classes_to_id) if ner else defaultdict(int) - span_label = torch.LongTensor([dict_lab[i] for i in spans_idx]) - spans_idx = torch.LongTensor(spans_idx) - valid_span_mask = spans_idx[:, 1] > num_tokens - 1 - span_label = span_label.masked_fill(valid_span_mask, -1) + span_label, spans_idx = self.prepare_span_labels(ner, classes_to_id, num_tokens) # Create entity span to index mapping span_to_idx = {(spans_idx[i, 0].item(), spans_idx[i, 1].item()): i for i in range(len(spans_idx))} @@ -1326,6 +1333,7 @@ def preprocess_example(self, tokens, ner, classes_to_id, relations, rel_classes_ "span_label": span_label, "seq_length": num_tokens, "entities": ner, + "entities_id": entity_to_span_idx, "relations": relations, "rel_idx": rel_idx, "rel_label": rel_label, @@ -1349,6 +1357,8 @@ def create_batch_dict(self, batch, class_to_ids, id_to_classes, rel_class_to_ids entities = [el["entities"] for el in batch] relations = [el["relations"] for el in batch] + entities_id = [el["entities_id"] for el in batch] + span_idx = pad_sequence([b["span_idx"] for b in batch], batch_first=True, padding_value=0) span_label = pad_sequence([el["span_label"] for el in batch], batch_first=True, padding_value=-1) rel_idx = pad_sequence([el["rel_idx"] for el in batch], batch_first=True, padding_value=0) @@ -1364,6 +1374,7 @@ def create_batch_dict(self, batch, class_to_ids, id_to_classes, rel_class_to_ids "span_mask": span_mask, "span_label": span_label, "entities": entities, + "entities_id": entities_id, "relations": relations, "rel_idx": rel_idx, "rel_label": rel_label, @@ -1373,44 +1384,42 @@ def create_batch_dict(self, batch, class_to_ids, id_to_classes, rel_class_to_ids "rel_id_to_classes": rel_id_to_classes, } + def create_relation_labels(self, batch, add_reversed_negatives=True, add_random_negatives=True, negative_ratio=2.0): """Create relation labels with negative pair sampling. - Generates training labels for relation extraction including both positive - relation pairs and carefully sampled negative pairs for contrastive learning. + Overrides the span-based version to work with token-level entity representations. + Uses entities_id count instead of span_label for entity counting. Args: batch: Batch dictionary containing entities and relations. - add_reversed_negatives: If True, add reversed direction pairs as - negatives (h,t) -> (t,h). These are important hard negatives - for learning relation directionality. - add_random_negatives: If True, add random entity pairs as negatives - to provide additional training signal. - negative_ratio: Ratio of negative to positive pairs. For example, - 2.0 means twice as many negatives as positives. + add_reversed_negatives: If True, add reversed direction pairs as negatives. + add_random_negatives: If True, add random entity pairs as negatives. + negative_ratio: Ratio of negative to positive pairs. Returns: Tuple containing: - - adj_matrix: Adjacency matrix indicating which entity pairs - to consider (shape: [B, max_entities, max_entities]) - - rel_matrix: Multi-hot encoded relation labels for each pair - (shape: [B, max_pairs, num_relation_classes]) + - adj_matrix: Adjacency matrix (shape: [B, max_entities, max_entities]) + - rel_matrix: Multi-hot relation labels (shape: [B, max_pairs, num_relation_classes]) """ B = len(batch["tokens"]) - entity_label = batch["span_label"] + entities_id = batch["entities_id"] - batch_ents = torch.sum(entity_label > 0, dim=-1) - max_En = torch.max(batch_ents).item() + # Count entities per sample (differs from span-based which uses span_label) + batch_ents = torch.LongTensor([len(ent_list) for ent_list in entities_id]) + max_En = max(batch_ents.max().item(), 1) rel_class_to_ids = batch["rel_class_to_ids"] if isinstance(rel_class_to_ids, list): - C = max(len(r) for r in rel_class_to_ids) + C = max((len(r) for r in rel_class_to_ids), default=0) else: - C = len(rel_class_to_ids) + C = len(rel_class_to_ids) if rel_class_to_ids else 0 + + if C == 0: + return torch.zeros(B, max_En, max_En, dtype=torch.float), torch.zeros(B, 1, 1, dtype=torch.float) adj_matrix = torch.zeros(B, max_En, max_En, dtype=torch.float) - # Collect all pairs (positive + negative) and their relations all_pairs_info = [] max_total_pairs = 0 @@ -1419,7 +1428,6 @@ def create_relation_labels(self, batch, add_reversed_negatives=True, add_random_ rel_idx_i = batch["rel_idx"][i] rel_label_i = batch["rel_label"][i] - # Dictionary to group relations by entity pair pair_to_relations = {} positive_pairs = set() @@ -1434,35 +1442,28 @@ def create_relation_labels(self, batch, add_reversed_negatives=True, add_random_ positive_pairs.add(pair_key) if pair_key not in pair_to_relations: pair_to_relations[pair_key] = [] - class_id = rel_label_i[k].item() - pair_to_relations[pair_key].append(class_id) + pair_to_relations[pair_key].append(rel_label_i[k].item()) # Generate negative pairs negative_pairs = set() num_positives = len(positive_pairs) target_negatives = int(num_positives * negative_ratio) - # 1. Add reversed pairs as negatives (most important!) if add_reversed_negatives: for e1, e2 in positive_pairs: reversed_pair = (e2, e1) - # Only add if reversed pair is NOT also a positive relation if reversed_pair not in positive_pairs: negative_pairs.add(reversed_pair) - # 2. Add random negative pairs if needed - if add_random_negatives and len(negative_pairs) < target_negatives: - # Get entity span positions for proximity-based sampling - entities = batch["entities"][i] - entity_positions = [(ent[0], ent[1]) for ent in entities] if entities else [] + if add_random_negatives and N > 1 and len(negative_pairs) < target_negatives: + ent_id_list = entities_id[i] + entity_positions = [(ent[0], ent[1]) for ent in ent_id_list] if ent_id_list else [] attempts = 0 - max_attempts = target_negatives * 10 # Avoid infinite loop + max_attempts = target_negatives * 10 while len(negative_pairs) < target_negatives and attempts < max_attempts: attempts += 1 - - # Sample two different entities e1 = random.randint(0, N - 1) e2 = random.randint(0, N - 1) @@ -1470,62 +1471,42 @@ def create_relation_labels(self, batch, add_reversed_negatives=True, add_random_ continue pair = (e1, e2) - - # Skip if already positive or already in negatives if pair in positive_pairs or pair in negative_pairs: continue - # Optional: bias towards nearby entities (hard negatives) - if entity_positions and len(entity_positions) > e1 and len(entity_positions) > e2: + if entity_positions and len(entity_positions) > max(e1, e2): pos1 = entity_positions[e1] pos2 = entity_positions[e2] - distance = abs(pos1[0] - pos2[1]) # Distance between entities - - # Sample with probability inversely proportional to distance - # (closer entities are harder negatives) + distance = abs(pos1[0] - pos2[1]) if distance > 10 and random.random() < 0.5: - continue # Skip some far pairs + continue negative_pairs.add(pair) - # Combine all pairs (positives + negatives) and sort all_pairs = sorted(list(positive_pairs) + list(negative_pairs)) - - # Store pair info: pair, is_positive, relations - pair_info = [] - for pair in all_pairs: - is_positive = pair in positive_pairs - relations = pair_to_relations.get(pair, []) - pair_info.append((pair, is_positive, relations)) + pair_info = [(pair, pair in positive_pairs, pair_to_relations.get(pair, [])) for pair in all_pairs] all_pairs_info.append(pair_info) max_total_pairs = max(max_total_pairs, len(all_pairs)) - # Create matrices + max_total_pairs = max(max_total_pairs, 1) + rel_matrix = torch.zeros(B, max_total_pairs, C, dtype=torch.float) - pair_type_mask = torch.zeros(B, max_total_pairs, dtype=torch.long) # 1=positive, 0=negative for i in range(B): N = batch_ents[i].item() pair_info = all_pairs_info[i] - - adj = torch.zeros(N, N) + adj = torch.zeros(max(N, 1), max(N, 1)) for pair_idx, (pair, is_positive, relations) in enumerate(pair_info): e1, e2 = pair - - # Set adjacency (1.0 for both positive and negative pairs) adj[e1, e2] = 1.0 - # Mark pair type - pair_type_mask[i, pair_idx] = 1 if is_positive else 0 - if is_positive: - # Create multi-hot vector for positive pairs for class_id in relations: rel_matrix[i, pair_idx, class_id - 1] = 1.0 - adj_matrix[i, :N, :N] = adj + adj_matrix[i, :N, :N] = adj[:N, :N] return adj_matrix, rel_matrix @@ -1630,3 +1611,196 @@ def tokenize_and_prepare_labels(self, batch, prepare_labels, *args, **kwargs): tokenized_input["rel_matrix"] = rel_matrix return tokenized_input + + +class RelationExtractionTokenProcessor(UniEncoderTokenProcessor, RelationExtractionSpanProcessor): + """Processor for joint entity and relation extraction using token-level NER. + + Extends token-based NER processing to additionally handle relation extraction + between entity pairs, supporting end-to-end joint training with BIO-style + entity tagging. + + Inherits from: + - UniEncoderTokenProcessor: Token-level BIO tagging for entities + - RelationExtractionSpanProcessor: Relation extraction utilities + """ + + def __init__(self, config, tokenizer, words_splitter): + """Initialize the relation extraction token processor. + + Args: + config: Configuration object. + tokenizer: Transformer tokenizer. + words_splitter: Word-level tokenizer/splitter. + """ + UniEncoderTokenProcessor.__init__(self, config, tokenizer, words_splitter) + self.rel_token = config.rel_token + + def preprocess_example(self, tokens, ner, classes_to_id, relations=None, rel_classes_to_id=None): + """Preprocess a single example for joint entity and relation extraction. + + Processes both entity annotations (for token-level BIO tagging) and + relation triplets, ensuring consistent indexing when entities are reordered. + + Args: + tokens: List of token strings. + ner: List of entity annotations as (start, end, label) tuples. + classes_to_id: Mapping from entity class labels to integer IDs. + relations: List of relation annotations as (head_idx, tail_idx, rel_type) tuples. + rel_classes_to_id: Mapping from relation class labels to integer IDs. + + Returns: + Dictionary containing: + - tokens: Token strings + - seq_length: Sequence length + - entities: Original entity annotations + - entities_id: Entity annotations with class IDs + - relations: Original relation annotations + - rel_idx: Tensor of relation head/tail entity indices + - rel_label: Tensor of relation type labels + + Warnings: + UserWarning: If sequence length exceeds max_len (gets truncated). + """ + # Handle empty token list + if len(tokens) == 0: + tokens = ["[PAD]"] + + # Truncate if necessary + max_len = self.config.max_len + if len(tokens) > max_len: + warnings.warn(f"Sentence of length {len(tokens)} has been truncated to {max_len}", stacklevel=2) + tokens = tokens[:max_len] + + num_tokens = len(tokens) + + ner, relations = self.sort_entities_and_relations(ner, relations) + + # Generate entity IDs for token-level labeling (filter by valid class and position) + try: + entities_id = [ + [start, end, classes_to_id[label]] + for start, end, label in ner + if label in classes_to_id and end < num_tokens + ] + except TypeError: + entities_id = [] + + # Create entity index mapping (from sorted entity list index to entities_id index) + entity_idx_mapping = {} + valid_entity_idx = 0 + if ner is not None: + for ent_idx, (start, end, label) in enumerate(ner): + if label in classes_to_id and end < num_tokens: + entity_idx_mapping[ent_idx] = valid_entity_idx + valid_entity_idx += 1 + + # Process relations + rel_idx_list = [] + rel_label_list = [] + + if relations is not None and rel_classes_to_id is not None: + for rel in relations: + head_idx, tail_idx, rel_type = rel + + # Check if both entities are valid and relation type is known + if (head_idx in entity_idx_mapping and + tail_idx in entity_idx_mapping and + rel_type in rel_classes_to_id): + mapped_head = entity_idx_mapping[head_idx] + mapped_tail = entity_idx_mapping[tail_idx] + rel_idx_list.append([mapped_head, mapped_tail]) + rel_label_list.append(rel_classes_to_id[rel_type]) + + # Convert to tensors + if rel_idx_list: + rel_idx = torch.LongTensor(rel_idx_list) + rel_label = torch.LongTensor(rel_label_list) + else: + rel_idx = torch.zeros(0, 2, dtype=torch.long) + rel_label = torch.zeros(0, dtype=torch.long) + + return { + "tokens": tokens, + "seq_length": num_tokens, + "entities": ner, + "entities_id": entities_id, + "relations": relations, + "rel_idx": rel_idx, + "rel_label": rel_label, + } + + def create_batch_dict(self, batch, class_to_ids, id_to_classes, rel_class_to_ids=None, rel_id_to_classes=None): + """Create a batch dictionary from preprocessed relation extraction examples. + + Args: + batch: List of preprocessed example dictionaries. + class_to_ids: List of entity class-to-ID mappings. + id_to_classes: List of entity ID-to-class mappings. + rel_class_to_ids: List of relation class-to-ID mappings. + rel_id_to_classes: List of relation ID-to-class mappings. + + Returns: + Dictionary containing all batch data for joint entity and relation + extraction with token-level entity labels. + """ + tokens = [el["tokens"] for el in batch] + seq_length = torch.LongTensor([el["seq_length"] for el in batch]).unsqueeze(-1) + entities = [el["entities"] for el in batch] + entities_id = [el["entities_id"] for el in batch] + relations = [el["relations"] for el in batch] + + rel_idx = pad_sequence([el["rel_idx"] for el in batch], batch_first=True, padding_value=0) + rel_label = pad_sequence([el["rel_label"] for el in batch], batch_first=True, padding_value=0) + + return { + "tokens": tokens, + "seq_length": seq_length, + "entities": entities, + "entities_id": entities_id, + "relations": relations, + "rel_idx": rel_idx, + "rel_label": rel_label, + "classes_to_id": class_to_ids, + "id_to_classes": id_to_classes, + "rel_class_to_ids": rel_class_to_ids, + "rel_id_to_classes": rel_id_to_classes, + } + + + def tokenize_and_prepare_labels(self, batch, prepare_labels, *args, **kwargs): + """Tokenize inputs and prepare labels for joint entity-relation extraction. + + Args: + batch: Batch dictionary with tokens, entities, relations, and class mappings. + prepare_labels: Whether to prepare labels. + *args: Variable length argument list. + **kwargs: Arbitrary keyword arguments. + + Returns: + Dictionary containing tokenized inputs, token-level entity labels, + relation adjacency matrix, and relation labels. + """ + batch_size = len(batch["tokens"]) + seq_len = batch["seq_length"].max().item() + num_classes = max([len(cid) for cid in batch["classes_to_id"]]) + + # Use relation-aware tokenize_inputs from RelationExtractionSpanProcessor + tokenized_input = self.tokenize_inputs( + batch["tokens"], + batch["classes_to_id"], + blank=None, + relations=batch["rel_class_to_ids"] + ) + + if prepare_labels: + # Create token-level BIO labels (from UniEncoderTokenProcessor) + labels = self.create_labels(batch["entities_id"], batch_size, seq_len, num_classes) + tokenized_input["labels"] = labels + + # Create relation labels (overridden method) + adj_matrix, rel_matrix = self.create_relation_labels(batch) + tokenized_input["adj_matrix"] = adj_matrix + tokenized_input["rel_matrix"] = rel_matrix + + return tokenized_input \ No newline at end of file diff --git a/gliner/decoding/__init__.py b/gliner/decoding/__init__.py index b8b2d0d..dee4e75 100644 --- a/gliner/decoding/__init__.py +++ b/gliner/decoding/__init__.py @@ -1 +1 @@ -from .decoder import SpanDecoder, TokenDecoder, SpanRelexDecoder, SpanGenerativeDecoder +from .decoder import SpanDecoder, TokenDecoder, SpanRelexDecoder, SpanGenerativeDecoder, TokenRelexDecoder diff --git a/gliner/decoding/decoder.py b/gliner/decoding/decoder.py index e077dab..ac55c4b 100644 --- a/gliner/decoding/decoder.py +++ b/gliner/decoding/decoder.py @@ -936,3 +936,215 @@ class IDs to class names. spans.append(span_i) return spans + +class TokenRelexDecoder(TokenDecoder): + """Token-based decoder with relation extraction support. + + Extends the token-based BIO decoder to decode both entity spans and the relations + between them. Entity spans are extracted first using BIO-style tagging logic, + then relations are decoded by identifying pairs of entities and their + relationship types based on model predictions. + + The decoder supports: + - Entity span extraction via BIO tagging with confidence thresholding + - Relation extraction between detected entities + - Flexible entity and relation label mappings (per-sample or global) + - Optional flat NER (non-overlapping entities) + - Multi-label entity classification + """ + + def _decode_relations( + self, + spans: List[List[tuple]], + rel_idx: Optional[torch.Tensor], + rel_logits: Optional[torch.Tensor], + rel_mask: Optional[torch.Tensor], + rel_id_to_classes: Union[Dict[int, str], List[Dict[int, str]]], + threshold: float, + batch_size: int, + ) -> List[List[tuple]]: + """Decode relations between detected entity spans. + + Extracts relation predictions from model outputs and maps them to pairs + of detected entity spans. For each potential relation, checks if both + head and tail entities exist in the decoded spans and if the relation + confidence exceeds the threshold. + + Args: + spans: List of entity spans for each sample in the batch. + Each sample contains a list of tuples: (start, end, entity_type, score). + rel_idx: Tensor of shape (batch_size, num_relations, 2) containing + indices of head and tail entities for each potential relation. + None if no relations to decode. + rel_logits: Tensor of shape (batch_size, num_relations, num_relation_classes) + containing logits for relation classifications. None if no relations. + rel_mask: Optional boolean tensor of shape (batch_size, num_relations) + indicating which relations are valid (True) vs. padding (False). + If None, all relations are considered valid. + rel_id_to_classes: Mapping from relation class IDs to relation names. + Can be either: + - Dict: Single mapping used for all samples + - List[Dict]: Per-sample mappings for different relation schemas + Class IDs are 1-indexed (0 reserved for "no relation" or padding). + threshold: Minimum confidence score (after sigmoid) for a relation + to be included in the output. Must be in range [0, 1]. + batch_size: Number of samples in the batch. + + Returns: + List of relation lists, one per sample. Each relation is a tuple: + (head_idx, relation_label, tail_idx, score) where: + - head_idx: Index into the sample's spans list for the head entity + - relation_label: String name of the relation type + - tail_idx: Index into the sample's spans list for the tail entity + - score: Confidence score for this relation (float, 0-1 range) + """ + relations = [[] for _ in range(batch_size)] + + # Check if relation outputs are available + if rel_idx is None or rel_logits is None: + return relations + + # Get or create relation mask + if rel_mask is None: + rel_mask = torch.ones(rel_idx[..., 0].shape, dtype=torch.bool, device=rel_idx.device) + + rel_probs = torch.sigmoid(rel_logits) + + # Decode relations for each sample + for i in range(batch_size): + rel_id_to_class_i = rel_id_to_classes[i] if isinstance(rel_id_to_classes, list) else rel_id_to_classes + + # Process each potential relation + for j in range(rel_idx.size(1)): + # Skip if masked out + if not rel_mask[i, j]: + continue + + head_idx = rel_idx[i, j, 0].item() + tail_idx = rel_idx[i, j, 1].item() + + # Skip invalid indices + if head_idx < 0 or tail_idx < 0: + continue + + # Skip if either span was removed by greedy search + if head_idx >= len(spans[i]) or tail_idx >= len(spans[i]): + continue + + # Check each relation class + for c, p in enumerate(rel_probs[i, j]): + prob = p.item() + + # Skip low confidence predictions + if prob <= threshold: + continue + + # Skip if class ID not in mapping (c + 1 because 0 is padding) + if (c + 1) not in rel_id_to_class_i: + continue + + rel_label = rel_id_to_class_i[c + 1] + relations[i].append((head_idx, rel_label, tail_idx, prob)) + + return relations + + def decode( + self, + tokens: List[List[str]], + id_to_classes: Union[Dict[int, str], List[Dict[int, str]]], + model_output: torch.Tensor, + rel_idx: Optional[torch.Tensor] = None, + rel_logits: Optional[torch.Tensor] = None, + rel_mask: Optional[torch.Tensor] = None, + flat_ner: bool = False, + threshold: float = 0.5, + relation_threshold: float = 0.5, + multi_label: bool = False, + rel_id_to_classes: Optional[Union[Dict[int, str], List[Dict[int, str]]]] = None, + **kwargs, + ) -> Tuple[List[List[tuple]], List[List[tuple]]]: + """Decode model output to extract entities and relations. + + Main decoding method that extracts both entity spans and relations from + model outputs. First decodes entity spans using BIO-style token tagging, + then decodes relations between the detected entities. + + Args: + tokens: Tokenized input text for each sample in the batch. + Each sample is a list of token strings. + id_to_classes: Mapping from entity class IDs to entity type names. + Can be either: + - Dict: Single mapping used for all samples (global entity schema) + - List[Dict]: Per-sample mappings for different entity schemas + Class IDs are 1-indexed (0 is reserved for padding). + model_output: Model output tensor with shape (B, L, C, 3) where the last + dimension contains [start, end, inside] predictions for BIO tagging. + rel_idx: Optional tensor of shape (batch_size, num_relations, 2) containing + head and tail entity indices for each potential relation. + rel_logits: Optional tensor of shape (batch_size, num_relations, num_relation_classes) + containing relation classification logits. + rel_mask: Optional boolean tensor of shape (batch_size, num_relations) + indicating valid relations. If None, all relations are considered valid. + flat_ner: If True, applies greedy filtering to ensure non-overlapping + entity spans. If False, allows overlapping entities. Defaults to False. + threshold: Minimum confidence score (0-1) for entity predictions + to be included in the output. Defaults to 0.5. + relation_threshold: Minimum confidence score (0-1) for relation + predictions to be included in the output. Defaults to 0.5. + multi_label: If True, allows multiple entity types per span. If False, + only the highest-scoring entity type per span is kept. Defaults to False. + rel_id_to_classes: Optional mapping from relation class IDs to relation names. + If None, relation decoding is skipped and empty relation lists are returned. + Can be either a single Dict or List[Dict] for per-sample mappings. + Class IDs are 1-indexed. + **kwargs: Additional keyword arguments passed to the parent class decode method. + + Returns: + Tuple of (spans, relations) where: + - spans: List of entity span lists, one per sample. Each entity span is + a tuple: (start, end, entity_type, score) + - relations: List of relation lists, one per sample. Each relation is + a tuple: (head_idx, relation_label, tail_idx, score) where head_idx + and tail_idx are indices into the corresponding sample's spans list. + + Examples: + >>> decoder = TokenRelexDecoder(config) + >>> tokens = [["John", "works", "at", "Microsoft"]] + >>> id_to_classes = {1: "PERSON", 2: "ORG"} + >>> rel_id_to_classes = {1: "works_at"} + >>> spans, relations = decoder.decode( + ... tokens=tokens, + ... id_to_classes=id_to_classes, + ... model_output=output, + ... rel_id_to_classes=rel_id_to_classes, + ... threshold=0.5, + ... ) + >>> # spans[0] might be: [(0, 0, "PERSON", 0.9), (3, 3, "ORG", 0.85)] + >>> # relations[0] might be: [(0, "works_at", 1, 0.8)] + """ + # Decode entity spans using parent class BIO-style logic + spans = super().decode( + tokens=tokens, + id_to_classes=id_to_classes, + model_output=model_output, + flat_ner=flat_ner, + threshold=threshold, + multi_label=multi_label, + **kwargs, + ) + + # Decode relations if requested + relations = [[] for _ in range(len(tokens))] + + if rel_id_to_classes is not None: + relations = self._decode_relations( + spans=spans, + rel_idx=rel_idx, + rel_logits=rel_logits, + rel_mask=rel_mask, + rel_id_to_classes=rel_id_to_classes, + threshold=relation_threshold, + batch_size=len(tokens), + ) + + return spans, relations \ No newline at end of file diff --git a/gliner/model.py b/gliner/model.py index d2b1ac2..ad4becf 100644 --- a/gliner/model.py +++ b/gliner/model.py @@ -31,9 +31,10 @@ UniEncoderSpanConfig, UniEncoderTokenConfig, UniEncoderSpanRelexConfig, + UniEncoderTokenRelexConfig, UniEncoderSpanDecoderConfig, ) -from .decoding import SpanDecoder, TokenDecoder, SpanRelexDecoder, SpanGenerativeDecoder +from .decoding import SpanDecoder, TokenDecoder, SpanRelexDecoder, SpanGenerativeDecoder, TokenRelexDecoder from .training import Trainer, TrainingArguments from .evaluation import BaseNEREvaluator, BaseRelexEvaluator from .onnx.model import ( @@ -43,6 +44,7 @@ UniEncoderSpanORTModel, UniEncoderTokenORTModel, UniEncoderSpanRelexORTModel, + UniEncoderTokenRelexORTModel ) from .decoding.trie import LabelsTrie from .infer_packing import InferencePackingConfig @@ -53,6 +55,7 @@ UniEncoderSpanModel, UniEncoderTokenModel, UniEncoderSpanRelexModel, + UniEncoderTokenRelexModel, UniEncoderSpanDecoderModel, ) from .data_processing import ( @@ -63,6 +66,7 @@ UniEncoderTokenProcessor, UniEncoderSpanDecoderProcessor, RelationExtractionSpanProcessor, + RelationExtractionTokenProcessor ) from .data_processing.collator import ( BiEncoderSpanDataCollator, @@ -71,6 +75,7 @@ UniEncoderTokenDataCollator, UniEncoderSpanDecoderDataCollator, RelationExtractionSpanDataCollator, + RelationExtractionTokenDataCollator ) from .data_processing.tokenizer import WordsSplitter @@ -2101,8 +2106,7 @@ def export_to_onnx( "2. Use PyTorch for inference with this model\n" "3. Implement a custom ONNX pipeline with separate encoder/decoder exports" ) - - + class UniEncoderSpanRelexGLiNER(BaseEncoderGLiNER): """GLiNER model for both entity recognition and relation extraction. @@ -2594,7 +2598,88 @@ def forward( return UniEncoderSpanRelexWrapper(core_model) +class UniEncoderTokenRelexGLiNER(UniEncoderSpanRelexGLiNER): + """GLiNER model for both entity recognition and relation extraction. + + Performs joint entity and relation prediction, allowing the model to simultaneously + detect entities and the relationships between them in a single forward pass. + """ + + config_class = UniEncoderTokenRelexConfig + model_class = UniEncoderTokenRelexModel + ort_model_class: type = UniEncoderTokenRelexORTModel + data_processor_class = RelationExtractionTokenProcessor + data_collator_class = RelationExtractionTokenDataCollator + decoder_class = TokenRelexDecoder + + def _get_onnx_input_spec(self) -> dict[str, Any]: + """Define ONNX input specification for UniEncoderSpanRelex model.""" + return { + "input_names": [ + "input_ids", + "attention_mask", + "words_mask", + "text_lengths", + ], + "output_names": ["logits", "rel_idx", "rel_logits", "rel_mask"], + "dynamic_axes": { + "input_ids": {0: "batch_size", 1: "sequence_length"}, + "attention_mask": {0: "batch_size", 1: "sequence_length"}, + "words_mask": {0: "batch_size", 1: "sequence_length"}, + "text_lengths": {0: "batch_size", 1: "value"}, + "logits": { + 0: "batch_size", + 1: "sequence_length", + 2: "num_ent_classes", + 3: "num_idx_classes", + }, + "rel_idx": { + 0: "batch_size", + 1: "num_pairs", + 2: "pair_index", + }, + "rel_logits": { + 0: "batch_size", + 1: "num_pairs", + 2: "num_rel_classes", + }, + "rel_mask": { + 0: "batch_size", + 1: "num_pairs", + }, + }, + } + + def _get_onnx_export_kwargs(self) -> dict[str, Any]: + """Provide default labels for relation extraction ONNX export.""" + return {"labels": ["head", "tail"]} + def _create_onnx_wrapper(self, core_model: nn.Module) -> nn.Module: + """Create wrapper for UniEncoderSpanRelex ONNX export.""" + + class UniEncoderTokenRelexWrapper(nn.Module): + def __init__(self, core): + super().__init__() + self.core = core + + def forward( + self, + input_ids, + attention_mask, + words_mask, + text_lengths, + ): + out = self.core( + input_ids=input_ids, + attention_mask=attention_mask, + words_mask=words_mask, + text_lengths=text_lengths, + ) + # Return all outputs for relation extraction + return out.logits, out.rel_idx, out.rel_logits, out.rel_mask + + return UniEncoderTokenRelexWrapper(core_model) + class GLiNER(nn.Module, PyTorchModelHubMixin): """Meta GLiNER class that automatically instantiates the appropriate GLiNER variant. @@ -2687,7 +2772,10 @@ def _get_gliner_class(config: GLiNERConfig): # Priority order: relations > decoder > bi-encoder > token vs span if has_relations: - return UniEncoderSpanRelexGLiNER + if is_token_level: + return UniEncoderTokenRelexGLiNER + else: + return UniEncoderSpanRelexGLiNER if has_labels_decoder: if has_labels_encoder: @@ -2956,6 +3044,11 @@ def model_map(self) -> dict[str, dict[str, Any]]: "description": "Joint entity and relation extraction with single encoder", "config": {"span_mode": "span_level", "labels_encoder": None, "relations_layer": "required"}, }, + "gliner_uni_encoder_token_relex": { + "class": UniEncoderTokenRelexGLiNER, + "description": "Joint entity and relation extraction with single encoder using token-level architecture", + "config": {"span_mode": "token_level", "labels_encoder": None, "relations_layer": "required"}, + }, } def get_model_type(self) -> str: @@ -2974,6 +3067,7 @@ def get_model_type(self) -> str: "BiEncoderTokenGLiNER": "gliner_bi_encoder_token", "UniEncoderSpanDecoderGLiNER": "gliner_uni_encoder_span_decoder", "UniEncoderSpanRelexGLiNER": "gliner_uni_encoder_span_relex", + "UniEncoderTokenRelexGLiNER": "gliner_uni_encoder_token_relex" } return type_mapping.get(class_name, "unknown") diff --git a/gliner/modeling/base.py b/gliner/modeling/base.py index b194a4c..b5d57a7 100644 --- a/gliner/modeling/base.py +++ b/gliner/modeling/base.py @@ -560,6 +560,7 @@ def forward( prompts_embedding, prompts_embedding_mask, target_C ) + # Shape: (batch_size, seq_len, num_classes, 3), 3 - start, end, inside scores = self.scorer(words_embedding, prompts_embedding) loss = None @@ -1628,16 +1629,17 @@ def select_span_target_embedding( - target_rep: Selected span representations of shape (B, E, D). - target_mask: Mask for selected spans of shape (B, E). """ - B, L, K, D = span_rep.shape + B = span_rep.size(0) + D = span_rep.size(-1) - span_rep_flat = span_rep.view(B, L * K, D) - span_mask_flat = span_mask.view(B, L * K) + span_rep_flat = span_rep.view(B, -1, D) + span_mask_flat = span_mask.view(B, -1) if span_labels is not None: - span_prob_flat = span_labels.max(dim=-1).values.view(B, L * K) + span_prob_flat = span_labels.max(dim=-1).values.view(B, -1) keep = (span_prob_flat == 1).bool() else: - span_prob_flat = torch.sigmoid(span_scores).max(dim=-1).values.view(B, L * K) + span_prob_flat = torch.sigmoid(span_scores).max(dim=-1).values.view(B, -1) keep = (span_prob_flat > threshold) & span_mask_flat.bool() if top_k is not None and top_k > 0: @@ -1688,6 +1690,26 @@ def select_target_embedding( return target_rep, target_mask + + def represent_spans(self, words_embeddings, words_mask, prompts_embeddings, + span_idx: Optional[torch.Tensor]=None, + span_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + threshold: float = 0.5, + ): + + span_idx = span_idx * span_mask.unsqueeze(-1) + span_rep = self.span_rep_layer(words_embeddings, span_idx) + scores = torch.einsum("BLKD,BCD->BLKC", span_rep, prompts_embeddings) + + if hasattr(self, "relations_rep_layer"): + target_span_rep, target_span_mask = self.select_span_target_embedding( + span_rep, scores, span_mask, labels, threshold + ) + else: + target_span_rep, target_span_mask = None, None + return scores, target_span_rep, target_span_mask + def forward( self, input_ids: Optional[torch.FloatTensor] = None, @@ -1746,30 +1768,29 @@ def forward( target_W = span_idx.size(1) // self.config.max_width words_embedding, mask = self._fit_length(words_embedding, mask, target_W) - span_idx = span_idx * span_mask.unsqueeze(-1) - - span_rep = self.span_rep_layer(words_embedding, span_idx) - - target_C = prompts_embedding.size(1) - if labels is not None: - target_C = max(target_C, labels.size(-1)) - prompts_embedding, prompts_embedding_mask = self._fit_length( prompts_embedding, prompts_embedding_mask, target_C ) prompts_embedding = self.prompt_rep_layer(prompts_embedding) batch_size, _, embed_dim = prompts_embedding.shape - scores = torch.einsum("BLKD,BCD->BLKC", span_rep, prompts_embedding) + + scores, target_span_rep, target_span_mask = self.represent_spans(words_embedding, mask, + prompts_embedding, + span_idx, + span_mask, + labels, + threshold + ) + target_C = prompts_embedding.size(1) + if labels is not None: + target_C = max(target_C, labels.size(-1)) pair_idx, pair_mask, pair_scores = None, None, None rel_prompts_embedding_mask = None pred_adj_matrix = None if hasattr(self, "relations_rep_layer"): - target_span_rep, target_span_mask = self.select_span_target_embedding( - span_rep, scores, span_mask, labels, threshold - ) pred_adj_matrix = self.relations_rep_layer(target_span_rep, target_span_mask) rel_prompts_embedding, rel_prompts_embedding_mask = extract_prompt_features( @@ -1965,3 +1986,134 @@ def rel_loss( loss = masked_loss.sum() return loss + + +class UniEncoderTokenRelexModel(UniEncoderSpanRelexModel): + """Token-level NER model with relation extraction capabilities. + + This model extends token-based NER to also extract relations between + identified entities, predicting both entity types and relation types + in a joint model. + + Attributes: + relations_rep_layer (Optional[RelationsRepLayer]): Layer for computing + pairwise entity relations (adjacency matrix). + triples_score_layer (Optional[TriplesScoreLayer]): Layer for scoring + (head, relation, tail) triples. + pair_rep_layer (Optional[nn.Module]): Alternative layer for relation + scoring via concatenation. + """ + + def __init__( + self, config: Any, from_pretrained: bool = False, cache_dir: Optional[Union[str, Path]] = None + ) -> None: + """Initialize the span-based relation extraction model. + + Args: + config: Model configuration object. + from_pretrained: Whether to load from pretrained weights. + cache_dir: Directory for caching pretrained models. + """ + super().__init__(config, from_pretrained, cache_dir) + self.scorer = Scorer(config.hidden_size, config.dropout) + + + def extract_spans( + self, + scores: torch.Tensor, + labels: Optional[torch.Tensor] = None, + threshold: float = 0.5, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Extract entity spans from BIO-style token predictions. + + Args: + scores: (B, W, C, 3) - logits for [start, end, inside] + labels: Optional (B, W, C, 3) - ground truth labels + threshold: Confidence threshold (used when labels is None) + + Returns: + span_idx: (B, N, 2) - [start, end] indices, padded + span_mask: (B, N) - validity mask + """ + B, W, C, _ = scores.shape + device = scores.device + + if labels is not None: + start_mask = labels[..., 0] > 0.5 + end_mask = labels[..., 1] > 0.5 + inside_mask = labels[..., 2] > 0.5 + else: + probs = torch.sigmoid(scores) + start_mask = probs[..., 0] > threshold + end_mask = probs[..., 1] > threshold + inside_mask = probs[..., 2] > threshold + + # Prepend zeros for cumsum indexing + inside_cumsum = torch.nn.functional.pad( + inside_mask.long().cumsum(dim=1), (0, 0, 1, 0) + ) # (B, W+1, C) + + spans_per_sample = [] + + for b in range(B): + starts = start_mask[b].nonzero(as_tuple=False) + ends = end_mask[b].nonzero(as_tuple=False) + + if starts.size(0) == 0 or ends.size(0) == 0: + spans_per_sample.append(torch.empty(0, 2, dtype=torch.long, device=device)) + continue + + s_pos, s_cls = starts.T + e_pos, e_cls = ends.T + + # Find valid (start, end) pairs: same class & end >= start + valid = (s_cls[:, None] == e_cls) & (s_pos[:, None] <= e_pos) + si, ei = valid.nonzero(as_tuple=True) + + if si.size(0) == 0: + spans_per_sample.append(torch.empty(0, 2, dtype=torch.long, device=device)) + continue + + cs, ce, cc = s_pos[si], e_pos[ei], s_cls[si] + + # Validate: all inside positions must be marked + inside_cnt = inside_cumsum[b, ce + 1, cc] - inside_cumsum[b, cs, cc] + valid = inside_cnt == (ce - cs + 1) + + cs, ce = cs[valid], ce[valid] + + if cs.size(0) == 0: + spans_per_sample.append(torch.empty(0, 2, dtype=torch.long, device=device)) + else: + spans_per_sample.append(torch.stack([cs, ce], dim=1)) + + # Pad to uniform size + max_spans = max(s.size(0) for s in spans_per_sample) if spans_per_sample else 0 + max_spans = max(max_spans, 1) # Ensure at least 1 to avoid empty tensor issues + + span_idx = torch.zeros(B, max_spans, 2, dtype=torch.long, device=device) + span_mask = torch.zeros(B, max_spans, dtype=torch.bool, device=device) + + for b, spans in enumerate(spans_per_sample): + n = spans.size(0) + if n > 0: + span_idx[b, :n] = spans + span_mask[b, :n] = True + + return span_idx, span_mask + + + def represent_spans(self, words_embeddings, words_mask, prompts_embeddings, + span_idx: Optional[torch.Tensor]=None, + span_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + threshold: float = 0.5, + ): + scores = self.scorer(words_embeddings, prompts_embeddings) + + span_idx, target_span_mask = self.extract_spans(scores, labels, threshold) + span_idx = span_idx * target_span_mask.unsqueeze(-1) + target_span_rep = self.span_rep_layer(words_embeddings, span_idx) + + return scores, target_span_rep, target_span_mask \ No newline at end of file diff --git a/gliner/onnx/model.py b/gliner/onnx/model.py index eb4e933..66f5085 100644 --- a/gliner/onnx/model.py +++ b/gliner/onnx/model.py @@ -369,3 +369,51 @@ def forward( rel_mask=inference_output["rel_mask"], ) return outputs + +class UniEncoderTokenRelexORTModel(BaseORTModel): + """ONNX Runtime model for uni-encoder token-level relation extraction. + + Uses a single encoder to process text and perform both entity recognition + and relation extraction at the token level. + """ + + def forward( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + words_mask: torch.Tensor, + text_lengths: torch.Tensor, + **kwargs, + ) -> Dict[str, Any]: + """Forward pass for span relation extraction model using ONNX inference. + + Args: + input_ids: Tensor of shape (batch_size, seq_len) containing input token IDs. + attention_mask: Tensor of shape (batch_size, seq_len) with 1s for real + tokens and 0s for padding. + words_mask: Tensor of shape (batch_size, seq_len) indicating word boundaries. + text_lengths: Tensor of shape (batch_size,) containing the actual length + of each text sequence. + span_idx: Tensor containing indices of spans to classify. + span_mask: Tensor indicating which spans are valid (not padding). + **kwargs: Additional arguments (ignored). + + Returns: + GLiNERRelexOutput containing logits for span classification, relation + indices, relation logits, and relation mask. + """ + inputs = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "words_mask": words_mask, + "text_lengths": text_lengths, + } + prepared_inputs = self.prepare_inputs(inputs) + inference_output = self.run_inference(prepared_inputs) + outputs = GLiNERRelexOutput( + logits=inference_output["logits"], + rel_idx=inference_output["rel_idx"], + rel_logits=inference_output["rel_logits"], + rel_mask=inference_output["rel_mask"], + ) + return outputs \ No newline at end of file From 90715a0ca1c646bc81c7df428aae189dfa481d13 Mon Sep 17 00:00:00 2001 From: Ingvar Date: Wed, 7 Jan 2026 13:08:21 +0200 Subject: [PATCH 2/8] fix data processing for token-level architectures --- gliner/data_processing/processor.py | 100 +++++++++++++--------------- gliner/data_processing/utils.py | 4 +- gliner/model.py | 2 +- gliner/modeling/base.py | 91 ++++++++++++++++++++----- gliner/modeling/span_rep.py | 49 ++++++++++++++ 5 files changed, 173 insertions(+), 73 deletions(-) diff --git a/gliner/data_processing/processor.py b/gliner/data_processing/processor.py index 9ce4d12..1ce68c4 100644 --- a/gliner/data_processing/processor.py +++ b/gliner/data_processing/processor.py @@ -429,7 +429,7 @@ class UniEncoderSpanProcessor(BaseProcessor): predict entity types for all possible spans up to a maximum width. """ - def prepare_span_labels(self, ner, classes_to_id, num_tokens): + def prepare_span_labels(self, ner, classes_to_id, num_tokens, spans_idx): dict_lab = self.get_dict(ner, classes_to_id) if ner else defaultdict(int) span_label = torch.LongTensor([dict_lab[i] for i in spans_idx]) spans_idx = torch.LongTensor(spans_idx) @@ -470,7 +470,7 @@ def preprocess_example(self, tokens, ner, classes_to_id): num_tokens = len(tokens) spans_idx = prepare_span_idx(num_tokens, max_width) - span_label, spans_idx = self.prepare_span_labels(ner, classes_to_id, num_tokens) + span_label, spans_idx = self.prepare_span_labels(ner, classes_to_id, num_tokens, spans_idx) return { "tokens": tokens, @@ -654,27 +654,29 @@ def create_batch_dict(self, batch, class_to_ids, id_to_classes): return batch_dict - def create_labels(self, entities_id, batch_size, seq_len, num_classes): + def create_labels(self, batch): """Create token-level labels with begin/inside/end markers. Creates labels indicating which tokens are at the start, end, or inside of entity spans for each entity type. Args: - entities_id: List of entity annotations with class IDs for each example. - batch_size: Size of the batch. - seq_len: Maximum sequence length in batch. - num_classes: Number of entity classes. + batch: List[Any] batch of data Returns: Tensor of shape (batch_size, seq_len, num_classes, 3) where the last dimension contains [start_marker, end_marker, inside_marker]. """ + batch_size = len(batch["tokens"]) + seq_len = batch["seq_length"].max().item() + num_classes = max([len(cid) for cid in batch["classes_to_id"]]) + word_labels = torch.zeros(batch_size, seq_len, num_classes, 3, dtype=torch.float) - for i, sentence_entities in enumerate(entities_id): + for i, sentence_entities in enumerate(batch['entities']): for st, ed, sp_label in sentence_entities: - class_idx = sp_label - 1 # Convert to 0-indexed + lbl = batch['classes_to_id'][i][sp_label] + class_idx = lbl - 1 # Convert to 0-indexed # skip entities that point beyond sequence length if st >= seq_len or ed >= seq_len: @@ -698,14 +700,9 @@ def tokenize_and_prepare_labels(self, batch, prepare_labels, *args, **kwargs): Returns: Dictionary containing tokenized inputs and optionally labels. """ - batch_size = len(batch["tokens"]) - seq_len = batch["seq_length"].max() - num_classes = max([len(cid) for cid in batch["classes_to_id"]]) - tokenized_input = self.tokenize_inputs(batch["tokens"], batch["classes_to_id"]) - if prepare_labels: - labels = self.create_labels(batch["entities_id"], batch_size, seq_len, num_classes) + labels = self.create_labels(batch) tokenized_input["labels"] = labels return tokenized_input @@ -866,14 +863,11 @@ def tokenize_and_prepare_labels(self, batch, prepare_labels, prepare_entities=Tr entities = list(batch["classes_to_id"][0]) else: entities = None - batch_size = len(batch["tokens"]) - seq_len = batch["seq_length"].max() - num_classes = len(entities) tokenized_input = self.tokenize_inputs(batch["tokens"], entities) if prepare_labels: - labels = self.create_labels(batch["entities_id"], batch_size, seq_len, num_classes) + labels = self.create_labels(batch) tokenized_input["labels"] = labels return tokenized_input @@ -1294,7 +1288,7 @@ def preprocess_example(self, tokens, ner, classes_to_id, relations, rel_classes_ ner, relations = self.sort_entities_and_relations(ner, relations) # Process entity labels - span_label, spans_idx = self.prepare_span_labels(ner, classes_to_id, num_tokens) + span_label, spans_idx = self.prepare_span_labels(ner, classes_to_id, num_tokens, spans_idx) # Create entity span to index mapping span_to_idx = {(spans_idx[i, 0].item(), spans_idx[i, 1].item()): i for i in range(len(spans_idx))} @@ -1333,7 +1327,6 @@ def preprocess_example(self, tokens, ner, classes_to_id, relations, rel_classes_ "span_label": span_label, "seq_length": num_tokens, "entities": ner, - "entities_id": entity_to_span_idx, "relations": relations, "rel_idx": rel_idx, "rel_label": rel_label, @@ -1357,8 +1350,6 @@ def create_batch_dict(self, batch, class_to_ids, id_to_classes, rel_class_to_ids entities = [el["entities"] for el in batch] relations = [el["relations"] for el in batch] - entities_id = [el["entities_id"] for el in batch] - span_idx = pad_sequence([b["span_idx"] for b in batch], batch_first=True, padding_value=0) span_label = pad_sequence([el["span_label"] for el in batch], batch_first=True, padding_value=-1) rel_idx = pad_sequence([el["rel_idx"] for el in batch], batch_first=True, padding_value=0) @@ -1374,7 +1365,6 @@ def create_batch_dict(self, batch, class_to_ids, id_to_classes, rel_class_to_ids "span_mask": span_mask, "span_label": span_label, "entities": entities, - "entities_id": entities_id, "relations": relations, "rel_idx": rel_idx, "rel_label": rel_label, @@ -1403,10 +1393,10 @@ def create_relation_labels(self, batch, add_reversed_negatives=True, add_random_ - rel_matrix: Multi-hot relation labels (shape: [B, max_pairs, num_relation_classes]) """ B = len(batch["tokens"]) - entities_id = batch["entities_id"] + span_mask = batch["span_mask"] # Count entities per sample (differs from span-based which uses span_label) - batch_ents = torch.LongTensor([len(ent_list) for ent_list in entities_id]) + batch_ents = span_mask.long().squeeze(-1).sum(-1) max_En = max(batch_ents.max().item(), 1) rel_class_to_ids = batch["rel_class_to_ids"] @@ -1456,9 +1446,6 @@ def create_relation_labels(self, batch, add_reversed_negatives=True, add_random_ negative_pairs.add(reversed_pair) if add_random_negatives and N > 1 and len(negative_pairs) < target_negatives: - ent_id_list = entities_id[i] - entity_positions = [(ent[0], ent[1]) for ent in ent_id_list] if ent_id_list else [] - attempts = 0 max_attempts = target_negatives * 10 @@ -1474,13 +1461,6 @@ def create_relation_labels(self, batch, add_reversed_negatives=True, add_random_ if pair in positive_pairs or pair in negative_pairs: continue - if entity_positions and len(entity_positions) > max(e1, e2): - pos1 = entity_positions[e1] - pos2 = entity_positions[e2] - distance = abs(pos1[0] - pos2[1]) - if distance > 10 and random.random() < 0.5: - continue - negative_pairs.add(pair) all_pairs = sorted(list(positive_pairs) + list(negative_pairs)) @@ -1676,25 +1656,23 @@ def preprocess_example(self, tokens, ner, classes_to_id, relations=None, rel_cla ner, relations = self.sort_entities_and_relations(ner, relations) - # Generate entity IDs for token-level labeling (filter by valid class and position) - try: - entities_id = [ - [start, end, classes_to_id[label]] - for start, end, label in ner - if label in classes_to_id and end < num_tokens - ] - except TypeError: - entities_id = [] - # Create entity index mapping (from sorted entity list index to entities_id index) entity_idx_mapping = {} valid_entity_idx = 0 + if ner is not None: + span_idx_list = [] for ent_idx, (start, end, label) in enumerate(ner): if label in classes_to_id and end < num_tokens: + span_idx_list.append([start, end]) entity_idx_mapping[ent_idx] = valid_entity_idx valid_entity_idx += 1 - + if span_idx_list: + span_idx = torch.LongTensor(span_idx_list) + else: + span_idx = torch.zeros(0, 2, dtype=torch.long) + else: + span_idx = None # Process relations rel_idx_list = [] rel_label_list = [] @@ -1720,11 +1698,12 @@ def preprocess_example(self, tokens, ner, classes_to_id, relations=None, rel_cla rel_idx = torch.zeros(0, 2, dtype=torch.long) rel_label = torch.zeros(0, dtype=torch.long) + return { "tokens": tokens, "seq_length": num_tokens, "entities": ner, - "entities_id": entities_id, + "span_idx": span_idx, "relations": relations, "rel_idx": rel_idx, "rel_label": rel_label, @@ -1747,9 +1726,24 @@ def create_batch_dict(self, batch, class_to_ids, id_to_classes, rel_class_to_ids tokens = [el["tokens"] for el in batch] seq_length = torch.LongTensor([el["seq_length"] for el in batch]).unsqueeze(-1) entities = [el["entities"] for el in batch] - entities_id = [el["entities_id"] for el in batch] relations = [el["relations"] for el in batch] + if batch[0]['span_idx'] is not None: + span_idx_list = [el["span_idx"] for el in batch] + + batch_size = len(span_idx_list) + span_counts = [s.size(0) if s.numel() > 0 else 0 for s in span_idx_list] + max_spans = max(max(span_counts), 1) # Ensure at least 1 + + span_mask = torch.zeros(batch_size, max_spans, dtype=torch.bool) + for i, count in enumerate(span_counts): + if count > 0: + span_mask[i, :count] = True + + span_idx = pad_2d_tensor(span_idx_list, padding_value=0) + else: + span_idx, span_mask = None, None + rel_idx = pad_sequence([el["rel_idx"] for el in batch], batch_first=True, padding_value=0) rel_label = pad_sequence([el["rel_label"] for el in batch], batch_first=True, padding_value=0) @@ -1757,7 +1751,8 @@ def create_batch_dict(self, batch, class_to_ids, id_to_classes, rel_class_to_ids "tokens": tokens, "seq_length": seq_length, "entities": entities, - "entities_id": entities_id, + "span_idx": span_idx, + "span_mask": span_mask, "relations": relations, "rel_idx": rel_idx, "rel_label": rel_label, @@ -1781,9 +1776,6 @@ def tokenize_and_prepare_labels(self, batch, prepare_labels, *args, **kwargs): Dictionary containing tokenized inputs, token-level entity labels, relation adjacency matrix, and relation labels. """ - batch_size = len(batch["tokens"]) - seq_len = batch["seq_length"].max().item() - num_classes = max([len(cid) for cid in batch["classes_to_id"]]) # Use relation-aware tokenize_inputs from RelationExtractionSpanProcessor tokenized_input = self.tokenize_inputs( @@ -1795,7 +1787,7 @@ def tokenize_and_prepare_labels(self, batch, prepare_labels, *args, **kwargs): if prepare_labels: # Create token-level BIO labels (from UniEncoderTokenProcessor) - labels = self.create_labels(batch["entities_id"], batch_size, seq_len, num_classes) + labels = self.create_labels(batch) tokenized_input["labels"] = labels # Create relation labels (overridden method) diff --git a/gliner/data_processing/utils.py b/gliner/data_processing/utils.py index 460beca..1c9e19f 100644 --- a/gliner/data_processing/utils.py +++ b/gliner/data_processing/utils.py @@ -4,7 +4,7 @@ import torch -def pad_2d_tensor(key_data): +def pad_2d_tensor(key_data, padding_value=0.0): """Pad a list of 2D tensors to uniform dimensions. Takes a list of 2D tensors with potentially different shapes and pads them @@ -44,7 +44,7 @@ def pad_2d_tensor(key_data): col_padding = max_cols - cols # Pad the tensor along both dimensions - padded_tensor = torch.nn.functional.pad(tensor, (0, col_padding, 0, row_padding), mode="constant", value=0) + padded_tensor = torch.nn.functional.pad(tensor, (0, col_padding, 0, row_padding), mode="constant", value=padding_value) tensors.append(padded_tensor) # Stack the tensors into a single tensor along a new batch dimension diff --git a/gliner/model.py b/gliner/model.py index ad4becf..a98e54f 100644 --- a/gliner/model.py +++ b/gliner/model.py @@ -992,7 +992,7 @@ def create_training_args( save_total_limit: int = 10, logging_steps: int = 10, use_cpu: bool = False, - bf16: bool = True, + bf16: bool = False, dataloader_num_workers: int = 1, report_to: str = "none", **kwargs, diff --git a/gliner/modeling/base.py b/gliner/modeling/base.py index b5d57a7..a1a6e94 100644 --- a/gliner/modeling/base.py +++ b/gliner/modeling/base.py @@ -426,7 +426,7 @@ def loss( scores: torch.Tensor, labels: torch.Tensor, prompts_embedding_mask: torch.Tensor, - mask_label: torch.Tensor, + span_mask: torch.Tensor, alpha: float = -1.0, gamma: float = 0.0, prob_margin: float = 0.0, @@ -471,9 +471,9 @@ def loss( masked_loss = all_losses.view(batch_size, -1, num_classes) * prompts_embedding_mask.unsqueeze(1) all_losses = masked_loss.view(-1, num_classes) - mask_label = mask_label.view(-1, 1) + span_mask = span_mask.view(-1, 1) - all_losses = all_losses * mask_label.float() + all_losses = all_losses * span_mask.float() if reduction == "mean": loss = all_losses.mean() @@ -582,7 +582,7 @@ def loss( scores: torch.Tensor, labels: torch.Tensor, prompts_embedding_mask: torch.Tensor, - mask: torch.Tensor, + word_mask: torch.Tensor, alpha: float = -1.0, gamma: float = 0.0, prob_margin: float = 0.0, @@ -611,7 +611,7 @@ def loss( """ all_losses = self._loss(scores, labels, alpha, gamma, prob_margin, label_smoothing, negatives) - all_losses = all_losses * (mask.unsqueeze(-1) * prompts_embedding_mask.unsqueeze(1)).unsqueeze(-1) + all_losses = all_losses * (word_mask.unsqueeze(-1) * prompts_embedding_mask.unsqueeze(1)).unsqueeze(-1) if reduction == "mean": loss = all_losses.mean() @@ -1698,7 +1698,7 @@ def represent_spans(self, words_embeddings, words_mask, prompts_embeddings, threshold: float = 0.5, ): - span_idx = span_idx * span_mask.unsqueeze(-1) + span_idx = span_idx * span_mask.unsqueeze(-1).long() span_rep = self.span_rep_layer(words_embeddings, span_idx) scores = torch.einsum("BLKD,BCD->BLKC", span_rep, prompts_embeddings) @@ -1761,11 +1761,23 @@ def forward( token_embeds, input_ids, attention_mask, text_lengths, words_mask ) ) - + if hasattr(self, "rnn"): words_embedding = self.rnn(words_embedding, mask) - target_W = span_idx.size(1) // self.config.max_width + if self.config.span_mode=='token_level': + if labels is not None: + target_W = labels.shape[1] + target_C = max(prompts_embedding.size(1), labels.size(-2)) + else: + target_W = words_embedding.size(1) + target_C = prompts_embedding.size(1) + else: + target_W = span_idx.size(1) // self.config.max_width + target_C = prompts_embedding.size(1) + if labels is not None: + target_C = max(target_C, labels.size(-1)) + words_embedding, mask = self._fit_length(words_embedding, mask, target_W) prompts_embedding, prompts_embedding_mask = self._fit_length( @@ -1782,9 +1794,6 @@ def forward( labels, threshold ) - target_C = prompts_embedding.size(1) - if labels is not None: - target_C = max(target_C, labels.size(-1)) pair_idx, pair_mask, pair_scores = None, None, None rel_prompts_embedding_mask = None @@ -1833,7 +1842,7 @@ def forward( loss = None if labels is not None: - loss = self.loss(scores, labels, prompts_embedding_mask, span_mask, **kwargs) + loss = self.loss(scores, labels, prompts_embedding_mask, span_mask=span_mask, word_mask=mask, **kwargs) if adj_matrix is not None and rel_matrix is not None and hasattr(self, "relations_rep_layer"): adj_mask = target_span_mask.float().unsqueeze(1) * target_span_mask.float().unsqueeze(2) @@ -2104,6 +2113,55 @@ def extract_spans( return span_idx, span_mask + def loss( + self, + scores: torch.Tensor, + labels: torch.Tensor, + prompts_embedding_mask: torch.Tensor, + word_mask: torch.Tensor, + alpha: float = -1.0, + gamma: float = 0.0, + prob_margin: float = 0.0, + label_smoothing: float = 0.0, + reduction: str = "sum", + negatives: float = 1.0, + **kwargs: Any, + ) -> torch.Tensor: + """Compute token classification loss. + + Args: + scores: Predicted scores of shape (B, W, C). + labels: Ground truth labels of shape (B, W, C). + prompts_embedding_mask: Mask for valid entity types of shape (B, C). + mask: Mask for valid tokens of shape (B, W). + alpha: Focal loss alpha parameter. + gamma: Focal loss gamma parameter. + prob_margin: Margin for probability adjustment. + label_smoothing: Label smoothing factor. + reduction: Loss reduction method ('sum' or 'mean'). + negatives: Negative sampling probability. + **kwargs: Additional arguments. + + Returns: + Scalar loss tensor. + """ + all_losses = self._loss(scores, labels, alpha, gamma, prob_margin, label_smoothing, negatives) + + all_losses = all_losses * (word_mask.unsqueeze(-1) * prompts_embedding_mask.unsqueeze(1)).unsqueeze(-1) + + if reduction == "mean": + loss = all_losses.mean() + elif reduction == "sum": + loss = all_losses.sum() + else: + warnings.warn( + f"Invalid Value for config 'loss_reduction': '{reduction}' \n Supported reduction modes:" + f" 'none', 'mean', 'sum'. It will be used 'sum' instead.", + stacklevel=2, + ) + loss = all_losses.sum() + return loss + def represent_spans(self, words_embeddings, words_mask, prompts_embeddings, span_idx: Optional[torch.Tensor]=None, span_mask: Optional[torch.Tensor] = None, @@ -2111,9 +2169,10 @@ def represent_spans(self, words_embeddings, words_mask, prompts_embeddings, threshold: float = 0.5, ): scores = self.scorer(words_embeddings, prompts_embeddings) - - span_idx, target_span_mask = self.extract_spans(scores, labels, threshold) - span_idx = span_idx * target_span_mask.unsqueeze(-1) + + if span_idx is None: + span_idx, span_mask = self.extract_spans(scores, labels, threshold) + span_idx = span_idx * span_mask.unsqueeze(-1).long() target_span_rep = self.span_rep_layer(words_embeddings, span_idx) - return scores, target_span_rep, target_span_mask \ No newline at end of file + return scores, target_span_rep, span_mask \ No newline at end of file diff --git a/gliner/modeling/span_rep.py b/gliner/modeling/span_rep.py index 677447e..7202aa8 100644 --- a/gliner/modeling/span_rep.py +++ b/gliner/modeling/span_rep.py @@ -633,7 +633,54 @@ def forward(self, x, *args): return out +class TokenMarker(nn.Module): + """Marks and projects span endpoints using an MLP. + + A cleaner version of SpanMarker using the create_projection_layer utility. + + Attributes: + max_width (int): Maximum span width to represent. + project_start (nn.Module): MLP for projecting start positions. + project_end (nn.Module): MLP for projecting end positions. + out_project (nn.Module): Final projection layer. + """ + + def __init__(self, hidden_size: int, dropout: float = 0.4): + """Initialize the SpanMarkerV0 layer. + + Args: + hidden_size (int): Dimension of the hidden representations. + max_width (int): Maximum span width to represent. + dropout (float, optional): Dropout rate. Defaults to 0.4. + """ + super().__init__() + self.project_start = create_projection_layer(hidden_size, dropout) + self.project_end = create_projection_layer(hidden_size, dropout) + + self.out_project = create_projection_layer(hidden_size * 2, dropout, hidden_size) + + def forward(self, h: torch.Tensor, span_idx: torch.Tensor) -> torch.Tensor: + """Compute span representations using start and end markers. + + Args: + h (torch.Tensor): Token representations of shape [B, L, D]. + span_idx (torch.Tensor): Span indices of shape [B, *, 2]. + + Returns: + torch.Tensor: Span representations of shape [B, L, max_width, D]. + """ + B, L, D = h.size() + num_spans = span_idx.size(1) + start_rep = self.project_start(h) + end_rep = self.project_end(h) + + start_span_rep = extract_elements(start_rep, span_idx[:, :, 0]) + end_span_rep = extract_elements(end_rep, span_idx[:, :, 1]) + + cat = torch.cat([start_span_rep, end_span_rep], dim=-1).relu() + return self.out_project(cat) + class SpanRepLayer(nn.Module): """Factory class for various span representation approaches. @@ -691,6 +738,8 @@ def __init__(self, hidden_size, max_width, span_mode, **kwargs): self.span_rep_layer = SpanConv(hidden_size, max_width, span_mode="conv_sum") elif span_mode == "conv_share": self.span_rep_layer = ConvShare(hidden_size, max_width) + elif span_mode == 'token_level': + self.span_rep_layer = TokenMarker(hidden_size, **kwargs) else: raise ValueError(f"Unknown span mode {span_mode}") From 846a5cc48285014428cfb368df8266944e9e3977 Mon Sep 17 00:00:00 2001 From: Ingvar Date: Wed, 7 Jan 2026 18:10:36 +0200 Subject: [PATCH 3/8] implement post-span modeling for token-level architectures --- gliner/config.py | 22 ++- gliner/data_processing/processor.py | 234 +++++++++++++++++++++++----- gliner/decoding/decoder.py | 155 +++++++++++++++--- gliner/model.py | 6 +- gliner/modeling/base.py | 220 ++++++++------------------ gliner/modeling/outputs.py | 3 + gliner/modeling/utils.py | 86 +++++++++- 7 files changed, 512 insertions(+), 214 deletions(-) diff --git a/gliner/config.py b/gliner/config.py index d368d68..e25bc67 100644 --- a/gliner/config.py +++ b/gliner/config.py @@ -119,10 +119,19 @@ def __init__(self, **kwargs): class UniEncoderTokenConfig(UniEncoderConfig): """Configuration for uni-encoder token-based GLiNER model.""" - def __init__(self, **kwargs): + def __init__(self, + represent_spans: bool = False, + token_loss_coef=1.0, + span_loss_coef=1.0, + neg_spans_ratio=1.0, + **kwargs): super().__init__(**kwargs) self.span_mode = "token_level" self.model_type = "gliner_uni_encoder_token" + self.token_loss_coef = token_loss_coef + self.span_loss_coef = span_loss_coef + self.represent_spans = represent_spans + self.neg_spans_ratio=neg_spans_ratio class UniEncoderSpanDecoderConfig(UniEncoderConfig): @@ -268,10 +277,19 @@ def __init__(self, **kwargs): class BiEncoderTokenConfig(BiEncoderConfig): """Configuration for bi-encoder token-based GLiNER model.""" - def __init__(self, **kwargs): + def __init__(self, + represent_spans: bool = False, + token_loss_coef=1.0, + span_loss_coef=1.0, + neg_spans_ratio=1.0, + **kwargs): super().__init__(**kwargs) self.span_mode = "token_level" self.model_type = "gliner_bi_encoder_token" + self.token_loss_coef = token_loss_coef + self.span_loss_coef = span_loss_coef + self.represent_spans = represent_spans + self.neg_spans_ratio=neg_spans_ratio class GLiNERConfig(BaseGLiNERConfig): diff --git a/gliner/data_processing/processor.py b/gliner/data_processing/processor.py index 1ce68c4..95859a9 100644 --- a/gliner/data_processing/processor.py +++ b/gliner/data_processing/processor.py @@ -126,6 +126,29 @@ def tokenize_and_prepare_labels(self): """ pass + def sort_entities_and_relations(self, ner, relations=None): + if ner is not None and len(ner) > 0: + indexed_ner = list(enumerate(ner)) + indexed_ner_sorted = sorted(indexed_ner, key=lambda x: (x[1][0], x[1][1])) + + ner_sorted = [entity for _, entity in indexed_ner_sorted] + + # Create mapping from old entity indices to new sorted indices + old_to_new_idx = {old_idx: new_idx for new_idx, (old_idx, _) in enumerate(indexed_ner_sorted)} + + # Update relation indices to match new entity ordering + if relations is not None and len(relations) > 0: + updated_relations = [] + for head_idx, tail_idx, rel_type in relations: + if head_idx in old_to_new_idx and tail_idx in old_to_new_idx: + new_head_idx = old_to_new_idx[head_idx] + new_tail_idx = old_to_new_idx[tail_idx] + updated_relations.append((new_head_idx, new_tail_idx, rel_type)) + relations = sorted(updated_relations, key=lambda x: (x[0], x[1])) + + ner = ner_sorted + return ner, relations + def prepare_inputs( self, texts: Sequence[Sequence[str]], @@ -582,6 +605,52 @@ class UniEncoderTokenProcessor(BaseProcessor): labeled with BIO-style tags (Begin, Inside, Outside) for each entity type. """ + def _generate_negative_spans(self, positive_spans, num_tokens, num_negatives, max_width=None): + """Generate random negative spans that don't overlap with positive spans. + + Args: + positive_spans: Set of (start, end) tuples representing positive entity spans. + num_tokens: Total number of tokens in the sequence. + num_negatives: Number of negative spans to generate. + max_width: Maximum width for negative spans. If None, uses config.max_width. + + Returns: + List of (start, end) tuples representing negative spans. + """ + if max_width is None: + max_width = getattr(self.config, 'max_width', 10) + + negative_spans = [] + attempts = 0 + max_attempts = num_negatives * 20 # Limit attempts to avoid infinite loops + + while len(negative_spans) < num_negatives and attempts < max_attempts: + attempts += 1 + + # Random start position + start = random.randint(0, num_tokens - 1) + + # Random width (1 to max_width) + width = random.randint(1, min(max_width, num_tokens - start)) + end = start + width - 1 + + # Check if this span overlaps with any positive span + span = (start, end) + if span in positive_spans: + continue + + # Check for overlap with positive spans + overlaps = False + for pos_start, pos_end in positive_spans: + if not (end < pos_start or start > pos_end): + overlaps = True + break + + if not overlaps and span not in negative_spans: + negative_spans.append(span) + + return negative_spans + def preprocess_example(self, tokens, ner, classes_to_id): """Preprocess a single example for token-based prediction. @@ -595,7 +664,8 @@ def preprocess_example(self, tokens, ner, classes_to_id): - tokens: Token strings - seq_length: Sequence length - entities: Original NER annotations - - entities_id: Entity annotations with class IDs + - span_idx: Tensor of entity span indices (if represent_spans=True) + - span_label: Tensor of entity class IDs (if represent_spans=True) Warnings: UserWarning: If sequence length exceeds max_len (gets truncated). @@ -610,13 +680,51 @@ def preprocess_example(self, tokens, ner, classes_to_id): warnings.warn(f"Sentence of length {len(tokens)} has been truncated to {max_len}", stacklevel=2) tokens = tokens[:max_len] - # Generate entity IDs based on the NER spans provided and their classes - try: # 'NoneType' object is not iterable - entities_id = [[i, j, classes_to_id[k]] for i, j, k in ner if k in classes_to_id] - except TypeError: - entities_id = [] + num_tokens = len(tokens) + + # Create span representations if configured + if ner is not None and self.config.represent_spans: + span_idx_list = [] + span_label_list = [] + positive_spans = set() + + # Add positive spans + for start, end, label in ner: + if label in classes_to_id and end < num_tokens: + span_idx_list.append([start, end]) + span_label_list.append(classes_to_id[label]) + positive_spans.add((start, end)) + + # Add negative spans + neg_spans_ratio = self.config.neg_spans_ratio + neg_spans_count = int(len(span_idx_list) * neg_spans_ratio) + + if neg_spans_count > 0 and num_tokens > 0: + max_width = getattr(self.config, 'max_width', 10) + negative_spans = self._generate_negative_spans( + positive_spans, num_tokens, neg_spans_count, max_width + ) + + for start, end in negative_spans: + span_idx_list.append([start, end]) + span_label_list.append(0) # 0 indicates negative/no entity + + if span_idx_list: + span_idx = torch.LongTensor(span_idx_list) + span_label = torch.LongTensor(span_label_list) + else: + span_idx = torch.zeros(0, 2, dtype=torch.long) + span_label = torch.zeros(0, dtype=torch.long) + else: + span_idx, span_label = None, None - example = {"tokens": tokens, "seq_length": len(tokens), "entities": ner, "entities_id": entities_id} + example = { + "tokens": tokens, + "seq_length": len(tokens), + "entities": ner, + "span_idx": span_idx, + "span_label": span_label, + } return example def create_batch_dict(self, batch, class_to_ids, id_to_classes): @@ -632,7 +740,9 @@ def create_batch_dict(self, batch, class_to_ids, id_to_classes): - tokens: Token strings - seq_length: Sequence lengths - entities: Original NER annotations - - entities_id: Entity annotations with class IDs + - span_idx: Padded span indices (if available) + - span_label: Padded span labels (if available) + - span_mask: Mask for valid spans (if available) - classes_to_id: Class mappings - id_to_classes: Reverse class mappings """ @@ -640,18 +750,39 @@ def create_batch_dict(self, batch, class_to_ids, id_to_classes): tokens = [el["tokens"] for el in batch] seq_length = torch.LongTensor([el["seq_length"] for el in batch]).unsqueeze(-1) entities = [el["entities"] for el in batch] - entities_id = [el["entities_id"] for el in batch] - # Assemble and return the batch dictionary + # Assemble the base batch dictionary batch_dict = { "tokens": tokens, "seq_length": seq_length, "entities": entities, - "entities_id": entities_id, "classes_to_id": class_to_ids, "id_to_classes": id_to_classes, } + # Handle span representations if present + if batch[0]['span_idx'] is not None: + span_idx_list = [el["span_idx"] for el in batch] + span_label_list = [el["span_label"] for el in batch] + + batch_size = len(span_idx_list) + span_counts = [s.size(0) if s.numel() > 0 else 0 for s in span_idx_list] + max_spans = max(max(span_counts), 1) # Ensure at least 1 + + # Create span mask indicating valid spans + span_mask = torch.zeros(batch_size, max_spans, dtype=torch.bool) + for i, count in enumerate(span_counts): + if count > 0: + span_mask[i, :count] = True + + # Pad span tensors + span_idx = pad_2d_tensor(span_idx_list, padding_value=0) + span_label = pad_sequence(span_label_list, batch_first=True, padding_value=-1) + + batch_dict["span_idx"] = span_idx + batch_dict["span_label"] = span_label + batch_dict["span_mask"] = span_mask + return batch_dict def create_labels(self, batch): @@ -688,6 +819,50 @@ def create_labels(self, batch): return word_labels + def create_span_labels(self, batch): + """Create one-hot encoded labels for spans with negative sampling. + + Creates one-hot encoded labels for entity spans, converting 1-indexed class IDs + to 0-indexed format. Labels with class ID 0 (negative spans) or -1 (invalid spans) + are represented as all zeros in the one-hot encoding. + + Args: + batch: Batch dictionary containing span_label, span_mask, and classes_to_id. + + Returns: + Tensor of shape (batch_size, max_spans, num_classes) containing one-hot + encoded labels where: + - Positive spans: one-hot vector at position (class_id - 1) + - Negative/invalid spans: all zeros + """ + batch_size = len(batch["tokens"]) + span_label = batch["span_label"] # (batch_size, max_spans) + span_mask = batch["span_mask"] # (batch_size, max_spans) + + # Get maximum number of classes across all examples + if isinstance(batch["classes_to_id"], list): + num_classes = max([len(cid) for cid in batch["classes_to_id"]]) + else: + num_classes = len(batch["classes_to_id"]) + + max_spans = span_label.size(1) + + # Initialize one-hot labels (batch_size, max_spans, num_classes) + labels_one_hot = torch.zeros(batch_size, max_spans, num_classes, dtype=torch.float) + + for i in range(batch_size): + for j in range(max_spans): + if span_mask[i, j]: # Valid span + class_id = span_label[i, j].item() + + if class_id > 0: + # Convert from 1-indexed to 0-indexed + class_idx = class_id - 1 + if class_idx < num_classes: + labels_one_hot[i, j, class_idx] = 1.0 + + return labels_one_hot + def tokenize_and_prepare_labels(self, batch, prepare_labels, *args, **kwargs): """Tokenize inputs and prepare token-level labels for a batch. @@ -704,9 +879,15 @@ def tokenize_and_prepare_labels(self, batch, prepare_labels, *args, **kwargs): if prepare_labels: labels = self.create_labels(batch) tokenized_input["labels"] = labels + + # Add span-level one-hot labels if spans are represented + if batch.get("span_idx") is not None: + span_labels = self.create_span_labels(batch) + tokenized_input["span_labels"] = span_labels + tokenized_input["span_idx"] = batch["span_idx"] + tokenized_input["span_mask"] = batch["span_mask"] return tokenized_input - class BaseBiEncoderProcessor(BaseProcessor): """Base processor for bi-encoder architectures. @@ -870,6 +1051,12 @@ def tokenize_and_prepare_labels(self, batch, prepare_labels, prepare_entities=Tr labels = self.create_labels(batch) tokenized_input["labels"] = labels + # Add span-level one-hot labels if spans are represented + if batch.get("span_idx") is not None: + span_labels = self.create_span_labels(batch) + tokenized_input["span_labels"] = span_labels + tokenized_input["span_idx"] = batch["span_idx"] + tokenized_input["span_mask"] = batch["span_mask"] return tokenized_input @@ -1222,29 +1409,6 @@ def collate_raw_batch( ] return self.create_batch_dict(batch, class_to_ids, id_to_classes, rel_class_to_ids, rel_id_to_classes) - - def sort_entities_and_relations(self, ner, relations): - if ner is not None and len(ner) > 0: - indexed_ner = list(enumerate(ner)) - indexed_ner_sorted = sorted(indexed_ner, key=lambda x: (x[1][0], x[1][1])) - - ner_sorted = [entity for _, entity in indexed_ner_sorted] - - # Create mapping from old entity indices to new sorted indices - old_to_new_idx = {old_idx: new_idx for new_idx, (old_idx, _) in enumerate(indexed_ner_sorted)} - - # Update relation indices to match new entity ordering - if relations is not None and len(relations) > 0: - updated_relations = [] - for head_idx, tail_idx, rel_type in relations: - if head_idx in old_to_new_idx and tail_idx in old_to_new_idx: - new_head_idx = old_to_new_idx[head_idx] - new_tail_idx = old_to_new_idx[tail_idx] - updated_relations.append((new_head_idx, new_tail_idx, rel_type)) - relations = sorted(updated_relations, key=lambda x: (x[0], x[1])) - - ner = ner_sorted - return ner, relations def preprocess_example(self, tokens, ner, classes_to_id, relations, rel_classes_to_id): """Preprocess a single example for joint entity and relation extraction. diff --git a/gliner/decoding/decoder.py b/gliner/decoding/decoder.py index ac55c4b..53d2ca4 100644 --- a/gliner/decoding/decoder.py +++ b/gliner/decoding/decoder.py @@ -827,7 +827,7 @@ class TokenDecoder(BaseDecoder): Token-based decoder for sequence labeling tasks. Uses BIO-style tagging with separate start, end, and inside predictions - to identify entity spans. + to identify entity spans. Can also decode from span-level predictions. """ def _get_indices_above_threshold(self, scores: torch.Tensor, threshold: float) -> List[torch.Tensor]: @@ -889,54 +889,167 @@ def _calculate_span_score( span_i.append((st, ed, id_to_classes[cls_st + 1], spn_score)) return span_i - def decode( + def _decode_from_spans( self, tokens: List[List[str]], id_to_classes: Union[Dict[int, str], List[Dict[int, str]]], - model_output: torch.Tensor, + span_logits: torch.Tensor, + span_idx: torch.Tensor, + span_mask: torch.Tensor, flat_ner: bool = False, threshold: float = 0.5, multi_label: bool = False, - **kwargs, ) -> List[List[tuple]]: """ - Decode token-level predictions to extract spans. + Decode from span-level predictions. Args: tokens (List[List[str]]): Tokenized input text for each sample in the batch. id_to_classes (Union[Dict[int, str], List[Dict[int, str]]]): Mapping from class IDs to class names. - model_output (torch.Tensor): Raw logits from the model with shape ( B, L, C, 3), - where the first dimension represents [start, end, inside] predictions. + span_logits (torch.Tensor): Span classification logits with shape (B, S, C), + where B is batch size, S is max spans, C is number of classes. + span_idx (torch.Tensor): Span indices with shape (B, S, 2), containing + [start, end] positions for each span. + span_mask (torch.Tensor): Boolean mask with shape (B, S) indicating + valid spans. flat_ner (bool): Whether to enforce non-overlapping spans. threshold (float): Confidence threshold for predictions. multi_label (bool): Whether to allow multiple labels per span. - **kwargs: Additional keyword arguments (unused). Returns: List[List[tuple]]: For each sample, list of span tuples in format (start, end, entity_type, None, score). """ - model_output = model_output.permute(3, 0, 1, 2) - scores_start, scores_end, scores_inside = model_output + batch_size = span_logits.size(0) spans = [] - for i, _ in enumerate(tokens): + # Apply sigmoid to get probabilities + span_probs = torch.sigmoid(span_logits) + + for i in range(batch_size): id_to_class_i = self._get_id_to_class_for_sample(id_to_classes, i) - span_scores = self._calculate_span_score( - self._get_indices_above_threshold(scores_start[i], threshold), - self._get_indices_above_threshold(scores_end[i], threshold), - torch.sigmoid(scores_inside[i]), - torch.sigmoid(scores_start[i]), - torch.sigmoid(scores_end[i]), - id_to_class_i, - threshold, - ) + span_scores = [] + + # Get valid spans for this sample + valid_mask = span_mask[i] + valid_indices = torch.where(valid_mask)[0] + + for span_pos in valid_indices: + span_start = span_idx[i, span_pos, 0].item() + span_end = span_idx[i, span_pos, 1].item() + + # Get probabilities for all classes for this span + probs = span_probs[i, span_pos] + + # Find classes above threshold + class_indices = torch.where(probs > threshold)[0] + + for class_idx in class_indices: + class_id = class_idx.item() + 1 # Convert to 1-indexed + if class_id in id_to_class_i: + entity_type = id_to_class_i[class_id] + score = probs[class_idx].item() + span_scores.append((span_start, span_end, entity_type, score)) + + # Apply greedy search to handle overlapping spans if needed span_i = self.greedy_search(span_scores, flat_ner, multi_label) spans.append(span_i) - return spans + def decode( + self, + tokens: List[List[str]], + id_to_classes: Union[Dict[int, str], List[Dict[int, str]]], + model_output: Optional[torch.Tensor] = None, + flat_ner: bool = False, + threshold: float = 0.5, + multi_label: bool = False, + span_logits: Optional[torch.Tensor] = None, + span_idx: Optional[torch.Tensor] = None, + span_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> List[List[tuple]]: + """ + Decode predictions to extract spans. + + Supports two decoding modes: + 1. Token-level BIO decoding (default): Uses model_output with start/end/inside predictions + 2. Span-level decoding: Uses span_logits, span_idx, and span_mask + + Args: + tokens (List[List[str]]): Tokenized input text for each sample in the batch. + id_to_classes (Union[Dict[int, str], List[Dict[int, str]]]): Mapping from + class IDs to class names. + model_output (torch.Tensor, optional): Raw logits from the model with shape + (B, L, C, 3), where the last dimension represents [start, end, inside] + predictions. Used for token-level decoding. + flat_ner (bool): Whether to enforce non-overlapping spans. + threshold (float): Confidence threshold for predictions. + multi_label (bool): Whether to allow multiple labels per span. + span_logits (torch.Tensor, optional): Span classification logits with shape + (B, S, C). Used for span-level decoding. + span_idx (torch.Tensor, optional): Span indices with shape (B, S, 2). + Used for span-level decoding. + span_mask (torch.Tensor, optional): Boolean mask with shape (B, S). + Used for span-level decoding. + **kwargs: Additional keyword arguments (unused). + + Returns: + List[List[tuple]]: For each sample, list of span tuples in format + (start, end, entity_type, None, score). + + Raises: + ValueError: If neither model_output nor span-level inputs are provided, + or if span-level inputs are incomplete. + """ + # Check if span-level decoding is requested + if span_logits is not None and span_idx is not None and span_mask is not None: + return self._decode_from_spans( + tokens=tokens, + id_to_classes=id_to_classes, + span_logits=span_logits, + span_idx=span_idx, + span_mask=span_mask, + flat_ner=flat_ner, + threshold=threshold, + multi_label=multi_label, + ) + + # Check if token-level decoding is requested + if model_output is not None: + model_output = model_output.permute(3, 0, 1, 2) + scores_start, scores_end, scores_inside = model_output + spans = [] + + for i, _ in enumerate(tokens): + id_to_class_i = self._get_id_to_class_for_sample(id_to_classes, i) + span_scores = self._calculate_span_score( + self._get_indices_above_threshold(scores_start[i], threshold), + self._get_indices_above_threshold(scores_end[i], threshold), + torch.sigmoid(scores_inside[i]), + torch.sigmoid(scores_start[i]), + torch.sigmoid(scores_end[i]), + id_to_class_i, + threshold, + ) + span_i = self.greedy_search(span_scores, flat_ner, multi_label) + spans.append(span_i) + return spans + + # Neither decoding mode has sufficient inputs + if span_logits is not None or span_idx is not None or span_mask is not None: + raise ValueError( + "For span-level decoding, all three parameters must be provided: " + "span_logits, span_idx, and span_mask" + ) + + raise ValueError( + "Either model_output (for token-level decoding) or " + "(span_logits, span_idx, span_mask) (for span-level decoding) must be provided" + ) + + class TokenRelexDecoder(TokenDecoder): """Token-based decoder with relation extraction support. diff --git a/gliner/model.py b/gliner/model.py index a98e54f..01613ff 100644 --- a/gliner/model.py +++ b/gliner/model.py @@ -1206,7 +1206,8 @@ def _process_batches(self, data_loader, threshold, flat_ner, multi_label, packin ) # Get predictions - model_logits = self.model(**model_inputs, threshold=threshold)[0] + model_output = self.model(**model_inputs, threshold=threshold) + model_logits = model_output[0] if not isinstance(model_logits, torch.Tensor): model_logits = torch.from_numpy(model_logits) @@ -1215,6 +1216,9 @@ def _process_batches(self, data_loader, threshold, flat_ner, multi_label, packin batch["tokens"], batch["id_to_classes"], model_logits, + span_idx=model_output.span_idx, + span_mask=model_output.span_mask, + span_logits=model_output.span_logits, flat_ner=flat_ner, threshold=threshold, multi_label=multi_label, diff --git a/gliner/modeling/base.py b/gliner/modeling/base.py index a1a6e94..572fbcb 100644 --- a/gliner/modeling/base.py +++ b/gliner/modeling/base.py @@ -29,6 +29,7 @@ build_entity_pairs, extract_prompt_features, extract_word_embeddings, + extract_spans_from_tokens, extract_prompt_features_and_word_embeddings, ) from .layers import CrossFuser, LstmSeq2SeqEncoder, create_projection_layer @@ -512,17 +513,42 @@ def __init__( super().__init__(config, from_pretrained, cache_dir) self.scorer = Scorer(config.hidden_size, config.dropout) + if getattr(config, 'represent_spans', False): + self.span_rep_layer = SpanRepLayer( + span_mode=config.span_mode, + hidden_size=config.hidden_size, + max_width=getattr(config, 'max_width', 12), + dropout=config.dropout, + ) + + def get_span_logits(self, scores, span_idx, span_mask, + words_embedding, prompts_embedding, + labels, threshold): + span_logits = None + if getattr(self.config, 'represent_spans', False): + if span_idx is None: + span_idx, span_mask = extract_spans_from_tokens(scores, labels, threshold) + span_idx = span_idx * span_mask.unsqueeze(-1).long() + + span_rep = self.span_rep_layer(words_embedding, span_idx) + span_logits = torch.einsum("BND,BCD->BNC", span_rep, prompts_embedding) + return span_logits, span_idx, span_mask + def forward( self, input_ids: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.LongTensor] = None, words_embedding: Optional[torch.FloatTensor] = None, mask: Optional[torch.LongTensor] = None, + span_idx: Optional[torch.Tensor]=None, + span_mask: Optional[torch.Tensor] = None, + span_labels: Optional[torch.Tensor] = None, prompts_embedding: Optional[torch.FloatTensor] = None, prompts_embedding_mask: Optional[torch.LongTensor] = None, words_mask: Optional[torch.LongTensor] = None, text_lengths: Optional[torch.Tensor] = None, labels: Optional[torch.FloatTensor] = None, + threshold: Optional[float] = 0.5, **kwargs: Any, ) -> GLiNERBaseOutput: """Forward pass through the token-based model. @@ -562,11 +588,18 @@ def forward( # Shape: (batch_size, seq_len, num_classes, 3), 3 - start, end, inside scores = self.scorer(words_embedding, prompts_embedding) - + + span_logits, span_idx, span_mask = self.get_span_logits(scores, span_idx, span_mask, words_embedding, + prompts_embedding, labels, threshold) + loss = None if labels is not None: loss = self.loss(scores, labels, prompts_embedding_mask, mask, **kwargs) + if span_labels is not None: + span_loss = self.loss(span_logits, span_labels, prompts_embedding_mask, span_mask, **kwargs) + loss = self.config.token_loss_coef*loss + self.config.span_loss_coef*span_loss + output = GLiNERBaseOutput( logits=scores, loss=loss, @@ -574,6 +607,9 @@ def forward( prompts_embedding_mask=prompts_embedding_mask, words_embedding=words_embedding, mask=mask, + span_idx=span_idx, + span_logits=span_logits, + span_mask=span_mask ) return output @@ -591,27 +627,25 @@ def loss( negatives: float = 1.0, **kwargs: Any, ) -> torch.Tensor: - """Compute token classification loss. + """Compute token/span classification loss. Args: - scores: Predicted scores of shape (B, W, C). - labels: Ground truth labels of shape (B, W, C). + scores: Predicted scores of shape (B, W, C, 3) for tokens or (B, N, C) for spans. + labels: Ground truth labels matching scores shape. prompts_embedding_mask: Mask for valid entity types of shape (B, C). - mask: Mask for valid tokens of shape (B, W). - alpha: Focal loss alpha parameter. - gamma: Focal loss gamma parameter. - prob_margin: Margin for probability adjustment. - label_smoothing: Label smoothing factor. - reduction: Loss reduction method ('sum' or 'mean'). - negatives: Negative sampling probability. - **kwargs: Additional arguments. - - Returns: - Scalar loss tensor. + word_mask: Mask for valid tokens/spans of shape (B, W) or (B, N). + ... """ all_losses = self._loss(scores, labels, alpha, gamma, prob_margin, label_smoothing, negatives) - all_losses = all_losses * (word_mask.unsqueeze(-1) * prompts_embedding_mask.unsqueeze(1)).unsqueeze(-1) + # Base mask: (B, W/N, C) + mask = word_mask.unsqueeze(-1) * prompts_embedding_mask.unsqueeze(1) + + # Only add extra dimension for 4D token-level scores (B, W, C, 3) + if all_losses.dim() == 4: + mask = mask.unsqueeze(-1) + + all_losses = all_losses * mask if reduction == "mean": loss = all_losses.mean() @@ -925,7 +959,7 @@ def loss( return loss -class BiEncoderTokenModel(BaseBiEncoderModel): +class BiEncoderTokenModel(BaseBiEncoderModel, UniEncoderTokenModel): """Token-based NER model using bi-encoder architecture. Attributes: @@ -954,11 +988,15 @@ def forward( labels_attention_mask: Optional[torch.LongTensor] = None, words_embedding: Optional[torch.FloatTensor] = None, mask: Optional[torch.LongTensor] = None, + span_idx: Optional[torch.Tensor]=None, + span_mask: Optional[torch.Tensor] = None, + span_labels: Optional[torch.Tensor] = None, prompts_embedding: Optional[torch.FloatTensor] = None, prompts_embedding_mask: Optional[torch.LongTensor] = None, words_mask: Optional[torch.LongTensor] = None, text_lengths: Optional[torch.Tensor] = None, labels: Optional[torch.FloatTensor] = None, + threshold: Optional[float] = 0.5, **kwargs: Any, ) -> GLiNERBaseOutput: """Forward pass through the bi-encoder token model. @@ -1008,10 +1046,17 @@ def forward( scores = self.scorer(words_embedding, prompts_embedding) + span_logits, span_idx, span_mask = self.get_span_logits(scores, span_idx, span_mask, words_embedding, + prompts_embedding, labels, threshold) + loss = None if labels is not None: loss = self.loss(scores, labels, prompts_embedding_mask, mask, **kwargs) + if span_labels is not None: + span_loss = self.loss(span_logits, span_labels, prompts_embedding_mask, span_mask, **kwargs) + loss = self.config.token_loss_coef*loss + self.config.span_loss_coef*span_loss + output = GLiNERBaseOutput( logits=scores, loss=loss, @@ -1019,58 +1064,12 @@ def forward( prompts_embedding_mask=prompts_embedding_mask, words_embedding=words_embedding, mask=mask, + span_idx=span_idx, + span_logits=span_logits, + span_mask=span_mask ) return output - def loss( - self, - scores: torch.Tensor, - labels: torch.Tensor, - prompts_embedding_mask: torch.Tensor, - mask: torch.Tensor, - alpha: float = -1.0, - gamma: float = 0.0, - prob_margin: float = 0.0, - label_smoothing: float = 0.0, - reduction: str = "sum", - negatives: float = 1.0, - **kwargs: Any, - ) -> torch.Tensor: - """Compute token classification loss for bi-encoder. - - Args: - scores: Predicted scores of shape (B, W, C). - labels: Ground truth labels of shape (B, W, C). - prompts_embedding_mask: Mask for valid entity types of shape (B, C). - mask: Mask for valid tokens of shape (B, W). - alpha: Focal loss alpha parameter. - gamma: Focal loss gamma parameter. - prob_margin: Margin for probability adjustment. - label_smoothing: Label smoothing factor. - reduction: Loss reduction method ('sum' or 'mean'). - negatives: Negative sampling probability. - **kwargs: Additional arguments. - - Returns: - Scalar loss tensor. - """ - all_losses = self._loss(scores, labels, alpha, gamma, prob_margin, label_smoothing, negatives) - - all_losses = all_losses * (mask.unsqueeze(-1) * prompts_embedding_mask.unsqueeze(1)).unsqueeze(-1) - - if reduction == "mean": - loss = all_losses.mean() - elif reduction == "sum": - loss = all_losses.sum() - else: - warnings.warn( - f"Invalid Value for config 'loss_reduction': '{reduction}' \n Supported reduction modes:" - f" 'none', 'mean', 'sum'. It will be used 'sum' instead.", - stacklevel=2, - ) - loss = all_losses.sum() - return loss - class UniEncoderSpanDecoderModel(UniEncoderSpanModel): """Span-based model with decoder for generating entity type labels. @@ -2026,93 +2025,6 @@ def __init__( super().__init__(config, from_pretrained, cache_dir) self.scorer = Scorer(config.hidden_size, config.dropout) - - def extract_spans( - self, - scores: torch.Tensor, - labels: Optional[torch.Tensor] = None, - threshold: float = 0.5, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Extract entity spans from BIO-style token predictions. - - Args: - scores: (B, W, C, 3) - logits for [start, end, inside] - labels: Optional (B, W, C, 3) - ground truth labels - threshold: Confidence threshold (used when labels is None) - - Returns: - span_idx: (B, N, 2) - [start, end] indices, padded - span_mask: (B, N) - validity mask - """ - B, W, C, _ = scores.shape - device = scores.device - - if labels is not None: - start_mask = labels[..., 0] > 0.5 - end_mask = labels[..., 1] > 0.5 - inside_mask = labels[..., 2] > 0.5 - else: - probs = torch.sigmoid(scores) - start_mask = probs[..., 0] > threshold - end_mask = probs[..., 1] > threshold - inside_mask = probs[..., 2] > threshold - - # Prepend zeros for cumsum indexing - inside_cumsum = torch.nn.functional.pad( - inside_mask.long().cumsum(dim=1), (0, 0, 1, 0) - ) # (B, W+1, C) - - spans_per_sample = [] - - for b in range(B): - starts = start_mask[b].nonzero(as_tuple=False) - ends = end_mask[b].nonzero(as_tuple=False) - - if starts.size(0) == 0 or ends.size(0) == 0: - spans_per_sample.append(torch.empty(0, 2, dtype=torch.long, device=device)) - continue - - s_pos, s_cls = starts.T - e_pos, e_cls = ends.T - - # Find valid (start, end) pairs: same class & end >= start - valid = (s_cls[:, None] == e_cls) & (s_pos[:, None] <= e_pos) - si, ei = valid.nonzero(as_tuple=True) - - if si.size(0) == 0: - spans_per_sample.append(torch.empty(0, 2, dtype=torch.long, device=device)) - continue - - cs, ce, cc = s_pos[si], e_pos[ei], s_cls[si] - - # Validate: all inside positions must be marked - inside_cnt = inside_cumsum[b, ce + 1, cc] - inside_cumsum[b, cs, cc] - valid = inside_cnt == (ce - cs + 1) - - cs, ce = cs[valid], ce[valid] - - if cs.size(0) == 0: - spans_per_sample.append(torch.empty(0, 2, dtype=torch.long, device=device)) - else: - spans_per_sample.append(torch.stack([cs, ce], dim=1)) - - # Pad to uniform size - max_spans = max(s.size(0) for s in spans_per_sample) if spans_per_sample else 0 - max_spans = max(max_spans, 1) # Ensure at least 1 to avoid empty tensor issues - - span_idx = torch.zeros(B, max_spans, 2, dtype=torch.long, device=device) - span_mask = torch.zeros(B, max_spans, dtype=torch.bool, device=device) - - for b, spans in enumerate(spans_per_sample): - n = spans.size(0) - if n > 0: - span_idx[b, :n] = spans - span_mask[b, :n] = True - - return span_idx, span_mask - - def loss( self, scores: torch.Tensor, @@ -2171,7 +2083,7 @@ def represent_spans(self, words_embeddings, words_mask, prompts_embeddings, scores = self.scorer(words_embeddings, prompts_embeddings) if span_idx is None: - span_idx, span_mask = self.extract_spans(scores, labels, threshold) + span_idx, span_mask = extract_spans_from_tokens(scores, labels, threshold) span_idx = span_idx * span_mask.unsqueeze(-1).long() target_span_rep = self.span_rep_layer(words_embeddings, span_idx) diff --git a/gliner/modeling/outputs.py b/gliner/modeling/outputs.py index cd275e2..a48614d 100644 --- a/gliner/modeling/outputs.py +++ b/gliner/modeling/outputs.py @@ -35,6 +35,9 @@ class GLiNERBaseOutput(ModelOutput): prompts_embedding_mask: Optional[torch.LongTensor] = None words_embedding: Optional[torch.FloatTensor] = None mask: Optional[torch.LongTensor] = None + span_idx: Optional[torch.LongTensor] = None + span_mask: Optional[torch.Tensor] = None + span_logits: Optional[torch.FloatTensor] = None @dataclass diff --git a/gliner/modeling/utils.py b/gliner/modeling/utils.py index 6ff0586..f28fa57 100644 --- a/gliner/modeling/utils.py +++ b/gliner/modeling/utils.py @@ -1,4 +1,4 @@ -from typing import Tuple +from typing import Tuple, Optional import torch @@ -294,3 +294,87 @@ def build_entity_pairs( tail_rep = span_rep[batch_idx, pair_idx[..., 1].clamp_min(0)] # (B, N, D) return pair_idx, pair_mask, head_rep, tail_rep + +def extract_spans_from_tokens( + scores: torch.Tensor, + labels: Optional[torch.Tensor] = None, + threshold: float = 0.5, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Extract entity spans from BIO-style token predictions. + + Args: + scores: (B, W, C, 3) - logits for [start, end, inside] + labels: Optional (B, W, C, 3) - ground truth labels + threshold: Confidence threshold (used when labels is None) + + Returns: + span_idx: (B, N, 2) - [start, end] indices, padded + span_mask: (B, N) - validity mask + """ + B, W, C, _ = scores.shape + device = scores.device + + if labels is not None: + start_mask = labels[..., 0] > 0.5 + end_mask = labels[..., 1] > 0.5 + inside_mask = labels[..., 2] > 0.5 + else: + probs = torch.sigmoid(scores) + start_mask = probs[..., 0] > threshold + end_mask = probs[..., 1] > threshold + inside_mask = probs[..., 2] > threshold + + # Prepend zeros for cumsum indexing + inside_cumsum = torch.nn.functional.pad( + inside_mask.long().cumsum(dim=1), (0, 0, 1, 0) + ) # (B, W+1, C) + + spans_per_sample = [] + + for b in range(B): + starts = start_mask[b].nonzero(as_tuple=False) + ends = end_mask[b].nonzero(as_tuple=False) + + if starts.size(0) == 0 or ends.size(0) == 0: + spans_per_sample.append(torch.empty(0, 2, dtype=torch.long, device=device)) + continue + + s_pos, s_cls = starts.T + e_pos, e_cls = ends.T + + # Find valid (start, end) pairs: same class & end >= start + valid = (s_cls[:, None] == e_cls) & (s_pos[:, None] <= e_pos) + si, ei = valid.nonzero(as_tuple=True) + + if si.size(0) == 0: + spans_per_sample.append(torch.empty(0, 2, dtype=torch.long, device=device)) + continue + + cs, ce, cc = s_pos[si], e_pos[ei], s_cls[si] + + # Validate: all inside positions must be marked + inside_cnt = inside_cumsum[b, ce + 1, cc] - inside_cumsum[b, cs, cc] + valid = inside_cnt == (ce - cs + 1) + + cs, ce = cs[valid], ce[valid] + + if cs.size(0) == 0: + spans_per_sample.append(torch.empty(0, 2, dtype=torch.long, device=device)) + else: + spans_per_sample.append(torch.stack([cs, ce], dim=1)) + + # Pad to uniform size + max_spans = max(s.size(0) for s in spans_per_sample) if spans_per_sample else 0 + max_spans = max(max_spans, 1) # Ensure at least 1 to avoid empty tensor issues + + span_idx = torch.zeros(B, max_spans, 2, dtype=torch.long, device=device) + span_mask = torch.zeros(B, max_spans, dtype=torch.bool, device=device) + + for b, spans in enumerate(spans_per_sample): + n = spans.size(0) + if n > 0: + span_idx[b, :n] = spans + span_mask[b, :n] = True + + return span_idx, span_mask \ No newline at end of file From 6175adb603b1646e119ee3f380120cf2bcf4a411 Mon Sep 17 00:00:00 2001 From: Ingvar Date: Wed, 7 Jan 2026 21:59:54 +0200 Subject: [PATCH 4/8] format codebase with ruff --- gliner/config.py | 27 ++--- gliner/data_processing/__init__.py | 2 +- gliner/data_processing/collator.py | 2 + gliner/data_processing/processor.py | 124 ++++++++++------------ gliner/data_processing/utils.py | 6 +- gliner/decoding/__init__.py | 2 +- gliner/decoding/decoder.py | 25 +++-- gliner/evaluation/evaluator.py | 1 + gliner/model.py | 33 +++--- gliner/modeling/base.py | 157 ++++++++++++++++------------ gliner/modeling/span_rep.py | 6 +- gliner/modeling/utils.py | 41 ++++---- gliner/onnx/model.py | 3 +- gliner/training/trainer.py | 2 +- 14 files changed, 225 insertions(+), 206 deletions(-) diff --git a/gliner/config.py b/gliner/config.py index e25bc67..d6f06ec 100644 --- a/gliner/config.py +++ b/gliner/config.py @@ -119,19 +119,16 @@ def __init__(self, **kwargs): class UniEncoderTokenConfig(UniEncoderConfig): """Configuration for uni-encoder token-based GLiNER model.""" - def __init__(self, - represent_spans: bool = False, - token_loss_coef=1.0, - span_loss_coef=1.0, - neg_spans_ratio=1.0, - **kwargs): + def __init__( + self, represent_spans: bool = False, token_loss_coef=1.0, span_loss_coef=1.0, neg_spans_ratio=1.0, **kwargs + ): super().__init__(**kwargs) self.span_mode = "token_level" self.model_type = "gliner_uni_encoder_token" self.token_loss_coef = token_loss_coef self.span_loss_coef = span_loss_coef self.represent_spans = represent_spans - self.neg_spans_ratio=neg_spans_ratio + self.neg_spans_ratio = neg_spans_ratio class UniEncoderSpanDecoderConfig(UniEncoderConfig): @@ -223,6 +220,7 @@ def __init__( self.adjacency_loss_coef = adjacency_loss_coef self.relation_loss_coef = relation_loss_coef + class UniEncoderSpanRelexConfig(UniEncoderRelexConfig): """Configuration for uni-encoder span model with relation extraction.""" @@ -241,7 +239,7 @@ def __init__(self, **kwargs): self.model_type = "gliner_uni_encoder_token_relex" self.span_mode = "token_level" - + class BiEncoderConfig(BaseGLiNERConfig): """Base configuration for bi-encoder GLiNER models.""" @@ -277,19 +275,16 @@ def __init__(self, **kwargs): class BiEncoderTokenConfig(BiEncoderConfig): """Configuration for bi-encoder token-based GLiNER model.""" - def __init__(self, - represent_spans: bool = False, - token_loss_coef=1.0, - span_loss_coef=1.0, - neg_spans_ratio=1.0, - **kwargs): + def __init__( + self, represent_spans: bool = False, token_loss_coef=1.0, span_loss_coef=1.0, neg_spans_ratio=1.0, **kwargs + ): super().__init__(**kwargs) self.span_mode = "token_level" self.model_type = "gliner_bi_encoder_token" self.token_loss_coef = token_loss_coef self.span_loss_coef = span_loss_coef self.represent_spans = represent_spans - self.neg_spans_ratio=neg_spans_ratio + self.neg_spans_ratio = neg_spans_ratio class GLiNERConfig(BaseGLiNERConfig): @@ -333,7 +328,7 @@ def model_type(self): elif self.labels_encoder: return "gliner_bi_encoder_span" if self.span_mode != "token-level" else "gliner_bi_encoder_token" elif self.relations_layer is not None: - if self.span_mode == 'token-level': + if self.span_mode == "token-level": return "gliner_uni_encoder_token_relex" else: return "gliner_uni_encoder_span_relex" diff --git a/gliner/data_processing/__init__.py b/gliner/data_processing/__init__.py index 102ecd4..7aed30a 100644 --- a/gliner/data_processing/__init__.py +++ b/gliner/data_processing/__init__.py @@ -15,6 +15,6 @@ UniEncoderTokenProcessor, UniEncoderSpanDecoderProcessor, RelationExtractionSpanProcessor, - RelationExtractionTokenProcessor + RelationExtractionTokenProcessor, ) from .tokenizer import WordsSplitter diff --git a/gliner/data_processing/collator.py b/gliner/data_processing/collator.py index 3cffc3e..dabf6ea 100644 --- a/gliner/data_processing/collator.py +++ b/gliner/data_processing/collator.py @@ -483,8 +483,10 @@ class RelationExtractionTokenDataCollator(RelationExtractionSpanDataCollator): Required Processor: RelationExtractionTokenProcessor """ + pass + class UniEncoderSpanDataCollator(SpanDataCollator): """ Backward compatibility alias for SpanDataCollator with UniEncoderSpanProcessor. diff --git a/gliner/data_processing/processor.py b/gliner/data_processing/processor.py index 95859a9..472238f 100644 --- a/gliner/data_processing/processor.py +++ b/gliner/data_processing/processor.py @@ -148,7 +148,7 @@ def sort_entities_and_relations(self, ner, relations=None): ner = ner_sorted return ner, relations - + def prepare_inputs( self, texts: Sequence[Sequence[str]], @@ -459,7 +459,7 @@ def prepare_span_labels(self, ner, classes_to_id, num_tokens, spans_idx): valid_span_mask = spans_idx[:, 1] > num_tokens - 1 span_label = span_label.masked_fill(valid_span_mask, -1) return span_label, spans_idx - + def preprocess_example(self, tokens, ner, classes_to_id): """Preprocess a single example for span-based prediction. @@ -493,7 +493,7 @@ def preprocess_example(self, tokens, ner, classes_to_id): num_tokens = len(tokens) spans_idx = prepare_span_idx(num_tokens, max_width) - span_label, spans_idx = self.prepare_span_labels(ner, classes_to_id, num_tokens, spans_idx) + span_label, spans_idx = self.prepare_span_labels(ner, classes_to_id, num_tokens, spans_idx) return { "tokens": tokens, @@ -607,48 +607,48 @@ class UniEncoderTokenProcessor(BaseProcessor): def _generate_negative_spans(self, positive_spans, num_tokens, num_negatives, max_width=None): """Generate random negative spans that don't overlap with positive spans. - + Args: positive_spans: Set of (start, end) tuples representing positive entity spans. num_tokens: Total number of tokens in the sequence. num_negatives: Number of negative spans to generate. max_width: Maximum width for negative spans. If None, uses config.max_width. - + Returns: List of (start, end) tuples representing negative spans. """ if max_width is None: - max_width = getattr(self.config, 'max_width', 10) - + max_width = getattr(self.config, "max_width", 10) + negative_spans = [] attempts = 0 max_attempts = num_negatives * 20 # Limit attempts to avoid infinite loops - + while len(negative_spans) < num_negatives and attempts < max_attempts: attempts += 1 - + # Random start position start = random.randint(0, num_tokens - 1) - + # Random width (1 to max_width) width = random.randint(1, min(max_width, num_tokens - start)) end = start + width - 1 - + # Check if this span overlaps with any positive span span = (start, end) if span in positive_spans: continue - + # Check for overlap with positive spans overlaps = False for pos_start, pos_end in positive_spans: if not (end < pos_start or start > pos_end): overlaps = True break - + if not overlaps and span not in negative_spans: negative_spans.append(span) - + return negative_spans def preprocess_example(self, tokens, ner, classes_to_id): @@ -694,21 +694,19 @@ def preprocess_example(self, tokens, ner, classes_to_id): span_idx_list.append([start, end]) span_label_list.append(classes_to_id[label]) positive_spans.add((start, end)) - + # Add negative spans neg_spans_ratio = self.config.neg_spans_ratio neg_spans_count = int(len(span_idx_list) * neg_spans_ratio) - + if neg_spans_count > 0 and num_tokens > 0: - max_width = getattr(self.config, 'max_width', 10) - negative_spans = self._generate_negative_spans( - positive_spans, num_tokens, neg_spans_count, max_width - ) - + max_width = getattr(self.config, "max_width", 10) + negative_spans = self._generate_negative_spans(positive_spans, num_tokens, neg_spans_count, max_width) + for start, end in negative_spans: span_idx_list.append([start, end]) span_label_list.append(0) # 0 indicates negative/no entity - + if span_idx_list: span_idx = torch.LongTensor(span_idx_list) span_label = torch.LongTensor(span_label_list) @@ -719,8 +717,8 @@ def preprocess_example(self, tokens, ner, classes_to_id): span_idx, span_label = None, None example = { - "tokens": tokens, - "seq_length": len(tokens), + "tokens": tokens, + "seq_length": len(tokens), "entities": ner, "span_idx": span_idx, "span_label": span_label, @@ -761,24 +759,24 @@ def create_batch_dict(self, batch, class_to_ids, id_to_classes): } # Handle span representations if present - if batch[0]['span_idx'] is not None: + if batch[0]["span_idx"] is not None: span_idx_list = [el["span_idx"] for el in batch] span_label_list = [el["span_label"] for el in batch] - + batch_size = len(span_idx_list) span_counts = [s.size(0) if s.numel() > 0 else 0 for s in span_idx_list] - max_spans = max(max(span_counts), 1) # Ensure at least 1 - + max_spans = max(*span_counts, 1) # Ensure at least 1 + # Create span mask indicating valid spans span_mask = torch.zeros(batch_size, max_spans, dtype=torch.bool) for i, count in enumerate(span_counts): if count > 0: span_mask[i, :count] = True - + # Pad span tensors span_idx = pad_2d_tensor(span_idx_list, padding_value=0) span_label = pad_sequence(span_label_list, batch_first=True, padding_value=-1) - + batch_dict["span_idx"] = span_idx batch_dict["span_label"] = span_label batch_dict["span_mask"] = span_mask @@ -804,9 +802,9 @@ def create_labels(self, batch): word_labels = torch.zeros(batch_size, seq_len, num_classes, 3, dtype=torch.float) - for i, sentence_entities in enumerate(batch['entities']): + for i, sentence_entities in enumerate(batch["entities"]): for st, ed, sp_label in sentence_entities: - lbl = batch['classes_to_id'][i][sp_label] + lbl = batch["classes_to_id"][i][sp_label] class_idx = lbl - 1 # Convert to 0-indexed # skip entities that point beyond sequence length @@ -821,7 +819,7 @@ def create_labels(self, batch): def create_span_labels(self, batch): """Create one-hot encoded labels for spans with negative sampling. - + Creates one-hot encoded labels for entity spans, converting 1-indexed class IDs to 0-indexed format. Labels with class ID 0 (negative spans) or -1 (invalid spans) are represented as all zeros in the one-hot encoding. @@ -837,30 +835,30 @@ def create_span_labels(self, batch): """ batch_size = len(batch["tokens"]) span_label = batch["span_label"] # (batch_size, max_spans) - span_mask = batch["span_mask"] # (batch_size, max_spans) - + span_mask = batch["span_mask"] # (batch_size, max_spans) + # Get maximum number of classes across all examples if isinstance(batch["classes_to_id"], list): num_classes = max([len(cid) for cid in batch["classes_to_id"]]) else: num_classes = len(batch["classes_to_id"]) - + max_spans = span_label.size(1) - + # Initialize one-hot labels (batch_size, max_spans, num_classes) labels_one_hot = torch.zeros(batch_size, max_spans, num_classes, dtype=torch.float) - + for i in range(batch_size): for j in range(max_spans): if span_mask[i, j]: # Valid span class_id = span_label[i, j].item() - + if class_id > 0: # Convert from 1-indexed to 0-indexed class_idx = class_id - 1 if class_idx < num_classes: labels_one_hot[i, j, class_idx] = 1.0 - + return labels_one_hot def tokenize_and_prepare_labels(self, batch, prepare_labels, *args, **kwargs): @@ -879,7 +877,7 @@ def tokenize_and_prepare_labels(self, batch, prepare_labels, *args, **kwargs): if prepare_labels: labels = self.create_labels(batch) tokenized_input["labels"] = labels - + # Add span-level one-hot labels if spans are represented if batch.get("span_idx") is not None: span_labels = self.create_span_labels(batch) @@ -888,6 +886,7 @@ def tokenize_and_prepare_labels(self, batch, prepare_labels, *args, **kwargs): tokenized_input["span_mask"] = batch["span_mask"] return tokenized_input + class BaseBiEncoderProcessor(BaseProcessor): """Base processor for bi-encoder architectures. @@ -1098,7 +1097,7 @@ def tokenize_inputs(self, texts, entities, blank=None): Dictionary containing encoder and decoder tokenized inputs. """ add_entities = True - if self.config.decoder_mode == 'prompt': + if self.config.decoder_mode == "prompt": add_entities = False input_texts, prompt_lengths = self.prepare_inputs(texts, entities, blank=blank, add_entities=add_entities) @@ -1185,15 +1184,15 @@ def create_labels(self, batch, blank=None): labels_one_hot[valid_span_mask, :] = 0.0 labels_one_hot = labels_one_hot[:, 1:] labels_batch.append(labels_one_hot) - - if self.config.decoder_mode == 'span': + + if self.config.decoder_mode == "span": # Collect decoder label strings in order sorted_idxs = sorted(span_labels_dict.keys()) for idx in sorted_idxs: decoder_label_strings.append(span_labels_dict[idx]) - elif self.config.decoder_mode == 'prompt': + elif self.config.decoder_mode == "prompt": decoder_label_strings.extend(list(classes_to_id)) - + labels_batch = pad_2d_tensor(labels_batch) if len(labels_batch) > 1 else labels_batch[0].unsqueeze(0) decoder_tokenized_input = None @@ -1409,7 +1408,7 @@ def collate_raw_batch( ] return self.create_batch_dict(batch, class_to_ids, id_to_classes, rel_class_to_ids, rel_id_to_classes) - + def preprocess_example(self, tokens, ner, classes_to_id, relations, rel_classes_to_id): """Preprocess a single example for joint entity and relation extraction. @@ -1452,7 +1451,7 @@ def preprocess_example(self, tokens, ner, classes_to_id, relations, rel_classes_ ner, relations = self.sort_entities_and_relations(ner, relations) # Process entity labels - span_label, spans_idx = self.prepare_span_labels(ner, classes_to_id, num_tokens, spans_idx) + span_label, spans_idx = self.prepare_span_labels(ner, classes_to_id, num_tokens, spans_idx) # Create entity span to index mapping span_to_idx = {(spans_idx[i, 0].item(), spans_idx[i, 1].item()): i for i in range(len(spans_idx))} @@ -1538,7 +1537,6 @@ def create_batch_dict(self, batch, class_to_ids, id_to_classes, rel_class_to_ids "rel_id_to_classes": rel_id_to_classes, } - def create_relation_labels(self, batch, add_reversed_negatives=True, add_random_negatives=True, negative_ratio=2.0): """Create relation labels with negative pair sampling. @@ -1783,7 +1781,7 @@ def __init__(self, config, tokenizer, words_splitter): def preprocess_example(self, tokens, ner, classes_to_id, relations=None, rel_classes_to_id=None): """Preprocess a single example for joint entity and relation extraction. - Processes both entity annotations (for token-level BIO tagging) and + Processes both entity annotations (for token-level BIO tagging) and relation triplets, ensuring consistent indexing when entities are reordered. Args: @@ -1823,7 +1821,7 @@ def preprocess_example(self, tokens, ner, classes_to_id, relations=None, rel_cla # Create entity index mapping (from sorted entity list index to entities_id index) entity_idx_mapping = {} valid_entity_idx = 0 - + if ner is not None: span_idx_list = [] for ent_idx, (start, end, label) in enumerate(ner): @@ -1846,9 +1844,7 @@ def preprocess_example(self, tokens, ner, classes_to_id, relations=None, rel_cla head_idx, tail_idx, rel_type = rel # Check if both entities are valid and relation type is known - if (head_idx in entity_idx_mapping and - tail_idx in entity_idx_mapping and - rel_type in rel_classes_to_id): + if head_idx in entity_idx_mapping and tail_idx in entity_idx_mapping and rel_type in rel_classes_to_id: mapped_head = entity_idx_mapping[head_idx] mapped_tail = entity_idx_mapping[tail_idx] rel_idx_list.append([mapped_head, mapped_tail]) @@ -1862,7 +1858,6 @@ def preprocess_example(self, tokens, ner, classes_to_id, relations=None, rel_cla rel_idx = torch.zeros(0, 2, dtype=torch.long) rel_label = torch.zeros(0, dtype=torch.long) - return { "tokens": tokens, "seq_length": num_tokens, @@ -1892,18 +1887,18 @@ def create_batch_dict(self, batch, class_to_ids, id_to_classes, rel_class_to_ids entities = [el["entities"] for el in batch] relations = [el["relations"] for el in batch] - if batch[0]['span_idx'] is not None: + if batch[0]["span_idx"] is not None: span_idx_list = [el["span_idx"] for el in batch] - + batch_size = len(span_idx_list) span_counts = [s.size(0) if s.numel() > 0 else 0 for s in span_idx_list] - max_spans = max(max(span_counts), 1) # Ensure at least 1 - + max_spans = max(*span_counts, 1) # Ensure at least 1 + span_mask = torch.zeros(batch_size, max_spans, dtype=torch.bool) for i, count in enumerate(span_counts): if count > 0: span_mask[i, :count] = True - + span_idx = pad_2d_tensor(span_idx_list, padding_value=0) else: span_idx, span_mask = None, None @@ -1926,7 +1921,6 @@ def create_batch_dict(self, batch, class_to_ids, id_to_classes, rel_class_to_ids "rel_id_to_classes": rel_id_to_classes, } - def tokenize_and_prepare_labels(self, batch, prepare_labels, *args, **kwargs): """Tokenize inputs and prepare labels for joint entity-relation extraction. @@ -1940,13 +1934,9 @@ def tokenize_and_prepare_labels(self, batch, prepare_labels, *args, **kwargs): Dictionary containing tokenized inputs, token-level entity labels, relation adjacency matrix, and relation labels. """ - # Use relation-aware tokenize_inputs from RelationExtractionSpanProcessor tokenized_input = self.tokenize_inputs( - batch["tokens"], - batch["classes_to_id"], - blank=None, - relations=batch["rel_class_to_ids"] + batch["tokens"], batch["classes_to_id"], blank=None, relations=batch["rel_class_to_ids"] ) if prepare_labels: @@ -1959,4 +1949,4 @@ def tokenize_and_prepare_labels(self, batch, prepare_labels, *args, **kwargs): tokenized_input["adj_matrix"] = adj_matrix tokenized_input["rel_matrix"] = rel_matrix - return tokenized_input \ No newline at end of file + return tokenized_input diff --git a/gliner/data_processing/utils.py b/gliner/data_processing/utils.py index 1c9e19f..7f50162 100644 --- a/gliner/data_processing/utils.py +++ b/gliner/data_processing/utils.py @@ -15,6 +15,7 @@ def pad_2d_tensor(key_data, padding_value=0.0): Args: key_data: List of 2D tensors to pad. Each tensor can have different dimensions, but all must be 2D. + padding_value: float, value used to fill pad elements. Returns: A 3D tensor of shape (batch_size, max_rows, max_cols) containing all @@ -44,7 +45,9 @@ def pad_2d_tensor(key_data, padding_value=0.0): col_padding = max_cols - cols # Pad the tensor along both dimensions - padded_tensor = torch.nn.functional.pad(tensor, (0, col_padding, 0, row_padding), mode="constant", value=padding_value) + padded_tensor = torch.nn.functional.pad( + tensor, (0, col_padding, 0, row_padding), mode="constant", value=padding_value + ) tensors.append(padded_tensor) # Stack the tensors into a single tensor along a new batch dimension @@ -88,6 +91,7 @@ def get_negatives(batch_list: List[Dict], sampled_neg: int = 5, key="ner") -> Li selected_elements = random.sample(element_types, k=min(sampled_neg, len(element_types))) return selected_elements + def prepare_word_mask( texts: Sequence[Sequence[str]], tokenized_inputs, diff --git a/gliner/decoding/__init__.py b/gliner/decoding/__init__.py index dee4e75..884f947 100644 --- a/gliner/decoding/__init__.py +++ b/gliner/decoding/__init__.py @@ -1 +1 @@ -from .decoder import SpanDecoder, TokenDecoder, SpanRelexDecoder, SpanGenerativeDecoder, TokenRelexDecoder +from .decoder import SpanDecoder, TokenDecoder, SpanRelexDecoder, TokenRelexDecoder, SpanGenerativeDecoder diff --git a/gliner/decoding/decoder.py b/gliner/decoding/decoder.py index 53d2ca4..7deb823 100644 --- a/gliner/decoding/decoder.py +++ b/gliner/decoding/decoder.py @@ -938,13 +938,13 @@ class IDs to class names. for span_pos in valid_indices: span_start = span_idx[i, span_pos, 0].item() span_end = span_idx[i, span_pos, 1].item() - + # Get probabilities for all classes for this span probs = span_probs[i, span_pos] - + # Find classes above threshold class_indices = torch.where(probs > threshold)[0] - + for class_idx in class_indices: class_id = class_idx.item() + 1 # Convert to 1-indexed if class_id in id_to_class_i: @@ -981,13 +981,13 @@ def decode( tokens (List[List[str]]): Tokenized input text for each sample in the batch. id_to_classes (Union[Dict[int, str], List[Dict[int, str]]]): Mapping from class IDs to class names. - model_output (torch.Tensor, optional): Raw logits from the model with shape - (B, L, C, 3), where the last dimension represents [start, end, inside] + model_output (torch.Tensor, optional): Raw logits from the model with shape + (B, L, C, 3), where the last dimension represents [start, end, inside] predictions. Used for token-level decoding. flat_ner (bool): Whether to enforce non-overlapping spans. threshold (float): Confidence threshold for predictions. multi_label (bool): Whether to allow multiple labels per span. - span_logits (torch.Tensor, optional): Span classification logits with shape + span_logits (torch.Tensor, optional): Span classification logits with shape (B, S, C). Used for span-level decoding. span_idx (torch.Tensor, optional): Span indices with shape (B, S, 2). Used for span-level decoding. @@ -1015,7 +1015,7 @@ class IDs to class names. threshold=threshold, multi_label=multi_label, ) - + # Check if token-level decoding is requested if model_output is not None: model_output = model_output.permute(3, 0, 1, 2) @@ -1036,19 +1036,18 @@ class IDs to class names. span_i = self.greedy_search(span_scores, flat_ner, multi_label) spans.append(span_i) return spans - + # Neither decoding mode has sufficient inputs if span_logits is not None or span_idx is not None or span_mask is not None: raise ValueError( - "For span-level decoding, all three parameters must be provided: " - "span_logits, span_idx, and span_mask" + "For span-level decoding, all three parameters must be provided: span_logits, span_idx, and span_mask" ) - + raise ValueError( "Either model_output (for token-level decoding) or " "(span_logits, span_idx, span_mask) (for span-level decoding) must be provided" ) - + class TokenRelexDecoder(TokenDecoder): """Token-based decoder with relation extraction support. @@ -1260,4 +1259,4 @@ def decode( batch_size=len(tokens), ) - return spans, relations \ No newline at end of file + return spans, relations diff --git a/gliner/evaluation/evaluator.py b/gliner/evaluation/evaluator.py index 27f66e3..cc5bd4c 100644 --- a/gliner/evaluation/evaluator.py +++ b/gliner/evaluation/evaluator.py @@ -203,6 +203,7 @@ class BaseRelexEvaluator(BaseEvaluator): The input format expects entity indices rather than entity spans directly. Entity spans are looked up from the entity list using these indices. """ + def get_ground_truth(self, ents, rels): """Extract ground truth relations in evaluation format. diff --git a/gliner/model.py b/gliner/model.py index 01613ff..3568169 100644 --- a/gliner/model.py +++ b/gliner/model.py @@ -34,7 +34,7 @@ UniEncoderTokenRelexConfig, UniEncoderSpanDecoderConfig, ) -from .decoding import SpanDecoder, TokenDecoder, SpanRelexDecoder, SpanGenerativeDecoder, TokenRelexDecoder +from .decoding import SpanDecoder, TokenDecoder, SpanRelexDecoder, TokenRelexDecoder, SpanGenerativeDecoder from .training import Trainer, TrainingArguments from .evaluation import BaseNEREvaluator, BaseRelexEvaluator from .onnx.model import ( @@ -44,7 +44,7 @@ UniEncoderSpanORTModel, UniEncoderTokenORTModel, UniEncoderSpanRelexORTModel, - UniEncoderTokenRelexORTModel + UniEncoderTokenRelexORTModel, ) from .decoding.trie import LabelsTrie from .infer_packing import InferencePackingConfig @@ -66,7 +66,7 @@ UniEncoderTokenProcessor, UniEncoderSpanDecoderProcessor, RelationExtractionSpanProcessor, - RelationExtractionTokenProcessor + RelationExtractionTokenProcessor, ) from .data_processing.collator import ( BiEncoderSpanDataCollator, @@ -75,7 +75,7 @@ UniEncoderTokenDataCollator, UniEncoderSpanDecoderDataCollator, RelationExtractionSpanDataCollator, - RelationExtractionTokenDataCollator + RelationExtractionTokenDataCollator, ) from .data_processing.tokenizer import WordsSplitter @@ -2110,7 +2110,8 @@ def export_to_onnx( "2. Use PyTorch for inference with this model\n" "3. Implement a custom ONNX pipeline with separate encoder/decoder exports" ) - + + class UniEncoderSpanRelexGLiNER(BaseEncoderGLiNER): """GLiNER model for both entity recognition and relation extraction. @@ -2427,13 +2428,13 @@ def evaluate( - rel_f1: Relation extraction F1 score """ self.eval() - + if relation_threshold is None: relation_threshold = threshold - + if adjacency_threshold is None: adjacency_threshold = threshold - + # Create the dataset and data loader dataset = test_data collator = self.data_collator_class( @@ -2446,9 +2447,7 @@ def evaluate( return_rel_id_to_classes=True, prepare_labels=False, ) - data_loader = torch.utils.data.DataLoader( - dataset, batch_size=batch_size, shuffle=False, collate_fn=collator - ) + data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=collator) all_entity_preds = [] all_relation_preds = [] @@ -2462,9 +2461,7 @@ def evaluate( # Get model predictions model_inputs = batch.copy() - model_output = self.model( - **model_inputs, threshold=threshold, adjacency_threshold=adjacency_threshold - ) + model_output = self.model(**model_inputs, threshold=threshold, adjacency_threshold=adjacency_threshold) # Extract logits and relation outputs model_logits = model_output.logits @@ -2520,7 +2517,7 @@ def evaluate( # Format data for relation evaluator: list of (entities, relations) tuples all_true_rel_data = list(zip(all_true_entities, all_true_relations)) all_pred_rel_data = list(zip(all_entity_preds, all_relation_preds)) - + rel_evaluator = BaseRelexEvaluator(all_true_rel_data, all_pred_rel_data) rel_output, rel_f1 = rel_evaluator.evaluate() @@ -2602,6 +2599,7 @@ def forward( return UniEncoderSpanRelexWrapper(core_model) + class UniEncoderTokenRelexGLiNER(UniEncoderSpanRelexGLiNER): """GLiNER model for both entity recognition and relation extraction. @@ -2683,7 +2681,8 @@ def forward( return out.logits, out.rel_idx, out.rel_logits, out.rel_mask return UniEncoderTokenRelexWrapper(core_model) - + + class GLiNER(nn.Module, PyTorchModelHubMixin): """Meta GLiNER class that automatically instantiates the appropriate GLiNER variant. @@ -3071,7 +3070,7 @@ def get_model_type(self) -> str: "BiEncoderTokenGLiNER": "gliner_bi_encoder_token", "UniEncoderSpanDecoderGLiNER": "gliner_uni_encoder_span_decoder", "UniEncoderSpanRelexGLiNER": "gliner_uni_encoder_span_relex", - "UniEncoderTokenRelexGLiNER": "gliner_uni_encoder_token_relex" + "UniEncoderTokenRelexGLiNER": "gliner_uni_encoder_token_relex", } return type_mapping.get(class_name, "unknown") diff --git a/gliner/modeling/base.py b/gliner/modeling/base.py index 572fbcb..b85d9cd 100644 --- a/gliner/modeling/base.py +++ b/gliner/modeling/base.py @@ -233,7 +233,7 @@ def __init__( super().__init__(config, from_pretrained, cache_dir) self.token_rep_layer = Encoder(config, from_pretrained, cache_dir=cache_dir) - if self.config.num_rnn_layers>0: + if self.config.num_rnn_layers > 0: self.rnn = LstmSeq2SeqEncoder(config, num_layers=self.config.num_rnn_layers) if config.post_fusion_schema: @@ -443,7 +443,7 @@ def loss( scores: Predicted scores of shape (B, L, K, C). labels: Ground truth labels of shape (B, L, K, C). prompts_embedding_mask: Mask for valid entity types of shape (B, C). - mask_label: Mask for valid spans of shape (B, L, K). + span_mask: Mask for valid spans of shape (B, L, K). alpha: Focal loss alpha parameter. gamma: Focal loss gamma parameter. prob_margin: Margin for probability adjustment. @@ -513,34 +513,32 @@ def __init__( super().__init__(config, from_pretrained, cache_dir) self.scorer = Scorer(config.hidden_size, config.dropout) - if getattr(config, 'represent_spans', False): + if getattr(config, "represent_spans", False): self.span_rep_layer = SpanRepLayer( span_mode=config.span_mode, hidden_size=config.hidden_size, - max_width=getattr(config, 'max_width', 12), + max_width=getattr(config, "max_width", 12), dropout=config.dropout, ) - def get_span_logits(self, scores, span_idx, span_mask, - words_embedding, prompts_embedding, - labels, threshold): + def get_span_logits(self, scores, span_idx, span_mask, words_embedding, prompts_embedding, labels, threshold): span_logits = None - if getattr(self.config, 'represent_spans', False): + if getattr(self.config, "represent_spans", False): if span_idx is None: span_idx, span_mask = extract_spans_from_tokens(scores, labels, threshold) span_idx = span_idx * span_mask.unsqueeze(-1).long() - + span_rep = self.span_rep_layer(words_embedding, span_idx) span_logits = torch.einsum("BND,BCD->BNC", span_rep, prompts_embedding) return span_logits, span_idx, span_mask - + def forward( self, input_ids: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.LongTensor] = None, words_embedding: Optional[torch.FloatTensor] = None, mask: Optional[torch.LongTensor] = None, - span_idx: Optional[torch.Tensor]=None, + span_idx: Optional[torch.Tensor] = None, span_mask: Optional[torch.Tensor] = None, span_labels: Optional[torch.Tensor] = None, prompts_embedding: Optional[torch.FloatTensor] = None, @@ -557,16 +555,22 @@ def forward( input_ids: Input token IDs of shape (B, L). attention_mask: Attention mask of shape (B, L). words_embedding: Pre-computed word embeddings of shape (B, W, D). - mask: Mask for words of shape (B, W). + mask: Mask for valid words of shape (B, W). + span_idx: Tensor containing span start/end indices of shape (B, S, 2), + where S is the number of spans. + span_mask: Boolean or integer mask indicating valid spans of shape (B, S). + span_labels: Ground truth span labels of shape (B, S, C). prompts_embedding: Pre-computed entity label embeddings of shape (B, C, D). - prompts_embedding_mask: Mask for prompts of shape (B, C). - words_mask: Word boundary mask. - text_lengths: Length of each text sequence. - labels: Ground truth labels of shape (B, W, C). - **kwargs: Additional arguments. + prompts_embedding_mask: Mask for prompts/entities of shape (B, C). + words_mask: Word boundary mask mapping tokens to words. + text_lengths: Length of each text sequence before padding, shape (B,). + labels: Ground truth token-level labels of shape (B, W, C). + threshold: Confidence threshold used for span selection. + **kwargs: Additional arguments passed to the encoder or loss functions + (e.g., ``packing_config``, ``pair_attention_mask``). Returns: - GLiNERBaseOutput containing logits, loss, and intermediate representations. + GLiNERBaseOutput containing logits, loss, embeddings, and span-level outputs. """ encoder_kwargs = {key: kwargs[key] for key in ("packing_config", "pair_attention_mask") if key in kwargs} @@ -588,17 +592,18 @@ def forward( # Shape: (batch_size, seq_len, num_classes, 3), 3 - start, end, inside scores = self.scorer(words_embedding, prompts_embedding) - - span_logits, span_idx, span_mask = self.get_span_logits(scores, span_idx, span_mask, words_embedding, - prompts_embedding, labels, threshold) - + + span_logits, span_idx, span_mask = self.get_span_logits( + scores, span_idx, span_mask, words_embedding, prompts_embedding, labels, threshold + ) + loss = None if labels is not None: loss = self.loss(scores, labels, prompts_embedding_mask, mask, **kwargs) if span_labels is not None: span_loss = self.loss(span_logits, span_labels, prompts_embedding_mask, span_mask, **kwargs) - loss = self.config.token_loss_coef*loss + self.config.span_loss_coef*span_loss + loss = self.config.token_loss_coef * loss + self.config.span_loss_coef * span_loss output = GLiNERBaseOutput( logits=scores, @@ -609,7 +614,7 @@ def forward( mask=mask, span_idx=span_idx, span_logits=span_logits, - span_mask=span_mask + span_mask=span_mask, ) return output @@ -627,24 +632,39 @@ def loss( negatives: float = 1.0, **kwargs: Any, ) -> torch.Tensor: - """Compute token/span classification loss. + """Compute token- or span-level classification loss. Args: - scores: Predicted scores of shape (B, W, C, 3) for tokens or (B, N, C) for spans. - labels: Ground truth labels matching scores shape. + scores: Predicted scores. Shape is (B, W, C, 3) for token-level + classification (start, end, inside) or (B, N, C) for span-level + classification, where B is batch size, W is number of words, + N is number of spans, and C is number of entity types. + labels: Ground truth labels matching ``scores`` shape. prompts_embedding_mask: Mask for valid entity types of shape (B, C). - word_mask: Mask for valid tokens/spans of shape (B, W) or (B, N). - ... + word_mask: Mask for valid tokens or spans of shape (B, W) for + token-level loss or (B, N) for span-level loss. + alpha: Alpha parameter for focal loss. If negative, focal weighting + is disabled. + gamma: Gamma parameter for focal loss. + prob_margin: Margin applied to predicted probabilities. + label_smoothing: Amount of label smoothing applied to targets. + reduction: Reduction method applied to the final loss. One of + ``"none"``, ``"mean"``, or ``"sum"``. + negatives: Weighting factor for negative examples. + **kwargs: Additional unused keyword arguments for API compatibility. + + Returns: + A scalar tensor representing the aggregated loss value. """ all_losses = self._loss(scores, labels, alpha, gamma, prob_margin, label_smoothing, negatives) # Base mask: (B, W/N, C) mask = word_mask.unsqueeze(-1) * prompts_embedding_mask.unsqueeze(1) - + # Only add extra dimension for 4D token-level scores (B, W, C, 3) if all_losses.dim() == 4: mask = mask.unsqueeze(-1) - + all_losses = all_losses * mask if reduction == "mean": @@ -988,7 +1008,7 @@ def forward( labels_attention_mask: Optional[torch.LongTensor] = None, words_embedding: Optional[torch.FloatTensor] = None, mask: Optional[torch.LongTensor] = None, - span_idx: Optional[torch.Tensor]=None, + span_idx: Optional[torch.Tensor] = None, span_mask: Optional[torch.Tensor] = None, span_labels: Optional[torch.Tensor] = None, prompts_embedding: Optional[torch.FloatTensor] = None, @@ -1009,6 +1029,10 @@ def forward( labels_attention_mask: Attention mask for labels. words_embedding: Pre-computed word embeddings. mask: Mask for words. + span_idx: Tensor containing span start/end indices of shape (B, S, 2), + where S is the number of spans. + span_mask: Boolean or integer mask indicating valid spans of shape (B, S). + span_labels: Ground truth span labels of shape (B, S, C). prompts_embedding: Pre-computed entity label embeddings. prompts_embedding_mask: Mask for prompts. words_mask: Word boundary mask. @@ -1046,8 +1070,9 @@ def forward( scores = self.scorer(words_embedding, prompts_embedding) - span_logits, span_idx, span_mask = self.get_span_logits(scores, span_idx, span_mask, words_embedding, - prompts_embedding, labels, threshold) + span_logits, span_idx, span_mask = self.get_span_logits( + scores, span_idx, span_mask, words_embedding, prompts_embedding, labels, threshold + ) loss = None if labels is not None: @@ -1055,7 +1080,7 @@ def forward( if span_labels is not None: span_loss = self.loss(span_logits, span_labels, prompts_embedding_mask, span_mask, **kwargs) - loss = self.config.token_loss_coef*loss + self.config.span_loss_coef*span_loss + loss = self.config.token_loss_coef * loss + self.config.span_loss_coef * span_loss output = GLiNERBaseOutput( logits=scores, @@ -1066,7 +1091,7 @@ def forward( mask=mask, span_idx=span_idx, span_logits=span_logits, - span_mask=span_mask + span_mask=span_mask, ) return output @@ -1689,14 +1714,16 @@ def select_target_embedding( return target_rep, target_mask - - def represent_spans(self, words_embeddings, words_mask, prompts_embeddings, - span_idx: Optional[torch.Tensor]=None, - span_mask: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - threshold: float = 0.5, - ): - + def represent_spans( + self, + words_embeddings, + words_mask, + prompts_embeddings, + span_idx: Optional[torch.Tensor] = None, + span_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + threshold: float = 0.5, + ): span_idx = span_idx * span_mask.unsqueeze(-1).long() span_rep = self.span_rep_layer(words_embeddings, span_idx) scores = torch.einsum("BLKD,BCD->BLKC", span_rep, prompts_embeddings) @@ -1708,7 +1735,7 @@ def represent_spans(self, words_embeddings, words_mask, prompts_embeddings, else: target_span_rep, target_span_mask = None, None return scores, target_span_rep, target_span_mask - + def forward( self, input_ids: Optional[torch.FloatTensor] = None, @@ -1760,17 +1787,17 @@ def forward( token_embeds, input_ids, attention_mask, text_lengths, words_mask ) ) - + if hasattr(self, "rnn"): words_embedding = self.rnn(words_embedding, mask) - if self.config.span_mode=='token_level': + if self.config.span_mode == "token_level": if labels is not None: target_W = labels.shape[1] target_C = max(prompts_embedding.size(1), labels.size(-2)) else: target_W = words_embedding.size(1) - target_C = prompts_embedding.size(1) + target_C = prompts_embedding.size(1) else: target_W = span_idx.size(1) // self.config.max_width target_C = prompts_embedding.size(1) @@ -1786,13 +1813,9 @@ def forward( prompts_embedding = self.prompt_rep_layer(prompts_embedding) batch_size, _, embed_dim = prompts_embedding.shape - scores, target_span_rep, target_span_mask = self.represent_spans(words_embedding, mask, - prompts_embedding, - span_idx, - span_mask, - labels, - threshold - ) + scores, target_span_rep, target_span_mask = self.represent_spans( + words_embedding, mask, prompts_embedding, span_idx, span_mask, labels, threshold + ) pair_idx, pair_mask, pair_scores = None, None, None rel_prompts_embedding_mask = None @@ -2045,7 +2068,7 @@ def loss( scores: Predicted scores of shape (B, W, C). labels: Ground truth labels of shape (B, W, C). prompts_embedding_mask: Mask for valid entity types of shape (B, C). - mask: Mask for valid tokens of shape (B, W). + word_mask: Mask for valid tokens of shape (B, W). alpha: Focal loss alpha parameter. gamma: Focal loss gamma parameter. prob_margin: Margin for probability adjustment. @@ -2073,18 +2096,22 @@ def loss( ) loss = all_losses.sum() return loss - - def represent_spans(self, words_embeddings, words_mask, prompts_embeddings, - span_idx: Optional[torch.Tensor]=None, - span_mask: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - threshold: float = 0.5, - ): + + def represent_spans( + self, + words_embeddings, + words_mask, + prompts_embeddings, + span_idx: Optional[torch.Tensor] = None, + span_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + threshold: float = 0.5, + ): scores = self.scorer(words_embeddings, prompts_embeddings) - + if span_idx is None: span_idx, span_mask = extract_spans_from_tokens(scores, labels, threshold) span_idx = span_idx * span_mask.unsqueeze(-1).long() target_span_rep = self.span_rep_layer(words_embeddings, span_idx) - return scores, target_span_rep, span_mask \ No newline at end of file + return scores, target_span_rep, span_mask diff --git a/gliner/modeling/span_rep.py b/gliner/modeling/span_rep.py index 7202aa8..091c36d 100644 --- a/gliner/modeling/span_rep.py +++ b/gliner/modeling/span_rep.py @@ -633,6 +633,7 @@ def forward(self, x, *args): return out + class TokenMarker(nn.Module): """Marks and projects span endpoints using an MLP. @@ -680,7 +681,8 @@ def forward(self, h: torch.Tensor, span_idx: torch.Tensor) -> torch.Tensor: cat = torch.cat([start_span_rep, end_span_rep], dim=-1).relu() return self.out_project(cat) - + + class SpanRepLayer(nn.Module): """Factory class for various span representation approaches. @@ -738,7 +740,7 @@ def __init__(self, hidden_size, max_width, span_mode, **kwargs): self.span_rep_layer = SpanConv(hidden_size, max_width, span_mode="conv_sum") elif span_mode == "conv_share": self.span_rep_layer = ConvShare(hidden_size, max_width) - elif span_mode == 'token_level': + elif span_mode == "token_level": self.span_rep_layer = TokenMarker(hidden_size, **kwargs) else: raise ValueError(f"Unknown span mode {span_mode}") diff --git a/gliner/modeling/utils.py b/gliner/modeling/utils.py index f28fa57..73b66f2 100644 --- a/gliner/modeling/utils.py +++ b/gliner/modeling/utils.py @@ -295,6 +295,7 @@ def build_entity_pairs( return pair_idx, pair_mask, head_rep, tail_rep + def extract_spans_from_tokens( scores: torch.Tensor, labels: Optional[torch.Tensor] = None, @@ -302,19 +303,19 @@ def extract_spans_from_tokens( ) -> Tuple[torch.Tensor, torch.Tensor]: """ Extract entity spans from BIO-style token predictions. - + Args: scores: (B, W, C, 3) - logits for [start, end, inside] labels: Optional (B, W, C, 3) - ground truth labels threshold: Confidence threshold (used when labels is None) - + Returns: span_idx: (B, N, 2) - [start, end] indices, padded span_mask: (B, N) - validity mask """ B, W, C, _ = scores.shape device = scores.device - + if labels is not None: start_mask = labels[..., 0] > 0.5 end_mask = labels[..., 1] > 0.5 @@ -326,55 +327,53 @@ def extract_spans_from_tokens( inside_mask = probs[..., 2] > threshold # Prepend zeros for cumsum indexing - inside_cumsum = torch.nn.functional.pad( - inside_mask.long().cumsum(dim=1), (0, 0, 1, 0) - ) # (B, W+1, C) - + inside_cumsum = torch.nn.functional.pad(inside_mask.long().cumsum(dim=1), (0, 0, 1, 0)) # (B, W+1, C) + spans_per_sample = [] - + for b in range(B): starts = start_mask[b].nonzero(as_tuple=False) ends = end_mask[b].nonzero(as_tuple=False) - + if starts.size(0) == 0 or ends.size(0) == 0: spans_per_sample.append(torch.empty(0, 2, dtype=torch.long, device=device)) continue - + s_pos, s_cls = starts.T e_pos, e_cls = ends.T - + # Find valid (start, end) pairs: same class & end >= start valid = (s_cls[:, None] == e_cls) & (s_pos[:, None] <= e_pos) si, ei = valid.nonzero(as_tuple=True) - + if si.size(0) == 0: spans_per_sample.append(torch.empty(0, 2, dtype=torch.long, device=device)) continue - + cs, ce, cc = s_pos[si], e_pos[ei], s_cls[si] - + # Validate: all inside positions must be marked inside_cnt = inside_cumsum[b, ce + 1, cc] - inside_cumsum[b, cs, cc] valid = inside_cnt == (ce - cs + 1) - + cs, ce = cs[valid], ce[valid] - + if cs.size(0) == 0: spans_per_sample.append(torch.empty(0, 2, dtype=torch.long, device=device)) else: spans_per_sample.append(torch.stack([cs, ce], dim=1)) - + # Pad to uniform size max_spans = max(s.size(0) for s in spans_per_sample) if spans_per_sample else 0 max_spans = max(max_spans, 1) # Ensure at least 1 to avoid empty tensor issues - + span_idx = torch.zeros(B, max_spans, 2, dtype=torch.long, device=device) span_mask = torch.zeros(B, max_spans, dtype=torch.bool, device=device) - + for b, spans in enumerate(spans_per_sample): n = spans.size(0) if n > 0: span_idx[b, :n] = spans span_mask[b, :n] = True - - return span_idx, span_mask \ No newline at end of file + + return span_idx, span_mask diff --git a/gliner/onnx/model.py b/gliner/onnx/model.py index 66f5085..1fc7b19 100644 --- a/gliner/onnx/model.py +++ b/gliner/onnx/model.py @@ -370,6 +370,7 @@ def forward( ) return outputs + class UniEncoderTokenRelexORTModel(BaseORTModel): """ONNX Runtime model for uni-encoder token-level relation extraction. @@ -416,4 +417,4 @@ def forward( rel_logits=inference_output["rel_logits"], rel_mask=inference_output["rel_mask"], ) - return outputs \ No newline at end of file + return outputs diff --git a/gliner/training/trainer.py b/gliner/training/trainer.py index 1b17e3e..d516bea 100644 --- a/gliner/training/trainer.py +++ b/gliner/training/trainer.py @@ -141,7 +141,7 @@ def training_step(self, model, inputs, *args, **kwargs) -> torch.Tensor: if self.use_apex: from apex import amp - + with amp.scale_loss(loss, self.optimizer) as scaled_loss: scaled_loss.backward() else: From 99cdd24fa4eef1cae1f3d4b85b622e231fe489c2 Mon Sep 17 00:00:00 2001 From: Ingvar Date: Thu, 8 Jan 2026 23:28:31 +0200 Subject: [PATCH 5/8] implement token decoder model architecture --- gliner/config.py | 16 +- gliner/data_processing/__init__.py | 3 + gliner/data_processing/collator.py | 8 + gliner/data_processing/processor.py | 250 +++++++++++++---- gliner/decoding/__init__.py | 2 +- gliner/decoding/decoder.py | 173 +++++++++++- gliner/model.py | 42 ++- gliner/modeling/__init__.py | 2 + gliner/modeling/base.py | 398 ++++++++++++++++++++++++++-- gliner/modeling/outputs.py | 1 - 10 files changed, 816 insertions(+), 79 deletions(-) diff --git a/gliner/config.py b/gliner/config.py index d6f06ec..5fe59a7 100644 --- a/gliner/config.py +++ b/gliner/config.py @@ -174,8 +174,19 @@ def __init__( self.decoder_loss_coef = decoder_loss_coef self.span_loss_coef = span_loss_coef self.model_type = "gliner_uni_encoder_span_decoder" - if self.span_mode == "token_level": - raise ValueError("UniEncoderSpanDecoderConfig requires span_mode != 'token_level'") + + +class UniEncoderTokenDecoderConfig(UniEncoderSpanDecoderConfig): + def __init__( + self, token_loss_coef=1.0, span_loss_coef=1.0, neg_spans_ratio=1.0, **kwargs + ): + super().__init__(**kwargs) + self.span_mode = "token_level" + self.model_type = "gliner_encoder_token_decoder" + self.token_loss_coef = token_loss_coef + self.span_loss_coef = span_loss_coef + self.represent_spans = True + self.neg_spans_ratio = neg_spans_ratio class UniEncoderRelexConfig(UniEncoderConfig): @@ -347,6 +358,7 @@ def model_type(self): "gliner_uni_encoder_span": UniEncoderSpanConfig, "gliner_uni_encoder_token": UniEncoderTokenConfig, "gliner_uni_encoder_span_decoder": UniEncoderSpanDecoderConfig, + "gliner_uni_encoder_token_decoder": UniEncoderTokenDecoderConfig, "gliner_uni_encoder_span_relex": UniEncoderSpanRelexConfig, "gliner_uni_encoder_token_relex": UniEncoderTokenRelexConfig, "gliner_bi_encoder": BiEncoderConfig, diff --git a/gliner/data_processing/__init__.py b/gliner/data_processing/__init__.py index 7aed30a..fcd9e8b 100644 --- a/gliner/data_processing/__init__.py +++ b/gliner/data_processing/__init__.py @@ -4,7 +4,9 @@ UniEncoderSpanDataCollator, UniEncoderTokenDataCollator, UniEncoderSpanDecoderDataCollator, + UniEncoderTokenDecoderDataCollator, RelationExtractionSpanDataCollator, + RelationExtractionTokenDataCollator ) from .processor import ( BaseProcessor, @@ -14,6 +16,7 @@ UniEncoderSpanProcessor, UniEncoderTokenProcessor, UniEncoderSpanDecoderProcessor, + UniEncoderTokenDecoderProcessor, RelationExtractionSpanProcessor, RelationExtractionTokenProcessor, ) diff --git a/gliner/data_processing/collator.py b/gliner/data_processing/collator.py index dabf6ea..e3925f1 100644 --- a/gliner/data_processing/collator.py +++ b/gliner/data_processing/collator.py @@ -517,6 +517,14 @@ class UniEncoderSpanDecoderDataCollator(SpanDataCollator): pass +class UniEncoderTokenDecoderDataCollator(UniEncoderSpanDecoderDataCollator): + """ + Backward compatibility alias for UniEncoderTokenDecoderDataCollator with UniEncoderSpanDecoderDataCollator. + """ + + pass + + class UniEncoderTokenDataCollator(TokenDataCollator): """ Backward compatibility alias for TokenDataCollator with UniEncoderTokenProcessor. diff --git a/gliner/data_processing/processor.py b/gliner/data_processing/processor.py index 472238f..f991ddf 100644 --- a/gliner/data_processing/processor.py +++ b/gliner/data_processing/processor.py @@ -35,8 +35,8 @@ def __init__(self, config, tokenizer, words_splitter): self.words_splitter = WordsSplitter(splitter_type=config.words_splitter_type) else: self.words_splitter = words_splitter - self.ent_token = config.ent_token - self.sep_token = config.sep_token + self.ent_token = getattr(config, 'ent_token', '[ENT]') + self.sep_token = getattr(config, 'sep_token', '[SEP]') # Check if the tokenizer has unk_token and pad_token self._check_and_set_special_tokens(self.transformer_tokenizer) @@ -651,38 +651,7 @@ def _generate_negative_spans(self, positive_spans, num_tokens, num_negatives, ma return negative_spans - def preprocess_example(self, tokens, ner, classes_to_id): - """Preprocess a single example for token-based prediction. - - Args: - tokens: List of token strings. - ner: List of NER annotations as (start, end, label) tuples. - classes_to_id: Mapping from class labels to integer IDs. - - Returns: - Dictionary containing: - - tokens: Token strings - - seq_length: Sequence length - - entities: Original NER annotations - - span_idx: Tensor of entity span indices (if represent_spans=True) - - span_label: Tensor of entity class IDs (if represent_spans=True) - - Warnings: - UserWarning: If sequence length exceeds max_len (gets truncated). - """ - # Ensure there is always a token list, even if it's empty - if len(tokens) == 0: - tokens = ["[PAD]"] - - # Limit the length of tokens based on configuration maximum length - max_len = self.config.max_len - if len(tokens) > max_len: - warnings.warn(f"Sentence of length {len(tokens)} has been truncated to {max_len}", stacklevel=2) - tokens = tokens[:max_len] - - num_tokens = len(tokens) - - # Create span representations if configured + def prepare_span_idx(self, ner, classes_to_id, num_tokens): if ner is not None and self.config.represent_spans: span_idx_list = [] span_label_list = [] @@ -715,6 +684,40 @@ def preprocess_example(self, tokens, ner, classes_to_id): span_label = torch.zeros(0, dtype=torch.long) else: span_idx, span_label = None, None + return span_idx, span_label + + def preprocess_example(self, tokens, ner, classes_to_id): + """Preprocess a single example for token-based prediction. + + Args: + tokens: List of token strings. + ner: List of NER annotations as (start, end, label) tuples. + classes_to_id: Mapping from class labels to integer IDs. + + Returns: + Dictionary containing: + - tokens: Token strings + - seq_length: Sequence length + - entities: Original NER annotations + - span_idx: Tensor of entity span indices (if represent_spans=True) + - span_label: Tensor of entity class IDs (if represent_spans=True) + + Warnings: + UserWarning: If sequence length exceeds max_len (gets truncated). + """ + # Ensure there is always a token list, even if it's empty + if len(tokens) == 0: + tokens = ["[PAD]"] + + # Limit the length of tokens based on configuration maximum length + max_len = self.config.max_len + if len(tokens) > max_len: + warnings.warn(f"Sentence of length {len(tokens)} has been truncated to {max_len}", stacklevel=2) + tokens = tokens[:max_len] + + num_tokens = len(tokens) + + span_idx, span_label = self.prepare_span_idx(ner, classes_to_id, num_tokens) example = { "tokens": tokens, @@ -1133,6 +1136,20 @@ def tokenize_inputs(self, texts, entities, blank=None): return tokenized_inputs + def prepare_decoder_labels(self, decoder_label_strings): + if not decoder_label_strings: + decoder_label_strings = ["other"] + + decoder_tokenized_input = self.decoder_tokenizer( + decoder_label_strings, return_tensors="pt", truncation=True, padding="longest", add_special_tokens=True + ) + decoder_input_ids = decoder_tokenized_input["input_ids"] + decoder_attention_mask = decoder_tokenized_input["attention_mask"] + decoder_labels = decoder_input_ids.clone() + decoder_labels.masked_fill(~decoder_attention_mask.bool(), -100) + decoder_tokenized_input["labels"] = decoder_labels + return decoder_tokenized_input + def create_labels(self, batch, blank=None): """Create labels for both span classification and decoder generation. @@ -1195,19 +1212,7 @@ def create_labels(self, batch, blank=None): labels_batch = pad_2d_tensor(labels_batch) if len(labels_batch) > 1 else labels_batch[0].unsqueeze(0) - decoder_tokenized_input = None - - if not decoder_label_strings: - decoder_label_strings = ["other"] - - decoder_tokenized_input = self.decoder_tokenizer( - decoder_label_strings, return_tensors="pt", truncation=True, padding="longest", add_special_tokens=True - ) - decoder_input_ids = decoder_tokenized_input["input_ids"] - decoder_attention_mask = decoder_tokenized_input["attention_mask"] - decoder_labels = decoder_input_ids.clone() - decoder_labels.masked_fill(~decoder_attention_mask.bool(), -100) - decoder_tokenized_input["labels"] = decoder_labels + decoder_tokenized_input = self.prepare_decoder_labels(decoder_label_strings) return labels_batch, decoder_tokenized_input def tokenize_and_prepare_labels(self, batch, prepare_labels, *args, **kwargs): @@ -1240,6 +1245,157 @@ def tokenize_and_prepare_labels(self, batch, prepare_labels, *args, **kwargs): return tokenized_input +class UniEncoderTokenDecoderProcessor(UniEncoderSpanDecoderProcessor, UniEncoderTokenProcessor): + """Processor for token-based NER with encoder-decoder architecture. + + This processor combines token-level BIO-style classification with a decoder + that generates entity type labels autoregressively, enabling more flexible + prediction strategies for token-level NER tasks. + + Inherits from: + - UniEncoderSpanDecoderProcessor: Encoder-decoder architecture and decoder utilities + - UniEncoderTokenProcessor: Token-level BIO tagging for entities + """ + + def __init__(self, config, tokenizer, words_splitter, decoder_tokenizer): + """Initialize the token-level encoder-decoder processor. + + Args: + config: Configuration object. + tokenizer: Transformer tokenizer for encoding. + words_splitter: Word-level tokenizer/splitter. + decoder_tokenizer: Separate tokenizer for decoder (label generation). + """ + # Initialize BaseProcessor through UniEncoderSpanDecoderProcessor's chain + super().__init__(config, tokenizer, words_splitter, decoder_tokenizer) + + def preprocess_example(self, tokens, ner, classes_to_id): + """Preprocess a single example for token-level encoder-decoder prediction. + + Uses token-level preprocessing from UniEncoderTokenProcessor while + preparing for decoder-based label generation. + + Args: + tokens: List of token strings. + ner: List of NER annotations as (start, end, label) tuples. + classes_to_id: Mapping from class labels to integer IDs. + + Returns: + Dictionary containing: + - tokens: Token strings + - seq_length: Sequence length + - entities: Original NER annotations + - span_idx: Tensor of entity span indices (if represent_spans=True) + - span_label: Tensor of entity class IDs (if represent_spans=True) + + Warnings: + UserWarning: If sequence length exceeds max_len (gets truncated). + """ + # Use token processor's preprocessing + return UniEncoderTokenProcessor.preprocess_example(self, tokens, ner, classes_to_id) + + def create_batch_dict(self, batch, class_to_ids, id_to_classes): + """Create a batch dictionary from preprocessed token examples. + + Args: + batch: List of preprocessed example dictionaries. + class_to_ids: List of class-to-ID mappings. + id_to_classes: List of ID-to-class mappings. + + Returns: + Dictionary containing all batch data for token-level encoder-decoder + processing. + """ + # Use token processor's batch dict creation + return UniEncoderTokenProcessor.create_batch_dict(self, batch, class_to_ids, id_to_classes) + + def create_labels(self, batch, blank=None): + """Create labels for both token classification and decoder generation. + + Creates both token-level BIO labels and decoder generation labels for + entity types. + + Args: + batch: Batch dictionary containing tokens, entities, and class mappings. + blank: Optional blank entity token for zero-shot scenarios. + + Returns: + Tuple containing: + - Token-level labels (BIO-style, shape: [batch_size, seq_len, num_classes, 3]) + - Decoder generation labels (tokenized entity types) or None + """ + # Create token-level labels + token_labels = UniEncoderTokenProcessor.create_labels(self, batch) + + # Create decoder labels + decoder_label_strings = [] + + for i in range(len(batch["tokens"])): + tokens = batch["tokens"][i] + classes_to_id = batch["classes_to_id"][i] + ner = batch["entities"][i] + + num_tokens = len(tokens) + if self.config.decoder_mode == "span": + # Collect entity labels in order of appearance + sorted_entities = sorted(ner, key=lambda x: (x[0], x[1])) if ner else [] + for start, end, label in sorted_entities: + if label in classes_to_id and end < num_tokens: + decoder_label_strings.append(label) + elif self.config.decoder_mode == "prompt": + # Use all entity types as decoder labels + decoder_label_strings.extend(list(classes_to_id)) + + decoder_tokenized_input = self.prepare_decoder_labels(decoder_label_strings) + + return token_labels, decoder_tokenized_input + + def tokenize_and_prepare_labels(self, batch, prepare_labels, *args, **kwargs): + """Tokenize inputs and prepare labels for token-level encoder-decoder training. + + Combines token-level input processing with decoder inputs and prepares + both token-level BIO labels and decoder generation labels. + + Args: + batch: Batch dictionary with tokens and class mappings. + prepare_labels: Whether to prepare labels. + *args: Variable length argument list. + **kwargs: Arbitrary keyword arguments. + + Returns: + Dictionary containing encoder inputs, decoder inputs, token-level labels, + and decoder labels. + """ + blank = None + if random.uniform(0, 1) < self.config.blank_entity_prob and prepare_labels: + blank = "entity" + + # Use span decoder's tokenize_inputs for encoder-decoder tokenization + tokenized_input = UniEncoderSpanDecoderProcessor.tokenize_inputs( + self, batch["tokens"], batch["classes_to_id"], blank + ) + + if prepare_labels: + # Create both token-level and decoder labels + token_labels, decoder_tokenized_input = self.create_labels(batch, blank=blank) + tokenized_input["labels"] = token_labels + + # Add span-level one-hot labels if spans are represented + if batch.get("span_idx") is not None: + span_labels = self.create_span_labels(batch) + tokenized_input["span_labels"] = span_labels + tokenized_input["span_idx"] = batch["span_idx"] + tokenized_input["span_mask"] = batch["span_mask"] + + # Add decoder labels + if decoder_tokenized_input is not None: + tokenized_input["decoder_labels_ids"] = decoder_tokenized_input["input_ids"] + tokenized_input["decoder_labels_mask"] = decoder_tokenized_input["attention_mask"] + tokenized_input["decoder_labels"] = decoder_tokenized_input["labels"] + + return tokenized_input + + class RelationExtractionSpanProcessor(UniEncoderSpanProcessor): """Processor for joint entity and relation extraction. diff --git a/gliner/decoding/__init__.py b/gliner/decoding/__init__.py index 884f947..24e4a71 100644 --- a/gliner/decoding/__init__.py +++ b/gliner/decoding/__init__.py @@ -1 +1 @@ -from .decoder import SpanDecoder, TokenDecoder, SpanRelexDecoder, TokenRelexDecoder, SpanGenerativeDecoder +from .decoder import SpanDecoder, TokenDecoder, SpanRelexDecoder, TokenRelexDecoder, SpanGenerativeDecoder, TokenGenerativeDecoder diff --git a/gliner/decoding/decoder.py b/gliner/decoding/decoder.py index 7deb823..fce0e95 100644 --- a/gliner/decoding/decoder.py +++ b/gliner/decoding/decoder.py @@ -560,7 +560,7 @@ class IDs to class names. multi_label=multi_label, ) - + class SpanRelexDecoder(BaseSpanDecoder): """Span decoder with relation extraction support. @@ -1260,3 +1260,174 @@ def decode( ) return spans, relations + + +class TokenGenerativeDecoder(TokenDecoder, SpanGenerativeDecoder): + """Token-based decoder with generative label support. + + Extends the token decoder to support generated labels from an encoder-decoder + architecture. Supports two decoder modes: + - 'prompt': Generated labels replace the original class names + - 'span': Generated labels are added as additional fields to each span + + Returns spans in format: (start, end, entity_type, generated_entity_type, score) + """ + + def decode_generative( + self, + tokens: List[List[str]], + id_to_classes: Union[Dict[int, str], List[Dict[int, str]]], + model_output: torch.Tensor, + gen_labels: List[str], + sel_idx: Optional[torch.LongTensor] = None, + num_gen_sequences: int = 1, + flat_ner: bool = False, + threshold: float = 0.5, + multi_label: bool = False, + span_logits: Optional[torch.Tensor] = None, + span_idx: Optional[torch.Tensor] = None, + span_mask: Optional[torch.Tensor] = None, + ) -> List[List[tuple]]: + """Decode model output with generated labels. + + Handles both 'prompt' and 'span' decoder modes: + - prompt mode: Generated labels replace class names in id_to_classes + - span mode: Generated labels are added to span tuples via span_label_map + + Args: + tokens (List[List[str]]): Tokenized input text for each sample in the batch. + id_to_classes (Union[Dict[int, str], List[Dict[int, str]]]): Mapping from + class IDs to class names. + model_output (torch.Tensor): Raw logits from the model with shape (B, W, C, 3). + gen_labels (List[str]): Generated labels from the decoder, flattened across batch. + sel_idx (Optional[torch.LongTensor]): Tensor of shape (B, M) with selected + span indices. Required for span mode, unused for prompt mode. + num_gen_sequences (int): Number of label sequences generated per span. + flat_ner (bool): Whether to enforce non-overlapping spans. + threshold (float): Confidence threshold for span predictions. + multi_label (bool): Whether to allow multiple labels per span. + span_logits (torch.Tensor, optional): Span classification logits. + span_idx (torch.Tensor, optional): Span indices. + span_mask (torch.Tensor, optional): Span mask. + + Returns: + List[List[tuple]]: For each sample, list of span tuples with generated labels. + """ + B = model_output.size(0) + + # Handle prompt mode: update id_to_classes with generated labels + if self.config.decoder_mode == "prompt": + id_to_classes = self._update_id_to_classes_with_generated(id_to_classes, gen_labels, B) + span_label_maps = [{} for _ in range(B)] + + # Handle span mode: build span_label_map from sel_idx and gen_labels + elif self.config.decoder_mode == "span": + if sel_idx is not None: + span_label_maps = self._build_span_label_map_for_batch(sel_idx, gen_labels, num_gen_sequences) + else: + span_label_maps = [{} for _ in range(B)] + else: + span_label_maps = [{} for _ in range(B)] + + + batch_size = span_logits.size(0) + spans = [] + + span_probs = torch.sigmoid(span_logits) + + for i in range(batch_size): + id_to_class_i = self._get_id_to_class_for_sample(id_to_classes, i) + span_label_map_i = span_label_maps[i] + span_scores = [] + + valid_mask = span_mask[i] + valid_indices = torch.where(valid_mask)[0] + + for span_pos in valid_indices: + span_start = span_idx[i, span_pos, 0].item() + span_end = span_idx[i, span_pos, 1].item() + + probs = span_probs[i, span_pos] + class_indices = torch.where(probs > threshold)[0] + + for class_idx in class_indices: + class_id = class_idx.item() + 1 + if class_id in id_to_class_i: + entity_type = id_to_class_i[class_id] + score = probs[class_idx].item() + gen_label = span_label_map_i.get(span_pos.item()) + span_scores.append((span_start, span_end, entity_type, gen_label, score)) + + span_i = self.greedy_search(span_scores, flat_ner, multi_label) + spans.append(span_i) + + return spans + + def decode( + self, + tokens: List[List[str]], + id_to_classes: Union[Dict[int, str], List[Dict[int, str]]], + model_output: Optional[torch.Tensor] = None, + flat_ner: bool = False, + threshold: float = 0.5, + multi_label: bool = False, + gen_labels: Optional[List[str]] = None, + sel_idx: Optional[torch.LongTensor] = None, + num_gen_sequences: int = 1, + span_logits: Optional[torch.Tensor] = None, + span_idx: Optional[torch.Tensor] = None, + span_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> List[List[tuple]]: + """Decode model output, with optional generative label support. + + If gen_labels are provided and decoder has a labels_decoder, uses generative + decoding. Otherwise falls back to standard token decoding. + + Args: + tokens: Tokenized input text. + id_to_classes: Class ID to name mapping. + model_output: Token-level logits (B, W, C, 3). + flat_ner: Whether to enforce non-overlapping spans. + threshold: Confidence threshold. + multi_label: Allow multiple labels per span. + gen_labels: Generated labels from decoder. + sel_idx: Selected span indices for span mode. + num_gen_sequences: Number of sequences per span. + span_logits: Span classification logits. + span_idx: Span indices. + span_mask: Span mask. + **kwargs: Additional arguments. + + Returns: + List of span tuples, with generated labels if available. + """ + # Use generative decoding if labels_decoder is configured and gen_labels provided + if self.config.labels_decoder is not None and gen_labels is not None: + return self.decode_generative( + tokens=tokens, + id_to_classes=id_to_classes, + model_output=model_output, + gen_labels=gen_labels, + sel_idx=sel_idx, + num_gen_sequences=num_gen_sequences, + flat_ner=flat_ner, + threshold=threshold, + multi_label=multi_label, + span_logits=span_logits, + span_idx=span_idx, + span_mask=span_mask, + ) + + # Fall back to standard decoding without generative labels + return super().decode( + tokens=tokens, + id_to_classes=id_to_classes, + model_output=model_output, + flat_ner=flat_ner, + threshold=threshold, + multi_label=multi_label, + span_logits=span_logits, + span_idx=span_idx, + span_mask=span_mask, + ) \ No newline at end of file diff --git a/gliner/model.py b/gliner/model.py index 3568169..a978ffb 100644 --- a/gliner/model.py +++ b/gliner/model.py @@ -33,8 +33,16 @@ UniEncoderSpanRelexConfig, UniEncoderTokenRelexConfig, UniEncoderSpanDecoderConfig, + UniEncoderTokenDecoderConfig ) -from .decoding import SpanDecoder, TokenDecoder, SpanRelexDecoder, TokenRelexDecoder, SpanGenerativeDecoder +from .decoding import ( + SpanDecoder, + TokenDecoder, + SpanRelexDecoder, + TokenRelexDecoder, + SpanGenerativeDecoder, + TokenGenerativeDecoder + ) from .training import Trainer, TrainingArguments from .evaluation import BaseNEREvaluator, BaseRelexEvaluator from .onnx.model import ( @@ -57,6 +65,7 @@ UniEncoderSpanRelexModel, UniEncoderTokenRelexModel, UniEncoderSpanDecoderModel, + UniEncoderTokenDecoderModel ) from .data_processing import ( BaseProcessor, @@ -65,6 +74,7 @@ UniEncoderSpanProcessor, UniEncoderTokenProcessor, UniEncoderSpanDecoderProcessor, + UniEncoderTokenDecoderProcessor, RelationExtractionSpanProcessor, RelationExtractionTokenProcessor, ) @@ -74,6 +84,7 @@ UniEncoderSpanDataCollator, UniEncoderTokenDataCollator, UniEncoderSpanDecoderDataCollator, + UniEncoderTokenDecoderDataCollator, RelationExtractionSpanDataCollator, RelationExtractionTokenDataCollator, ) @@ -2112,6 +2123,21 @@ def export_to_onnx( ) +class UniEncoderTokenDecoderGLiNER(UniEncoderSpanDecoderGLiNER): + """GLiNER model with token-based encoding and label decoding capabilities. + + Combines token-level BIO tagging with a decoder that generates entity type + labels autoregressively. + """ + + config_class = UniEncoderTokenDecoderConfig + model_class = UniEncoderTokenDecoderModel + ort_model_class = None + data_processor_class = UniEncoderTokenDecoderProcessor + data_collator_class = UniEncoderTokenDecoderDataCollator + decoder_class = TokenGenerativeDecoder + + class UniEncoderSpanRelexGLiNER(BaseEncoderGLiNER): """GLiNER model for both entity recognition and relation extraction. @@ -2759,21 +2785,12 @@ def __init__(self, config: Union[str, Path, GLiNERConfig], **kwargs): @staticmethod def _get_gliner_class(config: GLiNERConfig): - """Determine the appropriate GLiNER class based on configuration. - - Args: - config: GLiNER configuration object. - - Returns: - The appropriate GLiNER class type. - """ + """Determine the appropriate GLiNER class based on configuration.""" is_token_level = config.span_mode == "token_level" has_labels_encoder = config.labels_encoder is not None has_labels_decoder = config.labels_decoder is not None has_relations = config.relations_layer is not None - # Priority order: relations > decoder > bi-encoder > token vs span - if has_relations: if is_token_level: return UniEncoderTokenRelexGLiNER @@ -2787,6 +2804,8 @@ def _get_gliner_class(config: GLiNERConfig): "Using decoder model (labels_encoder will be ignored).", stacklevel=2, ) + if is_token_level: + return UniEncoderTokenDecoderGLiNER return UniEncoderSpanDecoderGLiNER if has_labels_encoder: @@ -2795,7 +2814,6 @@ def _get_gliner_class(config: GLiNERConfig): else: return BiEncoderSpanGLiNER - # Default: uni-encoder if is_token_level: return UniEncoderTokenGLiNER else: diff --git a/gliner/modeling/__init__.py b/gliner/modeling/__init__.py index 2e3095f..e0d5a22 100644 --- a/gliner/modeling/__init__.py +++ b/gliner/modeling/__init__.py @@ -7,5 +7,7 @@ UniEncoderSpanModel, UniEncoderTokenModel, UniEncoderSpanRelexModel, + UniEncoderTokenRelexModel, UniEncoderSpanDecoderModel, + UniEncoderTokenDecoderModel ) diff --git a/gliner/modeling/base.py b/gliner/modeling/base.py index b85d9cd..cb8ce36 100644 --- a/gliner/modeling/base.py +++ b/gliner/modeling/base.py @@ -521,16 +521,24 @@ def __init__( dropout=config.dropout, ) - def get_span_logits(self, scores, span_idx, span_mask, words_embedding, prompts_embedding, labels, threshold): - span_logits = None - if getattr(self.config, "represent_spans", False): - if span_idx is None: - span_idx, span_mask = extract_spans_from_tokens(scores, labels, threshold) - span_idx = span_idx * span_mask.unsqueeze(-1).long() + def get_span_representations(self, scores, span_idx, span_mask, words_embedding, labels, threshold): + if span_idx is None: + span_idx, span_mask = extract_spans_from_tokens(scores, labels, threshold) + span_idx = span_idx * span_mask.unsqueeze(-1).long() + if getattr(self.config, "represent_spans", False): span_rep = self.span_rep_layer(words_embedding, span_idx) - span_logits = torch.einsum("BND,BCD->BNC", span_rep, prompts_embedding) - return span_logits, span_idx, span_mask + else: + span_rep = words_embedding[span_idx[:,:, 0]] + words_embedding[span_idx[:,:, 1]] + + # span_rep = torch.zeros(B, S, D, device=words_embedding.device, dtype=words_embedding.dtype) + # for b in range(B): + # for s in range(S): + # if span_mask[b, s]: + # start, end = span_idx[b, s] + # span_rep[b, s] = words_embedding[b, start:end+1].mean(dim=0) + + return span_rep, span_idx, span_mask def forward( self, @@ -593,10 +601,13 @@ def forward( # Shape: (batch_size, seq_len, num_classes, 3), 3 - start, end, inside scores = self.scorer(words_embedding, prompts_embedding) - span_logits, span_idx, span_mask = self.get_span_logits( - scores, span_idx, span_mask, words_embedding, prompts_embedding, labels, threshold - ) - + if getattr(self.config, "represent_spans", False): + span_rep, span_idx, span_mask = self.get_span_representations( + scores, span_idx, span_mask, words_embedding, labels, threshold + ) + span_logits = torch.einsum("BND,BCD->BNC", span_rep, prompts_embedding) + else: + span_logits, span_idx, span_mask = None, None, None loss = None if labels is not None: loss = self.loss(scores, labels, prompts_embedding_mask, mask, **kwargs) @@ -1070,9 +1081,13 @@ def forward( scores = self.scorer(words_embedding, prompts_embedding) - span_logits, span_idx, span_mask = self.get_span_logits( - scores, span_idx, span_mask, words_embedding, prompts_embedding, labels, threshold - ) + if getattr(self.config, "represent_spans", False): + span_rep, span_idx, span_mask = self.get_span_representations( + scores, span_idx, span_mask, words_embedding, labels, threshold + ) + span_logits = torch.einsum("BND,BCD->BNC", span_rep, prompts_embedding) + else: + span_logits, span_idx, span_mask = None, None, None loss = None if labels is not None: @@ -1119,6 +1134,9 @@ def __init__( cache_dir: Directory for caching pretrained models. """ super().__init__(config, from_pretrained, cache_dir) + self.__init_decoder__(config, from_pretrained, cache_dir) + + def __init_decoder__(self, config, from_pretrained, cache_dir): self.decoder = Decoder(config, from_pretrained, cache_dir=cache_dir) if self.config.hidden_size != self.decoder.decoder_hidden_size: self._enc2dec_proj = create_projection_layer( @@ -1586,6 +1604,356 @@ def loss( return loss +class UniEncoderTokenDecoderModel(UniEncoderTokenModel, UniEncoderSpanDecoderModel): + """Token-based NER model with encoder-decoder architecture. + + This model combines token-level BIO-style classification with a decoder + that generates entity type labels autoregressively, enabling more flexible + prediction strategies for token-level NER tasks. + + Inherits from: + - UniEncoderTokenModel: Token-level BIO tagging for entities + - UniEncoderSpanDecoderModel: Encoder-decoder architecture and decoder utilities + + Attributes: + scorer (Scorer): Scoring layer for computing token-label compatibility. + decoder (Decoder): Decoder module for label generation. + span_rep_layer (SpanRepLayer): Layer for computing span representations (if represent_spans=True). + _enc2dec_proj (Optional[nn.Module]): Projection layer if encoder and decoder + dimensions differ. + """ + + def __init__( + self, config: Any, from_pretrained: bool = False, cache_dir: Optional[Union[str, Path]] = None + ) -> None: + """Initialize the token-level encoder-decoder model. + + Args: + config: Model configuration object. + from_pretrained: Whether to load from pretrained weights. + cache_dir: Directory for caching pretrained models. + """ + # Initialize through UniEncoderTokenModel's chain to get scorer + super().__init__(config, from_pretrained, cache_dir) + self.__init_decoder__(config, from_pretrained, cache_dir) + + + def select_token_decoder_embedding( + self, + prompts_embedding: torch.Tensor, + prompts_embedding_mask: torch.Tensor, + span_logits: Optional[torch.Tensor] = None, + span_rep: Optional[torch.Tensor] = None, + span_idx: Optional[torch.Tensor] = None, + span_mask: Optional[torch.Tensor] = None, + span_labels: Optional[torch.Tensor] = None, + decoder_text_embeds: Optional[torch.Tensor] = None, + decoder_words_mask: Optional[torch.Tensor] = None, + top_k: Optional[int] = None, + ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: + """Select entity embeddings for decoder input based on token predictions or labels. + + This method extracts entity spans from token-level predictions and prepares + their representations for the decoder. It can operate in two modes: + 1. "prompt" mode: Uses entity type embeddings as decoder input. + 2. "span" mode: Uses contextualized tokens within each detected span as decoder input. + + Args: + prompts_embedding: Entity type embeddings of shape (B, C, D). + prompts_embedding_mask: Mask for prompts of shape (B, C). + words_embedding: Word-level embeddings of shape (B, W, D). + token_scores: Token classification scores of shape (B, W, C, 3). + word_mask: Mask for valid words of shape (B, W). + span_rep: Span representations (B, S, D). + span_idx: Pre-computed span indices of shape (B, S, 2). + span_mask: Pre-computed span mask of shape (B, S). + decoder_text_embeds: Text embeddings for span mode of shape (B, T, D). + decoder_words_mask: Word position mask of shape (B, T). + labels: Ground truth token-level labels of shape (B, W, C, 3). + threshold: Confidence threshold for selecting spans. + top_k: Optional limit on number of spans to select. + + Returns: + Tuple containing: + - span_rep_kept: Selected span embeddings for decoder of shape (B, S, T, D) + or None if no valid spans. + - span_msk: Mask for selected spans of shape (B, S, T) or None. + - span_sel_idx: Original indices of selected spans of shape (B, S) or None. + """ + if self.config.decoder_mode == "prompt": + return self.select_decoder_embedding(prompts_embedding, prompts_embedding_mask)[:3] + + span_scores = torch.sigmoid(span_logits) + + # Flatten span representations for selection + B = span_rep.size(0) + + if span_labels is not None: + # During training: select spans where any class is positive (label == 1) + span_prob = span_labels.max(-1).values # (B, S) + keep = (span_prob == 1) & span_mask.bool() + else: + # During inference: use predicted scores + span_scores = torch.sigmoid(span_logits).max(-1).values # (B, S) + keep = (span_scores > 0.5) & span_mask.bool + + if top_k: + sel_scores = span_scores.masked_fill(~keep, -1.0) + top_idx = sel_scores.topk(k=min(top_k, sel_scores.size(1)), dim=1).indices + keep.zero_() + keep.scatter_(1, top_idx, True) + + # Pack valid spans + span_rep_kept, span_msk, span_sel_idx = self.select_decoder_embedding(span_rep, keep.long()) + + if hasattr(self, "_enc2dec_proj"): + span_rep_kept = self._enc2dec_proj(span_rep_kept) + + span_rep_kept = span_rep_kept.unsqueeze(2) + span_msk = span_msk.unsqueeze(-1) + + if decoder_text_embeds is None or decoder_words_mask is None: + return span_rep_kept, span_msk, span_sel_idx + + if span_rep_kept.numel() == 0: + return None, None, None + + decoder_text_embeds = decoder_text_embeds.to(dtype=span_rep_kept.dtype) + + # Build span representations with context tokens + S_kept = span_rep_kept.shape[1] + dec_D = span_rep_kept.shape[-1] + + # Get actual span boundaries from selected indices + batch_indices = torch.arange(B, device=span_idx.device).unsqueeze(1).expand(B, S_kept) + selected_spans = span_idx[batch_indices, span_sel_idx] # (B, S_kept, 2) + + span_start = selected_spans[:, :, 0] # (B, S_kept) + span_end = selected_spans[:, :, 1] # (B, S_kept) + + # Determine which decoder tokens belong to each span + token_in_span = (decoder_words_mask.unsqueeze(1) >= span_start.unsqueeze(-1)) & ( + decoder_words_mask.unsqueeze(1) <= span_end.unsqueeze(-1) + ) + + tokens_per_span = token_in_span.sum(-1) + max_tokens = int(tokens_per_span.max()) + + span_rep_new = span_rep_kept.new_zeros(B, S_kept, max_tokens + 1, dec_D) + span_rep_mask = torch.zeros(B, S_kept, max_tokens + 1, dtype=torch.bool, device=decoder_text_embeds.device) + + left_offset = (max_tokens + 1 - tokens_per_span).clamp(min=0) + pos_in_span = (token_in_span.cumsum(-1) - 1).masked_fill(~token_in_span, 0) + pos_in_span = pos_in_span + left_offset.unsqueeze(-1) + + b_idx, s_idx, tok_idx = torch.where(token_in_span) + span_rep_new[b_idx, s_idx, pos_in_span[b_idx, s_idx, tok_idx]] = decoder_text_embeds[b_idx, tok_idx] + span_rep_mask[b_idx, s_idx, pos_in_span[b_idx, s_idx, tok_idx]] = True + + kept_pos = (left_offset - 1).clamp(min=0) + + b_flat = torch.arange(B, device=decoder_text_embeds.device).view(-1, 1).expand(B, S_kept).reshape(-1) + s_flat = torch.arange(S_kept, device=decoder_text_embeds.device).view(1, -1).expand(B, S_kept).reshape(-1) + t_flat = kept_pos.reshape(-1) + + span_rep_new[b_flat, s_flat, t_flat] = span_rep_kept.reshape(-1, dec_D) + span_rep_mask[b_flat, s_flat, t_flat] = True + span_rep_mask = span_rep_mask & span_msk.bool() + + return span_rep_new, span_rep_mask, span_sel_idx + + def forward( + self, + input_ids: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + decoder_input_ids: Optional[torch.FloatTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + decoder_labels_ids: Optional[torch.FloatTensor] = None, + decoder_labels_mask: Optional[torch.LongTensor] = None, + decoder_words_mask: Optional[torch.LongTensor] = None, + words_embedding: Optional[torch.FloatTensor] = None, + mask: Optional[torch.LongTensor] = None, + span_idx: Optional[torch.Tensor] = None, + span_mask: Optional[torch.Tensor] = None, + span_labels: Optional[torch.Tensor] = None, + prompts_embedding: Optional[torch.FloatTensor] = None, + prompts_embedding_mask: Optional[torch.LongTensor] = None, + words_mask: Optional[torch.LongTensor] = None, + text_lengths: Optional[torch.Tensor] = None, + labels: Optional[torch.FloatTensor] = None, + decoder_labels: Optional[torch.FloatTensor] = None, + threshold: Optional[float] = 0.5, + **kwargs: Any, + ) -> GLiNERDecoderOutput: + """Forward pass through the token-level encoder-decoder model. + + Args: + input_ids: Input token IDs of shape (B, L). + attention_mask: Attention mask of shape (B, L). + decoder_input_ids: Decoder input IDs for span mode. + decoder_attention_mask: Decoder attention mask. + decoder_labels_ids: Label token IDs for decoding of shape (M, L). + decoder_labels_mask: Mask for decoder labels of shape (M, L). + decoder_words_mask: Word position mask for span mode. + words_embedding: Pre-computed word embeddings. + mask: Mask for words. + span_idx: Pre-computed span indices of shape (B, S, 2). + span_mask: Pre-computed span mask of shape (B, S). + span_labels: Ground truth span labels of shape (B, S, C). + prompts_embedding: Pre-computed entity type embeddings. + prompts_embedding_mask: Mask for prompts. + words_mask: Word boundary mask. + text_lengths: Length of each text sequence. + labels: Ground truth token-level labels of shape (B, W, C, 3). + decoder_labels: Ground truth decoder labels of shape (M, L). + threshold: Confidence threshold for span selection. + **kwargs: Additional arguments. + + Returns: + GLiNERDecoderOutput containing logits, losses, and decoder information. + """ + encoder_kwargs = {key: kwargs[key] for key in ("packing_config", "pair_attention_mask") if key in kwargs} + + prompts_embedding, prompts_embedding_mask, words_embedding, mask = self.get_representations( + input_ids, attention_mask, text_lengths, words_mask, **encoder_kwargs + ) + + if labels is not None: + target_W = labels.shape[1] + words_embedding, mask = self._fit_length(words_embedding, mask, target_W) + + target_C = prompts_embedding.size(1) + target_C = max(target_C, labels.size(-2)) + + prompts_embedding, prompts_embedding_mask = self._fit_length( + prompts_embedding, prompts_embedding_mask, target_C + ) + + # Token-level classification: (B, W, C, 3) + scores = self.scorer(words_embedding, prompts_embedding) + + # Get span representations and logits if represent_spans is enabled + span_rep, span_idx, span_mask = self.get_span_representations( + scores, span_idx, span_mask, words_embedding, labels, threshold + ) + span_logits = torch.einsum("BND,BCD->BNC", span_rep, prompts_embedding) + + # Decoder processing + decoder_embedding = decoder_mask = decoder_loss = decoder_span_idx = None + if hasattr(self, "decoder"): + if self.config.decoder_mode == "span": + decoder_text_embeds = self.decoder.ids_to_embeds(decoder_input_ids) + else: + decoder_text_embeds = None + + decoder_embedding, decoder_mask, decoder_span_idx = self.select_token_decoder_embedding( + prompts_embedding, + prompts_embedding_mask, + span_logits=span_logits, + span_rep=span_rep, + span_idx=span_idx, + span_mask=span_mask, + span_labels=span_labels, + decoder_text_embeds=decoder_text_embeds, + decoder_words_mask=decoder_words_mask, + ) + + if decoder_labels is not None: + decoder_loss, _ = self.decode_labels( + decoder_embedding, decoder_mask, decoder_labels_ids, decoder_labels_mask, decoder_labels + ) + + # Compute loss + loss = None + if labels is not None: + loss = self.loss( + scores, labels, prompts_embedding_mask, mask, + span_logits=span_logits, span_labels=span_labels, span_mask=span_mask, + decoder_loss=decoder_loss, **kwargs + ) + + output = GLiNERDecoderOutput( + logits=scores, + loss=loss, + decoder_loss=decoder_loss, + prompts_embedding=prompts_embedding, + prompts_embedding_mask=prompts_embedding_mask, + decoder_embedding=decoder_embedding, + decoder_embedding_mask=decoder_mask, + decoder_span_idx=decoder_span_idx, + words_embedding=words_embedding, + mask=mask, + span_idx=span_idx, + span_logits=span_logits, + span_mask=span_mask, + ) + return output + + def loss( + self, + scores: torch.Tensor, + labels: torch.Tensor, + prompts_embedding_mask: torch.Tensor, + word_mask: torch.Tensor, + alpha: float = -1.0, + gamma: float = 0.0, + prob_margin: float = 0.0, + label_smoothing: float = 0.0, + reduction: str = "sum", + negatives: float = 1.0, + span_logits: Optional[torch.Tensor] = None, + span_labels: Optional[torch.Tensor] = None, + span_mask: Optional[torch.Tensor] = None, + decoder_loss: Optional[torch.Tensor] = None, + **kwargs: Any, + ) -> torch.Tensor: + """Compute combined loss for token classification, spans, and decoder. + + Args: + scores: Predicted token scores of shape (B, W, C, 3). + labels: Ground truth token labels of shape (B, W, C, 3). + prompts_embedding_mask: Mask for valid entity types of shape (B, C). + word_mask: Mask for valid words of shape (B, W). + alpha: Focal loss alpha parameter. + gamma: Focal loss gamma parameter. + prob_margin: Margin for probability adjustment. + label_smoothing: Label smoothing factor. + reduction: Loss reduction method ('sum' or 'mean'). + negatives: Negative sampling probability. + span_logits: Optional span logits of shape (B, S, C). + span_labels: Optional span labels of shape (B, S, C). + span_mask: Optional span mask of shape (B, S). + decoder_loss: Optional decoder loss to combine. + **kwargs: Additional arguments. + + Returns: + Scalar combined loss tensor. + """ + # Token-level loss (use parent's loss function) + token_loss = UniEncoderTokenModel.loss( + self, scores, labels, prompts_embedding_mask, word_mask, + alpha, gamma, prob_margin, label_smoothing, reduction, negatives, **kwargs + ) + + # Combine with span loss if available + if span_logits is not None and span_labels is not None and span_mask is not None: + span_loss = UniEncoderTokenModel.loss( + self, span_logits, span_labels, prompts_embedding_mask, span_mask, + alpha, gamma, prob_margin, label_smoothing, reduction, negatives, **kwargs + ) + token_loss = self.config.token_loss_coef * token_loss + self.config.span_loss_coef * span_loss + + # Combine with decoder loss if available + if decoder_loss is not None: + total_loss = ( + decoder_loss * self.config.decoder_loss_coef + + token_loss * getattr(self.config, 'token_loss_coef', 1.0) + ) + return total_loss + + return token_loss + + class UniEncoderSpanRelexModel(UniEncoderSpanModel): """Span-based NER model with relation extraction capabilities. diff --git a/gliner/modeling/outputs.py b/gliner/modeling/outputs.py index a48614d..1bd8efc 100644 --- a/gliner/modeling/outputs.py +++ b/gliner/modeling/outputs.py @@ -71,7 +71,6 @@ class GLiNERDecoderOutput(GLiNERBaseOutput): decoder_embedding_mask: Optional[torch.LongTensor] = None decoder_span_idx: Optional[torch.LongTensor] = None - @dataclass class GLiNERRelexOutput(GLiNERBaseOutput): """Output class for GLiNER models with relation extraction. From 213bc6f50f1afbd490119f65e7d824103504c4a1 Mon Sep 17 00:00:00 2001 From: Ingvar Date: Fri, 9 Jan 2026 14:55:57 +0200 Subject: [PATCH 6/8] fix token generative architecture --- gliner/config.py | 4 +- gliner/data_processing/__init__.py | 6 +- gliner/data_processing/processor.py | 68 ++++++------- gliner/decoding/__init__.py | 9 +- gliner/decoding/decoder.py | 5 +- gliner/model.py | 24 ++--- gliner/modeling/__init__.py | 2 +- gliner/modeling/base.py | 151 ++++++++++++++++------------ gliner/modeling/outputs.py | 1 + 9 files changed, 151 insertions(+), 119 deletions(-) diff --git a/gliner/config.py b/gliner/config.py index 5fe59a7..e74a349 100644 --- a/gliner/config.py +++ b/gliner/config.py @@ -177,9 +177,7 @@ def __init__( class UniEncoderTokenDecoderConfig(UniEncoderSpanDecoderConfig): - def __init__( - self, token_loss_coef=1.0, span_loss_coef=1.0, neg_spans_ratio=1.0, **kwargs - ): + def __init__(self, token_loss_coef=1.0, span_loss_coef=1.0, neg_spans_ratio=1.0, **kwargs): super().__init__(**kwargs) self.span_mode = "token_level" self.model_type = "gliner_encoder_token_decoder" diff --git a/gliner/data_processing/__init__.py b/gliner/data_processing/__init__.py index fcd9e8b..4cd98f4 100644 --- a/gliner/data_processing/__init__.py +++ b/gliner/data_processing/__init__.py @@ -4,9 +4,9 @@ UniEncoderSpanDataCollator, UniEncoderTokenDataCollator, UniEncoderSpanDecoderDataCollator, - UniEncoderTokenDecoderDataCollator, RelationExtractionSpanDataCollator, - RelationExtractionTokenDataCollator + UniEncoderTokenDecoderDataCollator, + RelationExtractionTokenDataCollator, ) from .processor import ( BaseProcessor, @@ -16,8 +16,8 @@ UniEncoderSpanProcessor, UniEncoderTokenProcessor, UniEncoderSpanDecoderProcessor, - UniEncoderTokenDecoderProcessor, RelationExtractionSpanProcessor, + UniEncoderTokenDecoderProcessor, RelationExtractionTokenProcessor, ) from .tokenizer import WordsSplitter diff --git a/gliner/data_processing/processor.py b/gliner/data_processing/processor.py index f991ddf..3226bea 100644 --- a/gliner/data_processing/processor.py +++ b/gliner/data_processing/processor.py @@ -35,8 +35,8 @@ def __init__(self, config, tokenizer, words_splitter): self.words_splitter = WordsSplitter(splitter_type=config.words_splitter_type) else: self.words_splitter = words_splitter - self.ent_token = getattr(config, 'ent_token', '[ENT]') - self.sep_token = getattr(config, 'sep_token', '[SEP]') + self.ent_token = getattr(config, "ent_token", "[ENT]") + self.sep_token = getattr(config, "sep_token", "[SEP]") # Check if the tokenizer has unk_token and pad_token self._check_and_set_special_tokens(self.transformer_tokenizer) @@ -685,7 +685,7 @@ def prepare_span_idx(self, ner, classes_to_id, num_tokens): else: span_idx, span_label = None, None return span_idx, span_label - + def preprocess_example(self, tokens, ner, classes_to_id): """Preprocess a single example for token-based prediction. @@ -1149,7 +1149,7 @@ def prepare_decoder_labels(self, decoder_label_strings): decoder_labels.masked_fill(~decoder_attention_mask.bool(), -100) decoder_tokenized_input["labels"] = decoder_labels return decoder_tokenized_input - + def create_labels(self, batch, blank=None): """Create labels for both span classification and decoder generation. @@ -1247,19 +1247,19 @@ def tokenize_and_prepare_labels(self, batch, prepare_labels, *args, **kwargs): class UniEncoderTokenDecoderProcessor(UniEncoderSpanDecoderProcessor, UniEncoderTokenProcessor): """Processor for token-based NER with encoder-decoder architecture. - + This processor combines token-level BIO-style classification with a decoder that generates entity type labels autoregressively, enabling more flexible prediction strategies for token-level NER tasks. - + Inherits from: - UniEncoderSpanDecoderProcessor: Encoder-decoder architecture and decoder utilities - UniEncoderTokenProcessor: Token-level BIO tagging for entities """ - + def __init__(self, config, tokenizer, words_splitter, decoder_tokenizer): """Initialize the token-level encoder-decoder processor. - + Args: config: Configuration object. tokenizer: Transformer tokenizer for encoding. @@ -1271,15 +1271,15 @@ def __init__(self, config, tokenizer, words_splitter, decoder_tokenizer): def preprocess_example(self, tokens, ner, classes_to_id): """Preprocess a single example for token-level encoder-decoder prediction. - + Uses token-level preprocessing from UniEncoderTokenProcessor while preparing for decoder-based label generation. - + Args: tokens: List of token strings. ner: List of NER annotations as (start, end, label) tuples. classes_to_id: Mapping from class labels to integer IDs. - + Returns: Dictionary containing: - tokens: Token strings @@ -1287,38 +1287,38 @@ def preprocess_example(self, tokens, ner, classes_to_id): - entities: Original NER annotations - span_idx: Tensor of entity span indices (if represent_spans=True) - span_label: Tensor of entity class IDs (if represent_spans=True) - + Warnings: UserWarning: If sequence length exceeds max_len (gets truncated). """ # Use token processor's preprocessing return UniEncoderTokenProcessor.preprocess_example(self, tokens, ner, classes_to_id) - + def create_batch_dict(self, batch, class_to_ids, id_to_classes): """Create a batch dictionary from preprocessed token examples. - + Args: batch: List of preprocessed example dictionaries. class_to_ids: List of class-to-ID mappings. id_to_classes: List of ID-to-class mappings. - + Returns: Dictionary containing all batch data for token-level encoder-decoder processing. """ # Use token processor's batch dict creation return UniEncoderTokenProcessor.create_batch_dict(self, batch, class_to_ids, id_to_classes) - + def create_labels(self, batch, blank=None): """Create labels for both token classification and decoder generation. - + Creates both token-level BIO labels and decoder generation labels for entity types. - + Args: batch: Batch dictionary containing tokens, entities, and class mappings. blank: Optional blank entity token for zero-shot scenarios. - + Returns: Tuple containing: - Token-level labels (BIO-style, shape: [batch_size, seq_len, num_classes, 3]) @@ -1326,15 +1326,15 @@ def create_labels(self, batch, blank=None): """ # Create token-level labels token_labels = UniEncoderTokenProcessor.create_labels(self, batch) - + # Create decoder labels decoder_label_strings = [] - + for i in range(len(batch["tokens"])): tokens = batch["tokens"][i] classes_to_id = batch["classes_to_id"][i] ner = batch["entities"][i] - + num_tokens = len(tokens) if self.config.decoder_mode == "span": # Collect entity labels in order of appearance @@ -1345,23 +1345,23 @@ def create_labels(self, batch, blank=None): elif self.config.decoder_mode == "prompt": # Use all entity types as decoder labels decoder_label_strings.extend(list(classes_to_id)) - + decoder_tokenized_input = self.prepare_decoder_labels(decoder_label_strings) - + return token_labels, decoder_tokenized_input - + def tokenize_and_prepare_labels(self, batch, prepare_labels, *args, **kwargs): """Tokenize inputs and prepare labels for token-level encoder-decoder training. - + Combines token-level input processing with decoder inputs and prepares both token-level BIO labels and decoder generation labels. - + Args: batch: Batch dictionary with tokens and class mappings. prepare_labels: Whether to prepare labels. *args: Variable length argument list. **kwargs: Arbitrary keyword arguments. - + Returns: Dictionary containing encoder inputs, decoder inputs, token-level labels, and decoder labels. @@ -1369,32 +1369,32 @@ def tokenize_and_prepare_labels(self, batch, prepare_labels, *args, **kwargs): blank = None if random.uniform(0, 1) < self.config.blank_entity_prob and prepare_labels: blank = "entity" - + # Use span decoder's tokenize_inputs for encoder-decoder tokenization tokenized_input = UniEncoderSpanDecoderProcessor.tokenize_inputs( self, batch["tokens"], batch["classes_to_id"], blank ) - + if prepare_labels: # Create both token-level and decoder labels token_labels, decoder_tokenized_input = self.create_labels(batch, blank=blank) tokenized_input["labels"] = token_labels - + # Add span-level one-hot labels if spans are represented if batch.get("span_idx") is not None: span_labels = self.create_span_labels(batch) tokenized_input["span_labels"] = span_labels tokenized_input["span_idx"] = batch["span_idx"] tokenized_input["span_mask"] = batch["span_mask"] - + # Add decoder labels if decoder_tokenized_input is not None: tokenized_input["decoder_labels_ids"] = decoder_tokenized_input["input_ids"] tokenized_input["decoder_labels_mask"] = decoder_tokenized_input["attention_mask"] tokenized_input["decoder_labels"] = decoder_tokenized_input["labels"] - + return tokenized_input - + class RelationExtractionSpanProcessor(UniEncoderSpanProcessor): """Processor for joint entity and relation extraction. diff --git a/gliner/decoding/__init__.py b/gliner/decoding/__init__.py index 24e4a71..a0caaeb 100644 --- a/gliner/decoding/__init__.py +++ b/gliner/decoding/__init__.py @@ -1 +1,8 @@ -from .decoder import SpanDecoder, TokenDecoder, SpanRelexDecoder, TokenRelexDecoder, SpanGenerativeDecoder, TokenGenerativeDecoder +from .decoder import ( + SpanDecoder, + TokenDecoder, + SpanRelexDecoder, + TokenRelexDecoder, + SpanGenerativeDecoder, + TokenGenerativeDecoder, +) diff --git a/gliner/decoding/decoder.py b/gliner/decoding/decoder.py index fce0e95..f973026 100644 --- a/gliner/decoding/decoder.py +++ b/gliner/decoding/decoder.py @@ -560,7 +560,7 @@ class IDs to class names. multi_label=multi_label, ) - + class SpanRelexDecoder(BaseSpanDecoder): """Span decoder with relation extraction support. @@ -1329,7 +1329,6 @@ class IDs to class names. else: span_label_maps = [{} for _ in range(B)] - batch_size = span_logits.size(0) spans = [] @@ -1430,4 +1429,4 @@ def decode( span_logits=span_logits, span_idx=span_idx, span_mask=span_mask, - ) \ No newline at end of file + ) diff --git a/gliner/model.py b/gliner/model.py index a978ffb..939ff22 100644 --- a/gliner/model.py +++ b/gliner/model.py @@ -33,16 +33,16 @@ UniEncoderSpanRelexConfig, UniEncoderTokenRelexConfig, UniEncoderSpanDecoderConfig, - UniEncoderTokenDecoderConfig + UniEncoderTokenDecoderConfig, ) from .decoding import ( - SpanDecoder, - TokenDecoder, - SpanRelexDecoder, - TokenRelexDecoder, - SpanGenerativeDecoder, - TokenGenerativeDecoder - ) + SpanDecoder, + TokenDecoder, + SpanRelexDecoder, + TokenRelexDecoder, + SpanGenerativeDecoder, + TokenGenerativeDecoder, +) from .training import Trainer, TrainingArguments from .evaluation import BaseNEREvaluator, BaseRelexEvaluator from .onnx.model import ( @@ -65,7 +65,7 @@ UniEncoderSpanRelexModel, UniEncoderTokenRelexModel, UniEncoderSpanDecoderModel, - UniEncoderTokenDecoderModel + UniEncoderTokenDecoderModel, ) from .data_processing import ( BaseProcessor, @@ -74,8 +74,8 @@ UniEncoderSpanProcessor, UniEncoderTokenProcessor, UniEncoderSpanDecoderProcessor, - UniEncoderTokenDecoderProcessor, RelationExtractionSpanProcessor, + UniEncoderTokenDecoderProcessor, RelationExtractionTokenProcessor, ) from .data_processing.collator import ( @@ -84,8 +84,8 @@ UniEncoderSpanDataCollator, UniEncoderTokenDataCollator, UniEncoderSpanDecoderDataCollator, - UniEncoderTokenDecoderDataCollator, RelationExtractionSpanDataCollator, + UniEncoderTokenDecoderDataCollator, RelationExtractionTokenDataCollator, ) from .data_processing.tokenizer import WordsSplitter @@ -2137,7 +2137,7 @@ class UniEncoderTokenDecoderGLiNER(UniEncoderSpanDecoderGLiNER): data_collator_class = UniEncoderTokenDecoderDataCollator decoder_class = TokenGenerativeDecoder - + class UniEncoderSpanRelexGLiNER(BaseEncoderGLiNER): """GLiNER model for both entity recognition and relation extraction. diff --git a/gliner/modeling/__init__.py b/gliner/modeling/__init__.py index e0d5a22..8b45611 100644 --- a/gliner/modeling/__init__.py +++ b/gliner/modeling/__init__.py @@ -9,5 +9,5 @@ UniEncoderSpanRelexModel, UniEncoderTokenRelexModel, UniEncoderSpanDecoderModel, - UniEncoderTokenDecoderModel + UniEncoderTokenDecoderModel, ) diff --git a/gliner/modeling/base.py b/gliner/modeling/base.py index cb8ce36..57bc2f2 100644 --- a/gliner/modeling/base.py +++ b/gliner/modeling/base.py @@ -529,7 +529,7 @@ def get_span_representations(self, scores, span_idx, span_mask, words_embedding, if getattr(self.config, "represent_spans", False): span_rep = self.span_rep_layer(words_embedding, span_idx) else: - span_rep = words_embedding[span_idx[:,:, 0]] + words_embedding[span_idx[:,:, 1]] + span_rep = words_embedding[span_idx[:, :, 0]] + words_embedding[span_idx[:, :, 1]] # span_rep = torch.zeros(B, S, D, device=words_embedding.device, dtype=words_embedding.dtype) # for b in range(B): @@ -607,7 +607,7 @@ def forward( ) span_logits = torch.einsum("BND,BCD->BNC", span_rep, prompts_embedding) else: - span_logits, span_idx, span_mask = None, None, None + span_logits, span_idx, span_mask = None, None, None loss = None if labels is not None: loss = self.loss(scores, labels, prompts_embedding_mask, mask, **kwargs) @@ -1049,6 +1049,7 @@ def forward( words_mask: Word boundary mask. text_lengths: Length of each text sequence. labels: Ground truth labels of shape (B, W, C). + threshold: float value for filtering spans. **kwargs: Additional arguments. Returns: @@ -1606,15 +1607,15 @@ def loss( class UniEncoderTokenDecoderModel(UniEncoderTokenModel, UniEncoderSpanDecoderModel): """Token-based NER model with encoder-decoder architecture. - + This model combines token-level BIO-style classification with a decoder that generates entity type labels autoregressively, enabling more flexible prediction strategies for token-level NER tasks. - + Inherits from: - UniEncoderTokenModel: Token-level BIO tagging for entities - UniEncoderSpanDecoderModel: Encoder-decoder architecture and decoder utilities - + Attributes: scorer (Scorer): Scoring layer for computing token-label compatibility. decoder (Decoder): Decoder module for label generation. @@ -1622,12 +1623,12 @@ class UniEncoderTokenDecoderModel(UniEncoderTokenModel, UniEncoderSpanDecoderMod _enc2dec_proj (Optional[nn.Module]): Projection layer if encoder and decoder dimensions differ. """ - + def __init__( self, config: Any, from_pretrained: bool = False, cache_dir: Optional[Union[str, Path]] = None ) -> None: """Initialize the token-level encoder-decoder model. - + Args: config: Model configuration object. from_pretrained: Whether to load from pretrained weights. @@ -1637,7 +1638,6 @@ def __init__( super().__init__(config, from_pretrained, cache_dir) self.__init_decoder__(config, from_pretrained, cache_dir) - def select_token_decoder_embedding( self, prompts_embedding: torch.Tensor, @@ -1646,33 +1646,35 @@ def select_token_decoder_embedding( span_rep: Optional[torch.Tensor] = None, span_idx: Optional[torch.Tensor] = None, span_mask: Optional[torch.Tensor] = None, - span_labels: Optional[torch.Tensor] = None, + span_labels: Optional[torch.Tensor] = None, decoder_text_embeds: Optional[torch.Tensor] = None, decoder_words_mask: Optional[torch.Tensor] = None, top_k: Optional[int] = None, ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: """Select entity embeddings for decoder input based on token predictions or labels. - + This method extracts entity spans from token-level predictions and prepares their representations for the decoder. It can operate in two modes: 1. "prompt" mode: Uses entity type embeddings as decoder input. 2. "span" mode: Uses contextualized tokens within each detected span as decoder input. - + Args: prompts_embedding: Entity type embeddings of shape (B, C, D). prompts_embedding_mask: Mask for prompts of shape (B, C). words_embedding: Word-level embeddings of shape (B, W, D). token_scores: Token classification scores of shape (B, W, C, 3). word_mask: Mask for valid words of shape (B, W). + span_logits: Span-level classification logits of shape (B, S, C), span_rep: Span representations (B, S, D). span_idx: Pre-computed span indices of shape (B, S, 2). + span_labels: Ground truth span labels of shape (B, S, C) span_mask: Pre-computed span mask of shape (B, S). decoder_text_embeds: Text embeddings for span mode of shape (B, T, D). decoder_words_mask: Word position mask of shape (B, T). labels: Ground truth token-level labels of shape (B, W, C, 3). threshold: Confidence threshold for selecting spans. top_k: Optional limit on number of spans to select. - + Returns: Tuple containing: - span_rep_kept: Selected span embeddings for decoder of shape (B, S, T, D) @@ -1682,12 +1684,12 @@ def select_token_decoder_embedding( """ if self.config.decoder_mode == "prompt": return self.select_decoder_embedding(prompts_embedding, prompts_embedding_mask)[:3] - + span_scores = torch.sigmoid(span_logits) # Flatten span representations for selection B = span_rep.size(0) - + if span_labels is not None: # During training: select spans where any class is positive (label == 1) span_prob = span_labels.max(-1).values # (B, S) @@ -1705,63 +1707,63 @@ def select_token_decoder_embedding( # Pack valid spans span_rep_kept, span_msk, span_sel_idx = self.select_decoder_embedding(span_rep, keep.long()) - + if hasattr(self, "_enc2dec_proj"): span_rep_kept = self._enc2dec_proj(span_rep_kept) - + span_rep_kept = span_rep_kept.unsqueeze(2) span_msk = span_msk.unsqueeze(-1) - + if decoder_text_embeds is None or decoder_words_mask is None: return span_rep_kept, span_msk, span_sel_idx - + if span_rep_kept.numel() == 0: return None, None, None - + decoder_text_embeds = decoder_text_embeds.to(dtype=span_rep_kept.dtype) - + # Build span representations with context tokens S_kept = span_rep_kept.shape[1] dec_D = span_rep_kept.shape[-1] - + # Get actual span boundaries from selected indices batch_indices = torch.arange(B, device=span_idx.device).unsqueeze(1).expand(B, S_kept) selected_spans = span_idx[batch_indices, span_sel_idx] # (B, S_kept, 2) - + span_start = selected_spans[:, :, 0] # (B, S_kept) - span_end = selected_spans[:, :, 1] # (B, S_kept) - + span_end = selected_spans[:, :, 1] # (B, S_kept) + # Determine which decoder tokens belong to each span token_in_span = (decoder_words_mask.unsqueeze(1) >= span_start.unsqueeze(-1)) & ( decoder_words_mask.unsqueeze(1) <= span_end.unsqueeze(-1) ) - + tokens_per_span = token_in_span.sum(-1) max_tokens = int(tokens_per_span.max()) - + span_rep_new = span_rep_kept.new_zeros(B, S_kept, max_tokens + 1, dec_D) span_rep_mask = torch.zeros(B, S_kept, max_tokens + 1, dtype=torch.bool, device=decoder_text_embeds.device) - + left_offset = (max_tokens + 1 - tokens_per_span).clamp(min=0) pos_in_span = (token_in_span.cumsum(-1) - 1).masked_fill(~token_in_span, 0) pos_in_span = pos_in_span + left_offset.unsqueeze(-1) - + b_idx, s_idx, tok_idx = torch.where(token_in_span) span_rep_new[b_idx, s_idx, pos_in_span[b_idx, s_idx, tok_idx]] = decoder_text_embeds[b_idx, tok_idx] span_rep_mask[b_idx, s_idx, pos_in_span[b_idx, s_idx, tok_idx]] = True - + kept_pos = (left_offset - 1).clamp(min=0) - + b_flat = torch.arange(B, device=decoder_text_embeds.device).view(-1, 1).expand(B, S_kept).reshape(-1) s_flat = torch.arange(S_kept, device=decoder_text_embeds.device).view(1, -1).expand(B, S_kept).reshape(-1) t_flat = kept_pos.reshape(-1) - + span_rep_new[b_flat, s_flat, t_flat] = span_rep_kept.reshape(-1, dec_D) span_rep_mask[b_flat, s_flat, t_flat] = True span_rep_mask = span_rep_mask & span_msk.bool() - + return span_rep_new, span_rep_mask, span_sel_idx - + def forward( self, input_ids: Optional[torch.FloatTensor] = None, @@ -1786,7 +1788,7 @@ def forward( **kwargs: Any, ) -> GLiNERDecoderOutput: """Forward pass through the token-level encoder-decoder model. - + Args: input_ids: Input token IDs of shape (B, L). attention_mask: Attention mask of shape (B, L). @@ -1808,30 +1810,30 @@ def forward( decoder_labels: Ground truth decoder labels of shape (M, L). threshold: Confidence threshold for span selection. **kwargs: Additional arguments. - + Returns: GLiNERDecoderOutput containing logits, losses, and decoder information. """ encoder_kwargs = {key: kwargs[key] for key in ("packing_config", "pair_attention_mask") if key in kwargs} - + prompts_embedding, prompts_embedding_mask, words_embedding, mask = self.get_representations( input_ids, attention_mask, text_lengths, words_mask, **encoder_kwargs ) - + if labels is not None: target_W = labels.shape[1] words_embedding, mask = self._fit_length(words_embedding, mask, target_W) - + target_C = prompts_embedding.size(1) target_C = max(target_C, labels.size(-2)) - + prompts_embedding, prompts_embedding_mask = self._fit_length( prompts_embedding, prompts_embedding_mask, target_C ) - + # Token-level classification: (B, W, C, 3) scores = self.scorer(words_embedding, prompts_embedding) - + # Get span representations and logits if represent_spans is enabled span_rep, span_idx, span_mask = self.get_span_representations( scores, span_idx, span_mask, words_embedding, labels, threshold @@ -1845,7 +1847,7 @@ def forward( decoder_text_embeds = self.decoder.ids_to_embeds(decoder_input_ids) else: decoder_text_embeds = None - + decoder_embedding, decoder_mask, decoder_span_idx = self.select_token_decoder_embedding( prompts_embedding, prompts_embedding_mask, @@ -1857,21 +1859,27 @@ def forward( decoder_text_embeds=decoder_text_embeds, decoder_words_mask=decoder_words_mask, ) - + if decoder_labels is not None: decoder_loss, _ = self.decode_labels( decoder_embedding, decoder_mask, decoder_labels_ids, decoder_labels_mask, decoder_labels ) - + # Compute loss loss = None if labels is not None: loss = self.loss( - scores, labels, prompts_embedding_mask, mask, - span_logits=span_logits, span_labels=span_labels, span_mask=span_mask, - decoder_loss=decoder_loss, **kwargs + scores, + labels, + prompts_embedding_mask, + mask, + span_logits=span_logits, + span_labels=span_labels, + span_mask=span_mask, + decoder_loss=decoder_loss, + **kwargs, ) - + output = GLiNERDecoderOutput( logits=scores, loss=loss, @@ -1888,7 +1896,7 @@ def forward( span_mask=span_mask, ) return output - + def loss( self, scores: torch.Tensor, @@ -1908,7 +1916,7 @@ def loss( **kwargs: Any, ) -> torch.Tensor: """Compute combined loss for token classification, spans, and decoder. - + Args: scores: Predicted token scores of shape (B, W, C, 3). labels: Ground truth token labels of shape (B, W, C, 3). @@ -1925,34 +1933,53 @@ def loss( span_mask: Optional span mask of shape (B, S). decoder_loss: Optional decoder loss to combine. **kwargs: Additional arguments. - + Returns: Scalar combined loss tensor. """ # Token-level loss (use parent's loss function) token_loss = UniEncoderTokenModel.loss( - self, scores, labels, prompts_embedding_mask, word_mask, - alpha, gamma, prob_margin, label_smoothing, reduction, negatives, **kwargs + self, + scores, + labels, + prompts_embedding_mask, + word_mask, + alpha, + gamma, + prob_margin, + label_smoothing, + reduction, + negatives, + **kwargs, ) - + # Combine with span loss if available if span_logits is not None and span_labels is not None and span_mask is not None: span_loss = UniEncoderTokenModel.loss( - self, span_logits, span_labels, prompts_embedding_mask, span_mask, - alpha, gamma, prob_margin, label_smoothing, reduction, negatives, **kwargs + self, + span_logits, + span_labels, + prompts_embedding_mask, + span_mask, + alpha, + gamma, + prob_margin, + label_smoothing, + reduction, + negatives, + **kwargs, ) token_loss = self.config.token_loss_coef * token_loss + self.config.span_loss_coef * span_loss - + # Combine with decoder loss if available if decoder_loss is not None: - total_loss = ( - decoder_loss * self.config.decoder_loss_coef + - token_loss * getattr(self.config, 'token_loss_coef', 1.0) + total_loss = decoder_loss * self.config.decoder_loss_coef + token_loss * getattr( + self.config, "token_loss_coef", 1.0 ) return total_loss - + return token_loss - + class UniEncoderSpanRelexModel(UniEncoderSpanModel): """Span-based NER model with relation extraction capabilities. diff --git a/gliner/modeling/outputs.py b/gliner/modeling/outputs.py index 1bd8efc..a48614d 100644 --- a/gliner/modeling/outputs.py +++ b/gliner/modeling/outputs.py @@ -71,6 +71,7 @@ class GLiNERDecoderOutput(GLiNERBaseOutput): decoder_embedding_mask: Optional[torch.LongTensor] = None decoder_span_idx: Optional[torch.LongTensor] = None + @dataclass class GLiNERRelexOutput(GLiNERBaseOutput): """Output class for GLiNER models with relation extraction. From 3d61898bb8ffa153cb4c20324449eaa863b3adfa Mon Sep 17 00:00:00 2001 From: Ingvar Date: Fri, 9 Jan 2026 16:08:36 +0200 Subject: [PATCH 7/8] updates docs on architectures --- docs/architectures.md | 220 ++++++++++++++++++++++++++++++++++++++++++ train.py | 5 +- 2 files changed, 222 insertions(+), 3 deletions(-) diff --git a/docs/architectures.md b/docs/architectures.md index d5a29ca..4b57f18 100644 --- a/docs/architectures.md +++ b/docs/architectures.md @@ -21,7 +21,9 @@ The GLiNER framework now supports multiple architecture variants, each optimized | **BiEncoderSpan** | Separate encoders for text & labels | Span-level | Pre-compute label embeddings, handles 100+ entity types | Many entity types, production deployment | | **BiEncoderToken** | Separate encoders for text & labels | Token-level | Combines bi-encoder efficiency with token-level prediction | Long entities with many types | | **UniEncoderSpanDecoder** | Single encoder + generative decoder | Span-level with generation | Generates entity labels, open vocabulary | Open-domain NER, label discovery | +| **UniEncoderTokenDecoder** | Single encoder + generative decoder | Token-level with generation | Token-level detection + label generation | Long entities with open vocabulary | | **UniEncoderSpanRelex** | Single encoder + relation layers | Span-level + relations | Joint entity and relation extraction | Knowledge graph construction, IE | +| **UniEncoderTokenRelex** | Single encoder + relation layers | Token-level + relations | Token-level entities with relation extraction | Long entities with relations | The framework automatically selects the appropriate architecture based on your model configuration, providing a unified API across all variants. @@ -243,6 +245,95 @@ for entity in entities[0]: print(f" Generated: {entity['generated_labels']}") ``` +## GLiNER Token-Level with Generative Decoder (UniEncoderTokenDecoder) + +The UniEncoderTokenDecoder architecture combines token-level BIO tagging with a generative decoder, offering the best of both worlds: the ability to handle long entity spans (from token-level prediction) and open-vocabulary entity typing (from the generative decoder). + +### Architecture Overview + +The architecture extends UniEncoderToken with decoder capabilities: +1. **Token Encoder**: Standard GLiNER token-based encoder with BIO tagging for entity boundary detection +2. **Span Representation Layer**: Converts detected token sequences into span representations +3. **Generative Decoder**: GPT-2 or similar decoder that generates entity type labels for detected spans + +### Key Differences from UniEncoderSpanDecoder + +| Aspect | UniEncoderSpanDecoder | UniEncoderTokenDecoder | +|--------|----------------------|------------------------| +| Entity Detection | Span enumeration (max width 12) | Token-level BIO tagging | +| Long Entities | Limited by max span width | No length limitation | +| Computation | O(n × max_width) spans | O(n) tokens | +| Best For | Standard NER entities | Long-form extraction + label generation | + +### Architecture Details + +**Token-Level Detection**: +- Uses the Scorer module to compute token-label compatibility scores +- Produces three logits per token per class: start, inside, end (BIO-style) +- Entities are extracted by finding contiguous sequences of positive predictions + +**Span Representation**: +- Detected token sequences are converted to span representations +- Optional `represent_spans` mode uses a dedicated SpanRepLayer for richer representations +- Span representations are then fed to the decoder for label generation + +**Training Objective**: +- Multi-component loss: + - Token classification loss (`token_loss_coef`) + - Span classification loss (`span_loss_coef`) + - Decoder generation loss (`decoder_loss_coef`) + +### Configuration + +```python +from gliner import GLiNERConfig + +config = GLiNERConfig( + model_name="microsoft/deberta-v3-small", + span_mode="token_level", + labels_decoder="gpt2", # Enables decoder + decoder_mode="span", # or "prompt" + token_loss_coef=1.0, + span_loss_coef=1.0, + decoder_loss_coef=0.5, + represent_spans=True, # Use SpanRepLayer for span representations +) +``` + +### Use Cases + +- **Long entity extraction with open vocabulary**: Extract multi-sentence entities while generating appropriate labels +- **Extractive summarization with typing**: Identify summary-worthy spans and label their semantic type +- **Document-level entity typing**: Handle entities that span multiple clauses or sentences +- **Flexible entity annotation**: Combine precise boundary detection with descriptive type generation + +### Example Usage + +```python +from gliner import GLiNER + +# Load a token-level decoder model +model = GLiNER.from_pretrained("knowledgator/gliner-token-decoder-v1.0") + +text = """The Paris Agreement, adopted in December 2015 at the 21st Conference +of the Parties to the United Nations Framework Convention on Climate Change, +represents a landmark international accord on climate action.""" + +# Extract long entities with generated labels +entities = model.inference( + [text], + labels=["entity"], + gen_constraints=["agreement", "organization", "event", "topic"], + num_gen_sequences=1 +) + +for entity in entities[0]: + print(f"Entity: {entity['text'][:50]}...") + print(f" Label: {entity['label']}") + if 'generated_labels' in entity: + print(f" Generated: {entity['generated_labels']}") +``` + ## GLiNER for Relation Extraction (UniEncoderSpanRelex) The UniEncoderSpanRelex architecture extends GLiNER to perform joint entity and relation extraction, enabling the model to identify both entities and the relationships between them in a single forward pass. @@ -361,6 +452,108 @@ for relation in relations[0]: } ``` +## GLiNER Token-Level for Relation Extraction (UniEncoderTokenRelex) + +The UniEncoderTokenRelex architecture combines token-level entity detection with relation extraction capabilities, enabling joint extraction of long-form entities and their relationships. + +### Architecture Overview + +The architecture extends UniEncoderToken with relation extraction components: + +1. **Token Encoder**: Standard GLiNER token-based encoder with BIO tagging +2. **Span Representation Layer**: Converts detected token sequences into entity representations +3. **Relation Representation Layer**: Computes pairwise entity representations and adjacency predictions +4. **Relation Classification Layer**: Classifies relation types between entity pairs + +### Key Differences from UniEncoderSpanRelex + +| Aspect | UniEncoderSpanRelex | UniEncoderTokenRelex | +|--------|---------------------|----------------------| +| Entity Detection | Span enumeration | Token-level BIO tagging | +| Long Entities | Limited by max span width | No length limitation | +| Entity Boundaries | Explicit span indices | Derived from token predictions | +| Best For | Standard entity-relation extraction | Long entities with relations | + +### Architecture Details + +**Entity Detection**: +- Uses the Scorer module for token-level classification (start, inside, end) +- Entity spans are extracted from contiguous positive predictions +- No maximum span width limitation + +**Relation Extraction**: +- Same relation layers as UniEncoderSpanRelex +- Relations computed between entity representations derived from token sequences +- Supports both concatenation and triple scoring methods + +**Training Objective**: +- Multi-component loss: + - Token classification loss (for entity boundaries) + - Adjacency loss (for entity pair connectivity) + - Relation classification loss (for relation types) + +### Configuration + +```python +from gliner import GLiNERConfig + +config = GLiNERConfig( + model_name="microsoft/deberta-v3-small", + span_mode="token_level", + relations_layer="biaffine", # or "concat" + triples_layer="TransE", # Optional: for triple scoring + span_loss_coef=1.0, + adjacency_loss_coef=1.0, + relation_loss_coef=1.0, +) +``` + +### Use Cases + +- **Scientific IE**: Extract long entity mentions (chemical compounds, gene names) and their relationships +- **Legal Document Analysis**: Identify parties, clauses, and their legal relationships +- **Medical Record Processing**: Extract symptoms, treatments, and their clinical relationships +- **News Event Extraction**: Identify event participants and their roles across long descriptions + +### Example Usage + +```python +from gliner import GLiNER + +# Load a token-level relation extraction model +model = GLiNER.from_pretrained("knowledgator/gliner-token-relex-v1.0") + +text = """The Phase III clinical trial conducted by Pfizer and BioNTech +demonstrated that the BNT162b2 vaccine achieved 95% efficacy against +COVID-19 in participants without prior infection.""" + +# Define entity and relation types +entity_labels = ["organization", "vaccine", "disease", "metric"] +relation_labels = ["developed_by", "effective_against", "measured_as"] + +# Extract entities and relations +entities, relations = model.inference( + [text], + labels=entity_labels, + relations=relation_labels, + threshold=0.5, + relation_threshold=0.5 +) + +# Display results +print("Entities:") +for entity in entities[0]: + print(f" {entity['text']} ({entity['label']})") + +print("\nRelations:") +for relation in relations[0]: + print(f" {relation['head']['text']} --[{relation['relation']}]--> {relation['tail']['text']}") +``` + +### Output Format + +Same as UniEncoderSpanRelex - entities and relations are returned in identical formats, ensuring API consistency across architectures. + ## Choosing the Right Architecture Here's a quick guide to selecting the appropriate GLiNER architecture: @@ -372,9 +565,36 @@ Here's a quick guide to selecting the appropriate GLiNER architecture: | Many entity types (50-200+) | BiEncoderSpan or BiEncoderToken | Pre-compute labels, handles many types | | Production with fixed schema | BiEncoder variants | Cache label embeddings for speed | | Open-domain, unknown types | UniEncoderSpanDecoder | Generate labels on-the-fly | +| Long entities + open vocabulary | UniEncoderTokenDecoder | Token-level detection with label generation | | Knowledge graph extraction | UniEncoderSpanRelex | Joint entity and relation extraction | +| Long entities with relations | UniEncoderTokenRelex | Token-level entities with relation extraction | | Both long entities + many types | BiEncoderToken | Combines both advantages | +### Decision Flowchart + +``` +Start + │ + ├─ Need relation extraction? + │ ├─ Yes: Long entities expected? + │ │ ├─ Yes → UniEncoderTokenRelex + │ │ └─ No → UniEncoderSpanRelex + │ │ + │ └─ No: Need open vocabulary labels? + │ ├─ Yes: Long entities expected? + │ │ ├─ Yes → UniEncoderTokenDecoder + │ │ └─ No → UniEncoderSpanDecoder + │ │ + │ └─ No: Many entity types (>30)? + │ ├─ Yes: Long entities expected? + │ │ ├─ Yes → BiEncoderToken + │ │ └─ No → BiEncoderSpan + │ │ + │ └─ No: Long entities expected? + │ ├─ Yes → UniEncoderToken + │ └─ No → UniEncoderSpan +``` + ## References - The Intro section is based on the [Shahrukh Khan](https://www.linkedin.com/in/shahrukhx01/) article [Illustrated GLINER](https://medium.com/@shahrukhx01/illustrated-gliner-e6971e4c8c52) and placed into documentation with consent of the author. - [Urchade Zaratiana, Nadi Tomeh, Pierre Holat, and Thierry Charnois. 2023. Gliner: Generalist model for named entity recognition using bidirectional transformer](https://arxiv.org/abs/2311.08526) diff --git a/train.py b/train.py index 0e98d18..7e449aa 100644 --- a/train.py +++ b/train.py @@ -48,7 +48,7 @@ def main(cfg_path: str): # Build model model = build_model(model_cfg, train_cfg) print(f"Model type: {model.__class__.__name__}") - + # Get freeze components freeze_components = train_cfg.get("freeze_components", None) if freeze_components: @@ -56,7 +56,7 @@ def main(cfg_path: str): # Train print("\nStarting training...") - trainer = model.train_model( + model.train_model( train_dataset=train_dataset, eval_dataset=eval_dataset, output_dir="models", @@ -87,7 +87,6 @@ def main(cfg_path: str): freeze_components=freeze_components, ) - trainer.save_model() print(f"\n✓ Training complete! Model saved to {output_dir}") From 2f27ea6641ca915245fc89fe57b80f90d0eb63ee Mon Sep 17 00:00:00 2001 From: Ingvarstep Date: Sun, 11 Jan 2026 16:35:35 +0000 Subject: [PATCH 8/8] update training and model configs --- configs/config_biencoder.yaml | 1 + configs/config_decoder.yaml | 5 ++-- configs/config_relex.yaml | 1 + configs/config_token.yaml | 5 ++-- gliner/config.py | 48 +++++++++++++++-------------------- gliner/model.py | 6 +++++ 6 files changed, 35 insertions(+), 31 deletions(-) diff --git a/configs/config_biencoder.yaml b/configs/config_biencoder.yaml index 617c742..967ff02 100644 --- a/configs/config_biencoder.yaml +++ b/configs/config_biencoder.yaml @@ -7,6 +7,7 @@ model: hidden_size: 768 dropout: 0.3 fine_tune: true + represent_spans: false subtoken_pooling: first fuse_layers: false post_fusion_schema: null # e.g., "l2l-l2t-t2t" diff --git a/configs/config_decoder.yaml b/configs/config_decoder.yaml index a149201..887283b 100644 --- a/configs/config_decoder.yaml +++ b/configs/config_decoder.yaml @@ -10,6 +10,7 @@ model: fine_tune: true subtoken_pooling: first fuse_layers: false + represent_spans: true post_fusion_schema: null # e.g., "l2l-l2t-t2t" full_decoder_context: true span_mode: markerV0 # Options: token_level | markerV0 @@ -23,7 +24,7 @@ model: data: # Directory Paths root_dir: gliner_logs - train_data: "data.json" + train_data: "data/data.json" val_data_dir: "none" # Set to validation data path or "none" training: @@ -33,7 +34,7 @@ training: # Training Parameters num_steps: 15000 train_batch_size: 8 - eval_every: 500 + eval_every: 10 warmup_ratio: 0.05 scheduler_type: "cosine" # Options: linear, cosine, constant diff --git a/configs/config_relex.yaml b/configs/config_relex.yaml index 7ce7950..fc8aeb1 100644 --- a/configs/config_relex.yaml +++ b/configs/config_relex.yaml @@ -8,6 +8,7 @@ model: fine_tune: true subtoken_pooling: first fuse_layers: false + represent_spans: false post_fusion_schema: null # e.g., "l2l-l2t-t2t" span_mode: markerV0 # Options: token_level | markerV0 max_types: 100 diff --git a/configs/config_token.yaml b/configs/config_token.yaml index 8dd1b36..19ed07a 100644 --- a/configs/config_token.yaml +++ b/configs/config_token.yaml @@ -8,6 +8,7 @@ model: fine_tune: true subtoken_pooling: first fuse_layers: false + represent_spans: false post_fusion_schema: null # e.g., "l2l-l2t-t2t" span_mode: token_level # Options: token_level | markerV0 max_types: 100 @@ -17,7 +18,7 @@ model: data: # Directory Paths root_dir: gliner_logs - train_data: "data.json" + train_data: "data/data.json" val_data_dir: "none" # Set to validation data path or "none" training: @@ -27,7 +28,7 @@ training: # Training Parameters num_steps: 15000 train_batch_size: 8 - eval_every: 500 + eval_every: 10 warmup_ratio: 0.05 scheduler_type: "cosine" # Options: linear, cosine, constant diff --git a/gliner/config.py b/gliner/config.py index e74a349..d57e103 100644 --- a/gliner/config.py +++ b/gliner/config.py @@ -35,6 +35,10 @@ def __init__( ent_token: str = "<>", sep_token: str = "<>", _attn_implementation: Optional[str] = None, + token_loss_coef: float = 1.0, + span_loss_coef: float = 1.0, + represent_spans: bool = False, + neg_spans_ratio: float = 1.0, **kwargs, ): """Initialize BaseGLiNERConfig. @@ -64,6 +68,10 @@ def __init__( ent_token (str, optional): Entity marker token. Defaults to "<>". sep_token (str, optional): Separator token. Defaults to "<>". _attn_implementation (str, optional): Attention implementation. Defaults to None. + token_loss_coef (float, optional): Token loss coefficient. Defaults to 1.0. + span_loss_coef (float, optional): Span loss coefficient. Defaults to 1.0. + represent_spans (bool, optional): Whether to represent spans. Defaults to False. + neg_spans_ratio (float, optional): Ratio of negative spans. Defaults to 1.0. **kwargs: Additional keyword arguments passed to parent class. """ super().__init__(**kwargs) @@ -96,6 +104,10 @@ def __init__( self.ent_token = ent_token self.sep_token = sep_token self._attn_implementation = _attn_implementation + self.token_loss_coef = token_loss_coef + self.span_loss_coef = span_loss_coef + self.represent_spans = represent_spans + self.neg_spans_ratio = neg_spans_ratio class UniEncoderConfig(BaseGLiNERConfig): @@ -119,16 +131,10 @@ def __init__(self, **kwargs): class UniEncoderTokenConfig(UniEncoderConfig): """Configuration for uni-encoder token-based GLiNER model.""" - def __init__( - self, represent_spans: bool = False, token_loss_coef=1.0, span_loss_coef=1.0, neg_spans_ratio=1.0, **kwargs - ): + def __init__(self, **kwargs): super().__init__(**kwargs) self.span_mode = "token_level" self.model_type = "gliner_uni_encoder_token" - self.token_loss_coef = token_loss_coef - self.span_loss_coef = span_loss_coef - self.represent_spans = represent_spans - self.neg_spans_ratio = neg_spans_ratio class UniEncoderSpanDecoderConfig(UniEncoderConfig): @@ -142,7 +148,6 @@ def __init__( blank_entity_prob: float = 0.1, labels_decoder_config: Optional[dict] = None, decoder_loss_coef=0.5, - span_loss_coef=0.5, **kwargs, ): """Initialize UniEncoderSpanDecoderConfig. @@ -154,7 +159,6 @@ def __init__( blank_entity_prob (float, optional): Probability of blank entities. Defaults to 0.1. labels_decoder_config (dict, optional): Decoder config dict. Defaults to None. decoder_loss_coef (float, optional): Decoder loss coefficient. Defaults to 0.5. - span_loss_coef (float, optional): Span loss coefficient. Defaults to 0.5. **kwargs: Additional keyword arguments passed to UniEncoderConfig. Raises: @@ -172,19 +176,15 @@ def __init__( self.decoder_mode = decoder_mode # 'prompt' or 'span' self.full_decoder_context = full_decoder_context self.decoder_loss_coef = decoder_loss_coef - self.span_loss_coef = span_loss_coef self.model_type = "gliner_uni_encoder_span_decoder" class UniEncoderTokenDecoderConfig(UniEncoderSpanDecoderConfig): - def __init__(self, token_loss_coef=1.0, span_loss_coef=1.0, neg_spans_ratio=1.0, **kwargs): + def __init__(self, **kwargs): super().__init__(**kwargs) self.span_mode = "token_level" self.model_type = "gliner_encoder_token_decoder" - self.token_loss_coef = token_loss_coef - self.span_loss_coef = span_loss_coef - self.represent_spans = True - self.neg_spans_ratio = neg_spans_ratio + self.represent_spans = True # hardcoded to True for token decoder class UniEncoderRelexConfig(UniEncoderConfig): @@ -195,7 +195,6 @@ def __init__( embed_rel_token: bool = True, rel_token_index: int = -1, rel_token: str = "<>", - span_loss_coef=1.0, adjacency_loss_coef=1.0, relation_loss_coef=1.0, **kwargs, @@ -210,7 +209,6 @@ def __init__( embed_rel_token (bool, optional): Whether to embed relation tokens. Defaults to True. rel_token_index (int, optional): Index of relation token. Defaults to -1. rel_token (str, optional): Relation marker token. Defaults to "<>". - span_loss_coef (float, optional): Span representaton loss coefficient. Defaults to 1.0. adjacency_loss_coef (float, optional): Adjacency modeling loss coefficient. Defaults to 1.0. relation_loss_coef (float, optional): Relation representaton loss coefficient. Defaults to 1.0. **kwargs: Additional keyword arguments passed to UniEncoderConfig. @@ -225,7 +223,6 @@ def __init__( self.embed_rel_token = embed_rel_token self.rel_token_index = rel_token_index self.rel_token = rel_token - self.span_loss_coef = span_loss_coef self.adjacency_loss_coef = adjacency_loss_coef self.relation_loss_coef = relation_loss_coef @@ -284,16 +281,10 @@ def __init__(self, **kwargs): class BiEncoderTokenConfig(BiEncoderConfig): """Configuration for bi-encoder token-based GLiNER model.""" - def __init__( - self, represent_spans: bool = False, token_loss_coef=1.0, span_loss_coef=1.0, neg_spans_ratio=1.0, **kwargs - ): + def __init__(self, **kwargs): super().__init__(**kwargs) self.span_mode = "token_level" self.model_type = "gliner_bi_encoder_token" - self.token_loss_coef = token_loss_coef - self.span_loss_coef = span_loss_coef - self.represent_spans = represent_spans - self.neg_spans_ratio = neg_spans_ratio class GLiNERConfig(BaseGLiNERConfig): @@ -333,7 +324,10 @@ def __init__( def model_type(self): """Auto-detect model type based on configuration.""" if self.labels_decoder: - return "gliner_uni_encoder_span_decoder" + if self.span_mode == 'token-level': + return "gliner_uni_encoder_token_decoder" + else: + return "gliner_uni_encoder_span_decoder" elif self.labels_encoder: return "gliner_bi_encoder_span" if self.span_mode != "token-level" else "gliner_bi_encoder_token" elif self.relations_layer is not None: @@ -363,4 +357,4 @@ def model_type(self): "gliner_bi_encoder_span": BiEncoderSpanConfig, "gliner_bi_encoder_token": BiEncoderTokenConfig, } -) +) \ No newline at end of file diff --git a/gliner/model.py b/gliner/model.py index 939ff22..8a2e55e 100644 --- a/gliner/model.py +++ b/gliner/model.py @@ -3060,6 +3060,11 @@ def model_map(self) -> dict[str, dict[str, Any]]: "description": "Span-based NER with label generation decoder", "config": {"span_mode": "span_level", "labels_decoder": "required", "relations_layer": None}, }, + "gliner_uni_encoder_token_decoder": { + "class": UniEncoderTokenDecoderGLiNER, + "description": "Token-level NER with label generation decoder", + "config": {"span_mode": "token_level", "labels_decoder": "required", "relations_layer": None}, + }, "gliner_uni_encoder_span_relex": { "class": UniEncoderSpanRelexGLiNER, "description": "Joint entity and relation extraction with single encoder", @@ -3087,6 +3092,7 @@ def get_model_type(self) -> str: "BiEncoderSpanGLiNER": "gliner_bi_encoder_span", "BiEncoderTokenGLiNER": "gliner_bi_encoder_token", "UniEncoderSpanDecoderGLiNER": "gliner_uni_encoder_span_decoder", + "UniEncoderTokenDecoderGLiNER": "gliner_uni_encoder_token_decoder", "UniEncoderSpanRelexGLiNER": "gliner_uni_encoder_span_relex", "UniEncoderTokenRelexGLiNER": "gliner_uni_encoder_token_relex", }