Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/main.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ jobs:
- uses: actions/checkout@v3
- uses: actions/setup-node@v3
with:
node-version: '18.x'
node-version: '22.x'
- name: "Run checks"
run: |
npm ci
Expand Down Expand Up @@ -81,7 +81,7 @@ jobs:
# Setup .npmrc file to publish to npm
- uses: actions/setup-node@v3
with:
node-version: '18.x'
node-version: '22.x'
registry-url: 'https://registry.npmjs.org'
- run: npm ci
- run: npm run build
Expand Down
26 changes: 16 additions & 10 deletions ci/run_dependencies.sh
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,24 @@ function wait(){

echo "Waiting for $1"
while true; do
if curl -s $1 > /dev/null; then
# first check if weaviate already responds
if ! curl -s $1 > /dev/null; then
continue
fi

# endpoint available, check if it is ready
HTTP_STATUS=$(curl -s -o /dev/null -w "%{http_code}" "$1/v1/.well-known/ready")

if [ "$HTTP_STATUS" -eq 200 ]; then
break
else
if [ $? -eq 7 ]; then
echo "Weaviate is not up yet. (waited for ${ALREADY_WAITING}s)"
if [ $ALREADY_WAITING -gt $MAX_WAIT_SECONDS ]; then
echo "Weaviate did not start up in $MAX_WAIT_SECONDS."
exit 1
else
sleep 2
let ALREADY_WAITING=$ALREADY_WAITING+2
fi
echo "Weaviate is not up yet. (waited for ${ALREADY_WAITING}s)"
if [ $ALREADY_WAITING -gt $MAX_WAIT_SECONDS ]; then
echo "Weaviate did not start up in $MAX_WAIT_SECONDS."
exit 1
else
sleep 2
let ALREADY_WAITING=$ALREADY_WAITING+2
fi
fi
done
Expand Down
166 changes: 69 additions & 97 deletions src/collections/generate/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@ import Connection from '../../connection/grpc.js';
import { ConsistencyLevel } from '../../data/index.js';
import { DbVersionSupport } from '../../utils/dbVersion.js';

import { WeaviateInvalidInputError, WeaviateUnsupportedFeatureError } from '../../errors.js';
import { WeaviateInvalidInputError } from '../../errors.js';
import { toBase64FromMedia } from '../../index.js';
import { SearchReply } from '../../proto/v1/search_get.js';
import { Deserialize } from '../deserialize/index.js';
import { Check } from '../query/check.js';
import {
BaseBm25Options,
BaseHybridOptions,
Expand Down Expand Up @@ -34,24 +35,10 @@ import {
import { Generate } from './types.js';

class GenerateManager<T> implements Generate<T> {
private connection: Connection;
private name: string;
private dbVersionSupport: DbVersionSupport;
private consistencyLevel?: ConsistencyLevel;
private tenant?: string;
private check: Check<T>;

private constructor(
connection: Connection,
name: string,
dbVersionSupport: DbVersionSupport,
consistencyLevel?: ConsistencyLevel,
tenant?: string
) {
this.connection = connection;
this.name = name;
this.dbVersionSupport = dbVersionSupport;
this.consistencyLevel = consistencyLevel;
this.tenant = tenant;
private constructor(check: Check<T>) {
this.check = check;
}

public static use<T>(
Expand All @@ -61,78 +48,29 @@ class GenerateManager<T> implements Generate<T> {
consistencyLevel?: ConsistencyLevel,
tenant?: string
): GenerateManager<T> {
return new GenerateManager<T>(connection, name, dbVersionSupport, consistencyLevel, tenant);
return new GenerateManager<T>(new Check<T>(connection, name, dbVersionSupport, consistencyLevel, tenant));
}

private checkSupportForNamedVectors = async (opts?: BaseNearOptions<T>) => {
if (!Serialize.isNamedVectors(opts)) return;
const check = await this.dbVersionSupport.supportsNamedVectors();
if (!check.supports) throw new WeaviateUnsupportedFeatureError(check.message);
};

private checkSupportForBm25AndHybridGroupByQueries = async (query: 'Bm25' | 'Hybrid', opts?: any) => {
if (!Serialize.isGroupBy(opts)) return;
const check = await this.dbVersionSupport.supportsBm25AndHybridGroupByQueries();
if (!check.supports) throw new WeaviateUnsupportedFeatureError(check.message(query));
};

private checkSupportForHybridNearTextAndNearVectorSubSearches = async (opts?: HybridOptions<T>) => {
if (opts?.vector === undefined || Array.isArray(opts.vector)) return;
const check = await this.dbVersionSupport.supportsHybridNearTextAndNearVectorSubsearchQueries();
if (!check.supports) throw new WeaviateUnsupportedFeatureError(check.message);
};

private checkSupportForMultiTargetVectorSearch = async (opts?: BaseNearOptions<T>) => {
if (!Serialize.isMultiTargetVector(opts)) return false;
const check = await this.dbVersionSupport.supportsMultiTargetVectorSearch();
if (!check.supports) throw new WeaviateUnsupportedFeatureError(check.message);
return check.supports;
};

private nearSearch = async (opts?: BaseNearOptions<T>) => {
const [_, supportsTargets] = await Promise.all([
this.checkSupportForNamedVectors(opts),
this.checkSupportForMultiTargetVectorSearch(opts),
]);
return {
search: await this.connection.search(this.name, this.consistencyLevel, this.tenant),
supportsTargets,
};
};

private hybridSearch = async (opts?: BaseHybridOptions<T>) => {
const [supportsTargets] = await Promise.all([
this.checkSupportForMultiTargetVectorSearch(opts),
this.checkSupportForNamedVectors(opts),
this.checkSupportForBm25AndHybridGroupByQueries('Hybrid', opts),
this.checkSupportForHybridNearTextAndNearVectorSubSearches(opts),
]);
return {
search: await this.connection.search(this.name, this.consistencyLevel, this.tenant),
supportsTargets,
};
};

private async parseReply(reply: SearchReply) {
const deserialize = await Deserialize.use(this.dbVersionSupport);
const deserialize = await Deserialize.use(this.check.dbVersionSupport);
return deserialize.generate<T>(reply);
}

private async parseGroupByReply(
opts: SearchOptions<T> | GroupByOptions<T> | undefined,
reply: SearchReply
) {
const deserialize = await Deserialize.use(this.dbVersionSupport);
const deserialize = await Deserialize.use(this.check.dbVersionSupport);
return Serialize.isGroupBy(opts) ? deserialize.generateGroupBy<T>(reply) : deserialize.generate<T>(reply);
}

public fetchObjects(
generate: GenerateOptions<T>,
opts?: FetchObjectsOptions<T>
): Promise<GenerativeReturn<T>> {
return this.checkSupportForNamedVectors(opts)
.then(() => this.connection.search(this.name, this.consistencyLevel, this.tenant))
.then((search) =>
return this.check
.fetchObjects(opts)
.then(({ search }) =>
search.withFetch({
...Serialize.fetchObjects(opts),
generative: Serialize.generative(generate),
Expand All @@ -152,12 +90,9 @@ class GenerateManager<T> implements Generate<T> {
opts: GroupByBm25Options<T>
): Promise<GenerativeGroupByReturn<T>>;
public bm25(query: string, generate: GenerateOptions<T>, opts?: Bm25Options<T>): GenerateReturn<T> {
return Promise.all([
this.checkSupportForNamedVectors(opts),
this.checkSupportForBm25AndHybridGroupByQueries('Bm25', opts),
])
.then(() => this.connection.search(this.name, this.consistencyLevel, this.tenant))
.then((search) =>
return this.check
.bm25(opts)
.then(({ search }) =>
search.withBm25({
...Serialize.bm25({ query, ...opts }),
generative: Serialize.generative(generate),
Expand All @@ -180,10 +115,17 @@ class GenerateManager<T> implements Generate<T> {
opts: GroupByHybridOptions<T>
): Promise<GenerativeGroupByReturn<T>>;
public hybrid(query: string, generate: GenerateOptions<T>, opts?: HybridOptions<T>): GenerateReturn<T> {
return this.hybridSearch(opts)
.then(({ search, supportsTargets }) =>
return this.check
.hybridSearch(opts)
.then(({ search, supportsTargets, supportsVectorsForTargets, supportsWeightsForTargets }) =>
search.withHybrid({
...Serialize.hybrid({ query, supportsTargets, ...opts }),
...Serialize.hybrid({
query,
supportsTargets,
supportsVectorsForTargets,
supportsWeightsForTargets,
...opts,
}),
generative: Serialize.generative(generate),
groupBy: Serialize.isGroupBy<GroupByHybridOptions<T>>(opts)
? Serialize.groupBy(opts.groupBy)
Expand All @@ -208,11 +150,17 @@ class GenerateManager<T> implements Generate<T> {
generate: GenerateOptions<T>,
opts?: NearOptions<T>
): GenerateReturn<T> {
return this.nearSearch(opts)
.then(({ search, supportsTargets }) =>
return this.check
.nearSearch(opts)
.then(({ search, supportsTargets, supportsWeightsForTargets }) =>
toBase64FromMedia(image).then((image) =>
search.withNearImage({
...Serialize.nearImage({ image, supportsTargets, ...(opts ? opts : {}) }),
...Serialize.nearImage({
image,
supportsTargets,
supportsWeightsForTargets,
...(opts ? opts : {}),
}),
generative: Serialize.generative(generate),
groupBy: Serialize.isGroupBy<GroupByNearOptions<T>>(opts)
? Serialize.groupBy(opts.groupBy)
Expand All @@ -234,10 +182,16 @@ class GenerateManager<T> implements Generate<T> {
opts: GroupByNearOptions<T>
): Promise<GenerativeGroupByReturn<T>>;
public nearObject(id: string, generate: GenerateOptions<T>, opts?: NearOptions<T>): GenerateReturn<T> {
return this.nearSearch(opts)
.then(({ search, supportsTargets }) =>
return this.check
.nearSearch(opts)
.then(({ search, supportsTargets, supportsWeightsForTargets }) =>
search.withNearObject({
...Serialize.nearObject({ id, supportsTargets, ...(opts ? opts : {}) }),
...Serialize.nearObject({
id,
supportsTargets,
supportsWeightsForTargets,
...(opts ? opts : {}),
}),
generative: Serialize.generative(generate),
groupBy: Serialize.isGroupBy<GroupByNearOptions<T>>(opts)
? Serialize.groupBy(opts.groupBy)
Expand All @@ -262,10 +216,16 @@ class GenerateManager<T> implements Generate<T> {
generate: GenerateOptions<T>,
opts?: NearOptions<T>
): GenerateReturn<T> {
return this.nearSearch(opts)
.then(({ search, supportsTargets }) =>
return this.check
.nearSearch(opts)
.then(({ search, supportsTargets, supportsWeightsForTargets }) =>
search.withNearText({
...Serialize.nearText({ query, supportsTargets, ...(opts ? opts : {}) }),
...Serialize.nearText({
query,
supportsTargets,
supportsWeightsForTargets,
...(opts ? opts : {}),
}),
generative: Serialize.generative(generate),
groupBy: Serialize.isGroupBy<GroupByNearOptions<T>>(opts)
? Serialize.groupBy(opts.groupBy)
Expand All @@ -290,10 +250,17 @@ class GenerateManager<T> implements Generate<T> {
generate: GenerateOptions<T>,
opts?: NearOptions<T>
): GenerateReturn<T> {
return this.nearSearch(opts)
.then(({ search, supportsTargets }) =>
return this.check
.nearVector(vector, opts)
.then(({ search, supportsTargets, supportsVectorsForTargets, supportsWeightsForTargets }) =>
search.withNearVector({
...Serialize.nearVector({ vector, supportsTargets, ...(opts ? opts : {}) }),
...Serialize.nearVector({
vector,
supportsTargets,
supportsVectorsForTargets,
supportsWeightsForTargets,
...(opts ? opts : {}),
}),
generative: Serialize.generative(generate),
groupBy: Serialize.isGroupBy<GroupByNearOptions<T>>(opts)
? Serialize.groupBy(opts.groupBy)
Expand Down Expand Up @@ -321,10 +288,15 @@ class GenerateManager<T> implements Generate<T> {
generate: GenerateOptions<T>,
opts?: NearOptions<T>
): GenerateReturn<T> {
return this.nearSearch(opts)
.then(({ search, supportsTargets }) => {
return this.check
.nearSearch(opts)
.then(({ search, supportsTargets, supportsWeightsForTargets }) => {
let reply: Promise<SearchReply>;
const args = { supportsTargets, ...(opts ? opts : {}) };
const args = {
supportsTargets,
supportsWeightsForTargets,
...(opts ? opts : {}),
};
const generative = Serialize.generative(generate);
const groupBy = Serialize.isGroupBy<GroupByNearOptions<T>>(opts)
? Serialize.groupBy(opts.groupBy)
Expand Down
5 changes: 1 addition & 4 deletions src/collections/generate/integration.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -379,14 +379,11 @@ maybe('Testing of the collection.generate methods with a multi vector collection
it('should generate with a near vector search on multi vectors', async () => {
const query = () =>
collection.generate.nearVector(
[titleVector, title2Vector],
{ title: titleVector, title2: title2Vector },
{
groupedTask: 'What is the value of title here? {title}',
groupedProperties: ['title'],
singlePrompt: 'Write a haiku about ducks for {title}',
},
{
targetVector: ['title', 'title2'],
}
);
if (await client.getWeaviateVersion().then((ver) => ver.isLowerThan(1, 26, 0))) {
Expand Down
Loading
Loading