In [2]:
def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
    )

# model, tokenizer

In [3]:
additional_tokens = {'additional_special_tokens': ['[learn1]', '[learn2]', '[learn3]', '[learn4]', '[learn5]', '[learn6]']}

In [1]:
from transformers import GPT2LMHeadModel,  GPT2Tokenizer, GPT2Config, GPT2LMHeadModel
tokenizer = GPT2Tokenizer.from_pretrained('gpt2', bos_token='<|startoftext|>', eos_token='<|endoftext|>', pad_token='<|pad|>')

configuration = GPT2Config.from_pretrained('gpt2', output_hidden_states=False)

model = GPT2LMHeadModel.from_pretrained("gpt2", config=configuration)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [4]:
print_trainable_parameters(model)

trainable params: 124439808 || all params: 124439808 || trainable%: 100.0


In [5]:
num_added_toks = tokenizer.add_special_tokens(additional_tokens)

In [6]:
model.resize_token_embeddings(len(tokenizer))

Embedding(50265, 768)

In [7]:
# save the tokenizer

tokenizer.save_pretrained('DocRED/GPT_w_ner_short_relation/gpt2_tokenizer')

('DocRED/GPT_w_ner_short_relation/gpt2_tokenizer/tokenizer_config.json',
 'DocRED/GPT_w_ner_short_relation/gpt2_tokenizer/special_tokens_map.json',
 'DocRED/GPT_w_ner_short_relation/gpt2_tokenizer/vocab.json',
 'DocRED/GPT_w_ner_short_relation/gpt2_tokenizer/merges.txt',
 'DocRED/GPT_w_ner_short_relation/gpt2_tokenizer/added_tokens.json')

# data

In [8]:
# load the json file from DocRED/data/test.json and DocRED/data/rel_info.json

import json

with open('DocRED/data/train_annotated.json') as f:
    train_set = json.load(f)


with open('DocRED/data/rel_info.json') as f:
    rel_info = json.load(f)

In [9]:
train_set[0]

