Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: store images in docstore #281

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file removed examples/data/multi_modal/1.jpg
Binary file not shown.
Binary file removed examples/data/multi_modal/2.jpg
Binary file not shown.
Binary file removed examples/data/multi_modal/3.jpg
Binary file not shown.
323 changes: 0 additions & 323 deletions examples/data/multi_modal/San Francisco.txt

This file was deleted.

File renamed without changes.
File renamed without changes
File renamed without changes
File renamed without changes
46 changes: 46 additions & 0 deletions examples/multimodal/load.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import {
ServiceContext,
serviceContextFromDefaults,
SimpleDirectoryReader,
storageContextFromDefaults,
VectorStoreIndex,
} from "llamaindex";
import * as path from "path";

async function getRuntime(func: any) {
const start = Date.now();
await func();
const end = Date.now();
return end - start;
}

async function generateDatasource(serviceContext: ServiceContext) {
console.log(`Generating storage...`);
// Split documents, create embeddings and store them in the storage context
const ms = await getRuntime(async () => {
const documents = await new SimpleDirectoryReader().loadData({
directoryPath: path.join("multimodal", "data"),
});
const storageContext = await storageContextFromDefaults({
persistDir: "storage",
storeImages: true,
});
await VectorStoreIndex.fromDocuments(documents, {
serviceContext,
storageContext,
});
});
console.log(`Storage successfully generated in ${ms / 1000}s.`);
}

async function main() {
const serviceContext = serviceContextFromDefaults({
chunkSize: 512,
chunkOverlap: 20,
});

await generateDatasource(serviceContext);
console.log("Finished generating storage.");
}

main().catch(console.error);
43 changes: 43 additions & 0 deletions examples/multimodal/rag.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import {
MultiModalResponseSynthesizer,
OpenAI,
ServiceContext,
VectorStoreIndex,
serviceContextFromDefaults,
storageContextFromDefaults,
} from "llamaindex";

export async function createIndex(serviceContext: ServiceContext) {
// set up vector store index with two vector stores, one for text, the other for images
const storageContext = await storageContextFromDefaults({
persistDir: "storage",
storeImages: true,
});
return await VectorStoreIndex.init({
nodes: [],
storageContext,
serviceContext,
});
}

async function main() {
const llm = new OpenAI({ model: "gpt-4-vision-preview", maxTokens: 512 });
const serviceContext = serviceContextFromDefaults({
llm,
chunkSize: 512,
chunkOverlap: 20,
});
const index = await createIndex(serviceContext);

const queryEngine = index.asQueryEngine({
responseSynthesizer: new MultiModalResponseSynthesizer({ serviceContext }),
// TODO: set text similarity to a higher value than image similarity
retriever: index.asRetriever({ similarityTopK: 1 }),
});
const result = await queryEngine.query(
"what are Vincent van Gogh's famous paintings",
);
console.log(result.response);
}

main().catch(console.error);
38 changes: 19 additions & 19 deletions examples/multiModal.ts → examples/multimodal/retrieve.ts
Original file line number Diff line number Diff line change
@@ -1,32 +1,32 @@
import {
ImageNode,
serviceContextFromDefaults,
SimpleDirectoryReader,
SimpleVectorStore,
storageContextFromDefaults,
TextNode,
VectorStoreIndex,
} from "llamaindex";
import * as path from "path";

