In [91]:
def create_entity_matrix(text, ade_list):
    """
    Create an entity matrix for a given sentence and ADE list.
    """
    words = text.split()
    n = len(words)
    
    # Initialize NxN matrix
    entity_matrix = np.zeros((n, n), dtype=int)

    # Mark entity spans
    for ade in ade_list:
        ade_tokens = ade.split()
        current_token = 0
        current_word = 0
        starting_word = -1
        # Locate entity indices
        if len(ade_tokens) == 1:
            for i in range(n):
                if words[i] == ade_tokens[0]:
                    entity_matrix[i][i] = 1
        else:
            while current_word < n and current_token < len(ade_tokens)-1:
                if words[current_word] == ade_tokens[current_token]:
                    if starting_word == -1:
                        starting_word = current_word
                    next_not_found = True
                    next_word = current_word
                    while next_word < n and next_not_found:
                        if words[next_word] == ade_tokens[current_token+1]:
                            entity_matrix[current_word][next_word] = 1
                            current_token += 1
                            if current_token == len(ade_tokens)-1:
                                entity_matrix[next_word][starting_word] = 1
                            next_not_found = False
                        next_word += 1
                current_word += 1
                
    return entity_matrix

import numpy as np

def extract_entities_from_matrix(text, entity_matrix):
    """
    Extract entities from a binary matrix by constructing a tree of next-word connections
    and using the lower triangle to detect and close entity loops.

    Args:
        text (str): The input sentence.
        entity_matrix (np.ndarray): The NxN entity matrix.

    Returns:
        List[str]: A list of extracted entities.
    """
    words = text.split()
    n = len(words)
    entities = set()

    # Step 1: Construct adjacency tree from upper triangle (forward connections)
    entity_tree = {i: [] for i in range(n)}

    for i in range(n):
        for j in range(i + 1, n):
            if entity_matrix[i, j] == 1:  # Upper triangle connection
                entity_tree[i].append(j)

    # Step 2: Identify entity loops from lower triangle (closing connections)
    closing_loops = [(j, i) for i in range(n) for j in range(i + 1, n) if entity_matrix[j, i] == 1]

    # Step 3: Traverse the tree and extract entities
    for end_word, start_word in closing_loops:
        stack = [(start_word, [start_word])]  # Stack to track traversal paths

        while stack:
            current_word, path = stack.pop()

            # If we reach the closing loop, finalize the entity
            if current_word == end_word:
                entity = " ".join(words[idx] for idx in path)
                entities.add(entity)
                continue  # Stop extending this path

            # Traverse only valid next-word connections
            for next_word in entity_tree[current_word]:
                if next_word not in path and next_word <= end_word:  # Ensure it stays within entity boundaries
                    stack.append((next_word, path + [next_word]))

    for i in range(n):
        if entity_matrix[i, i] == 1:
            entities.add(words[i])

    return list(entities)


In [116]:
import numpy as np
import pandas as pd
import re
from datasets import load_dataset, Dataset, DatasetDict

# Load the CADEC dataset
cadec = load_dataset("KevinSpaghetti/cadec")

# Convert dataset to DataFrame
df = pd.DataFrame(cadec["train"])

# Function to remove punctuation
def normalize_text(text):
    return re.sub(r'[^\w\s]', '', text)

# Apply normalization to both text and ADEs
grouped_df = df.groupby("text")["ade"].apply(list).reset_index()
grouped_df["text"] = grouped_df["text"].apply(normalize_text)
grouped_df["ade"] = grouped_df["ade"].apply(lambda ade_list: [normalize_text(ade) for ade in ade_list])

# Apply transformation to generate entity matrices
grouped_df["entity_matrix"] = grouped_df.apply(lambda row: create_entity_matrix(row["text"], row["ade"]), axis=1)

# Convert each matrix to a 1D list format
grouped_df["entity_matrix"] = grouped_df["entity_matrix"].apply(lambda matrix: matrix.flatten().tolist())

