![Redis](https://redis.io/wp-content/uploads/2024/04/Logotype.svg?auto=webp&quality=85,75&width=120)

# RAG with Spring AI and Redis

This notebook demonstrates how to build a Retrieval-Augmented Generation (RAG) system using Spring AI and Redis. The example focuses on creating a beer recommendation chatbot that can answer questions about beers by retrieving relevant information from a database.

## Maven Dependencies

The notebook requires several dependencies:

- Spring AI OpenAI: To interact with OpenAI's language models
- Spring AI Transformers: For embedding generation using local models
- Spring AI Redis Store: To use Redis as a vector database
- SLF4J: For logging
- Jedis: Redis client for Java

In [1]:
%mavenRepo spring_milestones https://repo.spring.io/milestone/   
%maven "org.springframework.ai:spring-ai-openai:1.0.0-M6"
%maven "org.springframework.ai:spring-ai-transformers:1.0.0-M6"
%maven "org.springframework.ai:spring-ai-redis-store:1.0.0-M6"
%maven "org.slf4j:slf4j-simple:2.0.17"    
%maven "redis.clients:jedis:5.2.0"

## Setting up the OpenAI Chat Model

To run the code below, you need to have your OpenAI API key available in environment variable `OPENAI_API_KEY`.

In [2]:
import org.springframework.ai.openai.OpenAiChatModel;
import org.springframework.ai.openai.OpenAiChatOptions;
import org.springframework.ai.openai.api.OpenAiApi;

var openAiApi = new OpenAiApi(System.getenv("OPENAI_API_KEY"));

var openAiChatOptions = OpenAiChatOptions.builder()
    .model("gpt-3.5-turbo")
    .temperature(0.4)
    .maxTokens(200)
    .build();

var chatModel = OpenAiChatModel.builder()
    .openAiApi(openAiApi)
    .defaultOptions(openAiChatOptions)
    .build();

## Setting up the Embedding Model

Initializes the transformer-based embedding model. Unlike the chat model which uses OpenAI's API, this embedding model runs locally using the Hugging Face transformer models.

In [3]:
import org.springframework.ai.transformers.TransformersEmbeddingModel;

var embeddingModel = new TransformersEmbeddingModel();
embeddingModel.afterPropertiesSet();

[JJava-executor-0] INFO org.springframework.ai.transformers.ResourceCacheService - Create cache root directory: /tmp/spring-ai-onnx-generative
[JJava-executor-0] INFO org.springframework.ai.transformers.ResourceCacheService - Caching the URL [https://raw.githubusercontent.com/spring-projects/spring-ai/main/models/spring-ai-transformers/src/main/resources/onnx/all-MiniLM-L6-v2/tokenizer.json] resource to: /tmp/spring-ai-onnx-generative/4d42ba07-cb22-352f-bb44-beccc8c8c0b7/tokenizer.json
[JJava-executor-0] INFO ai.djl.util.Platform - Found matching platform from: jar:file:/home/jovyan/.ivy2/cache/ai.djl.huggingface/tokenizers/jars/tokenizers-0.30.0.jar!/native/lib/tokenizers.properties
[JJava-executor-0] INFO org.springframework.ai.transformers.ResourceCacheService - Caching the URL [https://github.com/spring-projects/spring-ai/raw/main/models/spring-ai-transformers/src/main/resources/onnx/all-MiniLM-L6-v2/model.onnx] resource to: /tmp/spring-ai-onnx-generative/eb4e1bd7-63c5-301b-8383-5d

## Testing the Embedding Model

Generating vector embeddings for two sample phrases

In [4]:
List<float[]> embeddings = embeddingModel.embed(List.of("Hello world", "World is big"));

[JJava-executor-0] INFO ai.djl.pytorch.engine.PtEngine - PyTorch graph executor optimizer is enabled, this may impact your inference latency and throughput. See: https://docs.djl.ai/master/docs/development/inference_performance_optimization.html#graph-executor-optimization
[JJava-executor-0] INFO ai.djl.pytorch.engine.PtEngine - Number of inter-op threads is 12
[JJava-executor-0] INFO ai.djl.pytorch.engine.PtEngine - Number of intra-op threads is 12


In [5]:
embeddings.size()

2

In [6]:
float[] e0 = embeddings.get(0);
Arrays.toString(e0);

[-0.19744644, 0.17766532, 0.03857004, 0.1495222, -0.22542009, -0.918028, 0.38326377, -0.03688945, -0.271742, 0.084521994, 0.40589252, 0.31799775, 0.10991715, -0.15033704, -0.0578956, -0.1542844, 0.1277511, -0.12728858, -0.85726726, -0.100180045, 0.043960992, 0.31126785, 0.018637724, 0.18169005, -0.4846143, -0.16840324, 0.29548055, 0.27559924, -0.01898329, -0.33375576, 0.24035157, 0.12719727, 0.7341182, -0.12793198, -0.06675415, 0.3603812, -0.18827778, -0.52243793, -0.17853652, 0.301802, 0.2693615, -0.48221794, -0.17212732, -0.11880259, 0.054506138, -0.021313868, 0.042054005, 0.22520447, 0.53416646, -0.02169647, -0.30204588, -0.3324908, -0.039310955, 0.030255951, 0.47471577, 0.11088768, 0.03599049, -0.059162557, 0.05172684, -0.21580887, -0.2588888, 0.13753763, -0.03976778, 0.077264294, 0.5730004, -0.41052252, -0.12424426, 0.18107419, -0.29570377, -0.47102028, -0.3762157, -0.0566694, 0.03330949, 0.42123562, -0.19500081, 0.14251879, 0.08297111, 0.15151738, 0.055302583, 0.17305022, 0.30240

## Configuring Redis Vector Store

Sets up a connection to a Redis server at hostname "redis-java" on port 6379
Creates a vector store for storing and retrieving embeddings, with:

- A Redis index named "beers"
- A prefix of "beer:" for all keys
- Automatic schema initialization

In [7]:
import redis.clients.jedis.JedisPooled;
import org.springframework.ai.vectorstore.redis.RedisVectorStore;

var jedisPooled = new JedisPooled("redis-java", 6379);

var vectorStore = RedisVectorStore.builder(jedisPooled, embeddingModel)
        .indexName("beers")              
        .prefix("beer:") 
        .initializeSchema(true) 
        .build();

vectorStore.afterPropertiesSet();

## Loading Beer Data into Redis

- Defines the relevant fields to extract from the beer JSON data
- Checks if embeddings are already loaded in Redis by querying the index information
- If not loaded:
  - Opens the compressed beer data file
  - Creates a JSON reader to parse the file and extract the specified fields
  - Adds the documents to the vector store, which automatically:
    - Creates embeddings for each document
    - Stores both the documents and their embeddings in Redis

In [8]:
import java.io.File;
import java.io.FileInputStream;
import java.util.Map;
import java.util.zip.GZIPInputStream;

import org.springframework.ai.reader.JsonReader;
import org.springframework.core.io.InputStreamResource;
import org.springframework.core.io.FileSystemResource;

// Define the keys we want to extract from the JSON
String[] KEYS = { "name", "abv", "ibu", "description" };

// Data path
String filePath = "../resources/beers.json.gz";

// Check if embeddings are already loaded
Map<String, Object> indexInfo = vectorStore.getJedis().ftInfo("beers");
long numDocs = (long)indexInfo.getOrDefault("num_docs", "0");
if (numDocs > 20000) {
    System.out.println("Embeddings already loaded. Skipping");
} else {
    System.out.println("Creating Embeddings...");
    
    // Create a file resource directly from the absolute path
    File file = new File(filePath);
    
    // Create a GZIPInputStream
    GZIPInputStream inputStream = new GZIPInputStream(new FileInputStream(file));
    InputStreamResource resource = new InputStreamResource(inputStream);
    
    // Create a JSON reader with fields relevant to our use case
    JsonReader loader = new JsonReader(resource, KEYS);
    
    // Use the VectorStore to insert the documents into Redis
    vectorStore.add(loader.get());
    
    System.out.println("Embeddings created.");
}

Embeddings already loaded. Skipping


## Define the System Prompt

Here we try to control the behavior of the LLM

In [9]:
String systemPrompt = """
    You're assisting with questions about products in a beer catalog.
    Use the information from the DOCUMENTS section to provide accurate answers.
    The answer involves referring to the ABV or IBU of the beer, include the beer name in the response.
    If unsure, simply state that you don't know.
    
    DOCUMENTS:
    {documents}
    """;

## Setting up the Chat Client with the created ChatModel

In [10]:
import org.springframework.ai.chat.client.ChatClient;

ChatClient chatClient = ChatClient.builder(chatModel)
    .build();

## Creating a Query Function

Encapsulate the RAG logic into a single method

In [11]:
import java.util.stream.Collectors;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.chat.prompt.SystemPromptTemplate;
import org.springframework.ai.document.Document;
import org.springframework.ai.vectorstore.SearchRequest;

void ask(String query) {
    SearchRequest request = SearchRequest.builder().query(query).topK(10).build();

    // Query Redis for the top K documents most relevant to the input message
    List<Document> docs = vectorStore.similaritySearch(request);
    
    String documents = docs.stream() //
        .map(Document::getText) //
        .collect(Collectors.joining("\n"));
    
    SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemPrompt);
    Message systemMessage = systemPromptTemplate.createMessage(Map.of("documents", documents));
    
    UserMessage userMessage = new UserMessage(query);
    // Assemble the complete prompt using a template
    Prompt prompt = new Prompt(List.of(systemMessage, userMessage));
    // Call the chat client with the prompt
    ChatResponse chatResponse = chatClient.prompt(prompt).call().chatResponse();
    
    System.out.println(chatResponse.getResult().getOutput().getText());
}

## 🍺 Now let's talk about Beers!

In [12]:
ask("What beer pais well with smoked meats?");

A beer that pairs well with smoked meats is the "Oak Smoker," with an ABV of 11.5%. This Smoked Wee Heavy has a wonderfully subtle smoky background and rich malty flavors, making it a perfect pairing for BBQ or enjoying on its own.


In [13]:
ask("What beer would make me lose weight?");

Beer does not typically aid in weight loss as it contains calories. However, lower alcohol content beers like the Airship Cream Ale with an ABV of 4.5 might be a lighter option compared to higher ABV beers.
