diff --git a/configs/config_biencoder.yaml b/configs/config_biencoder.yaml index 617c742..967ff02 100644 --- a/configs/config_biencoder.yaml +++ b/configs/config_biencoder.yaml @@ -7,6 +7,7 @@ model: hidden_size: 768 dropout: 0.3 fine_tune: true + represent_spans: false subtoken_pooling: first fuse_layers: false post_fusion_schema: null # e.g., "l2l-l2t-t2t" diff --git a/configs/config_decoder.yaml b/configs/config_decoder.yaml index a149201..887283b 100644 --- a/configs/config_decoder.yaml +++ b/configs/config_decoder.yaml @@ -10,6 +10,7 @@ model: fine_tune: true subtoken_pooling: first fuse_layers: false + represent_spans: true post_fusion_schema: null # e.g., "l2l-l2t-t2t" full_decoder_context: true span_mode: markerV0 # Options: token_level | markerV0 @@ -23,7 +24,7 @@ model: data: # Directory Paths root_dir: gliner_logs - train_data: "data.json" + train_data: "data/data.json" val_data_dir: "none" # Set to validation data path or "none" training: @@ -33,7 +34,7 @@ training: # Training Parameters num_steps: 15000 train_batch_size: 8 - eval_every: 500 + eval_every: 10 warmup_ratio: 0.05 scheduler_type: "cosine" # Options: linear, cosine, constant diff --git a/configs/config_relex.yaml b/configs/config_relex.yaml index 7ce7950..fc8aeb1 100644 --- a/configs/config_relex.yaml +++ b/configs/config_relex.yaml @@ -8,6 +8,7 @@ model: fine_tune: true subtoken_pooling: first fuse_layers: false + represent_spans: false post_fusion_schema: null # e.g., "l2l-l2t-t2t" span_mode: markerV0 # Options: token_level | markerV0 max_types: 100 diff --git a/configs/config_token.yaml b/configs/config_token.yaml index 8dd1b36..19ed07a 100644 --- a/configs/config_token.yaml +++ b/configs/config_token.yaml @@ -8,6 +8,7 @@ model: fine_tune: true subtoken_pooling: first fuse_layers: false + represent_spans: false post_fusion_schema: null # e.g., "l2l-l2t-t2t" span_mode: token_level # Options: token_level | markerV0 max_types: 100 @@ -17,7 +18,7 @@ model: data: # Directory Paths root_dir: gliner_logs - train_data: "data.json" + train_data: "data/data.json" val_data_dir: "none" # Set to validation data path or "none" training: @@ -27,7 +28,7 @@ training: # Training Parameters num_steps: 15000 train_batch_size: 8 - eval_every: 500 + eval_every: 10 warmup_ratio: 0.05 scheduler_type: "cosine" # Options: linear, cosine, constant diff --git a/docs/architectures.md b/docs/architectures.md index d5a29ca..4b57f18 100644 --- a/docs/architectures.md +++ b/docs/architectures.md @@ -21,7 +21,9 @@ The GLiNER framework now supports multiple architecture variants, each optimized | **BiEncoderSpan** | Separate encoders for text & labels | Span-level | Pre-compute label embeddings, handles 100+ entity types | Many entity types, production deployment | | **BiEncoderToken** | Separate encoders for text & labels | Token-level | Combines bi-encoder efficiency with token-level prediction | Long entities with many types | | **UniEncoderSpanDecoder** | Single encoder + generative decoder | Span-level with generation | Generates entity labels, open vocabulary | Open-domain NER, label discovery | +| **UniEncoderTokenDecoder** | Single encoder + generative decoder | Token-level with generation | Token-level detection + label generation | Long entities with open vocabulary | | **UniEncoderSpanRelex** | Single encoder + relation layers | Span-level + relations | Joint entity and relation extraction | Knowledge graph construction, IE | +| **UniEncoderTokenRelex** | Single encoder + relation layers | Token-level + relations | Token-level entities with relation extraction | Long entities with relations | The framework automatically selects the appropriate architecture based on your model configuration, providing a unified API across all variants. @@ -243,6 +245,95 @@ for entity in entities[0]: print(f" Generated: {entity['generated_labels']}") ``` +## GLiNER Token-Level with Generative Decoder (UniEncoderTokenDecoder) + +The UniEncoderTokenDecoder architecture combines token-level BIO tagging with a generative decoder, offering the best of both worlds: the ability to handle long entity spans (from token-level prediction) and open-vocabulary entity typing (from the generative decoder). + +### Architecture Overview + +The architecture extends UniEncoderToken with decoder capabilities: +1. **Token Encoder**: Standard GLiNER token-based encoder with BIO tagging for entity boundary detection +2. **Span Representation Layer**: Converts detected token sequences into span representations +3. **Generative Decoder**: GPT-2 or similar decoder that generates entity type labels for detected spans + +### Key Differences from UniEncoderSpanDecoder + +| Aspect | UniEncoderSpanDecoder | UniEncoderTokenDecoder | +|--------|----------------------|------------------------| +| Entity Detection | Span enumeration (max width 12) | Token-level BIO tagging | +| Long Entities | Limited by max span width | No length limitation | +| Computation | O(n × max_width) spans | O(n) tokens | +| Best For | Standard NER entities | Long-form extraction + label generation | + +### Architecture Details + +**Token-Level Detection**: +- Uses the Scorer module to compute token-label compatibility scores +- Produces three logits per token per class: start, inside, end (BIO-style) +- Entities are extracted by finding contiguous sequences of positive predictions + +**Span Representation**: +- Detected token sequences are converted to span representations +- Optional `represent_spans` mode uses a dedicated SpanRepLayer for richer representations +- Span representations are then fed to the decoder for label generation + +**Training Objective**: +- Multi-component loss: + - Token classification loss (`token_loss_coef`) + - Span classification loss (`span_loss_coef`) + - Decoder generation loss (`decoder_loss_coef`) + +### Configuration + +```python +from gliner import GLiNERConfig + +config = GLiNERConfig( + model_name="microsoft/deberta-v3-small", + span_mode="token_level", + labels_decoder="gpt2", # Enables decoder + decoder_mode="span", # or "prompt" + token_loss_coef=1.0, + span_loss_coef=1.0, + decoder_loss_coef=0.5, + represent_spans=True, # Use SpanRepLayer for span representations +) +``` + +### Use Cases + +- **Long entity extraction with open vocabulary**: Extract multi-sentence entities while generating appropriate labels +- **Extractive summarization with typing**: Identify summary-worthy spans and label their semantic type +- **Document-level entity typing**: Handle entities that span multiple clauses or sentences +- **Flexible entity annotation**: Combine precise boundary detection with descriptive type generation + +### Example Usage + +```python +from gliner import GLiNER + +# Load a token-level decoder model +model = GLiNER.from_pretrained("knowledgator/gliner-token-decoder-v1.0") + +text = """The Paris Agreement, adopted in December 2015 at the 21st Conference +of the Parties to the United Nations Framework Convention on Climate Change, +represents a landmark international accord on climate action.""" + +# Extract long entities with generated labels +entities = model.inference( + [text], + labels=["entity"], + gen_constraints=["agreement", "organization", "event", "topic"], + num_gen_sequences=1 +) + +for entity in entities[0]: + print(f"Entity: {entity['text'][:50]}...") + print(f" Label: {entity['label']}") + if 'generated_labels' in entity: + print(f" Generated: {entity['generated_labels']}") +``` + ## GLiNER for Relation Extraction (UniEncoderSpanRelex) The UniEncoderSpanRelex architecture extends GLiNER to perform joint entity and relation extraction, enabling the model to identify both entities and the relationships between them in a single forward pass. @@ -361,6 +452,108 @@ for relation in relations[0]: } ``` +## GLiNER Token-Level for Relation Extraction (UniEncoderTokenRelex) + +The UniEncoderTokenRelex architecture combines token-level entity detection with relation extraction capabilities, enabling joint extraction of long-form entities and their relationships. + +### Architecture Overview + +The architecture extends UniEncoderToken with relation extraction components: + +1. **Token Encoder**: Standard GLiNER token-based encoder with BIO tagging +2. **Span Representation Layer**: Converts detected token sequences into entity representations +3. **Relation Representation Layer**: Computes pairwise entity representations and adjacency predictions +4. **Relation Classification Layer**: Classifies relation types between entity pairs + +### Key Differences from UniEncoderSpanRelex + +| Aspect | UniEncoderSpanRelex | UniEncoderTokenRelex | +|--------|---------------------|----------------------| +| Entity Detection | Span enumeration | Token-level BIO tagging | +| Long Entities | Limited by max span width | No length limitation | +| Entity Boundaries | Explicit span indices | Derived from token predictions | +| Best For | Standard entity-relation extraction | Long entities with relations | + +### Architecture Details + +**Entity Detection**: +- Uses the Scorer module for token-level classification (start, inside, end) +- Entity spans are extracted from contiguous positive predictions +- No maximum span width limitation + +**Relation Extraction**: +- Same relation layers as UniEncoderSpanRelex +- Relations computed between entity representations derived from token sequences +- Supports both concatenation and triple scoring methods + +**Training Objective**: +- Multi-component loss: + - Token classification loss (for entity boundaries) + - Adjacency loss (for entity pair connectivity) + - Relation classification loss (for relation types) + +### Configuration + +```python +from gliner import GLiNERConfig + +config = GLiNERConfig( + model_name="microsoft/deberta-v3-small", + span_mode="token_level", + relations_layer="biaffine", # or "concat" + triples_layer="TransE", # Optional: for triple scoring + span_loss_coef=1.0, + adjacency_loss_coef=1.0, + relation_loss_coef=1.0, +) +``` + +### Use Cases + +- **Scientific IE**: Extract long entity mentions (chemical compounds, gene names) and their relationships +- **Legal Document Analysis**: Identify parties, clauses, and their legal relationships +- **Medical Record Processing**: Extract symptoms, treatments, and their clinical relationships +- **News Event Extraction**: Identify event participants and their roles across long descriptions + +### Example Usage + +```python +from gliner import GLiNER + +# Load a token-level relation extraction model +model = GLiNER.from_pretrained("knowledgator/gliner-token-relex-v1.0") + +text = """The Phase III clinical trial conducted by Pfizer and BioNTech +demonstrated that the BNT162b2 vaccine achieved 95% efficacy against +COVID-19 in participants without prior infection.""" + +# Define entity and relation types +entity_labels = ["organization", "vaccine", "disease", "metric"] +relation_labels = ["developed_by", "effective_against", "measured_as"] + +# Extract entities and relations +entities, relations = model.inference( + [text], + labels=entity_labels, + relations=relation_labels, + threshold=0.5, + relation_threshold=0.5 +) + +# Display results +print("Entities:") +for entity in entities[0]: + print(f" {entity['text']} ({entity['label']})") + +print("\nRelations:") +for relation in relations[0]: + print(f" {relation['head']['text']} --[{relation['relation']}]--> {relation['tail']['text']}") +``` + +### Output Format + +Same as UniEncoderSpanRelex - entities and relations are returned in identical formats, ensuring API consistency across architectures. + ## Choosing the Right Architecture Here's a quick guide to selecting the appropriate GLiNER architecture: @@ -372,9 +565,36 @@ Here's a quick guide to selecting the appropriate GLiNER architecture: | Many entity types (50-200+) | BiEncoderSpan or BiEncoderToken | Pre-compute labels, handles many types | | Production with fixed schema | BiEncoder variants | Cache label embeddings for speed | | Open-domain, unknown types | UniEncoderSpanDecoder | Generate labels on-the-fly | +| Long entities + open vocabulary | UniEncoderTokenDecoder | Token-level detection with label generation | | Knowledge graph extraction | UniEncoderSpanRelex | Joint entity and relation extraction | +| Long entities with relations | UniEncoderTokenRelex | Token-level entities with relation extraction | | Both long entities + many types | BiEncoderToken | Combines both advantages | +### Decision Flowchart + +``` +Start + │ + ├─ Need relation extraction? + │ ├─ Yes: Long entities expected? + │ │ ├─ Yes → UniEncoderTokenRelex + │ │ └─ No → UniEncoderSpanRelex + │ │ + │ └─ No: Need open vocabulary labels? + │ ├─ Yes: Long entities expected? + │ │ ├─ Yes → UniEncoderTokenDecoder + │ │ └─ No → UniEncoderSpanDecoder + │ │ + │ └─ No: Many entity types (>30)? + │ ├─ Yes: Long entities expected? + │ │ ├─ Yes → BiEncoderToken + │ │ └─ No → BiEncoderSpan + │ │ + │ └─ No: Long entities expected? + │ ├─ Yes → UniEncoderToken + │ └─ No → UniEncoderSpan +``` + ## References - The Intro section is based on the [Shahrukh Khan](https://www.linkedin.com/in/shahrukhx01/) article [Illustrated GLINER](https://medium.com/@shahrukhx01/illustrated-gliner-e6971e4c8c52) and placed into documentation with consent of the author. - [Urchade Zaratiana, Nadi Tomeh, Pierre Holat, and Thierry Charnois. 2023. Gliner: Generalist model for named entity recognition using bidirectional transformer](https://arxiv.org/abs/2311.08526) diff --git a/gliner/config.py b/gliner/config.py index 8cdd817..d57e103 100644 --- a/gliner/config.py +++ b/gliner/config.py @@ -35,6 +35,10 @@ def __init__( ent_token: str = "<>", sep_token: str = "<>", _attn_implementation: Optional[str] = None, + token_loss_coef: float = 1.0, + span_loss_coef: float = 1.0, + represent_spans: bool = False, + neg_spans_ratio: float = 1.0, **kwargs, ): """Initialize BaseGLiNERConfig. @@ -64,6 +68,10 @@ def __init__( ent_token (str, optional): Entity marker token. Defaults to "<>". sep_token (str, optional): Separator token. Defaults to "<>". _attn_implementation (str, optional): Attention implementation. Defaults to None. + token_loss_coef (float, optional): Token loss coefficient. Defaults to 1.0. + span_loss_coef (float, optional): Span loss coefficient. Defaults to 1.0. + represent_spans (bool, optional): Whether to represent spans. Defaults to False. + neg_spans_ratio (float, optional): Ratio of negative spans. Defaults to 1.0. **kwargs: Additional keyword arguments passed to parent class. """ super().__init__(**kwargs) @@ -96,6 +104,10 @@ def __init__( self.ent_token = ent_token self.sep_token = sep_token self._attn_implementation = _attn_implementation + self.token_loss_coef = token_loss_coef + self.span_loss_coef = span_loss_coef + self.represent_spans = represent_spans + self.neg_spans_ratio = neg_spans_ratio class UniEncoderConfig(BaseGLiNERConfig): @@ -136,7 +148,6 @@ def __init__( blank_entity_prob: float = 0.1, labels_decoder_config: Optional[dict] = None, decoder_loss_coef=0.5, - span_loss_coef=0.5, **kwargs, ): """Initialize UniEncoderSpanDecoderConfig. @@ -148,7 +159,6 @@ def __init__( blank_entity_prob (float, optional): Probability of blank entities. Defaults to 0.1. labels_decoder_config (dict, optional): Decoder config dict. Defaults to None. decoder_loss_coef (float, optional): Decoder loss coefficient. Defaults to 0.5. - span_loss_coef (float, optional): Span loss coefficient. Defaults to 0.5. **kwargs: Additional keyword arguments passed to UniEncoderConfig. Raises: @@ -166,15 +176,18 @@ def __init__( self.decoder_mode = decoder_mode # 'prompt' or 'span' self.full_decoder_context = full_decoder_context self.decoder_loss_coef = decoder_loss_coef - self.span_loss_coef = span_loss_coef self.model_type = "gliner_uni_encoder_span_decoder" - if self.span_mode == "token_level": - raise ValueError("UniEncoderSpanDecoderConfig requires span_mode != 'token_level'") -class UniEncoderSpanRelexConfig(UniEncoderConfig): - """Configuration for uni-encoder span model with relation extraction.""" +class UniEncoderTokenDecoderConfig(UniEncoderSpanDecoderConfig): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.span_mode = "token_level" + self.model_type = "gliner_encoder_token_decoder" + self.represent_spans = True # hardcoded to True for token decoder + +class UniEncoderRelexConfig(UniEncoderConfig): def __init__( self, relations_layer: Optional[str] = None, @@ -182,12 +195,11 @@ def __init__( embed_rel_token: bool = True, rel_token_index: int = -1, rel_token: str = "<>", - span_loss_coef=1.0, adjacency_loss_coef=1.0, relation_loss_coef=1.0, **kwargs, ): - """Initialize UniEncoderSpanRelexConfig. + """Initialize UniEncoderRelexConfig. Args: relations_layer (str, optional): Name of relations layer, @@ -197,7 +209,6 @@ def __init__( embed_rel_token (bool, optional): Whether to embed relation tokens. Defaults to True. rel_token_index (int, optional): Index of relation token. Defaults to -1. rel_token (str, optional): Relation marker token. Defaults to "<>". - span_loss_coef (float, optional): Span representaton loss coefficient. Defaults to 1.0. adjacency_loss_coef (float, optional): Adjacency modeling loss coefficient. Defaults to 1.0. relation_loss_coef (float, optional): Relation representaton loss coefficient. Defaults to 1.0. **kwargs: Additional keyword arguments passed to UniEncoderConfig. @@ -212,14 +223,29 @@ def __init__( self.embed_rel_token = embed_rel_token self.rel_token_index = rel_token_index self.rel_token = rel_token - self.span_loss_coef = span_loss_coef self.adjacency_loss_coef = adjacency_loss_coef self.relation_loss_coef = relation_loss_coef + + +class UniEncoderSpanRelexConfig(UniEncoderRelexConfig): + """Configuration for uni-encoder span model with relation extraction.""" + + def __init__(self, **kwargs): + super().__init__(**kwargs) self.model_type = "gliner_uni_encoder_span_relex" if self.span_mode == "token_level": raise ValueError("UniEncoderSpanRelexConfig requires span_mode != 'token_level'") +class UniEncoderTokenRelexConfig(UniEncoderRelexConfig): + """Configuration for uni-encoder token-level model with relation extraction.""" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.model_type = "gliner_uni_encoder_token_relex" + self.span_mode = "token_level" + + class BiEncoderConfig(BaseGLiNERConfig): """Base configuration for bi-encoder GLiNER models.""" @@ -298,11 +324,17 @@ def __init__( def model_type(self): """Auto-detect model type based on configuration.""" if self.labels_decoder: - return "gliner_uni_encoder_span_decoder" + if self.span_mode == 'token-level': + return "gliner_uni_encoder_token_decoder" + else: + return "gliner_uni_encoder_span_decoder" elif self.labels_encoder: return "gliner_bi_encoder_span" if self.span_mode != "token-level" else "gliner_bi_encoder_token" elif self.relations_layer is not None: - return "gliner_uni_encoder_span_relex" + if self.span_mode == "token-level": + return "gliner_uni_encoder_token_relex" + else: + return "gliner_uni_encoder_span_relex" elif self.span_mode == "token-level": return "gliner_uni_encoder_token" else: @@ -318,9 +350,11 @@ def model_type(self): "gliner_uni_encoder_span": UniEncoderSpanConfig, "gliner_uni_encoder_token": UniEncoderTokenConfig, "gliner_uni_encoder_span_decoder": UniEncoderSpanDecoderConfig, + "gliner_uni_encoder_token_decoder": UniEncoderTokenDecoderConfig, "gliner_uni_encoder_span_relex": UniEncoderSpanRelexConfig, + "gliner_uni_encoder_token_relex": UniEncoderTokenRelexConfig, "gliner_bi_encoder": BiEncoderConfig, "gliner_bi_encoder_span": BiEncoderSpanConfig, "gliner_bi_encoder_token": BiEncoderTokenConfig, } -) +) \ No newline at end of file diff --git a/gliner/data_processing/__init__.py b/gliner/data_processing/__init__.py index 95be360..4cd98f4 100644 --- a/gliner/data_processing/__init__.py +++ b/gliner/data_processing/__init__.py @@ -5,6 +5,8 @@ UniEncoderTokenDataCollator, UniEncoderSpanDecoderDataCollator, RelationExtractionSpanDataCollator, + UniEncoderTokenDecoderDataCollator, + RelationExtractionTokenDataCollator, ) from .processor import ( BaseProcessor, @@ -15,5 +17,7 @@ UniEncoderTokenProcessor, UniEncoderSpanDecoderProcessor, RelationExtractionSpanProcessor, + UniEncoderTokenDecoderProcessor, + RelationExtractionTokenProcessor, ) from .tokenizer import WordsSplitter diff --git a/gliner/data_processing/collator.py b/gliner/data_processing/collator.py index 95875f7..e3925f1 100644 --- a/gliner/data_processing/collator.py +++ b/gliner/data_processing/collator.py @@ -475,6 +475,18 @@ def _filter_none_values(self, batch_dict: Dict[str, Any]) -> Dict[str, Any]: return {k: v for k, v in batch_dict.items() if v is not None} +class RelationExtractionTokenDataCollator(RelationExtractionSpanDataCollator): + """Data collator for RelationExtractionTokenProcessor. + + Handles joint entity and relation extraction at token level. + Produces both entity labels and relation adjacency matrices. + + Required Processor: RelationExtractionTokenProcessor + """ + + pass + + class UniEncoderSpanDataCollator(SpanDataCollator): """ Backward compatibility alias for SpanDataCollator with UniEncoderSpanProcessor. @@ -505,6 +517,14 @@ class UniEncoderSpanDecoderDataCollator(SpanDataCollator): pass +class UniEncoderTokenDecoderDataCollator(UniEncoderSpanDecoderDataCollator): + """ + Backward compatibility alias for UniEncoderTokenDecoderDataCollator with UniEncoderSpanDecoderDataCollator. + """ + + pass + + class UniEncoderTokenDataCollator(TokenDataCollator): """ Backward compatibility alias for TokenDataCollator with UniEncoderTokenProcessor. diff --git a/gliner/data_processing/processor.py b/gliner/data_processing/processor.py index f2ad777..3226bea 100644 --- a/gliner/data_processing/processor.py +++ b/gliner/data_processing/processor.py @@ -35,8 +35,8 @@ def __init__(self, config, tokenizer, words_splitter): self.words_splitter = WordsSplitter(splitter_type=config.words_splitter_type) else: self.words_splitter = words_splitter - self.ent_token = config.ent_token - self.sep_token = config.sep_token + self.ent_token = getattr(config, "ent_token", "[ENT]") + self.sep_token = getattr(config, "sep_token", "[SEP]") # Check if the tokenizer has unk_token and pad_token self._check_and_set_special_tokens(self.transformer_tokenizer) @@ -126,6 +126,29 @@ def tokenize_and_prepare_labels(self): """ pass + def sort_entities_and_relations(self, ner, relations=None): + if ner is not None and len(ner) > 0: + indexed_ner = list(enumerate(ner)) + indexed_ner_sorted = sorted(indexed_ner, key=lambda x: (x[1][0], x[1][1])) + + ner_sorted = [entity for _, entity in indexed_ner_sorted] + + # Create mapping from old entity indices to new sorted indices + old_to_new_idx = {old_idx: new_idx for new_idx, (old_idx, _) in enumerate(indexed_ner_sorted)} + + # Update relation indices to match new entity ordering + if relations is not None and len(relations) > 0: + updated_relations = [] + for head_idx, tail_idx, rel_type in relations: + if head_idx in old_to_new_idx and tail_idx in old_to_new_idx: + new_head_idx = old_to_new_idx[head_idx] + new_tail_idx = old_to_new_idx[tail_idx] + updated_relations.append((new_head_idx, new_tail_idx, rel_type)) + relations = sorted(updated_relations, key=lambda x: (x[0], x[1])) + + ner = ner_sorted + return ner, relations + def prepare_inputs( self, texts: Sequence[Sequence[str]], @@ -429,6 +452,14 @@ class UniEncoderSpanProcessor(BaseProcessor): predict entity types for all possible spans up to a maximum width. """ + def prepare_span_labels(self, ner, classes_to_id, num_tokens, spans_idx): + dict_lab = self.get_dict(ner, classes_to_id) if ner else defaultdict(int) + span_label = torch.LongTensor([dict_lab[i] for i in spans_idx]) + spans_idx = torch.LongTensor(spans_idx) + valid_span_mask = spans_idx[:, 1] > num_tokens - 1 + span_label = span_label.masked_fill(valid_span_mask, -1) + return span_label, spans_idx + def preprocess_example(self, tokens, ner, classes_to_id): """Preprocess a single example for span-based prediction. @@ -461,11 +492,8 @@ def preprocess_example(self, tokens, ner, classes_to_id): tokens = tokens[:max_len] num_tokens = len(tokens) spans_idx = prepare_span_idx(num_tokens, max_width) - dict_lab = self.get_dict(ner, classes_to_id) if ner else defaultdict(int) - span_label = torch.LongTensor([dict_lab[i] for i in spans_idx]) - spans_idx = torch.LongTensor(spans_idx) - valid_span_mask = spans_idx[:, 1] > num_tokens - 1 - span_label = span_label.masked_fill(valid_span_mask, -1) + + span_label, spans_idx = self.prepare_span_labels(ner, classes_to_id, num_tokens, spans_idx) return { "tokens": tokens, @@ -577,6 +605,87 @@ class UniEncoderTokenProcessor(BaseProcessor): labeled with BIO-style tags (Begin, Inside, Outside) for each entity type. """ + def _generate_negative_spans(self, positive_spans, num_tokens, num_negatives, max_width=None): + """Generate random negative spans that don't overlap with positive spans. + + Args: + positive_spans: Set of (start, end) tuples representing positive entity spans. + num_tokens: Total number of tokens in the sequence. + num_negatives: Number of negative spans to generate. + max_width: Maximum width for negative spans. If None, uses config.max_width. + + Returns: + List of (start, end) tuples representing negative spans. + """ + if max_width is None: + max_width = getattr(self.config, "max_width", 10) + + negative_spans = [] + attempts = 0 + max_attempts = num_negatives * 20 # Limit attempts to avoid infinite loops + + while len(negative_spans) < num_negatives and attempts < max_attempts: + attempts += 1 + + # Random start position + start = random.randint(0, num_tokens - 1) + + # Random width (1 to max_width) + width = random.randint(1, min(max_width, num_tokens - start)) + end = start + width - 1 + + # Check if this span overlaps with any positive span + span = (start, end) + if span in positive_spans: + continue + + # Check for overlap with positive spans + overlaps = False + for pos_start, pos_end in positive_spans: + if not (end < pos_start or start > pos_end): + overlaps = True + break + + if not overlaps and span not in negative_spans: + negative_spans.append(span) + + return negative_spans + + def prepare_span_idx(self, ner, classes_to_id, num_tokens): + if ner is not None and self.config.represent_spans: + span_idx_list = [] + span_label_list = [] + positive_spans = set() + + # Add positive spans + for start, end, label in ner: + if label in classes_to_id and end < num_tokens: + span_idx_list.append([start, end]) + span_label_list.append(classes_to_id[label]) + positive_spans.add((start, end)) + + # Add negative spans + neg_spans_ratio = self.config.neg_spans_ratio + neg_spans_count = int(len(span_idx_list) * neg_spans_ratio) + + if neg_spans_count > 0 and num_tokens > 0: + max_width = getattr(self.config, "max_width", 10) + negative_spans = self._generate_negative_spans(positive_spans, num_tokens, neg_spans_count, max_width) + + for start, end in negative_spans: + span_idx_list.append([start, end]) + span_label_list.append(0) # 0 indicates negative/no entity + + if span_idx_list: + span_idx = torch.LongTensor(span_idx_list) + span_label = torch.LongTensor(span_label_list) + else: + span_idx = torch.zeros(0, 2, dtype=torch.long) + span_label = torch.zeros(0, dtype=torch.long) + else: + span_idx, span_label = None, None + return span_idx, span_label + def preprocess_example(self, tokens, ner, classes_to_id): """Preprocess a single example for token-based prediction. @@ -590,7 +699,8 @@ def preprocess_example(self, tokens, ner, classes_to_id): - tokens: Token strings - seq_length: Sequence length - entities: Original NER annotations - - entities_id: Entity annotations with class IDs + - span_idx: Tensor of entity span indices (if represent_spans=True) + - span_label: Tensor of entity class IDs (if represent_spans=True) Warnings: UserWarning: If sequence length exceeds max_len (gets truncated). @@ -605,13 +715,17 @@ def preprocess_example(self, tokens, ner, classes_to_id): warnings.warn(f"Sentence of length {len(tokens)} has been truncated to {max_len}", stacklevel=2) tokens = tokens[:max_len] - # Generate entity IDs based on the NER spans provided and their classes - try: # 'NoneType' object is not iterable - entities_id = [[i, j, classes_to_id[k]] for i, j, k in ner if k in classes_to_id] - except TypeError: - entities_id = [] + num_tokens = len(tokens) + + span_idx, span_label = self.prepare_span_idx(ner, classes_to_id, num_tokens) - example = {"tokens": tokens, "seq_length": len(tokens), "entities": ner, "entities_id": entities_id} + example = { + "tokens": tokens, + "seq_length": len(tokens), + "entities": ner, + "span_idx": span_idx, + "span_label": span_label, + } return example def create_batch_dict(self, batch, class_to_ids, id_to_classes): @@ -627,7 +741,9 @@ def create_batch_dict(self, batch, class_to_ids, id_to_classes): - tokens: Token strings - seq_length: Sequence lengths - entities: Original NER annotations - - entities_id: Entity annotations with class IDs + - span_idx: Padded span indices (if available) + - span_label: Padded span labels (if available) + - span_mask: Mask for valid spans (if available) - classes_to_id: Class mappings - id_to_classes: Reverse class mappings """ @@ -635,41 +751,64 @@ def create_batch_dict(self, batch, class_to_ids, id_to_classes): tokens = [el["tokens"] for el in batch] seq_length = torch.LongTensor([el["seq_length"] for el in batch]).unsqueeze(-1) entities = [el["entities"] for el in batch] - entities_id = [el["entities_id"] for el in batch] - # Assemble and return the batch dictionary + # Assemble the base batch dictionary batch_dict = { "tokens": tokens, "seq_length": seq_length, "entities": entities, - "entities_id": entities_id, "classes_to_id": class_to_ids, "id_to_classes": id_to_classes, } + # Handle span representations if present + if batch[0]["span_idx"] is not None: + span_idx_list = [el["span_idx"] for el in batch] + span_label_list = [el["span_label"] for el in batch] + + batch_size = len(span_idx_list) + span_counts = [s.size(0) if s.numel() > 0 else 0 for s in span_idx_list] + max_spans = max(*span_counts, 1) # Ensure at least 1 + + # Create span mask indicating valid spans + span_mask = torch.zeros(batch_size, max_spans, dtype=torch.bool) + for i, count in enumerate(span_counts): + if count > 0: + span_mask[i, :count] = True + + # Pad span tensors + span_idx = pad_2d_tensor(span_idx_list, padding_value=0) + span_label = pad_sequence(span_label_list, batch_first=True, padding_value=-1) + + batch_dict["span_idx"] = span_idx + batch_dict["span_label"] = span_label + batch_dict["span_mask"] = span_mask + return batch_dict - def create_labels(self, entities_id, batch_size, seq_len, num_classes): + def create_labels(self, batch): """Create token-level labels with begin/inside/end markers. Creates labels indicating which tokens are at the start, end, or inside of entity spans for each entity type. Args: - entities_id: List of entity annotations with class IDs for each example. - batch_size: Size of the batch. - seq_len: Maximum sequence length in batch. - num_classes: Number of entity classes. + batch: List[Any] batch of data Returns: Tensor of shape (batch_size, seq_len, num_classes, 3) where the last dimension contains [start_marker, end_marker, inside_marker]. """ + batch_size = len(batch["tokens"]) + seq_len = batch["seq_length"].max().item() + num_classes = max([len(cid) for cid in batch["classes_to_id"]]) + word_labels = torch.zeros(batch_size, seq_len, num_classes, 3, dtype=torch.float) - for i, sentence_entities in enumerate(entities_id): + for i, sentence_entities in enumerate(batch["entities"]): for st, ed, sp_label in sentence_entities: - class_idx = sp_label - 1 # Convert to 0-indexed + lbl = batch["classes_to_id"][i][sp_label] + class_idx = lbl - 1 # Convert to 0-indexed # skip entities that point beyond sequence length if st >= seq_len or ed >= seq_len: @@ -681,6 +820,50 @@ def create_labels(self, entities_id, batch_size, seq_len, num_classes): return word_labels + def create_span_labels(self, batch): + """Create one-hot encoded labels for spans with negative sampling. + + Creates one-hot encoded labels for entity spans, converting 1-indexed class IDs + to 0-indexed format. Labels with class ID 0 (negative spans) or -1 (invalid spans) + are represented as all zeros in the one-hot encoding. + + Args: + batch: Batch dictionary containing span_label, span_mask, and classes_to_id. + + Returns: + Tensor of shape (batch_size, max_spans, num_classes) containing one-hot + encoded labels where: + - Positive spans: one-hot vector at position (class_id - 1) + - Negative/invalid spans: all zeros + """ + batch_size = len(batch["tokens"]) + span_label = batch["span_label"] # (batch_size, max_spans) + span_mask = batch["span_mask"] # (batch_size, max_spans) + + # Get maximum number of classes across all examples + if isinstance(batch["classes_to_id"], list): + num_classes = max([len(cid) for cid in batch["classes_to_id"]]) + else: + num_classes = len(batch["classes_to_id"]) + + max_spans = span_label.size(1) + + # Initialize one-hot labels (batch_size, max_spans, num_classes) + labels_one_hot = torch.zeros(batch_size, max_spans, num_classes, dtype=torch.float) + + for i in range(batch_size): + for j in range(max_spans): + if span_mask[i, j]: # Valid span + class_id = span_label[i, j].item() + + if class_id > 0: + # Convert from 1-indexed to 0-indexed + class_idx = class_id - 1 + if class_idx < num_classes: + labels_one_hot[i, j, class_idx] = 1.0 + + return labels_one_hot + def tokenize_and_prepare_labels(self, batch, prepare_labels, *args, **kwargs): """Tokenize inputs and prepare token-level labels for a batch. @@ -693,15 +876,17 @@ def tokenize_and_prepare_labels(self, batch, prepare_labels, *args, **kwargs): Returns: Dictionary containing tokenized inputs and optionally labels. """ - batch_size = len(batch["tokens"]) - seq_len = batch["seq_length"].max() - num_classes = max([len(cid) for cid in batch["classes_to_id"]]) - tokenized_input = self.tokenize_inputs(batch["tokens"], batch["classes_to_id"]) - if prepare_labels: - labels = self.create_labels(batch["entities_id"], batch_size, seq_len, num_classes) + labels = self.create_labels(batch) tokenized_input["labels"] = labels + + # Add span-level one-hot labels if spans are represented + if batch.get("span_idx") is not None: + span_labels = self.create_span_labels(batch) + tokenized_input["span_labels"] = span_labels + tokenized_input["span_idx"] = batch["span_idx"] + tokenized_input["span_mask"] = batch["span_mask"] return tokenized_input @@ -861,16 +1046,19 @@ def tokenize_and_prepare_labels(self, batch, prepare_labels, prepare_entities=Tr entities = list(batch["classes_to_id"][0]) else: entities = None - batch_size = len(batch["tokens"]) - seq_len = batch["seq_length"].max() - num_classes = len(entities) tokenized_input = self.tokenize_inputs(batch["tokens"], entities) if prepare_labels: - labels = self.create_labels(batch["entities_id"], batch_size, seq_len, num_classes) + labels = self.create_labels(batch) tokenized_input["labels"] = labels + # Add span-level one-hot labels if spans are represented + if batch.get("span_idx") is not None: + span_labels = self.create_span_labels(batch) + tokenized_input["span_labels"] = span_labels + tokenized_input["span_idx"] = batch["span_idx"] + tokenized_input["span_mask"] = batch["span_mask"] return tokenized_input @@ -912,7 +1100,7 @@ def tokenize_inputs(self, texts, entities, blank=None): Dictionary containing encoder and decoder tokenized inputs. """ add_entities = True - if self.config.decoder_mode == 'prompt': + if self.config.decoder_mode == "prompt": add_entities = False input_texts, prompt_lengths = self.prepare_inputs(texts, entities, blank=blank, add_entities=add_entities) @@ -948,6 +1136,20 @@ def tokenize_inputs(self, texts, entities, blank=None): return tokenized_inputs + def prepare_decoder_labels(self, decoder_label_strings): + if not decoder_label_strings: + decoder_label_strings = ["other"] + + decoder_tokenized_input = self.decoder_tokenizer( + decoder_label_strings, return_tensors="pt", truncation=True, padding="longest", add_special_tokens=True + ) + decoder_input_ids = decoder_tokenized_input["input_ids"] + decoder_attention_mask = decoder_tokenized_input["attention_mask"] + decoder_labels = decoder_input_ids.clone() + decoder_labels.masked_fill(~decoder_attention_mask.bool(), -100) + decoder_tokenized_input["labels"] = decoder_labels + return decoder_tokenized_input + def create_labels(self, batch, blank=None): """Create labels for both span classification and decoder generation. @@ -999,30 +1201,18 @@ def create_labels(self, batch, blank=None): labels_one_hot[valid_span_mask, :] = 0.0 labels_one_hot = labels_one_hot[:, 1:] labels_batch.append(labels_one_hot) - - if self.config.decoder_mode == 'span': + + if self.config.decoder_mode == "span": # Collect decoder label strings in order sorted_idxs = sorted(span_labels_dict.keys()) for idx in sorted_idxs: decoder_label_strings.append(span_labels_dict[idx]) - elif self.config.decoder_mode == 'prompt': + elif self.config.decoder_mode == "prompt": decoder_label_strings.extend(list(classes_to_id)) - - labels_batch = pad_2d_tensor(labels_batch) if len(labels_batch) > 1 else labels_batch[0].unsqueeze(0) - - decoder_tokenized_input = None - if not decoder_label_strings: - decoder_label_strings = ["other"] + labels_batch = pad_2d_tensor(labels_batch) if len(labels_batch) > 1 else labels_batch[0].unsqueeze(0) - decoder_tokenized_input = self.decoder_tokenizer( - decoder_label_strings, return_tensors="pt", truncation=True, padding="longest", add_special_tokens=True - ) - decoder_input_ids = decoder_tokenized_input["input_ids"] - decoder_attention_mask = decoder_tokenized_input["attention_mask"] - decoder_labels = decoder_input_ids.clone() - decoder_labels.masked_fill(~decoder_attention_mask.bool(), -100) - decoder_tokenized_input["labels"] = decoder_labels + decoder_tokenized_input = self.prepare_decoder_labels(decoder_label_strings) return labels_batch, decoder_tokenized_input def tokenize_and_prepare_labels(self, batch, prepare_labels, *args, **kwargs): @@ -1055,6 +1245,157 @@ def tokenize_and_prepare_labels(self, batch, prepare_labels, *args, **kwargs): return tokenized_input +class UniEncoderTokenDecoderProcessor(UniEncoderSpanDecoderProcessor, UniEncoderTokenProcessor): + """Processor for token-based NER with encoder-decoder architecture. + + This processor combines token-level BIO-style classification with a decoder + that generates entity type labels autoregressively, enabling more flexible + prediction strategies for token-level NER tasks. + + Inherits from: + - UniEncoderSpanDecoderProcessor: Encoder-decoder architecture and decoder utilities + - UniEncoderTokenProcessor: Token-level BIO tagging for entities + """ + + def __init__(self, config, tokenizer, words_splitter, decoder_tokenizer): + """Initialize the token-level encoder-decoder processor. + + Args: + config: Configuration object. + tokenizer: Transformer tokenizer for encoding. + words_splitter: Word-level tokenizer/splitter. + decoder_tokenizer: Separate tokenizer for decoder (label generation). + """ + # Initialize BaseProcessor through UniEncoderSpanDecoderProcessor's chain + super().__init__(config, tokenizer, words_splitter, decoder_tokenizer) + + def preprocess_example(self, tokens, ner, classes_to_id): + """Preprocess a single example for token-level encoder-decoder prediction. + + Uses token-level preprocessing from UniEncoderTokenProcessor while + preparing for decoder-based label generation. + + Args: + tokens: List of token strings. + ner: List of NER annotations as (start, end, label) tuples. + classes_to_id: Mapping from class labels to integer IDs. + + Returns: + Dictionary containing: + - tokens: Token strings + - seq_length: Sequence length + - entities: Original NER annotations + - span_idx: Tensor of entity span indices (if represent_spans=True) + - span_label: Tensor of entity class IDs (if represent_spans=True) + + Warnings: + UserWarning: If sequence length exceeds max_len (gets truncated). + """ + # Use token processor's preprocessing + return UniEncoderTokenProcessor.preprocess_example(self, tokens, ner, classes_to_id) + + def create_batch_dict(self, batch, class_to_ids, id_to_classes): + """Create a batch dictionary from preprocessed token examples. + + Args: + batch: List of preprocessed example dictionaries. + class_to_ids: List of class-to-ID mappings. + id_to_classes: List of ID-to-class mappings. + + Returns: + Dictionary containing all batch data for token-level encoder-decoder + processing. + """ + # Use token processor's batch dict creation + return UniEncoderTokenProcessor.create_batch_dict(self, batch, class_to_ids, id_to_classes) + + def create_labels(self, batch, blank=None): + """Create labels for both token classification and decoder generation. + + Creates both token-level BIO labels and decoder generation labels for + entity types. + + Args: + batch: Batch dictionary containing tokens, entities, and class mappings. + blank: Optional blank entity token for zero-shot scenarios. + + Returns: + Tuple containing: + - Token-level labels (BIO-style, shape: [batch_size, seq_len, num_classes, 3]) + - Decoder generation labels (tokenized entity types) or None + """ + # Create token-level labels + token_labels = UniEncoderTokenProcessor.create_labels(self, batch) + + # Create decoder labels + decoder_label_strings = [] + + for i in range(len(batch["tokens"])): + tokens = batch["tokens"][i] + classes_to_id = batch["classes_to_id"][i] + ner = batch["entities"][i] + + num_tokens = len(tokens) + if self.config.decoder_mode == "span": + # Collect entity labels in order of appearance + sorted_entities = sorted(ner, key=lambda x: (x[0], x[1])) if ner else [] + for start, end, label in sorted_entities: + if label in classes_to_id and end < num_tokens: + decoder_label_strings.append(label) + elif self.config.decoder_mode == "prompt": + # Use all entity types as decoder labels + decoder_label_strings.extend(list(classes_to_id)) + + decoder_tokenized_input = self.prepare_decoder_labels(decoder_label_strings) + + return token_labels, decoder_tokenized_input + + def tokenize_and_prepare_labels(self, batch, prepare_labels, *args, **kwargs): + """Tokenize inputs and prepare labels for token-level encoder-decoder training. + + Combines token-level input processing with decoder inputs and prepares + both token-level BIO labels and decoder generation labels. + + Args: + batch: Batch dictionary with tokens and class mappings. + prepare_labels: Whether to prepare labels. + *args: Variable length argument list. + **kwargs: Arbitrary keyword arguments. + + Returns: + Dictionary containing encoder inputs, decoder inputs, token-level labels, + and decoder labels. + """ + blank = None + if random.uniform(0, 1) < self.config.blank_entity_prob and prepare_labels: + blank = "entity" + + # Use span decoder's tokenize_inputs for encoder-decoder tokenization + tokenized_input = UniEncoderSpanDecoderProcessor.tokenize_inputs( + self, batch["tokens"], batch["classes_to_id"], blank + ) + + if prepare_labels: + # Create both token-level and decoder labels + token_labels, decoder_tokenized_input = self.create_labels(batch, blank=blank) + tokenized_input["labels"] = token_labels + + # Add span-level one-hot labels if spans are represented + if batch.get("span_idx") is not None: + span_labels = self.create_span_labels(batch) + tokenized_input["span_labels"] = span_labels + tokenized_input["span_idx"] = batch["span_idx"] + tokenized_input["span_mask"] = batch["span_mask"] + + # Add decoder labels + if decoder_tokenized_input is not None: + tokenized_input["decoder_labels_ids"] = decoder_tokenized_input["input_ids"] + tokenized_input["decoder_labels_mask"] = decoder_tokenized_input["attention_mask"] + tokenized_input["decoder_labels"] = decoder_tokenized_input["labels"] + + return tokenized_input + + class RelationExtractionSpanProcessor(UniEncoderSpanProcessor): """Processor for joint entity and relation extraction. @@ -1263,31 +1604,10 @@ def preprocess_example(self, tokens, ner, classes_to_id, relations, rel_classes_ num_tokens = len(tokens) spans_idx = prepare_span_idx(num_tokens, max_width) - if ner is not None and len(ner) > 0: - indexed_ner = list(enumerate(ner)) - indexed_ner_sorted = sorted(indexed_ner, key=lambda x: (x[1][0], x[1][1])) - - ner_sorted = [entity for _, entity in indexed_ner_sorted] - - old_to_new_idx = {old_idx: new_idx for new_idx, (old_idx, _) in enumerate(indexed_ner_sorted)} - - if relations is not None and len(relations) > 0: - updated_relations = [] - for head_idx, tail_idx, rel_type in relations: - if head_idx in old_to_new_idx and tail_idx in old_to_new_idx: - new_head_idx = old_to_new_idx[head_idx] - new_tail_idx = old_to_new_idx[tail_idx] - updated_relations.append((new_head_idx, new_tail_idx, rel_type)) - relations = sorted(updated_relations, key=lambda x: (x[0], x[1])) - - ner = ner_sorted + ner, relations = self.sort_entities_and_relations(ner, relations) # Process entity labels - dict_lab = self.get_dict(ner, classes_to_id) if ner else defaultdict(int) - span_label = torch.LongTensor([dict_lab[i] for i in spans_idx]) - spans_idx = torch.LongTensor(spans_idx) - valid_span_mask = spans_idx[:, 1] > num_tokens - 1 - span_label = span_label.masked_fill(valid_span_mask, -1) + span_label, spans_idx = self.prepare_span_labels(ner, classes_to_id, num_tokens, spans_idx) # Create entity span to index mapping span_to_idx = {(spans_idx[i, 0].item(), spans_idx[i, 1].item()): i for i in range(len(spans_idx))} @@ -1376,41 +1696,38 @@ def create_batch_dict(self, batch, class_to_ids, id_to_classes, rel_class_to_ids def create_relation_labels(self, batch, add_reversed_negatives=True, add_random_negatives=True, negative_ratio=2.0): """Create relation labels with negative pair sampling. - Generates training labels for relation extraction including both positive - relation pairs and carefully sampled negative pairs for contrastive learning. + Overrides the span-based version to work with token-level entity representations. + Uses entities_id count instead of span_label for entity counting. Args: batch: Batch dictionary containing entities and relations. - add_reversed_negatives: If True, add reversed direction pairs as - negatives (h,t) -> (t,h). These are important hard negatives - for learning relation directionality. - add_random_negatives: If True, add random entity pairs as negatives - to provide additional training signal. - negative_ratio: Ratio of negative to positive pairs. For example, - 2.0 means twice as many negatives as positives. + add_reversed_negatives: If True, add reversed direction pairs as negatives. + add_random_negatives: If True, add random entity pairs as negatives. + negative_ratio: Ratio of negative to positive pairs. Returns: Tuple containing: - - adj_matrix: Adjacency matrix indicating which entity pairs - to consider (shape: [B, max_entities, max_entities]) - - rel_matrix: Multi-hot encoded relation labels for each pair - (shape: [B, max_pairs, num_relation_classes]) + - adj_matrix: Adjacency matrix (shape: [B, max_entities, max_entities]) + - rel_matrix: Multi-hot relation labels (shape: [B, max_pairs, num_relation_classes]) """ B = len(batch["tokens"]) - entity_label = batch["span_label"] + span_mask = batch["span_mask"] - batch_ents = torch.sum(entity_label > 0, dim=-1) - max_En = torch.max(batch_ents).item() + # Count entities per sample (differs from span-based which uses span_label) + batch_ents = span_mask.long().squeeze(-1).sum(-1) + max_En = max(batch_ents.max().item(), 1) rel_class_to_ids = batch["rel_class_to_ids"] if isinstance(rel_class_to_ids, list): - C = max(len(r) for r in rel_class_to_ids) + C = max((len(r) for r in rel_class_to_ids), default=0) else: - C = len(rel_class_to_ids) + C = len(rel_class_to_ids) if rel_class_to_ids else 0 + + if C == 0: + return torch.zeros(B, max_En, max_En, dtype=torch.float), torch.zeros(B, 1, 1, dtype=torch.float) adj_matrix = torch.zeros(B, max_En, max_En, dtype=torch.float) - # Collect all pairs (positive + negative) and their relations all_pairs_info = [] max_total_pairs = 0 @@ -1419,7 +1736,6 @@ def create_relation_labels(self, batch, add_reversed_negatives=True, add_random_ rel_idx_i = batch["rel_idx"][i] rel_label_i = batch["rel_label"][i] - # Dictionary to group relations by entity pair pair_to_relations = {} positive_pairs = set() @@ -1434,35 +1750,25 @@ def create_relation_labels(self, batch, add_reversed_negatives=True, add_random_ positive_pairs.add(pair_key) if pair_key not in pair_to_relations: pair_to_relations[pair_key] = [] - class_id = rel_label_i[k].item() - pair_to_relations[pair_key].append(class_id) + pair_to_relations[pair_key].append(rel_label_i[k].item()) # Generate negative pairs negative_pairs = set() num_positives = len(positive_pairs) target_negatives = int(num_positives * negative_ratio) - # 1. Add reversed pairs as negatives (most important!) if add_reversed_negatives: for e1, e2 in positive_pairs: reversed_pair = (e2, e1) - # Only add if reversed pair is NOT also a positive relation if reversed_pair not in positive_pairs: negative_pairs.add(reversed_pair) - # 2. Add random negative pairs if needed - if add_random_negatives and len(negative_pairs) < target_negatives: - # Get entity span positions for proximity-based sampling - entities = batch["entities"][i] - entity_positions = [(ent[0], ent[1]) for ent in entities] if entities else [] - + if add_random_negatives and N > 1 and len(negative_pairs) < target_negatives: attempts = 0 - max_attempts = target_negatives * 10 # Avoid infinite loop + max_attempts = target_negatives * 10 while len(negative_pairs) < target_negatives and attempts < max_attempts: attempts += 1 - - # Sample two different entities e1 = random.randint(0, N - 1) e2 = random.randint(0, N - 1) @@ -1470,62 +1776,35 @@ def create_relation_labels(self, batch, add_reversed_negatives=True, add_random_ continue pair = (e1, e2) - - # Skip if already positive or already in negatives if pair in positive_pairs or pair in negative_pairs: continue - # Optional: bias towards nearby entities (hard negatives) - if entity_positions and len(entity_positions) > e1 and len(entity_positions) > e2: - pos1 = entity_positions[e1] - pos2 = entity_positions[e2] - distance = abs(pos1[0] - pos2[1]) # Distance between entities - - # Sample with probability inversely proportional to distance - # (closer entities are harder negatives) - if distance > 10 and random.random() < 0.5: - continue # Skip some far pairs - negative_pairs.add(pair) - # Combine all pairs (positives + negatives) and sort all_pairs = sorted(list(positive_pairs) + list(negative_pairs)) - - # Store pair info: pair, is_positive, relations - pair_info = [] - for pair in all_pairs: - is_positive = pair in positive_pairs - relations = pair_to_relations.get(pair, []) - pair_info.append((pair, is_positive, relations)) + pair_info = [(pair, pair in positive_pairs, pair_to_relations.get(pair, [])) for pair in all_pairs] all_pairs_info.append(pair_info) max_total_pairs = max(max_total_pairs, len(all_pairs)) - # Create matrices + max_total_pairs = max(max_total_pairs, 1) + rel_matrix = torch.zeros(B, max_total_pairs, C, dtype=torch.float) - pair_type_mask = torch.zeros(B, max_total_pairs, dtype=torch.long) # 1=positive, 0=negative for i in range(B): N = batch_ents[i].item() pair_info = all_pairs_info[i] - - adj = torch.zeros(N, N) + adj = torch.zeros(max(N, 1), max(N, 1)) for pair_idx, (pair, is_positive, relations) in enumerate(pair_info): e1, e2 = pair - - # Set adjacency (1.0 for both positive and negative pairs) adj[e1, e2] = 1.0 - # Mark pair type - pair_type_mask[i, pair_idx] = 1 if is_positive else 0 - if is_positive: - # Create multi-hot vector for positive pairs for class_id in relations: rel_matrix[i, pair_idx, class_id - 1] = 1.0 - adj_matrix[i, :N, :N] = adj + adj_matrix[i, :N, :N] = adj[:N, :N] return adj_matrix, rel_matrix @@ -1630,3 +1909,200 @@ def tokenize_and_prepare_labels(self, batch, prepare_labels, *args, **kwargs): tokenized_input["rel_matrix"] = rel_matrix return tokenized_input + + +class RelationExtractionTokenProcessor(UniEncoderTokenProcessor, RelationExtractionSpanProcessor): + """Processor for joint entity and relation extraction using token-level NER. + + Extends token-based NER processing to additionally handle relation extraction + between entity pairs, supporting end-to-end joint training with BIO-style + entity tagging. + + Inherits from: + - UniEncoderTokenProcessor: Token-level BIO tagging for entities + - RelationExtractionSpanProcessor: Relation extraction utilities + """ + + def __init__(self, config, tokenizer, words_splitter): + """Initialize the relation extraction token processor. + + Args: + config: Configuration object. + tokenizer: Transformer tokenizer. + words_splitter: Word-level tokenizer/splitter. + """ + UniEncoderTokenProcessor.__init__(self, config, tokenizer, words_splitter) + self.rel_token = config.rel_token + + def preprocess_example(self, tokens, ner, classes_to_id, relations=None, rel_classes_to_id=None): + """Preprocess a single example for joint entity and relation extraction. + + Processes both entity annotations (for token-level BIO tagging) and + relation triplets, ensuring consistent indexing when entities are reordered. + + Args: + tokens: List of token strings. + ner: List of entity annotations as (start, end, label) tuples. + classes_to_id: Mapping from entity class labels to integer IDs. + relations: List of relation annotations as (head_idx, tail_idx, rel_type) tuples. + rel_classes_to_id: Mapping from relation class labels to integer IDs. + + Returns: + Dictionary containing: + - tokens: Token strings + - seq_length: Sequence length + - entities: Original entity annotations + - entities_id: Entity annotations with class IDs + - relations: Original relation annotations + - rel_idx: Tensor of relation head/tail entity indices + - rel_label: Tensor of relation type labels + + Warnings: + UserWarning: If sequence length exceeds max_len (gets truncated). + """ + # Handle empty token list + if len(tokens) == 0: + tokens = ["[PAD]"] + + # Truncate if necessary + max_len = self.config.max_len + if len(tokens) > max_len: + warnings.warn(f"Sentence of length {len(tokens)} has been truncated to {max_len}", stacklevel=2) + tokens = tokens[:max_len] + + num_tokens = len(tokens) + + ner, relations = self.sort_entities_and_relations(ner, relations) + + # Create entity index mapping (from sorted entity list index to entities_id index) + entity_idx_mapping = {} + valid_entity_idx = 0 + + if ner is not None: + span_idx_list = [] + for ent_idx, (start, end, label) in enumerate(ner): + if label in classes_to_id and end < num_tokens: + span_idx_list.append([start, end]) + entity_idx_mapping[ent_idx] = valid_entity_idx + valid_entity_idx += 1 + if span_idx_list: + span_idx = torch.LongTensor(span_idx_list) + else: + span_idx = torch.zeros(0, 2, dtype=torch.long) + else: + span_idx = None + # Process relations + rel_idx_list = [] + rel_label_list = [] + + if relations is not None and rel_classes_to_id is not None: + for rel in relations: + head_idx, tail_idx, rel_type = rel + + # Check if both entities are valid and relation type is known + if head_idx in entity_idx_mapping and tail_idx in entity_idx_mapping and rel_type in rel_classes_to_id: + mapped_head = entity_idx_mapping[head_idx] + mapped_tail = entity_idx_mapping[tail_idx] + rel_idx_list.append([mapped_head, mapped_tail]) + rel_label_list.append(rel_classes_to_id[rel_type]) + + # Convert to tensors + if rel_idx_list: + rel_idx = torch.LongTensor(rel_idx_list) + rel_label = torch.LongTensor(rel_label_list) + else: + rel_idx = torch.zeros(0, 2, dtype=torch.long) + rel_label = torch.zeros(0, dtype=torch.long) + + return { + "tokens": tokens, + "seq_length": num_tokens, + "entities": ner, + "span_idx": span_idx, + "relations": relations, + "rel_idx": rel_idx, + "rel_label": rel_label, + } + + def create_batch_dict(self, batch, class_to_ids, id_to_classes, rel_class_to_ids=None, rel_id_to_classes=None): + """Create a batch dictionary from preprocessed relation extraction examples. + + Args: + batch: List of preprocessed example dictionaries. + class_to_ids: List of entity class-to-ID mappings. + id_to_classes: List of entity ID-to-class mappings. + rel_class_to_ids: List of relation class-to-ID mappings. + rel_id_to_classes: List of relation ID-to-class mappings. + + Returns: + Dictionary containing all batch data for joint entity and relation + extraction with token-level entity labels. + """ + tokens = [el["tokens"] for el in batch] + seq_length = torch.LongTensor([el["seq_length"] for el in batch]).unsqueeze(-1) + entities = [el["entities"] for el in batch] + relations = [el["relations"] for el in batch] + + if batch[0]["span_idx"] is not None: + span_idx_list = [el["span_idx"] for el in batch] + + batch_size = len(span_idx_list) + span_counts = [s.size(0) if s.numel() > 0 else 0 for s in span_idx_list] + max_spans = max(*span_counts, 1) # Ensure at least 1 + + span_mask = torch.zeros(batch_size, max_spans, dtype=torch.bool) + for i, count in enumerate(span_counts): + if count > 0: + span_mask[i, :count] = True + + span_idx = pad_2d_tensor(span_idx_list, padding_value=0) + else: + span_idx, span_mask = None, None + + rel_idx = pad_sequence([el["rel_idx"] for el in batch], batch_first=True, padding_value=0) + rel_label = pad_sequence([el["rel_label"] for el in batch], batch_first=True, padding_value=0) + + return { + "tokens": tokens, + "seq_length": seq_length, + "entities": entities, + "span_idx": span_idx, + "span_mask": span_mask, + "relations": relations, + "rel_idx": rel_idx, + "rel_label": rel_label, + "classes_to_id": class_to_ids, + "id_to_classes": id_to_classes, + "rel_class_to_ids": rel_class_to_ids, + "rel_id_to_classes": rel_id_to_classes, + } + + def tokenize_and_prepare_labels(self, batch, prepare_labels, *args, **kwargs): + """Tokenize inputs and prepare labels for joint entity-relation extraction. + + Args: + batch: Batch dictionary with tokens, entities, relations, and class mappings. + prepare_labels: Whether to prepare labels. + *args: Variable length argument list. + **kwargs: Arbitrary keyword arguments. + + Returns: + Dictionary containing tokenized inputs, token-level entity labels, + relation adjacency matrix, and relation labels. + """ + # Use relation-aware tokenize_inputs from RelationExtractionSpanProcessor + tokenized_input = self.tokenize_inputs( + batch["tokens"], batch["classes_to_id"], blank=None, relations=batch["rel_class_to_ids"] + ) + + if prepare_labels: + # Create token-level BIO labels (from UniEncoderTokenProcessor) + labels = self.create_labels(batch) + tokenized_input["labels"] = labels + + # Create relation labels (overridden method) + adj_matrix, rel_matrix = self.create_relation_labels(batch) + tokenized_input["adj_matrix"] = adj_matrix + tokenized_input["rel_matrix"] = rel_matrix + + return tokenized_input diff --git a/gliner/data_processing/utils.py b/gliner/data_processing/utils.py index 460beca..7f50162 100644 --- a/gliner/data_processing/utils.py +++ b/gliner/data_processing/utils.py @@ -4,7 +4,7 @@ import torch -def pad_2d_tensor(key_data): +def pad_2d_tensor(key_data, padding_value=0.0): """Pad a list of 2D tensors to uniform dimensions. Takes a list of 2D tensors with potentially different shapes and pads them @@ -15,6 +15,7 @@ def pad_2d_tensor(key_data): Args: key_data: List of 2D tensors to pad. Each tensor can have different dimensions, but all must be 2D. + padding_value: float, value used to fill pad elements. Returns: A 3D tensor of shape (batch_size, max_rows, max_cols) containing all @@ -44,7 +45,9 @@ def pad_2d_tensor(key_data): col_padding = max_cols - cols # Pad the tensor along both dimensions - padded_tensor = torch.nn.functional.pad(tensor, (0, col_padding, 0, row_padding), mode="constant", value=0) + padded_tensor = torch.nn.functional.pad( + tensor, (0, col_padding, 0, row_padding), mode="constant", value=padding_value + ) tensors.append(padded_tensor) # Stack the tensors into a single tensor along a new batch dimension @@ -88,6 +91,7 @@ def get_negatives(batch_list: List[Dict], sampled_neg: int = 5, key="ner") -> Li selected_elements = random.sample(element_types, k=min(sampled_neg, len(element_types))) return selected_elements + def prepare_word_mask( texts: Sequence[Sequence[str]], tokenized_inputs, diff --git a/gliner/decoding/__init__.py b/gliner/decoding/__init__.py index b8b2d0d..a0caaeb 100644 --- a/gliner/decoding/__init__.py +++ b/gliner/decoding/__init__.py @@ -1 +1,8 @@ -from .decoder import SpanDecoder, TokenDecoder, SpanRelexDecoder, SpanGenerativeDecoder +from .decoder import ( + SpanDecoder, + TokenDecoder, + SpanRelexDecoder, + TokenRelexDecoder, + SpanGenerativeDecoder, + TokenGenerativeDecoder, +) diff --git a/gliner/decoding/decoder.py b/gliner/decoding/decoder.py index e077dab..f973026 100644 --- a/gliner/decoding/decoder.py +++ b/gliner/decoding/decoder.py @@ -827,7 +827,7 @@ class TokenDecoder(BaseDecoder): Token-based decoder for sequence labeling tasks. Uses BIO-style tagging with separate start, end, and inside predictions - to identify entity spans. + to identify entity spans. Can also decode from span-level predictions. """ def _get_indices_above_threshold(self, scores: torch.Tensor, threshold: float) -> List[torch.Tensor]: @@ -889,50 +889,544 @@ def _calculate_span_score( span_i.append((st, ed, id_to_classes[cls_st + 1], spn_score)) return span_i + def _decode_from_spans( + self, + tokens: List[List[str]], + id_to_classes: Union[Dict[int, str], List[Dict[int, str]]], + span_logits: torch.Tensor, + span_idx: torch.Tensor, + span_mask: torch.Tensor, + flat_ner: bool = False, + threshold: float = 0.5, + multi_label: bool = False, + ) -> List[List[tuple]]: + """ + Decode from span-level predictions. + + Args: + tokens (List[List[str]]): Tokenized input text for each sample in the batch. + id_to_classes (Union[Dict[int, str], List[Dict[int, str]]]): Mapping from + class IDs to class names. + span_logits (torch.Tensor): Span classification logits with shape (B, S, C), + where B is batch size, S is max spans, C is number of classes. + span_idx (torch.Tensor): Span indices with shape (B, S, 2), containing + [start, end] positions for each span. + span_mask (torch.Tensor): Boolean mask with shape (B, S) indicating + valid spans. + flat_ner (bool): Whether to enforce non-overlapping spans. + threshold (float): Confidence threshold for predictions. + multi_label (bool): Whether to allow multiple labels per span. + + Returns: + List[List[tuple]]: For each sample, list of span tuples in format + (start, end, entity_type, None, score). + """ + batch_size = span_logits.size(0) + spans = [] + + # Apply sigmoid to get probabilities + span_probs = torch.sigmoid(span_logits) + + for i in range(batch_size): + id_to_class_i = self._get_id_to_class_for_sample(id_to_classes, i) + span_scores = [] + + # Get valid spans for this sample + valid_mask = span_mask[i] + valid_indices = torch.where(valid_mask)[0] + + for span_pos in valid_indices: + span_start = span_idx[i, span_pos, 0].item() + span_end = span_idx[i, span_pos, 1].item() + + # Get probabilities for all classes for this span + probs = span_probs[i, span_pos] + + # Find classes above threshold + class_indices = torch.where(probs > threshold)[0] + + for class_idx in class_indices: + class_id = class_idx.item() + 1 # Convert to 1-indexed + if class_id in id_to_class_i: + entity_type = id_to_class_i[class_id] + score = probs[class_idx].item() + span_scores.append((span_start, span_end, entity_type, score)) + + # Apply greedy search to handle overlapping spans if needed + span_i = self.greedy_search(span_scores, flat_ner, multi_label) + spans.append(span_i) + return spans + def decode( self, tokens: List[List[str]], id_to_classes: Union[Dict[int, str], List[Dict[int, str]]], - model_output: torch.Tensor, + model_output: Optional[torch.Tensor] = None, flat_ner: bool = False, threshold: float = 0.5, multi_label: bool = False, + span_logits: Optional[torch.Tensor] = None, + span_idx: Optional[torch.Tensor] = None, + span_mask: Optional[torch.Tensor] = None, **kwargs, ) -> List[List[tuple]]: """ - Decode token-level predictions to extract spans. + Decode predictions to extract spans. + + Supports two decoding modes: + 1. Token-level BIO decoding (default): Uses model_output with start/end/inside predictions + 2. Span-level decoding: Uses span_logits, span_idx, and span_mask Args: tokens (List[List[str]]): Tokenized input text for each sample in the batch. id_to_classes (Union[Dict[int, str], List[Dict[int, str]]]): Mapping from class IDs to class names. - model_output (torch.Tensor): Raw logits from the model with shape ( B, L, C, 3), - where the first dimension represents [start, end, inside] predictions. + model_output (torch.Tensor, optional): Raw logits from the model with shape + (B, L, C, 3), where the last dimension represents [start, end, inside] + predictions. Used for token-level decoding. flat_ner (bool): Whether to enforce non-overlapping spans. threshold (float): Confidence threshold for predictions. multi_label (bool): Whether to allow multiple labels per span. + span_logits (torch.Tensor, optional): Span classification logits with shape + (B, S, C). Used for span-level decoding. + span_idx (torch.Tensor, optional): Span indices with shape (B, S, 2). + Used for span-level decoding. + span_mask (torch.Tensor, optional): Boolean mask with shape (B, S). + Used for span-level decoding. **kwargs: Additional keyword arguments (unused). Returns: List[List[tuple]]: For each sample, list of span tuples in format (start, end, entity_type, None, score). + + Raises: + ValueError: If neither model_output nor span-level inputs are provided, + or if span-level inputs are incomplete. + """ + # Check if span-level decoding is requested + if span_logits is not None and span_idx is not None and span_mask is not None: + return self._decode_from_spans( + tokens=tokens, + id_to_classes=id_to_classes, + span_logits=span_logits, + span_idx=span_idx, + span_mask=span_mask, + flat_ner=flat_ner, + threshold=threshold, + multi_label=multi_label, + ) + + # Check if token-level decoding is requested + if model_output is not None: + model_output = model_output.permute(3, 0, 1, 2) + scores_start, scores_end, scores_inside = model_output + spans = [] + + for i, _ in enumerate(tokens): + id_to_class_i = self._get_id_to_class_for_sample(id_to_classes, i) + span_scores = self._calculate_span_score( + self._get_indices_above_threshold(scores_start[i], threshold), + self._get_indices_above_threshold(scores_end[i], threshold), + torch.sigmoid(scores_inside[i]), + torch.sigmoid(scores_start[i]), + torch.sigmoid(scores_end[i]), + id_to_class_i, + threshold, + ) + span_i = self.greedy_search(span_scores, flat_ner, multi_label) + spans.append(span_i) + return spans + + # Neither decoding mode has sufficient inputs + if span_logits is not None or span_idx is not None or span_mask is not None: + raise ValueError( + "For span-level decoding, all three parameters must be provided: span_logits, span_idx, and span_mask" + ) + + raise ValueError( + "Either model_output (for token-level decoding) or " + "(span_logits, span_idx, span_mask) (for span-level decoding) must be provided" + ) + + +class TokenRelexDecoder(TokenDecoder): + """Token-based decoder with relation extraction support. + + Extends the token-based BIO decoder to decode both entity spans and the relations + between them. Entity spans are extracted first using BIO-style tagging logic, + then relations are decoded by identifying pairs of entities and their + relationship types based on model predictions. + + The decoder supports: + - Entity span extraction via BIO tagging with confidence thresholding + - Relation extraction between detected entities + - Flexible entity and relation label mappings (per-sample or global) + - Optional flat NER (non-overlapping entities) + - Multi-label entity classification + """ + + def _decode_relations( + self, + spans: List[List[tuple]], + rel_idx: Optional[torch.Tensor], + rel_logits: Optional[torch.Tensor], + rel_mask: Optional[torch.Tensor], + rel_id_to_classes: Union[Dict[int, str], List[Dict[int, str]]], + threshold: float, + batch_size: int, + ) -> List[List[tuple]]: + """Decode relations between detected entity spans. + + Extracts relation predictions from model outputs and maps them to pairs + of detected entity spans. For each potential relation, checks if both + head and tail entities exist in the decoded spans and if the relation + confidence exceeds the threshold. + + Args: + spans: List of entity spans for each sample in the batch. + Each sample contains a list of tuples: (start, end, entity_type, score). + rel_idx: Tensor of shape (batch_size, num_relations, 2) containing + indices of head and tail entities for each potential relation. + None if no relations to decode. + rel_logits: Tensor of shape (batch_size, num_relations, num_relation_classes) + containing logits for relation classifications. None if no relations. + rel_mask: Optional boolean tensor of shape (batch_size, num_relations) + indicating which relations are valid (True) vs. padding (False). + If None, all relations are considered valid. + rel_id_to_classes: Mapping from relation class IDs to relation names. + Can be either: + - Dict: Single mapping used for all samples + - List[Dict]: Per-sample mappings for different relation schemas + Class IDs are 1-indexed (0 reserved for "no relation" or padding). + threshold: Minimum confidence score (after sigmoid) for a relation + to be included in the output. Must be in range [0, 1]. + batch_size: Number of samples in the batch. + + Returns: + List of relation lists, one per sample. Each relation is a tuple: + (head_idx, relation_label, tail_idx, score) where: + - head_idx: Index into the sample's spans list for the head entity + - relation_label: String name of the relation type + - tail_idx: Index into the sample's spans list for the tail entity + - score: Confidence score for this relation (float, 0-1 range) + """ + relations = [[] for _ in range(batch_size)] + + # Check if relation outputs are available + if rel_idx is None or rel_logits is None: + return relations + + # Get or create relation mask + if rel_mask is None: + rel_mask = torch.ones(rel_idx[..., 0].shape, dtype=torch.bool, device=rel_idx.device) + + rel_probs = torch.sigmoid(rel_logits) + + # Decode relations for each sample + for i in range(batch_size): + rel_id_to_class_i = rel_id_to_classes[i] if isinstance(rel_id_to_classes, list) else rel_id_to_classes + + # Process each potential relation + for j in range(rel_idx.size(1)): + # Skip if masked out + if not rel_mask[i, j]: + continue + + head_idx = rel_idx[i, j, 0].item() + tail_idx = rel_idx[i, j, 1].item() + + # Skip invalid indices + if head_idx < 0 or tail_idx < 0: + continue + + # Skip if either span was removed by greedy search + if head_idx >= len(spans[i]) or tail_idx >= len(spans[i]): + continue + + # Check each relation class + for c, p in enumerate(rel_probs[i, j]): + prob = p.item() + + # Skip low confidence predictions + if prob <= threshold: + continue + + # Skip if class ID not in mapping (c + 1 because 0 is padding) + if (c + 1) not in rel_id_to_class_i: + continue + + rel_label = rel_id_to_class_i[c + 1] + relations[i].append((head_idx, rel_label, tail_idx, prob)) + + return relations + + def decode( + self, + tokens: List[List[str]], + id_to_classes: Union[Dict[int, str], List[Dict[int, str]]], + model_output: torch.Tensor, + rel_idx: Optional[torch.Tensor] = None, + rel_logits: Optional[torch.Tensor] = None, + rel_mask: Optional[torch.Tensor] = None, + flat_ner: bool = False, + threshold: float = 0.5, + relation_threshold: float = 0.5, + multi_label: bool = False, + rel_id_to_classes: Optional[Union[Dict[int, str], List[Dict[int, str]]]] = None, + **kwargs, + ) -> Tuple[List[List[tuple]], List[List[tuple]]]: + """Decode model output to extract entities and relations. + + Main decoding method that extracts both entity spans and relations from + model outputs. First decodes entity spans using BIO-style token tagging, + then decodes relations between the detected entities. + + Args: + tokens: Tokenized input text for each sample in the batch. + Each sample is a list of token strings. + id_to_classes: Mapping from entity class IDs to entity type names. + Can be either: + - Dict: Single mapping used for all samples (global entity schema) + - List[Dict]: Per-sample mappings for different entity schemas + Class IDs are 1-indexed (0 is reserved for padding). + model_output: Model output tensor with shape (B, L, C, 3) where the last + dimension contains [start, end, inside] predictions for BIO tagging. + rel_idx: Optional tensor of shape (batch_size, num_relations, 2) containing + head and tail entity indices for each potential relation. + rel_logits: Optional tensor of shape (batch_size, num_relations, num_relation_classes) + containing relation classification logits. + rel_mask: Optional boolean tensor of shape (batch_size, num_relations) + indicating valid relations. If None, all relations are considered valid. + flat_ner: If True, applies greedy filtering to ensure non-overlapping + entity spans. If False, allows overlapping entities. Defaults to False. + threshold: Minimum confidence score (0-1) for entity predictions + to be included in the output. Defaults to 0.5. + relation_threshold: Minimum confidence score (0-1) for relation + predictions to be included in the output. Defaults to 0.5. + multi_label: If True, allows multiple entity types per span. If False, + only the highest-scoring entity type per span is kept. Defaults to False. + rel_id_to_classes: Optional mapping from relation class IDs to relation names. + If None, relation decoding is skipped and empty relation lists are returned. + Can be either a single Dict or List[Dict] for per-sample mappings. + Class IDs are 1-indexed. + **kwargs: Additional keyword arguments passed to the parent class decode method. + + Returns: + Tuple of (spans, relations) where: + - spans: List of entity span lists, one per sample. Each entity span is + a tuple: (start, end, entity_type, score) + - relations: List of relation lists, one per sample. Each relation is + a tuple: (head_idx, relation_label, tail_idx, score) where head_idx + and tail_idx are indices into the corresponding sample's spans list. + + Examples: + >>> decoder = TokenRelexDecoder(config) + >>> tokens = [["John", "works", "at", "Microsoft"]] + >>> id_to_classes = {1: "PERSON", 2: "ORG"} + >>> rel_id_to_classes = {1: "works_at"} + >>> spans, relations = decoder.decode( + ... tokens=tokens, + ... id_to_classes=id_to_classes, + ... model_output=output, + ... rel_id_to_classes=rel_id_to_classes, + ... threshold=0.5, + ... ) + >>> # spans[0] might be: [(0, 0, "PERSON", 0.9), (3, 3, "ORG", 0.85)] + >>> # relations[0] might be: [(0, "works_at", 1, 0.8)] """ - model_output = model_output.permute(3, 0, 1, 2) - scores_start, scores_end, scores_inside = model_output + # Decode entity spans using parent class BIO-style logic + spans = super().decode( + tokens=tokens, + id_to_classes=id_to_classes, + model_output=model_output, + flat_ner=flat_ner, + threshold=threshold, + multi_label=multi_label, + **kwargs, + ) + + # Decode relations if requested + relations = [[] for _ in range(len(tokens))] + + if rel_id_to_classes is not None: + relations = self._decode_relations( + spans=spans, + rel_idx=rel_idx, + rel_logits=rel_logits, + rel_mask=rel_mask, + rel_id_to_classes=rel_id_to_classes, + threshold=relation_threshold, + batch_size=len(tokens), + ) + + return spans, relations + + +class TokenGenerativeDecoder(TokenDecoder, SpanGenerativeDecoder): + """Token-based decoder with generative label support. + + Extends the token decoder to support generated labels from an encoder-decoder + architecture. Supports two decoder modes: + - 'prompt': Generated labels replace the original class names + - 'span': Generated labels are added as additional fields to each span + + Returns spans in format: (start, end, entity_type, generated_entity_type, score) + """ + + def decode_generative( + self, + tokens: List[List[str]], + id_to_classes: Union[Dict[int, str], List[Dict[int, str]]], + model_output: torch.Tensor, + gen_labels: List[str], + sel_idx: Optional[torch.LongTensor] = None, + num_gen_sequences: int = 1, + flat_ner: bool = False, + threshold: float = 0.5, + multi_label: bool = False, + span_logits: Optional[torch.Tensor] = None, + span_idx: Optional[torch.Tensor] = None, + span_mask: Optional[torch.Tensor] = None, + ) -> List[List[tuple]]: + """Decode model output with generated labels. + + Handles both 'prompt' and 'span' decoder modes: + - prompt mode: Generated labels replace class names in id_to_classes + - span mode: Generated labels are added to span tuples via span_label_map + + Args: + tokens (List[List[str]]): Tokenized input text for each sample in the batch. + id_to_classes (Union[Dict[int, str], List[Dict[int, str]]]): Mapping from + class IDs to class names. + model_output (torch.Tensor): Raw logits from the model with shape (B, W, C, 3). + gen_labels (List[str]): Generated labels from the decoder, flattened across batch. + sel_idx (Optional[torch.LongTensor]): Tensor of shape (B, M) with selected + span indices. Required for span mode, unused for prompt mode. + num_gen_sequences (int): Number of label sequences generated per span. + flat_ner (bool): Whether to enforce non-overlapping spans. + threshold (float): Confidence threshold for span predictions. + multi_label (bool): Whether to allow multiple labels per span. + span_logits (torch.Tensor, optional): Span classification logits. + span_idx (torch.Tensor, optional): Span indices. + span_mask (torch.Tensor, optional): Span mask. + + Returns: + List[List[tuple]]: For each sample, list of span tuples with generated labels. + """ + B = model_output.size(0) + + # Handle prompt mode: update id_to_classes with generated labels + if self.config.decoder_mode == "prompt": + id_to_classes = self._update_id_to_classes_with_generated(id_to_classes, gen_labels, B) + span_label_maps = [{} for _ in range(B)] + + # Handle span mode: build span_label_map from sel_idx and gen_labels + elif self.config.decoder_mode == "span": + if sel_idx is not None: + span_label_maps = self._build_span_label_map_for_batch(sel_idx, gen_labels, num_gen_sequences) + else: + span_label_maps = [{} for _ in range(B)] + else: + span_label_maps = [{} for _ in range(B)] + + batch_size = span_logits.size(0) spans = [] - for i, _ in enumerate(tokens): + span_probs = torch.sigmoid(span_logits) + + for i in range(batch_size): id_to_class_i = self._get_id_to_class_for_sample(id_to_classes, i) - span_scores = self._calculate_span_score( - self._get_indices_above_threshold(scores_start[i], threshold), - self._get_indices_above_threshold(scores_end[i], threshold), - torch.sigmoid(scores_inside[i]), - torch.sigmoid(scores_start[i]), - torch.sigmoid(scores_end[i]), - id_to_class_i, - threshold, - ) + span_label_map_i = span_label_maps[i] + span_scores = [] + + valid_mask = span_mask[i] + valid_indices = torch.where(valid_mask)[0] + + for span_pos in valid_indices: + span_start = span_idx[i, span_pos, 0].item() + span_end = span_idx[i, span_pos, 1].item() + + probs = span_probs[i, span_pos] + class_indices = torch.where(probs > threshold)[0] + + for class_idx in class_indices: + class_id = class_idx.item() + 1 + if class_id in id_to_class_i: + entity_type = id_to_class_i[class_id] + score = probs[class_idx].item() + gen_label = span_label_map_i.get(span_pos.item()) + span_scores.append((span_start, span_end, entity_type, gen_label, score)) + span_i = self.greedy_search(span_scores, flat_ner, multi_label) spans.append(span_i) return spans + + def decode( + self, + tokens: List[List[str]], + id_to_classes: Union[Dict[int, str], List[Dict[int, str]]], + model_output: Optional[torch.Tensor] = None, + flat_ner: bool = False, + threshold: float = 0.5, + multi_label: bool = False, + gen_labels: Optional[List[str]] = None, + sel_idx: Optional[torch.LongTensor] = None, + num_gen_sequences: int = 1, + span_logits: Optional[torch.Tensor] = None, + span_idx: Optional[torch.Tensor] = None, + span_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> List[List[tuple]]: + """Decode model output, with optional generative label support. + + If gen_labels are provided and decoder has a labels_decoder, uses generative + decoding. Otherwise falls back to standard token decoding. + + Args: + tokens: Tokenized input text. + id_to_classes: Class ID to name mapping. + model_output: Token-level logits (B, W, C, 3). + flat_ner: Whether to enforce non-overlapping spans. + threshold: Confidence threshold. + multi_label: Allow multiple labels per span. + gen_labels: Generated labels from decoder. + sel_idx: Selected span indices for span mode. + num_gen_sequences: Number of sequences per span. + span_logits: Span classification logits. + span_idx: Span indices. + span_mask: Span mask. + **kwargs: Additional arguments. + + Returns: + List of span tuples, with generated labels if available. + """ + # Use generative decoding if labels_decoder is configured and gen_labels provided + if self.config.labels_decoder is not None and gen_labels is not None: + return self.decode_generative( + tokens=tokens, + id_to_classes=id_to_classes, + model_output=model_output, + gen_labels=gen_labels, + sel_idx=sel_idx, + num_gen_sequences=num_gen_sequences, + flat_ner=flat_ner, + threshold=threshold, + multi_label=multi_label, + span_logits=span_logits, + span_idx=span_idx, + span_mask=span_mask, + ) + + # Fall back to standard decoding without generative labels + return super().decode( + tokens=tokens, + id_to_classes=id_to_classes, + model_output=model_output, + flat_ner=flat_ner, + threshold=threshold, + multi_label=multi_label, + span_logits=span_logits, + span_idx=span_idx, + span_mask=span_mask, + ) diff --git a/gliner/evaluation/evaluator.py b/gliner/evaluation/evaluator.py index 27f66e3..cc5bd4c 100644 --- a/gliner/evaluation/evaluator.py +++ b/gliner/evaluation/evaluator.py @@ -203,6 +203,7 @@ class BaseRelexEvaluator(BaseEvaluator): The input format expects entity indices rather than entity spans directly. Entity spans are looked up from the entity list using these indices. """ + def get_ground_truth(self, ents, rels): """Extract ground truth relations in evaluation format. diff --git a/gliner/model.py b/gliner/model.py index d2b1ac2..8a2e55e 100644 --- a/gliner/model.py +++ b/gliner/model.py @@ -31,9 +31,18 @@ UniEncoderSpanConfig, UniEncoderTokenConfig, UniEncoderSpanRelexConfig, + UniEncoderTokenRelexConfig, UniEncoderSpanDecoderConfig, + UniEncoderTokenDecoderConfig, +) +from .decoding import ( + SpanDecoder, + TokenDecoder, + SpanRelexDecoder, + TokenRelexDecoder, + SpanGenerativeDecoder, + TokenGenerativeDecoder, ) -from .decoding import SpanDecoder, TokenDecoder, SpanRelexDecoder, SpanGenerativeDecoder from .training import Trainer, TrainingArguments from .evaluation import BaseNEREvaluator, BaseRelexEvaluator from .onnx.model import ( @@ -43,6 +52,7 @@ UniEncoderSpanORTModel, UniEncoderTokenORTModel, UniEncoderSpanRelexORTModel, + UniEncoderTokenRelexORTModel, ) from .decoding.trie import LabelsTrie from .infer_packing import InferencePackingConfig @@ -53,7 +63,9 @@ UniEncoderSpanModel, UniEncoderTokenModel, UniEncoderSpanRelexModel, + UniEncoderTokenRelexModel, UniEncoderSpanDecoderModel, + UniEncoderTokenDecoderModel, ) from .data_processing import ( BaseProcessor, @@ -63,6 +75,8 @@ UniEncoderTokenProcessor, UniEncoderSpanDecoderProcessor, RelationExtractionSpanProcessor, + UniEncoderTokenDecoderProcessor, + RelationExtractionTokenProcessor, ) from .data_processing.collator import ( BiEncoderSpanDataCollator, @@ -71,6 +85,8 @@ UniEncoderTokenDataCollator, UniEncoderSpanDecoderDataCollator, RelationExtractionSpanDataCollator, + UniEncoderTokenDecoderDataCollator, + RelationExtractionTokenDataCollator, ) from .data_processing.tokenizer import WordsSplitter @@ -987,7 +1003,7 @@ def create_training_args( save_total_limit: int = 10, logging_steps: int = 10, use_cpu: bool = False, - bf16: bool = True, + bf16: bool = False, dataloader_num_workers: int = 1, report_to: str = "none", **kwargs, @@ -1201,7 +1217,8 @@ def _process_batches(self, data_loader, threshold, flat_ner, multi_label, packin ) # Get predictions - model_logits = self.model(**model_inputs, threshold=threshold)[0] + model_output = self.model(**model_inputs, threshold=threshold) + model_logits = model_output[0] if not isinstance(model_logits, torch.Tensor): model_logits = torch.from_numpy(model_logits) @@ -1210,6 +1227,9 @@ def _process_batches(self, data_loader, threshold, flat_ner, multi_label, packin batch["tokens"], batch["id_to_classes"], model_logits, + span_idx=model_output.span_idx, + span_mask=model_output.span_mask, + span_logits=model_output.span_logits, flat_ner=flat_ner, threshold=threshold, multi_label=multi_label, @@ -2103,6 +2123,21 @@ def export_to_onnx( ) +class UniEncoderTokenDecoderGLiNER(UniEncoderSpanDecoderGLiNER): + """GLiNER model with token-based encoding and label decoding capabilities. + + Combines token-level BIO tagging with a decoder that generates entity type + labels autoregressively. + """ + + config_class = UniEncoderTokenDecoderConfig + model_class = UniEncoderTokenDecoderModel + ort_model_class = None + data_processor_class = UniEncoderTokenDecoderProcessor + data_collator_class = UniEncoderTokenDecoderDataCollator + decoder_class = TokenGenerativeDecoder + + class UniEncoderSpanRelexGLiNER(BaseEncoderGLiNER): """GLiNER model for both entity recognition and relation extraction. @@ -2419,13 +2454,13 @@ def evaluate( - rel_f1: Relation extraction F1 score """ self.eval() - + if relation_threshold is None: relation_threshold = threshold - + if adjacency_threshold is None: adjacency_threshold = threshold - + # Create the dataset and data loader dataset = test_data collator = self.data_collator_class( @@ -2438,9 +2473,7 @@ def evaluate( return_rel_id_to_classes=True, prepare_labels=False, ) - data_loader = torch.utils.data.DataLoader( - dataset, batch_size=batch_size, shuffle=False, collate_fn=collator - ) + data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=collator) all_entity_preds = [] all_relation_preds = [] @@ -2454,9 +2487,7 @@ def evaluate( # Get model predictions model_inputs = batch.copy() - model_output = self.model( - **model_inputs, threshold=threshold, adjacency_threshold=adjacency_threshold - ) + model_output = self.model(**model_inputs, threshold=threshold, adjacency_threshold=adjacency_threshold) # Extract logits and relation outputs model_logits = model_output.logits @@ -2512,7 +2543,7 @@ def evaluate( # Format data for relation evaluator: list of (entities, relations) tuples all_true_rel_data = list(zip(all_true_entities, all_true_relations)) all_pred_rel_data = list(zip(all_entity_preds, all_relation_preds)) - + rel_evaluator = BaseRelexEvaluator(all_true_rel_data, all_pred_rel_data) rel_output, rel_f1 = rel_evaluator.evaluate() @@ -2595,6 +2626,89 @@ def forward( return UniEncoderSpanRelexWrapper(core_model) +class UniEncoderTokenRelexGLiNER(UniEncoderSpanRelexGLiNER): + """GLiNER model for both entity recognition and relation extraction. + + Performs joint entity and relation prediction, allowing the model to simultaneously + detect entities and the relationships between them in a single forward pass. + """ + + config_class = UniEncoderTokenRelexConfig + model_class = UniEncoderTokenRelexModel + ort_model_class: type = UniEncoderTokenRelexORTModel + data_processor_class = RelationExtractionTokenProcessor + data_collator_class = RelationExtractionTokenDataCollator + decoder_class = TokenRelexDecoder + + def _get_onnx_input_spec(self) -> dict[str, Any]: + """Define ONNX input specification for UniEncoderSpanRelex model.""" + return { + "input_names": [ + "input_ids", + "attention_mask", + "words_mask", + "text_lengths", + ], + "output_names": ["logits", "rel_idx", "rel_logits", "rel_mask"], + "dynamic_axes": { + "input_ids": {0: "batch_size", 1: "sequence_length"}, + "attention_mask": {0: "batch_size", 1: "sequence_length"}, + "words_mask": {0: "batch_size", 1: "sequence_length"}, + "text_lengths": {0: "batch_size", 1: "value"}, + "logits": { + 0: "batch_size", + 1: "sequence_length", + 2: "num_ent_classes", + 3: "num_idx_classes", + }, + "rel_idx": { + 0: "batch_size", + 1: "num_pairs", + 2: "pair_index", + }, + "rel_logits": { + 0: "batch_size", + 1: "num_pairs", + 2: "num_rel_classes", + }, + "rel_mask": { + 0: "batch_size", + 1: "num_pairs", + }, + }, + } + + def _get_onnx_export_kwargs(self) -> dict[str, Any]: + """Provide default labels for relation extraction ONNX export.""" + return {"labels": ["head", "tail"]} + + def _create_onnx_wrapper(self, core_model: nn.Module) -> nn.Module: + """Create wrapper for UniEncoderSpanRelex ONNX export.""" + + class UniEncoderTokenRelexWrapper(nn.Module): + def __init__(self, core): + super().__init__() + self.core = core + + def forward( + self, + input_ids, + attention_mask, + words_mask, + text_lengths, + ): + out = self.core( + input_ids=input_ids, + attention_mask=attention_mask, + words_mask=words_mask, + text_lengths=text_lengths, + ) + # Return all outputs for relation extraction + return out.logits, out.rel_idx, out.rel_logits, out.rel_mask + + return UniEncoderTokenRelexWrapper(core_model) + + class GLiNER(nn.Module, PyTorchModelHubMixin): """Meta GLiNER class that automatically instantiates the appropriate GLiNER variant. @@ -2671,23 +2785,17 @@ def __init__(self, config: Union[str, Path, GLiNERConfig], **kwargs): @staticmethod def _get_gliner_class(config: GLiNERConfig): - """Determine the appropriate GLiNER class based on configuration. - - Args: - config: GLiNER configuration object. - - Returns: - The appropriate GLiNER class type. - """ + """Determine the appropriate GLiNER class based on configuration.""" is_token_level = config.span_mode == "token_level" has_labels_encoder = config.labels_encoder is not None has_labels_decoder = config.labels_decoder is not None has_relations = config.relations_layer is not None - # Priority order: relations > decoder > bi-encoder > token vs span - if has_relations: - return UniEncoderSpanRelexGLiNER + if is_token_level: + return UniEncoderTokenRelexGLiNER + else: + return UniEncoderSpanRelexGLiNER if has_labels_decoder: if has_labels_encoder: @@ -2696,6 +2804,8 @@ def _get_gliner_class(config: GLiNERConfig): "Using decoder model (labels_encoder will be ignored).", stacklevel=2, ) + if is_token_level: + return UniEncoderTokenDecoderGLiNER return UniEncoderSpanDecoderGLiNER if has_labels_encoder: @@ -2704,7 +2814,6 @@ def _get_gliner_class(config: GLiNERConfig): else: return BiEncoderSpanGLiNER - # Default: uni-encoder if is_token_level: return UniEncoderTokenGLiNER else: @@ -2951,11 +3060,21 @@ def model_map(self) -> dict[str, dict[str, Any]]: "description": "Span-based NER with label generation decoder", "config": {"span_mode": "span_level", "labels_decoder": "required", "relations_layer": None}, }, + "gliner_uni_encoder_token_decoder": { + "class": UniEncoderTokenDecoderGLiNER, + "description": "Token-level NER with label generation decoder", + "config": {"span_mode": "token_level", "labels_decoder": "required", "relations_layer": None}, + }, "gliner_uni_encoder_span_relex": { "class": UniEncoderSpanRelexGLiNER, "description": "Joint entity and relation extraction with single encoder", "config": {"span_mode": "span_level", "labels_encoder": None, "relations_layer": "required"}, }, + "gliner_uni_encoder_token_relex": { + "class": UniEncoderTokenRelexGLiNER, + "description": "Joint entity and relation extraction with single encoder using token-level architecture", + "config": {"span_mode": "token_level", "labels_encoder": None, "relations_layer": "required"}, + }, } def get_model_type(self) -> str: @@ -2973,7 +3092,9 @@ def get_model_type(self) -> str: "BiEncoderSpanGLiNER": "gliner_bi_encoder_span", "BiEncoderTokenGLiNER": "gliner_bi_encoder_token", "UniEncoderSpanDecoderGLiNER": "gliner_uni_encoder_span_decoder", + "UniEncoderTokenDecoderGLiNER": "gliner_uni_encoder_token_decoder", "UniEncoderSpanRelexGLiNER": "gliner_uni_encoder_span_relex", + "UniEncoderTokenRelexGLiNER": "gliner_uni_encoder_token_relex", } return type_mapping.get(class_name, "unknown") diff --git a/gliner/modeling/__init__.py b/gliner/modeling/__init__.py index 2e3095f..8b45611 100644 --- a/gliner/modeling/__init__.py +++ b/gliner/modeling/__init__.py @@ -7,5 +7,7 @@ UniEncoderSpanModel, UniEncoderTokenModel, UniEncoderSpanRelexModel, + UniEncoderTokenRelexModel, UniEncoderSpanDecoderModel, + UniEncoderTokenDecoderModel, ) diff --git a/gliner/modeling/base.py b/gliner/modeling/base.py index b194a4c..57bc2f2 100644 --- a/gliner/modeling/base.py +++ b/gliner/modeling/base.py @@ -29,6 +29,7 @@ build_entity_pairs, extract_prompt_features, extract_word_embeddings, + extract_spans_from_tokens, extract_prompt_features_and_word_embeddings, ) from .layers import CrossFuser, LstmSeq2SeqEncoder, create_projection_layer @@ -232,7 +233,7 @@ def __init__( super().__init__(config, from_pretrained, cache_dir) self.token_rep_layer = Encoder(config, from_pretrained, cache_dir=cache_dir) - if self.config.num_rnn_layers>0: + if self.config.num_rnn_layers > 0: self.rnn = LstmSeq2SeqEncoder(config, num_layers=self.config.num_rnn_layers) if config.post_fusion_schema: @@ -426,7 +427,7 @@ def loss( scores: torch.Tensor, labels: torch.Tensor, prompts_embedding_mask: torch.Tensor, - mask_label: torch.Tensor, + span_mask: torch.Tensor, alpha: float = -1.0, gamma: float = 0.0, prob_margin: float = 0.0, @@ -442,7 +443,7 @@ def loss( scores: Predicted scores of shape (B, L, K, C). labels: Ground truth labels of shape (B, L, K, C). prompts_embedding_mask: Mask for valid entity types of shape (B, C). - mask_label: Mask for valid spans of shape (B, L, K). + span_mask: Mask for valid spans of shape (B, L, K). alpha: Focal loss alpha parameter. gamma: Focal loss gamma parameter. prob_margin: Margin for probability adjustment. @@ -471,9 +472,9 @@ def loss( masked_loss = all_losses.view(batch_size, -1, num_classes) * prompts_embedding_mask.unsqueeze(1) all_losses = masked_loss.view(-1, num_classes) - mask_label = mask_label.view(-1, 1) + span_mask = span_mask.view(-1, 1) - all_losses = all_losses * mask_label.float() + all_losses = all_losses * span_mask.float() if reduction == "mean": loss = all_losses.mean() @@ -512,17 +513,48 @@ def __init__( super().__init__(config, from_pretrained, cache_dir) self.scorer = Scorer(config.hidden_size, config.dropout) + if getattr(config, "represent_spans", False): + self.span_rep_layer = SpanRepLayer( + span_mode=config.span_mode, + hidden_size=config.hidden_size, + max_width=getattr(config, "max_width", 12), + dropout=config.dropout, + ) + + def get_span_representations(self, scores, span_idx, span_mask, words_embedding, labels, threshold): + if span_idx is None: + span_idx, span_mask = extract_spans_from_tokens(scores, labels, threshold) + span_idx = span_idx * span_mask.unsqueeze(-1).long() + + if getattr(self.config, "represent_spans", False): + span_rep = self.span_rep_layer(words_embedding, span_idx) + else: + span_rep = words_embedding[span_idx[:, :, 0]] + words_embedding[span_idx[:, :, 1]] + + # span_rep = torch.zeros(B, S, D, device=words_embedding.device, dtype=words_embedding.dtype) + # for b in range(B): + # for s in range(S): + # if span_mask[b, s]: + # start, end = span_idx[b, s] + # span_rep[b, s] = words_embedding[b, start:end+1].mean(dim=0) + + return span_rep, span_idx, span_mask + def forward( self, input_ids: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.LongTensor] = None, words_embedding: Optional[torch.FloatTensor] = None, mask: Optional[torch.LongTensor] = None, + span_idx: Optional[torch.Tensor] = None, + span_mask: Optional[torch.Tensor] = None, + span_labels: Optional[torch.Tensor] = None, prompts_embedding: Optional[torch.FloatTensor] = None, prompts_embedding_mask: Optional[torch.LongTensor] = None, words_mask: Optional[torch.LongTensor] = None, text_lengths: Optional[torch.Tensor] = None, labels: Optional[torch.FloatTensor] = None, + threshold: Optional[float] = 0.5, **kwargs: Any, ) -> GLiNERBaseOutput: """Forward pass through the token-based model. @@ -531,16 +563,22 @@ def forward( input_ids: Input token IDs of shape (B, L). attention_mask: Attention mask of shape (B, L). words_embedding: Pre-computed word embeddings of shape (B, W, D). - mask: Mask for words of shape (B, W). + mask: Mask for valid words of shape (B, W). + span_idx: Tensor containing span start/end indices of shape (B, S, 2), + where S is the number of spans. + span_mask: Boolean or integer mask indicating valid spans of shape (B, S). + span_labels: Ground truth span labels of shape (B, S, C). prompts_embedding: Pre-computed entity label embeddings of shape (B, C, D). - prompts_embedding_mask: Mask for prompts of shape (B, C). - words_mask: Word boundary mask. - text_lengths: Length of each text sequence. - labels: Ground truth labels of shape (B, W, C). - **kwargs: Additional arguments. + prompts_embedding_mask: Mask for prompts/entities of shape (B, C). + words_mask: Word boundary mask mapping tokens to words. + text_lengths: Length of each text sequence before padding, shape (B,). + labels: Ground truth token-level labels of shape (B, W, C). + threshold: Confidence threshold used for span selection. + **kwargs: Additional arguments passed to the encoder or loss functions + (e.g., ``packing_config``, ``pair_attention_mask``). Returns: - GLiNERBaseOutput containing logits, loss, and intermediate representations. + GLiNERBaseOutput containing logits, loss, embeddings, and span-level outputs. """ encoder_kwargs = {key: kwargs[key] for key in ("packing_config", "pair_attention_mask") if key in kwargs} @@ -560,12 +598,24 @@ def forward( prompts_embedding, prompts_embedding_mask, target_C ) + # Shape: (batch_size, seq_len, num_classes, 3), 3 - start, end, inside scores = self.scorer(words_embedding, prompts_embedding) + if getattr(self.config, "represent_spans", False): + span_rep, span_idx, span_mask = self.get_span_representations( + scores, span_idx, span_mask, words_embedding, labels, threshold + ) + span_logits = torch.einsum("BND,BCD->BNC", span_rep, prompts_embedding) + else: + span_logits, span_idx, span_mask = None, None, None loss = None if labels is not None: loss = self.loss(scores, labels, prompts_embedding_mask, mask, **kwargs) + if span_labels is not None: + span_loss = self.loss(span_logits, span_labels, prompts_embedding_mask, span_mask, **kwargs) + loss = self.config.token_loss_coef * loss + self.config.span_loss_coef * span_loss + output = GLiNERBaseOutput( logits=scores, loss=loss, @@ -573,6 +623,9 @@ def forward( prompts_embedding_mask=prompts_embedding_mask, words_embedding=words_embedding, mask=mask, + span_idx=span_idx, + span_logits=span_logits, + span_mask=span_mask, ) return output @@ -581,7 +634,7 @@ def loss( scores: torch.Tensor, labels: torch.Tensor, prompts_embedding_mask: torch.Tensor, - mask: torch.Tensor, + word_mask: torch.Tensor, alpha: float = -1.0, gamma: float = 0.0, prob_margin: float = 0.0, @@ -590,27 +643,40 @@ def loss( negatives: float = 1.0, **kwargs: Any, ) -> torch.Tensor: - """Compute token classification loss. + """Compute token- or span-level classification loss. Args: - scores: Predicted scores of shape (B, W, C). - labels: Ground truth labels of shape (B, W, C). + scores: Predicted scores. Shape is (B, W, C, 3) for token-level + classification (start, end, inside) or (B, N, C) for span-level + classification, where B is batch size, W is number of words, + N is number of spans, and C is number of entity types. + labels: Ground truth labels matching ``scores`` shape. prompts_embedding_mask: Mask for valid entity types of shape (B, C). - mask: Mask for valid tokens of shape (B, W). - alpha: Focal loss alpha parameter. - gamma: Focal loss gamma parameter. - prob_margin: Margin for probability adjustment. - label_smoothing: Label smoothing factor. - reduction: Loss reduction method ('sum' or 'mean'). - negatives: Negative sampling probability. - **kwargs: Additional arguments. + word_mask: Mask for valid tokens or spans of shape (B, W) for + token-level loss or (B, N) for span-level loss. + alpha: Alpha parameter for focal loss. If negative, focal weighting + is disabled. + gamma: Gamma parameter for focal loss. + prob_margin: Margin applied to predicted probabilities. + label_smoothing: Amount of label smoothing applied to targets. + reduction: Reduction method applied to the final loss. One of + ``"none"``, ``"mean"``, or ``"sum"``. + negatives: Weighting factor for negative examples. + **kwargs: Additional unused keyword arguments for API compatibility. Returns: - Scalar loss tensor. + A scalar tensor representing the aggregated loss value. """ all_losses = self._loss(scores, labels, alpha, gamma, prob_margin, label_smoothing, negatives) - all_losses = all_losses * (mask.unsqueeze(-1) * prompts_embedding_mask.unsqueeze(1)).unsqueeze(-1) + # Base mask: (B, W/N, C) + mask = word_mask.unsqueeze(-1) * prompts_embedding_mask.unsqueeze(1) + + # Only add extra dimension for 4D token-level scores (B, W, C, 3) + if all_losses.dim() == 4: + mask = mask.unsqueeze(-1) + + all_losses = all_losses * mask if reduction == "mean": loss = all_losses.mean() @@ -924,7 +990,7 @@ def loss( return loss -class BiEncoderTokenModel(BaseBiEncoderModel): +class BiEncoderTokenModel(BaseBiEncoderModel, UniEncoderTokenModel): """Token-based NER model using bi-encoder architecture. Attributes: @@ -953,11 +1019,15 @@ def forward( labels_attention_mask: Optional[torch.LongTensor] = None, words_embedding: Optional[torch.FloatTensor] = None, mask: Optional[torch.LongTensor] = None, + span_idx: Optional[torch.Tensor] = None, + span_mask: Optional[torch.Tensor] = None, + span_labels: Optional[torch.Tensor] = None, prompts_embedding: Optional[torch.FloatTensor] = None, prompts_embedding_mask: Optional[torch.LongTensor] = None, words_mask: Optional[torch.LongTensor] = None, text_lengths: Optional[torch.Tensor] = None, labels: Optional[torch.FloatTensor] = None, + threshold: Optional[float] = 0.5, **kwargs: Any, ) -> GLiNERBaseOutput: """Forward pass through the bi-encoder token model. @@ -970,11 +1040,16 @@ def forward( labels_attention_mask: Attention mask for labels. words_embedding: Pre-computed word embeddings. mask: Mask for words. + span_idx: Tensor containing span start/end indices of shape (B, S, 2), + where S is the number of spans. + span_mask: Boolean or integer mask indicating valid spans of shape (B, S). + span_labels: Ground truth span labels of shape (B, S, C). prompts_embedding: Pre-computed entity label embeddings. prompts_embedding_mask: Mask for prompts. words_mask: Word boundary mask. text_lengths: Length of each text sequence. labels: Ground truth labels of shape (B, W, C). + threshold: float value for filtering spans. **kwargs: Additional arguments. Returns: @@ -1007,10 +1082,22 @@ def forward( scores = self.scorer(words_embedding, prompts_embedding) + if getattr(self.config, "represent_spans", False): + span_rep, span_idx, span_mask = self.get_span_representations( + scores, span_idx, span_mask, words_embedding, labels, threshold + ) + span_logits = torch.einsum("BND,BCD->BNC", span_rep, prompts_embedding) + else: + span_logits, span_idx, span_mask = None, None, None + loss = None if labels is not None: loss = self.loss(scores, labels, prompts_embedding_mask, mask, **kwargs) + if span_labels is not None: + span_loss = self.loss(span_logits, span_labels, prompts_embedding_mask, span_mask, **kwargs) + loss = self.config.token_loss_coef * loss + self.config.span_loss_coef * span_loss + output = GLiNERBaseOutput( logits=scores, loss=loss, @@ -1018,58 +1105,12 @@ def forward( prompts_embedding_mask=prompts_embedding_mask, words_embedding=words_embedding, mask=mask, + span_idx=span_idx, + span_logits=span_logits, + span_mask=span_mask, ) return output - def loss( - self, - scores: torch.Tensor, - labels: torch.Tensor, - prompts_embedding_mask: torch.Tensor, - mask: torch.Tensor, - alpha: float = -1.0, - gamma: float = 0.0, - prob_margin: float = 0.0, - label_smoothing: float = 0.0, - reduction: str = "sum", - negatives: float = 1.0, - **kwargs: Any, - ) -> torch.Tensor: - """Compute token classification loss for bi-encoder. - - Args: - scores: Predicted scores of shape (B, W, C). - labels: Ground truth labels of shape (B, W, C). - prompts_embedding_mask: Mask for valid entity types of shape (B, C). - mask: Mask for valid tokens of shape (B, W). - alpha: Focal loss alpha parameter. - gamma: Focal loss gamma parameter. - prob_margin: Margin for probability adjustment. - label_smoothing: Label smoothing factor. - reduction: Loss reduction method ('sum' or 'mean'). - negatives: Negative sampling probability. - **kwargs: Additional arguments. - - Returns: - Scalar loss tensor. - """ - all_losses = self._loss(scores, labels, alpha, gamma, prob_margin, label_smoothing, negatives) - - all_losses = all_losses * (mask.unsqueeze(-1) * prompts_embedding_mask.unsqueeze(1)).unsqueeze(-1) - - if reduction == "mean": - loss = all_losses.mean() - elif reduction == "sum": - loss = all_losses.sum() - else: - warnings.warn( - f"Invalid Value for config 'loss_reduction': '{reduction}' \n Supported reduction modes:" - f" 'none', 'mean', 'sum'. It will be used 'sum' instead.", - stacklevel=2, - ) - loss = all_losses.sum() - return loss - class UniEncoderSpanDecoderModel(UniEncoderSpanModel): """Span-based model with decoder for generating entity type labels. @@ -1094,6 +1135,9 @@ def __init__( cache_dir: Directory for caching pretrained models. """ super().__init__(config, from_pretrained, cache_dir) + self.__init_decoder__(config, from_pretrained, cache_dir) + + def __init_decoder__(self, config, from_pretrained, cache_dir): self.decoder = Decoder(config, from_pretrained, cache_dir=cache_dir) if self.config.hidden_size != self.decoder.decoder_hidden_size: self._enc2dec_proj = create_projection_layer( @@ -1561,6 +1605,382 @@ def loss( return loss +class UniEncoderTokenDecoderModel(UniEncoderTokenModel, UniEncoderSpanDecoderModel): + """Token-based NER model with encoder-decoder architecture. + + This model combines token-level BIO-style classification with a decoder + that generates entity type labels autoregressively, enabling more flexible + prediction strategies for token-level NER tasks. + + Inherits from: + - UniEncoderTokenModel: Token-level BIO tagging for entities + - UniEncoderSpanDecoderModel: Encoder-decoder architecture and decoder utilities + + Attributes: + scorer (Scorer): Scoring layer for computing token-label compatibility. + decoder (Decoder): Decoder module for label generation. + span_rep_layer (SpanRepLayer): Layer for computing span representations (if represent_spans=True). + _enc2dec_proj (Optional[nn.Module]): Projection layer if encoder and decoder + dimensions differ. + """ + + def __init__( + self, config: Any, from_pretrained: bool = False, cache_dir: Optional[Union[str, Path]] = None + ) -> None: + """Initialize the token-level encoder-decoder model. + + Args: + config: Model configuration object. + from_pretrained: Whether to load from pretrained weights. + cache_dir: Directory for caching pretrained models. + """ + # Initialize through UniEncoderTokenModel's chain to get scorer + super().__init__(config, from_pretrained, cache_dir) + self.__init_decoder__(config, from_pretrained, cache_dir) + + def select_token_decoder_embedding( + self, + prompts_embedding: torch.Tensor, + prompts_embedding_mask: torch.Tensor, + span_logits: Optional[torch.Tensor] = None, + span_rep: Optional[torch.Tensor] = None, + span_idx: Optional[torch.Tensor] = None, + span_mask: Optional[torch.Tensor] = None, + span_labels: Optional[torch.Tensor] = None, + decoder_text_embeds: Optional[torch.Tensor] = None, + decoder_words_mask: Optional[torch.Tensor] = None, + top_k: Optional[int] = None, + ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: + """Select entity embeddings for decoder input based on token predictions or labels. + + This method extracts entity spans from token-level predictions and prepares + their representations for the decoder. It can operate in two modes: + 1. "prompt" mode: Uses entity type embeddings as decoder input. + 2. "span" mode: Uses contextualized tokens within each detected span as decoder input. + + Args: + prompts_embedding: Entity type embeddings of shape (B, C, D). + prompts_embedding_mask: Mask for prompts of shape (B, C). + words_embedding: Word-level embeddings of shape (B, W, D). + token_scores: Token classification scores of shape (B, W, C, 3). + word_mask: Mask for valid words of shape (B, W). + span_logits: Span-level classification logits of shape (B, S, C), + span_rep: Span representations (B, S, D). + span_idx: Pre-computed span indices of shape (B, S, 2). + span_labels: Ground truth span labels of shape (B, S, C) + span_mask: Pre-computed span mask of shape (B, S). + decoder_text_embeds: Text embeddings for span mode of shape (B, T, D). + decoder_words_mask: Word position mask of shape (B, T). + labels: Ground truth token-level labels of shape (B, W, C, 3). + threshold: Confidence threshold for selecting spans. + top_k: Optional limit on number of spans to select. + + Returns: + Tuple containing: + - span_rep_kept: Selected span embeddings for decoder of shape (B, S, T, D) + or None if no valid spans. + - span_msk: Mask for selected spans of shape (B, S, T) or None. + - span_sel_idx: Original indices of selected spans of shape (B, S) or None. + """ + if self.config.decoder_mode == "prompt": + return self.select_decoder_embedding(prompts_embedding, prompts_embedding_mask)[:3] + + span_scores = torch.sigmoid(span_logits) + + # Flatten span representations for selection + B = span_rep.size(0) + + if span_labels is not None: + # During training: select spans where any class is positive (label == 1) + span_prob = span_labels.max(-1).values # (B, S) + keep = (span_prob == 1) & span_mask.bool() + else: + # During inference: use predicted scores + span_scores = torch.sigmoid(span_logits).max(-1).values # (B, S) + keep = (span_scores > 0.5) & span_mask.bool + + if top_k: + sel_scores = span_scores.masked_fill(~keep, -1.0) + top_idx = sel_scores.topk(k=min(top_k, sel_scores.size(1)), dim=1).indices + keep.zero_() + keep.scatter_(1, top_idx, True) + + # Pack valid spans + span_rep_kept, span_msk, span_sel_idx = self.select_decoder_embedding(span_rep, keep.long()) + + if hasattr(self, "_enc2dec_proj"): + span_rep_kept = self._enc2dec_proj(span_rep_kept) + + span_rep_kept = span_rep_kept.unsqueeze(2) + span_msk = span_msk.unsqueeze(-1) + + if decoder_text_embeds is None or decoder_words_mask is None: + return span_rep_kept, span_msk, span_sel_idx + + if span_rep_kept.numel() == 0: + return None, None, None + + decoder_text_embeds = decoder_text_embeds.to(dtype=span_rep_kept.dtype) + + # Build span representations with context tokens + S_kept = span_rep_kept.shape[1] + dec_D = span_rep_kept.shape[-1] + + # Get actual span boundaries from selected indices + batch_indices = torch.arange(B, device=span_idx.device).unsqueeze(1).expand(B, S_kept) + selected_spans = span_idx[batch_indices, span_sel_idx] # (B, S_kept, 2) + + span_start = selected_spans[:, :, 0] # (B, S_kept) + span_end = selected_spans[:, :, 1] # (B, S_kept) + + # Determine which decoder tokens belong to each span + token_in_span = (decoder_words_mask.unsqueeze(1) >= span_start.unsqueeze(-1)) & ( + decoder_words_mask.unsqueeze(1) <= span_end.unsqueeze(-1) + ) + + tokens_per_span = token_in_span.sum(-1) + max_tokens = int(tokens_per_span.max()) + + span_rep_new = span_rep_kept.new_zeros(B, S_kept, max_tokens + 1, dec_D) + span_rep_mask = torch.zeros(B, S_kept, max_tokens + 1, dtype=torch.bool, device=decoder_text_embeds.device) + + left_offset = (max_tokens + 1 - tokens_per_span).clamp(min=0) + pos_in_span = (token_in_span.cumsum(-1) - 1).masked_fill(~token_in_span, 0) + pos_in_span = pos_in_span + left_offset.unsqueeze(-1) + + b_idx, s_idx, tok_idx = torch.where(token_in_span) + span_rep_new[b_idx, s_idx, pos_in_span[b_idx, s_idx, tok_idx]] = decoder_text_embeds[b_idx, tok_idx] + span_rep_mask[b_idx, s_idx, pos_in_span[b_idx, s_idx, tok_idx]] = True + + kept_pos = (left_offset - 1).clamp(min=0) + + b_flat = torch.arange(B, device=decoder_text_embeds.device).view(-1, 1).expand(B, S_kept).reshape(-1) + s_flat = torch.arange(S_kept, device=decoder_text_embeds.device).view(1, -1).expand(B, S_kept).reshape(-1) + t_flat = kept_pos.reshape(-1) + + span_rep_new[b_flat, s_flat, t_flat] = span_rep_kept.reshape(-1, dec_D) + span_rep_mask[b_flat, s_flat, t_flat] = True + span_rep_mask = span_rep_mask & span_msk.bool() + + return span_rep_new, span_rep_mask, span_sel_idx + + def forward( + self, + input_ids: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + decoder_input_ids: Optional[torch.FloatTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + decoder_labels_ids: Optional[torch.FloatTensor] = None, + decoder_labels_mask: Optional[torch.LongTensor] = None, + decoder_words_mask: Optional[torch.LongTensor] = None, + words_embedding: Optional[torch.FloatTensor] = None, + mask: Optional[torch.LongTensor] = None, + span_idx: Optional[torch.Tensor] = None, + span_mask: Optional[torch.Tensor] = None, + span_labels: Optional[torch.Tensor] = None, + prompts_embedding: Optional[torch.FloatTensor] = None, + prompts_embedding_mask: Optional[torch.LongTensor] = None, + words_mask: Optional[torch.LongTensor] = None, + text_lengths: Optional[torch.Tensor] = None, + labels: Optional[torch.FloatTensor] = None, + decoder_labels: Optional[torch.FloatTensor] = None, + threshold: Optional[float] = 0.5, + **kwargs: Any, + ) -> GLiNERDecoderOutput: + """Forward pass through the token-level encoder-decoder model. + + Args: + input_ids: Input token IDs of shape (B, L). + attention_mask: Attention mask of shape (B, L). + decoder_input_ids: Decoder input IDs for span mode. + decoder_attention_mask: Decoder attention mask. + decoder_labels_ids: Label token IDs for decoding of shape (M, L). + decoder_labels_mask: Mask for decoder labels of shape (M, L). + decoder_words_mask: Word position mask for span mode. + words_embedding: Pre-computed word embeddings. + mask: Mask for words. + span_idx: Pre-computed span indices of shape (B, S, 2). + span_mask: Pre-computed span mask of shape (B, S). + span_labels: Ground truth span labels of shape (B, S, C). + prompts_embedding: Pre-computed entity type embeddings. + prompts_embedding_mask: Mask for prompts. + words_mask: Word boundary mask. + text_lengths: Length of each text sequence. + labels: Ground truth token-level labels of shape (B, W, C, 3). + decoder_labels: Ground truth decoder labels of shape (M, L). + threshold: Confidence threshold for span selection. + **kwargs: Additional arguments. + + Returns: + GLiNERDecoderOutput containing logits, losses, and decoder information. + """ + encoder_kwargs = {key: kwargs[key] for key in ("packing_config", "pair_attention_mask") if key in kwargs} + + prompts_embedding, prompts_embedding_mask, words_embedding, mask = self.get_representations( + input_ids, attention_mask, text_lengths, words_mask, **encoder_kwargs + ) + + if labels is not None: + target_W = labels.shape[1] + words_embedding, mask = self._fit_length(words_embedding, mask, target_W) + + target_C = prompts_embedding.size(1) + target_C = max(target_C, labels.size(-2)) + + prompts_embedding, prompts_embedding_mask = self._fit_length( + prompts_embedding, prompts_embedding_mask, target_C + ) + + # Token-level classification: (B, W, C, 3) + scores = self.scorer(words_embedding, prompts_embedding) + + # Get span representations and logits if represent_spans is enabled + span_rep, span_idx, span_mask = self.get_span_representations( + scores, span_idx, span_mask, words_embedding, labels, threshold + ) + span_logits = torch.einsum("BND,BCD->BNC", span_rep, prompts_embedding) + + # Decoder processing + decoder_embedding = decoder_mask = decoder_loss = decoder_span_idx = None + if hasattr(self, "decoder"): + if self.config.decoder_mode == "span": + decoder_text_embeds = self.decoder.ids_to_embeds(decoder_input_ids) + else: + decoder_text_embeds = None + + decoder_embedding, decoder_mask, decoder_span_idx = self.select_token_decoder_embedding( + prompts_embedding, + prompts_embedding_mask, + span_logits=span_logits, + span_rep=span_rep, + span_idx=span_idx, + span_mask=span_mask, + span_labels=span_labels, + decoder_text_embeds=decoder_text_embeds, + decoder_words_mask=decoder_words_mask, + ) + + if decoder_labels is not None: + decoder_loss, _ = self.decode_labels( + decoder_embedding, decoder_mask, decoder_labels_ids, decoder_labels_mask, decoder_labels + ) + + # Compute loss + loss = None + if labels is not None: + loss = self.loss( + scores, + labels, + prompts_embedding_mask, + mask, + span_logits=span_logits, + span_labels=span_labels, + span_mask=span_mask, + decoder_loss=decoder_loss, + **kwargs, + ) + + output = GLiNERDecoderOutput( + logits=scores, + loss=loss, + decoder_loss=decoder_loss, + prompts_embedding=prompts_embedding, + prompts_embedding_mask=prompts_embedding_mask, + decoder_embedding=decoder_embedding, + decoder_embedding_mask=decoder_mask, + decoder_span_idx=decoder_span_idx, + words_embedding=words_embedding, + mask=mask, + span_idx=span_idx, + span_logits=span_logits, + span_mask=span_mask, + ) + return output + + def loss( + self, + scores: torch.Tensor, + labels: torch.Tensor, + prompts_embedding_mask: torch.Tensor, + word_mask: torch.Tensor, + alpha: float = -1.0, + gamma: float = 0.0, + prob_margin: float = 0.0, + label_smoothing: float = 0.0, + reduction: str = "sum", + negatives: float = 1.0, + span_logits: Optional[torch.Tensor] = None, + span_labels: Optional[torch.Tensor] = None, + span_mask: Optional[torch.Tensor] = None, + decoder_loss: Optional[torch.Tensor] = None, + **kwargs: Any, + ) -> torch.Tensor: + """Compute combined loss for token classification, spans, and decoder. + + Args: + scores: Predicted token scores of shape (B, W, C, 3). + labels: Ground truth token labels of shape (B, W, C, 3). + prompts_embedding_mask: Mask for valid entity types of shape (B, C). + word_mask: Mask for valid words of shape (B, W). + alpha: Focal loss alpha parameter. + gamma: Focal loss gamma parameter. + prob_margin: Margin for probability adjustment. + label_smoothing: Label smoothing factor. + reduction: Loss reduction method ('sum' or 'mean'). + negatives: Negative sampling probability. + span_logits: Optional span logits of shape (B, S, C). + span_labels: Optional span labels of shape (B, S, C). + span_mask: Optional span mask of shape (B, S). + decoder_loss: Optional decoder loss to combine. + **kwargs: Additional arguments. + + Returns: + Scalar combined loss tensor. + """ + # Token-level loss (use parent's loss function) + token_loss = UniEncoderTokenModel.loss( + self, + scores, + labels, + prompts_embedding_mask, + word_mask, + alpha, + gamma, + prob_margin, + label_smoothing, + reduction, + negatives, + **kwargs, + ) + + # Combine with span loss if available + if span_logits is not None and span_labels is not None and span_mask is not None: + span_loss = UniEncoderTokenModel.loss( + self, + span_logits, + span_labels, + prompts_embedding_mask, + span_mask, + alpha, + gamma, + prob_margin, + label_smoothing, + reduction, + negatives, + **kwargs, + ) + token_loss = self.config.token_loss_coef * token_loss + self.config.span_loss_coef * span_loss + + # Combine with decoder loss if available + if decoder_loss is not None: + total_loss = decoder_loss * self.config.decoder_loss_coef + token_loss * getattr( + self.config, "token_loss_coef", 1.0 + ) + return total_loss + + return token_loss + + class UniEncoderSpanRelexModel(UniEncoderSpanModel): """Span-based NER model with relation extraction capabilities. @@ -1628,16 +2048,17 @@ def select_span_target_embedding( - target_rep: Selected span representations of shape (B, E, D). - target_mask: Mask for selected spans of shape (B, E). """ - B, L, K, D = span_rep.shape + B = span_rep.size(0) + D = span_rep.size(-1) - span_rep_flat = span_rep.view(B, L * K, D) - span_mask_flat = span_mask.view(B, L * K) + span_rep_flat = span_rep.view(B, -1, D) + span_mask_flat = span_mask.view(B, -1) if span_labels is not None: - span_prob_flat = span_labels.max(dim=-1).values.view(B, L * K) + span_prob_flat = span_labels.max(dim=-1).values.view(B, -1) keep = (span_prob_flat == 1).bool() else: - span_prob_flat = torch.sigmoid(span_scores).max(dim=-1).values.view(B, L * K) + span_prob_flat = torch.sigmoid(span_scores).max(dim=-1).values.view(B, -1) keep = (span_prob_flat > threshold) & span_mask_flat.bool() if top_k is not None and top_k > 0: @@ -1688,6 +2109,28 @@ def select_target_embedding( return target_rep, target_mask + def represent_spans( + self, + words_embeddings, + words_mask, + prompts_embeddings, + span_idx: Optional[torch.Tensor] = None, + span_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + threshold: float = 0.5, + ): + span_idx = span_idx * span_mask.unsqueeze(-1).long() + span_rep = self.span_rep_layer(words_embeddings, span_idx) + scores = torch.einsum("BLKD,BCD->BLKC", span_rep, prompts_embeddings) + + if hasattr(self, "relations_rep_layer"): + target_span_rep, target_span_mask = self.select_span_target_embedding( + span_rep, scores, span_mask, labels, threshold + ) + else: + target_span_rep, target_span_mask = None, None + return scores, target_span_rep, target_span_mask + def forward( self, input_ids: Optional[torch.FloatTensor] = None, @@ -1743,16 +2186,20 @@ def forward( if hasattr(self, "rnn"): words_embedding = self.rnn(words_embedding, mask) - target_W = span_idx.size(1) // self.config.max_width - words_embedding, mask = self._fit_length(words_embedding, mask, target_W) - - span_idx = span_idx * span_mask.unsqueeze(-1) - - span_rep = self.span_rep_layer(words_embedding, span_idx) + if self.config.span_mode == "token_level": + if labels is not None: + target_W = labels.shape[1] + target_C = max(prompts_embedding.size(1), labels.size(-2)) + else: + target_W = words_embedding.size(1) + target_C = prompts_embedding.size(1) + else: + target_W = span_idx.size(1) // self.config.max_width + target_C = prompts_embedding.size(1) + if labels is not None: + target_C = max(target_C, labels.size(-1)) - target_C = prompts_embedding.size(1) - if labels is not None: - target_C = max(target_C, labels.size(-1)) + words_embedding, mask = self._fit_length(words_embedding, mask, target_W) prompts_embedding, prompts_embedding_mask = self._fit_length( prompts_embedding, prompts_embedding_mask, target_C @@ -1760,16 +2207,16 @@ def forward( prompts_embedding = self.prompt_rep_layer(prompts_embedding) batch_size, _, embed_dim = prompts_embedding.shape - scores = torch.einsum("BLKD,BCD->BLKC", span_rep, prompts_embedding) + + scores, target_span_rep, target_span_mask = self.represent_spans( + words_embedding, mask, prompts_embedding, span_idx, span_mask, labels, threshold + ) pair_idx, pair_mask, pair_scores = None, None, None rel_prompts_embedding_mask = None pred_adj_matrix = None if hasattr(self, "relations_rep_layer"): - target_span_rep, target_span_mask = self.select_span_target_embedding( - span_rep, scores, span_mask, labels, threshold - ) pred_adj_matrix = self.relations_rep_layer(target_span_rep, target_span_mask) rel_prompts_embedding, rel_prompts_embedding_mask = extract_prompt_features( @@ -1812,7 +2259,7 @@ def forward( loss = None if labels is not None: - loss = self.loss(scores, labels, prompts_embedding_mask, span_mask, **kwargs) + loss = self.loss(scores, labels, prompts_embedding_mask, span_mask=span_mask, word_mask=mask, **kwargs) if adj_matrix is not None and rel_matrix is not None and hasattr(self, "relations_rep_layer"): adj_mask = target_span_mask.float().unsqueeze(1) * target_span_mask.float().unsqueeze(2) @@ -1965,3 +2412,101 @@ def rel_loss( loss = masked_loss.sum() return loss + + +class UniEncoderTokenRelexModel(UniEncoderSpanRelexModel): + """Token-level NER model with relation extraction capabilities. + + This model extends token-based NER to also extract relations between + identified entities, predicting both entity types and relation types + in a joint model. + + Attributes: + relations_rep_layer (Optional[RelationsRepLayer]): Layer for computing + pairwise entity relations (adjacency matrix). + triples_score_layer (Optional[TriplesScoreLayer]): Layer for scoring + (head, relation, tail) triples. + pair_rep_layer (Optional[nn.Module]): Alternative layer for relation + scoring via concatenation. + """ + + def __init__( + self, config: Any, from_pretrained: bool = False, cache_dir: Optional[Union[str, Path]] = None + ) -> None: + """Initialize the span-based relation extraction model. + + Args: + config: Model configuration object. + from_pretrained: Whether to load from pretrained weights. + cache_dir: Directory for caching pretrained models. + """ + super().__init__(config, from_pretrained, cache_dir) + self.scorer = Scorer(config.hidden_size, config.dropout) + + def loss( + self, + scores: torch.Tensor, + labels: torch.Tensor, + prompts_embedding_mask: torch.Tensor, + word_mask: torch.Tensor, + alpha: float = -1.0, + gamma: float = 0.0, + prob_margin: float = 0.0, + label_smoothing: float = 0.0, + reduction: str = "sum", + negatives: float = 1.0, + **kwargs: Any, + ) -> torch.Tensor: + """Compute token classification loss. + + Args: + scores: Predicted scores of shape (B, W, C). + labels: Ground truth labels of shape (B, W, C). + prompts_embedding_mask: Mask for valid entity types of shape (B, C). + word_mask: Mask for valid tokens of shape (B, W). + alpha: Focal loss alpha parameter. + gamma: Focal loss gamma parameter. + prob_margin: Margin for probability adjustment. + label_smoothing: Label smoothing factor. + reduction: Loss reduction method ('sum' or 'mean'). + negatives: Negative sampling probability. + **kwargs: Additional arguments. + + Returns: + Scalar loss tensor. + """ + all_losses = self._loss(scores, labels, alpha, gamma, prob_margin, label_smoothing, negatives) + + all_losses = all_losses * (word_mask.unsqueeze(-1) * prompts_embedding_mask.unsqueeze(1)).unsqueeze(-1) + + if reduction == "mean": + loss = all_losses.mean() + elif reduction == "sum": + loss = all_losses.sum() + else: + warnings.warn( + f"Invalid Value for config 'loss_reduction': '{reduction}' \n Supported reduction modes:" + f" 'none', 'mean', 'sum'. It will be used 'sum' instead.", + stacklevel=2, + ) + loss = all_losses.sum() + return loss + + def represent_spans( + self, + words_embeddings, + words_mask, + prompts_embeddings, + span_idx: Optional[torch.Tensor] = None, + span_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + threshold: float = 0.5, + ): + scores = self.scorer(words_embeddings, prompts_embeddings) + + if span_idx is None: + span_idx, span_mask = extract_spans_from_tokens(scores, labels, threshold) + span_idx = span_idx * span_mask.unsqueeze(-1).long() + target_span_rep = self.span_rep_layer(words_embeddings, span_idx) + + return scores, target_span_rep, span_mask diff --git a/gliner/modeling/outputs.py b/gliner/modeling/outputs.py index cd275e2..a48614d 100644 --- a/gliner/modeling/outputs.py +++ b/gliner/modeling/outputs.py @@ -35,6 +35,9 @@ class GLiNERBaseOutput(ModelOutput): prompts_embedding_mask: Optional[torch.LongTensor] = None words_embedding: Optional[torch.FloatTensor] = None mask: Optional[torch.LongTensor] = None + span_idx: Optional[torch.LongTensor] = None + span_mask: Optional[torch.Tensor] = None + span_logits: Optional[torch.FloatTensor] = None @dataclass diff --git a/gliner/modeling/span_rep.py b/gliner/modeling/span_rep.py index 677447e..091c36d 100644 --- a/gliner/modeling/span_rep.py +++ b/gliner/modeling/span_rep.py @@ -634,6 +634,55 @@ def forward(self, x, *args): return out +class TokenMarker(nn.Module): + """Marks and projects span endpoints using an MLP. + + A cleaner version of SpanMarker using the create_projection_layer utility. + + Attributes: + max_width (int): Maximum span width to represent. + project_start (nn.Module): MLP for projecting start positions. + project_end (nn.Module): MLP for projecting end positions. + out_project (nn.Module): Final projection layer. + """ + + def __init__(self, hidden_size: int, dropout: float = 0.4): + """Initialize the SpanMarkerV0 layer. + + Args: + hidden_size (int): Dimension of the hidden representations. + max_width (int): Maximum span width to represent. + dropout (float, optional): Dropout rate. Defaults to 0.4. + """ + super().__init__() + self.project_start = create_projection_layer(hidden_size, dropout) + self.project_end = create_projection_layer(hidden_size, dropout) + + self.out_project = create_projection_layer(hidden_size * 2, dropout, hidden_size) + + def forward(self, h: torch.Tensor, span_idx: torch.Tensor) -> torch.Tensor: + """Compute span representations using start and end markers. + + Args: + h (torch.Tensor): Token representations of shape [B, L, D]. + span_idx (torch.Tensor): Span indices of shape [B, *, 2]. + + Returns: + torch.Tensor: Span representations of shape [B, L, max_width, D]. + """ + B, L, D = h.size() + num_spans = span_idx.size(1) + start_rep = self.project_start(h) + end_rep = self.project_end(h) + + start_span_rep = extract_elements(start_rep, span_idx[:, :, 0]) + end_span_rep = extract_elements(end_rep, span_idx[:, :, 1]) + + cat = torch.cat([start_span_rep, end_span_rep], dim=-1).relu() + + return self.out_project(cat) + + class SpanRepLayer(nn.Module): """Factory class for various span representation approaches. @@ -691,6 +740,8 @@ def __init__(self, hidden_size, max_width, span_mode, **kwargs): self.span_rep_layer = SpanConv(hidden_size, max_width, span_mode="conv_sum") elif span_mode == "conv_share": self.span_rep_layer = ConvShare(hidden_size, max_width) + elif span_mode == "token_level": + self.span_rep_layer = TokenMarker(hidden_size, **kwargs) else: raise ValueError(f"Unknown span mode {span_mode}") diff --git a/gliner/modeling/utils.py b/gliner/modeling/utils.py index 6ff0586..73b66f2 100644 --- a/gliner/modeling/utils.py +++ b/gliner/modeling/utils.py @@ -1,4 +1,4 @@ -from typing import Tuple +from typing import Tuple, Optional import torch @@ -294,3 +294,86 @@ def build_entity_pairs( tail_rep = span_rep[batch_idx, pair_idx[..., 1].clamp_min(0)] # (B, N, D) return pair_idx, pair_mask, head_rep, tail_rep + + +def extract_spans_from_tokens( + scores: torch.Tensor, + labels: Optional[torch.Tensor] = None, + threshold: float = 0.5, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Extract entity spans from BIO-style token predictions. + + Args: + scores: (B, W, C, 3) - logits for [start, end, inside] + labels: Optional (B, W, C, 3) - ground truth labels + threshold: Confidence threshold (used when labels is None) + + Returns: + span_idx: (B, N, 2) - [start, end] indices, padded + span_mask: (B, N) - validity mask + """ + B, W, C, _ = scores.shape + device = scores.device + + if labels is not None: + start_mask = labels[..., 0] > 0.5 + end_mask = labels[..., 1] > 0.5 + inside_mask = labels[..., 2] > 0.5 + else: + probs = torch.sigmoid(scores) + start_mask = probs[..., 0] > threshold + end_mask = probs[..., 1] > threshold + inside_mask = probs[..., 2] > threshold + + # Prepend zeros for cumsum indexing + inside_cumsum = torch.nn.functional.pad(inside_mask.long().cumsum(dim=1), (0, 0, 1, 0)) # (B, W+1, C) + + spans_per_sample = [] + + for b in range(B): + starts = start_mask[b].nonzero(as_tuple=False) + ends = end_mask[b].nonzero(as_tuple=False) + + if starts.size(0) == 0 or ends.size(0) == 0: + spans_per_sample.append(torch.empty(0, 2, dtype=torch.long, device=device)) + continue + + s_pos, s_cls = starts.T + e_pos, e_cls = ends.T + + # Find valid (start, end) pairs: same class & end >= start + valid = (s_cls[:, None] == e_cls) & (s_pos[:, None] <= e_pos) + si, ei = valid.nonzero(as_tuple=True) + + if si.size(0) == 0: + spans_per_sample.append(torch.empty(0, 2, dtype=torch.long, device=device)) + continue + + cs, ce, cc = s_pos[si], e_pos[ei], s_cls[si] + + # Validate: all inside positions must be marked + inside_cnt = inside_cumsum[b, ce + 1, cc] - inside_cumsum[b, cs, cc] + valid = inside_cnt == (ce - cs + 1) + + cs, ce = cs[valid], ce[valid] + + if cs.size(0) == 0: + spans_per_sample.append(torch.empty(0, 2, dtype=torch.long, device=device)) + else: + spans_per_sample.append(torch.stack([cs, ce], dim=1)) + + # Pad to uniform size + max_spans = max(s.size(0) for s in spans_per_sample) if spans_per_sample else 0 + max_spans = max(max_spans, 1) # Ensure at least 1 to avoid empty tensor issues + + span_idx = torch.zeros(B, max_spans, 2, dtype=torch.long, device=device) + span_mask = torch.zeros(B, max_spans, dtype=torch.bool, device=device) + + for b, spans in enumerate(spans_per_sample): + n = spans.size(0) + if n > 0: + span_idx[b, :n] = spans + span_mask[b, :n] = True + + return span_idx, span_mask diff --git a/gliner/onnx/model.py b/gliner/onnx/model.py index eb4e933..1fc7b19 100644 --- a/gliner/onnx/model.py +++ b/gliner/onnx/model.py @@ -369,3 +369,52 @@ def forward( rel_mask=inference_output["rel_mask"], ) return outputs + + +class UniEncoderTokenRelexORTModel(BaseORTModel): + """ONNX Runtime model for uni-encoder token-level relation extraction. + + Uses a single encoder to process text and perform both entity recognition + and relation extraction at the token level. + """ + + def forward( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + words_mask: torch.Tensor, + text_lengths: torch.Tensor, + **kwargs, + ) -> Dict[str, Any]: + """Forward pass for span relation extraction model using ONNX inference. + + Args: + input_ids: Tensor of shape (batch_size, seq_len) containing input token IDs. + attention_mask: Tensor of shape (batch_size, seq_len) with 1s for real + tokens and 0s for padding. + words_mask: Tensor of shape (batch_size, seq_len) indicating word boundaries. + text_lengths: Tensor of shape (batch_size,) containing the actual length + of each text sequence. + span_idx: Tensor containing indices of spans to classify. + span_mask: Tensor indicating which spans are valid (not padding). + **kwargs: Additional arguments (ignored). + + Returns: + GLiNERRelexOutput containing logits for span classification, relation + indices, relation logits, and relation mask. + """ + inputs = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "words_mask": words_mask, + "text_lengths": text_lengths, + } + prepared_inputs = self.prepare_inputs(inputs) + inference_output = self.run_inference(prepared_inputs) + outputs = GLiNERRelexOutput( + logits=inference_output["logits"], + rel_idx=inference_output["rel_idx"], + rel_logits=inference_output["rel_logits"], + rel_mask=inference_output["rel_mask"], + ) + return outputs diff --git a/gliner/training/trainer.py b/gliner/training/trainer.py index 1b17e3e..d516bea 100644 --- a/gliner/training/trainer.py +++ b/gliner/training/trainer.py @@ -141,7 +141,7 @@ def training_step(self, model, inputs, *args, **kwargs) -> torch.Tensor: if self.use_apex: from apex import amp - + with amp.scale_loss(loss, self.optimizer) as scaled_loss: scaled_loss.backward() else: diff --git a/train.py b/train.py index 0e98d18..7e449aa 100644 --- a/train.py +++ b/train.py @@ -48,7 +48,7 @@ def main(cfg_path: str): # Build model model = build_model(model_cfg, train_cfg) print(f"Model type: {model.__class__.__name__}") - + # Get freeze components freeze_components = train_cfg.get("freeze_components", None) if freeze_components: @@ -56,7 +56,7 @@ def main(cfg_path: str): # Train print("\nStarting training...") - trainer = model.train_model( + model.train_model( train_dataset=train_dataset, eval_dataset=eval_dataset, output_dir="models", @@ -87,7 +87,6 @@ def main(cfg_path: str): freeze_components=freeze_components, ) - trainer.save_model() print(f"\n✓ Training complete! Model saved to {output_dir}")