Skip to content

Commit

Permalink
Update main.py
Browse files Browse the repository at this point in the history
  • Loading branch information
congxuma committed Apr 17, 2023
1 parent f8692eb commit e9a8511
Showing 1 changed file with 65 additions and 39 deletions.
104 changes: 65 additions & 39 deletions plugins/plugin_replicate/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@
from plugins import *
from common.log import logger
import replicate
from common.expired_dict import ExpiredDict

@plugins.register(name="replicate", desc="利用replicate api来画图", version="0.1", author="lanvent")
class SDWebUI(Plugin):
@plugins.register(name="replicate", desc="利用replicate api来画图", version="0.2", author="lanvent")
class Replicate(Plugin):
def __init__(self):
super().__init__()
curdir = os.path.dirname(__file__)
config_path = os.path.join(curdir, "config.json")
self.params_cache = ExpiredDict(60 * 60)
if not os.path.exists(config_path):
logger.info('[RP] 配置文件不存在,将使用config-template.json模板')
config_path = os.path.join(curdir, "config.json.template")
Expand All @@ -40,61 +42,86 @@ def __init__(self):
if isinstance(e, FileNotFoundError):
logger.warn(f"[RP] init failed, config.json not found.")
else:
logger.warn("[RP] init failed.")
logger.warn("[RP] init failed." + str(e))
raise e

def on_handle_context(self, e_context: EventContext):

if e_context['context'].type != ContextType.IMAGE_CREATE:
if e_context['context'].type not in [ContextType.IMAGE_CREATE, ContextType.IMAGE]:
return

logger.debug("[RP] on_handle_context. content: %s" %e_context['context'].content)

logger.info("[RP] image_query={}".format(e_context['context'].content))
reply = Reply()
try:
user_id = e_context['context']["session_id"]
content = e_context['context'].content[:]
# 解析用户输入 如"横版 高清 二次元:cat"
if ":" in content:
keywords, prompt = content.split(":", 1)
else:
keywords = content
prompt = ""
if e_context['context'].type == ContextType.IMAGE_CREATE:
# 解析用户输入 如"横版 高清 二次元:cat"
content = content.replace(",", ",").replace(":", ":")
if ":" in content:
keywords, prompt = content.split(":", 1)
else:
keywords = content
prompt = ""

keywords = keywords.split()

keywords = keywords.split()
if "help" in keywords or "帮助" in keywords:
reply.type = ReplyType.INFO
reply.content = self.get_help_text(verbose = True)
else:
rule_params = {}
for keyword in keywords:
matched = False
for rule in self.rules:
if keyword in rule["keywords"]:
for key in rule["params"]:
rule_params[key] = rule["params"][key]
matched = True
break # 一个关键词只匹配一个规则
if not matched:
logger.warning("[RP] keyword not matched: %s" % keyword)
params = {**self.default_params, **rule_params}
params["prompt"] = params.get("prompt", "")+f", {prompt}"
logger.info("[RP] params={}".format(params))

if "help" in keywords or "帮助" in keywords:
reply.type = ReplyType.INFO
reply.content = self.get_help_text(verbose = True)
if params.get("image",None):
self.params_cache[user_id] = params
reply.type = ReplyType.INFO
reply.content = "请发送一张图片给我"
else:
model = self.client.models.get(params.pop("model"))
version = model.versions.get(params.pop("version"))
result = version.predict(**params)
reply.type = ReplyType.IMAGE_URL
reply.content = result[0]
e_context.action = EventAction.BREAK_PASS # 事件结束后,跳过处理context的默认逻辑
e_context['reply'] = reply
else:
rule_params = {}
for keyword in keywords:
matched = False
for rule in self.rules:
if keyword in rule["keywords"]:
for key in rule["params"]:
rule_params[key] = rule["params"][key]
matched = True
break # 一个关键词只匹配一个规则
if not matched:
logger.warning("[RP] keyword not matched: %s" % keyword)

params = {**self.default_params, **rule_params}
params["prompt"] = params.get("prompt", "")+f", {prompt}"
logger.info("[RP] params={}".format(params))
model = self.client.models.get(params.pop("model"))
version = model.versions.get(params.pop("version"))
result = version.predict(**params)
reply.type = ReplyType.IMAGE_URL
reply.content = result[0]
e_context.action = EventAction.BREAK_PASS # 事件结束后,跳过处理context的默认逻辑
cmsg = e_context['context']['msg']
if user_id in self.params_cache:
params = self.params_cache[user_id]
del self.params_cache[user_id]
cmsg.prepare()
img_key = params.pop("image")
params[img_key]=open(content,"rb")
model = self.client.models.get(params.pop("model"))
version = model.versions.get(params.pop("version"))
result = version.predict(**params)
reply.type = ReplyType.IMAGE_URL
reply.content = result
logger.info("[RP] result={}".format(result))
e_context['reply'] = reply
e_context.action = EventAction.BREAK_PASS # 事件结束后,跳过处理context的默认逻辑

except Exception as e:
reply.type = ReplyType.ERROR
reply.content = "[RP] "+str(e)
logger.error("[RP] exception: %s" % e)
e_context.action = EventAction.CONTINUE # 事件继续,交付给下个插件或默认逻辑
finally:
e_context['reply'] = reply
logger.exception("[RP] exception: %s" % e)
e_context.action = EventAction.CONTINUE # 事件继续,交付给下个插件或默认逻辑

def get_help_text(self, verbose = False, **kwargs):
if not conf().get('image_create_prefix'):
Expand All @@ -115,4 +142,3 @@ def get_help_text(self, verbose = False, **kwargs):
else:
help_text += "\n"
return help_text

0 comments on commit e9a8511

Please sign in to comment.