Skip to content

Commit

Permalink
UD in word service, bug fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
turtlesoupy committed May 17, 2020
1 parent 3d4aa3e commit 95eb2da
Show file tree
Hide file tree
Showing 14 changed files with 217 additions and 88 deletions.
123 changes: 72 additions & 51 deletions notebooks/urban_dictionary.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
"from collections import OrderedDict\n",
"import torch\n",
"from transformers import AutoModelWithLMHead, AutoTokenizer\n",
"import copy"
"import copy\n",
"from word_generator import WordGenerator"
]
},
{
Expand Down Expand Up @@ -118,37 +119,22 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# nlp = stanza.Pipeline(lang='en', processors='tokenize,mwt,pos', use)\n",
"tokenizer = AutoTokenizer.from_pretrained(\"gpt2\")\n",
"tokenizer.add_special_tokens(datasets.SpecialTokens.special_tokens_dict())\n",
"blacklist = datasets.Blacklist.load(\"/mnt/evo/projects/title-maker-pro/models/blacklist_urban_dictionary.pickle\")\n",
"model = AutoModelWithLMHead.from_pretrained(\"/mnt/evo/projects/title-maker-pro/models/urban_dictionary_250_cleaned_lr_00005_b9_seed4/checkpoint-140000\").to(\"cuda:0\")"
"# model = AutoModelWithLMHead.from_pretrained(\"/mnt/evo/projects/title-maker-pro/models/urban_dictionary_250_cleaned_lr_00005_b9_seed4/checkpoint-140000\").to(\"cuda:0\")# model = AutoModelWithLMHead.from_pretrained(\"/mnt/evo/projects/title-maker-pro/models/urban_dictionary_250_cleaned_lr_00005_b9_seed4/checkpoint-140000\").to(\"cuda:0\")"
]
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "2727cb2f87b84946a498bf5075aa84ae",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, max=10000.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"outputs": [],
"source": [
"words, stats = datasets.UrbanDictionaryDataset.generate_words(\n",
" tokenizer, model,\n",
Expand All @@ -169,33 +155,9 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Filter 'chnk' removed 0.08%\n",
"Filter 'cnt' removed 1.31%\n",
"Filter 'fg' removed 0.80%\n",
"Filter 'fggot' removed 0.75%\n",
"Filter 'ghetto' removed 0.57%\n",
"Filter 'indian' removed 0.18%\n",
"Filter 'mex' removed 0.15%\n",
"Filter 'ngga' removed 1.25%\n",
"Filter 'nig' removed 0.42%\n",
"Filter 'pki' removed 0.00%\n",
"Filter 'rape' removed 0.11%\n",
"Filter 'sknk' removed 0.32%\n",
"Filter 'slap' removed 0.21%\n",
"Filter 'ultra_bad_def' removed 0.24%\n",
"Filter 'ultra_bad_example' removed 0.03%\n",
"Filter 'ultra_bad_word' removed 0.14%\n",
"Total removed 6.58%\n"
]
}
],
"outputs": [],
"source": [
"import os\n",
"from title_maker_pro.bad_words import ULTRA_BAD_REGEX\n",
Expand Down Expand Up @@ -283,29 +245,88 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n",
"\n",
"\n",
"Downloading https://raw.githubusercontent.com/stanfordnlp/stanza-resources/master/resources_1.0.0.json: 116kB [00:00, 9.88MB/s] \u001b[A\u001b[A\u001b[A\n",
"2020-05-16 20:56:47 INFO: Downloading default packages for language: en (English)...\n",
"2020-05-16 20:56:48 INFO: File exists: /home/tdimson/stanza_resources/en/default.zip.\n",
"2020-05-16 20:56:51 INFO: Finished downloading models and saved to /home/tdimson/stanza_resources.\n",
"2020-05-16 20:56:51 WARNING: Can not find mwt: default from official model list. Ignoring it.\n",
"2020-05-16 20:56:51 INFO: Loading these models for language: en (English):\n",
"=======================\n",
"| Processor | Package |\n",
"-----------------------\n",
"| tokenize | ewt |\n",
"| pos | ewt |\n",
"=======================\n",
"\n",
"2020-05-16 20:56:51 INFO: Use device: gpu\n",
"2020-05-16 20:56:51 INFO: Loading: tokenize\n",
"2020-05-16 20:56:51 INFO: Loading: pos\n",
"2020-05-16 20:56:52 INFO: Done loading processors!\n"
]
}
],
"source": [
"wg = WordGenerator(\n",
" device=\"cuda:0\",\n",
" forward_model_path=\"/mnt/evo/projects/title-maker-pro/models/urban_dictionary_250_cleaned_lr_00005_b9_seed4/checkpoint-140000\",\n",
" inverse_model_path=None,\n",
" blacklist_path=\"/mnt/evo/projects/title-maker-pro/models/blacklist.pickle\",\n",
" quantize=False,\n",
" is_urban=True,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "460b19e3d1284415b1a1545225cf406f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"2"
"HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"GeneratedWord(word='cummy', pos=None, topic=None, definition='n: very big; enormous', example='That lady had a cummy penis!!', decoded='<|bod|> cummy <|bd|> n: very big; enormous <|be|> That lady had a cummy penis!! <|eod|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|>', decoded_tokens=[50257, 66, 13513, 50260, 77, 25, 845, 1263, 26, 9812, 50261, 2504, 10846, 550, 257, 10973, 1820, 16360, 3228, 50258, 50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259])"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": []
"source": [
"wg.generate_definition(\"cummy\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from word_service.word_service_proto import wordservice_pb"
"from word_service.word_service_proto import wordservice_pb# model = AutoModelWithLMHead.from_pretrained(\"/mnt/evo/projects/title-maker-pro/models/urban_dictionary_250_cleaned_lr_00005_b9_seed4/checkpoint-140000\").to(\"cuda:0\")"
]
}
],
Expand Down
1 change: 1 addition & 0 deletions scripts/start_word_service.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ export ASSET_PATH=/mnt/evo/projects/title-maker-pro
python word_service/wordservice_server.py \
--forward-model-path $ASSET_PATH/models/forward-dictionary-model-v1 \
--inverse-model-path $ASSET_PATH/models/inverse-dictionary-model-v1 \
--forward-urban-model-path $ASSET_PATH/models/forward-urban-dictionary-model-v1 \
--blacklist-path $ASSET_PATH/models/blacklist.pickle \
--quantize \
--device cpu
48 changes: 44 additions & 4 deletions title_maker_pro/word_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,14 @@


