diff --git a/src/main/kotlin/cc/unitmesh/cf/domains/DomainDispatcher.kt b/src/main/kotlin/cc/unitmesh/cf/domains/DomainDispatcher.kt index f9bfe981..b90e0e1a 100644 --- a/src/main/kotlin/cc/unitmesh/cf/domains/DomainDispatcher.kt +++ b/src/main/kotlin/cc/unitmesh/cf/domains/DomainDispatcher.kt @@ -1,27 +1,31 @@ package cc.unitmesh.cf.domains -import cc.unitmesh.cf.core.Domain import cc.unitmesh.cf.factory.process.DomainDetector import cc.unitmesh.cf.factory.process.DomainDetectorPlaceholder import cc.unitmesh.cf.infrastructure.cache.CachedEmbedding import cc.unitmesh.cf.infrastructure.llms.embedding.Embedding -import org.springframework.stereotype.Component import org.reflections.Reflections -import org.reflections.scanners.SubTypesScanner +import org.springframework.stereotype.Component @Component class DomainDispatcher( private val cachedEmbedding: CachedEmbedding, ) { + val cachedDomains: MutableList> = mutableListOf() fun dispatch(question: String): DomainDetector { val question: Embedding = cachedEmbedding.createEmbedding(question) return DomainDetectorPlaceholder() } fun lookupDomains(): List> { - val reflections = Reflections(DomainDispatcher::class.java.`package`.name, SubTypesScanner(false)) + if (cachedDomains.isNotEmpty()) { + return cachedDomains + } - return reflections.getSubTypesOf(DomainDetector::class.java) + val domains = Reflections(DomainDispatcher::class.java.`package`.name).getSubTypesOf(DomainDetector::class.java) .toList() + + this.cachedDomains.addAll(domains) + return domains } } \ No newline at end of file diff --git a/src/test/kotlin/cc/unitmesh/cf/domains/DomainDispatcherTest.kt b/src/test/kotlin/cc/unitmesh/cf/domains/DomainDispatcherTest.kt index 59d47826..54daae23 100644 --- a/src/test/kotlin/cc/unitmesh/cf/domains/DomainDispatcherTest.kt +++ b/src/test/kotlin/cc/unitmesh/cf/domains/DomainDispatcherTest.kt @@ -1,13 +1,11 @@ package cc.unitmesh.cf.domains import cc.unitmesh.cf.infrastructure.cache.CachedEmbedding -import io.kotest.matchers.ints.shouldBeGreaterThan import io.kotest.matchers.ints.shouldBeGreaterThanOrEqual import io.mockk.MockKAnnotations import io.mockk.impl.annotations.MockK import io.mockk.junit5.MockKExtension import org.junit.jupiter.api.BeforeEach -import org.junit.jupiter.api.Disabled import org.junit.jupiter.api.Test import org.junit.jupiter.api.extension.ExtendWith