In [2]:
def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
    )

# model, tokenizer

In [3]:
additional_tokens = {'additional_special_tokens': ['[learn1]', '[learn2]', '[learn3]', '[learn4]', '[learn5]', '[learn6]']}

In [4]:
from transformers import GPT2LMHeadModel,  GPT2Tokenizer, GPT2Config, GPT2LMHeadModel
tokenizer = GPT2Tokenizer.from_pretrained('gpt2', bos_token='<|startoftext|>', eos_token='<|endoftext|>', pad_token='<|pad|>')

configuration = GPT2Config.from_pretrained('gpt2', output_hidden_states=False)

model = GPT2LMHeadModel.from_pretrained("gpt2", config=configuration)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [5]:
print_trainable_parameters(model)

trainable params: 124439808 || all params: 124439808 || trainable%: 100.0


In [6]:
num_added_toks = tokenizer.add_special_tokens(additional_tokens)

In [7]:
model.resize_token_embeddings(len(tokenizer))

Embedding(50265, 768)

# data

In [8]:
# load the json file from DocRED/data/test.json and DocRED/data/rel_info.json

import json

with open('DocRED/data/train_annotated.json') as f:
    train_set = json.load(f)


with open('DocRED/data/rel_info.json') as f:
    rel_info = json.load(f)

In [78]:
relation_dict = {
    'id': [],
    'text': [],
    'head': [],
    'tail': [],
    'head_first': [],
    'relation': [],
    'head_start_pos' : [],
    'tail_start_pos' : []
}

for i in range(len(train_set)):
    sents = ""
    for sent in train_set[i]['sents']:
        # flatten the sent list
        a = " ".join(sent)
        sents += a.lower() + " "
    # relation_dict['text'].append(sents)

    for relation_pair in train_set[i]['labels']:
        relation_dict['id'].append(i)
        relation_dict['text'].append(sents)
        head = []
        head_ = []
        head_start_pos = []
        head.append([[item['name'].lower()] for item in train_set[i]['vertexSet'][relation_pair['h']]])
        for j, item in enumerate(head[0]):
            if item not in head_:
                head_.append(item)
                head_start_pos.append(train_set[i]['vertexSet'][relation_pair['h']][j]['pos'][0])

        relation_dict['head'].append(head_)
        relation_dict['head_start_pos'].append(head_start_pos)

        tail = []
        tail_ = []
        tail_start_pos = []
        tail.append([[item['name'].lower()] for item in train_set[i]['vertexSet'][relation_pair['t']]])
        for j, item in enumerate(tail[0]):
            if item not in tail_:
                tail_.append(item)
                tail_start_pos.append(train_set[i]['vertexSet'][relation_pair['t']][j]['pos'][0])
        relation_dict['tail'].append(tail_)
        relation_dict['tail_start_pos'].append(tail_start_pos)

        
        if train_set[i]['vertexSet'][relation_pair['h']][0]['pos'][0] < train_set[i]['vertexSet'][relation_pair['t']][0]['pos'][0]:
            relation_dict['head_first'].append(1)
        else:
            relation_dict['head_first'].append(0)
        
        relation_dict['relation'].append(relation_pair['r'])
    # break


# save the relation_dict to a json file

with open('DocRED/data/DocRED_baseline_metadata/relation_dict.json', 'w') as f:
    json.dump(relation_dict, f)
        

In [36]:
from tqdm.notebook import trange, tqdm

import random

random.seed(42)

relation_dict = {
    # 'id': [],
    'text': [],
    'pair': [],
    # 'head_first': [],
    'relation': [],
    # 'head_start_pos' : [],
    # 'tail_start_pos' : []
}

for i in tqdm(range(len(train_set))):
    sents = ""
    for sent in train_set[i]['sents']:
        # flatten the sent list
        a = " ".join(sent)
        sents += a.lower() + " "
    # relation_dict['text'].append(sents)
    relation_lines = {k['r']: [] for k in train_set[i]['labels']}
    # print(relation_lines)
    for relation_pair in train_set[i]['labels']:
        # print(relation_pair)
        # relation_dict['text'].append(sents)
        heads = [item['name'].lower() for item in train_set[i]['vertexSet'][relation_pair['h']]]

        max_length, max_index = 0, 0
        for k, item in enumerate(heads):
            # head = the longest string in the example['head'], also record the index of the longest string
            if len(item) > max_length:
                max_length = len(item)
                max_index = k
            
        head = heads[max_index]

        # print(heads)
        # print(head)


        tails = [item['name'].lower() for item in train_set[i]['vertexSet'][relation_pair['t']]]
        max_length, max_index = 0, 0
        for k, item in enumerate(tails):
            # tail = the longest string in the example['tail'], also record the index of the longest string
            if len(item) > max_length:
                max_length = len(item)
                max_index = k

        tail = tails[max_index]

        # print(tails)
        # print(tail)

        relation_lines[relation_pair['r']].append((head, tail))
        
    # random choosing a relation in the rel_info that not be included in the relation_lines.keys()

    none_relation = random.choice(list(rel_info.keys()))
    while none_relation in relation_lines.keys():
        none_relation = random.choice(list(rel_info.keys()))
    relation_lines[none_relation] = [("none", "none")]

    for relation, pair in relation_lines.items():
        relation_dict['text'].append(sents)
        relation_dict['pair'].append(pair)
        relation_dict['relation'].append(rel_info[relation])
    # break