# Sanity check on 5 random samples
sample_df = grouped_df.sample(5, random_state=42)

for index, row in sample_df.iterrows():
    text = row["text"]
    flat_matrix = row["entity_matrix"]

    # Convert back to a 2D matrix
    words = text.split()
    size = len(words)
    matrix = np.array(flat_matrix).reshape((size, size))
    
    extracted_entities = extract_entities_from_matrix(text, matrix)

    print("\n--- Sample ---")
    print(f"Text: {text}")
    print(f"Actual ADEs: {row['ade']}")
    print(f"Extracted Entities: {extracted_entities}")
    print(f"Match: {set(extracted_entities) == set(row['ade'])}")
    print("----------------------")

# Convert the dataset to Hugging Face format
dataset = Dataset.from_pandas(grouped_df)

# Save the dataset
dataset_dict = DatasetDict({"train": dataset})
dataset_dict.save_to_disk("./datasets/cadec")

print("\nDataset saved successfully in Hugging Face format!")



--- Sample ---
Text: i have vaginal bleeding and my vaginal skin burns a lot my cycle was back im menopausal  breast pain other symptoms i have erithema my skin is so dry i have little cuts on my hands i m sweating a lot usually i dont pruritus ani etc i took it 3 times the first for 3 days then i stopped because it didnt make great improvements and i had diarrhea my doctor told me the symptoms had to go away after taking few more second time i took other 3 pills in 3 days little improvement with the pain i had the first vaginal bleeding a lot of pruritus and my skin was all red but i didnt know it was arthrotec my doctor prescribed me a topical cream  did some vaginal test to exclude infections i had my cycle the last one for 4 days more problems and this time i had the idea to go on internet and i discovered i had a lot of the side effects by taking arthrotec
Actual ADEs: ['skin is so dry', 'vaginal skin burns', 'breast pain', 'pruritus', 'skin was all red', 'vaginal bleeding', 'swe

Saving the dataset (0/1 shards):   0%|          | 0/879 [00:00<?, ? examples/s]


Dataset saved successfully in Hugging Face format!


In [384]:
import numpy as np

def reachability(m,k):
    return np.diag(np.linalg.matrix_power(m,k))

def find_loops(adj_matrix):
    n = adj_matrix.shape[0]

    # Nodes that are NOT part of any cycle (diagonal elements are zero)
    m = [reachability(adj_matrix, i) for i in range(1, n+1)]

    return (np.ones((1,n)) @ m).flatten().tolist()

import numpy as np

def split_lower_triangular_ones(matrix):
    # Extract the upper triangular part (including diagonal)
    upper_triangle = np.triu(matrix)
    
    # Find positions of 1s in the lower triangle
    lower_positions = np.argwhere(np.tril(matrix, k=-1) == 1)
    
    # Generate new matrices, each keeping only one lower-triangle 1
    matrices = []
    for (i, j) in lower_positions:
        new_matrix = np.copy(upper_triangle)  # Copy upper triangle
        new_matrix[i, j] = 1  # Place one 1 in lower triangle
        matrices.append(new_matrix)
    
    return matrices

def nodes_with_arcs(adj_matrix):
    return np.any(adj_matrix, axis=0) | np.any(adj_matrix, axis=1)

# Example Usage
A = np.array([
    [0, 1, 0, 0, 1],  
    [0, 0, 1, 0, 0],  
    [1, 0, 0, 0, 0],  
    [0, 0, 0, 0, 0],
    [1, 0, 0, 0, 0]   
])

x = np.any([find_loops(a) for a in split_lower_triangular_ones(A)],axis=0)
x = ~(x ^ (~nodes_with_arcs(A)))

print(f"Nodes that are part of a loop: {x}")


Nodes that are part of a loop: [False False False False False]


# Reinforcement Learning

In [388]:
from torch.utils.data import DataLoader, Dataset
import torch

def collate_fn(batch, tokenizer, max_seq_length=128):
    """
    Collate function to pad words and entity matrices within a batch.

    Args:
        batch: List of tuples (words, entity_matrix).
        tokenizer: Tokenizer to encode and pad words.
        max_seq_length: Maximum sequence length for padding.

    Returns:
        - Encoded tokenized inputs (input_ids, attention_mask, word_ids).
        - Padded entity matrices tensor.
    """
    # Unpack tuples
    words_batch, matrices_batch, entities = zip(*batch)

    # Tokenize and pad words
    encoded = tokenizer(
        list(words_batch),
        is_split_into_words=True,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=max_seq_length
    )

    # Determine batch size and padded sequence length
    batch_size = len(batch)
    padded_seq_len = max_seq_length

    # Initialize a zero-padded entity matrix (batch_size, max_seq_length, max_seq_length)
    padded_matrices = torch.zeros((batch_size, padded_seq_len, padded_seq_len), dtype=torch.float32)

    for i, (matrix, words) in enumerate(zip(matrices_batch, words_batch)):
        size = len(words)
        matrix = matrix.reshape((size, size))
        seq_len = min(matrix.shape[0], padded_seq_len)  # Ensure we don't exceed max size
        padded_matrices[i, :seq_len, :seq_len] = matrix[:seq_len, :seq_len]
        
    return encoded, padded_matrices, entities

In [387]:
from torch.utils.data import Dataset
import torch
import os
import json

class EntityMatrixDataset(Dataset):
    def __init__(self, data_dir, tokenizer, max_seq_length=128):
        """
        Initialize the dataset by loading all JSON files from the directory.
        
        Args:
            data_dir: Path to the directory containing JSON files.
            tokenizer: Tokenizer to encode tokens.
            max_seq_length: Maximum sequence length for padding/truncation.
        """
        self.data = []
        self.tokenizer = tokenizer
        self.max_seq_length = max_seq_length

        for file_name in os.listdir(data_dir):
            if file_name.endswith(".json"):
                with open(os.path.join(data_dir, file_name), "r") as f:
                    entry = json.load(f)
                    
                    # **Extract only the single-word entity labels**
                    entry["entity_matrix"] = entry["entity_matrix"]  

                    self.data.append(entry)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        """
        Get the words and single-word entity vector for the given index.
        
        Returns:
            - words: List of words in the sentence.
            - single_word_labels: Tensor with 1s for single-word entities, 0s otherwise.
        """
        entry = self.data[idx]
        words = entry["words"]  # List of words
        entities = entry["entities"]
        entity_matrix = torch.tensor(entry["entity_matrix"], dtype=torch.float32)
        return words, entity_matrix, entities


In [389]:
import torch
import torch.nn as nn
from transformers import BertModel

class EntityMatrixPredictor(nn.Module):
    def __init__(self, bert_model_name="bert-base-cased", hidden_dim=768, num_heads=4, dropout=0.1):
        super(EntityMatrixPredictor, self).__init__()

        self.hidden_dim = hidden_dim
        self.bert = BertModel.from_pretrained(bert_model_name)

        # MLP to process word embeddings before span classification
        self.mlp_forward = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim), 
            nn.ReLU()
        )

        # Learnable vector for dot product projection
        self.v_forward = nn.Parameter(torch.randn(hidden_dim))  # (hidden_dim,)

        # MLP to process word embeddings before span classification
        self.mlp_backward = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim), 
            nn.ReLU()
        )

        # Learnable vector for dot product projection
        self.v_backward = nn.Parameter(torch.randn(hidden_dim))  # (hidden_dim,)

    def forward(self, input_ids, attention_mask, word_ids):
        batch_size, _ = input_ids.shape

        # Step 1: Get BERT Token Embeddings
        bert_output = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        token_embeddings = bert_output.last_hidden_state  # (batch, seq_len, hidden_dim)

        # Step 2: Aggregate Token Embeddings into Word Embeddings
        max_words = max([max([wid for wid in word_id if wid is not None], default=-1) + 1 for word_id in word_ids])
        word_embeddings = torch.zeros((batch_size, max_words, token_embeddings.shape[-1]), device=token_embeddings.device)

        for i in range(batch_size):
            word_counts = torch.zeros((max_words, 1), device=token_embeddings.device)
            for token_idx, word_idx in enumerate(word_ids[i]):
                if word_idx is not None:
                    word_embeddings[i, word_idx] += token_embeddings[i, token_idx]
                    word_counts[word_idx] += 1

            word_embeddings[i] /= word_counts.clamp(min=1)  # Avoid division by zero

        # Step 4: Construct Word Pair Matrix (Concatenation of i-th and j-th word)
        i_emb = word_embeddings.unsqueeze(2).expand(-1, -1, max_words, -1)  # (batch, max_words, max_words, hidden_dim)
        j_emb = word_embeddings.unsqueeze(1).expand(-1, max_words, -1, -1)  # (batch, max_words, max_words, hidden_dim)
        pair_matrix = torch.cat((i_emb, j_emb), dim=-1)  # (batch, max_words, max_words, hidden_dim * 2)

        logits_forward = torch.matmul(self.mlp_forward(pair_matrix), self.v_forward)  # (batch, max_words, max_words)
        logits_forward = torch.triu(logits_forward)

        logits_backward = torch.matmul(self.mlp_backward(pair_matrix), self.v_backward)  # (batch, max_words, max_words)
        logits_backward = torch.tril(logits_backward, diagonal=-1)

        logits = logits_forward + logits_backward

        return logits  # Raw logits (can be passed to BCEWithLogitsLoss)


