Skip to content

Commit

Permalink
Improve conversation (#609)
Browse files Browse the repository at this point in the history
* fix conversation form

* refactor conversation list
  • Loading branch information
an-lee committed May 14, 2024
1 parent e090cca commit 838ed1e
Show file tree
Hide file tree
Showing 9 changed files with 204 additions and 238 deletions.
3 changes: 3 additions & 0 deletions enjoy/src/i18n/en.json
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,9 @@
"selectScenario": "Select scenario",
"selectAiEngine": "Select AI engine",
"selectAiModel": "Select AI model",
"selectTtsEngine": "Select TTS engine",
"selectTtsModel": "Select TTS model",
"selectTtsVoice": "Select TTS voice",
"youNeedToSetupApiKeyBeforeUsingOpenAI": "You need to setup API key before using OpenAI",
"ensureYouHaveOllamaRunningLocallyAndHasAtLeastOneModel": "Ensure you have Ollama running locally and has at least one model",
"creatingSpeech": "Speech is creating",
Expand Down
6 changes: 5 additions & 1 deletion enjoy/src/i18n/zh-CN.json
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@
"confirm": "确认",
"continue": "继续",
"save": "保存",
"delete": "删除",
"edit": "修改",
"retry": "重试",
"failedToLogin": "登录失败",
Expand All @@ -200,7 +201,6 @@
"inputMixinId": "请输入您的 Mixin ID",
"dontHaveMixinAccount": "没有 Mixin 账号?",
"youCanAlsoLoginWith": "您也可以使用以下方式登录",
"delete": "删除",
"transcribe": "语音转文本",
"stillTranscribing": "语音转文本仍在进行中,请耐心等候。",
"unableToSetLibraryPath": "无法设置资源库保存路径 {{path}}",
Expand Down Expand Up @@ -332,6 +332,7 @@
"accountSettings": "账户设置",
"advancedSettingsShort": "高级设置",
"advancedSettings": "高级设置",
"advanced": "高级设置",
"language": "语言",
"editEmail": "修改邮箱地址",
"editUserName": "修改用户名",
Expand Down Expand Up @@ -400,6 +401,9 @@
"selectScenario": "选择场景",
"selectAiEngine": "选择 AI 引擎",
"selectAiModel": "选择 AI 模型",
"selectTtsEngine": "选择 TTS 引擎",
"selectTtsModel": "选择 TTS 模型",
"selectTtsVoice": "选择 TTS 角色",
"youNeedToSetupApiKeyBeforeUsingOpenAI": "在使用 OpenAI 之前您需要设置 API 密钥",
"ensureYouHaveOllamaRunningLocallyAndHasAtLeastOneModel": "确保您已经在本地运行 Ollama 并且至少有一个模型",
"creatingSpeech": "正在生成语音",
Expand Down
36 changes: 36 additions & 0 deletions enjoy/src/renderer/components/conversations/conversation-card.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import { MessageCircleIcon, SpeechIcon } from "lucide-react";
import dayjs from "dayjs";

export const ConversationCard = (props: { conversation: ConversationType }) => {
const { conversation } = props;

return (
<div
className="bg-background hover:bg-muted hover:text-muted-foreground border rounded-full w-full mb-2 px-4 py-2 cursor-pointer flex items-center"
style={{
borderLeftColor: `#${conversation.id.replaceAll("-", "").slice(0, 6)}`,
borderLeftWidth: 3,
}}
>
<div className="">
{conversation.type === "gpt" && <MessageCircleIcon className="mr-2" />}

{conversation.type === "tts" && <SpeechIcon className="mr-2" />}
</div>
<div className="flex-1 flex items-center justify-between space-x-4">
<div className="">
<div className="line-clamp-1 text-sm">{conversation.name}</div>
<div className="text-xs text-muted-foreground">
{conversation.engine} /{" "}
{conversation.type === "tts"
? conversation.configuration?.tts?.model
: conversation.model}
</div>
</div>
<span className="min-w-fit text-sm text-muted-foreground">
{dayjs(conversation.createdAt).format("HH:mm l")}
</span>
</div>
</div>
);
};
138 changes: 84 additions & 54 deletions enjoy/src/renderer/components/conversations/conversation-form.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import {
SelectContent,
SelectItem,
Textarea,
toast,
} from "@renderer/components/ui";
import { useState, useEffect, useContext } from "react";
import {
Expand All @@ -47,28 +48,24 @@ const conversationFormSchema = z.object({
engine: z
.enum(["enjoyai", "openai", "ollama", "googleGenerativeAi"])
.default("openai"),
configuration: z
.object({
type: z.enum(["gpt", "tts"]),
model: z.string().optional(),
configuration: z.object({
type: z.enum(["gpt", "tts"]),
model: z.string().optional(),
baseUrl: z.string().optional(),
roleDefinition: z.string().optional(),
temperature: z.number().min(0).max(1).default(0.2),
numberOfChoices: z.number().min(1).default(1),
maxTokens: z.number().min(-1).default(2000),
presencePenalty: z.number().min(-2).max(2).default(0),
frequencyPenalty: z.number().min(-2).max(2).default(0),
historyBufferSize: z.number().min(0).default(10),
tts: z.object({
engine: z.enum(["openai", "enjoyai"]).default("enjoyai"),
model: z.string().default("tts-1"),
voice: z.string(),
baseUrl: z.string().optional(),
roleDefinition: z.string().optional(),
temperature: z.number().min(0).max(1).default(0.2),
numberOfChoices: z.number().min(1).default(1),
maxTokens: z.number().min(-1).default(2000),
presencePenalty: z.number().min(-2).max(2).default(0),
frequencyPenalty: z.number().min(-2).max(2).default(0),
historyBufferSize: z.number().min(0).default(10),
tts: z
.object({
engine: z.enum(["openai", "enjoyai"]).default("openai"),
model: z.string().default("tts-1"),
voice: z.string().optional(),
baseUrl: z.string().optional(),
})
.optional(),
})
.optional(),
}),
}),
});

export const ConversationForm = (props: {
Expand All @@ -77,8 +74,8 @@ export const ConversationForm = (props: {
}) => {
const { conversation, onFinish } = props;
const [submitting, setSubmitting] = useState<boolean>(false);
const [gptProviders, setGptProviders] = useState<any>(GPT_PROVIDERS);
const [ttsProviders, setTtsProviders] = useState<any>(TTS_PROVIDERS);
const [gptProviders, setGptProviders] = useState<any>([]);
const [ttsProviders, setTtsProviders] = useState<any>([]);
const { EnjoyApp, webApi } = useContext(AppSettingsProviderContext);
const { openai } = useContext(AISettingsProviderContext);
const navigate = useNavigate();
Expand Down Expand Up @@ -132,6 +129,7 @@ export const ConversationForm = (props: {
}, []);

const defaultConfig = JSON.parse(JSON.stringify(conversation || {}));

if (defaultConfig.engine === "openai" && openai) {
if (!defaultConfig.configuration) {
defaultConfig.configuration = {};
Expand Down Expand Up @@ -172,31 +170,15 @@ export const ConversationForm = (props: {
});

const onSubmit = async (data: z.infer<typeof conversationFormSchema>) => {
const { name, engine, configuration } = data;
let { name, engine, configuration } = data;
setSubmitting(true);

Object.keys(configuration).forEach((key) => {
if (key === "type") return;

if (!GPT_PROVIDERS[engine]?.configurable.includes(key)) {
// @ts-ignore
delete configuration[key];
}
});

if (configuration.type === "tts") {
conversation.model = configuration.tts.model;
}

// use default base url if not set
if (!configuration.baseUrl) {
configuration.baseUrl = GPT_PROVIDERS[engine]?.baseUrl;
}

// use default base url if not set
if (!configuration?.tts?.baseUrl) {
configuration.tts ||= {};
configuration.tts.baseUrl = GPT_PROVIDERS[engine]?.baseUrl;
try {
configuration = validateConfiguration(data);
} catch (e) {
toast.error(e.message);
setSubmitting(false);
return;
}

if (conversation?.id) {
Expand Down Expand Up @@ -227,6 +209,54 @@ export const ConversationForm = (props: {
}
};

const validateConfiguration = (
data: z.infer<typeof conversationFormSchema>
) => {
const { engine, configuration } = data;

Object.keys(configuration).forEach((key) => {
if (key === "type") return;

if (
configuration.type === "gpt" &&
!gptProviders[engine]?.configurable.includes(key)
) {
// @ts-ignore
delete configuration[key];
}

if (
configuration.type === "tts" &&
!ttsProviders[engine]?.configurable.includes(key)
) {
// @ts-ignore
delete configuration.tts[key];
}
});

if (configuration.type === "tts") {
if (!configuration.tts?.engine) {
throw new Error(t("models.conversation.ttsEngineRequired"));
}
if (!configuration.tts?.model) {
throw new Error(t("models.conversation.ttsModelRequired"));
}
}

// use default base url if not set
if (!configuration.baseUrl) {
configuration.baseUrl = gptProviders[engine]?.baseUrl;
}

// use default base url if not set
if (!configuration?.tts?.baseUrl) {
configuration.tts ||= {};
configuration.tts.baseUrl = gptProviders[engine]?.baseUrl;
}

return configuration;
};

return (
<Form {...form}>
<form
Expand Down Expand Up @@ -367,7 +397,7 @@ export const ConversationForm = (props: {
)}
/>

{GPT_PROVIDERS[form.watch("engine")]?.configurable.includes(
{gptProviders[form.watch("engine")]?.configurable.includes(
"temperature"
) && (
<FormField
Expand Down Expand Up @@ -401,7 +431,7 @@ export const ConversationForm = (props: {
/>
)}

{GPT_PROVIDERS[form.watch("engine")]?.configurable.includes(
{gptProviders[form.watch("engine")]?.configurable.includes(
"maxTokens"
) && (
<FormField
Expand Down Expand Up @@ -430,7 +460,7 @@ export const ConversationForm = (props: {
/>
)}

{GPT_PROVIDERS[form.watch("engine")]?.configurable.includes(
{gptProviders[form.watch("engine")]?.configurable.includes(
"presencePenalty"
) && (
<FormField
Expand Down Expand Up @@ -461,7 +491,7 @@ export const ConversationForm = (props: {
/>
)}

{GPT_PROVIDERS[form.watch("engine")]?.configurable.includes(
{gptProviders[form.watch("engine")]?.configurable.includes(
"frequencyPenalty"
) && (
<FormField
Expand Down Expand Up @@ -492,7 +522,7 @@ export const ConversationForm = (props: {
/>
)}

{GPT_PROVIDERS[form.watch("engine")]?.configurable.includes(
{gptProviders[form.watch("engine")]?.configurable.includes(
"numberOfChoices"
) && (
<FormField
Expand Down Expand Up @@ -555,7 +585,7 @@ export const ConversationForm = (props: {
)}
/>

{GPT_PROVIDERS[form.watch("engine")]?.configurable.includes(
{gptProviders[form.watch("engine")]?.configurable.includes(
"baseUrl"
) && (
<FormField
Expand Down Expand Up @@ -611,7 +641,7 @@ export const ConversationForm = (props: {

{ttsProviders[
form.watch("configuration.tts.engine")
]?.configurable.includes("model") && (
]?.configurable?.includes("model") && (
<FormField
control={form.control}
name="configuration.tts.model"
Expand Down Expand Up @@ -647,7 +677,7 @@ export const ConversationForm = (props: {

{ttsProviders[
form.watch("configuration.tts.engine")
]?.configurable.includes("voice") && (
]?.configurable?.includes("voice") && (
<FormField
control={form.control}
name="configuration.tts.voice"
Expand Down
Loading

0 comments on commit 838ed1e

Please sign in to comment.