In [1]:
import sys; sys.path.insert(0,'..')
from bardbot import *

In [2]:
def generate(
        s,
        max_new_tokens=20,
        top_k=50, 
        top_p=0.95,
        num_return_sequences=1,
        **kwargs
    ):
    tokenizer = get_tokenizer()
    inputs = tokenizer(s, return_tensors="pt")
    
    model=get_model()
    generation_output = model.generate(
        **inputs,
        do_sample=True, 
        # max_length=max_length, 
        top_k=top_k, 
        top_p=top_p, 
        num_return_sequences=num_return_sequences,
        max_new_tokens=max_new_tokens,
        no_repeat_ngram_size=2
    )
    for out in generation_output:
        o=tokenizer.decode(out)
        yield o[len(s):]
    #o=[tokenizer.decode(x) for x in generation_output[0]]
    # return o

In [3]:
# s='From fairest creatures we desire increase'
# for x in generate(s,num_return_sequences=10):
#     print([x])

In [4]:
def get_num_sylls(w):
    w=''.join(x for x in w.strip() if x.isalpha())
    if not w: return 0
    return p.Word(w).num_syll

def get_lineobj(line):
    return p.Text(line).lines()[0] if type(line)==str else line

def rhymes(line1,line2,max_rhyme_dist=5):
    line1=get_lineobj(line1)
    line2=get_lineobj(line2)
    word1=line1.words()[-1].token
    word2=line2.words()[-1].token
    if word1==word2: return False
    dist=line1.rime_distance(line2)
    # print(dist,line1,'|',line2)
    return dist<=max_rhyme_dist

def ends_on_stress(line):
    lineobj=get_lineobj(line)
    sylls=lineobj.syllables()
    return sylls[-1].feats['prom.stress']>0

In [5]:
def gen_poem(
        prompt="Let us",
        num_lines=14,
        num_sylls=10,
        rhyme_scheme='abab cdcd efef gg',
        max_tries=50,
        num_return_sequences=10,
        max_rhyme_dist=3):
    lines=[]
    rhyme_scheme=rhyme_scheme.replace(' ','')
    tries=Counter()
    while len(lines)<num_lines:
        line_i=len(lines)
        # if tries[line_i] and not tries[line_i]%100: print(tries[line_i],'...')
        txt='\n'.join(lines) if lines else prompt
        giveup=False
        for line in generate(txt,max_new_tokens=20,num_return_sequences=num_return_sequences):
            tries[line_i]+=1
            if not lines: line=prompt+line
            toks = line.strip().split()
            ns=0
            linetoks=[]
            for tok in toks:
                if not any(x.isalpha() for x in tok): continue
                nstok=get_num_sylls(tok)
                if ns+nstok>num_sylls: break
                linetoks.append(tok)
                ns+=nstok
            if ns!=num_sylls: continue
            oline=' '.join(linetoks).strip()
            if '"' in oline: continue
            if '(' in oline: continue
            if ')' in oline: continue
            if '[' in oline: continue
            if ']' in oline: continue

            if linetoks[-1][-1].isalpha(): continue
            if '.' in oline[:-1]: continue
            if '?' in oline[:-1]: continue
            if line_i+1==num_lines and linetoks[-1][-1] not in {'.','!','?'}: continue

            # print('?',oline)
            oline=oline[0].upper()+oline[1:]
            lineobj = p.Text(oline).lines()[0]
            
            if not ends_on_stress(lineobj): continue
            
            line_rhyme_letter=rhyme_scheme[line_i]
            if lines:
                line_i_matcher=rhyme_scheme.index(line_rhyme_letter)
                if line_i_matcher<line_i:
                    oline0 = lines[line_i_matcher]
                    lineobj0 = p.Text(oline0).lines()[0]
                    if not rhymes(lineobj, lineobj0, max_rhyme_dist=max_rhyme_dist):
                        if tries[line_i]>=max_tries:
                            lastw=get_lineobj(lines[line_i_matcher]).words()[-1].token
                            print(f'!! Ugh! I tried {tries[line_i]} times, but I can\'t rhyme on {lastw}! Going back to rewrite line {line_i_matcher+1}')
                            print()
                            lines=lines[:line_i_matcher]
                            tries2=Counter()
                            for li,lx in enumerate(lines): 
                                #opre=f'{li+1} [{tries[li]}] {lx}'
                                tries2[li]=tries[li]
                                print(f'{lx:<60} {rhyme_scheme[li]:<10} {li+1:<10} [{tries[li]}]')
                            giveup=True
                            tries=tries2
                            # print(tries)
                            break
                        continue
            if giveup: break
            # opre=f'{line_i+1} [{tries[line_i]}] {oline}'
            print(f'{oline:<60} {line_rhyme_letter:<10} {line_i+1:<10} [{tries[line_i]}]')
            lines.append(oline)
            break

In [6]:
gen_poem(
    prompt='Look in thy glass and',
    num_lines=14,
    rhyme_scheme='abab cdcd efef gg',
    max_rhyme_dist=0,
    max_tries=1000,
    num_return_sequences=10
)

Loading GPT2 model
Look in thy glass and let me see the truth,                  a          1          [12]
And thou shalt see all truth that I have heard.              b          2          [20]
!! Ugh! I tried 1028 times, but I can't rhyme on truth! Going back to rewrite line 1

Look in thy glass and see if we can talk.                    a          1          [26]
Say, 'I'm going to pick up our guitar,                       b          2          [49]
!! Ugh! I tried 1002 times, but I can't rhyme on talk! Going back to rewrite line 1

Look in thy glass and ye will see the same,                  a          1          [12]
And in the midst of the wine is the sun.                     b          2          [17]
!! Ugh! I tried 1010 times, but I can't rhyme on same! Going back to rewrite line 1

Look in thy glass and it shall come to pass.                 a          1          [8]
But the world has no gods left, not by law,                  b          2          [90]
!! Ugh! I tried 1007 t