Skip to content

Commit

Permalink
second part addressing #31
Browse files Browse the repository at this point in the history
  • Loading branch information
painebenjamin committed Jul 1, 2023
1 parent 7d3fab4 commit 3cb413c
Show file tree
Hide file tree
Showing 7 changed files with 190 additions and 22 deletions.
28 changes: 26 additions & 2 deletions src/css/08-enfugue.css
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ enfugue-application > enfugue-node-editor.image-editor {
}

enfugue-application > enfugue-node-editor.windows {
z-index: 4;
z-index: 5;
}

enfugue-menu {
Expand Down Expand Up @@ -769,7 +769,7 @@ form.model-picker div#tensorrt span.fraction {
border-radius: 20px;
position: absolute;
top: 0;
right: 13px;
right: 12.1px;
font-size: 10px;
text-align: center;
color: black;
Expand Down Expand Up @@ -813,6 +813,30 @@ ul.model-picker-list-input-view > li > span {
justify-content: space-between;
}

.additional-weights-form-view {
position: absolute;
top: 75px;
left: 35px;
width: 520px;
max-width: calc(100vw - 380px);
z-index: 4;
opacity: 0.5;
transition: opacity 0.25s ease-in-out;
}

.additional-weights-form-view:hover {
opacity: 1;
}

.additional-weights-form-view legend {
flex-direction: row-reverse;
justify-content: flex-end;
}

.additional-weights-form-view legend::after {
margin-right: 10px;
}

.page-buttons {
display: flex;
flex-flow: row nowrap;
Expand Down
5 changes: 5 additions & 0 deletions src/css/09-enfugue-nodes.css
Original file line number Diff line number Diff line change
Expand Up @@ -860,6 +860,11 @@ table.log-table-view tr td:nth-child(3) {
width: 80px;
}

table.installation-directory-summary-table-view tr th:nth-child(1),
table.installation-directory-summary-table-view tr td:nth-child(1) {
width: 250px;
}

input.new-user-input-view,
input.new-model-input-view,
input.upload-file-input-view {
Expand Down
34 changes: 34 additions & 0 deletions src/js/controller/common/invocation.mjs
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,40 @@ class InvocationController extends Controller {
this.kwargs.upscale_diffusion_controlnet = newControlnet;
}

/**
* @return array<string> Optional textual inversion when not using preconfigured models
*/
get inversion() {
return this.kwargs.inversion || [];
}

/**
* @param array<string> The new value of inversion for when not using preconfigured models
*/
set inversion(newInversion) {
if(!isEquivalent(this.inversion, newInversion)) {
this.publish("engineInversionChange", newInversion);
}
this.kwargs.inversion = newInversion;
}

/**
* @return array<object> Optional lora when not using preconfigured models
*/
get lora() {
return this.kwargs.lora || [];
}

/**
* @param array<object> The new value of lora for when not using preconfigured models
*/
set lora(newLora) {
if(!isEquivalent(this.lora, newLora)) {
this.publish("engineLoraChange", newLora);
}
this.kwargs.lora = newLora;
}

/**
* On initialization, create DOM elements related to invocations.
*/
Expand Down
6 changes: 5 additions & 1 deletion src/js/controller/common/model-manager.mjs
Original file line number Diff line number Diff line change
Expand Up @@ -356,4 +356,8 @@ class ModelManagerController extends Controller {
}
}

export { ModelManagerController };
export {
ModelManagerController,
MultiLoraInputView,
MultiInversionInputView
};
95 changes: 81 additions & 14 deletions src/js/controller/common/model-picker.mjs
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,26 @@ import { Controller } from "../base.mjs";
import { TableView } from "../../view/table.mjs";
import { View } from "../../view/base.mjs";
import { FormView } from "../../view/forms/base.mjs";
import { SearchListInputView, StringInputView, SearchListInputListView } from "../../view/forms/input.mjs";
import {
SearchListInputView,
StringInputView,
SearchListInputListView
} from "../../view/forms/input.mjs";
import { MultiLoraInputView, MultiInversionInputView } from "./model-manager.mjs";
import { isEmpty, waitFor, createElementsFromString } from "../../base/helpers.mjs";
import { ElementBuilder } from "../../base/builder.mjs";

const E = new ElementBuilder();

/**
* Extend the SearchListInputListView to add additional classes
*/
class ModelPickerListInputView extends SearchListInputListView {
/**
* @var array<string> CSS classes
*/
static classList = SearchListInputListView.classList.concat(["model-picker-list-input-view"]);
}
};

/**
* Extend the StringInputView so we can strip HTML from the value
Expand All @@ -22,11 +33,15 @@ class ModelPickerStringInputView extends StringInputView {
*/
setValue(newValue, triggerChange) {
if(!isEmpty(newValue)) {
newValue = createElementsFromString(newValue)[0].innerText;
if (newValue.startsWith("<")) {
newValue = createElementsFromString(newValue)[0].innerText;
} else {
newValue = newValue.split("/")[1];
}
}
return super.setValue(newValue, triggerChange);
}
}
};