# save the relation_dict to a json file

with open('DocRED/data/DocRED_baseline_metadata/relation_dict_ner.json', 'w') as f:
    json.dump(relation_dict, f)
        

  0%|          | 0/3053 [00:00<?, ?it/s]

In [37]:
ner = 0

In [38]:
import json



from datasets import Dataset

if ner:
    with open('DocRED/data/DocRED_baseline_metadata/relation_dict.json') as f:
        relation_dict = json.load(f)

    dataset = Dataset.from_dict(
        {
            'text': relation_dict['text'],
            'head': relation_dict['head'],
            'tail': relation_dict['tail'],
            'head_first': relation_dict['head_first'],
            'relation': relation_dict['relation']
        }
    )

else:
    with open('DocRED/data/DocRED_baseline_metadata/relation_dict_ner.json') as f:
        relation_dict = json.load(f)

    dataset = Dataset.from_dict(
        {
            'text': relation_dict['text'],
            'pair': relation_dict['pair'],
            'relation': relation_dict['relation']
        }
    )

In [39]:
dataset

Dataset({
    features: ['text', 'pair', 'relation'],
    num_rows: 19431
})

In [52]:
def pro_processing_without_ner(example, tokenizer, padding=True):

    padding=True
    texts = example['text']

    # special_tokens = [50259, 50260, 50261, 50262, 50263, 50264]

    output_texts = []

    text_ids = tokenizer(texts, add_special_tokens=False)['input_ids']

    output_lines = []

    for i in range(len(example['pair'])):
        if example['pair'][i][0][0] != "none":
            output_line = f"for relation {example['relation'][i]} , " + f"[learn1] [learn2] [learn3] [learn4] [learn5] [learn6] "
            for pair in example['pair'][i]:
                output_line = output_line  + f"the source is {pair[0]} and the target is {pair[1]} ; "
            
            output_line = output_line[:-2] + ". " + tokenizer.eos_token
            output_lines.append(output_line)

        else:
            output_line = f"for relation {example['relation'][i]} , " + "[learn1] [learn2] [learn3] [learn4] [learn5] [learn6] the source is none . " + tokenizer.eos_token
            output_lines.append(output_line)

    # print(output_lines)


    output_ids = tokenizer(output_lines, add_special_tokens=False)['input_ids']

    # input_ids = []
    attention_mask = []

    count = 0
    for i, ids in enumerate(output_ids):
        if len(text_ids[i]) + len(ids) > 1024:
            text_ids[i] = text_ids[:1024 - len(ids)]
            count += 1
        text_ids[i] = text_ids[i] + output_ids[i]
        assert len(text_ids[i]) <= 1024
        attention_mask.append([1] * len(text_ids[i]) + [0] * (1024 - len(text_ids[i])))
    if count != 0:
        print(f"truncated {count} examples")

    if padding:
        for i, ids in enumerate(output_ids):
            output_ids[i] = ids + [tokenizer.pad_token_id] * (1024 - len(ids))
            text_ids[i] = text_ids[i] + [tokenizer.pad_token_id] * (1024 - len(text_ids[i]))

    return {
        'input_ids': text_ids,
        'attention_mask': attention_mask,
        }


In [9]:
def pro_processing_ner(example, tokenizer, padding=True):
    texts = example['text']

    # special_tokens = [50259, 50260, 50261, 50262, 50263, 50264]

    output_texts = []

    text_ids = tokenizer(texts, add_special_tokens=False)['input_ids']

    for i in range(len(example['head'])):
        head = ""
        for item in example['head'][i]:
            head += item[0] + " ; "
        head = head[:-2]
        head += ". "

        tail = ""
        for item in example['tail'][i]:
            tail += item[0] + " ; "
        tail = tail[:-2]
        tail += ". "
        
        if example['head_first'][i] == 1:
            output_line = " entity 1 : " + head + "entity 2 : " + tail + "[learn1] [learn2] [learn3] [learn4] [learn5] [learn6] "
            output_line = output_line + f"the relation between source entity 1 and target entity 2 is {rel_info[example['relation'][i]]} . " + tokenizer.eos_token

        else:
            output_line = " entity 1 : " + tail + "entity 2 : " + head + "[learn1] [learn2] [learn3] [learn4] [learn5] [learn6] "
            output_line = output_line + f"the relation between source entity 2 and target entity 1 is {rel_info[example['relation'][i]]} . " + tokenizer.eos_token
        
        output_texts.append(output_line)


    output_ids = tokenizer(output_texts, add_special_tokens=False)['input_ids']

    # input_ids = []
    attention_mask = []

    count = 0
    for i, ids in enumerate(output_ids):
        if len(text_ids[i]) + len(ids) > 1024:
            text_ids[i] = text_ids[:1024 - len(ids)]
            count += 1
        text_ids[i] = text_ids[i] + output_ids[i]
        assert len(text_ids[i]) <= 1024
        attention_mask.append([1] * len(text_ids[i]) + [0] * (1024 - len(text_ids[i])))
    if count != 0:
        print(f"truncated {count} examples")

    if padding:
        for i, ids in enumerate(output_ids):
            output_ids[i] = ids + [tokenizer.pad_token_id] * (1024 - len(ids))
            text_ids[i] = text_ids[i] + [tokenizer.pad_token_id] * (1024 - len(text_ids[i]))

    return {
        'input_ids': text_ids,
        'attention_mask': attention_mask,
        }

