diff --git a/gliner/config.py b/gliner/config.py index 202c8e0..31f55f1 100644 --- a/gliner/config.py +++ b/gliner/config.py @@ -197,6 +197,10 @@ def __init__( rel_token: str = "<>", adjacency_loss_coef=1.0, relation_loss_coef=1.0, + augment_data_prob=0.5, + augment_ent_drop_prob=(0.0, 1.0), + augment_rel_drop_prob=(0.0, 0.3), + augment_add_other_prob=0.5, **kwargs, ): """Initialize UniEncoderRelexConfig. @@ -204,6 +208,8 @@ def __init__( Args: relations_layer (str, optional): Name of relations layer, see gliner.modeling.multitask.relations_layers.py. Defaults to None. + Use "none" to enable single-step relation extraction that scores all + entity pair combinations directly without adjacency filtering. triples_layer (str, optional): Name of triples layer, see gliner.modeling.multitask.triples_layers.py. Defaults to None. embed_rel_token (bool, optional): Whether to embed relation tokens. Defaults to True. @@ -211,6 +217,12 @@ def __init__( rel_token (str, optional): Relation marker token. Defaults to "<>". adjacency_loss_coef (float, optional): Adjacency modeling loss coefficient. Defaults to 1.0. relation_loss_coef (float, optional): Relation representaton loss coefficient. Defaults to 1.0. + augment_data_prob (float, optional): Probability of applying data augmentation + to an example. Defaults to 0.0 (disabled). + augment_ent_drop_prob (tuple, optional): Range (min, max) from which to sample + the per-type entity drop probability. Defaults to (0.0, 0.4). + augment_rel_drop_prob (tuple, optional): Range (min, max) from which to sample + the per-type relation drop probability. Defaults to (0.0, 0.4). **kwargs: Additional keyword arguments passed to UniEncoderConfig. Raises: @@ -225,6 +237,10 @@ def __init__( self.rel_token = rel_token self.adjacency_loss_coef = adjacency_loss_coef self.relation_loss_coef = relation_loss_coef + self.augment_data_prob = augment_data_prob + self.augment_ent_drop_prob = tuple(augment_ent_drop_prob) + self.augment_rel_drop_prob = tuple(augment_rel_drop_prob) + self.augment_add_other_prob = augment_add_other_prob class UniEncoderSpanRelexConfig(UniEncoderRelexConfig): diff --git a/gliner/data_processing/processor.py b/gliner/data_processing/processor.py index 830cdd1..e0d604d 100644 --- a/gliner/data_processing/processor.py +++ b/gliner/data_processing/processor.py @@ -1417,6 +1417,95 @@ def __init__(self, config, tokenizer, words_splitter): super().__init__(config, tokenizer, words_splitter) self.rel_token = config.rel_token + def augment_example(self, example, ner_negatives=None, other_keyword="other"): + """Apply data augmentation by randomly dropping entity/relation types. + + For each example (triggered by augment_data_prob): + - Sample a per-type drop probability from augment_ent_drop_prob range + and drop each entity type with that probability. + - Sample a per-type drop probability from augment_rel_drop_prob range + and drop each relation type with that probability. + - With augment_add_other_prob, entities whose type is dropped get their + type replaced with "other" so they still participate in relation + extraction. If "other" is not added and the entity has no active + (non-dropped) relations, it is dropped entirely. + - Relations with dropped types are removed; relation indices are remapped + after entity removal. + + Args: + example: Dictionary with 'ner' and 'relations' keys. + ner_negatives: Pool of negative entity types (unused, kept for API compat). + other_keyword: Replacement type label for dropped entity types. + + Returns: + A (possibly modified) copy of the example with '_dropped_ent_types' + and '_dropped_rel_types' metadata keys. + """ + ner = example.get("ner", []) + relations = example.get("relations", []) + + if not ner: + return example + + ent_drop_prob = random.uniform(*self.config.augment_ent_drop_prob) + rel_drop_prob = random.uniform(*self.config.augment_rel_drop_prob) + add_other = random.random() < self.config.augment_add_other_prob + + all_ent_types = set(e[-1] for e in ner) + all_rel_types = set(r[-1] for r in relations) if relations else set() + + # "other" is exempt from dropping since it's our replacement label + dropped_ent_types = {t for t in all_ent_types if t != other_keyword and random.random() < ent_drop_prob} + dropped_rel_types = {t for t in all_rel_types if random.random() < rel_drop_prob} + + if not dropped_ent_types and not dropped_rel_types: + return example + + # Determine which entities participate in non-dropped relations + entity_has_active_rel = set() + if relations: + for head_idx, tail_idx, rel_type in relations: + if rel_type not in dropped_rel_types: + entity_has_active_rel.add(head_idx) + entity_has_active_rel.add(tail_idx) + + # Process entities: replace with "other", keep, or drop + new_ner = [] + old_to_new_idx = {} + + for i, ent in enumerate(ner): + ent_type = ent[-1] + + if ent_type in dropped_ent_types and i not in entity_has_active_rel: + # Entity type dropped and no active relations → drop entity + continue + if ent_type in dropped_ent_types and not add_other: + # Entity type dropped and "other" not enabled → drop entity + continue + + old_to_new_idx[i] = len(new_ner) + if ent_type in dropped_ent_types and add_other: + # Replace dropped type with "other" + new_ner.append(list(ent[:-1]) + [other_keyword]) + else: + new_ner.append(ent) + + # Update relations: drop relations with dropped types, remap entity indices + new_relations = [] + for head_idx, tail_idx, rel_type in relations: + if rel_type in dropped_rel_types: + continue + if head_idx in old_to_new_idx and tail_idx in old_to_new_idx: + new_relations.append([old_to_new_idx[head_idx], old_to_new_idx[tail_idx], rel_type]) + + result = dict(example) + result["ner"] = new_ner + result["relations"] = new_relations + result["_dropped_ent_types"] = dropped_ent_types + result["_dropped_rel_types"] = dropped_rel_types + + return result + def batch_generate_class_mappings( self, batch_list: List[Dict], @@ -1456,6 +1545,10 @@ def batch_generate_class_mappings( max_neg_type_ratio = int(self.config.max_neg_type_ratio) neg_type_ratio = random.randint(0, max_neg_type_ratio) if max_neg_type_ratio else 0 + # Augmentation metadata (set by augment_example) + dropped_ent_types = b.get("_dropped_ent_types", set()) + dropped_rel_types = b.get("_dropped_rel_types", set()) + # Process NER types if "ner_negatives" in b: negs_i = b["ner_negatives"] @@ -1465,7 +1558,9 @@ def batch_generate_class_mappings( if "ner_labels" in b: types = b["ner_labels"] else: - types = list(set([el[-1] for el in b["ner"]] + negs_i)) + # Exclude dropped entity types ("other" replacements are already in ner) + ent_types = [el[-1] for el in b["ner"] if el[-1] not in dropped_ent_types] + types = list(set(ent_types + negs_i)) random.shuffle(types) types = types[: int(self.config.max_types)] @@ -1483,7 +1578,9 @@ def batch_generate_class_mappings( if "rel_labels" in b: rel_types = b["rel_labels"] else: - rel_types = list(set([el[-1] for el in b.get("relations", [])] + rel_negs_i)) + # Exclude dropped relation types + active_rel_types = [el[-1] for el in b.get("relations", []) if el[-1] not in dropped_rel_types] + rel_types = list(set(active_rel_types + rel_negs_i)) random.shuffle(rel_types) rel_types = rel_types[: int(self.config.max_types)] @@ -1525,6 +1622,15 @@ def collate_raw_batch( Dictionary containing collated batch data for joint entity and relation extraction. """ + # Apply data augmentation if enabled (only during dynamic mapping generation) + augment_prob = getattr(self.config, 'augment_data_prob', 0.0) + if augment_prob > 0.0 and class_to_ids is None and entity_types is None: + if ner_negatives is None: + ner_negatives = get_negatives(batch_list, sampled_neg=100, key="ner") + batch_list = [ + self.augment_example(b, ner_negatives) if random.random() < augment_prob else b + for b in batch_list + ] if class_to_ids is None and entity_types is None: # Dynamically infer per-example mappings class_to_ids, id_to_classes, rel_class_to_ids, rel_id_to_classes = self.batch_generate_class_mappings( @@ -1616,11 +1722,16 @@ def preprocess_example(self, tokens, ner, classes_to_id, relations, rel_classes_ span_to_idx = {(s, e): i for i, (s, e) in enumerate(spans_idx.tolist())} # Create entity index mapping (from original entity list to span indices) + # and compact indices (0, 1, 2, ...) matching target_span_rep ordering entity_to_span_idx = {} + entity_to_compact_idx = {} + compact_idx = 0 if ner is not None: for ent_idx, (start, end, _) in enumerate(ner): # (start, end, label) if (start, end) in span_to_idx and end < num_tokens: entity_to_span_idx[ent_idx] = span_to_idx[(start, end)] + entity_to_compact_idx[ent_idx] = compact_idx + compact_idx += 1 # Process relations rel_idx_list = [] @@ -1630,9 +1741,9 @@ def preprocess_example(self, tokens, ner, classes_to_id, relations, rel_classes_ for rel in relations: head_idx, tail_idx, rel_type = rel - # Check if both entities are valid and map to span indices - if head_idx in entity_to_span_idx and tail_idx in entity_to_span_idx and rel_type in rel_classes_to_id: - rel_idx_list.append([head_idx, tail_idx]) + # Use compact indices so rel_idx aligns with target_span_rep positions + if head_idx in entity_to_compact_idx and tail_idx in entity_to_compact_idx and rel_type in rel_classes_to_id: + rel_idx_list.append([entity_to_compact_idx[head_idx], entity_to_compact_idx[tail_idx]]) rel_label_list.append(rel_classes_to_id[rel_type]) # Convert to tensors @@ -1696,28 +1807,40 @@ def create_batch_dict(self, batch, class_to_ids, id_to_classes, rel_class_to_ids "rel_id_to_classes": rel_id_to_classes, } - def create_relation_labels(self, batch, add_reversed_negatives=True, add_random_negatives=True, negative_ratio=2.0): + def create_relation_labels(self, batch, add_reversed_negatives=True, add_random_negatives=True, negative_ratio=(1.0, 10.0)): """Create relation labels with negative pair sampling. Overrides the span-based version to work with token-level entity representations. Uses entities_id count instead of span_label for entity counting. + When relations_layer is "none", generates labels for ALL entity pair + combinations (no adjacency matrix), matching the order produced by + build_all_entity_pairs. Otherwise, uses adjacency-based pair sampling. + Args: batch: Batch dictionary containing entities and relations. add_reversed_negatives: If True, add reversed direction pairs as negatives. add_random_negatives: If True, add random entity pairs as negatives. - negative_ratio: Ratio of negative to positive pairs. + negative_ratio: Ratio of negative to positive pairs. Can be a float for a fixed ratio + or a (min, max) tuple to sample a random ratio per example. Returns: Tuple containing: - - adj_matrix: Adjacency matrix (shape: [B, max_entities, max_entities]) + - adj_matrix: Adjacency matrix (shape: [B, max_entities, max_entities]). + None when relations_layer is "none". - rel_matrix: Multi-hot relation labels (shape: [B, max_pairs, num_relation_classes]) """ B = len(batch["tokens"]) span_mask = batch["span_mask"] - # Count entities per sample (differs from span-based which uses span_label) - batch_ents = span_mask.long().squeeze(-1).sum(-1) + # For span-based models, span_mask covers all candidate spans (L*max_width), + # but rel_idx uses compact entity indices (0..num_annotated-1), so we must + # count annotated entities instead. For token-level models, span_mask is already + # sized by entity count, so the original formula works. + if "span_label" in batch and batch["span_label"] is not None: + batch_ents = (batch["span_label"] > 0).sum(-1) + else: + batch_ents = span_mask.long().squeeze(-1).sum(-1) max_En = max(batch_ents.max().item(), 1) rel_class_to_ids = batch["rel_class_to_ids"] @@ -1726,9 +1849,16 @@ def create_relation_labels(self, batch, add_reversed_negatives=True, add_random_ else: C = len(rel_class_to_ids) if rel_class_to_ids else 0 + single_step = getattr(self.config, "relations_layer", None) == "none" + if C == 0: + if single_step: + return None, torch.zeros(B, 1, 1, dtype=torch.float) return torch.zeros(B, max_En, max_En, dtype=torch.float), torch.zeros(B, 1, 1, dtype=torch.float) + if single_step: + return self._create_single_step_relation_labels(batch, batch_ents, C) + adj_matrix = torch.zeros(B, max_En, max_En, dtype=torch.float) all_pairs_info = [] @@ -1758,7 +1888,11 @@ def create_relation_labels(self, batch, add_reversed_negatives=True, add_random_ # Generate negative pairs negative_pairs = set() num_positives = len(positive_pairs) - target_negatives = int(num_positives * negative_ratio) + if isinstance(negative_ratio, (tuple, list)): + ratio = random.uniform(negative_ratio[0], negative_ratio[1]) + else: + ratio = negative_ratio + target_negatives = max(1, int(num_positives * ratio)) if add_reversed_negatives: for e1, e2 in positive_pairs: @@ -1811,6 +1945,60 @@ def create_relation_labels(self, batch, add_reversed_negatives=True, add_random_ return adj_matrix, rel_matrix + def _create_single_step_relation_labels(self, batch, batch_ents, C): + """Create relation labels for single-step mode (all entity pair combinations). + + Generates labels for ALL directed pairs (i, j) where i != j among entities, + matching the order produced by build_all_entity_pairs. + + Args: + batch: Batch dictionary containing entities and relations. + batch_ents: Tensor of entity counts per example. + C: Number of relation classes. + + Returns: + Tuple of (None, rel_matrix) where rel_matrix has shape + (B, max_pairs, C) with max_pairs = max(N_i * (N_i - 1)). + """ + B = len(batch["tokens"]) + + # Build pair-to-index mapping and collect labels + max_total_pairs = 0 + all_pair_maps = [] + + for i in range(B): + N = batch_ents[i].item() + # All (e1, e2) pairs where e1 != e2, ordered as build_all_entity_pairs produces: + # (0,1), (0,2), ..., (1,0), (1,2), ..., i.e., sorted by (e1, e2) + pair_to_idx = {} + idx = 0 + for e1 in range(N): + for e2 in range(N): + if e1 != e2: + pair_to_idx[(e1, e2)] = idx + idx += 1 + all_pair_maps.append(pair_to_idx) + max_total_pairs = max(max_total_pairs, idx) + + max_total_pairs = max(max_total_pairs, 1) + rel_matrix = torch.zeros(B, max_total_pairs, C, dtype=torch.float) + + for i in range(B): + N = batch_ents[i].item() + rel_idx_i = batch["rel_idx"][i] + rel_label_i = batch["rel_label"][i] + pair_to_idx = all_pair_maps[i] + + for k in range(rel_label_i.shape[0]): + if rel_label_i[k] > 0: + e1 = rel_idx_i[k, 0].item() + e2 = rel_idx_i[k, 1].item() + pair_key = (e1, e2) + if pair_key in pair_to_idx: + rel_matrix[i, pair_to_idx[pair_key], rel_label_i[k].item() - 1] = 1.0 + + return None, rel_matrix + def prepare_inputs( self, texts: Sequence[Sequence[str]], diff --git a/gliner/decoding/decoder.py b/gliner/decoding/decoder.py index ca0ef03..64a815a 100644 --- a/gliner/decoding/decoder.py +++ b/gliner/decoding/decoder.py @@ -962,6 +962,52 @@ def _build_span_tuple( class_probs=class_probs ) + def _build_entity_span_to_decoded_idx( + self, + spans: List[List[tuple]], + entity_spans: Optional[torch.Tensor], + batch_size: int, + ) -> List[Optional[dict]]: + """Build mapping from model entity indices to decoded span indices. + + Maps entity positions in the model's internal target_span_rep + (referenced by pair_idx) to positions in the decoded spans list. + Uses span boundaries (start, end) for matching. + + Args: + spans: Decoded entity spans per sample, each as (start, end, type, score). + entity_spans: Tensor of shape (B, E, 2) with span boundaries for each + entity in target_span_rep. None if not available. + batch_size: Number of samples. + + Returns: + List of dicts mapping model entity index -> decoded span index, + or None entries if entity_spans is not available. + """ + if entity_spans is None: + return [None] * batch_size + + mappings = [] + for i in range(batch_size): + # Build reverse lookup: (start, end) -> index in decoded spans + span_boundary_to_idx = {} + for idx, span_tuple in enumerate(spans[i]): + key = (span_tuple[0], span_tuple[1]) + if key not in span_boundary_to_idx: + span_boundary_to_idx[key] = idx + + # Map each entity in target_span_rep to decoded span index + model_to_decoded = {} + for e in range(entity_spans.size(1)): + start = entity_spans[i, e, 0].item() + end = entity_spans[i, e, 1].item() + key = (start, end) + if key in span_boundary_to_idx: + model_to_decoded[e] = span_boundary_to_idx[key] + + mappings.append(model_to_decoded) + return mappings + def _decode_relations( self, model_output, @@ -972,43 +1018,36 @@ def _decode_relations( rel_id_to_classes: Union[Dict[int, str], List[Dict[int, str]]], threshold: float, batch_size: int, + entity_spans: Optional[torch.Tensor] = None, ) -> List[List[tuple]]: """Decode relations between detected entity spans. - Extracts relation predictions from model outputs and maps them to pairs - of detected entity spans. For each potential relation, checks if both - head and tail entities exist in the decoded spans and if the relation - confidence exceeds the threshold. + Uses entity_spans (model's internal entity boundaries) to correctly map + pair_idx values (which reference target_span_rep positions) to positions + in the decoded spans list. This is necessary because greedy search may + remove or reorder entities during decoding. Args: model_output: Model output object containing relation predictions. - Expected to have attributes rel_idx, rel_logits, and optionally rel_mask. spans: List of entity spans for each sample in the batch. Each sample contains a list of tuples: (start, end, entity_type, score). rel_idx: Tensor of shape (batch_size, num_relations, 2) containing indices of head and tail entities for each potential relation. - None if no relations to decode. rel_logits: Tensor of shape (batch_size, num_relations, num_relation_classes) - containing logits for relation classifications. None if no relations. + containing logits for relation classifications. rel_mask: Optional boolean tensor of shape (batch_size, num_relations) indicating which relations are valid (True) vs. padding (False). - If None, all relations are considered valid. rel_id_to_classes: Mapping from relation class IDs to relation names. - Can be either: - - Dict: Single mapping used for all samples - - List[Dict]: Per-sample mappings for different relation schemas Class IDs are 1-indexed (0 reserved for "no relation" or padding). - threshold: Minimum confidence score (after sigmoid) for a relation - to be included in the output. Must be in range [0, 1]. + threshold: Minimum confidence score (after sigmoid) for a relation. batch_size: Number of samples in the batch. + entity_spans: Optional tensor of shape (B, E, 2) with span boundaries + for entities in target_span_rep. Used to map pair_idx to decoded spans. Returns: List of relation lists, one per sample. Each relation is a tuple: - (head_idx, relation_label, tail_idx, score) where: - - head_idx: Index into the sample's spans list for the head entity - - relation_label: String name of the relation type - - tail_idx: Index into the sample's spans list for the tail entity - - score: Confidence score for this relation (float, 0-1 range) + (head_idx, relation_label, tail_idx, score) where head_idx and tail_idx + are indices into the corresponding sample's decoded spans list. """ if rel_idx is None or rel_logits is None: return [[] for _ in range(batch_size)] @@ -1017,15 +1056,61 @@ def _decode_relations( if rel_mask is None: rel_mask = torch.ones(rel_idx[..., 0].shape, dtype=torch.bool, device=rel_idx.device) - return _decode_relations_batch( - rel_idx=rel_idx, - rel_logits=rel_logits, - rel_mask=rel_mask, - rel_probs_threshold=threshold, - spans=spans, - rel_id_to_classes=rel_id_to_classes, - batch_size=batch_size, - ) + rel_probs = torch.sigmoid(rel_logits) + + # Build mapping from model entity indices to decoded span indices + idx_mappings = self._build_entity_span_to_decoded_idx(spans, entity_spans, batch_size) + + # Decode relations for each sample + for i in range(batch_size): + rel_id_to_class_i = rel_id_to_classes[i] if isinstance(rel_id_to_classes, list) else rel_id_to_classes + idx_map = idx_mappings[i] + + # Process each potential relation + for j in range(rel_idx.size(1)): + # Skip if masked out + if not rel_mask[i, j]: + continue + + model_head_idx = rel_idx[i, j, 0].item() + model_tail_idx = rel_idx[i, j, 1].item() + + # Skip invalid indices + if model_head_idx < 0 or model_tail_idx < 0: + continue + + # Map model entity indices to decoded span indices + if idx_map is not None: + head_idx = idx_map.get(model_head_idx) + tail_idx = idx_map.get(model_tail_idx) + if head_idx is None or tail_idx is None: + continue + else: + # Fallback: use model indices directly (legacy behavior) + head_idx = model_head_idx + tail_idx = model_tail_idx + if head_idx >= len(spans[i]) or tail_idx >= len(spans[i]): + continue + + # Check each relation class + for c, p in enumerate(rel_probs[i, j]): + prob = p.item() + + # Skip low confidence predictions + if prob <= threshold: + continue + + # Skip if class ID not in mapping + # (c + 1 because 0 may be "no-relation" or padding) + if (c + 1) not in rel_id_to_class_i: + continue + + rel_label = rel_id_to_class_i[c + 1] + + # Append relation: (head_idx, relation_label, tail_idx, score) + relations[i].append((head_idx, rel_label, tail_idx, prob)) + + return relations def decode( self, @@ -1040,7 +1125,7 @@ def decode( relation_threshold: float = 0.5, multi_label: bool = False, rel_id_to_classes: Optional[Union[Dict[int, str], List[Dict[int, str]]]] = None, - return_class_probs: bool = False, + entity_spans: Optional[torch.Tensor] = None, **kwargs, ) -> Tuple[List[List[tuple]], List[List[tuple]]]: """Decode model output to extract entities and relations. @@ -1051,59 +1136,23 @@ def decode( Args: tokens: Tokenized input text for each sample in the batch. - Each sample is a list of token strings. id_to_classes: Mapping from entity class IDs to entity type names. - Can be either: - - Dict: Single mapping used for all samples (global entity schema) - - List[Dict]: Per-sample mappings for different entity schemas - Class IDs are 1-indexed (0 is reserved for padding). model_output: Model output object containing both entity logits and - optionally relation predictions. Must have a logits attribute for - entity extraction. May have rel_idx, rel_logits, and rel_mask for - relation extraction. - rel_idx: Optional tensor of shape (batch_size, num_relations, 2) containing - head and tail entity indices for each potential relation. - rel_logits: Optional tensor of shape (batch_size, num_relations, num_relation_classes) - containing relation classification logits. - rel_mask: Optional boolean tensor of shape (batch_size, num_relations) - indicating valid relations. If None, all relations are considered valid. - flat_ner: If True, applies greedy filtering to ensure non-overlapping - entity spans. If False, allows overlapping entities. Defaults to False. - threshold: Minimum confidence score (0-1) for both entity - predictions to be included in the output. Defaults to 0.5. - relation_threshold: Minimum confidence score (0-1) for both relation - predictions to be included in the output. Defaults to 0.5. - multi_label: If True, allows multiple entity types per span. If False, - only the highest-scoring entity type per span is kept. Defaults to False. + optionally relation predictions. + rel_idx: Optional tensor of shape (batch_size, num_relations, 2). + rel_logits: Optional tensor of shape (batch_size, num_relations, num_relation_classes). + rel_mask: Optional boolean tensor of shape (batch_size, num_relations). + flat_ner: If True, applies greedy filtering for non-overlapping entities. + threshold: Minimum confidence score for entity predictions. + relation_threshold: Minimum confidence score for relation predictions. + multi_label: If True, allows multiple entity types per span. rel_id_to_classes: Optional mapping from relation class IDs to relation names. - If None, relation decoding is skipped and empty relation lists are returned. - Can be either a single Dict or List[Dict] for per-sample mappings. - Class IDs are 1-indexed. - return_class_probs: Whether to include class probabilities in entity output. + entity_spans: Optional tensor of shape (B, E, 2) with span boundaries + for entities in target_span_rep. Used for correct index mapping. **kwargs: Additional keyword arguments passed to the parent class decode method. Returns: - Tuple of (spans, relations) where: - - spans: List of entity span lists, one per sample. Each entity span is - a tuple: (start, end, entity_type, score) or (start, end, entity_type, score, class_probs) - - relations: List of relation lists, one per sample. Each relation is - a tuple: (head_idx, relation_label, tail_idx, score) where head_idx - and tail_idx are indices into the corresponding sample's spans list. - - Examples: - >>> decoder = SpanRelexDecoder() - >>> tokens = [["John", "works", "at", "Microsoft"]] - >>> id_to_classes = {1: "PERSON", 2: "ORG"} - >>> rel_id_to_classes = {1: "works_at"} - >>> spans, relations = decoder.decode( - ... tokens=tokens, - ... id_to_classes=id_to_classes, - ... model_output=output, - ... rel_id_to_classes=rel_id_to_classes, - ... threshold=0.5, - ... ) - >>> # spans[0] might be: [(0, 1, "PERSON", 0.9), (3, 4, "ORG", 0.85)] - >>> # relations[0] might be: [(0, "works_at", 1, 0.8)] + Tuple of (spans, relations). """ # Decode entity spans using base class logic spans = super().decode( @@ -1130,6 +1179,7 @@ def decode( rel_id_to_classes=rel_id_to_classes, threshold=relation_threshold, batch_size=len(tokens), + entity_spans=entity_spans, ) return spans, relations @@ -1412,6 +1462,47 @@ class TokenRelexDecoder(TokenDecoder): - Multi-label entity classification """ + def _build_entity_span_to_decoded_idx( + self, + spans: List[List[tuple]], + entity_spans: Optional[torch.Tensor], + batch_size: int, + ) -> List[Optional[dict]]: + """Build mapping from model entity indices to decoded span indices. + + Uses span boundaries (start, end) to match model entities to decoded spans. + + Args: + spans: Decoded entity spans per sample, each as (start, end, type, score). + entity_spans: Tensor of shape (B, E, 2) with span boundaries for each + entity in target_span_rep. None if not available. + batch_size: Number of samples. + + Returns: + List of dicts mapping model entity index -> decoded span index. + """ + if entity_spans is None: + return [None] * batch_size + + mappings = [] + for i in range(batch_size): + span_boundary_to_idx = {} + for idx, span_tuple in enumerate(spans[i]): + key = (span_tuple[0], span_tuple[1]) + if key not in span_boundary_to_idx: + span_boundary_to_idx[key] = idx + + model_to_decoded = {} + for e in range(entity_spans.size(1)): + start = entity_spans[i, e, 0].item() + end = entity_spans[i, e, 1].item() + key = (start, end) + if key in span_boundary_to_idx: + model_to_decoded[e] = span_boundary_to_idx[key] + + mappings.append(model_to_decoded) + return mappings + def _decode_relations( self, spans: List[List[tuple]], @@ -1421,41 +1512,26 @@ def _decode_relations( rel_id_to_classes: Union[Dict[int, str], List[Dict[int, str]]], threshold: float, batch_size: int, + entity_spans: Optional[torch.Tensor] = None, ) -> List[List[tuple]]: """Decode relations between detected entity spans. - Extracts relation predictions from model outputs and maps them to pairs - of detected entity spans. For each potential relation, checks if both - head and tail entities exist in the decoded spans and if the relation - confidence exceeds the threshold. + Uses entity_spans (model's internal entity boundaries) to correctly map + pair_idx values to positions in the decoded spans list. Args: spans: List of entity spans for each sample in the batch. - Each sample contains a list of tuples: (start, end, entity_type, score). - rel_idx: Tensor of shape (batch_size, num_relations, 2) containing - indices of head and tail entities for each potential relation. - None if no relations to decode. - rel_logits: Tensor of shape (batch_size, num_relations, num_relation_classes) - containing logits for relation classifications. None if no relations. - rel_mask: Optional boolean tensor of shape (batch_size, num_relations) - indicating which relations are valid (True) vs. padding (False). - If None, all relations are considered valid. + rel_idx: Tensor of shape (batch_size, num_relations, 2). + rel_logits: Tensor of shape (batch_size, num_relations, num_relation_classes). + rel_mask: Optional boolean tensor of shape (batch_size, num_relations). rel_id_to_classes: Mapping from relation class IDs to relation names. - Can be either: - - Dict: Single mapping used for all samples - - List[Dict]: Per-sample mappings for different relation schemas - Class IDs are 1-indexed (0 reserved for "no relation" or padding). - threshold: Minimum confidence score (after sigmoid) for a relation - to be included in the output. Must be in range [0, 1]. + threshold: Minimum confidence score (after sigmoid). batch_size: Number of samples in the batch. + entity_spans: Optional tensor of shape (B, E, 2) with span boundaries + for entities in target_span_rep. Returns: - List of relation lists, one per sample. Each relation is a tuple: - (head_idx, relation_label, tail_idx, score) where: - - head_idx: Index into the sample's spans list for the head entity - - relation_label: String name of the relation type - - tail_idx: Index into the sample's spans list for the tail entity - - score: Confidence score for this relation (float, 0-1 range) + List of relation lists, one per sample. """ if rel_idx is None or rel_logits is None: return [[] for _ in range(batch_size)] @@ -1464,15 +1540,57 @@ def _decode_relations( if rel_mask is None: rel_mask = torch.ones(rel_idx[..., 0].shape, dtype=torch.bool, device=rel_idx.device) - return _decode_relations_batch( - rel_idx=rel_idx, - rel_logits=rel_logits, - rel_mask=rel_mask, - rel_probs_threshold=threshold, - spans=spans, - rel_id_to_classes=rel_id_to_classes, - batch_size=batch_size, - ) + rel_probs = torch.sigmoid(rel_logits) + + # Build mapping from model entity indices to decoded span indices + idx_mappings = self._build_entity_span_to_decoded_idx(spans, entity_spans, batch_size) + + # Decode relations for each sample + for i in range(batch_size): + rel_id_to_class_i = rel_id_to_classes[i] if isinstance(rel_id_to_classes, list) else rel_id_to_classes + idx_map = idx_mappings[i] + + # Process each potential relation + for j in range(rel_idx.size(1)): + # Skip if masked out + if not rel_mask[i, j]: + continue + + model_head_idx = rel_idx[i, j, 0].item() + model_tail_idx = rel_idx[i, j, 1].item() + + # Skip invalid indices + if model_head_idx < 0 or model_tail_idx < 0: + continue + + # Map model entity indices to decoded span indices + if idx_map is not None: + head_idx = idx_map.get(model_head_idx) + tail_idx = idx_map.get(model_tail_idx) + if head_idx is None or tail_idx is None: + continue + else: + head_idx = model_head_idx + tail_idx = model_tail_idx + if head_idx >= len(spans[i]) or tail_idx >= len(spans[i]): + continue + + # Check each relation class + for c, p in enumerate(rel_probs[i, j]): + prob = p.item() + + # Skip low confidence predictions + if prob <= threshold: + continue + + # Skip if class ID not in mapping (c + 1 because 0 is padding) + if (c + 1) not in rel_id_to_class_i: + continue + + rel_label = rel_id_to_class_i[c + 1] + relations[i].append((head_idx, rel_label, tail_idx, prob)) + + return relations def decode( self, @@ -1487,6 +1605,7 @@ def decode( relation_threshold: float = 0.5, multi_label: bool = False, rel_id_to_classes: Optional[Union[Dict[int, str], List[Dict[int, str]]]] = None, + entity_spans: Optional[torch.Tensor] = None, **kwargs, ) -> Tuple[List[List[tuple]], List[List[tuple]]]: """Decode model output to extract entities and relations. @@ -1571,6 +1690,7 @@ def decode( rel_id_to_classes=rel_id_to_classes, threshold=relation_threshold, batch_size=len(tokens), + entity_spans=entity_spans, ) return spans, relations diff --git a/gliner/model.py b/gliner/model.py index ef09ad0..b6fee65 100644 --- a/gliner/model.py +++ b/gliner/model.py @@ -2528,12 +2528,9 @@ def collate_fn(batch): if not isinstance(rel_mask, torch.Tensor): rel_mask = torch.from_numpy(rel_mask) - # Slice input_spans for this batch - batch_input_spans = None - if word_input_spans is not None: - current_batch_size = len(batch["tokens"]) - batch_input_spans = word_input_spans[batch_offset:batch_offset + current_batch_size] - batch_offset += current_batch_size + entity_spans = getattr(model_output, "entity_spans", None) + if entity_spans is not None and not isinstance(entity_spans, torch.Tensor): + entity_spans = torch.from_numpy(entity_spans) decoded_results = self.decoder.decode( batch["tokens"], @@ -2547,8 +2544,7 @@ def collate_fn(batch): relation_threshold=relation_threshold, multi_label=multi_label, rel_id_to_classes=batch["rel_id_to_classes"], - return_class_probs=return_class_probs, - input_spans=batch_input_spans, + entity_spans=entity_spans, ) if len(decoded_results) == 1: @@ -2779,6 +2775,10 @@ def evaluate( if not isinstance(rel_mask, torch.Tensor): rel_mask = torch.from_numpy(rel_mask) + entity_spans = getattr(model_output, "entity_spans", None) + if entity_spans is not None and not isinstance(entity_spans, torch.Tensor): + entity_spans = torch.from_numpy(entity_spans) + # Decode predictions decoded_results = self.decoder.decode( batch["tokens"], @@ -2792,6 +2792,7 @@ def evaluate( relation_threshold=relation_threshold, multi_label=multi_label, rel_id_to_classes=batch["rel_id_to_classes"], + entity_spans=entity_spans, ) # Unpack results diff --git a/gliner/modeling/base.py b/gliner/modeling/base.py index 8f00bf4..0bbbfc2 100644 --- a/gliner/modeling/base.py +++ b/gliner/modeling/base.py @@ -26,6 +26,7 @@ from torch.nn import functional as F from .utils import ( + build_all_entity_pairs, build_entity_pairs, extract_prompt_features, extract_word_embeddings, @@ -2014,9 +2015,10 @@ def __init__( super().__init__(config, from_pretrained, cache_dir) if config.relations_layer is not None: - self.relations_rep_layer = RelationsRepLayer( - in_dim=config.hidden_size, relation_mode=config.relations_layer - ) + if config.relations_layer != "none": + self.relations_rep_layer = RelationsRepLayer( + in_dim=config.hidden_size, relation_mode=config.relations_layer + ) if config.triples_layer is not None: self.triples_score_layer = TriplesScoreLayer(config.triples_layer) @@ -2033,7 +2035,8 @@ def select_span_target_embedding( span_labels: Optional[torch.FloatTensor] = None, threshold: float = 0.5, top_k: Optional[int] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: + span_idx: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: """Select entity spans for relation extraction. Filters spans based on entity classification scores or ground truth labels, @@ -2046,11 +2049,15 @@ def select_span_target_embedding( span_labels: Optional ground truth labels of shape (B, L, K, C). threshold: Confidence threshold for selecting spans. top_k: Optional limit on number of spans to select. + span_idx: Optional span boundary indices of shape (B, L*K, 2). + If provided, selected boundaries are returned for decoding. Returns: Tuple containing: - target_rep: Selected span representations of shape (B, E, D). - target_mask: Mask for selected spans of shape (B, E). + - target_span_idx: Selected span boundaries of shape (B, E, 2), + or None if span_idx was not provided. """ B = span_rep.size(0) D = span_rep.size(-1) @@ -2058,12 +2065,14 @@ def select_span_target_embedding( span_rep_flat = span_rep.view(B, -1, D) span_mask_flat = span_mask.view(B, -1) - if span_labels is not None: + if span_labels is not None and span_labels.size(-1) > 0: span_prob_flat = span_labels.max(dim=-1).values.view(B, -1) keep = (span_prob_flat == 1).bool() - else: + elif span_scores is not None and span_scores.size(-1) > 0: span_prob_flat = torch.sigmoid(span_scores).max(dim=-1).values.view(B, -1) keep = (span_prob_flat > threshold) & span_mask_flat.bool() + else: + keep = torch.zeros_like(span_mask_flat, dtype=torch.bool) if top_k is not None and top_k > 0: sel_scores = span_prob_flat.masked_fill(~keep, -1.0) @@ -2075,7 +2084,16 @@ def select_span_target_embedding( target_rep, target_mask = self.select_target_embedding(representations=span_rep_flat, rep_mask=rep_mask) - return target_rep, target_mask + # Also select span boundaries using the same mask + target_span_idx = None + if span_idx is not None: + span_idx_float = span_idx.float() + target_span_idx_float, _ = self.select_target_embedding( + representations=span_idx_float, rep_mask=rep_mask + ) + target_span_idx = target_span_idx_float.long() + + return target_rep, target_mask, target_span_idx def select_target_embedding( self, representations: Optional[torch.FloatTensor] = None, rep_mask: Optional[torch.LongTensor] = None @@ -2127,13 +2145,14 @@ def represent_spans( span_rep = self.span_rep_layer(words_embeddings, span_idx) scores = torch.einsum("BLKD,BCD->BLKC", span_rep, prompts_embeddings) - if hasattr(self, "relations_rep_layer"): - target_span_rep, target_span_mask = self.select_span_target_embedding( - span_rep, scores, span_mask, labels, threshold + has_relex = hasattr(self, "relations_rep_layer") or hasattr(self, "pair_rep_layer") or hasattr(self, "triples_score_layer") + if has_relex: + target_span_rep, target_span_mask, entity_spans = self.select_span_target_embedding( + span_rep, scores, span_mask, labels, threshold, span_idx=span_idx ) else: - target_span_rep, target_span_mask = None, None - return scores, target_span_rep, target_span_mask + target_span_rep, target_span_mask, entity_spans = None, None, None + return scores, target_span_rep, target_span_mask, entity_spans def forward( self, @@ -2212,7 +2231,7 @@ def forward( prompts_embedding = self.prompt_rep_layer(prompts_embedding) batch_size, _, embed_dim = prompts_embedding.shape - scores, target_span_rep, target_span_mask = self.represent_spans( + scores, target_span_rep, target_span_mask, entity_spans = self.represent_spans( words_embedding, mask, prompts_embedding, span_idx, span_mask, labels, threshold ) @@ -2220,8 +2239,11 @@ def forward( rel_prompts_embedding_mask = None pred_adj_matrix = None - if hasattr(self, "relations_rep_layer"): - pred_adj_matrix = self.relations_rep_layer(target_span_rep, target_span_mask) + has_relex = hasattr(self, "relations_rep_layer") or hasattr(self, "pair_rep_layer") or hasattr(self, "triples_score_layer") + + if has_relex: + if hasattr(self, "relations_rep_layer"): + pred_adj_matrix = self.relations_rep_layer(target_span_rep, target_span_mask) rel_prompts_embedding, rel_prompts_embedding_mask = extract_prompt_features( self.config.rel_token_index, @@ -2236,11 +2258,15 @@ def forward( B, _, D = target_span_rep.shape C_rel = rel_prompts_embedding.size(1) - adj_for_selection = adj_matrix if (labels is not None and adj_matrix is not None) else pred_adj_matrix - - pair_idx, pair_mask, head_rep_selected, tail_rep_selected = build_entity_pairs( - adj_for_selection, target_span_rep, threshold=adjacency_threshold - ) + if hasattr(self, "relations_rep_layer"): + adj_for_selection = adj_matrix if (labels is not None and adj_matrix is not None) else pred_adj_matrix + pair_idx, pair_mask, head_rep_selected, tail_rep_selected = build_entity_pairs( + adj_for_selection, target_span_rep, threshold=adjacency_threshold + ) + else: + pair_idx, pair_mask, head_rep_selected, tail_rep_selected = build_all_entity_pairs( + target_span_rep, target_span_mask + ) N = head_rep_selected.size(1) @@ -2263,24 +2289,57 @@ def forward( loss = None if labels is not None: - loss = self.loss(scores, labels, prompts_embedding_mask, span_mask=span_mask, word_mask=mask, **kwargs) - - if adj_matrix is not None and rel_matrix is not None and hasattr(self, "relations_rep_layer"): - adj_mask = target_span_mask.float().unsqueeze(1) * target_span_mask.float().unsqueeze(2) - adj_loss = self.adj_loss(pred_adj_matrix, adj_matrix, adj_mask, **kwargs) + num_ner_classes = prompts_embedding_mask.shape[-1] + if num_ner_classes > 0: + loss = self.loss(scores, labels, prompts_embedding_mask, span_mask=span_mask, word_mask=mask, **kwargs) + if has_relex and rel_matrix is not None and C_rel > 0: rel_labels_selected = rel_matrix + # Align rel_labels_selected to N (pairs from build_entity_pairs / build_all_entity_pairs). + # They should match by construction, but can differ by 1 in edge cases. + if rel_labels_selected.size(1) != N: + if rel_labels_selected.size(1) > N: + rel_labels_selected = rel_labels_selected[:, :N, :] + else: + pad = rel_labels_selected.new_zeros(B, N - rel_labels_selected.size(1), rel_labels_selected.size(2)) + rel_labels_selected = torch.cat([rel_labels_selected, pad], dim=1) + + # Align rel_labels_selected to C_rel (relation classes from prompt embeddings). + # rel_matrix uses C from rel_class_to_ids which can differ from C_rel. + if rel_labels_selected.size(2) != C_rel: + if rel_labels_selected.size(2) > C_rel: + rel_labels_selected = rel_labels_selected[:, :, :C_rel] + else: + pad = rel_labels_selected.new_zeros(B, rel_labels_selected.size(1), C_rel - rel_labels_selected.size(2)) + rel_labels_selected = torch.cat([rel_labels_selected, pad], dim=2) + rel_mask_selected = pair_mask.unsqueeze(-1).expand(B, N, C_rel) class_mask = rel_prompts_embedding_mask.unsqueeze(1).expand(B, N, C_rel) rel_loss = self.rel_loss(pair_scores, rel_labels_selected, rel_mask_selected, class_mask, **kwargs) - loss = ( - loss * self.config.span_loss_coef - + adj_loss * self.config.adjacency_loss_coef - + rel_loss * self.config.relation_loss_coef - ) - + span_loss = loss * self.config.span_loss_coef if loss is not None else 0.0 + + if hasattr(self, "relations_rep_layer") and adj_matrix is not None: + adj_mask = target_span_mask.float().unsqueeze(1) * target_span_mask.float().unsqueeze(2) + adj_loss = self.adj_loss(pred_adj_matrix, adj_matrix, adj_mask, **kwargs) + + loss = ( + span_loss + + adj_loss * self.config.adjacency_loss_coef + + rel_loss * self.config.relation_loss_coef + ) + else: + loss = ( + span_loss + + rel_loss * self.config.relation_loss_coef + ) + + # During training, rel_logits/rel_idx/rel_mask/entity_spans can have + # variable sizes across batch splits (different C_rel or N per GPU), + # which causes DataParallel gather to fail. Only loss is needed for + # training, so skip these tensors when labels are provided. + is_training = labels is not None output = GLiNERRelexOutput( logits=scores, loss=loss, @@ -2288,9 +2347,10 @@ def forward( prompts_embedding_mask=prompts_embedding_mask, words_embedding=words_embedding, mask=mask, - rel_idx=pair_idx, - rel_logits=pair_scores, - rel_mask=pair_mask, + rel_idx=None if is_training else pair_idx, + rel_logits=None if is_training else pair_scores, + rel_mask=None if is_training else pair_mask, + entity_spans=None if is_training else entity_spans, ) return output @@ -2514,4 +2574,5 @@ def represent_spans( span_idx = span_idx * span_mask.unsqueeze(-1).long() target_span_rep = self.span_rep_layer(words_embeddings, span_idx) - return scores, target_span_rep, span_mask + # span_idx directly corresponds to target_span_rep positions + return scores, target_span_rep, span_mask, span_idx diff --git a/gliner/modeling/outputs.py b/gliner/modeling/outputs.py index a48614d..ce53fe3 100644 --- a/gliner/modeling/outputs.py +++ b/gliner/modeling/outputs.py @@ -99,3 +99,4 @@ class GLiNERRelexOutput(GLiNERBaseOutput): rel_idx: Optional[torch.LongTensor] = None rel_logits: Optional[torch.FloatTensor] = None rel_mask: Optional[torch.FloatTensor] = None + entity_spans: Optional[torch.LongTensor] = None diff --git a/gliner/modeling/utils.py b/gliner/modeling/utils.py index 71c8ed9..620b5c9 100644 --- a/gliner/modeling/utils.py +++ b/gliner/modeling/utils.py @@ -291,6 +291,68 @@ def build_entity_pairs( return pair_idx, pair_mask, head_rep, tail_rep +def build_all_entity_pairs( + span_rep: torch.Tensor, + span_mask: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Build all possible entity pairs for single-step relation extraction. + + Generates all directed pairs (i, j) where i != j for valid entities + (those with span_mask == 1), without any adjacency filtering. + + Args: + span_rep: Entity/span embeddings. Shape: (batch_size, num_entities, embed_dim) + span_mask: Mask for valid entities. Shape: (batch_size, num_entities) + + Returns: + Tuple containing: + - pair_idx: Indices of (head, tail) entity pairs. Shape: (B, max_pairs, 2) + - pair_mask: Boolean mask for valid pairs. Shape: (B, max_pairs) + - head_rep: Head entity embeddings. Shape: (B, max_pairs, embed_dim) + - tail_rep: Tail entity embeddings. Shape: (B, max_pairs, embed_dim) + """ + B, E, D = span_rep.shape + device = span_rep.device + + # Count valid entities per example + entity_counts = span_mask.long().sum(dim=1) # (B,) + + # Build pairs per example + batch_pair_lists: list[torch.Tensor] = [] + for b in range(B): + n = entity_counts[b].item() + if n < 2: + batch_pair_lists.append(torch.zeros(0, 2, dtype=torch.long, device=device)) + continue + # All (i, j) pairs where i != j, both < n + idx = torch.arange(n, device=device) + row = idx.repeat_interleave(n - 1) + col = torch.cat([torch.cat([idx[:i], idx[i + 1:]]) for i in range(n)]) + batch_pair_lists.append(torch.stack([row, col], dim=-1)) + + N = max(p.shape[0] for p in batch_pair_lists) if batch_pair_lists else 0 + + if N == 0: + pair_idx = torch.full((B, 1, 2), -1, dtype=torch.long, device=device) + pair_mask = torch.zeros((B, 1), dtype=torch.bool, device=device) + head_rep = tail_rep = torch.zeros((B, 1, D), dtype=span_rep.dtype, device=device) + return pair_idx, pair_mask, head_rep, tail_rep + + pair_idx = torch.full((B, N, 2), -1, dtype=torch.long, device=device) + pair_mask = torch.zeros((B, N), dtype=torch.bool, device=device) + + for b, pairs in enumerate(batch_pair_lists): + m = pairs.shape[0] + pair_idx[b, :m] = pairs + pair_mask[b, :m] = True + + batch_idx = torch.arange(B, device=device).unsqueeze(1) + head_rep = span_rep[batch_idx, pair_idx[..., 0].clamp_min(0)] + tail_rep = span_rep[batch_idx, pair_idx[..., 1].clamp_min(0)] + + return pair_idx, pair_mask, head_rep, tail_rep + + def extract_spans_from_tokens( scores: torch.Tensor, labels: Optional[torch.Tensor] = None,