/**
* We extend the SearchListInputView to change some default config.
Expand All @@ -42,9 +57,54 @@ class ModelPickerInputView extends SearchListInputView {
*/
static stringInputClass = ModelPickerStringInputView;

/**
* @var class The class of the list input, override so we can add css classes
*/
static listInputClass = ModelPickerListInputView
};

/**
* This form allows additional pipeline weights when using a checkpoint
*/
class AdditionalWeightsFormView extends FormView {
/**
* @var string Custom CSS class
*/
static className = "additional-weights-form-view";

/**
* @var boolean no submit button
*/
static autoSubmit = true;

/**
* @var boolean Start hidden
*/
static collapseFieldSets = true;

/**
* @var object one fieldset describes all inputs
*/
static fieldSets = {
"Additional Weights": {
"lora": {
"class": MultiLoraInputView,
"label": "LoRA",
"config": {
"tooltip": "LoRA stands for <strong>Low Rank Adapation</strong>, it is a kind of fine-tuning that can perform very specific modifications to Stable Diffusion such as training an individual's appearance, new products that are not in Stable Diffusion's training set, etc."
}
},
"inversion": {
"class": MultiInversionInputView,
"label": "Textual Inversion",
"config": {
"tooltip": "Textual Inversion is another kind of fine-tuning that teaches novel concepts to Stable Diffusion in a small number of images, which can be used to positively or negatively affect the impact of various prompts."
}
}
}
};
};

