# Initialize Generation parameters

In [None]:
'''

Set parameters for generation of events

Run this block only once, as it overwrite ``batches'' and
    ``anon_heads'' which keep track of generations

'''
import openai


# this file should contain openai key for generation
#  GPT-3 access is REQUIRED for this script
#  (although GPT-3 calls could be replaced with calls
#  to a local language model--this would require code changes)
with open('api.key') as f:
    openai.api_key = f.read().strip()

engine = 'davinci' # gpt-3 model version to use
length= 12 # maximum token length of generations. 12 tokens is up to 12 words, but likely less
top_p = 0.9 # top_p to use for nucleus sampling
n_gen = 100 # number of generations per-batch
presence_penalty = 0.5 # discourage tokens in the prompt
frequency_penalty = 0.5 # discourage tokens in the prompt based on frequency
PersonX = 'PersonX' # name to use for PersonX. Default is just "PersonX"
PersonY = 'PersonY' # name to use for PersonY. Default is just "PersonY"



n_examples_in = 10# how many few-shot examples to use per generation batch
n_batches = 2 # how many generation batches to run



batches = []

anon_heads = []

# Definitions for Generation

In [None]:
'''

Load exampe events for few-shot generation, and define
function to convert these into few-shot prompts

'''


with open('seed_events.txt') as f:
    seed_events = [line.strip() for line in f.readlines()]


## converts an example into text for few-shot
def ex2text(ex, include_gen = True, number = None):
    text = ''
    if number is not None:
        text += '{}. Event:'.format(number)
        
    text += ''.format()

    if include_gen:
        text += ' {}\n\n\n'.format(ex)
    return text


## create a few-shot prompt with the final 
## example left open for generation
##
### note: ex2text should be a function that takes 
## an example from examples and produces a string 
## template. This should accept an argument, include_gen 
## which is a bool to decide whether to leave it open for
## generation or include the gt
def few_shot_prompt(examples, ex2text, number = None, Person_list = None):
    template_str = ''
    
    i = -1
    for i, example in enumerate(examples[:-1]):
        
        if number:
            ex_str = ex2text(example, include_gen = True, number = i + 1)
        else:
            ex_str = ex2text(example, include_gen = True)
            
            
        if Person_list is not None:
            ex_str = name_PX_PY(ex_str, Person_list[i][0],Person_list[i][1] )
            
        template_str += ex_str
        
    i = i + 1
    
    if number:
        ex_str = ex2text(examples[-1], include_gen = False,number = i + 1)
    else:
        ex_str = ex2text(examples[-1], include_gen = False)
        
    if Person_list is not None:
        ex_str = name_PX_PY(ex_str, Person_list[i][0],Person_list[i][1] )
    
    template_str += ex_str
    
    return template_str


def scrub_PX_PY(s, PersonX, PersonY):
    return s.replace(PersonX, 'PersonX').replace(PersonY, 'PersonY')

def name_PX_PY(s, PersonX, PersonY):
    return s.replace('PersonX', PersonX).replace('PersonY', PersonY)

# Generate Events

In [None]:
import json
import time
import random

from event_utils import few_shot_prompt, name_PX_PY, scrub_PX_PY, complete_gpt3

for _ in range(n_batches):
    # sleep to prevent timout from api
    time.sleep(0.05)
    # randomize which events are used for generation
    random.shuffle(seed_events)
    examples_in = seed_events[:n_examples_in]

    # outputs for this batch
    outputs = []

    # use names defined abov
    names = (PersonX,PersonY)
    # define prompt based on examples and names
    prompt = few_shot_prompt(examples_in + [''], ex2text, number=True)
    prompt = name_PX_PY(prompt,names[0],names[1])


    ## generate using the prompt
    print('='*20 + 'prompt'+ '='*20)
    print(prompt)
    result = complete_gpt3(prompt, length, engine, top_p = top_p,num_log_probs=1,n=n_gen, stop='\n\n', echo=False,
                          frequency_penalty=frequency_penalty, presence_penalty=presence_penalty)



    ### sort the output
    outputs = []
    for choice in result['choices']:
        try:
            out = choice['text']
        except:
            out = ''
        end_ind = choice['logprobs']['text_offset'].index( max(choice['logprobs']['text_offset']))
        nll = sum(choice['logprobs']['token_logprobs'][:end_ind + 1])

        text = choice['text']

        anon_text= scrub_PX_PY(out, names[0],names[1]).strip()

        print('='*20 + 'out'+ '='*20)
        print(anon_text)

        ind_newline_0 = 0 
        try:
            ind_newline_1 = ind_newline_0 + text[ind_newline_0 + 1:].index('\n') + 1
        except:
            ind_newline_1 = len(text)-1
        #nll_explanation = get_subspan_CE(ind_newline_0 + 1, ind_newline_1, choice)

        print('nll: {}'.format(nll))

        ## if did not reach a stop word, don't take this generation
        if choice['finish_reason'] != 'stop' or not (anon_text.startswith('PersonX')):
            print('break')
            continue

        outputs  += [{'text':text,
                      'anon_text': anon_text,
                                             'result':choice,
                                              'nll':nll,
                                              'prompt':prompt}]
    
    
    batches.append({'prompt':prompt, 'events':outputs})

    anon_heads += [v['anon_text'] for v in outputs]
    
    print('{} unique new events'.format(len(set(anon_heads))))
    
with open('generated_events.jsonl','w') as f:
    for batch in batches:
        f.write(json.dumps(batch) + '\n')

# Process Generated Events

In [None]:
'''

First, we load the generated head events, and sort them by negative log-likelihood (nll)

We truncate the bottom 20% which are most likely to be degenerate

'''

## first, load the generated heads from earlier, in batches
generated_batches = []
with open('generated_events.jsonl','r') as f:
    for line in f.readlines():
        generated_batches.append(json.loads(line))
     
# next, get each individual event
all_events = []
for batch in generated_batches:
    for d in batch['events']:
        all_events += [(d['anon_text'],d['nll'])]
 

# trim the bottom 20% in terms of nll
all_events = {v[0]:v[1] for v in all_events}
all_events = [(key, all_events[key]) for key in all_events.keys()]
all_events.sort(key = lambda v: -v[1])
all_events = all_events[:int(0.8*len(all_events))]

all_events

In [None]:
''' 

Finally, remove any events with strange formatting or degenerate properties

''' 
   
todo_events = []


for event,_ in all_events:

    if any(not (c.isalnum() or c in '\',".-’$ ') for c in event) or (len(event) < len(PersonX)):
        continue
    todo_events += [event]
    

# write these to a file for use in inference generation
with open('todo_events.txt','w') as f:
    for event in todo_events:
        f.write('{}\n'.format(event))

# [If using data release instead of generating]

In [None]:
'''

!!!
Only run this block if you are planning to use our released head events,
rather than generating your own. This block will produce the same format of 
'todo_events.txt' file as above, but will use those generated as part of 
the Symbolic Knowledge Decoding paper, rather than generating new ones

Remove the assert(False) line. We include this to prevent users from accidentally
running this block and overwriting their own generations

'''

import json

assert(False)

events = []
with open('downloaded/ATOMIC10X.jsonl') as f:
    for line in f:
        d = json.loads(line)
        events.append(d['head'])
        

# remove duplicates
events = list(set(events))

# write these to a file for use in inference generation
with open('todo_events.txt','w') as f:
    for event in events:
        f.write('{}\n'.format(event))