Skip to content

Commit

Permalink
sync stable diffusion params to latest version
Browse files Browse the repository at this point in the history
  • Loading branch information
Cyberhan123 committed Dec 11, 2023
1 parent f42ca6b commit 321d6f0
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 12 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,5 @@
/out/
/cmake-build-debug/
/cmake-build-release/
/cmake-build-debug-visual-studio/
/cmake-build-release-visual-studio/
4 changes: 4 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ endif ()
project("stable-diffusion-build")

set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
#set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -O3")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -O2")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O2")


if (NOT XCODE AND NOT MSVC AND NOT CMAKE_BUILD_TYPE)
set(CMAKE_BUILD_TYPE Release CACHE STRING "Build type" FORCE)
Expand Down
2 changes: 1 addition & 1 deletion cmake/sd.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ if(${CMAKE_VERSION} VERSION_LESS 3.14)
include(add_FetchContent_MakeAvailable.cmake)
endif()

set(SD_GIT_TAG 47dd704198f46ee75b11cbaf3aa2b8f644df0be9)
set(SD_GIT_TAG ac8f5a044c762eaa1181a89ef673fcb174acfb47)
set(SD_GIT_URL https://github.com/leejet/stable-diffusion.cpp)
set(BUILD_SHARED_LIBS OFF)

Expand Down
54 changes: 46 additions & 8 deletions stable-diffusion-abi.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "stable-diffusion-abi.h"

#include "stable-diffusion.h"
#include "util.h"
#include <string>
#include <cstring>
#include <map>
Expand Down Expand Up @@ -40,6 +41,17 @@ const static std::map<std::string, enum Schedule> ScheduleMap = {
{"N_SCHEDULES", N_SCHEDULES},
};

const static std::map<std::string, enum ggml_type> ggmlTypeMap = {
{"DEFAULT", GGML_TYPE_COUNT},
{"F32", GGML_TYPE_F32},
{"F16", GGML_TYPE_F16},
{"Q4_0", GGML_TYPE_Q4_0},
{"Q4_1", GGML_TYPE_Q4_1},
{"Q5_0", GGML_TYPE_Q5_0},
{"Q5_1", GGML_TYPE_Q5_1},
{"Q8_0", GGML_TYPE_Q8_0},
};

void stable_diffusion_full_params_set_negative_prompt(
struct stable_diffusion_full_params* params,
const char* negative_prompt
Expand Down Expand Up @@ -128,6 +140,7 @@ struct stable_diffusion_ctx {
struct stable_diffusion_ctx* stable_diffusion_init(
const int n_threads,
const bool vae_decode_only,
const char * taesd_path,
const bool free_params_immediately,
const char* lora_model_dir,
const char* rng_type
Expand All @@ -139,6 +152,7 @@ struct stable_diffusion_ctx* stable_diffusion_init(
const auto sd = new StableDiffusion(
n_threads,
vae_decode_only,
std::string(taesd_path),
free_params_immediately,
std::string(lora_model_dir),
it->second
Expand All @@ -152,11 +166,23 @@ struct stable_diffusion_ctx* stable_diffusion_init(
bool stable_diffusion_load_from_file(
const struct stable_diffusion_ctx* ctx,
const char* file_path,
const char* vae_path,
const char* wtype,
const char* schedule
) {
auto e_wtype=ggmlTypeMap.find(std::string(wtype));
if (e_wtype!=ggmlTypeMap.end()){
e_wtype=ggmlTypeMap.find("DEFAULT");
}

const auto e_schedule = ScheduleMap.find(std::string(schedule));
if (e_schedule != ScheduleMap.end()) {
return ctx->sd->load_from_file(std::string(file_path), e_schedule->second);
return ctx->sd->load_from_file(
std::string(file_path),
std::string(vae_path),
e_wtype->second ,
e_schedule->second
);
}
return false;
};
Expand Down Expand Up @@ -237,15 +263,27 @@ const char* stable_diffusion_get_system_info() {
return buffer;
};

void stable_diffusion_free(const struct stable_diffusion_ctx* ctx) {
delete ctx->sd;
delete ctx;
void stable_diffusion_free(struct stable_diffusion_ctx* ctx) {
if (ctx!= nullptr){
if (ctx->sd!= nullptr){
delete ctx->sd;
ctx->sd= nullptr;
}
delete ctx;
ctx = nullptr;
}
};

void stable_diffusion_free_full_params(const struct stable_diffusion_full_params* params) {
delete params;
void stable_diffusion_free_full_params( struct stable_diffusion_full_params* params) {
if (params!= nullptr){
delete params;
params = nullptr;
}
}

void stable_diffusion_free_buffer(const uint8_t* buffer) {
delete [] buffer;
void stable_diffusion_free_buffer(uint8_t* buffer) {
if (buffer!= nullptr){
delete [] buffer;
buffer= nullptr;
}
}
10 changes: 7 additions & 3 deletions stable-diffusion-abi.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#ifndef STABLE_DIFFUSION_ABI_H
#define STABLE_DIFFUSION_ABI_H

#include "ggml/ggml.h"
#include "stable-diffusion.h"

#ifdef STABLE_DIFFUSION_SHARED
Expand Down Expand Up @@ -90,6 +91,7 @@ STABLE_DIFFUSION_API void stable_diffusion_full_params_set_strength(
STABLE_DIFFUSION_API stable_diffusion_ctx* stable_diffusion_init(
int n_threads,
bool vae_decode_only,
const char *taesd_path,
bool free_params_immediately,
const char* lora_model_dir,
const char* rng_type
Expand All @@ -99,6 +101,8 @@ STABLE_DIFFUSION_API stable_diffusion_ctx* stable_diffusion_init(
STABLE_DIFFUSION_API bool stable_diffusion_load_from_file(
const struct stable_diffusion_ctx* ctx,
const char* file_path,
const char* vae_path,
const char* wtype,
const char* schedule
);

Expand All @@ -119,11 +123,11 @@ STABLE_DIFFUSION_API void stable_diffusion_set_log_level(const char* level);

STABLE_DIFFUSION_API const char* stable_diffusion_get_system_info();

STABLE_DIFFUSION_API void stable_diffusion_free(const struct stable_diffusion_ctx* ctx);
STABLE_DIFFUSION_API void stable_diffusion_free(struct stable_diffusion_ctx* ctx);

STABLE_DIFFUSION_API void stable_diffusion_free_full_params(const struct stable_diffusion_full_params* params);
STABLE_DIFFUSION_API void stable_diffusion_free_full_params(struct stable_diffusion_full_params* params);

STABLE_DIFFUSION_API void stable_diffusion_free_buffer(const uint8_t* buffer);
STABLE_DIFFUSION_API void stable_diffusion_free_buffer(uint8_t* buffer);

#ifdef __cplusplus
}
Expand Down

0 comments on commit 321d6f0

Please sign in to comment.