/**
* Extend the TableView to disable sorting and add conditional buttons
*/
Expand Down Expand Up @@ -280,14 +340,14 @@ class ModelPickerController extends Controller {
* Get state from the model picker
*/
getState() {
return { "model": this.formView.values };
return { "model": this.formView.values, "weights": this.additionalWeightsFormView.values };
}

/**
* Gets default state
*/
getDefaultState() {
return { "model": null };
return { "model": null, "weights": null };
}

/**
Expand All @@ -297,9 +357,11 @@ class ModelPickerController extends Controller {
if (!isEmpty(newState.model)) {
this.formView.setValues(newState.model).then(() => this.formView.submit());
}
if (!isEmpty(newState.weights)) {
this.additionalWeightsFormView.setValues(newState.weights).then(() => this.additionalWeightsFormView.submit());
}
}


/**
* Issues the request to the engine to build a specific engine
*/
Expand All @@ -308,7 +370,6 @@ class ModelPickerController extends Controller {
this.notify("info", "Build Started", "The engine will be busy throughout this TensorRT build. You will see a notification when it is complete, and the status indicator in the top bar will show ready or idle.");
await waitFor(
() => {
console.log("Built engines are", this.builtEngines);
return !isEmpty(this.builtEngines[model]) && this.builtEngines[model].indexOf(engine) !== -1;
},
{
Expand Down Expand Up @@ -365,14 +426,17 @@ class ModelPickerController extends Controller {
}, {});
return modelOptions;
};

this.formView = new ModelPickerFormView(this.config);

this.additionalWeightsFormView = new AdditionalWeightsFormView(this.config);

this.formView.onSubmit(async (values) => {
if (values.model) {
let [selectedType, selectedName] = values.model.split("/");
this.engine.model = selectedName;
this.engine.modelType = selectedType;
if (selectedType === "model") {
this.additionalWeightsFormView.hide();
try {
let fullModel = await this.model.DiffusionModel.query({name: selectedName}),
tensorRTStatus = await fullModel.getTensorRTStatus();
Expand All @@ -386,14 +450,22 @@ class ModelPickerController extends Controller {
this.formView.setValues({"model": null});
}
} else {
this.additionalWeightsFormView.show();
this.formView.setTensorRTStatus({supported: false});
}
} else {
this.formView.setTensorRTStatus({supported: false});
}
});

this.additionalWeightsFormView.onSubmit(async (values) => {
this.engine.lora = values.lora;
this.engine.inversion = values.inversion;
});

this.application.container.appendChild(await this.formView.render());
this.application.container.appendChild(await this.additionalWeightsFormView.render());

this.subscribe("invocationError", (payload) => {
console.error(payload);
if (!isEmpty(payload.metadata) && !isEmpty(payload.metadata.tensorrt_build)) {
Expand All @@ -402,11 +474,6 @@ class ModelPickerController extends Controller {
model = payload.metadata.tensorrt_build.model;

this.notify("info", "TensorRT Engine Build Failed", `${model} ${networkName} TensorRT Engine failed to build. Please try again.`);

if (isEmpty(this.builtEngines[model])) {
this.builtEngines[model] = [];
}
this.builtEngines[model].push(network);
}
});
this.subscribe("invocationComplete", (payload) => {
Expand Down
19 changes: 15 additions & 4 deletions src/js/controller/system/03-installation.mjs
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,17 @@ class UploadFileButtonInputView extends ButtonInputView {
* This is the table view for the summary over all directories
*/
class InstallationDirectorySummaryTableView extends TableView {
/**
* @var string Custom class name
*/
static className = "installation-directory-summary-table-view";

/**
* @var object column names and labels
*/
static columns = {
"directory": "Directory",
"location": "Location",
"directory": "Directory",
"items": "Items",
"bytes": "Total File Size"
};
Expand All @@ -99,6 +104,11 @@ class InstallationDirectorySummaryTableView extends TableView {
* This is the table view for a single directory
*/
class InstallationDirectoryTableView extends TableView {
/**
* @var string Custom class name
*/
static className = "installation-directory-table-view";

/**
* @var object column names and labels
*/
Expand Down Expand Up @@ -197,7 +207,7 @@ class InstallationSummaryView extends View {
this.controller.showDirectoryManager(row.directory);
});
this.summaryTable.addButton("Change Directory", "fa-solid fa-edit", (row) => {
this.controller.showChangeDirectory(row.directory);
this.controller.showChangeDirectory(row.directory, row.location);
});
this.engineTable = new TensorRTEngineSummaryTableView(this.config);
this.engineTable.addButton("Manage", "fa-solid fa-list-check", () => {
Expand Down Expand Up @@ -421,8 +431,8 @@ class InstallationController extends MenuController {
/**
* Shows the 'change directory' dialogue for a directory
*/
async showChangeDirectory(directory) {
let changeDirectoryForm = new ChangeDirectoryForm(this.config),
async showChangeDirectory(directory, currentValue) {
let changeDirectoryForm = new ChangeDirectoryForm(this.config, {"directory": currentValue}),
changeDirectoryWindow = await this.spawnWindow(
`Change Filesystem Location for ${directory}`,
changeDirectoryForm,
Expand Down Expand Up @@ -484,6 +494,7 @@ class InstallationController extends MenuController {
}
});
let container = new ParentView(this.config);
container.addClass("installation-summary-view");
container.addChild(table);
if (this.constructor.uploadableDirectories.indexOf(directory) !== -1) {
let uploadView = new UploadFileButtonInputView(this.config);
Expand Down
25 changes: 24 additions & 1 deletion src/python/enfugue/api/controller/invocation.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,30 @@ def invoke_engine(self, request: Request, response: Response) -> Dict[str, Any]:
os.path.join(self.engine_root, "checkpoint")
),
model_name
)
),
"lora": [
(
os.path.join(
self.configuration.get("enfugue.engine.lora",
os.path.join(self.engine_root, "lora")
),
lora["model"]
),
float(lora["weight"])
)
for lora
in request.parsed.pop("lora", [])
],
"inversion": [
os.path.join(
self.configuration.get("enfugue.engine.inversion",
os.path.join(self.engine_root, "inversion")
),
inversion
)
for inversion
in request.parsed.pop("inversion", [])
],
}
plan = DiffusionPlan.from_nodes(**{**plan_kwargs, **request.parsed})
return self.invoke(request.token.user.id, plan).format()
Expand Down

0 comments on commit 3cb413c

Please sign in to comment.