In [394]:
model_archietcture = "matrix_cadec"

dataset_dir = f"./datasets/cadec"
model_path = f"./models/{model_archietcture}_rl.pth"

In [422]:
from torch.utils.data import DataLoader
import torch.optim as optim
from transformers import BertTokenizer, AdamW
import json

from transformers import AutoTokenizer

def get_train_loader(dataset_dir=dataset_dir):
    # Ensure you are using a Fast Tokenizer
    tokenizer = AutoTokenizer.from_pretrained("bert-base-cased", use_fast=True)

    # Load dataset
    train_dir = f"{dataset_dir}/train"  # Path to your dataset
    train_dataset = EntityMatrixDataset(train_dir, tokenizer)

    # Use the subset in the DataLoader
    train_loader = DataLoader(
        train_dataset,
        batch_size=32,
        shuffle=True,
        collate_fn=lambda batch: collate_fn(batch, tokenizer)
    )

    return train_loader


In [423]:
import torch
import torch.nn as nn

def training_loop(model = EntityMatrixPredictor(), device = torch.device("cuda" if torch.cuda.is_available() else "cpu"), epochs = 3, pos_weight= 20, model_path=model_path, verbose = False, lambda_penalty = 0.1):    

    model.to(device)
    loss_bce = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([pos_weight], device=device))

    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
    
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        train_loader = get_train_loader()

        for batch in train_loader:
            tokens = batch[0]
            target_matrix = batch[1].to(device)  # Ensure targets are on the correct device

            input_ids = tokens["input_ids"].to(device)
            attention_mask = tokens["attention_mask"].to(device)

            # Extract `word_ids` only once for the batch
            word_ids = [tokens.word_ids(batch_index=i) for i in range(len(input_ids))]
            
            optimizer.zero_grad()

            # Forward pass (now includes word_ids)
            predicted_matrix = model(input_ids=input_ids, attention_mask=attention_mask, word_ids=word_ids)

            # **Mask the target matrix to match valid words in predicted matrix**
            batch_size, max_words, _ = predicted_matrix.shape
            target_matrix = target_matrix[:, :max_words, :max_words]  # Trim to match predicted size

            rewards = []
            log_probs = []
            probs = torch.sigmoid(predicted_matrix)

            for i in range(probs.size(0)):  # Loop over batch
                sampled_matrix = torch.bernoulli(probs[i]).detach().cpu()
                reward = compute_loop_reward(sampled_matrix)
                if verbose:
                    print(f"Reward: {reward}")

                rewards.append(reward)

                log_p = (sampled_matrix * torch.log(probs[i] + 1e-6) +
                        (1 - sampled_matrix) * torch.log(1 - probs[i] + 1e-6)).mean()
                log_probs.append(log_p)

            rewards = torch.tensor(rewards, device=probs.device)
            rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-6)
            rl_penalty = -torch.stack(log_probs) @ rewards

            if verbose:
                print(f"RL Penalty: {rl_penalty}")

            # Compute loss directly over the valid portion
            loss = loss_bce(predicted_matrix, target_matrix) + lambda_penalty * rl_penalty
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        # Compute average loss for epoch
        epoch_loss = total_loss / len(train_loader)
        if verbose: print(f"Epoch {epoch+1}, Loss: {epoch_loss}")

    return model

