Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions gliner/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,20 +197,32 @@ def __init__(
rel_token: str = "<<REL>>",
adjacency_loss_coef=1.0,
relation_loss_coef=1.0,
augment_data_prob=0.5,
augment_ent_drop_prob=(0.0, 1.0),
augment_rel_drop_prob=(0.0, 0.3),
augment_add_other_prob=0.5,
**kwargs,
):
"""Initialize UniEncoderRelexConfig.

Args:
relations_layer (str, optional): Name of relations layer,
see gliner.modeling.multitask.relations_layers.py. Defaults to None.
Use "none" to enable single-step relation extraction that scores all
entity pair combinations directly without adjacency filtering.
triples_layer (str, optional): Name of triples layer,
see gliner.modeling.multitask.triples_layers.py. Defaults to None.
embed_rel_token (bool, optional): Whether to embed relation tokens. Defaults to True.
rel_token_index (int, optional): Index of relation token. Defaults to -1.
rel_token (str, optional): Relation marker token. Defaults to "<<REL>>".
adjacency_loss_coef (float, optional): Adjacency modeling loss coefficient. Defaults to 1.0.
relation_loss_coef (float, optional): Relation representaton loss coefficient. Defaults to 1.0.
augment_data_prob (float, optional): Probability of applying data augmentation
to an example. Defaults to 0.0 (disabled).
augment_ent_drop_prob (tuple, optional): Range (min, max) from which to sample
the per-type entity drop probability. Defaults to (0.0, 0.4).
augment_rel_drop_prob (tuple, optional): Range (min, max) from which to sample
the per-type relation drop probability. Defaults to (0.0, 0.4).
**kwargs: Additional keyword arguments passed to UniEncoderConfig.

Raises:
Expand All @@ -225,6 +237,10 @@ def __init__(
self.rel_token = rel_token
self.adjacency_loss_coef = adjacency_loss_coef
self.relation_loss_coef = relation_loss_coef
self.augment_data_prob = augment_data_prob
self.augment_ent_drop_prob = tuple(augment_ent_drop_prob)
self.augment_rel_drop_prob = tuple(augment_rel_drop_prob)
self.augment_add_other_prob = augment_add_other_prob


class UniEncoderSpanRelexConfig(UniEncoderRelexConfig):
Expand Down
210 changes: 199 additions & 11 deletions gliner/data_processing/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1417,6 +1417,95 @@ def __init__(self, config, tokenizer, words_splitter):
super().__init__(config, tokenizer, words_splitter)
self.rel_token = config.rel_token

def augment_example(self, example, ner_negatives=None, other_keyword="other"):
"""Apply data augmentation by randomly dropping entity/relation types.

For each example (triggered by augment_data_prob):
- Sample a per-type drop probability from augment_ent_drop_prob range
and drop each entity type with that probability.
- Sample a per-type drop probability from augment_rel_drop_prob range
and drop each relation type with that probability.
- With augment_add_other_prob, entities whose type is dropped get their
type replaced with "other" so they still participate in relation
extraction. If "other" is not added and the entity has no active
(non-dropped) relations, it is dropped entirely.
- Relations with dropped types are removed; relation indices are remapped
after entity removal.

Args:
example: Dictionary with 'ner' and 'relations' keys.
ner_negatives: Pool of negative entity types (unused, kept for API compat).
other_keyword: Replacement type label for dropped entity types.

