Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,23 +24,30 @@
import org.springaicommunity.mcp.annotation.McpProgress;
import org.springaicommunity.mcp.annotation.McpSampling;

import org.springframework.ai.mcp.annotation.spring.scan.AbstractAnnotatedMethodBeanFactoryInitializationAotProcessor;
import org.springframework.ai.mcp.annotation.spring.scan.AbstractAnnotatedMethodBeanPostProcessor;
import org.springframework.ai.mcp.annotation.spring.scan.AbstractMcpAnnotatedBeans;
import org.springframework.aot.hint.MemberCategory;
import org.springframework.aot.hint.RuntimeHints;
import org.springframework.aot.hint.RuntimeHintsRegistrar;
import org.springframework.boot.autoconfigure.AutoConfiguration;
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.ImportRuntimeHints;

/**
* @author Christian Tzolov
* @author Josh Long
*/
@AutoConfiguration
@ConditionalOnClass(McpLogging.class)
@ConditionalOnProperty(prefix = McpClientAnnotationScannerProperties.CONFIG_PREFIX, name = "enabled",
havingValue = "true", matchIfMissing = true)
@EnableConfigurationProperties(McpClientAnnotationScannerProperties.class)
@ImportRuntimeHints(McpClientAnnotationScannerAutoConfiguration.AnnotationHints.class)
public class McpClientAnnotationScannerAutoConfiguration {

private static final Set<Class<? extends Annotation>> CLIENT_MCP_ANNOTATIONS = Set.of(McpLogging.class,
Expand All @@ -54,15 +61,30 @@ public ClientMcpAnnotatedBeans clientAnnotatedBeans() {

@Bean
@ConditionalOnMissingBean
public ClientAnnotatedMethodBeanPostProcessor clientAnnotatedMethodBeanPostProcessor(
public static ClientAnnotatedMethodBeanPostProcessor clientAnnotatedMethodBeanPostProcessor(
ClientMcpAnnotatedBeans clientMcpAnnotatedBeans, McpClientAnnotationScannerProperties properties) {
return new ClientAnnotatedMethodBeanPostProcessor(clientMcpAnnotatedBeans, CLIENT_MCP_ANNOTATIONS);
}

@Bean
static ClientAnnotatedBeanFactoryInitializationAotProcessor clientAnnotatedBeanFactoryInitializationAotProcessor() {
return new ClientAnnotatedBeanFactoryInitializationAotProcessor(CLIENT_MCP_ANNOTATIONS);
}

public static class ClientMcpAnnotatedBeans extends AbstractMcpAnnotatedBeans {

}

public static class ClientAnnotatedBeanFactoryInitializationAotProcessor
extends AbstractAnnotatedMethodBeanFactoryInitializationAotProcessor {

public ClientAnnotatedBeanFactoryInitializationAotProcessor(
Set<Class<? extends Annotation>> targetAnnotations) {
super(targetAnnotations);
}

}

public static class ClientAnnotatedMethodBeanPostProcessor extends AbstractAnnotatedMethodBeanPostProcessor {

public ClientAnnotatedMethodBeanPostProcessor(ClientMcpAnnotatedBeans clientMcpAnnotatedBeans,
Expand All @@ -72,4 +94,13 @@ public ClientAnnotatedMethodBeanPostProcessor(ClientMcpAnnotatedBeans clientMcpA

}

static class AnnotationHints implements RuntimeHintsRegistrar {

@Override
public void registerHints(RuntimeHints hints, ClassLoader classLoader) {
CLIENT_MCP_ANNOTATIONS.forEach(an -> hints.reflection().registerType(an, MemberCategory.values()));
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,6 @@ void streamableHttpTest() {

mcpClient.ping();

System.out.println("mcpClient = " + mcpClient.getServerInfo());

ListToolsResult toolsResult = mcpClient.listTools();

assertThat(toolsResult).isNotNull();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,6 @@ void streamableHttpTest() {

mcpClient.ping();

System.out.println("mcpClient = " + mcpClient.getServerInfo());

ListToolsResult toolsResult = mcpClient.listTools();

assertThat(toolsResult).isNotNull();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,6 @@ void streamableHttpTest() {

mcpClient.ping();

System.out.println("mcpClient = " + mcpClient.getServerInfo());

ListToolsResult toolsResult = mcpClient.listTools();

assertThat(toolsResult).isNotNull();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,6 @@ void streamableHttpTest() {

mcpClient.ping();

System.out.println("mcpClient = " + mcpClient.getServerInfo());

ListToolsResult toolsResult = mcpClient.listTools();

assertThat(toolsResult).isNotNull();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,23 +24,30 @@
import org.springaicommunity.mcp.annotation.McpResource;
import org.springaicommunity.mcp.annotation.McpTool;

import org.springframework.ai.mcp.annotation.spring.scan.AbstractAnnotatedMethodBeanFactoryInitializationAotProcessor;
import org.springframework.ai.mcp.annotation.spring.scan.AbstractAnnotatedMethodBeanPostProcessor;
import org.springframework.ai.mcp.annotation.spring.scan.AbstractMcpAnnotatedBeans;
import org.springframework.aot.hint.MemberCategory;
import org.springframework.aot.hint.RuntimeHints;
import org.springframework.aot.hint.RuntimeHintsRegistrar;
import org.springframework.boot.autoconfigure.AutoConfiguration;
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.ImportRuntimeHints;

/**
* @author Christian Tzolov
* @author Josh Long
*/
@AutoConfiguration
@ConditionalOnClass(McpTool.class)
@ConditionalOnProperty(prefix = McpServerAnnotationScannerProperties.CONFIG_PREFIX, name = "enabled",
havingValue = "true", matchIfMissing = true)
@EnableConfigurationProperties(McpServerAnnotationScannerProperties.class)
@ImportRuntimeHints(McpServerAnnotationScannerAutoConfiguration.AnnotationHints.class)
public class McpServerAnnotationScannerAutoConfiguration {

private static final Set<Class<? extends Annotation>> SERVER_MCP_ANNOTATIONS = Set.of(McpTool.class,
Expand All @@ -54,15 +61,30 @@ public ServerMcpAnnotatedBeans serverAnnotatedBeanRegistry() {

@Bean
@ConditionalOnMissingBean
public ServerAnnotatedMethodBeanPostProcessor serverAnnotatedMethodBeanPostProcessor(
public static ServerAnnotatedMethodBeanPostProcessor serverAnnotatedMethodBeanPostProcessor(
ServerMcpAnnotatedBeans serverMcpAnnotatedBeans, McpServerAnnotationScannerProperties properties) {
return new ServerAnnotatedMethodBeanPostProcessor(serverMcpAnnotatedBeans, SERVER_MCP_ANNOTATIONS);
}

@Bean
public static ServerAnnotatedBeanFactoryInitializationAotProcessor serverAnnotatedBeanFactoryInitializationAotProcessor() {
return new ServerAnnotatedBeanFactoryInitializationAotProcessor(SERVER_MCP_ANNOTATIONS);
}

public static class ServerMcpAnnotatedBeans extends AbstractMcpAnnotatedBeans {

}

public static class ServerAnnotatedBeanFactoryInitializationAotProcessor
extends AbstractAnnotatedMethodBeanFactoryInitializationAotProcessor {

public ServerAnnotatedBeanFactoryInitializationAotProcessor(
Set<Class<? extends Annotation>> targetAnnotations) {
super(targetAnnotations);
}

}

public static class ServerAnnotatedMethodBeanPostProcessor extends AbstractAnnotatedMethodBeanPostProcessor {

public ServerAnnotatedMethodBeanPostProcessor(ServerMcpAnnotatedBeans serverMcpAnnotatedBeans,
Expand All @@ -72,4 +94,13 @@ public ServerAnnotatedMethodBeanPostProcessor(ServerMcpAnnotatedBeans serverMcpA

}

static class AnnotationHints implements RuntimeHintsRegistrar {

@Override
public void registerHints(RuntimeHints hints, ClassLoader classLoader) {
SERVER_MCP_ANNOTATIONS.forEach(an -> hints.reflection().registerType(an, MemberCategory.values()));
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,10 @@ static class SyncServerSpecificationConfiguration {
@Bean
public List<McpServerFeatures.SyncResourceSpecification> resourceSpecs(
ServerMcpAnnotatedBeans beansWithMcpMethodAnnotations) {
return SyncMcpAnnotationProviders

List<McpServerFeatures.SyncResourceSpecification> syncResourceSpecifications = SyncMcpAnnotationProviders
.resourceSpecifications(beansWithMcpMethodAnnotations.getBeansByAnnotation(McpResource.class));
return syncResourceSpecifications;
}

@Bean
Expand All @@ -75,8 +77,10 @@ public List<McpServerFeatures.SyncCompletionSpecification> completionSpecs(
@Bean
public List<McpServerFeatures.SyncToolSpecification> toolSpecs(
ServerMcpAnnotatedBeans beansWithMcpMethodAnnotations) {
return SyncMcpAnnotationProviders
.toolSpecifications(beansWithMcpMethodAnnotations.getBeansByAnnotation(McpTool.class));
List<Object> beansByAnnotation = beansWithMcpMethodAnnotations.getBeansByAnnotation(McpTool.class);
List<McpServerFeatures.SyncToolSpecification> syncToolSpecifications = SyncMcpAnnotationProviders
.toolSpecifications(beansByAnnotation);
return syncToolSpecifications;
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,10 @@ public List<McpStatelessServerFeatures.SyncCompletionSpecification> completionSp
@Bean
public List<McpStatelessServerFeatures.SyncToolSpecification> toolSpecs(
ServerMcpAnnotatedBeans beansWithMcpMethodAnnotations) {
return SyncMcpAnnotationProviders
.statelessToolSpecifications(beansWithMcpMethodAnnotations.getBeansByAnnotation(McpTool.class));
List<Object> beansByAnnotation = beansWithMcpMethodAnnotations.getBeansByAnnotation(McpTool.class);
List<McpStatelessServerFeatures.SyncToolSpecification> syncToolSpecifications = SyncMcpAnnotationProviders
.statelessToolSpecifications(beansByAnnotation);
return syncToolSpecifications;
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,11 @@

package org.springframework.ai.mcp.aot;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Set;
import java.util.stream.Collectors;

import io.modelcontextprotocol.spec.McpSchema;

import org.springframework.ai.aot.AiRuntimeHints;
import org.springframework.aot.hint.MemberCategory;
import org.springframework.aot.hint.RuntimeHints;
import org.springframework.aot.hint.RuntimeHintsRegistrar;
Expand Down Expand Up @@ -65,45 +62,10 @@ public class McpHints implements RuntimeHintsRegistrar {
public void registerHints(RuntimeHints hints, @Nullable ClassLoader classLoader) {
var mcs = MemberCategory.values();

for (var tr : innerClasses(McpSchema.class)) {
Set<TypeReference> typeReferences = AiRuntimeHints.findInnerClassesFor(McpSchema.class);
for (var tr : typeReferences) {
hints.reflection().registerType(tr, mcs);
}
}

/**
* Discovers all inner classes of a given class.
* <p>
* This method recursively finds all nested classes (both declared and inherited) of
* the provided class and converts them to type references.
* @param clazz the class to find inner classes for
* @return a set of type references for all discovered inner classes
*/
private Set<TypeReference> innerClasses(Class<?> clazz) {
var indent = new HashSet<String>();
this.findNestedClasses(clazz, indent);
return indent.stream().map(TypeReference::of).collect(Collectors.toSet());
}

/**
* Recursively finds all nested classes of a given class.
* <p>
* This method:
* <ol>
* <li>Collects both declared and inherited nested classes</li>
* <li>Recursively processes each nested class</li>
* <li>Adds the class names to the provided set</li>
* </ol>
* @param clazz the class to find nested classes for
* @param indent the set to collect class names in
*/
private void findNestedClasses(Class<?> clazz, Set<String> indent) {
var classes = new ArrayList<Class<?>>();
classes.addAll(Arrays.asList(clazz.getDeclaredClasses()));
classes.addAll(Arrays.asList(clazz.getClasses()));
for (var nestedClass : classes) {
this.findNestedClasses(nestedClass, indent);
}
indent.addAll(classes.stream().map(Class::getName).toList());
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
* Copyright 2025-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.ai.mcp.annotation.spring.scan;

import java.lang.annotation.Annotation;
import java.util.ArrayList;
import java.util.List;
import java.util.Set;

import org.springframework.aot.hint.MemberCategory;
import org.springframework.aot.hint.RuntimeHints;
import org.springframework.beans.factory.aot.BeanFactoryInitializationAotContribution;
import org.springframework.beans.factory.aot.BeanFactoryInitializationAotProcessor;
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
import org.springframework.core.log.LogAccessor;

/**
* @author Josh Long
*/
public class AbstractAnnotatedMethodBeanFactoryInitializationAotProcessor extends AnnotatedMethodDiscovery
implements BeanFactoryInitializationAotProcessor {

private static final LogAccessor logger = new LogAccessor(AbstractAnnotatedMethodBeanPostProcessor.class);

public AbstractAnnotatedMethodBeanFactoryInitializationAotProcessor(
Set<Class<? extends Annotation>> targetAnnotations) {
super(targetAnnotations);
}

@Override
public BeanFactoryInitializationAotContribution processAheadOfTime(ConfigurableListableBeanFactory beanFactory) {
List<Class<?>> types = new ArrayList<>();
for (String beanName : beanFactory.getBeanDefinitionNames()) {
Class<?> beanClass = beanFactory.getType(beanName);
Set<Class<? extends Annotation>> classes = this.scan(beanClass);
if (!classes.isEmpty()) {
types.add(beanClass);
}
}
return (generationContext, beanFactoryInitializationCode) -> {
RuntimeHints runtimeHints = generationContext.getRuntimeHints();
for (Class<?> typeReference : types) {
runtimeHints.reflection().registerType(typeReference, MemberCategory.values());
logger.info("registering " + typeReference.getName() + " for reflection");
}
};
}

}
Loading