Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Search wikipedia tool #312

Merged
merged 11 commits into from
Aug 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package com.xebia.functional.xef.conversation.reasoning

import com.xebia.functional.xef.conversation.llm.openai.OpenAI
import com.xebia.functional.xef.prompt.Prompt
import com.xebia.functional.xef.prompt.templates.user
import com.xebia.functional.xef.reasoning.tools.LLMTool
import com.xebia.functional.xef.reasoning.tools.ReActAgent
import com.xebia.functional.xef.reasoning.wikipedia.*

suspend fun main() {
OpenAI.conversation {
val model = OpenAI().DEFAULT_CHAT
val serialization = OpenAI().DEFAULT_SERIALIZATION
val math =
LLMTool.create(
name = "Calculator",
description =
"Perform math operations and calculations processing them with an LLM model. The tool input is a simple string containing the operation to solve expressed in numbers and math symbols.",
model = model,
scope = this
)
val search = SearchWikipedia(model = model, scope = this)
val searchByPageId = SearchWikipediaByPageId(model = model, scope = this)
val searchByTitle = SearchWikipediaByTitle(model = model, scope = this)

val reActAgent =
ReActAgent(
model = serialization,
scope = this,
tools = listOf(search, math, searchByPageId, searchByTitle),
)
val result =
reActAgent.run(
Prompt {
+user("Find and multiply the number of human bones by the number of Metallica albums")
}
)
println(result)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@ suspend fun main() {
OpenAI.conversation {
val client = WikipediaClient()

val searchDataByPageId = WikipediaClient.SearchDataByParam(pageId = 5222)
val searchDataByPageId = WikipediaClient.SearchDataByPageId(5222)
val answerByPageId = client.searchByPageId(searchDataByPageId)

val searchDataByTitle = WikipediaClient.SearchDataByParam(title = "Departments of Colombia")
val searchDataByTitle = WikipediaClient.SearchDataByTitle("Departments of Colombia")
val answerByTitle = client.searchByTitle(searchDataByTitle)

println(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package com.xebia.functional.xef.reasoning.wikipedia

import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable

@Serializable
data class Page(
@SerialName("pageid") val pageId: Int,
val title: String,
@SerialName("extract") val document: String
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package com.xebia.functional.xef.reasoning.wikipedia

import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable

@Serializable
data class SearchByParamResult(@SerialName("query") val searchResults: SearchByParamResults)
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package com.xebia.functional.xef.reasoning.wikipedia

import kotlinx.serialization.Serializable

@Serializable data class SearchByParamResults(val pages: Map<String, Page>)
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package com.xebia.functional.xef.reasoning.wikipedia

import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable

@Serializable
data class SearchData(
val title: String,
@SerialName("pageid") val pageId: Int,
@SerialName("size") val size: Int,
@SerialName("wordcount") val wordCount: Int,
@SerialName("snippet") val document: String
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
package com.xebia.functional.xef.reasoning.wikipedia

import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable

@Serializable data class SearchResult(@SerialName("query") val searchResults: SearchResults)
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
package com.xebia.functional.xef.reasoning.wikipedia

import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable

@Serializable data class SearchResults(@SerialName("search") val searches: List<SearchData>)
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package com.xebia.functional.xef.reasoning.wikipedia

import com.xebia.functional.xef.conversation.Conversation
import com.xebia.functional.xef.llm.Chat
import kotlin.jvm.JvmOverloads

expect class SearchWikipedia
@JvmOverloads
constructor(model: Chat, scope: Conversation, maxResultsInContext: Int = 3) : SearchWikipediaTool
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package com.xebia.functional.xef.reasoning.wikipedia

import com.xebia.functional.xef.conversation.Conversation
import com.xebia.functional.xef.llm.Chat
import kotlin.jvm.JvmOverloads

expect class SearchWikipediaByPageId @JvmOverloads constructor(model: Chat, scope: Conversation) :
SearchWikipediaByPageIdTool
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package com.xebia.functional.xef.reasoning.wikipedia

import com.xebia.functional.xef.conversation.Conversation
import com.xebia.functional.xef.llm.Chat
import com.xebia.functional.xef.prompt.Prompt
import com.xebia.functional.xef.prompt.templates.assistant
import com.xebia.functional.xef.prompt.templates.system
import com.xebia.functional.xef.prompt.templates.user
import com.xebia.functional.xef.reasoning.tools.Tool
import com.xebia.functional.xef.reasoning.wikipedia.WikipediaClient.SearchDataByPageId

interface SearchWikipediaByPageIdTool : Tool {

val model: Chat
val scope: Conversation
val client: WikipediaClient

override val name: String
get() = "SearchWikipediaByPageId"

override val description: String
get() =
"Search secondary tool in Wikipedia for detail information. The tool input is the number of page id, this tool can only be used with valid Wikipedia page ids returned by the primary search tool"

override suspend fun invoke(input: String): String {
val docs = client.searchByPageId(SearchDataByPageId(input.toInt()))

return model
.promptMessages(
prompt =
Prompt {
+system("Search results:")
+system("Title: ${docs.title}")
+system("PageId: ${docs.pageId}")
+system("Content: ${docs.document}")
+user("input: $input")
+assistant(
"I will select the best search results and reply with information relevant to the `input`"
)
},
scope = scope,
)
.firstOrNull()
?: "No results found"
}

fun close() {
client.close()
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package com.xebia.functional.xef.reasoning.wikipedia

import com.xebia.functional.xef.conversation.Conversation
import com.xebia.functional.xef.llm.Chat
import kotlin.jvm.JvmOverloads

expect class SearchWikipediaByTitle @JvmOverloads constructor(model: Chat, scope: Conversation) :
SearchWikipediaByTitleTool
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package com.xebia.functional.xef.reasoning.wikipedia

import com.xebia.functional.xef.conversation.Conversation
import com.xebia.functional.xef.llm.Chat
import com.xebia.functional.xef.prompt.Prompt
import com.xebia.functional.xef.prompt.templates.assistant
import com.xebia.functional.xef.prompt.templates.system
import com.xebia.functional.xef.prompt.templates.user
import com.xebia.functional.xef.reasoning.tools.Tool
import com.xebia.functional.xef.reasoning.wikipedia.WikipediaClient.SearchDataByTitle

interface SearchWikipediaByTitleTool : Tool {

val model: Chat
val scope: Conversation
val client: WikipediaClient

override val name: String
get() = "SearchWikipediaByTitle"

override val description: String
get() =
"Search secondary tool in Wikipedia for detail information. The tool input is the title of the page, this tool can only be used with valid Wikipedia page titles returned by the primary search tool"

override suspend fun invoke(input: String): String {
val docs = client.searchByTitle(SearchDataByTitle(input))

return model
.promptMessages(
prompt =
Prompt {
+system("Search results:")
+system("Title: ${docs.title}")
+system("PageId: ${docs.pageId}")
+system("Content: ${docs.document}")
+user("input: $input")
+assistant(
"I will select the best search results and reply with information relevant to the `input`"
)
},
scope = scope,
)
.firstOrNull()
?: "No results found"
}

fun close() {
client.close()
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package com.xebia.functional.xef.reasoning.wikipedia

import com.xebia.functional.xef.conversation.Conversation
import com.xebia.functional.xef.llm.Chat
import com.xebia.functional.xef.prompt.Prompt
import com.xebia.functional.xef.prompt.templates.assistant
import com.xebia.functional.xef.prompt.templates.system
import com.xebia.functional.xef.prompt.templates.user
import com.xebia.functional.xef.reasoning.tools.Tool
import com.xebia.functional.xef.reasoning.wikipedia.WikipediaClient.SearchData

interface SearchWikipediaTool : Tool {

val model: Chat
val scope: Conversation
val maxResultsInContext: Int
val client: WikipediaClient

override val name: String
get() = "SearchWikipediaTool"

override val description: String
get() =
"Search primary tool in Wikipedia for information. The tool input is a simple one line string"

override suspend fun invoke(input: String): String {
val docs = client.search(SearchData(input))

return model
.promptMessages(
prompt =
Prompt {
+system("Search results:")
docs.searchResults.searches.take(maxResultsInContext).forEach {
+system("Title: ${it.title}")
+system("PageId: ${it.pageId}")
+system("Size: ${it.size}")
+system("WordCount: ${it.wordCount}")
+system("Content: ${it.document}")
}
+user("input: $input")
+assistant(
"I will select the best search results and reply with information relevant to the `input`"
)
},
scope = scope,
)
.firstOrNull()
?: "No results found"
}

fun close() {
client.close()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@ import io.ktor.client.plugins.contentnegotiation.*
import io.ktor.client.request.*
import io.ktor.http.*
import io.ktor.serialization.kotlinx.json.*
import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable
import kotlinx.serialization.json.Json

class WikipediaClient : AutoCloseable, AutoClose by autoClose() {
Expand Down Expand Up @@ -39,32 +37,9 @@ class WikipediaClient : AutoCloseable, AutoClose by autoClose() {

data class SearchData(val search: String)

data class SearchDataByParam(val pageId: Int? = null, val title: String? = null)
data class SearchDataByPageId(val pageId: Int)

@Serializable data class SearchResult(@SerialName("query") val searchResults: SearchResults)

@Serializable data class SearchResults(@SerialName("search") val searches: List<Search>)

@Serializable
data class Search(
val title: String,
@SerialName("pageid") val pageId: Int,
@SerialName("size") val size: Int,
@SerialName("wordcount") val wordCount: Int,
@SerialName("snippet") val document: String
)

@Serializable
data class SearchByParamResult(@SerialName("query") val searchResults: SearchByParamResults)

@Serializable data class SearchByParamResults(val pages: Map<String, Page>)

@Serializable
data class Page(
@SerialName("pageid") val pageId: Int,
val title: String,
@SerialName("extract") val document: String
)
data class SearchDataByTitle(val title: String)

suspend fun search(searchData: SearchData): SearchResult {
return http
Expand All @@ -76,7 +51,7 @@ class WikipediaClient : AutoCloseable, AutoClose by autoClose() {
.body<SearchResult>()
}

suspend fun searchByPageId(searchData: SearchDataByParam): Page {
suspend fun searchByPageId(searchData: SearchDataByPageId): Page {
return http
.get(
"https://en.wikipedia.org/w/api.php?action=query&format=json&prop=extracts&exintro&explaintext&redirects=1&pageids=${searchData.pageId}"
Expand All @@ -89,10 +64,10 @@ class WikipediaClient : AutoCloseable, AutoClose by autoClose() {
.firstNotNullOf { it.value }
}

suspend fun searchByTitle(searchData: SearchDataByParam): Page {
suspend fun searchByTitle(searchData: SearchDataByTitle): Page {
return http
.get(
"https://en.wikipedia.org/w/api.php?action=query&format=json&prop=extracts&exintro&explaintext&redirects=1&titles=${searchData.title?.encodeURLQueryComponent()}"
"https://en.wikipedia.org/w/api.php?action=query&format=json&prop=extracts&exintro&explaintext&redirects=1&titles=${searchData.title.encodeURLQueryComponent()}"
) {
contentType(ContentType.Application.Json)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package com.xebia.functional.xef.reasoning.wikipedia

import com.xebia.functional.xef.conversation.Conversation
import com.xebia.functional.xef.llm.Chat

actual class SearchWikipedia
actual constructor(
override val model: Chat,
override val scope: Conversation,
override val maxResultsInContext: Int,
) : SearchWikipediaTool, AutoCloseable {
override val client: WikipediaClient = WikipediaClient()

override fun close() {
client.close()
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package com.xebia.functional.xef.reasoning.wikipedia

import com.xebia.functional.xef.conversation.Conversation
import com.xebia.functional.xef.llm.Chat

actual class SearchWikipediaByPageId
actual constructor(override val model: Chat, override val scope: Conversation) :
SearchWikipediaByPageIdTool, AutoCloseable {
override val client: WikipediaClient = WikipediaClient()

override fun close() {
client.close()
}
}