In [424]:
model = training_loop(verbose=True)
torch.save(model.state_dict(), model_path)
print(f"model saved at {model_path}")

Reward: 0.0
Reward: 0.0
Reward: 0.0
Reward: 0.0
Reward: 0.0
Reward: 0.0
Reward: 0.0
Reward: 0.0
Reward: 0.0
Reward: 0.0
Reward: 0.0
Reward: 0.0
Reward: 0.0
Reward: 0.0
Reward: 0.0
Reward: 0.0
Reward: 0.0
Reward: 0.0
Reward: 0.0
Reward: 0.0
Reward: 0.0
Reward: 0.0
Reward: 0.0
Reward: 0.0
Reward: 0.0
Reward: 0.0
Reward: 0.0
Reward: 0.0
Reward: 0.0
Reward: 0.0
Reward: 0.0
Reward: 0.0
RL Penalty: 0.0
Reward: 0.0
Reward: 0.0
Reward: 0.0
Reward: 0.0
Reward: 0.0
Reward: 0.0
Reward: 0.0
Reward: 0.0
Reward: 0.017241379246115685
Reward: 0.008620689623057842
Reward: 0.0


KeyboardInterrupt: 

In [410]:
import torch

def reachability_torch(m, k):
    return torch.diag(torch.matrix_power(m, k))

