Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
summerstay authored May 15, 2020
1 parent f813f3b commit 71d9024
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 80 deletions.
Binary file modified syllable_tokens.p
Binary file not shown.
155 changes: 75 additions & 80 deletions true_poetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ class Struct:
pass
params = Struct()
params.rhyme_set_size = 20 # if a word will later be rhymed with, at least this many rhyming words must exist
params.probability_threshold = .0005 # a token must have at least this probability of being the next token
params.line_probability_threshold = 0 #total multiplied out probability of entire line of tokens can be no lower than this
params.probability_threshold = .0005 # a token must have at least this probability of being the next token. .05 = better quality but slower; .005 = worse quality but faster
params.line_probability_threshold = 0 #total multiplied out probability of entire line of tokens can be no lower than this. 0 means this isn't being used.
params.ultimate_expansion = 30 # no more than this many words will be tried as the last syllable for any previous phrase
params.penultimate_expansion = 10 # no more than this many words will be tried for the next to last syllable for any previous phrase
params.other_expansion = 10 # no more than this many words will be tried for the second through next to last syllables of any phrase
Expand Down Expand Up @@ -147,6 +147,75 @@ def compare_meters(test_meter,target_meter):
# matchflag = False
return matchflag

def rhyme_and_meter_filter(this_text_sentence,target_rhyme,target_meter,probs,params):
#returns a sorted list of words which are (usually) compatible with the upcoming rhyme and meter constraints.
#It's meant to make searching faster, not to be a perfect filter
global stress_tokens
global big_rhymesets
global acceptable_punctuation
global rhyming_tokens
global syllable_tokens
#this randomization helps prevent repetition, but it's kind of a hack.
offset = randint(0,2)
this_meter = text_to_meter(this_text_sentence,stress_dictionary)

#meter filter
next_stresses = target_meter[len(this_meter):min(len(this_meter)+3,len(target_meter)+1)]
if len(next_stresses)==0:
return []
all_tokens = set(range(0,50257))
stress_okay = set(stress_tokens[next_stresses])
#all tokens EXCEPT those with okay stress or acceptable punctuation are zeroed out.
for token in all_tokens.difference(stress_okay.union(acceptable_punctuation)):
probs[token] = 0

#rhyme_filter
xprint("meter_length = ", end = "")
xprint(len(this_meter))
if len(target_rhyme)>0 and target_rhyme != "!":
target_rhyme_words = target_rhyme.split(" ")
last_target_rhyme_word = target_rhyme_words[-1].strip().lower()
if last_target_rhyme_word[-1] in {"!",".",",",";",":","?","-"}:
last_target_rhyme_word = last_target_rhyme_word[:-1]
xprint("target rhyme =",end="")
xprint(last_target_rhyme_word)
these_rhyming_tokens = rhyming_tokens[last_target_rhyme_word]
xprint(tokenizer.decode(these_rhyming_tokens))
if len(this_meter)==len(target_meter)-1:
for t in range(0,50257):
if t in these_rhyming_tokens:
pass
else:
probs[t] = 0
elif len(this_meter)==len(target_meter)-2:
# either a rhyming word or a one-syllable word which could be followed by a rhyming word is okay.
safeset = syllable_tokens[1].union(these_rhyming_tokens)
for t in range(0,50257):
if t in safeset:
pass
else:
probs[t] = 0
sorted_probability_list = sorted(enumerate(probs), key=lambda x: x[1], reverse=True)
short_probability_list = sorted_probability_list[0+offset:params.ultimate_expansion+offset]
xprint("PART 1")
elif len(this_meter)>len(target_meter)-3:
sorted_probability_list = sorted(enumerate(probs), key=lambda x: x[1], reverse=True)
short_probability_list = sorted_probability_list[0+offset:params.penultimate_expansion+offset]
xprint("PART 2")
elif len(this_meter)<1:
sorted_probability_list = sorted(enumerate(probs), key=lambda x: x[1], reverse=True)
short_probability_list = sorted_probability_list
xprint("PART 3")
else:
sorted_probability_list = sorted(enumerate(probs), key=lambda x: x[1], reverse=True)
short_probability_list = sorted_probability_list[0+offset:params.other_expansion+offset]
xprint("PART 4")
short_probability_list = [i for i in short_probability_list if i[1] != 0]
xprint("short prob list len = ", end =" ")
xprint(len(short_probability_list))

return short_probability_list

