Skip to content

Commit

Permalink
feat(api): add range syntax to expand numbered tokens (#179)
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Mar 8, 2023
1 parent 66c4248 commit 0a4f83a
Showing 1 changed file with 13 additions and 0 deletions.
13 changes: 13 additions & 0 deletions api/onnx_web/diffusers/utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,25 @@
from logging import getLogger
from math import ceil
from re import compile
from typing import List, Optional

import numpy as np
from diffusers import OnnxStableDiffusionPipeline

logger = getLogger(__name__)


MAX_TOKENS_PER_GROUP = 77
PATTERN_RANGE = compile("(\w+)-{(\d+),(\d+)(?:,(\d+))?}")


def expand_prompt_ranges(prompt: str) -> str:
def expand_range(match):
(base_token, start, end, step) = match.groups(default=1)
num_tokens = [f"{base_token}-{i}" for i in range(int(start), int(end), int(step))]
return " ".join(num_tokens)

return PATTERN_RANGE.sub(expand_range, prompt)


def expand_prompt(
Expand All @@ -22,6 +34,7 @@ def expand_prompt(
# encoder: OnnxRuntimeModel

batch_size = len(prompt) if isinstance(prompt, list) else 1
prompt = expand_prompt_ranges(prompt)

# split prompt into 75 token chunks
tokens = self.tokenizer(
Expand Down

0 comments on commit 0a4f83a

Please sign in to comment.