def find_loops_torch(adj_matrix):
    n = adj_matrix.shape[0]
    loop_diag = [reachability_torch(adj_matrix, i) for i in range(1, n + 1)]
    return torch.stack(loop_diag).sum(dim=0).bool()  # shape: (n,)

def split_lower_triangular_ones_torch(matrix):
    upper_triangle = torch.triu(matrix)
    lower_positions = (torch.tril(matrix, diagonal=-1) == 1).nonzero(as_tuple=False)

    matrices = []
    for pos in lower_positions:
        i, j = pos
        new_matrix = upper_triangle.clone()
        new_matrix[i, j] = 1
        matrices.append(new_matrix)

    return matrices

def nodes_with_arcs_torch(adj_matrix):
    return (adj_matrix.sum(dim=0) > 0) | (adj_matrix.sum(dim=1) > 0)

def compute_loop_reward(pred_matrix):
    pred_binary = (pred_matrix > 0.5).int()

    split_matrices = split_lower_triangular_ones_torch(pred_binary)
    if not split_matrices:
        return torch.tensor(0.0, device=pred_matrix.device)

    loop_flags = [find_loops_torch(m) for m in split_matrices]
    nodes_with_loops = torch.stack(loop_flags).any(dim=0)
    nodes_with_edges = nodes_with_arcs_torch(pred_binary)

    # XNOR equivalent: ~(A ^ B)
    valid_nodes = ~(nodes_with_loops ^ (~nodes_with_edges))
    num_valid = valid_nodes.sum().float()

    return num_valid / pred_matrix.shape[0]


