Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Add multi modal RAG #280

Merged
merged 5 commits into from
Dec 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .changeset/large-plums-drum.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
"llamaindex": patch
---

Added support for multi-modal RAG (retriever and query engine) incl. an example
Fixed persisting and loading image vector stores
Binary file removed examples/data/multi_modal/1.jpg
Binary file not shown.
Binary file removed examples/data/multi_modal/2.jpg
Binary file not shown.
Binary file removed examples/data/multi_modal/3.jpg
Binary file not shown.
323 changes: 0 additions & 323 deletions examples/data/multi_modal/San Francisco.txt

This file was deleted.

File renamed without changes.
File renamed without changes
File renamed without changes
File renamed without changes
46 changes: 46 additions & 0 deletions examples/multimodal/load.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import {
ServiceContext,
serviceContextFromDefaults,
SimpleDirectoryReader,
storageContextFromDefaults,
VectorStoreIndex,
} from "llamaindex";
import * as path from "path";

async function getRuntime(func: any) {
const start = Date.now();
await func();
const end = Date.now();
return end - start;
}

async function generateDatasource(serviceContext: ServiceContext) {
console.log(`Generating storage...`);
// Split documents, create embeddings and store them in the storage context
const ms = await getRuntime(async () => {
const documents = await new SimpleDirectoryReader().loadData({
directoryPath: path.join("multimodal", "data"),
});
const storageContext = await storageContextFromDefaults({
persistDir: "storage",
storeImages: true,
});
await VectorStoreIndex.fromDocuments(documents, {
serviceContext,
storageContext,
});
});
console.log(`Storage successfully generated in ${ms / 1000}s.`);
}

async function main() {
const serviceContext = serviceContextFromDefaults({
chunkSize: 512,
chunkOverlap: 20,
});

await generateDatasource(serviceContext);
console.log("Finished generating storage.");
}

main().catch(console.error);
58 changes: 58 additions & 0 deletions examples/multimodal/rag.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import {
CallbackManager,
ImageDocument,
ImageType,
MultiModalResponseSynthesizer,
NodeWithScore,
OpenAI,
ServiceContext,
VectorStoreIndex,
serviceContextFromDefaults,
storageContextFromDefaults,
} from "llamaindex";

export async function createIndex(serviceContext: ServiceContext) {
// set up vector store index with two vector stores, one for text, the other for images
const storageContext = await storageContextFromDefaults({
persistDir: "storage",
storeImages: true,
});
return await VectorStoreIndex.init({
nodes: [],
storageContext,
serviceContext,
});
}

async function main() {
let images: ImageType[] = [];
const callbackManager = new CallbackManager({
onRetrieve: ({ query, nodes }) => {
images = nodes
.filter(({ node }: NodeWithScore) => node instanceof ImageDocument)
.map(({ node }: NodeWithScore) => (node as ImageDocument).image);
},
});
const llm = new OpenAI({ model: "gpt-4-vision-preview", maxTokens: 512 });
const serviceContext = serviceContextFromDefaults({
llm,
chunkSize: 512,
chunkOverlap: 20,
callbackManager,
});
const index = await createIndex(serviceContext);

const queryEngine = index.asQueryEngine({
responseSynthesizer: new MultiModalResponseSynthesizer({ serviceContext }),
retriever: index.asRetriever({ similarityTopK: 3, imageSimilarityTopK: 1 }),
});
const result = await queryEngine.query(
"Tell me more about Vincent van Gogh's famous paintings",
);
console.log(result.response, "\n");
images.forEach((image) =>
console.log(`Image retrieved and used in inference: ${image.toString()}`),
);
}

main().catch(console.error);
38 changes: 19 additions & 19 deletions examples/multiModal.ts → examples/multimodal/retrieve.ts
Original file line number Diff line number Diff line change
@@ -1,32 +1,32 @@
import {
ImageNode,
serviceContextFromDefaults,
SimpleDirectoryReader,
SimpleVectorStore,
storageContextFromDefaults,
TextNode,
VectorStoreIndex,
} from "llamaindex";
import * as path from "path";