In [53]:
# tokenized_dataset = dataset.map(lambda example: pro_processing_ner(example, tokenizer), batched=True, remove_columns=['text', 'head', 'tail', 'head_first', 'relation'])

tokenized_dataset = dataset.map(lambda example: pro_processing_without_ner(example, tokenizer), batched=True, remove_columns=['text', 'pair', 'relation'])

Map:   0%|          | 0/19431 [00:00<?, ? examples/s]

In [54]:
# save the datasets type tokenized_dataset

tokenized_dataset.save_to_disk('DocRED/data/DocRED_baseline_metadata/tokenized_dataset_without_ner')


Saving the dataset (0/1 shards):   0%|          | 0/19431 [00:00<?, ? examples/s]

In [23]:
# load the tokenized_dataset

from datasets import load_from_disk

tokenized_dataset = load_from_disk('DocRED/data/DocRED_baseline_metadata/tokenized_dataset_without_ner')

In [24]:
tokenized_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask'])

# trainer

In [57]:
import wandb

wandb.init(
    # set the wandb project where this run will be logged
    project="GPT2-normal",
    # notes="PubmedBERT-FT-NER_w_NERin_10epochs",
    name="GPT2-DocRED-without-ner-5epochs"
)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33m309439737[0m ([33mtian1995[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [67]:
import transformers
from transformers import DataCollatorForLanguageModeling

trainer = transformers.Trainer(
    model=model, 
    train_dataset=tokenized_dataset,
    args=transformers.TrainingArguments(
        per_device_train_batch_size=2, 
        gradient_accumulation_steps=2,
        warmup_steps=1000, 
        num_train_epochs=5,
        learning_rate=2e-4, 
        # fp16=True,
        logging_steps=100, 
        report_to="wandb",
        save_strategy="epoch",
        output_dir='DocRED/GPT_without_ner'
    ),
    data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False)
)
model.config.use_cache = False  # silence the warnings. Please re-enable for inference!


In [68]:
trainer.train()



  0%|          | 0/24290 [00:00<?, ?it/s]

{'loss': 28.43, 'learning_rate': 2e-05, 'epoch': 0.02}
{'loss': 3.6175, 'learning_rate': 4e-05, 'epoch': 0.04}
{'loss': 3.2075, 'learning_rate': 6e-05, 'epoch': 0.06}
{'loss': 3.0989, 'learning_rate': 8e-05, 'epoch': 0.08}
{'loss': 3.0115, 'learning_rate': 0.0001, 'epoch': 0.1}
{'loss': 2.8918, 'learning_rate': 0.00012, 'epoch': 0.12}
{'loss': 2.8831, 'learning_rate': 0.00014, 'epoch': 0.14}
{'loss': 2.8157, 'learning_rate': 0.00016, 'epoch': 0.16}
{'loss': 2.73, 'learning_rate': 0.00018, 'epoch': 0.19}
{'loss': 2.7119, 'learning_rate': 0.0002, 'epoch': 0.21}
{'loss': 2.6979, 'learning_rate': 0.00019914126234435382, 'epoch': 0.23}
{'loss': 2.6357, 'learning_rate': 0.0001982825246887076, 'epoch': 0.25}
{'loss': 2.6032, 'learning_rate': 0.00019742378703306143, 'epoch': 0.27}
{'loss': 2.5512, 'learning_rate': 0.0001965650493774152, 'epoch': 0.29}
{'loss': 2.4882, 'learning_rate': 0.000195706311721769, 'epoch': 0.31}
{'loss': 2.4568, 'learning_rate': 0.00019484757406612282, 'epoch': 0.33}


TrainOutput(global_step=24290, training_loss=0.8649780624047284, metrics={'train_runtime': 20677.7711, 'train_samples_per_second': 4.699, 'train_steps_per_second': 1.175, 'train_loss': 0.8649780624047284, 'epoch': 5.0})

In [69]:
wandb.finish()
trainer.save_model("DocRED/GPT_without_ner/model")

# save the tokenizer
# tokenizer.save_pretrained("DocRED/GPT_w_ner/tokenizer")

0,1
train/epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/learning_rate,▂▇███▇▇▇▇▇▇▆▆▆▆▆▅▅▅▅▅▄▄▄▄▄▃▃▃▃▃▃▂▂▂▂▂▁▁▁
train/loss,█▆▆▅▅▅▄▄▃▃▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/total_flos,▁
train/train_loss,▁
train/train_runtime,▁
train/train_samples_per_second,▁
train/train_steps_per_second,▁

0,1
train/epoch,5.0
train/global_step,24290.0
train/learning_rate,0.0
train/loss,0.1358
train/total_flos,5.077165473792e+16
train/train_loss,0.86498
train/train_runtime,20677.7711
train/train_samples_per_second,4.699
train/train_steps_per_second,1.175


# Inference

In [1]:
from transformers import AutoModelForCausalLM

checkpoint = "DocRED/GPT_w_ner/model"

In [2]:
from transformers import GPT2Tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("DocRED/GPT_w_ner/tokenizer")
model = AutoModelForCausalLM.from_pretrained(checkpoint)

In [3]:
import torch

model.eval()
model.to("cpu")
inputs = tokenizer("Tweet text : @HondaCustSvc Your customer service has been horrible during the recall process. I will never purchase a Honda again. Label :", return_tensors="pt")

with torch.no_grad():
    outputs = model.generate(input_ids=inputs["input_ids"], max_new_tokens=10, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id)
    print(tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True)[0])



Tweet text : @HondaCustSvc Your customer service has been horrible during the recall process. I will never purchase a Honda again. Label : uk supreme motorsports   uk supreme motors


test data

In [53]:
# load the json file from DocRED/data/test.json and DocRED/data/rel_info.json

import json

with open('DocRED/data/dev.json') as f:
    test_set = json.load(f)


with open('DocRED/data/rel_info.json') as f:
    rel_info = json.load(f)

In [54]:
# with ner
from tqdm.notebook import trange, tqdm

relation_dict = {
    'id': [],
    'text': [],
    'head': [],
    'tail': [],
    'head_first': [],
    'relation': [],
    'head_start_pos' : [],
    'tail_start_pos' : []
}

for i in tqdm(range(len(test_set))):
    sents = ""
    for sent in test_set[i]['sents']:
        # flatten the sent list
        a = " ".join(sent)
        sents += a.lower() + " "
    # relation_dict['text'].append(sents)

    for relation_pair in test_set[i]['labels']:
        relation_dict['id'].append(i)
        relation_dict['text'].append(sents)
        head = []
        head_ = []
        head_start_pos = []
        head.append([[item['name'].lower()] for item in test_set[i]['vertexSet'][relation_pair['h']]])
        for j, item in enumerate(head[0]):
            if item not in head_:
                head_.append(item)
                head_start_pos.append(test_set[i]['vertexSet'][relation_pair['h']][j]['pos'][0])

        relation_dict['head'].append(head_)
        relation_dict['head_start_pos'].append(head_start_pos)

        tail = []
        tail_ = []
        tail_start_pos = []
        tail.append([[item['name'].lower()] for item in test_set[i]['vertexSet'][relation_pair['t']]])
        for j, item in enumerate(tail[0]):
            if item not in tail_:
                tail_.append(item)
                tail_start_pos.append(test_set[i]['vertexSet'][relation_pair['t']][j]['pos'][0])
        relation_dict['tail'].append(tail_)
        relation_dict['tail_start_pos'].append(tail_start_pos)

        
        if test_set[i]['vertexSet'][relation_pair['h']][0]['pos'][0] < test_set[i]['vertexSet'][relation_pair['t']][0]['pos'][0]:
            relation_dict['head_first'].append(1)
        else:
            relation_dict['head_first'].append(0)
        
        relation_dict['relation'].append(relation_pair['r'])
    # break


# save the relation_dict to a json file

with open('DocRED/data/DocRED_baseline_metadata/dev_relation_dict.json', 'w') as f:
    json.dump(relation_dict, f)

  0%|          | 0/998 [00:00<?, ?it/s]

In [82]:
# without ner
from tqdm.notebook import trange, tqdm

import random

random.seed(42)

relation_dict = {
    # 'id': [],
    'text': [],
    'pair': [],
    # 'head_first': [],
    'relation': [],
    # 'head_start_pos' : [],
    # 'tail_start_pos' : []
}

for i in tqdm(range(len(test_set))):
    sents = ""
    for sent in test_set[i]['sents']:
        # flatten the sent list
        a = " ".join(sent)
        sents += a.lower() + " "
    # relation_dict['text'].append(sents)
    relation_lines = {k['r']: [] for k in test_set[i]['labels']}
    # print(relation_lines)
    for relation_pair in test_set[i]['labels']:
        # print(relation_pair)
        # relation_dict['text'].append(sents)
        heads = [item['name'].lower() for item in test_set[i]['vertexSet'][relation_pair['h']]]

        max_length, max_index = 0, 0
        for k, item in enumerate(heads):
            # head = the longest string in the example['head'], also record the index of the longest string
            if len(item) > max_length:
                max_length = len(item)
                max_index = k
            
        head = heads[max_index]

        # print(heads)
        # print(head)


        tails = [item['name'].lower() for item in test_set[i]['vertexSet'][relation_pair['t']]]
        max_length, max_index = 0, 0
        for k, item in enumerate(tails):
            # tail = the longest string in the example['tail'], also record the index of the longest string
            if len(item) > max_length:
                max_length = len(item)
                max_index = k

        tail = tails[max_index]

        # print(tails)
        # print(tail)

        relation_lines[relation_pair['r']].append((head, tail))
        
    # random choosing a relation in the rel_info that not be included in the relation_lines.keys()

    none_relation = random.choice(list(rel_info.keys()))
    while none_relation in relation_lines.keys():
        none_relation = random.choice(list(rel_info.keys()))
    relation_lines[none_relation] = [("none", "none")]

    for relation, pair in relation_lines.items():
        relation_dict['text'].append(sents)
        relation_dict['pair'].append(pair)
        relation_dict['relation'].append(rel_info[relation])
    # break


# save the relation_dict to a json file

with open('DocRED/data/DocRED_baseline_metadata/dev_relation_dict_without_ner.json', 'w') as f:
    json.dump(relation_dict, f)
        

  0%|          | 0/998 [00:00<?, ?it/s]

In [11]:
ner = 1

In [12]:
import json



from datasets import Dataset

if ner:
    with open('DocRED/data/DocRED_baseline_metadata/dev_relation_dict.json') as f:
        relation_dict = json.load(f)

    dataset = Dataset.from_dict(
        {
            'text': relation_dict['text'],
            'head': relation_dict['head'],
            'tail': relation_dict['tail'],
            'head_first': relation_dict['head_first'],
            'relation': relation_dict['relation']
        }
    )

else:
    with open('DocRED/data/DocRED_baseline_metadata/dev_relation_dict_without_ner.json') as f:
        relation_dict = json.load(f)

    dataset = Dataset.from_dict(
        {
            'text': relation_dict['text'],
            'pair': relation_dict['pair'],
            'relation': relation_dict['relation']
        }
    )

In [57]:
dataset[7]

{'text': 'washington place ( william washington house ) is one of the first homes built by freed slaves after the emancipation proclamation of 1863 in hampshire county , west virginia , united states . washington place was built by william and annie washington in north romney between 1863 and 1874 on land given to annie by her former owner , susan blue parsons of wappocomo plantation . william washington later acquired other properties on the hills north of romney along west virginia route 28 and became the first african - american land developer in the state of west virginia . one of his subdivisions is the " blacks hill " neighborhood of romney , adjacent to the washington place homestead . washington place was bought and restored by ralph w. haines , a local attorney and historic preservationist . ',
 'head': [['emancipation proclamation']],
 'tail': [['united states']],
 'head_first': 1,
 'relation': 'P17'}

In [10]:
def pro_processing_without_ner_infer(example, tokenizer):

    padding=True
    texts = example['text']

    # special_tokens = [50259, 50260, 50261, 50262, 50263, 50264]

    output_texts = []

    text_ids = tokenizer(texts, add_special_tokens=False)['input_ids']

    output_lines = []

    for i in range(len(example['pair'])):
        # if example['pair'][i][0][0] != "none":
        output_line = f"for relation {example['relation'][i]} , [learn1] [learn2] [learn3][learn4] [learn5] [learn6] "
        # for pair in example['pair'][i]:
        #     output_line = output_line  + f"the source is {pair[0]} and the target is {pair[1]} ; "
        
        # output_line = output_line[:-2] + ". " + tokenizer.eos_token
        output_lines.append(output_line)
        if i == 5:
            print(output_line)

        # else:
        #     output_line = f"for relation {example['relation'][i]} , " + "[learn1] [learn2] [learn3] [learn4] [learn5] [learn6] the source is none . " + tokenizer.eos_token
        #     output_lines.append(output_line)

    # print(output_lines)


    output_ids = tokenizer(output_lines, add_special_tokens=False)['input_ids']

    # input_ids = []
    attention_mask = []

    count = 0
    for i, ids in enumerate(output_ids):
        if len(text_ids[i]) + len(ids) > 1024:
            text_ids[i] = text_ids[:1024 - len(ids)]
            count += 1
        text_ids[i] = text_ids[i] + output_ids[i]
        assert len(text_ids[i]) <= 1024
        # attention_mask.append([1] * len(text_ids[i]) + [0] * (1024 - len(text_ids[i])))
    if count != 0:
        print(f"truncated {count} examples")

    # if padding:
    #     for i, ids in enumerate(output_ids):
    #         output_ids[i] = ids + [tokenizer.pad_token_id] * (1024 - len(ids))
    #         text_ids[i] = text_ids[i] + [tokenizer.pad_token_id] * (1024 - len(text_ids[i]))

    return {
        'input_ids': text_ids,
        # 'attention_mask': attention_mask,
        }


In [67]:
def pro_processing_ner_infer(example, tokenizer):
    texts = example['text']

    # special_tokens = [50259, 50260, 50261, 50262, 50263, 50264]

    output_texts = []

    text_ids = tokenizer(texts, add_special_tokens=False)['input_ids']

    for i in range(len(example['head'])):
        head = ""
        for item in example['head'][i]:
            head += item[0] + " ; "
        head = head[:-2]
        head += ". "

        tail = ""
        for item in example['tail'][i]:
            tail += item[0] + " ; "
        tail = tail[:-2]
        tail += ". "
        
        if example['head_first'][i] == 1:
            output_line = "entity 1 : " + head + "entity 2 : " + tail + "[learn1] [learn2] [learn3] [learn4] [learn5] [learn6] "
            # output_line = output_line + f"the relation between source entity 1 and target entity 2 is {rel_info[example['relation'][i]]} . " + tokenizer.eos_token

        else:
            output_line = " entity 1 : " + tail + "entity 2 : " + head + "[learn1] [learn2] [learn3] [learn4] [learn5] [learn6] "
            # output_line = output_line + f"the relation between source entity 2 and target entity 1 is {rel_info[example['relation'][i]]} . " + tokenizer.eos_token
        
        output_texts.append(output_line)


    output_ids = tokenizer(output_texts, add_special_tokens=False)['input_ids']

    # input_ids = []
    # attention_mask = []

    count = 0
    for i, ids in enumerate(output_ids):
        if len(text_ids[i]) + len(ids) > 1024:
            text_ids[i] = text_ids[:1024 - len(ids)]
            count += 1
        text_ids[i] = text_ids[i] + output_ids[i]
        assert len(text_ids[i]) <= 1024
        # attention_mask.append([1] * len(text_ids[i]) + [0] * (1024 - len(text_ids[i])))
    if count != 0:
        print(f"truncated {count} examples")

    # if padding:
    #     for i, ids in enumerate(output_ids):
    #         output_ids[i] = ids + [tokenizer.pad_token_id] * (1024 - len(ids))
    #         text_ids[i] = text_ids[i] + [tokenizer.pad_token_id] * (1024 - len(text_ids[i]))

    return {
        'input_ids': text_ids,
        # 'attention_mask': attention_mask,
        }

In [60]:
dataset

Dataset({
    features: ['text', 'head', 'tail', 'head_first', 'relation'],
    num_rows: 12275
})

In [68]:
# tokenized_dataset = dataset.map(lambda example: pro_processing_without_ner_infer(example, tokenizer), batched=True, remove_columns=['text', 'pair', 'relation'])
tokenized_dataset = dataset.map(lambda example: pro_processing_ner_infer(example, tokenizer), batched=True, remove_columns=['text', 'head', 'tail', 'head_first', 'relation'])

Map:   0%|          | 0/12275 [00:00<?, ? examples/s]

In [69]:
tokenized_dataset.save_to_disk('DocRED/data/DocRED_baseline_metadata/dev_tokenized_dataset_w_ner')

Saving the dataset (0/1 shards):   0%|          | 0/12275 [00:00<?, ? examples/s]

In [63]:
len(tokenized_dataset)

12275

In [4]:
from datasets import load_from_disk

tokenized_dataset = load_from_disk('DocRED/data/DocRED_baseline_metadata/dev_tokenized_dataset_w_ner')

In [5]:
tokenized_dataset.set_format(type='torch', columns=['input_ids'])

In [6]:
tokenizer.decode(tokenized_dataset[0]['input_ids'])

"skai tv is a greek free - to - air television network based in piraeus. it is part of the skai group, one of the largest media groups in the country. it was relaunched in its present form on 1st of april 2006 in the athens metropolitan area, and gradually spread its coverage nationwide. besides digital terrestrial transmission, it is available on the subscription - based encrypted services of nova and cosmote tv. skai tv is also a member of digea, a consortium of private television networks introducing digital terrestrial transmission in greece. at launch, skai tv opted for dubbing all foreign language content into greek, instead of using subtitles. this is very uncommon in greece for anything except documentaries ( using voiceover dubbing ) and children's programmes ( using lip - synced dubbing ), so after intense criticism the station switched to using subtitles for almost all foreign shows. entity 1 : piraeus. entity 2 : greece. [learn1] [learn2] [learn3] [learn4] [learn5] [learn6]

inference

In [9]:
from tqdm.notebook import trange, tqdm


model.eval()
outputs = []
model.to("cuda")
with torch.no_grad():
    for i in tqdm(range(len(tokenized_dataset))):
    # for i in range(1):
        output = model.generate(input_ids=tokenized_dataset["input_ids"][i].unsqueeze(0).to("cuda"), max_new_tokens=50, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id)
        output_text = tokenizer.batch_decode(output.detach().cpu().numpy(), skip_special_tokens=False)[0]
        try:
            outputs.append(output_text.split("[learn1] [learn2] [learn3] [learn4] [learn5] [learn6]")[1].strip())
        except:
            outputs.append(output_text.split("[learn1][learn2][learn3][learn4][learn5][learn6]")[1].strip())
        if i % 100 == 0:
            print(outputs[-1])

    # print(tokenizer.batch_decode(output.detach().cpu().numpy(), skip_special_tokens=False)[0])

  0%|          | 0/12275 [00:00<?, ?it/s]

the relation between source entity 1 and target entity 2 is country. <|endoftext|>
the relation between source entity 1 and target entity 2 is country. <|endoftext|>
the relation between source entity 2 and target entity 1 is country. <|endoftext|>
the relation between source entity 2 and target entity 1 is country of citizenship. <|endoftext|>
the relation between source entity 2 and target entity 1 is located in the administrative territorial entity. <|endoftext|>
the relation between source entity 1 and target entity 2 is inception. <|endoftext|>
the relation between source entity 2 and target entity 1 is country. <|endoftext|>
the relation between source entity 1 and target entity 2 is mother. <|endoftext|>
the relation between source entity 1 and target entity 2 is country. <|endoftext|>
the relation between source entity 1 and target entity 2 is inception. <|endoftext|>
the relation between source entity 2 and target entity 1 is country. <|endoftext|>
the relation between source 

In [14]:
dataset[0]

{'text': "skai tv is a greek free - to - air television network based in piraeus . it is part of the skai group , one of the largest media groups in the country . it was relaunched in its present form on 1st of april 2006 in the athens metropolitan area , and gradually spread its coverage nationwide . besides digital terrestrial transmission , it is available on the subscription - based encrypted services of nova and cosmote tv . skai tv is also a member of digea , a consortium of private television networks introducing digital terrestrial transmission in greece . at launch , skai tv opted for dubbing all foreign language content into greek , instead of using subtitles . this is very uncommon in greece for anything except documentaries ( using voiceover dubbing ) and children 's programmes ( using lip - synced dubbing ) , so after intense criticism the station switched to using subtitles for almost all foreign shows . ",
 'head': [['piraeus']],
 'tail': [['greece']],
 'head_first': 1,


In [18]:
# load the rel_info

with open('DocRED/data/rel_info.json') as f:
    rel_info = json.load(f)

rel2id = {v: k for k, v in rel_info.items()}

In [21]:
# post processing for the outputs w ner
# (source, target, relation)
# (2, 1, relation)
pairs = []
count = 0
for output in outputs:
    # if the output doesn't end with "<|endoftext|>", find the lastest ";" of the output and only take the previous part
    try:
        source = output.split("between source ")[1].strip()
        source = source.split(" and target ")[0].strip()
        source = source.split("entity")[1].strip()

        target = output.split(" and target ")[1].strip()
        target = target.split(" is ")[0].strip()
        target = target.split("entity")[1].strip()

        relation = output.split(" is ")[-1].strip()
        relation = relation.split(". <|endoftext|>")[0].strip()

        try:
            relation = rel2id[relation]
        except:
            count += 1
            pass

        pairs.append((source, target, relation))
    except:
        pairs.append(("none", "none", "none"))

print(f"{count} / {len(outputs)}")

3 / 12275


In [22]:
dataset[0]

{'text': "skai tv is a greek free - to - air television network based in piraeus . it is part of the skai group , one of the largest media groups in the country . it was relaunched in its present form on 1st of april 2006 in the athens metropolitan area , and gradually spread its coverage nationwide . besides digital terrestrial transmission , it is available on the subscription - based encrypted services of nova and cosmote tv . skai tv is also a member of digea , a consortium of private television networks introducing digital terrestrial transmission in greece . at launch , skai tv opted for dubbing all foreign language content into greek , instead of using subtitles . this is very uncommon in greece for anything except documentaries ( using voiceover dubbing ) and children 's programmes ( using lip - synced dubbing ) , so after intense criticism the station switched to using subtitles for almost all foreign shows . ",
 'head': [['piraeus']],
 'tail': [['greece']],
 'head_first': 1,


In [24]:
result = {
    "output": [],
    "label": []
}

for i, output in enumerate(pairs):
    result['output'].append(output)
    if dataset[i]['head_first'] == 1:
        result['label'].append(('1', '2', dataset[i]['relation']))
    else:
        result['label'].append(('2', '1', dataset[i]['relation']))

In [25]:
# save the result dictionary
import pickle
with open("DocRED/GPT_w_ner/result/epoch_5_result.pkl", "wb") as f:
    pickle.dump(result, f)

In [45]:
# post processing for the outputs without ner
pairs = []
for output in outputs:
    pair = []
    # if the output doesn't end with "<|endoftext|>", find the lastest ";" of the output and only take the previous part
    if output.endswith("<|endoftext|>"):
        string = output

        for line in string.split(";"):
            try:
                source = line.split("the source is")[1].strip()
                source = source.split("and the target is")[0].strip()
                if source.startswith("none"):
                    source = "none"
                    target = "none"
                    continue
                target = line.split("the target is")[1].strip()
                if target.endswith(". <|endoftext|>"):
                    target = target.split(". <|endoftext|>")[0].strip()
                    
                if (source, target) not in pair:
                    pair.append((source, target))
            except:
                continue

    else:
        string = output.split(";")[:-1]
        string = [line.strip() for line in string]
        for line in string:
            try:
                source = line.split("the source is")[1].strip()
                source = source.split("and the target is")[0].strip()

                target = line.split("the target is")[1].strip()
                if (source, target) not in pair:
                    pair.append((source, target))
            except:
                continue
    
    pairs.append(pair)
            



In [38]:
output

'the source is hampshire county and the target is united states ; the source is west virginia and the target is united states ; the source is virginia route 28 and the target is united states ; the source is william washington house and the target is united states ; the source is north romney and the target is united states ; the source is william washington house and the target is united states ; the source is west virginia route 28 and the target is united states ; the'

In [49]:
result = {
    "output": [],
    "label": []
}

for output, label in zip(pairs, dataset['pair']):
    result['output'].append(output)
    result['label'].append(label)

In [52]:
# save the result dictionary
import pickle
with open("DocRED/GPT_without_ner/result/epoch_5_result.pkl", "wb") as f:
    pickle.dump(result, f)

# Analysis

w ner

In [33]:
import pickle
with open("DocRED/GPT_w_ner/result/epoch_5_result.pkl", "rb") as f:
    result = pickle.load(f)

In [34]:
print(f'the length: {len(result["output"])}, {len(result["label"])}')
print(f'instance:\n{result["output"][0]}\n{result["label"][0]}')

the length: 12275, 12275
instance:
('1', '2', 'P17')
('1', '2', 'P17')


In [35]:
# source and target, relation
st_tp = 0
st_fp = 0
st_fn = 0
st_tn = 0

r_tp = 0
r_fp = 0
r_fn = 0
r_tn = 0

tuple_tp = 0
tuple_fp = 0  
tuple_fn = 0
tuple_tn = 0


for output, label in zip(result['output'], result['label']):
    pair = False
    relation = False
    if output[0] == label[0] and output[1] == label[1]:
        st_tp += 1
        pair = True
    else:
        st_fn += 1
        st_fp += 1
    
    if output[2] == label[2]:
        r_tp += 1
        relation = True
    else:
        r_fn += 1
        r_fp += 1

    if pair and relation:
        tuple_tp += 1
    else:
        tuple_fn += 1
        tuple_fp += 1

In [37]:
# calculate the precision, recall and f1 score

# for source and target
st_precision = st_tp / (st_tp + st_fp)
st_recall = st_tp / (st_tp + st_fn)
st_f1 = 2 * st_precision * st_recall / (st_precision + st_recall)
print(f"source and target precision: {st_precision}, recall: {st_recall}, f1: {st_f1}")

# for relation
r_precision = r_tp / (r_tp + r_fp)
r_recall = r_tp / (r_tp + r_fn)
r_f1 = 2 * r_precision * r_recall / (r_precision + r_recall)
print(f"relation precision: {r_precision}, recall: {r_recall}, f1: {r_f1}")

# for tuple
tuple_precision = tuple_tp / (tuple_tp + tuple_fp)
tuple_recall = tuple_tp / (tuple_tp + tuple_fn)
tuple_f1 = 2 * tuple_precision * tuple_recall / (tuple_precision + tuple_recall)
print(f"tuple precision: {tuple_precision}, recall: {tuple_recall}, f1: {tuple_f1}")

source and target precision: 0.8511608961303462, recall: 0.8511608961303462, f1: 0.8511608961303462
relation precision: 0.7165784114052953, recall: 0.7165784114052953, f1: 0.7165784114052953
tuple precision: 0.6974338085539715, recall: 0.6974338085539715, f1: 0.6974338085539715


without ner

In [38]:
import pickle
with open("DocRED/GPT_without_ner/result/epoch_5_result.pkl", "rb") as f:
    result = pickle.load(f)

In [39]:
print(f'the length: {len(result["output"])}, {len(result["label"])}')
print(f'instance:\n{result["output"][0]}\n{result["label"][0]}')

the length: 6254, 6254
instance:
[('piraeus', 'greece'), ('athens metropolitan area', 'greece'), ('athens', 'greece')]
[['piraeus', 'greece'], ['skai group', 'greece'], ['athens', 'greece'], ['skai tv', 'greece']]


In [47]:
tuple_tp = 0
tuple_fp = 0  
tuple_fn = 0
tuple_tn = 0

for output, label in zip(result['output'], result['label']):
    for pair in output:
        true_tuple = False
        for label_pair in label:
            if pair[0] == label_pair[0] and pair[1] == label_pair[1]:
                tuple_tp += 1
                true_tuple = True
                break
        if not true_tuple:
            tuple_fp += 1
    
    for label_pair in label:
        true_tuple = False
        for pair in output:
            if pair[0] == label_pair[0] and pair[1] == label_pair[1]:
                true_tuple = True
                break
        if not true_tuple:
            tuple_fn += 1

In [48]:
# calculate the precision, recall and f1 score

# for tuple
tuple_precision = tuple_tp / (tuple_tp + tuple_fp)
tuple_recall = tuple_tp / (tuple_tp + tuple_fn)
tuple_f1 = 2 * tuple_precision * tuple_recall / (tuple_precision + tuple_recall)
print(f"tuple precision: {tuple_precision}, recall: {tuple_recall}, f1: {tuple_f1}")

tuple precision: 0.27586206896551724, recall: 0.172809667673716, f1: 0.2125011609547692


In [51]:
# loosen the condition for the tp


tuple_tp = 0
tuple_fp = 0  
tuple_fn = 0
tuple_tn = 0

for output, label in zip(result['output'], result['label']):
    for pair in output:
        true_tuple = False
        for label_pair in label:
            if (pair[0] in label_pair[0]) and (pair[1] in label_pair[1]):
                tuple_tp += 1
                true_tuple = True
                break
        if not true_tuple:
            tuple_fp += 1
    
    for label_pair in label:
        true_tuple = False
        for pair in output:
            if (pair[0] in label_pair[0]) and (pair[1] in label_pair[1]):
                true_tuple = True
                break
        if not true_tuple:
            tuple_fn += 1

# calculate the precision, recall and f1 score

# for tuple
tuple_precision = tuple_tp / (tuple_tp + tuple_fp)
tuple_recall = tuple_tp / (tuple_tp + tuple_fn)
tuple_f1 = 2 * tuple_precision * tuple_recall / (tuple_precision + tuple_recall)
print(f"tuple precision: {tuple_precision}, recall: {tuple_recall}, f1: {tuple_f1}")

tuple precision: 0.329153605015674, recall: 0.20828564888990617, f1: 0.2551282650343442
