In [None]:
'''

Initial Imports and api key

'''



### gpt3 function
import openai
import sys
from inference_utils import scrub_PX_PY, name_PX_PY, get_key, gpt3 ,complete_gpt3, few_shot_prompt
from interence_fs_templates import relation_dicts

# define API key for GPT-3 access
with open('api.key') as f:
    openai.api_key = f.read().strip()


# Initialize Parameters and Results File

In [None]:
'''

Figure out which heads we need to generate for

'''

import os
import json


# variable definitions
batch_name = 'curie_inference_generation_0' ## define name of this batch
engine = 'curie' # the GPT-3 engine to use
top_p = 0.5# top_p for nucleus sampling
n_gen = 10 # number of generations for each event/relation pair
PersonX = 'Alex' # name to use for PersonX (to make sentences more natural)
PersonY = 'Chris' # name to use for PersonY
length = 12 # maximum token length of generations (number of words per generaation will be less)

stop_token = '\n\n' # stop token during generation
#d = HinderedBy_dict # which relation to useo
n_examples_in = 5 # how many few-shot examples to use in the prompt

# names to use in few-shot examples
names = ['Jean','Robin','Charlie', 'Ryan','Taylor','Jordan','Riley','Jamie','Leslie','Rowan',
               'Adrian','Ali','Wyatt', 'Sydney','Stevie','Shiloh', 'Sam','Pat','Noel','Nicky','Max',
               'Madison', 'Lindsay','Leslie','Lee','Jesse','Hunter','Glen','Devin','Avery']



heads_file = 'todo_events.txt'
out_file = 'inference_gens.jsonl'
meta_file = 'params.json'


### save params if this is not done already (for records)
if not os.path.isfile( meta_file):
    relation_param_dict = [{key:d[key] for key in d.keys() if key !='function'} for d in  relation_dicts] 
    param_dict = {'batch_name':batch_name,
                    'engine':engine,
                    'top_p':top_p,
                    'n_gen':n_gen,
                    'PersonX':PersonX,
                    'PersonY':PersonY,
                    'n_examples_in':n_examples_in,
                    'length':length,
                  'stop_token':stop_token,
                  'names':names,
                 'relation_param_dict':relation_param_dict}
    with open(meta_file, 'w') as f:
        f.write(json.dumps(param_dict))
        
        
### create output file if it does not exist
if not os.path.isfile(out_file):
    with open(out_file,'w') as f:
        pass
    




# Generate Inferences

In [None]:
'''

Generated inferences for each event/relation pair

'''
import random
import time

t_start = time.time()

print_step = 400




####
# For robustness, figure out if we have already generated for some inputs, 
# (e.g. if the notebook crashed) and if so, pick up where we left off.
####
##
## first, load the heads we need to generate for
with open(heads_file) as f: 
    heads_all = [head.strip() for head in f.readlines()]
## then, get all heads we have generated for so far
heads_done = []
with open(out_file) as f:
    for line in f:
        heads_done.append(json.loads(line)['head'])
## heads todo are those in heads_file that aren't yet in out_file
heads_todo = [head for head in heads_all if head not in heads_done]
print('{} heads todo'.format(len(heads_todo)))
##
####




####
# Next, begin generated on whichever inputs (event/relation) we still
# need to generate for
####