async function main() {
// read data into documents
const reader = new SimpleDirectoryReader();
const documents = await reader.loadData({
directoryPath: "data/multi_modal",
});
export async function createIndex() {
// set up vector store index with two vector stores, one for text, the other for images
const serviceContext = serviceContextFromDefaults({ chunkSize: 512 });
const vectorStore = await SimpleVectorStore.fromPersistDir("./storage/text");
const imageVectorStore =
await SimpleVectorStore.fromPersistDir("./storage/images");
const index = await VectorStoreIndex.fromDocuments(documents, {
const serviceContext = serviceContextFromDefaults({
chunkSize: 512,
chunkOverlap: 20,
});
const storageContext = await storageContextFromDefaults({
persistDir: "storage",
storeImages: true,
});
return await VectorStoreIndex.init({
nodes: [],
storageContext,
serviceContext,
imageVectorStore,
vectorStore,
});
}

async function main() {
// retrieve documents using the index
const retriever = index.asRetriever();
retriever.similarityTopK = 3;
const index = await createIndex();
const retriever = index.asRetriever({ similarityTopK: 3 });
const results = await retriever.retrieve(
"what are Vincent van Gogh's famous paintings",
);
Expand All @@ -36,7 +36,7 @@ async function main() {
continue;
}
if (node instanceof ImageNode) {
console.log(`Image: ${path.join(__dirname, node.id_)}`);
console.log(`Image: ${node.getUrl()}`);
} else if (node instanceof TextNode) {
console.log("Text:", (node as TextNode).text.substring(0, 128));
}
Expand Down
1 change: 1 addition & 0 deletions packages/core/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
"@xenova/transformers": "^2.10.0",
"assemblyai": "^4.0.0",
"crypto-js": "^4.2.0",
"file-type": "^18.7.0",
"js-tiktoken": "^1.0.8",
"lodash": "^4.17.21",
"mammoth": "^1.6.0",
Expand Down
27 changes: 27 additions & 0 deletions packages/core/src/Node.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import CryptoJS from "crypto-js";
import path from "path";
import { v4 as uuidv4 } from "uuid";

export enum NodeRelationship {
Expand Down Expand Up @@ -304,6 +305,12 @@ export class ImageNode<T extends Metadata = Metadata> extends TextNode<T> {
getType(): ObjectType {
return ObjectType.IMAGE;
}

getUrl(): URL {
// id_ stores the relative path, convert it to the URL of the file
const absPath = path.resolve(this.id_);
return new URL(`file://${absPath}`);
}
}

export class ImageDocument<T extends Metadata = Metadata> extends ImageNode<T> {
Expand All @@ -327,3 +334,23 @@ export interface NodeWithScore<T extends Metadata = Metadata> {
node: BaseNode<T>;
score?: number;
}

export function splitNodesByType(nodes: BaseNode[]): {
imageNodes: ImageNode[];
textNodes: TextNode[];
} {
let imageNodes: ImageNode[] = [];
let textNodes: TextNode[] = [];

for (let node of nodes) {
if (node instanceof ImageNode) {
imageNodes.push(node);
} else if (node instanceof TextNode) {
textNodes.push(node);
}
}
return {
imageNodes,
textNodes,
};
}
17 changes: 9 additions & 8 deletions packages/core/src/QueryEngine.ts
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
import { v4 as uuidv4 } from "uuid";
import { Event } from "./callbacks/CallbackManager";
import { BaseNodePostprocessor } from "./indices/BaseNodePostprocessor";
import { NodeWithScore, TextNode } from "./Node";
import {
BaseQuestionGenerator,
LLMQuestionGenerator,
SubQuestion,
} from "./QuestionGenerator";
import { Response } from "./Response";
import { CompactAndRefine, ResponseSynthesizer } from "./ResponseSynthesizer";
import { BaseRetriever } from "./Retriever";
import { ServiceContext, serviceContextFromDefaults } from "./ServiceContext";
import { QueryEngineTool, ToolMetadata } from "./Tool";
import { Event } from "./callbacks/CallbackManager";
import { BaseNodePostprocessor } from "./indices/BaseNodePostprocessor";
import { CompactAndRefine, ResponseSynthesizer } from "./synthesizers";
import { BaseSynthesizer } from "./synthesizers/types";

/**
* A query engine is a question answerer that can use one or more steps.
Expand All @@ -30,13 +31,13 @@ export interface BaseQueryEngine {
*/
export class RetrieverQueryEngine implements BaseQueryEngine {
retriever: BaseRetriever;
responseSynthesizer: ResponseSynthesizer;
responseSynthesizer: BaseSynthesizer;
nodePostprocessors: BaseNodePostprocessor[];
preFilters?: unknown;

constructor(
retriever: BaseRetriever,
responseSynthesizer?: ResponseSynthesizer,
responseSynthesizer?: BaseSynthesizer,
preFilters?: unknown,
nodePostprocessors?: BaseNodePostprocessor[],
) {
Expand Down Expand Up @@ -81,14 +82,14 @@ export class RetrieverQueryEngine implements BaseQueryEngine {
* SubQuestionQueryEngine decomposes a question into subquestions and then
*/
export class SubQuestionQueryEngine implements BaseQueryEngine {
responseSynthesizer: ResponseSynthesizer;
responseSynthesizer: BaseSynthesizer;
questionGen: BaseQuestionGenerator;
queryEngines: Record<string, BaseQueryEngine>;
metadatas: ToolMetadata[];

constructor(init: {
questionGen: BaseQuestionGenerator;
responseSynthesizer: ResponseSynthesizer;
responseSynthesizer: BaseSynthesizer;
queryEngineTools: QueryEngineTool[];
}) {
this.questionGen = init.questionGen;
Expand All @@ -106,7 +107,7 @@ export class SubQuestionQueryEngine implements BaseQueryEngine {
static fromDefaults(init: {
queryEngineTools: QueryEngineTool[];
questionGen?: BaseQuestionGenerator;
responseSynthesizer?: ResponseSynthesizer;
responseSynthesizer?: BaseSynthesizer;
serviceContext?: ServiceContext;
}) {
const serviceContext =
Expand Down
2 changes: 1 addition & 1 deletion packages/core/src/embeddings/ClipEmbedding.ts
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ export class ClipEmbedding extends MultiModalEmbedding {
const loadedImage = await readImage(image);
const imageInputs = await (await this.getProcessor())(loadedImage);
const { image_embeds } = await (await this.getVisionModel())(imageInputs);
return image_embeds.data;
return Array.from(image_embeds.data);
}

async getTextEmbedding(text: string): Promise<number[]> {
Expand Down
62 changes: 61 additions & 1 deletion packages/core/src/embeddings/utils.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import _ from "lodash";
import { ImageType } from "../Node";
import { DEFAULT_SIMILARITY_TOP_K } from "../constants";
import { VectorStoreQueryMode } from "../storage";
import { DEFAULT_FS, VectorStoreQueryMode } from "../storage";
import { SimilarityType } from "./types";

/**
Expand Down Expand Up @@ -185,6 +185,16 @@ export function getTopKMMREmbeddings(
return [resultSimilarities, resultIds];
}

async function blobToDataUrl(input: Blob) {
const { fileTypeFromBuffer } = await import("file-type");
const buffer = Buffer.from(await input.arrayBuffer());
const type = await fileTypeFromBuffer(buffer);
if (!type) {
throw new Error("Unsupported image type");
}
return "data:" + type.mime + ";base64," + buffer.toString("base64");
}

export async function readImage(input: ImageType) {
const { RawImage } = await import("@xenova/transformers");
if (input instanceof Blob) {
Expand All @@ -195,3 +205,53 @@ export async function readImage(input: ImageType) {
throw new Error(`Unsupported input type: ${typeof input}`);
}
}

export async function imageToString(input: ImageType): Promise<string> {
if (input instanceof Blob) {
// if the image is a Blob, convert it to a base64 data URL
return await blobToDataUrl(input);
} else if (_.isString(input)) {
return input;
} else if (input instanceof URL) {
return input.toString();
} else {
throw new Error(`Unsupported input type: ${typeof input}`);
}
}

export function stringToImage(input: string): ImageType {
if (input.startsWith("data:")) {
// if the input is a base64 data URL, convert it back to a Blob
const base64Data = input.split(",")[1];
const byteArray = Buffer.from(base64Data, "base64");
return new Blob([byteArray]);
} else if (input.startsWith("http://") || input.startsWith("https://")) {
return new URL(input);
} else if (_.isString(input)) {
return input;
} else {
throw new Error(`Unsupported input type: ${typeof input}`);
}
}

export async function imageToDataUrl(input: ImageType): Promise<string> {
// first ensure, that the input is a Blob
if (
(input instanceof URL && input.protocol === "file:") ||
_.isString(input)
) {
// string or file URL
const fs = DEFAULT_FS;
const dataBuffer = await fs.readFile(
input instanceof URL ? input.pathname : input,
);
input = new Blob([dataBuffer]);
} else if (!(input instanceof Blob)) {
if (input instanceof URL) {
throw new Error(`Unsupported URL with protocol: ${input.protocol}`);
} else {
throw new Error(`Unsupported input type: ${typeof input}`);
}
}
return await blobToDataUrl(input);
}
2 changes: 1 addition & 1 deletion packages/core/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ export * from "./PromptHelper";
export * from "./QueryEngine";
export * from "./QuestionGenerator";
export * from "./Response";
export * from "./ResponseSynthesizer";
export * from "./synthesizers";
export * from "./Retriever";
export * from "./ServiceContext";
export * from "./TextSplitter";
Expand Down
Loading
Loading