In [416]:
def extract_spans_from_matrix(matrix):
    """
    Extracts and merges entity spans by collapsing linked words in the **upper triangular** part of the entity matrix.

    Args:
        matrix (torch.Tensor): Binary entity matrix (size: max_words x max_words).

    Returns:
        merged_spans (set of tuples): Extracted entity spans in (start, end) format.
    """
    max_words = matrix.shape[0]
    spans = []

    # **Step 1: Extract Raw Spans from Upper Triangle**
    for i in range(max_words):
        if matrix[i, i] == 1:
                spans.append([i, i])

    start = -1
    for i in range(max_words-1):
        if matrix[i,i+1] == 1:
            if start == -1:
                start = i
        elif start != -1:
            spans.append([start,i])
            start = -1
    if start != -1:
         spans.append([start,max_words-1])

    return spans  # Convert to set for unique values


In [417]:
import torch
from torch.utils.data import DataLoader, Subset
from sklearn.metrics import precision_score, recall_score, f1_score
from transformers import AutoTokenizer

def evaluation_loop(model_path):
    """
    Evaluates the model by directly comparing the predicted and target matrices.

    Args:
        model_path (str): Path to the trained model.
    
    Returns:
        Precision, Recall, F1 Score computed at the matrix level.
    """
    # **Initialize tokenizer**
    bert_model_name = "bert-base-cased"
    tokenizer = AutoTokenizer.from_pretrained(bert_model_name, use_fast=True)

    # **Create the evaluation dataset**
    eval_dir = f"{dataset_dir}/validation"  # Replace with your validation directory path
    eval_dataset = EntityMatrixDataset(eval_dir, tokenizer)

    # **Create the evaluation DataLoader**
    eval_loader = DataLoader(
        eval_dataset,
        batch_size=8,  
        shuffle=False,  
        collate_fn=lambda batch: collate_fn(batch, tokenizer),  
        pin_memory=True
    )

    # **Load the trained model**
    model = EntityMatrixPredictor(bert_model_name="bert-base-cased")
    model.load_state_dict(torch.load(model_path))
    model.eval()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    # **Store results**
    all_preds = []
    all_targets = []

    # **Evaluation loop**
    with torch.no_grad():
        for batch in eval_loader:
            tokens, target_matrices, _ = batch  

            input_ids = tokens["input_ids"].to(device)
            attention_mask = tokens["attention_mask"].to(device)

            # **Extract word_ids for mapping token outputs to words**
            word_ids = [tokens.word_ids(batch_index=i) for i in range(len(input_ids))]

            # **Model prediction**
            predicted_matrix = model(input_ids=input_ids, attention_mask=attention_mask, word_ids=word_ids)

            # **Convert logits to binary predictions**
            binary_predictions = (predicted_matrix > 0.5).long()

            # **Flatten matrices for metric computation**
            for i in range(len(input_ids)):
                max_word_idx = max([wid for wid in word_ids[i] if wid is not None], default=-1) + 1

                # Extract the relevant part of the matrix
                pred_matrix = binary_predictions[i, :max_word_idx, :max_word_idx]
                target_matrix = target_matrices[i, :max_word_idx, :max_word_idx]

                # Flatten and store
                all_preds.extend(pred_matrix.cpu().numpy().flatten())
                all_targets.extend(target_matrix.cpu().numpy().flatten())

    # **Compute Metrics**
    precision = precision_score(all_targets, all_preds)
    recall = recall_score(all_targets, all_preds)
    f1 = f1_score(all_targets, all_preds)

    return precision, recall, f1


In [418]:
precision, recall, f1 = evaluation_loop(model_path=model_path)

print(f"Evaluation Results:")
print(f"Precision: {precision:.4f}, Recall: {recall:.4f}, F1 Score: {f1:.4f}")

  model.load_state_dict(torch.load(model_path))


Evaluation Results:
Precision: 0.0000, Recall: 0.0000, F1 Score: 0.0000
