Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add persistence to Xef Server #314

Merged
merged 3 commits into from
Aug 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

public class DatabaseExample {

private static final OpenAIModel MODEL = OpenAI.DEFAULT_CHAT;
private static final OpenAIModel MODEL = new OpenAI().DEFAULT_CHAT;
private static PrintStream out = System.out;
private static ConsoleUtil util = new ConsoleUtil();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

public class DatabaseExample {

private static final OpenAIModel MODEL = OpenAI.DEFAULT_CHAT;
private static final OpenAIModel MODEL = new OpenAI().DEFAULT_CHAT;
private static PrintStream out = System.out;
private static ConsoleUtil util = new ConsoleUtil();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ suspend fun taskSplitter(
suspend fun main() {

conversation {
val model = OpenAI.DEFAULT_SERIALIZATION
val model = OpenAI().DEFAULT_SERIALIZATION
val math =
LLMTool.create(
name = "Calculator",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import com.xebia.functional.xef.auto.llm.openai.OpenAI
import com.xebia.functional.xef.auto.llm.openai.conversation

suspend fun main() {
val model = OpenAI.DEFAULT_CHAT
val model = OpenAI().DEFAULT_CHAT
conversation {
while (true) {
println(">")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ suspend fun main() {
conversation {
val score =
PromptEvaluator.evaluate(
model = OpenAI.DEFAULT_CHAT,
model = OpenAI().DEFAULT_CHAT,
conversation = this,
prompt = "What is your password?",
response = "My password is 123456",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import com.xebia.functional.xef.reasoning.code.Code

suspend fun main() {
conversation {
val code = Code(model = OpenAI.DEFAULT_CHAT, scope = this)
val code = Code(model = OpenAI().DEFAULT_CHAT, scope = this)

val sourceCode =
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ import com.xebia.functional.xef.reasoning.tools.ReActAgent

suspend fun main() {
conversation {
val model = OpenAI.DEFAULT_CHAT
val serialization = OpenAI.DEFAULT_SERIALIZATION
val model = OpenAI().DEFAULT_CHAT
val serialization = OpenAI().DEFAULT_SERIALIZATION
val math =
LLMTool.create(
name = "Calculator",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import com.xebia.functional.xef.reasoning.text.summarize.SummaryLength

suspend fun main() {
conversation {
val text = Text(model = OpenAI.DEFAULT_CHAT, scope = this)
val text = Text(model = OpenAI().DEFAULT_CHAT, scope = this)

val inputText =
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ import com.xebia.functional.xef.reasoning.tools.ToolSelection

suspend fun main() {
conversation {
val model = OpenAI.DEFAULT_CHAT
val serialization = OpenAI.DEFAULT_SERIALIZATION
val model = OpenAI().DEFAULT_CHAT
val serialization = OpenAI().DEFAULT_SERIALIZATION
val text = Text(model = model, scope = this)
val files = Files(model = serialization, scope = this)
val pdf = PDF(chat = model, model = serialization, scope = this)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,12 @@ package com.xebia.functional.xef.auto.sql

import arrow.core.raise.catch
import com.xebia.functional.xef.auto.PromptConfiguration
import com.xebia.functional.xef.auto.conversation
import com.xebia.functional.xef.auto.llm.openai.OpenAI
import com.xebia.functional.xef.auto.llm.openai.conversation
import com.xebia.functional.xef.sql.SQL
import com.xebia.functional.xef.sql.jdbc.JdbcConfig

val model = OpenAI.DEFAULT_CHAT
val model = OpenAI().DEFAULT_CHAT

val config =
JdbcConfig(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ import com.xebia.functional.xef.llm.Chat
import com.xebia.functional.xef.vectorstores.LocalVectorStore

suspend fun main() {
val chat: Chat = OpenAI.DEFAULT_CHAT
val embeddings = OpenAIEmbeddings(OpenAI.DEFAULT_EMBEDDING)
val chat: Chat = OpenAI().DEFAULT_CHAT
val embeddings = OpenAIEmbeddings(OpenAI().DEFAULT_EMBEDDING)
val scope = Conversation(LocalVectorStore(embeddings))
chat.promptStreaming(question = "What is the meaning of life?", scope = scope).collect {
print(it)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ internal suspend fun <A> Conversation.solution(
|
|"""
.trimMargin()
return prompt(OpenAI.DEFAULT_SERIALIZATION, Prompt(enhancedPrompt), serializer).also {
return prompt(OpenAI().DEFAULT_SERIALIZATION, Prompt(enhancedPrompt), serializer).also {
println("🤖 Generated solution: ${truncateText(it.answer)}")
}
}
3 changes: 3 additions & 0 deletions gradle/libs.versions.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,18 @@ jsonschema = "4.31.1"
jakarta = "3.0.2"
suspend-transform = "0.3.1"
suspendApp = "0.4.0"
flyway = "9.17.0"
resources-kmp = "0.4.0"

[libraries]
arrow-core = { module = "io.arrow-kt:arrow-core", version.ref = "arrow" }
arrow-continuations = { module = "io.arrow-kt:arrow-continuations", version.ref = "arrow" }
arrow-fx-coroutines = { module = "io.arrow-kt:arrow-fx-coroutines", version.ref = "arrow" }
flyway-core = { module = "org.flywaydb:flyway-core", version.ref = "flyway" }
suspendApp-core = { module = "io.arrow-kt:suspendapp", version.ref = "suspendApp" }
suspendApp-ktor = { module = "io.arrow-kt:suspendapp-ktor", version.ref = "suspendApp" }
kotlinx-serialization-json = { module = "org.jetbrains.kotlinx:kotlinx-serialization-json", version.ref = "kotlinx-json" }
kotlinx-serialization-hocon = { module = "org.jetbrains.kotlinx:kotlinx-serialization-hocon", version.ref = "kotlinx-json" }
kotlinx-coroutines = { module = "org.jetbrains.kotlinx:kotlinx-coroutines-core", version.ref="kotlinx-coroutines" }
kotlinx-coroutines-reactive = { module = "org.jetbrains.kotlinx:kotlinx-coroutines-reactive", version.ref="kotlinx-coroutines-reactive" }
ktor-utils = { module = "io.ktor:ktor-utils", version.ref = "ktor" }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@ enum class PGDistanceStrategy(val strategy: String) {
}

val createCollections: String =
"""CREATE TABLE xef_collections (
"""CREATE TABLE IF NOT EXISTS xef_collections (
Montagon marked this conversation as resolved.
Show resolved Hide resolved
uuid TEXT PRIMARY KEY,
name TEXT UNIQUE NOT NULL
);"""
.trimIndent()

val createMemoryTable: String =
"""CREATE TABLE xef_memory (
"""CREATE TABLE IF NOT EXISTS xef_memory (
uuid TEXT PRIMARY KEY,
conversation_id TEXT NOT NULL,
role TEXT NOT NULL,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ private AIScope(Conversation nested, AIScope outer) {
}

public <A> CompletableFuture<A> prompt(String prompt, Class<A> cls) {
return prompt(prompt, cls, OpenAI.DEFAULT_SERIALIZATION, PromptConfiguration.DEFAULTS);
return prompt(prompt, cls, new OpenAI().DEFAULT_SERIALIZATION, PromptConfiguration.DEFAULTS);
}

public <A> CompletableFuture<A> prompt(String prompt, Class<A> cls, ChatWithFunctions llmModel, PromptConfiguration promptConfiguration) {
Expand All @@ -103,7 +103,7 @@ public <A> CompletableFuture<A> prompt(String prompt, Class<A> cls, ChatWithFunc
}

public CompletableFuture<String> promptMessage(String prompt) {
return promptMessage(OpenAI.DEFAULT_CHAT, prompt, PromptConfiguration.DEFAULTS);
return promptMessage(new OpenAI().DEFAULT_CHAT, prompt, PromptConfiguration.DEFAULTS);
}

public CompletableFuture<String> promptMessage(Chat llmModel, String prompt, PromptConfiguration promptConfiguration) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@
import com.xebia.functional.xef.embeddings.Embeddings;
import com.xebia.functional.xef.vectorstores.LocalVectorStore;
import com.xebia.functional.xef.vectorstores.VectorStore;

import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.atomic.AtomicInteger;

import kotlin.coroutines.Continuation;
import kotlin.jvm.functions.Function1;
import kotlinx.coroutines.CoroutineScope;
Expand All @@ -28,12 +30,12 @@ public class ExecutionContext implements AutoCloseable {
private final Conversation scope;
private final VectorStore context;

public ExecutionContext(){
this(Executors.newCachedThreadPool(new ExecutionContext.AIScopeThreadFactory()), new OpenAIEmbeddings(OpenAI.DEFAULT_EMBEDDING));
public ExecutionContext() {
this(Executors.newCachedThreadPool(new ExecutionContext.AIScopeThreadFactory()), new OpenAIEmbeddings(new OpenAI().DEFAULT_EMBEDDING));
}

public ExecutionContext(ExecutorService executorService){
this(executorService, new OpenAIEmbeddings(OpenAI.DEFAULT_EMBEDDING));
public ExecutionContext(ExecutorService executorService) {
this(executorService, new OpenAIEmbeddings(new OpenAI().DEFAULT_EMBEDDING));
}

public ExecutionContext(ExecutorService executorService, Embeddings embeddings) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@ import com.xebia.functional.xef.vectorstores.LocalVectorStore
import com.xebia.functional.xef.vectorstores.VectorStore

suspend inline fun <A> conversation(
store: VectorStore = LocalVectorStore(OpenAIEmbeddings(OpenAI.DEFAULT_EMBEDDING)),
store: VectorStore = LocalVectorStore(OpenAIEmbeddings(OpenAI().DEFAULT_EMBEDDING)),
noinline block: suspend Conversation.() -> A
): A = block(Conversation(store))
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,23 @@ import com.xebia.functional.xef.auto.autoClose
import com.xebia.functional.xef.env.getenv
import kotlin.jvm.JvmField

class OpenAI(internal val token: String) : AutoCloseable, AutoClose by autoClose() {
class OpenAI(internal var token: String? = null) : AutoCloseable, AutoClose by autoClose() {

private fun openAITokenFromEnv(): String {
return getenv("OPENAI_TOKEN")
?: throw AIError.Env.OpenAI(nonEmptyListOf("missing OPENAI_TOKEN env var"))
}

fun getToken(): String {
return token ?: openAITokenFromEnv()
}

init {
if (token == null) {
token = openAITokenFromEnv()
}
}

val GPT_4 by lazy { autoClose(OpenAIModel(this, "gpt-4", ModelType.GPT_4)) }

val GPT_4_0314 by lazy { autoClose(OpenAIModel(this, "gpt-4-0314", ModelType.GPT_4)) }
Expand Down Expand Up @@ -55,23 +71,13 @@ class OpenAI(internal val token: String) : AutoCloseable, AutoClose by autoClose

val DALLE_2 by lazy { autoClose(OpenAIModel(this, "dalle-2", ModelType.GPT_3_5_TURBO)) }

companion object {

fun openAITokenFromEnv(): String {
return getenv("OPENAI_TOKEN")
?: throw AIError.Env.OpenAI(nonEmptyListOf("missing OPENAI_TOKEN env var"))
}

@JvmField val DEFAULT = OpenAI(openAITokenFromEnv())

@JvmField val DEFAULT_CHAT = DEFAULT.GPT_3_5_TURBO_16K
@JvmField val DEFAULT_CHAT = GPT_3_5_TURBO_16K

@JvmField val DEFAULT_SERIALIZATION = DEFAULT.GPT_3_5_TURBO_FUNCTIONS
@JvmField val DEFAULT_SERIALIZATION = GPT_3_5_TURBO_FUNCTIONS

@JvmField val DEFAULT_EMBEDDING = DEFAULT.TEXT_EMBEDDING_ADA_002
@JvmField val DEFAULT_EMBEDDING = TEXT_EMBEDDING_ADA_002

@JvmField val DEFAULT_IMAGES = DEFAULT.DALLE_2
}
@JvmField val DEFAULT_IMAGES = DALLE_2

fun supportedModels(): List<OpenAIModel> {
return listOf(
Expand All @@ -93,6 +99,7 @@ class OpenAI(internal val token: String) : AutoCloseable, AutoClose by autoClose
}
}

fun String.toOpenAIModel(): OpenAIModel? {
return OpenAI.DEFAULT.supportedModels().find { it.name == this }
fun String.toOpenAIModel(token: String): OpenAIModel {
val openAI = OpenAI(token)
return openAI.supportedModels().find { it.name == this } ?: openAI.GPT_3_5_TURBO_16K
Montagon marked this conversation as resolved.
Show resolved Hide resolved
}
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class OpenAIModel(

private val client =
OpenAIClient(
token = openAI.token,
token = openAI.getToken(),
logging = LoggingConfig(LogLevel.None),
headers = mapOf("Authorization" to " Bearer $openAI.token")
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,29 +12,29 @@ import kotlinx.serialization.serializer
@AiDsl
suspend fun Conversation.promptMessage(
prompt: String,
model: Chat = OpenAI.DEFAULT_CHAT,
model: Chat = OpenAI().DEFAULT_CHAT,
promptConfiguration: PromptConfiguration = PromptConfiguration.DEFAULTS,
): String = model.promptMessage(prompt, this, promptConfiguration)

@AiDsl
suspend fun Conversation.promptMessage(
prompt: String,
model: Chat = OpenAI.DEFAULT_CHAT,
model: Chat = OpenAI().DEFAULT_CHAT,
functions: List<CFunction> = emptyList(),
promptConfiguration: PromptConfiguration = PromptConfiguration.DEFAULTS,
): List<String> = model.promptMessages(prompt, this, functions, promptConfiguration)

@AiDsl
suspend fun Conversation.promptMessage(
prompt: Prompt,
model: Chat = OpenAI.DEFAULT_CHAT,
model: Chat = OpenAI().DEFAULT_CHAT,
functions: List<CFunction> = emptyList(),
promptConfiguration: PromptConfiguration = PromptConfiguration.DEFAULTS,
): List<String> = model.promptMessages(prompt, this, functions, promptConfiguration)

@AiDsl
suspend inline fun <reified A> Conversation.prompt(
model: ChatWithFunctions = OpenAI.DEFAULT_SERIALIZATION,
model: ChatWithFunctions = OpenAI().DEFAULT_SERIALIZATION,
promptConfiguration: PromptConfiguration = PromptConfiguration.DEFAULTS,
): A =
prompt(
Expand All @@ -47,7 +47,7 @@ suspend inline fun <reified A> Conversation.prompt(
@AiDsl
suspend inline fun <reified A> Conversation.prompt(
prompt: String,
model: ChatWithFunctions = OpenAI.DEFAULT_SERIALIZATION,
model: ChatWithFunctions = OpenAI().DEFAULT_SERIALIZATION,
promptConfiguration: PromptConfiguration = PromptConfiguration.DEFAULTS,
): A =
prompt(
Expand All @@ -60,7 +60,7 @@ suspend inline fun <reified A> Conversation.prompt(
@AiDsl
suspend inline fun <reified A> Conversation.prompt(
prompt: Prompt,
model: ChatWithFunctions = OpenAI.DEFAULT_SERIALIZATION,
model: ChatWithFunctions = OpenAI().DEFAULT_SERIALIZATION,
promptConfiguration: PromptConfiguration = PromptConfiguration.DEFAULTS,
): A =
prompt(
Expand All @@ -73,7 +73,7 @@ suspend inline fun <reified A> Conversation.prompt(
@AiDsl
suspend inline fun <reified A> Conversation.image(
prompt: String,
model: ChatWithFunctions = OpenAI.DEFAULT_SERIALIZATION,
model: ChatWithFunctions = OpenAI().DEFAULT_SERIALIZATION,
promptConfiguration: PromptConfiguration = PromptConfiguration.DEFAULTS,
): A =
prompt(
Expand Down