class WordGenerator:
def __init__(self, forward_model_path, inverse_model_path, blacklist_path, quantize=False, device=None):
def __init__(self, forward_model_path, inverse_model_path, blacklist_path, quantize=False, device=None, is_urban=False):
if not device:
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
else:
self.device = torch.device(device)

self.is_urban = is_urban

stanza.download("en")
self.stanza_pos_pipeline = stanza.Pipeline(
lang="en", processors="tokenize,mwt,pos", use_gpu=("cpu" not in self.device.type)
Expand All @@ -40,13 +42,20 @@ def __init__(self, forward_model_path, inverse_model_path, blacklist_path, quant
self.forward_model = ml(AutoModelWithLMHead, forward_model_path).to(self.device)
logger.info("Loaded forward model")

logger.info(f"Loading inverse model from {inverse_model_path}")
self.inverse_model = ml(AutoModelWithLMHead, inverse_model_path).to(self.device)
logger.info("Loaded inverse model")
if inverse_model_path:
logger.info(f"Loading inverse model from {inverse_model_path}")
self.inverse_model = ml(AutoModelWithLMHead, inverse_model_path).to(self.device)
logger.info("Loaded inverse model")
else:
self.inverse_model = None
logger.info(f"Skipping inverse model")

self.approx_max_length = 250

def generate_word(self, user_filter=None):
if self.is_urban:
raise RuntimeError("Urban dataset not supported yet")

expanded, _ = datasets.ParsedDictionaryDefinitionDataset.generate_words(
self.tokenizer,
self.forward_model,
Expand All @@ -68,7 +77,35 @@ def generate_word(self, user_filter=None):
def probably_real_word(self, word):
return self.blacklist.contains(word)

def generate_urban_definition(self, word, user_filter=None):
prefix = f"{datasets.SpecialTokens.BOS_TOKEN}{word}{datasets.SpecialTokens.DEFINITION_SEP}"
expanded, stats = datasets.UrbanDictionaryDataset.generate_words(
self.tokenizer,
self.forward_model,
num=1,
prefix=prefix,
max_iterations=1,
generation_args=dict(top_k=50, num_return_sequences=5, max_length=self.approx_max_length, do_sample=True,),
dedupe_titles=False,
user_filter=user_filter,
filter_proper_nouns=False,
use_custom_generate=True,
)

logger.info(f"Urban generation stats: {stats} (found {len(expanded)} true and {len(stats.viable_candidates)} viable)")

if expanded:
return expanded[0]
elif stats.viable_candidates:
ret = max(stats.viable_candidates, key=lambda x: x.score).candidate
return ret
else:
return None

def generate_definition(self, word, user_filter=None):
if self.is_urban:
return self.generate_urban_definition(word, user_filter)

prefix = f"{datasets.SpecialTokens.BOS_TOKEN}{word}{datasets.SpecialTokens.POS_SEP}"
expanded, stats = datasets.ParsedDictionaryDefinitionDataset.generate_words(
self.tokenizer,
Expand Down Expand Up @@ -127,6 +164,9 @@ def generate_definition(self, word, user_filter=None):
return None

def generate_word_from_definition(self, definition, user_filter=None):
if self.is_urban:
raise RuntimeError("Urban dataset not supported yet")

# Data peculiarity: definitions ending in a period are out of domain
prefix = f"{datasets.SpecialTokens.BOS_TOKEN}{definition.rstrip('. ')}{datasets.SpecialTokens.DEFINITION_SEP}"
expanded, stats = datasets.InverseParsedDictionaryDefinitionDataset.generate_words(
Expand Down
Binary file added website/data/words_ud_filtered.enc.gz
Binary file not shown.
Binary file added website/data/words_ud_unfiltered.enc.gz
Binary file not shown.
21 changes: 19 additions & 2 deletions website/main.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import re
import os
import json
import jinja2
Expand All @@ -21,6 +22,9 @@
from async_lru import alru_cache
from title_maker_pro.bad_words import grawlix
from pathlib import Path
from jinja2 import evalcontextfilter
from markupsafe import Markup, escape


logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -67,6 +71,18 @@ def _grpc_nonretriable(e: GRPCError):
)


_paragraph_re = re.compile(r'(?:\r\n|\r|\n){2,}')


@evalcontextfilter
def nl2br(eval_ctx, value):
result = u'\n\n'.join(p.replace('\n', Markup('<br>\n'))
for p in _paragraph_re.split(escape(value)))
if eval_ctx.autoescape:
result = Markup(result)
return result


class Handlers:
def __init__(
self,
Expand Down Expand Up @@ -156,15 +172,15 @@ def _word_from_url(self, request):
word_dict = json.loads(base64.urlsafe_b64decode(payload).decode("utf-8"))
w = words.Word.from_dict(word_dict)

if w.dataset_type and w.dataset_type != request.dataset_type:
if w.dataset_type and w.dataset_type != request.dataset:
raise _json_error(web.HTTPBadRequest, "Mismatched word dataset")

return w

@aiohttp_jinja2.template("index.jinja2")
async def word(self, request):
w = self._word_from_url(request)
return self._index_response(w, word_in_title=True)
return self._index_response(request, w, word_in_title=True)

async def shorten_word_url(self, request):
w = self._word_from_url(request)
Expand Down Expand Up @@ -272,6 +288,7 @@ def app(handlers=None):
"remove_period": lambda x: x.rstrip("."),
"escape_double": lambda x: x.replace('"', r'\"'),
"strip_quotes": lambda x: x.strip('"'),
"nl2br": nl2br,
},
)
return app
Expand Down
2 changes: 1 addition & 1 deletion website/static/bundle.js

Large diffs are not rendered by default.

11 changes: 9 additions & 2 deletions website/static_src/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,14 @@ const copy = require('clipboard-copy');

function wordURL(word, permalink, relative) {
const base = relative ? "" : "https://www.thisworddoesnotexist.com";
return `${base}/w/${encodeURIComponent(word)}` + (permalink ? `/${permalink}` : "");
let query_params = "";
if (typeof URLSearchParams !== undefined) {
const params = new URLSearchParams(window.location.search);
if (params.has("dataset") && params.has("secret")) {
query_params = `?dataset=${params.get("dataset")}&secret=${params.get("secret")}`;
}
}
return `${base}/w/${encodeURIComponent(word)}` + (permalink ? `/${permalink}` : "") + query_params;
}


Expand All @@ -28,7 +35,7 @@ function syncToWord(word, permalink, pushHistory) {
let wordExistsLinkEl = document.getElementById("word-exists-link");

posEl.innerHTML = word.pos;
if (!posEl.innerHTML.endsWith("]")) {
if (word.pos && !word.pos.endsWith("]")) {
posEl.innerHTML += ".";
}

Expand Down
Loading

0 comments on commit 95eb2da

Please sign in to comment.