In [22]:
import json
import torch
from transformers import BertTokenizer
from torch.utils.data import Dataset, DataLoader

In [36]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

def preprocess_pairs(context, entity_pairs):
    '''
    Preprocesses the already made entity_pairs
    Returns the input ids and attention masks for each context based on the pairs found.
    '''
    tokenized_data = []
    for pair in entity_pairs:
        e1 = pair['e1']
        e2 = pair['e2']
        
        # Find positions of entities in the context
        context_tokens = tokenizer.tokenize(context)
        e1_tokens = tokenizer.tokenize(e1)
        e2_tokens = tokenizer.tokenize(e2)
        
        try:
            e1_start = context_tokens.index(e1_tokens[0])
            e1_end = e1_start + len(e1_tokens) - 1
            e2_start = context_tokens.index(e2_tokens[0])
            e2_end = e2_start + len(e2_tokens) - 1
        except ValueError:
            #print(f"Entity {e1} or {e2} not found in context")
            continue

        # Mark entity positions with special tokens
        marked_tokens = context_tokens[:e1_start] + ['[E1]'] + e1_tokens + ['[/E1]'] + \
                        context_tokens[e1_end + 1:e2_start] + ['[E2]'] + e2_tokens + ['[/E2]'] + \
                        context_tokens[e2_end + 1:]
        
        # Convert tokens to IDs
        input_ids = tokenizer.convert_tokens_to_ids(marked_tokens)
        attention_mask = [1] * len(input_ids)
        
        tokenized_data.append((input_ids, attention_mask))
    return tokenized_data

In [99]:
# Function to preprocess context and entity pairs data
def preprocess_data(filepath, length=-1):
    with open(filepath, 'r') as file:
        data = json.load(file)

    tokenized_data = []
    
    if length == -1:
        length = len(data)
    
    name = filepath[20:].replace('_context_and_pairs.json','')

    for i in range(length):
        if i%50 == 0:
            print(f'{name}: {i/length * 100:.3f}%')

        elem = data[i]

        context = elem['context']
        entity_pairs = elem['pairs']
        preprocessed = preprocess_pairs(context, entity_pairs)
        
        tokenized_data.append(preprocessed)
    
    return tokenized_data

In [101]:
import pickle 
import time

start = time.time()

train = preprocess_data('../../ignore/docred_train_annotated_context_and_pairs.json')
validation = preprocess_data('../../ignore/docred_validation_context_and_pairs.json')

with open('../../ignore/train.pkl', 'wb') as file:
    pickle.dump(train, file)

with open('../../ignore/validation.pkl', 'wb') as file:
    pickle.dump(validation, file)

end = time.time()
print(f'Preprocessing took {(end-start)/60:.2f} minutes')

train_annotated: 0.000%
train_annotated: 1.638%
train_annotated: 3.275%
train_annotated: 4.913%
train_annotated: 6.551%
train_annotated: 8.189%
train_annotated: 9.826%
train_annotated: 11.464%
train_annotated: 13.102%
train_annotated: 14.740%
train_annotated: 16.377%
train_annotated: 18.015%
train_annotated: 19.653%
train_annotated: 21.291%
train_annotated: 22.928%
train_annotated: 24.566%
train_annotated: 26.204%
train_annotated: 27.841%
train_annotated: 29.479%
train_annotated: 31.117%
train_annotated: 32.755%
train_annotated: 34.392%
train_annotated: 36.030%
train_annotated: 37.668%
train_annotated: 39.306%
train_annotated: 40.943%
train_annotated: 42.581%
train_annotated: 44.219%
train_annotated: 45.857%
train_annotated: 47.494%
train_annotated: 49.132%
train_annotated: 50.770%
train_annotated: 52.407%
train_annotated: 54.045%
train_annotated: 55.683%
train_annotated: 57.321%
train_annotated: 58.958%
train_annotated: 60.596%
train_annotated: 62.234%
train_annotated: 63.872%
train_a

In [51]:
# Next, prepare the triplets
from datasets import load_dataset
import pandas as pd
from _RE import make_triplets

def preprocess_triplets(data, length=-1):
    annotated_triplets = []

    if length == -1:
        length = len(data)

    # Make gold labels
    for i in range(length):
        struct = {}
        elem = data.iloc[i]
        vertexSet = data['vertexSet'][i]
        labels = elem['labels']

        triplets = make_triplets(vertexSet, labels)
        struct['idx'] = i
        struct['triplets'] = triplets

        annotated_triplets.append(struct)

    triplets_df = pd.DataFrame(annotated_triplets)
    return triplets_df

In [58]:
dataset = load_dataset('docred', trust_remote_code=True)

train_docred = pd.DataFrame(dataset['train_annotated'])
validation_docred = pd.DataFrame(dataset['validation'])

train_gold = preprocess_triplets(train_docred, length=10)
validation_gold = preprocess_triplets(validation_docred, length=10)

In [76]:
print(train[0])
print(validation[0])
print(train_gold.head())
print(validation_gold.head())