{'vertexSet': [[{'pos': [0, 4],
    'type': 'ORG',
    'sent_id': 0,
    'name': 'Zest Airways, Inc.'},
   {'sent_id': 0,
    'type': 'ORG',
    'pos': [10, 15],
    'name': 'Asian Spirit and Zest Air'},
   {'name': 'AirAsia Zest', 'pos': [6, 8], 'sent_id': 0, 'type': 'ORG'},
   {'name': 'AirAsia Zest', 'pos': [19, 21], 'sent_id': 6, 'type': 'ORG'}],
  [{'name': 'Ninoy Aquino International Airport',
    'pos': [4, 8],
    'sent_id': 3,
    'type': 'LOC'},
   {'name': 'Ninoy Aquino International Airport',
    'pos': [26, 30],
    'sent_id': 0,
    'type': 'LOC'}],
  [{'name': 'Pasay City', 'pos': [31, 33], 'sent_id': 0, 'type': 'LOC'}],
  [{'name': 'Metro Manila', 'pos': [34, 36], 'sent_id': 0, 'type': 'LOC'}],
  [{'name': 'Philippines', 'pos': [38, 39], 'sent_id': 0, 'type': 'LOC'},
   {'name': 'Philippines', 'pos': [13, 14], 'sent_id': 4, 'type': 'LOC'},
   {'sent_id': 5,
    'type': 'LOC',
    'pos': [25, 29],
    'name': 'Republic of the Philippines'}],
  [{'name': 'Manila', 'pos': 

In [10]:
"""
the names of vertextSet can be the same, but the pos should be different
structure:
'vertexSet': 
    [
        (for the same entity but in different synonyms and different sentences)
        [
            {
                'pos':[start, end],
                'type': 'NER',
                'sent_id': 0,
                'name': 'string',
            },
            {}
        ],
        [entity-2]
    ]
'labels':
    [
        {
            'r': 'Pxx',
            'h': 0,
            't': 1,
            'evidence': [2, 3, 4],
        },
        {}
    ]
'title': 'string',
'sents':
    [
        ['word0', 'word1',]
        ['word0', 'word1',]
    ]
"""

"\nthe names of vertextSet can be the same, but the pos should be different\nstructure:\n'vertexSet': \n    [\n        {\n            'pos':[start, end],\n            'type': 'NER',\n            'sent_id': 0,\n            'name': 'string',\n        },\n        {}\n    ]\n'labels':\n    [\n        {\n            'r': 'Pxx',\n            'h': 0,\n            't': 1,\n            'evidence': [2, 3, 4],\n        },\n        {}\n    ]\n'title': 'string',\n'sents':\n    [\n        ['word0', 'word1',]\n        ['word0', 'word1',]\n    ]\n"

In [89]:
with open('DocRED/data/ner_info.json') as f:
    ner_info = json.load(f)

with open('DocRED/data/rel_info.json') as f:
    relation_info = json.load(f)

In [12]:
"""# doc-level is too long for gpt-2, so we need to split the doc-level into bi-sent-level

relation_dict = {
    'id': [],
    'text':[],
    'entity': [],
    'relation': []
}

for i in range(len(train_set)):
    # id
    relation_dict['id'].append(i)

    # text
    sents = ""
    for sent in train_set[i]['sents']:
        # flatten the sent list
        a = " ".join(sent)
        sents += a.lower() + " "
    # if there are space, delete the first and last space of the sents
    sents = sents.strip()
    # delete double space in the sents
    sents = sents.replace("  ", " ")
    relation_dict['text'].append(sents)
    del sents

    # entity
    entity = []
    entity_list = []
    entity_flat = {}
    entity_count = 0
    for sent_item in train_set[i]['vertexSet']:
        for item in sent_item:
            entity_item = []
            if item['name'].lower() not in entity_list:
                entity_list.append(item['name'].lower().strip())
                entity_item.append(item['name'].lower().strip())
                entity_item.append(ner_info[item['type']])

                entity.append(entity_item)
            
            # add the entity_flat
            entity_flat[entity_count] = item['name'].lower().strip()
            entity_count += 1

    # release the entity_list and entity_item
    del entity_item
    del entity_count
        

    # relation pairs
    relation_pairs = {}
    for relation_item in train_set[i]['labels']:
        pair = []
        head = entity_flat[relation_item['h']]
        tail = entity_flat[relation_item['t']]
        pair.append(head)
        pair.append(tail)

        relation  = relation_info[relation_item['r']]
        if relation not in relation_pairs.keys():
            relation_pairs[relation] = []

        relation_pairs[relation].append(pair)
    del pair
    del head
    del tail

    # add the entity and relation pairs to the relation_dict
    relation_dict['entity'].append(entity)
    relation_dict['relation'].append(relation_pairs)
    break


# save the relation_dict to a json file

# with open('DocRED/data/DocRED_baseline_metadata/relation_dict.json', 'w') as f:
#     json.dump(relation_dict, f)"""

In [91]:
relation_dict = {
    'text':[],
    'entity': [],
    'relation': []
}
id_count = 0

for i in range(len(train_set)):
    

    # text
    sent_pairs = []
    for sent_index, sent in enumerate(train_set[i]['sents']):
        sents = ""

        # flatten the sent list
        a = " ".join(sent)
        sents += a.lower() + " "

        # and the next_sent if it exists
        try:
            next_sent = train_set[i]['sents'][sent_index + 1]
            b = " ".join(next_sent)
            sents += b.lower() + " "
        except:
            pass
        # post process the sents for some spaces
        sents = sents.strip()
        sents = sents.replace("  ", " ")

        relation_dict['text'].append(sents)
            
    del sents


    # entity
    entity = []
    for index in range(len(train_set[i]['sents'])):
        # focus on the current sent and the next sent if it exists
        if index + 1 < len(train_set[i]['sents']):
            next_index = index + 1
        else:
            next_index = index
        # group the entities for every 2 sents, no repeated entities in one group
        c_sent_entity_lists = []
        next_sent_entity_lists = []
        entity_for_each_2_sents = []

        for entity_spans in train_set[i]['vertexSet']:
            for item in entity_spans:
                entity_item = []
                # if neither in the current sent nor in the next sent, continue
                if item['sent_id'] != index and item['sent_id'] != next_index:
                    continue
                # also store the first pos of the entity in the entity_item
                # it will look like this: [[entity_name, sent_index, pos1, ner_type]]
                entity_item = [item['name'].lower().strip(), item['sent_id'], item['pos'][0], ner_info[item['type']]]

                if entity_item[1] == index:
                    c_sent_entity_lists.append(entity_item)
                else:
                    next_sent_entity_lists.append(entity_item)

        # sort the c_sent_entity_lists and next_sent_entity_lists by the pos in ascending order
        c_sent_entity_lists.sort(key=lambda x: x[2])
        if index != next_index:
            next_sent_entity_lists.sort(key=lambda x: x[2])
        
        entity_list = []
        for item in c_sent_entity_lists:
            if item[0] not in entity_list:
                entity_list.append(item[0])
                entity_for_each_2_sents.append([item[0], item[3]])
            
        if index != next_index:
            for item in next_sent_entity_lists:
                if item[0] not in entity_list:
                    entity_list.append(item[0])
                    entity_for_each_2_sents.append([item[0], item[3]])

        relation_dict['entity'].append(entity_for_each_2_sents)

    del entity_item
    del c_sent_entity_lists
    del next_sent_entity_lists
    del entity_for_each_2_sents
        

    # relation pairs
    relation_pairs = []

    for index in range(len(train_set[i]['sents'])):
        relation_pairs_for_each_2_sents = {}
        # focus on the current sent and the next sent if it exists
        if index + 1 < len(train_set[i]['vertexSet']):
            next_index = index + 1
        else:
            next_index = index

        # heads, tails: ['entity_name', start_pos]
        for relation_item in train_set[i]['labels']:
            heads = []
            tails = []
            
            # head
            head_exist = False
            for head_span in train_set[i]['vertexSet'][relation_item['h']]:
                if head_span['sent_id'] == index or head_span['sent_id'] == next_index:
                    heads.append([head_span['name'].lower().strip(), head_span['pos'][0]])
                    head_exist = True
            if not head_exist:
                continue
    
            # tail
            tail_exist = False
            for tail_span in train_set[i]['vertexSet'][relation_item['t']]:
                if tail_span['sent_id'] == index or tail_span['sent_id'] == next_index:
                    tails.append([tail_span['name'].lower().strip(), tail_span['pos'][0]])
                    tail_exist = True
            if not tail_exist:
                continue
            

            if relation_info[relation_item['r']] not in relation_pairs_for_each_2_sents.keys():
                relation_pairs_for_each_2_sents[relation_info[relation_item['r']]] = []
            for head in heads:
                for tail in tails:
                    relation_pairs_for_each_2_sents[relation_info[relation_item['r']]].append([head[0], tail[0]])

        relation_dict['relation'].append(relation_pairs_for_each_2_sents)


    
    # break


# save the relation_dict to a json file

# with open('DocRED/data/bi-sent-pre-process.json', 'w') as f:
    # json.dump(relation_dict, f)

In [65]:
len(relation_dict['text']) == len(relation_dict['entity']) == len(relation_dict['relation'])

True

In [98]:
relation_dict['relation'][0]

{'headquarters location': [['zest airways, inc.', 'pasay city'],
  ['asian spirit and zest air', 'pasay city'],
  ['airasia zest', 'pasay city']],
 'country': [['zest airways, inc.', 'philippines'],
  ['asian spirit and zest air', 'philippines'],
  ['airasia zest', 'philippines'],
  ['pasay city', 'philippines'],
  ['manila', 'philippines'],
  ['metro manila', 'philippines'],
  ['ninoy aquino international airport', 'philippines']],
 'located in the administrative territorial entity': [['pasay city',
   'metro manila'],
  ['metro manila', 'philippines'],
  ['ninoy aquino international airport', 'pasay city']],
 'contains administrative territorial entity': [['philippines',
   'metro manila'],
  ['metro manila', 'pasay city']]}

In [8]:
ner = 1

In [9]:
import json


from datasets import Dataset

relation_dict = {}
if ner:
    with open('DocRED/data/bi-sent-pre-process.json') as f:
        relation_dict = json.load(f)

    dataset = Dataset.from_dict(
        {
            'text': relation_dict['text'],
            'entity': relation_dict['entity'],
            'relation': relation_dict['relation']
        }
    )

else:
    pass

In [10]:
dataset

Dataset({
    features: ['text', 'entity', 'relation'],
    num_rows: 24256
})

In [11]:
dataset[0]

{'text': 'zest airways , inc. operated as airasia zest ( formerly asian spirit and zest air ) , was a low - cost airline based at the ninoy aquino international airport in pasay city , metro manila in the philippines . it operated scheduled domestic and international tourist services , mainly feeder services linking manila and cebu with 24 domestic destinations in support of the trunk route operations of other airlines .',
 'entity': [['zest airways, inc.', 'organization'],
  ['airasia zest', 'organization'],
  ['asian spirit and zest air', 'organization'],
  ['ninoy aquino international airport', 'location'],
  ['pasay city', 'location'],
  ['metro manila', 'location'],
  ['philippines', 'location'],
  ['manila', 'location'],
  ['cebu', 'location'],
  ['24', 'number']],
 'relation': {'applies to jurisdiction': None,
  'author': None,
  'award received': None,
  'basin country': None,
  'capital': None,
  'capital of': None,
  'cast member': None,
  'chairperson': None,
  'characters':

In [12]:
len(dataset['entity'][0])

10

In [13]:
"""relation_info_dict = {}
for id, relation in enumerate(dataset[0]['relation'].keys()):
    relation_info_dict[relation] = id

with open('DocRED/data/relation-index.json', 'w') as f:
    json.dump(relation_info_dict, f)"""

with open('DocRED/data/relation-index.json') as f:
    relation_info_dict = json.load(f)

In [16]:
def pro_processing_ner(example, tokenizer, padding=True):
    texts = example['text']
    input_texts = []
    for index in range(len(texts)):
        # entity extraction and NER
        text = texts[index].lower().strip() + " [learn1] [learn2]"
        for entity in example['entity'][index]:
            text = text + " entity : " + entity[0] + " , type : " + entity[1] + " ;"
        text = text[:-1] + "."
        # print("1")
        # add relation classificaiton
        text = text.lower().strip() + " [learn3] [learn4]"
        for relation_type, relation_pair in example['relation'][index].items():
            if relation_pair:
                text_w_relation = text + " for the relation " + relation_type + " : 1 ."
                text_w_relation = text_w_relation.lower().strip() + " [learn5] [learn6]"
                text_w_relation = text_w_relation + " and the entity for the relation " + relation_type + " are :"
                for pair in relation_pair:
                    text_w_relation = text_w_relation + " head entity: " + pair[0] + " , tail entity: " + pair[1] + ";"
                text_w_relation = text_w_relation[:-1] + "." + tokenizer.eos_token
                input_texts.append(text_w_relation)

            else:
                text_w_relation = text + " for the relation " + relation_type + " : 0 ." + tokenizer.eos_token
                input_texts.append(text_w_relation)

    return {
        'input_ids': input_texts
        }

In [15]:
tokenized_dataset = dataset.map(lambda example: pro_processing_ner(example, tokenizer), batched=True)

Map:   0%|          | 0/24256 [00:00<?, ? examples/s]

KeyboardInterrupt: 

In [27]:
# feed the dataset:dataset to the pro_processing_ner() function with tokenizer, at each time, we feed 30 examples to the function, and then save the output to a json file
# each time the return of the function is a dict, we need to save the dict to a list, and then save the list to a json file

import json
from tqdm import tqdm


output = {"input_texts": []}

for i in tqdm(range(0, len(dataset), 30)):
    result = pro_processing_ner(dataset[i:i+30], tokenizer)
    output["input_texts"].extend(result["input_ids"])

100%|██████████| 809/809 [00:21<00:00, 37.95it/s]


In [32]:
len(output['input_texts'])

2328576

In [8]:
# make the output["input_texts"] into a dataset
from datasets import Dataset

input_text_dataset = Dataset.from_dict(
    {
        'input_texts': output['input_texts'],
    }
)


NameError: name 'output' is not defined

In [None]:
tokenized_dataset = input_text_dataset.map(lambda example: tokenizer(example['input_texts'], padding='max_length', truncation=True, max_length=1024, pad_to_max_length=True), batched=True)

Map:   0%|          | 0/2328576 [00:00<?, ? examples/s]

In [37]:
# remove the column of input_texts in the tokenized_dataset
tokenized_dataset.remove_columns('input_texts')

Dataset({
    features: ['input_texts', 'input_ids', 'attention_mask'],
    num_rows: 2328576
})

In [39]:
# save the tokenized_dataset

tokenized_dataset.save_to_disk('DocRED/GPT_w_ner_short_relation/train_data_ner_short_relation')
# with open('DocRED/data/train_ner_short_relation.json', 'w') as f:
#     json.dump(tokenized_dataset, f)

Saving the dataset (0/27 shards):   0%|          | 0/2328576 [00:00<?, ? examples/s]

In [9]:


from datasets import Dataset

tokenized_dataset = Dataset.load_from_disk('DocRED/GPT_w_ner_short_relation/train_data_ner_short_relation')
tokenized_dataset = tokenized_dataset.remove_columns('input_texts')

In [10]:
# only take the first len(tokenized_dataset) // 50 examples to train the model

tokenized_dataset = tokenized_dataset.select(range(len(tokenized_dataset) // 50))

In [11]:
tokenized_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask'])

In [12]:
tokenizer.decode(tokenized_dataset[66]['input_ids'])

'zest airways, inc. operated as airasia zest ( formerly asian spirit and zest air ), was a low - cost airline based at the ninoy aquino international airport in pasay city, metro manila in the philippines. it operated scheduled domestic and international tourist services, mainly feeder services linking manila and cebu with 24 domestic destinations in support of the trunk route operations of other airlines. [learn1] [learn2] entity : zest airways, inc., type : organization ; entity : airasia zest, type : organization ; entity : asian spirit and zest air, type : organization ; entity : ninoy aquino international airport, type : location ; entity : pasay city, type : location ; entity : metro manila, type : location ; entity : philippines, type : location ; entity : manila, type : location ; entity : cebu, type : location ; entity : 24, type : number. [learn3] [learn4] for the relation participant of : 0.<|endoftext|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|>

In [13]:
tokenized_dataset.__getitems__([1,4])

[{'input_ids': tensor([   89,   395,  1633,  ..., 50258, 50258, 50258]),
  'attention_mask': tensor([1, 1, 1,  ..., 0, 0, 0])},
 {'input_ids': tensor([   89,   395,  1633,  ..., 50258, 50258, 50258]),
  'attention_mask': tensor([1, 1, 1,  ..., 0, 0, 0])}]

# trainer

In [14]:
import wandb

wandb.init(
    # set the wandb project where this run will be logged
    project="GPT2-intermediate",
    # notes="PubmedBERT-FT-NER_w_NERin_10epochs",
    name="GPT2-short_relation_DocRED-w-ner-5epochs"
)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33m309439737[0m ([33mtian1995[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [15]:
import transformers
from transformers import DataCollatorForLanguageModeling

trainer = transformers.Trainer(
    model=model, 
    train_dataset=tokenized_dataset,
    args=transformers.TrainingArguments(
        per_device_train_batch_size=2, 
        gradient_accumulation_steps=2,
        warmup_steps=1000, 
        num_train_epochs=5,
        learning_rate=2e-4, 
        # fp16=True,
        logging_steps=100, 
        report_to="wandb",
        save_strategy="epoch",
        output_dir='DocRED/GPT_w_ner_short_relation'
    ),
    data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False)
)
model.config.use_cache = False  # silence the warnings. Please re-enable for inference!


In [16]:
trainer.train()



  0%|          | 0/58215 [00:00<?, ?it/s]

{'loss': 28.2435, 'learning_rate': 2e-05, 'epoch': 0.01}
{'loss': 2.7659, 'learning_rate': 4e-05, 'epoch': 0.02}
{'loss': 1.9937, 'learning_rate': 6e-05, 'epoch': 0.03}
{'loss': 1.6083, 'learning_rate': 8e-05, 'epoch': 0.03}
{'loss': 1.3042, 'learning_rate': 0.0001, 'epoch': 0.04}
{'loss': 1.0588, 'learning_rate': 0.00012, 'epoch': 0.05}
{'loss': 0.7999, 'learning_rate': 0.00014, 'epoch': 0.06}
{'loss': 0.6312, 'learning_rate': 0.00016, 'epoch': 0.07}
{'loss': 0.4804, 'learning_rate': 0.00018, 'epoch': 0.08}
{'loss': 0.4285, 'learning_rate': 0.0002, 'epoch': 0.09}
{'loss': 0.3386, 'learning_rate': 0.00019965044131783623, 'epoch': 0.09}
{'loss': 0.2831, 'learning_rate': 0.00019930088263567247, 'epoch': 0.1}
{'loss': 0.2668, 'learning_rate': 0.0001989513239535087, 'epoch': 0.11}
{'loss': 0.2264, 'learning_rate': 0.00019860176527134493, 'epoch': 0.12}
{'loss': 0.1955, 'learning_rate': 0.00019825220658918117, 'epoch': 0.13}
{'loss': 0.1855, 'learning_rate': 0.00019790264790701742, 'epoch':

TrainOutput(global_step=58215, training_loss=0.134495056195633, metrics={'train_runtime': 49310.7754, 'train_samples_per_second': 4.722, 'train_steps_per_second': 1.181, 'train_loss': 0.134495056195633, 'epoch': 5.0})

wandb: ERROR Error while calling W&B API: graphql: panic occurred: runtime error: invalid memory address or nil pointer dereference (<Response [500]>)
wandb: ERROR Error while calling W&B API: graphql: panic occurred: runtime error: invalid memory address or nil pointer dereference (<Response [500]>)
wandb: ERROR Error while calling W&B API: graphql: panic occurred: runtime error: invalid memory address or nil pointer dereference (<Response [500]>)
wandb: ERROR Error while calling W&B API: graphql: panic occurred: runtime error: invalid memory address or nil pointer dereference (<Response [500]>)
wandb: ERROR Error while calling W&B API: graphql: panic occurred: runtime error: invalid memory address or nil pointer dereference (<Response [500]>)
wandb: ERROR Error while calling W&B API: graphql: panic occurred: runtime error: invalid memory address or nil pointer dereference (<Response [500]>)
wandb: ERROR Error while calling W&B API: graphql: panic occurred: runtime error: invalid memo

In [17]:
wandb.finish()
trainer.save_model("DocRED/GPT_w_ner_short_relation")

# save the tokenizer
tokenizer.save_pretrained("DocRED/GPT_w_ner_short_relation/tokenizer")

0,1
train/epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/global_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/learning_rate,▇███▇▇▇▇▇▇▆▆▆▆▆▅▅▅▅▅▅▄▄▄▄▄▃▃▃▃▃▃▂▂▂▂▂▁▁▁
train/loss,█▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/total_flos,▁
train/train_loss,▁
train/train_runtime,▁
train/train_samples_per_second,▁
train/train_steps_per_second,▁

0,1
train/epoch,5.0
train/global_step,58215.0
train/learning_rate,0.0
train/loss,0.0563
train/total_flos,1.2168631222272e+17
train/train_loss,0.1345
train/train_runtime,49310.7754
train/train_samples_per_second,4.722
train/train_steps_per_second,1.181


('DocRED/GPT_w_ner_short_relation/tokenizer/tokenizer_config.json',
 'DocRED/GPT_w_ner_short_relation/tokenizer/special_tokens_map.json',
 'DocRED/GPT_w_ner_short_relation/tokenizer/vocab.json',
 'DocRED/GPT_w_ner_short_relation/tokenizer/merges.txt',
 'DocRED/GPT_w_ner_short_relation/tokenizer/added_tokens.json')

# Inference

In [18]:
from transformers import AutoModelForCausalLM

checkpoint = "DocRED/GPT_w_ner_short_relation"

In [19]:
from transformers import GPT2Tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("DocRED/GPT_w_ner_short_relation/tokenizer")
model = AutoModelForCausalLM.from_pretrained(checkpoint)

In [20]:
# output all of the special tokens in the tokenizer
tokenizer.all_special_tokens

['<|startoftext|>',
 '<|endoftext|>',
 '<|pad|>',
 '[learn1]',
 '[learn2]',
 '[learn3]',
 '[learn4]',
 '[learn5]',
 '[learn6]']

In [21]:
import torch

model.eval()
model.to("cpu")
# inputs = tokenizer("Tweet text : @HondaCustSvc Your customer service has been horrible during the recall process. I will never purchase a Honda again. [learn1] [learn2] entity :", return_tensors="pt")

inputs = tokenizer("Tweet text : @HondaCustSvc Your customer service has been horrible during the recall process. I will never purchase a Honda again.", return_tensors="pt", padding='max_length', max_length=1000)

with torch.no_grad():
    outputs = model.generate(input_ids=inputs["input_ids"], max_new_tokens=20, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id)
    print(tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=False)[0])

A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.


Tweet text : @HondaCustSvc Your customer service has been horrible during the recall process. I will never purchase a Honda again. <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad

test data pre-processing

In [26]:
import json

from datasets import Dataset

ner = 1

test_relation_dict = {}
if ner:
    with open('DocRED/data/bi-sent-pre-process_test.json') as f:
        test_relation_dict = json.load(f)

    test_dataset = Dataset.from_dict(
        {
            'text': test_relation_dict['text'],
            'entity': test_relation_dict['entity'],
            'relation': test_relation_dict['relation']
        }
    )

else:
    pass


dataset = test_dataset

In [27]:
def pro_processing_ner(example, tokenizer, padding=True):
    texts = example['text']
    input_texts = []
    for index in range(len(texts)):
        # entity extraction and NER
        text = texts[index].lower().strip() + " [learn1] [learn2]"
        for entity in example['entity'][index]:
            text = text + " entity : " + entity[0] + " , type : " + entity[1] + " ;"
        text = text[:-1] + "."
        # print("1")
        # add relation classificaiton
        text = text.lower().strip() + " [learn3] [learn4]"
        for relation_type, relation_pair in example['relation'][index].items():
            if relation_pair:
                text_w_relation = text + " for the relation " + relation_type + " : 1 ."
                text_w_relation = text_w_relation.lower().strip() + " [learn5] [learn6]"
                text_w_relation = text_w_relation + " and the entity for the relation " + relation_type + " are :"
                for pair in relation_pair:
                    text_w_relation = text_w_relation + " head entity: " + pair[0] + " , tail entity: " + pair[1] + ";"
                text_w_relation = text_w_relation[:-1] + "." + tokenizer.eos_token
                input_texts.append(text_w_relation)

            else:
                text_w_relation = text + " for the relation " + relation_type + " : 0 ." + tokenizer.eos_token
                input_texts.append(text_w_relation)

    return {
        'input_ids': input_texts
        }

import json
from tqdm import tqdm


output = {"input_texts": []}

for i in tqdm(range(0, len(dataset), 30)):
    result = pro_processing_ner(dataset[i:i+30], tokenizer)
    output["input_texts"].extend(result["input_ids"])


from datasets import Dataset

input_text_dataset = Dataset.from_dict(
    {
        'input_texts': output['input_texts'],
    }
)


100%|██████████| 269/269 [00:04<00:00, 66.03it/s]


In [28]:
tokenized_dataset = input_text_dataset.map(lambda example: tokenizer(example['input_texts'], padding='max_length', truncation=True, max_length=1024, pad_to_max_length=True), batched=True)

tokenized_dataset.remove_columns('input_texts')

tokenized_dataset.save_to_disk('DocRED/GPT_w_ner_short_relation/test_data_ner_short_relation')
# with open('DocRED/data/train_ner_short_relation.json', 'w') as f:
#     json.dump(tokenized_dataset, f)



Map:   0%|          | 0/773472 [00:00<?, ? examples/s]

Saving the dataset (0/9 shards):   0%|          | 0/773472 [00:00<?, ? examples/s]

randomly select test data

In [30]:
from datasets import Dataset

tokenized_test_dataset = Dataset.load_from_disk('DocRED/GPT_w_ner_short_relation/test_data_ner_short_relation')
tokenized_test_dataset = tokenized_test_dataset.remove_columns('input_texts')

tokenized_test_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask'])

In [38]:
input_text_dataset[288]

{'input_texts': 'besides digital terrestrial transmission , it is available on the subscription - based encrypted services of nova and cosmote tv . skai tv is also a member of digea , a consortium of private television networks introducing digital terrestrial transmission in greece . [learn1] [learn2] entity : nova , type : organization ; entity : cosmote tv , type : organization ; entity : skai tv , type : organization ; entity : digea , type : organization ; entity : greece , type : location . [learn3] [learn4] for the relation applies to jurisdiction : 0 .<|endoftext|>'}

In [40]:
len(tokenized_dataset) / 96

8057.0

In [280]:
# random sample 1 example from the test_dataset
import random
# have a random seed
random.seed(60)

random_test_index = random.randint(0, 8056)
print("index: ", random_test_index * 96, " - ", random_test_index * 96 + 95)
print("index: ", random_test_index)
print(tokenizer.decode(tokenized_test_dataset[random_test_index * 96]['input_ids']))

# output the length of tokenized_test_dataset[index]['input_ids'] except the padding tokens. the tokenized_test_dataset[index]['input_ids'] is tensor

input_ids_lists = tokenized_test_dataset[random_test_index * 96 : random_test_index * 96 + 96]['input_ids'].tolist()

all_actually_inputs = {"learn2_index": [], "valid_length": []}
all_gold_truths = []

for input_ids_list in tqdm(input_ids_lists):
    
    valid_length = len(input_ids_list) - input_ids_list.count(tokenizer.pad_token_id)

    # print(tokenizer.decode(tokenized_test_dataset[index]['input_ids'][:valid_length]))

    # generate a lower triangle matrix of 1s with the shape is (valid_length, valid_length), using torch

    low_triangle_matrix = torch.tril(torch.ones((valid_length, valid_length), dtype=torch.long))

    # find the index of the token id of "[learn2]" in the tokenized_test_dataset[index]['input_ids'] tensor

    learn2_index = input_ids_list.index(tokenizer.convert_tokens_to_ids("[learn2]"))

    # have a vector to store the token of okenized_test_dataset[index]['input_ids'][learn2_index + 1:valid_length]

    gold_truth = input_ids_list[learn2_index + 1:valid_length]

    all_actually_inputs["learn2_index"].append(learn2_index)

    all_actually_inputs["valid_length"].append(valid_length)

    all_gold_truths.append(gold_truth)



index:  242016  -  242111
index:  2521
durgada is a rural village in gollaprolu mandal, east godavari district, andhra pradesh, india. the village was formerly known as durga ooda, durga vaahini. [learn1] [learn2] entity : durgada, type : location ; entity : gollaprolu, type : location ; entity : east godavari, type : location ; entity : andhra pradesh, type : location ; entity : india, type : location ; entity : durga ooda, type : location ; entity : durga vaahini, type : location. [learn3] [learn4] for the relation applies to jurisdiction : 0.<|endoftext|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|p

  0%|          | 0/96 [00:00<?, ?it/s]

quick inference

In [281]:
from tqdm.notebook import trange, tqdm
import torch
import numpy as np


model.eval()
outputs = []
model.to("cuda")

batch_output = []
all_batch_output = []
all_accuracy = []


with torch.no_grad():
    # feed the actually_input to the model by 10 examples each time
    for input_index, input_ids_list in enumerate(input_ids_lists):
        print(input_index + 1, " / 96")
        for i in range(all_actually_inputs['learn2_index'][input_index] + 1, all_actually_inputs['valid_length'][input_index]):
            output = model(input_ids=torch.tensor(input_ids_list[:i]).to("cuda"))
            current_output = np.array(output['logits'].cpu())
            max_index = np.argmax(current_output[-1, :], axis=0)
            batch_output.append(max_index)
            # break
        if len(batch_output) == len(all_gold_truths[input_index]):
            accuracy = sum(np.array(batch_output) == np.array(all_gold_truths[input_index])) / len(batch_output)
            all_accuracy.append(accuracy)
            all_batch_output.append(batch_output)
            batch_output = []
        else:
            print("error")
        
    # print(tokenizer.batch_decode(max_index, skip_special_tokens=False)[0])

1  / 96
2  / 96
3  / 96
4  / 96
5  / 96
6  / 96
7  / 96
8  / 96
9  / 96
10  / 96
11  / 96
12  / 96
13  / 96
14  / 96
15  / 96
16  / 96
17  / 96
18  / 96
19  / 96
20  / 96
21  / 96
22  / 96
23  / 96
24  / 96
25  / 96
26  / 96
27  / 96
28  / 96
29  / 96
30  / 96
31  / 96
32  / 96
33  / 96
34  / 96
35  / 96
36  / 96
37  / 96
38  / 96
39  / 96
40  / 96
41  / 96
42  / 96
43  / 96
44  / 96
45  / 96
46  / 96
47  / 96
48  / 96
49  / 96
50  / 96
51  / 96
52  / 96
53  / 96
54  / 96
55  / 96
56  / 96
57  / 96
58  / 96
59  / 96
60  / 96
61  / 96
62  / 96
63  / 96
64  / 96
65  / 96
66  / 96
67  / 96
68  / 96
69  / 96
70  / 96
71  / 96
72  / 96
73  / 96
74  / 96
75  / 96
76  / 96
77  / 96
78  / 96
79  / 96
80  / 96
81  / 96
82  / 96
83  / 96
84  / 96
85  / 96
86  / 96
87  / 96
88  / 96
89  / 96
90  / 96
91  / 96
92  / 96
93  / 96
94  / 96
95  / 96
96  / 96


In [282]:
# have the average accuracy
if len(all_accuracy) == 96:

    print(sum(all_accuracy) / len(all_accuracy))

0.768768807756679


In [283]:
# have the max accuracy and the index of the max accuracy

max_accuracy = max(all_accuracy)
max_accuracy_index = all_accuracy.index(max_accuracy)

In [284]:
# output the index of the correct prediction

for index, item in enumerate(zip(all_batch_output[max_accuracy_index], all_gold_truths[max_accuracy_index])):
    if item[0] == item[1]:
        print(index, ": ", tokenizer.decode(item[0]))

0 :  entity
1 :   :
4 :  ada
5 :  ,
6 :   type
7 :   :
8 :   location
9 :   ;
10 :   entity
11 :   :
13 :  ll
18 :   type
19 :   :
20 :   location
21 :   ;
22 :   entity
23 :   :
24 :   east
25 :   god
28 :  ,
29 :   type
30 :   :
31 :   location
32 :   ;
33 :   entity
34 :   :
36 :  hra
38 :  adesh
40 :   type
41 :   :
42 :   location
43 :   ;
44 :   entity
45 :   :
47 :  ia
48 :  ,
49 :   type
50 :   :
51 :   location
52 :   ;
53 :   entity
54 :   :
55 :   d
57 :   o
59 :  ,
60 :   type
61 :   :
62 :   location
64 :   entity
65 :   :
66 :   d
68 :   v
70 :  ini
71 :  ,
72 :   type
73 :   :
75 :  .
76 :  [learn3]
77 :  [learn4]
78 :  for
79 :   the
80 :   relation
81 :   member
82 :   of
83 :   sports
84 :   team
85 :   :
86 :   0
87 :  .
88 :  <|endoftext|>


In [285]:
tokenizer.decode(input_ids_lists[max_accuracy_index])

'durgada is a rural village in gollaprolu mandal, east godavari district, andhra pradesh, india. the village was formerly known as durga ooda, durga vaahini. [learn1] [learn2] entity : durgada, type : location ; entity : gollaprolu, type : location ; entity : east godavari, type : location ; entity : andhra pradesh, type : location ; entity : india, type : location ; entity : durga ooda, type : location ; entity : durga vaahini, type : location. [learn3] [learn4] for the relation member of sports team : 0.<|endoftext|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|p

whole inference

In [286]:
import json

def for_the_relation (index, input,tokenizer=tokenizer):
    rel_info = {}
    with open('DocRED/data/relation-index.json') as f:
        rel_info = json.load(f)
    rel_info_list = [relation for relation in rel_info.keys()]
    relation_input = "for the relation " + rel_info_list[index].lower() + " : "
    tokenized_relation_input = tokenizer.encode(relation_input, add_special_tokens=False, return_tensors="pt")
    # print("before relation classification: ", tokenizer.decode(input))
    return (torch.cat((input, tokenized_relation_input[0]), dim=0))


In [287]:
from tqdm.notebook import trange, tqdm
import torch
import numpy as np


model.eval()
outputs = []
model.to("cuda")
output_texts = []




with torch.no_grad():
    # feed the actually_input to the model by 10 examples each time
    for input_index, input_ids_list in enumerate(input_ids_lists):
        print(input_index + 1, " / 96")
        start = all_actually_inputs['learn2_index'][input_index] + 1
        input_ids = torch.tensor(input_ids_list[:start]).to("cuda")
        relation_classfication = False
        relation_extraction = False

        while((input_ids[-1].item() != tokenizer.eos_token_id) and (len(input_ids) < 1024)):
            output = model(input_ids=input_ids)
            current_output = np.array(output['logits'].cpu())
            max_index = np.argmax(current_output[-1, :], axis=0)

            if max_index == tokenizer.convert_tokens_to_ids("[learn4]") and (not relation_classfication):
                input_ids = torch.cat((input_ids, torch.tensor(max_index).unsqueeze(0).to("cuda")), dim=0)
                input_ids = for_the_relation(input_index, input_ids.to("cpu"))
                input_ids = input_ids.to("cuda")
                relation_classfication = True
                continue

            if relation_classfication and (not relation_extraction):

                # find the possibilities for token id tokenizer.conver_token_to_id("0") and tokenizer.conver_token_to_id("1") in the current_output

                zero_index = tokenizer.convert_tokens_to_ids("0")
                one_index = tokenizer.convert_tokens_to_ids("1")
                zero_possibility = current_output[-1, zero_index]
                print("zero_possibility: ", zero_possibility)
                one_possibility = current_output[-1, one_index]
                print("one_possibility: ", one_possibility)


                if zero_possibility > one_possibility:
                    relation_extraction = True
                    next_input = "0 . <|endoftext|>"
                    input_ids = torch.cat((input_ids, torch.tensor(tokenizer.encode(next_input, add_special_tokens=False, return_tensors="pt")[0]).to("cuda")), dim=0)
                    continue
                else:
                    relation_extraction = True
                    next_input = "1 . [learn5] [learn6]"
                    input_ids = torch.cat((input_ids, torch.tensor(tokenizer.encode(next_input, add_special_tokens=False, return_tensors="pt")[0]).to("cuda")), dim=0)
                    relation_extraction = True
                    # print(tokenizer.decode(input_ids[start:]))
                    continue

            # if relation_extraction:
            #     print(tokenizer.decode(input_ids[start:]))
                
            input_ids = torch.cat((input_ids, torch.tensor(max_index).unsqueeze(0).to("cuda")), dim=0)

        # break
        print(tokenizer.decode(input_ids[start:]))
        output_texts.append(tokenizer.decode(input_ids[start:]))
        
    # print(tokenizer.batch_decode(max_index, skip_special_tokens=False)[0])

1  / 96
zero_possibility:  14.324283
one_possibility:  20.368336


  input_ids = torch.cat((input_ids, torch.tensor(tokenizer.encode(next_input, add_special_tokens=False, return_tensors="pt")[0]).to("cuda")), dim=0)


entity : italy, type : location ; entity : greece, type : location ; entity : st. louis, type : location ; entity : india, type : location ; entity : daimaru, type : location ; entity : cowichan lake, type : location ; entity : kahlo, type : head of government. [learn3] [learn4] for the relation applies to jurisdiction : 1. [learn5] [learn6] and the entity for the relation applies to jurisdiction are : head entity: italy, tail entity: india; head entity: italy, tail entity: india; head entity: italy, tail entity: india; head entity: italy, tail entity: india; head entity: italy, tail entity: india; head entity: italy, tail entity: india; head entity: italy, tail entity: india; head entity: italy, tail entity: india; head entity: italy, tail entity: india; head entity: italy, tail entity: india; head entity: italy, tail entity: india.<|endoftext|>
2  / 96
zero_possibility:  -3.2559252
one_possibility:  3.9837303
entity : italy, type : location ; entity : greece, type : location ; entity

  input_ids = torch.cat((input_ids, torch.tensor(tokenizer.encode(next_input, add_special_tokens=False, return_tensors="pt")[0]).to("cuda")), dim=0)


zero_possibility:  19.808475
one_possibility:  19.078203
entity : italy, type : location ; entity : greece, type : location ; entity : st. louis, type : location ; entity : india, type : location ; entity : daimaru, type : location ; entity : cowichan lake, type : location ; entity : kahlo, type : head of government. [learn3] [learn4] for the relation date of death : 0. <|endoftext|>
21  / 96
zero_possibility:  1.3002384
one_possibility:  6.9767156
entity : italy, type : location ; entity : greece, type : location ; entity : st. louis, type : location ; entity : india, type : location ; entity : daimaru, type : location ; entity : cowichan lake, type : location ; entity : kahlo, type : head of government. [learn3] [learn4] for the relation developer : 1. [learn5] [learn6] and the entity for the relation developer are : head entity: cowichan lake, tail entity: kahlo; head entity: cowichan lake, tail entity: kahlo.<|endoftext|>
22  / 96
zero_possibility:  2.1231802
one_possibility:  7.56

In [288]:
import json

from datasets import Dataset

ner = 1

test_relation_dict = {}
if ner:
    with open('DocRED/data/bi-sent-pre-process_test.json') as f:
        test_relation_dict = json.load(f)

    test_dataset = Dataset.from_dict(
        {
            'text': test_relation_dict['text'],
            'entity': test_relation_dict['entity'],
            'relation': test_relation_dict['relation']
        }
    )

else:
    pass


In [289]:
# save the output_texts to a json file

with open(f'DocRED/GPT_w_ner_short_relation/test_output_texts_for_test_dataset_{random_test_index}.json', 'w') as f:
    json.dump(output_texts, f)

In [290]:
rel_info = {}
with open('DocRED/data/relation-index.json') as f:
    rel_info = json.load(f)
rel_info_list = [relation for relation in rel_info.keys()]

In [291]:
test_dataset[random_test_index]

{'text': 'durgada is a rural village in gollaprolu mandal , east godavari district , andhra pradesh , india . the village was formerly known as durga ooda , durga vaahini .',
 'entity': [['durgada', 'location'],
  ['gollaprolu', 'location'],
  ['east godavari', 'location'],
  ['andhra pradesh', 'location'],
  ['india', 'location'],
  ['durga ooda', 'location'],
  ['durga vaahini', 'location']],
 'relation': {'applies to jurisdiction': None,
  'author': None,
  'award received': None,
  'basin country': None,
  'capital': None,
  'capital of': None,
  'cast member': None,
  'chairperson': None,
  'characters': None,
  'child': None,
  'composer': None,
  'conflict': None,
  'contains administrative territorial entity': [['india', 'andhra pradesh']],
  'continent': None,
  'country': [['east godavari', 'india'],
   ['durga ooda', 'india'],
   ['durga vaahini', 'india'],
   ['andhra pradesh', 'india'],
   ['gollaprolu', 'india'],
   ['durgada', 'india']],
  'country of citizenship': None,

In [292]:
output_texts[0]

'entity : italy, type : location ; entity : greece, type : location ; entity : st. louis, type : location ; entity : india, type : location ; entity : daimaru, type : location ; entity : cowichan lake, type : location ; entity : kahlo, type : head of government. [learn3] [learn4] for the relation applies to jurisdiction : 1. [learn5] [learn6] and the entity for the relation applies to jurisdiction are : head entity: italy, tail entity: india; head entity: italy, tail entity: india; head entity: italy, tail entity: india; head entity: italy, tail entity: india; head entity: italy, tail entity: india; head entity: italy, tail entity: india; head entity: italy, tail entity: india; head entity: italy, tail entity: india; head entity: italy, tail entity: india; head entity: italy, tail entity: india; head entity: italy, tail entity: india.<|endoftext|>'

In [308]:
right_relation_1 = []
right_relation_0 = []
wrong_relation_1 = []
wrong_relation_0 = []
entity_pairs = []

# text = output_texts[0]

# rel_index = 0

for rel_index in range(len(rel_info_list)):
    text = output_texts[rel_index]
    relation_text = "for the relation " + rel_info_list[rel_index].lower() + " : "
    relation_classfication = text.split(relation_text)[1].split(".")[0].strip()

    if "1" in relation_classfication:
        if test_dataset[random_test_index]['relation'][rel_info_list[rel_index]]:
            print("relation : ", rel_info_list[rel_index])
            print("gold truth: ", test_dataset[random_test_index]['relation'][rel_info_list[rel_index]])
            right_relation_1.append(rel_index)

            entity_pair_for_this_relation = []
            remain_text = text
            while("tail entity" in remain_text):
                if ";" in remain_text.split("tail entity")[1]:
                    head_entity = remain_text.split("head entity: ")[1].split(", tail entity: ")[0].strip()
                    tail_entity = remain_text.split("tail entity: ")[1].split(";")[0].strip()
                    if [head_entity, tail_entity] not in entity_pair_for_this_relation:
                        entity_pair_for_this_relation.append([head_entity, tail_entity])
                        print("head entity: ", head_entity)
                        print("tail entity: ", tail_entity)
                    remain_text = remain_text.split(";")[1].strip()
                else:
                    head_entity = remain_text.split("head entity: ")[1].split(", tail entity: ")[0].strip()
                    tail_entity = remain_text.split("tail entity: ")[1].split(".<|endoftext|>")[0].strip()
                    if [head_entity, tail_entity] not in entity_pair_for_this_relation:
                        entity_pair_for_this_relation.append([head_entity, tail_entity])
                        print("head entity: ", head_entity)
                        print("tail entity: ", tail_entity)
                    break
            entity_pairs.append(entity_pair_for_this_relation)

        else:
            wrong_relation_1.append(rel_index)
    else:
        if test_dataset[random_test_index]['relation'][rel_info_list[rel_index]]:
            wrong_relation_0.append(rel_index)
        else:
            right_relation_0.append(rel_index)

relation :  contains administrative territorial entity
[['india', 'andhra pradesh']]
head entity:  india
tail entity:  kahlo
relation :  country
[['east godavari', 'india'], ['durga ooda', 'india'], ['durga vaahini', 'india'], ['andhra pradesh', 'india'], ['gollaprolu', 'india'], ['durgada', 'india']]
head entity:  bharatiya
tail entity:  india
relation :  located in the administrative territorial entity
[['east godavari', 'andhra pradesh'], ['andhra pradesh', 'india'], ['gollaprolu', 'east godavari'], ['gollaprolu', 'andhra pradesh'], ['durgada', 'gollaprolu']]
head entity:  greece
tail entity:  india


In [306]:
print("right_relation_1: ", len(right_relation_1))
print("right_relation_0: ", len(right_relation_0))
print("wrong_relation_1: ", len(wrong_relation_1))
print("wrong_relation_0: ", len(wrong_relation_0))

right_relation_1:  3
right_relation_0:  4
wrong_relation_1:  89
wrong_relation_0:  0


In [307]:
entity_pairs

[[['india', 'kahlo']], [['bharatiya', 'india']], [['greece', 'india']]]

play ground

In [315]:
tokenizer.convert_tokens_to_ids("[learn5]")

50263

In [316]:
# find the index of 50262 in input_ids_lists[12]

tokenizer.decode(input_ids_lists[12][:input_ids_lists[12].index(50263) - 2])

'durgada is a rural village in gollaprolu mandal, east godavari district, andhra pradesh, india. the village was formerly known as durga ooda, durga vaahini. [learn1] [learn2] entity : durgada, type : location ; entity : gollaprolu, type : location ; entity : east godavari, type : location ; entity : andhra pradesh, type : location ; entity : india, type : location ; entity : durga ooda, type : location ; entity : durga vaahini, type : location. [learn3] [learn4] for the relation contains administrative territorial entity :'

In [440]:
input_ids = torch.tensor(input_ids_lists[12][:input_ids_lists[12].index(50263) - 2]).to("cuda")

In [434]:
with torch.no_grad():
    output = model(input_ids=input_ids)
    current_output = np.array(output['logits'].cpu())
    max_index = np.argmax(current_output[-1, :], axis=0)
    sorted_index = np.argsort(current_output[-1, :])[::-1]
    print(tokenizer.decode(max_index))

    

atar


In [439]:
tokenizer.decode(sorted_index[5])

'ito'

In [433]:
with torch.no_grad():
    input_ids = torch.cat((input_ids, torch.tensor(sorted_index [0]).unsqueeze(0).to("cuda")), dim=0)
    print(tokenizer.decode(input_ids))

durgada is a rural village in gollaprolu mandal, east godavari district, andhra pradesh, india. the village was formerly known as durga ooda, durga vaahini. [learn1] [learn2] entity : durgada, type : location ; entity : gollaprolu, type : location ; entity : east godavari, type : location ; entity : andhra pradesh, type : location ; entity : india, type : location ; entity : durga ooda, type : location ; entity : durga vaahini, type : location. [learn3] [learn4] for the relation contains administrative territorial entity : 1. [learn5] [learn6] and the entity for the relation contains administrative administrative territorial entity are : head entity: india, tail entity: godav