Returns:
A (possibly modified) copy of the example with '_dropped_ent_types'
and '_dropped_rel_types' metadata keys.
"""
ner = example.get("ner", [])
relations = example.get("relations", [])

if not ner:
return example

ent_drop_prob = random.uniform(*self.config.augment_ent_drop_prob)
rel_drop_prob = random.uniform(*self.config.augment_rel_drop_prob)
add_other = random.random() < self.config.augment_add_other_prob

all_ent_types = set(e[-1] for e in ner)
all_rel_types = set(r[-1] for r in relations) if relations else set()

# "other" is exempt from dropping since it's our replacement label
dropped_ent_types = {t for t in all_ent_types if t != other_keyword and random.random() < ent_drop_prob}
dropped_rel_types = {t for t in all_rel_types if random.random() < rel_drop_prob}

if not dropped_ent_types and not dropped_rel_types:
return example

# Determine which entities participate in non-dropped relations
entity_has_active_rel = set()
if relations:
for head_idx, tail_idx, rel_type in relations:
if rel_type not in dropped_rel_types:
entity_has_active_rel.add(head_idx)
entity_has_active_rel.add(tail_idx)

# Process entities: replace with "other", keep, or drop
new_ner = []
old_to_new_idx = {}

for i, ent in enumerate(ner):
ent_type = ent[-1]

if ent_type in dropped_ent_types and i not in entity_has_active_rel:
# Entity type dropped and no active relations → drop entity
continue
if ent_type in dropped_ent_types and not add_other:
# Entity type dropped and "other" not enabled → drop entity
continue

old_to_new_idx[i] = len(new_ner)
if ent_type in dropped_ent_types and add_other:
# Replace dropped type with "other"
new_ner.append(list(ent[:-1]) + [other_keyword])
else:
new_ner.append(ent)

# Update relations: drop relations with dropped types, remap entity indices
new_relations = []
for head_idx, tail_idx, rel_type in relations:
if rel_type in dropped_rel_types:
continue
if head_idx in old_to_new_idx and tail_idx in old_to_new_idx:
new_relations.append([old_to_new_idx[head_idx], old_to_new_idx[tail_idx], rel_type])

result = dict(example)
result["ner"] = new_ner
result["relations"] = new_relations
result["_dropped_ent_types"] = dropped_ent_types
result["_dropped_rel_types"] = dropped_rel_types

return result

def batch_generate_class_mappings(
self,
batch_list: List[Dict],
Expand Down Expand Up @@ -1456,6 +1545,10 @@ def batch_generate_class_mappings(
max_neg_type_ratio = int(self.config.max_neg_type_ratio)
neg_type_ratio = random.randint(0, max_neg_type_ratio) if max_neg_type_ratio else 0

# Augmentation metadata (set by augment_example)
dropped_ent_types = b.get("_dropped_ent_types", set())
dropped_rel_types = b.get("_dropped_rel_types", set())

# Process NER types
if "ner_negatives" in b:
negs_i = b["ner_negatives"]
Expand All @@ -1465,7 +1558,9 @@ def batch_generate_class_mappings(
if "ner_labels" in b:
types = b["ner_labels"]
else:
types = list(set([el[-1] for el in b["ner"]] + negs_i))
# Exclude dropped entity types ("other" replacements are already in ner)
ent_types = [el[-1] for el in b["ner"] if el[-1] not in dropped_ent_types]
types = list(set(ent_types + negs_i))
random.shuffle(types)
types = types[: int(self.config.max_types)]

Expand All @@ -1483,7 +1578,9 @@ def batch_generate_class_mappings(
if "rel_labels" in b:
rel_types = b["rel_labels"]
else:
rel_types = list(set([el[-1] for el in b.get("relations", [])] + rel_negs_i))
# Exclude dropped relation types
active_rel_types = [el[-1] for el in b.get("relations", []) if el[-1] not in dropped_rel_types]
rel_types = list(set(active_rel_types + rel_negs_i))
random.shuffle(rel_types)
rel_types = rel_types[: int(self.config.max_types)]

Expand Down Expand Up @@ -1525,6 +1622,15 @@ def collate_raw_batch(
Dictionary containing collated batch data for joint entity and
relation extraction.
"""
# Apply data augmentation if enabled (only during dynamic mapping generation)
augment_prob = getattr(self.config, 'augment_data_prob', 0.0)
if augment_prob > 0.0 and class_to_ids is None and entity_types is None:
if ner_negatives is None:
ner_negatives = get_negatives(batch_list, sampled_neg=100, key="ner")
batch_list = [
self.augment_example(b, ner_negatives) if random.random() < augment_prob else b
for b in batch_list
]
if class_to_ids is None and entity_types is None:
# Dynamically infer per-example mappings
class_to_ids, id_to_classes, rel_class_to_ids, rel_id_to_classes = self.batch_generate_class_mappings(
Expand Down Expand Up @@ -1616,11 +1722,16 @@ def preprocess_example(self, tokens, ner, classes_to_id, relations, rel_classes_
span_to_idx = {(s, e): i for i, (s, e) in enumerate(spans_idx.tolist())}

# Create entity index mapping (from original entity list to span indices)
# and compact indices (0, 1, 2, ...) matching target_span_rep ordering
entity_to_span_idx = {}
entity_to_compact_idx = {}
compact_idx = 0
if ner is not None:
for ent_idx, (start, end, _) in enumerate(ner): # (start, end, label)
if (start, end) in span_to_idx and end < num_tokens:
entity_to_span_idx[ent_idx] = span_to_idx[(start, end)]
entity_to_compact_idx[ent_idx] = compact_idx
compact_idx += 1

# Process relations
rel_idx_list = []
Expand All @@ -1630,9 +1741,9 @@ def preprocess_example(self, tokens, ner, classes_to_id, relations, rel_classes_
for rel in relations:
head_idx, tail_idx, rel_type = rel

# Check if both entities are valid and map to span indices
if head_idx in entity_to_span_idx and tail_idx in entity_to_span_idx and rel_type in rel_classes_to_id:
rel_idx_list.append([head_idx, tail_idx])
# Use compact indices so rel_idx aligns with target_span_rep positions
if head_idx in entity_to_compact_idx and tail_idx in entity_to_compact_idx and rel_type in rel_classes_to_id:
rel_idx_list.append([entity_to_compact_idx[head_idx], entity_to_compact_idx[tail_idx]])
rel_label_list.append(rel_classes_to_id[rel_type])

# Convert to tensors
Expand Down Expand Up @@ -1696,28 +1807,40 @@ def create_batch_dict(self, batch, class_to_ids, id_to_classes, rel_class_to_ids
"rel_id_to_classes": rel_id_to_classes,
}

def create_relation_labels(self, batch, add_reversed_negatives=True, add_random_negatives=True, negative_ratio=2.0):
def create_relation_labels(self, batch, add_reversed_negatives=True, add_random_negatives=True, negative_ratio=(1.0, 10.0)):
"""Create relation labels with negative pair sampling.

Overrides the span-based version to work with token-level entity representations.
Uses entities_id count instead of span_label for entity counting.

When relations_layer is "none", generates labels for ALL entity pair
combinations (no adjacency matrix), matching the order produced by
build_all_entity_pairs. Otherwise, uses adjacency-based pair sampling.

Args:
batch: Batch dictionary containing entities and relations.
add_reversed_negatives: If True, add reversed direction pairs as negatives.
add_random_negatives: If True, add random entity pairs as negatives.
negative_ratio: Ratio of negative to positive pairs.
negative_ratio: Ratio of negative to positive pairs. Can be a float for a fixed ratio
or a (min, max) tuple to sample a random ratio per example.

Returns:
Tuple containing:
- adj_matrix: Adjacency matrix (shape: [B, max_entities, max_entities])
- adj_matrix: Adjacency matrix (shape: [B, max_entities, max_entities]).
None when relations_layer is "none".
- rel_matrix: Multi-hot relation labels (shape: [B, max_pairs, num_relation_classes])
"""
B = len(batch["tokens"])
span_mask = batch["span_mask"]

# Count entities per sample (differs from span-based which uses span_label)
batch_ents = span_mask.long().squeeze(-1).sum(-1)
# For span-based models, span_mask covers all candidate spans (L*max_width),
# but rel_idx uses compact entity indices (0..num_annotated-1), so we must
# count annotated entities instead. For token-level models, span_mask is already
# sized by entity count, so the original formula works.
if "span_label" in batch and batch["span_label"] is not None:
batch_ents = (batch["span_label"] > 0).sum(-1)
else:
batch_ents = span_mask.long().squeeze(-1).sum(-1)
max_En = max(batch_ents.max().item(), 1)

rel_class_to_ids = batch["rel_class_to_ids"]
Expand All @@ -1726,9 +1849,16 @@ def create_relation_labels(self, batch, add_reversed_negatives=True, add_random_
else:
C = len(rel_class_to_ids) if rel_class_to_ids else 0

single_step = getattr(self.config, "relations_layer", None) == "none"

if C == 0:
if single_step:
return None, torch.zeros(B, 1, 1, dtype=torch.float)
return torch.zeros(B, max_En, max_En, dtype=torch.float), torch.zeros(B, 1, 1, dtype=torch.float)

if single_step:
return self._create_single_step_relation_labels(batch, batch_ents, C)

adj_matrix = torch.zeros(B, max_En, max_En, dtype=torch.float)

all_pairs_info = []
Expand Down Expand Up @@ -1758,7 +1888,11 @@ def create_relation_labels(self, batch, add_reversed_negatives=True, add_random_
# Generate negative pairs
negative_pairs = set()
num_positives = len(positive_pairs)
target_negatives = int(num_positives * negative_ratio)
if isinstance(negative_ratio, (tuple, list)):
ratio = random.uniform(negative_ratio[0], negative_ratio[1])
else:
ratio = negative_ratio
target_negatives = max(1, int(num_positives * ratio))

if add_reversed_negatives:
for e1, e2 in positive_pairs:
Expand Down Expand Up @@ -1811,6 +1945,60 @@ def create_relation_labels(self, batch, add_reversed_negatives=True, add_random_

return adj_matrix, rel_matrix

def _create_single_step_relation_labels(self, batch, batch_ents, C):
"""Create relation labels for single-step mode (all entity pair combinations).

Generates labels for ALL directed pairs (i, j) where i != j among entities,
matching the order produced by build_all_entity_pairs.

Args:
batch: Batch dictionary containing entities and relations.
batch_ents: Tensor of entity counts per example.
C: Number of relation classes.

Returns:
Tuple of (None, rel_matrix) where rel_matrix has shape
(B, max_pairs, C) with max_pairs = max(N_i * (N_i - 1)).
"""
B = len(batch["tokens"])

# Build pair-to-index mapping and collect labels
max_total_pairs = 0
all_pair_maps = []

for i in range(B):
N = batch_ents[i].item()
# All (e1, e2) pairs where e1 != e2, ordered as build_all_entity_pairs produces:
# (0,1), (0,2), ..., (1,0), (1,2), ..., i.e., sorted by (e1, e2)
pair_to_idx = {}
idx = 0
for e1 in range(N):
for e2 in range(N):
if e1 != e2:
pair_to_idx[(e1, e2)] = idx
idx += 1
all_pair_maps.append(pair_to_idx)
max_total_pairs = max(max_total_pairs, idx)

max_total_pairs = max(max_total_pairs, 1)
rel_matrix = torch.zeros(B, max_total_pairs, C, dtype=torch.float)

for i in range(B):
N = batch_ents[i].item()
rel_idx_i = batch["rel_idx"][i]
rel_label_i = batch["rel_label"][i]
pair_to_idx = all_pair_maps[i]

for k in range(rel_label_i.shape[0]):
if rel_label_i[k] > 0:
e1 = rel_idx_i[k, 0].item()
e2 = rel_idx_i[k, 1].item()
pair_key = (e1, e2)
if pair_key in pair_to_idx:
rel_matrix[i, pair_to_idx[pair_key], rel_label_i[k].item() - 1] = 1.0

return None, rel_matrix

def prepare_inputs(
self,
texts: Sequence[Sequence[str]],
Expand Down
Loading