In [654]:
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Model for NER 
model_name = 'dslim/distilbert-NER'

# Custom keywords
try:
    with open('keyword_matching/directory.pkl', 'rb') as file:
        keywords = pickle.load(file)
except FileNotFoundError or FileExistsError:
    with open('directory.pkl', 'rb') as file:
        keywords = pickle.load(file)

# Hyperparameters
input_size = 512
output_size = 768
num_layers = 4          # may require tuning
hidden_size = 256       # may require tuning
num_classes = 97        # 96 different relations plus '0' for no relation
learning_rate = 0.001   # may require tuning
batch_size = 32
num_epochs = 5
PAIR_EMBEDDING_WIDTH = 1540
PAIR_EMBEDDING_LENGTH = 3000

In [651]:
from torch.utils.data import DataLoader, Dataset
import importlib
import CustomDocREDDataset
importlib.reload(CustomDocREDDataset)

length = 100
train = CustomDocREDDataset.CustomDocREDDataset(
    dataset='train_annotated',
    input_size=input_size,
    model_name=model_name,
    custom_keywords=custom_keywords,
    device=device,
    length = length*2
)
'''test = CustomDocREDDataset.CustomDocREDDataset(
    dataset='test',
    input_size=input_size,
    model_name=model_name,
    custom_keywords=custom_keywords,
    device=device,
    length = length
)
val = CustomDocREDDataset.CustomDocREDDataset(
    dataset='validation',
    input_size=input_size,
    model_name=model_name,
    custom_keywords=custom_keywords,
    device=device,
    length = length
)
'''
train_loader = DataLoader(train, batch_size=batch_size, shuffle=True, collate_fn=CustomDocREDDataset.custom_collate_fn)
#test_loader = DataLoader(test, batch_size=batch_size, shuffle=False, collate_fn=CustomDocREDDataset.custom_collate_fn)
#val_loader = DataLoader(val, batch_size=batch_size, shuffle=False, collate_fn=CustomDocREDDataset.custom_collate_fn)

Starting preprocessing...
0.00% finished
25.00% finished
50.00% finished
75.00% finished
Preprocessing took 0.81 minutes.


In [688]:
import importlib
import REBRNN
importlib.reload(REBRNN)
model = REBRNN.RelationExtractorBRNN(
    input_size=output_size,
    hidden_size=hidden_size,
    num_layers=num_layers,
    num_classes=num_classes,
    pair_embedding_width=PAIR_EMBEDDING_WIDTH,
    pair_embedding_length=PAIR_EMBEDDING_LENGTH,
    model_name=model_name,
    device=device
).to(device)
custom_loss_fn = REBRNN.CustomLoss(threshold=0.8)
optimizer = optim.Adam(model.parameters(), lr=0.001)
model.train()

for batch in train_loader:
    text_embeddings = batch['text_embeddings']
    pair_embeddings = batch['pair_embeddings']
    triplet_embeddings = batch['triplet_embeddings']

    preds = model(text_embeddings, pair_embeddings) # shape [32, 3000]

    # for each item in preds, compare to true label using triplets and custom loss function
    # get total loss across batch (alter custom loss function)




torch.Size([32, 3000])
torch.Size([32, 3000])
torch.Size([32, 3000])
torch.Size([32, 3000])
torch.Size([32, 3000])
torch.Size([32, 3000])
torch.Size([8, 3000])


In [702]:
'''
To do:
- sanity check ending position
'''
from datasets import load_dataset
from torch.utils.data import DataLoader, Dataset
import importlib
import CustomDocREDDataset
importlib.reload(CustomDocREDDataset)

data = load_dataset('docred', trust_remote_code=True)

Starting preprocessing...


In [822]:
import CustomDocREDDataset
importlib.reload(CustomDocREDDataset)
custom_dataset_class = CustomDocREDDataset.CustomDocREDDataset(
    dataset='train_annotated',
    input_size=input_size,
    model_name=model_name,
    custom_keywords=custom_keywords,
    device=device,
    length = 0
)

Starting preprocessing...
Preprocessing skipped.


In [825]:
train = pd.DataFrame(data['train_annotated'])
distant = pd.DataFrame(data['train_distant'])
test = pd.DataFrame(data['test'])
val = pd.DataFrame(data['validation'])

In [824]:
def check_indicing(data):
    length = len(data)
    indices = []

    for i in range(length):

        instance = train.iloc[i]
        raw_sents = instance['sents']

        sents, _, _ = custom_dataset_class.get_info(instance)

        count1 = sum(len(sent) for sent in raw_sents)
        count2 = sum(len(sent.split(' ')) for sent in sents)

        if count1 != count2:
            indices.append(i)

    return indices

print('train:', check_indicing(train))
print('distant:', check_indicing(distant))
print('test:', check_indicing(test))
print('val:', check_indicing(val))

train: []
test: []
val: []


In [823]:
instance = train.iloc[2404]
raw_sents = instance['sents']

sents, _, _ = custom_dataset_class.get_info(instance)
sents_split = [sent.split(' ') for sent in sents]
print('\n')

print(raw_sents[4])
new = sents_split[4]
print(new)



['Then', 'I', 'Saw', 'The', 'Holy', 'City', '\xa0 ', 'was', 'produced', 'by', 'Brian', 'Coates', ',', 'engineer', '/', 'producer', 'for', 'The', 'Dandy', 'Warhols', ',', 'and', 'released', 'in', 'the', 'fall', 'of', '2004', 'on', 'The', 'Kora', 'Records', '.']
['Then', 'I', 'Saw', 'The', 'Holy', 'City', '_', 'was', 'produced', 'by', 'Brian', 'Coates', ',', 'engineer', '/', 'producer', 'for', 'The', 'Dandy', 'Warhols', ',', 'and', 'released', 'in', 'the', 'fall', 'of', '2004', 'on', 'The', 'Kora', 'Records', '.']


In [789]:
SPACE_TOKEN = '[SPACE]'

print(new)

for i, word in enumerate(new):
    if word == SPACE_TOKEN:
        new[i] = ' '

print(new)

['In', '2015/16', 'she', 'co', '-', 'wrote', 'for', '[SPACE]', 'Bridget', 'Jones', "'", 'Diary', '(', 'musical', ')', '[SPACE]', 'with', 'Lily', 'Allen', 'and', 'Greg', 'Kurstin']
['In', '2015/16', 'she', 'co', '-', 'wrote', 'for', ' ', 'Bridget', 'Jones', "'", 'Diary', '(', 'musical', ')', ' ', 'with', 'Lily', 'Allen', 'and', 'Greg', 'Kurstin']
