Skip to content

Commit

Permalink
feat(api): add support for negative embeds (#348)
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Apr 22, 2023
1 parent 9e9feb2 commit 7b0095a
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 0 deletions.
9 changes: 9 additions & 0 deletions api/onnx_web/diffusers/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,15 @@ def parse_prompt(params: ImageParams) -> Tuple[List[Tuple[str, float]], List[Tup
prompt, inversions = get_inversions_from_prompt(prompt)
params.prompt = prompt

if params.input_negative_prompt is not None:
neg_prompt, neg_loras = get_loras_from_prompt(params.input_negative_prompt)
neg_prompt, neg_inversions = get_inversions_from_prompt(neg_prompt)
params.negative_prompt = neg_prompt

# TODO: check whether these need to be * -1
loras.extend(neg_loras)
inversions.extend(neg_inversions)

return loras, inversions


Expand Down
5 changes: 5 additions & 0 deletions api/onnx_web/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ class ImageParams:
batch: int
control: Optional[NetworkModel]
input_prompt: str
input_negative_prompt: str

def __init__(
self,
Expand All @@ -189,6 +190,7 @@ def __init__(
batch: int = 1,
control: Optional[NetworkModel] = None,
input_prompt: Optional[str] = None,
input_negative_prompt: Optional[str] = None,
) -> None:
self.model = model
self.pipeline = pipeline
Expand All @@ -202,6 +204,7 @@ def __init__(
self.batch = batch
self.control = control
self.input_prompt = input_prompt or prompt
self.input_negative_prompt = input_negative_prompt or negative_prompt

def lpw(self):
return self.pipeline == "lpw"
Expand All @@ -220,6 +223,7 @@ def tojson(self) -> Dict[str, Optional[Param]]:
"batch": self.batch,
"control": self.control.name if self.control is not None else "",
"input_prompt": self.input_prompt,
"input_negative_prompt": self.input_negative_prompt,
}

def with_args(self, **kwargs):
Expand All @@ -236,6 +240,7 @@ def with_args(self, **kwargs):
kwargs.get("batch", self.batch),
kwargs.get("control", self.control),
kwargs.get("input_prompt", self.input_prompt),
kwargs.get("input_negative_prompt", self.input_negative_prompt),
)


Expand Down

0 comments on commit 7b0095a

Please sign in to comment.