for i in range(len(heads_todo)):
    time.sleep(0.02)
    
    # get the event to generate for in this iteration
    head_to_test = heads_todo[i]


    ## first, define output objects (empty for now)
    ex_result = {}
    ex_result['head'] = head_to_test
    ex_result['tails_by_relation'] = {}

    ## randomize the names we use
    random.shuffle(names)
    names_list = list(zip(names[:15],names[15:]))

    
    
    prompts = []

    ## define inputs for GPT-3 (prompts) for each relation (defined by relation dict d)
    for d in relation_dicts:
        # get few-show examples
        examples_in = d['examples'][:n_examples_in]

        # define prompt, and name PersonX/PersonY in each example
        prompt = few_shot_prompt(examples_in + [{'head':head_to_test}], d['function'], number=True,
                                 Person_list=names_list[:len(examples_in)]+ [(PersonX,PersonY)])

        # add prefix to prompt
        prompt = d['prefix'] + prompt
        prompt = name_PX_PY(prompt, PersonX,PersonY)

        # add prompt to the output object
        ex_result['tails_by_relation'][d['relation']] = {'relation':d['relation'],'prompt':prompt} 
        prompts.append(prompt)

    ## generate using the prompt
    result = complete_gpt3(prompts, length, engine, top_p = top_p,num_log_probs=1,n=n_gen, stop=stop_token,echo=False )


    # for each prompt (corresponding to one event/relation input) generate inferences
    j = 0
    for d in relation_dicts:



        ### sort the output
        outputs = []
        for choice in result['choices'][j:j+n_gen]:
            out = choice['text']

            # get the ind up to the stop_word, and take nll for this sequence
            end_ind = choice['logprobs']['text_offset'].index( max(choice['logprobs']['text_offset']))
            nll = sum(choice['logprobs']['token_logprobs'][:end_ind + 1])

            outputs  += [{'text':out,
                         'result':choice,
                          'nll':nll}]
        j+=n_gen
        ex_result['tails_by_relation'][d['relation']]['tails'] = outputs

    # append the generaitons to the output file
    with open(out_file,'a') as f:
        f.write(json.dumps(ex_result) + '\n')
        
        
    # print results
    if (i % print_step) == 0:
        print('='*50)
        print('time: {} mins, avg_rate: {}'.format((time.time()-t_start)/60.,(time.time()-t_start)/60./(i+1) ))
        print('{}) {}'.format(i, ex_result['head']))
        
        for relation in ex_result['tails_by_relation'].keys():
            print('='*10)
            print('{})'.format(relation))
            print('='*10)
            print('{}'.format([v['text'] for v in ex_result['tails_by_relation'][relation]['tails'] ] ))
            
        print('='*50)

# Process Generations once they are done

In [None]:
'''

ONLY RUN once you have generated all inferences

this block takes the full set of generated inferences,
and puts them in a standardized format, including removing inferences
that repeat for a given event/relation input.

It also automatically divides generations into a train/val/test split
based on the event (i.e. all inferences for a given event are sorted into the 
same split)

generations are saved as a jsonl file where each entry has the following keys:

head: the event for generation
relation: relation for generation
tail: the generated inference
split: the dataset split

'''

import json
    
    
full_dataset_file = 'unique_dataset.jsonl'
    

def name_PX_PY(s, PersonX, PersonY):
    return s.replace('PersonX', PersonX).replace('PersonY', PersonY)
def scrub_PX_PY(s, PersonX, PersonY):
    return s.replace(PersonX, 'PersonX').replace(PersonY, 'PersonY')
    
def process_tail(tail):
    if tail[-1] == '.':
        tail = tail[:-1]
    if tail[0] == ' ':
        tail = tail[1:]
    tail = scrub_PX_PY(tail, PersonX, PersonY)
    
    return tail




import json
import random
i = 0

data_out = []

with open(out_file) as f, open(full_dataset_file,'w') as f_out:
    for line in f:
        
        r = random.random()
        if r< 0.8:
            split = 'train'
        elif r < 0.9:
            split = 'val'
        else:
            split = 'test'


        d = json.loads(line)

        small_tails_by_relation = {}

        for relation in d['tails_by_relation'].keys():
            
            # remove short inferences (degenerate)
            d['tails_by_relation'][relation]['tails'] = [v for v in d['tails_by_relation'][relation]['tails'] if len(v['text']) > 2]
            
            small_tails_by_relation[relation] = {'relation':relation,
                                                 'tails':[{'text':process_tail(v['text'])} for v in d['tails_by_relation'][relation]['tails']]}


        
        
        for relation in small_tails_by_relation.keys():
            inferences = [v['text'] for v in small_tails_by_relation[relation]['tails']]
            
            # do not include repeats
            for inference in set(list(inferences)):
                data_out.append({'split':split,'head':d['head'], 'relation':relation, 'inference':inference })
        
        new_d = {'split':split,'head':d['head'],'tails_by_relation':small_tails_by_relation}
        
    
    for d in data_out:
        f_out.write(json.dumps(d) + '\n')

        
print('total of {} unique generated examples written to {}'.format(len(data_out), full_dataset_file))