[([100, 27838, 3367, 13095, 1010, 4297, 1012, 100, 3498, 2004, 2250, 15396, 27838, 3367, 1006, 3839, 100, 4004, 4382, 1998, 100, 27838, 3367, 2250, 1007, 1010, 2001, 1037, 2659, 1011, 3465, 8582, 2241, 2012, 1996, 20801, 2100, 24183, 2248, 3199, 1999, 14674, 4710, 2103, 1010, 6005, 9011, 1999, 1996, 5137, 1012, 2009, 3498, 5115, 4968, 1998, 2248, 7538, 2578, 1010, 3701, 21429, 2578, 11383, 9011, 1998, 23312, 2007, 2484, 4968, 14345, 1999, 2490, 1997, 1996, 8260, 2799, 3136, 1997, 2060, 7608, 1012, 1999, 2286, 1010, 1996, 8582, 2150, 2019, 8727, 1997, 5137, 2250, 15396, 4082, 2037, 4435, 10329, 1012, 2049, 2364, 2918, 2001, 20801, 2100, 24183, 2248, 3199, 1010, 9011, 1012, 1996, 8582, 2001, 2631, 2004, 4004, 4382, 1010, 1996, 2034, 8582, 1999, 1996, 5137, 2000, 2022, 2448, 2004, 1037, 10791, 1012, 2006, 2257, 2385, 1010, 2286, 1010, 1996, 2942, 5734, 3691, 1997, 1996, 5137, 1006, 6187, 9331, 1007, 1010, 1996, 21575, 2303, 1997, 1996, 2231, 1997, 1996, 3072, 1997, 1996, 5137, 2005, 2942,

In [80]:
token_lengths = [len(tokens) for tokens in train]
print("Token lengths:", token_lengths)
print("Max token length:", max(token_lengths))
print("Min token length:", min(token_lengths))

Token lengths: [311, 132, 133, 604, 184, 215, 242, 92, 132, 508]
Max token length: 604
Min token length: 92


In [77]:
import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence

class RelationExtractionDataset(Dataset):
    def __init__(self, tokenized_data, gold_labels, max_length=604):
        self.tokenized_data = tokenized_data
        self.gold_labels = gold_labels
        self.max_length = max_length

    def __len__(self):
        return len(self.tokenized_data)

    def __getitem__(self, idx):
        tokens = self.tokenized_data[idx]
        # Pad tokens to the maximum length
        if len(tokens) < self.max_length:
            tokens = tokens + [0] * (self.max_length - len(tokens))
        elif len(tokens) > self.max_length:
            tokens = tokens[:self.max_length]
        labels = self.gold_labels[self.gold_labels['idx'] == idx]['triplets'].values[0]
        return torch.tensor(tokens, dtype=torch.long), labels

# Creating datasets
train_dataset = RelationExtractionDataset(train, train_gold)
val_dataset = RelationExtractionDataset(validation, validation_gold)

# Creating data loaders
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False)

# Training loop
model = BiLSTMRelationExtractor(vocab_size, embedding_dim=100, hidden_dim=256, output_dim=97)
criterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

model.train()
total_loss = 0
for tokens, labels in train_loader:
    optimizer.zero_grad()
    outputs = model(tokens)
    loss = criterion(outputs, labels.float())
    loss.backward()
    optimizer.step()
    total_loss += loss.item()

print("Training loss:", total_loss / len(train_loader))


In [78]:
import torch.nn as nn
import torch.nn.functional as F

class RelationExtractionModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(RelationExtractionModel, self).__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        _, (hidden, _) = self.lstm(x)
        out = self.fc(hidden[-1])
        return out

input_dim = 768  # Dimension of the BERT embeddings
hidden_dim = 128
output_dim = 97  # Number of relations

model = RelationExtractionModel(input_dim, hidden_dim, output_dim)


In [79]:
import torch.optim as optim

criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
for epoch in range(3):  # Number of epochs
    model.train()
    total_loss = 0
    for tokens, labels in train_loader:
        optimizer.zero_grad()
        outputs = model(tokens)
        labels_padded = pad_sequence(labels, batch_first=True, padding_value=0)  # Pad labels
        labels_padded = labels_padded.type_as(outputs)  # Ensure the labels are the same type as outputs
        loss = criterion(outputs, labels_padded)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    print(f"Epoch {epoch+1}, Loss: {total_loss/len(train_loader)}")

    # Validation step (optional)
    model.eval()
    total_val_loss = 0
    with torch.no_grad():
        for tokens, labels in val_loader:
            outputs = model(tokens)
            labels_padded = pad_sequence(labels, batch_first=True, padding_value=0)  # Pad labels
            labels_padded = labels_padded.type_as(outputs)  # Ensure the labels are the same type as outputs
            loss = criterion(outputs, labels_padded)
            total_val_loss += loss.item()

    print(f"Validation Loss: {total_val_loss/len(val_loader)}")


ValueError: expected sequence of length 218 at dim 2 (got 314)