Skip to content

Commit

Permalink
feat: show additional networks in client
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Mar 19, 2023
1 parent e5862d1 commit 2d11210
Show file tree
Hide file tree
Showing 7 changed files with 91 additions and 37 deletions.
18 changes: 18 additions & 0 deletions api/onnx_web/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from typing import Literal

NetworkType = Literal["inversion", "lora"]


class NetworkModel:
name: str
type: NetworkType

def __init__(self, name: str, type: NetworkType) -> None:
self.name = name
self.type = type

def tojson(self):
return {
"name": self.name,
"type": self.type,
}
4 changes: 2 additions & 2 deletions api/onnx_web/server/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@
get_correction_models,
get_diffusion_models,
get_extra_strings,
get_inversion_models,
get_mask_filters,
get_network_models,
get_noise_sources,
get_upscaling_models,
)
Expand Down Expand Up @@ -111,7 +111,7 @@ def list_models(context: ServerContext):
{
"correction": get_correction_models(),
"diffusion": get_diffusion_models(),
"inversion": get_inversion_models(),
"networks": get_network_models(),
"upscaling": get_upscaling_models(),
}
)
Expand Down
43 changes: 31 additions & 12 deletions api/onnx_web/server/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from glob import glob
from logging import getLogger
from os import path
from typing import Any, Dict, List, Union
from typing import Any, Dict, List, Optional, Union

import torch
import yaml
Expand All @@ -20,6 +20,7 @@
noise_source_normal,
noise_source_uniform,
)
from ..models import NetworkModel
from ..params import DeviceParams
from ..torch_before_ort import get_available_providers
from ..utils import merge
Expand Down Expand Up @@ -58,7 +59,7 @@
# loaded from model_path
correction_models: List[str] = []
diffusion_models: List[str] = []
inversion_models: List[str] = []
network_models: List[NetworkModel] = []
upscaling_models: List[str] = []

# Loaded from extra_models
Expand All @@ -81,8 +82,8 @@ def get_diffusion_models():
return diffusion_models


def get_inversion_models():
return inversion_models
def get_network_models():
return network_models


def get_upscaling_models():
Expand Down Expand Up @@ -184,10 +185,12 @@ def load_extras(context: ServerContext):
extra_strings = strings


def list_model_globs(context: ServerContext, globs: List[str]) -> List[str]:
def list_model_globs(
context: ServerContext, globs: List[str], base_path: Optional[str] = None
) -> List[str]:
models = []
for pattern in globs:
pattern_path = path.join(context.model_path, pattern)
pattern_path = path.join(base_path or context.model_path, pattern)
logger.debug("loading models from %s", pattern_path)

models.extend([get_model_name(f) for f in glob(pattern_path)])
Expand All @@ -200,9 +203,10 @@ def list_model_globs(context: ServerContext, globs: List[str]) -> List[str]:
def load_models(context: ServerContext) -> None:
global correction_models
global diffusion_models
global inversion_models
global network_models
global upscaling_models

# main categories
diffusion_models = list_model_globs(
context,
[
Expand All @@ -220,21 +224,36 @@ def load_models(context: ServerContext) -> None:
)
logger.debug("loaded correction models from disk: %s", correction_models)

upscaling_models = list_model_globs(
context,
[
"upscaling-*",
],
)
logger.debug("loaded upscaling models from disk: %s", upscaling_models)

# additional networks
inversion_models = list_model_globs(
context,
[
"inversion-*",
"*",
],
base_path=path.join(context.model_path, "inversion"),
)
logger.debug("loaded Textual Inversion models from disk: %s", inversion_models)
network_models.extend(
[NetworkModel(model, "inversion") for model in inversion_models]
)
logger.debug("loaded inversion models from disk: %s", inversion_models)

upscaling_models = list_model_globs(
lora_models = list_model_globs(
context,
[
"upscaling-*",
"*",
],
base_path=path.join(context.model_path, "lora"),
)
logger.debug("loaded upscaling models from disk: %s", upscaling_models)
logger.debug("loaded LoRA models from disk: %s", lora_models)
network_models.extend([NetworkModel(model, "lora") for model in lora_models])


def load_params(context: ServerContext) -> None:
Expand Down
2 changes: 1 addition & 1 deletion api/params.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{
"version": "0.8.1",
"version": "0.9.0",
"batch": {
"default": 1,
"min": 1,
Expand Down
11 changes: 9 additions & 2 deletions gui/src/client/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -181,13 +181,20 @@ export interface ReadyResponse {
ready: boolean;
}

export interface NetworkModel {
name: string;
type: 'inversion' | 'lora';
// TODO: add token
// TODO: add layer/token count
}

/**
* List of available models.
*/
export interface ModelsResponse {
diffusion: Array<string>;
correction: Array<string>;
inversion: Array<string>;
diffusion: Array<string>;
networks: Array<NetworkModel>;
upscaling: Array<string>;
}

Expand Down
48 changes: 29 additions & 19 deletions gui/src/components/control/ModelControl.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ export function ModelControl() {
/>
<QueryList
id='diffusion'
labelKey='model'
labelKey='diffusion'
name={t('modelType.diffusion')}
query={{
result: models,
Expand All @@ -55,25 +55,9 @@ export function ModelControl() {
});
}}
/>
<QueryList
id='inversion'
labelKey='model'
name={t('modelType.inversion')}
query={{
result: models,
selector: (result) => result.inversion,
}}
showEmpty={true}
value={params.inversion}
onChange={(inversion) => {
setModel({
inversion,
});
}}
/>
<QueryList
id='upscaling'
labelKey='model'
labelKey='upscaling'
name={t('modelType.upscaling')}
query={{
result: models,
Expand All @@ -88,7 +72,7 @@ export function ModelControl() {
/>
<QueryList
id='correction'
labelKey='model'
labelKey='correction'
name={t('modelType.correction')}
query={{
result: models,
Expand All @@ -113,5 +97,31 @@ export function ModelControl() {
}}
/>}
/>
<QueryList
id='inversion'
labelKey='inversion'
name={t('modelType.inversion')}
query={{
result: models,
selector: (result) => result.networks.filter((network) => network.type === 'inversion').map((network) => network.name),
}}
value={params.correction}
onChange={(correction) => {
// noop
}}
/>
<QueryList
id='lora'
labelKey='lora'
name={t('modelType.lora')}
query={{
result: models,
selector: (result) => result.networks.filter((network) => network.type === 'lora').map((network) => network.name),
}}
value={params.correction}
onChange={(correction) => {
// noop
}}
/>
</Stack>;
}
2 changes: 1 addition & 1 deletion gui/src/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ export interface Config<T = ClientParams> {
}

export const IMAGE_FILTER = '.bmp, .jpg, .jpeg, .png';
export const PARAM_VERSION = '>=0.4.0';
export const PARAM_VERSION = '>=0.9.0';

export const STALE_TIME = 300_000; // 5 minutes
export const POLL_TIME = 5_000; // 5 seconds
Expand Down

0 comments on commit 2d11210

Please sign in to comment.