def grow_branches(these_tokens, probs, input_probability,past,params, prompt_length,target_rhyme,target_meter):
#recursive function to find all sentence completions
global model
Expand All @@ -157,15 +226,14 @@ def grow_branches(these_tokens, probs, input_probability,past,params, prompt_len
global stuck_counter
global past_backup
stuck_counter = stuck_counter + 1
print(stuck_counter, end = "\t")
if stuck_counter > params.stuck_counter_limit:
stuck_counter = 0
past = past_backup
these_tokens = these_tokens[:prompt_length]
found = None
this_text_sentence = tokenizer.decode(these_tokens[prompt_length:])
if len(these_tokens[prompt_length:])<2:
probability_threshold = 0
probability_threshold = 0 # no restrictions on the first tokens in each line.
else:
probability_threshold = params.probability_threshold
short_probability_list = rhyme_and_meter_filter(this_text_sentence,target_rhyme,target_meter,probs,params)
Expand Down Expand Up @@ -247,11 +315,7 @@ def grow_branches(these_tokens, probs, input_probability,past,params, prompt_len
#this isn't long enough to be a complete line, but it could be the start of a complete line, so we expand it.
(next_probability_list,next_past) = expand_node(next_tokens,past)
found = grow_branches(next_tokens,next_probability_list, next_probability, next_past,params,prompt_length,target_rhyme,target_meter)
else:
pass
if found == False:
pass
else:
if found != False:
return found
continue
else:
Expand All @@ -261,77 +325,8 @@ def grow_branches(these_tokens, probs, input_probability,past,params, prompt_len
return False
return False

def rhyme_and_meter_filter(this_text_sentence,target_rhyme,target_meter,probs,params):
#returns a sorted list of words which are (usually) compatible with the upcoming rhyme and meter constraints.
#It's meant to make searching faster, not to be a perfect filter
global stress_tokens
global big_rhymesets
global acceptable_punctuation
global rhyming_tokens
global syllable_tokens
#this randomization helps prevent repetition, but it's kind of a hack.
offset = randint(0,2)
this_meter = text_to_meter(this_text_sentence,stress_dictionary)

#meter filter
next_stresses = target_meter[len(this_meter):min(len(this_meter)+3,len(target_meter)+1)]
if len(next_stresses)==0:
return []
all_tokens = set(range(0,50257))
stress_okay = set(stress_tokens[next_stresses])
#all tokens EXCEPT those with okay stress or acceptable punctuation are zeroed out.
for token in all_tokens.difference(stress_okay.union(acceptable_punctuation)):
probs[token] = 0

#rhyme_filter
xprint("meter_length = ", end = "")
xprint(len(this_meter))
if len(target_rhyme)>0 and target_rhyme != "!":
target_rhyme_words = target_rhyme.split(" ")
last_target_rhyme_word = target_rhyme_words[-1].strip().lower()
if last_target_rhyme_word[-1] in {"!",".",",",";",":","?","-"}:
last_target_rhyme_word = last_target_rhyme_word[:-1]
xprint("target rhyme =",end="")
xprint(last_target_rhyme_word)
these_rhyming_tokens = rhyming_tokens[last_target_rhyme_word]
xprint(tokenizer.decode(these_rhyming_tokens))
if len(this_meter)==len(target_meter)-1:
for t in range(0,50257):
if t in these_rhyming_tokens:
pass
else:
probs[t] = 0
elif len(this_meter)==len(target_meter)-2:
# either a rhyming word or a one-syllable word which could be followed by a rhyming word is okay.
safeset = syllable_tokens[1].union(these_rhyming_tokens)
for t in range(0,50257):
if t in safeset:
pass
else:
probs[t] = 0
sorted_probability_list = sorted(enumerate(probs), key=lambda x: x[1], reverse=True)
short_probability_list = sorted_probability_list[0+offset:params.ultimate_expansion+offset]
xprint("PART 1")
elif len(this_meter)>len(target_meter)-3:
sorted_probability_list = sorted(enumerate(probs), key=lambda x: x[1], reverse=True)
short_probability_list = sorted_probability_list[0+offset:params.penultimate_expansion+offset]
xprint("PART 2")
elif len(this_meter)<1:
sorted_probability_list = sorted(enumerate(probs), key=lambda x: x[1], reverse=True)
short_probability_list = sorted_probability_list
xprint("PART 3")
else:
sorted_probability_list = sorted(enumerate(probs), key=lambda x: x[1], reverse=True)
short_probability_list = sorted_probability_list[0+offset:params.other_expansion+offset]
xprint("PART 4")
short_probability_list = [i for i in short_probability_list if i[1] != 0]
xprint("short prob list len = ", end =" ")
xprint(len(short_probability_list))

return short_probability_list

def expand_node(sentence, past):
#finds probabilities for the next token using gpt-2
#finds probabilities for the next token using gpt-2. This is the only computationally expensive operation in the program.
global model
if past == None:
input_ids = torch.tensor(sentence).unsqueeze(0)
Expand All @@ -340,7 +335,7 @@ def expand_node(sentence, past):
inputs = {'input_ids': input_ids}
with torch.no_grad():
logits, past = model(**inputs, past=past)
logits[0][0][50256]=-math.inf
logits[0][0][50256]=-math.inf # no <end of text> token
logits = logits[:, -1, :]
probs = F.softmax(logits, dim=-1).tolist()[0]
return (probs, past)
Expand Down

0 comments on commit 71d9024

Please sign in to comment.