In [None]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.append("../")
sys.path.append("../website")

In [None]:
import re
import stanza
from collections import Counter
import itertools
from title_maker_pro import datasets
import pickle
import torch
from transformers import AutoModelWithLMHead, AutoTokenizer
from words import WordIndex, Word
# stanza.download('en')  

In [None]:
def print_words(words, f):
    for word in words:
        word_str = [word.word]
        if word.pos:
            word_str.append(f"/{word.pos}/")
        if word.topic:
            word_str.append(f"[{word.topic}]")
        print(" ".join(word_str), file=f)
        print(f"\t{word.definition}{' |n| ' if word.example is None else ''}", file=f)
        if word.example:
            print(f"\t\"{word.example}\"", file=f)

        print("", file=f)

In [None]:
nlp = stanza.Pipeline(lang='en', processors='tokenize,mwt,pos')
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.add_special_tokens(datasets.SpecialTokens.special_tokens_dict())
blacklist = datasets.Blacklist.load("/mnt/evo/projects/title-maker-pro/models/blacklist2.pickle")
model = AutoModelWithLMHead.from_pretrained("/mnt/evo/projects/title-maker-pro/models/en_dictionary_parsed_lr_00001_creativity/checkpoint-120000/").to("cuda:0")

In [None]:
%timeit datasets.ParsedDictionaryDefinitionDataset.evaluate_creativity(tokenizer, model, blacklist, 100, 50, max_length=512)

In [None]:
def no_weird(w):
    return (
            w.word[-1] != "-"
            and "<|" not in w.definition
            and "<|" not in w.example
            and (not w.pos or "<|" not in w.pos)
            and len(w.word.split()) <= 3
            and len(w.definition.split()) >= 3
            and len(w.example.split()) >= 3
        )
         
def go(**kwargs):
    return datasets.ParsedDictionaryDefinitionDataset.generate_words(
        tokenizer, model,
        num=100000,
        max_iterations=50000, 
        blacklist=blacklist, 
        example_match_pos_pipeline=nlp,
        generation_args=dict(
            top_k=200,
            num_return_sequences=500,
            max_length=250,
            do_sample=True,
        ),
        filter_proper_nouns=True,
        user_filter=no_weird,
        dedupe_titles=True,
        min_definition_words=3,
        **kwargs
    )

# words, stats = go()
# print(stats)
# print()
# print_words(words, sys.stdout)

In [None]:
words, stats = go(use_custom_generate=True)
print(stats)
#print_words(words, sys.stdout)

In [None]:
print(len(words))

In [None]:
blacklist.contains("foolage")

In [None]:
len(blacklist.blacklist_set)

In [None]:
print_words(words[:50], sys.stdout)

In [None]:
import math
from transformers import activations
import transformers

def gelu_new(x):
    return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))

activations.ACT2FN['gelu_new'] = gelu_new

model = AutoModelWithLMHead.from_pretrained("../build/forward-dictionary-model-v1").to("cpu")
quantized_model = torch.quantization.quantize_dynamic(
    model, {torch.nn.Linear, torch.nn.Embedding, transformers.modeling_utils.Conv1D}, dtype=torch.qint8
)

In [None]:
a = go2()

In [None]:
print(tokenizer.decode(a[0]))

In [None]:
%timeit go2()

In [None]:
from words import WordIndex, Word
def clean_example(w, example):
    return re.sub(re.escape(w), w, example, flags=re.IGNORECASE)

In [None]:
from hyphen import Hyphenator
h_en = Hyphenator('en_US')

wi = WordIndex(
    [
        Word(
            word=w.word,
            definition=w.definition,
            pos=w.pos,
            topic=w.topic,
            example=clean_example(w.word, w.example),
            syllables=h_en.syllables(w.word),
            probably_exists=False,
        ) for w in words
        
    ]
)
wi.dump("../website/data/words3.json")

In [None]:

h_en.syllables('fancccwe')
wi2 = WordIndex.load("../website/data/words.json")
wi_p = WordIndex(
    [
        Word(
            word=w.word,
            definition=w.definition,
            pos=w.pos,
            topic=w.topic,
            example=clean_example(w.word, w.example),
            syllables=h_en.syllables(w.word)
        )
        for w in wi2.words
    ]
)

In [None]:
wi_p.dump("../website/data/words.json")