async function main() {
// read data into documents
const reader = new SimpleDirectoryReader();
const documents = await reader.loadData({
directoryPath: "data/multi_modal",
});
export async function createIndex() {
// set up vector store index with two vector stores, one for text, the other for images
const serviceContext = serviceContextFromDefaults({ chunkSize: 512 });
const vectorStore = await SimpleVectorStore.fromPersistDir("./storage/text");
const imageVectorStore =
await SimpleVectorStore.fromPersistDir("./storage/images");
const index = await VectorStoreIndex.fromDocuments(documents, {
const serviceContext = serviceContextFromDefaults({
chunkSize: 512,
chunkOverlap: 20,
});
const storageContext = await storageContextFromDefaults({
persistDir: "storage",
storeImages: true,
});
return await VectorStoreIndex.init({
nodes: [],
storageContext,
serviceContext,
imageVectorStore,
vectorStore,
});
}

async function main() {
// retrieve documents using the index
const retriever = index.asRetriever();
retriever.similarityTopK = 3;
const index = await createIndex();
const retriever = index.asRetriever({ similarityTopK: 3 });
const results = await retriever.retrieve(
"what are Vincent van Gogh's famous paintings",
);
Expand All @@ -36,7 +36,7 @@ async function main() {
continue;
}
if (node instanceof ImageNode) {
console.log(`Image: ${path.join(__dirname, node.id_)}`);
console.log(`Image: ${node.getUrl()}`);
} else if (node instanceof TextNode) {
console.log("Text:", (node as TextNode).text.substring(0, 128));
}
Expand Down
1 change: 1 addition & 0 deletions packages/core/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
"@xenova/transformers": "^2.10.0",
"assemblyai": "^4.0.0",
"crypto-js": "^4.2.0",
"file-type": "^18.7.0",
"js-tiktoken": "^1.0.8",
"lodash": "^4.17.21",
"mammoth": "^1.6.0",
Expand Down
76 changes: 71 additions & 5 deletions packages/core/src/Node.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import CryptoJS from "crypto-js";
import path from "path";
import { v4 as uuidv4 } from "uuid";
import { imageToString, stringToImage } from "./embeddings";

export enum NodeRelationship {
SOURCE = "SOURCE",
Expand Down Expand Up @@ -146,6 +148,19 @@ export abstract class BaseNode<T extends Metadata = Metadata> {
toJSON(): Record<string, any> {
return { ...this, type: this.getType() };
}

/**
* Async version of toJSON (used to return larger objects like images)
* @returns
*/
async aToJSON(): Promise<Record<string, any>> {
return this.toJSON();
}

async clone() {
const json = await this.aToJSON();
return jsonToNode(json);
}
}

/**
Expand Down Expand Up @@ -265,24 +280,34 @@ export class Document<T extends Metadata = Metadata> extends TextNode<T> {
}
}

export function jsonToNode(json: any, type?: ObjectType) {
export function jsonToNode(json: any, type?: ObjectType): BaseNode {
if (!json.type && !type) {
throw new Error("Node type not found");
}
const nodeType = type || json.type;

let node;
switch (nodeType) {
case ObjectType.TEXT:
return new TextNode(json);
node = new TextNode(json);
break;
case ObjectType.INDEX:
return new IndexNode(json);
node = new IndexNode(json);
break;
case ObjectType.DOCUMENT:
return new Document(json);
node = new Document(json);
break;
case ObjectType.IMAGE_DOCUMENT:
return new ImageDocument(json);
node = ImageDocument.fromJSON(json);
break;
default:
throw new Error(`Invalid node type: ${nodeType}`);
}
// XXX: Calling the constructor generates a new hash that we don't want when we're
// deserializing the node. So we set the original hash here again. This is a hacky solution
// and we have to clean that up when we refactor the hashing mechanism.
node.hash = json.hash;
return node;
}

export type ImageType = string | Blob | URL;
Expand All @@ -304,6 +329,19 @@ export class ImageNode<T extends Metadata = Metadata> extends TextNode<T> {
getType(): ObjectType {
return ObjectType.IMAGE;
}

async aToJSON(): Promise<Record<string, any>> {
return {
...(await super.aToJSON()),
image: await imageToString(this.image),
};
}

getUrl(): URL {
// id_ stores the relative path, convert it to the URL of the file
const absPath = path.resolve(this.id_);
return new URL(`file://${absPath}`);
}
}

export class ImageDocument<T extends Metadata = Metadata> extends ImageNode<T> {
Expand All @@ -315,6 +353,14 @@ export class ImageDocument<T extends Metadata = Metadata> extends ImageNode<T> {
}
}

static fromJSON(json: any) {
if (typeof json.image !== "string") {
return new ImageDocument(json);
}
const image = stringToImage(json.image);
return new ImageDocument({ ...json, image });
}

getType() {
return ObjectType.IMAGE_DOCUMENT;
}
Expand All @@ -327,3 +373,23 @@ export interface NodeWithScore<T extends Metadata = Metadata> {
node: BaseNode<T>;
score?: number;
}

export function splitNodesByType(nodes: BaseNode[]): {
imageNodes: ImageNode[];
textNodes: TextNode[];
} {
let imageNodes: ImageNode[] = [];
let textNodes: TextNode[] = [];

for (let node of nodes) {
if (node instanceof ImageNode) {
imageNodes.push(node);
} else if (node instanceof TextNode) {
textNodes.push(node);
}
}
return {
imageNodes,
textNodes,
};
}
17 changes: 9 additions & 8 deletions packages/core/src/QueryEngine.ts
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
import { v4 as uuidv4 } from "uuid";
import { Event } from "./callbacks/CallbackManager";
import { BaseNodePostprocessor } from "./indices/BaseNodePostprocessor";
import { NodeWithScore, TextNode } from "./Node";
import {
BaseQuestionGenerator,
LLMQuestionGenerator,
SubQuestion,
} from "./QuestionGenerator";
import { Response } from "./Response";
import { CompactAndRefine, ResponseSynthesizer } from "./ResponseSynthesizer";
import { BaseRetriever } from "./Retriever";
import { ServiceContext, serviceContextFromDefaults } from "./ServiceContext";
import { QueryEngineTool, ToolMetadata } from "./Tool";
import { Event } from "./callbacks/CallbackManager";
import { BaseNodePostprocessor } from "./indices/BaseNodePostprocessor";
import { CompactAndRefine, ResponseSynthesizer } from "./synthesizers";
import { BaseSynthesizer } from "./synthesizers/types";

/**
* A query engine is a question answerer that can use one or more steps.
Expand All @@ -30,13 +31,13 @@ export interface BaseQueryEngine {
*/
export class RetrieverQueryEngine implements BaseQueryEngine {
retriever: BaseRetriever;
responseSynthesizer: ResponseSynthesizer;
responseSynthesizer: BaseSynthesizer;
nodePostprocessors: BaseNodePostprocessor[];
preFilters?: unknown;

constructor(
retriever: BaseRetriever,
responseSynthesizer?: ResponseSynthesizer,
responseSynthesizer?: BaseSynthesizer,
preFilters?: unknown,
nodePostprocessors?: BaseNodePostprocessor[],
) {
Expand Down Expand Up @@ -81,14 +82,14 @@ export class RetrieverQueryEngine implements BaseQueryEngine {
* SubQuestionQueryEngine decomposes a question into subquestions and then
*/
export class SubQuestionQueryEngine implements BaseQueryEngine {
responseSynthesizer: ResponseSynthesizer;
responseSynthesizer: BaseSynthesizer;
questionGen: BaseQuestionGenerator;
queryEngines: Record<string, BaseQueryEngine>;
metadatas: ToolMetadata[];

constructor(init: {
questionGen: BaseQuestionGenerator;
responseSynthesizer: ResponseSynthesizer;
responseSynthesizer: BaseSynthesizer;
queryEngineTools: QueryEngineTool[];
}) {
this.questionGen = init.questionGen;
Expand All @@ -106,7 +107,7 @@ export class SubQuestionQueryEngine implements BaseQueryEngine {
static fromDefaults(init: {
queryEngineTools: QueryEngineTool[];
questionGen?: BaseQuestionGenerator;
responseSynthesizer?: ResponseSynthesizer;
responseSynthesizer?: BaseSynthesizer;
serviceContext?: ServiceContext;
}) {
const serviceContext =
Expand Down
2 changes: 1 addition & 1 deletion packages/core/src/embeddings/ClipEmbedding.ts
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ export class ClipEmbedding extends MultiModalEmbedding {
const loadedImage = await readImage(image);
const imageInputs = await (await this.getProcessor())(loadedImage);
const { image_embeds } = await (await this.getVisionModel())(imageInputs);
return image_embeds.data;
return Array.from(image_embeds.data);
}

async getTextEmbedding(text: string): Promise<number[]> {
Expand Down
Loading