Skip to content

Commit

Permalink
enhance: add sparse float vector support to restful v2 (milvus-io#33231)
Browse files Browse the repository at this point in the history
issue: milvus-io#29419
also re-enabled an e2e test using restful api, which is previously
disabled due to milvus-io#32214.

In restful api, the accepted json formats of sparse float vector are:

* `{"indices": [1, 100, 1000], "values": [0.1, 0.2, 0.3]}`
* {"1": 0.1, "100": 0.2, "1000": 0.3}

for accepted indice and value range, see
https://milvus.io/docs/sparse_vector.md#FAQ

Signed-off-by: Buqian Zheng <zhengbuqian@gmail.com>
  • Loading branch information
zhengbuqian authored and yellow-shine committed Jul 2, 2024
1 parent c33a611 commit 3bc6be0
Show file tree
Hide file tree
Showing 9 changed files with 422 additions and 108 deletions.
5 changes: 4 additions & 1 deletion internal/distributed/proxy/httpserver/handler_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -848,7 +848,10 @@ func generatePlaceholderGroup(ctx context.Context, body string, collSchema *sche
if vectorField == nil {
return nil, errors.New("cannot find a vector field named: " + fieldName)
}
dim, _ := getDim(vectorField)
dim := int64(0)
if !typeutil.IsSparseFloatVectorType(vectorField.DataType) {
dim, _ = getDim(vectorField)
}
phv, err := convertVectors2Placeholder(body, vectorField.DataType, dim)
if err != nil {
return nil, err
Expand Down
39 changes: 35 additions & 4 deletions internal/distributed/proxy/httpserver/handler_v2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -471,11 +471,11 @@ func TestDatabaseWrapper(t *testing.T) {
func TestCreateCollection(t *testing.T) {
postTestCases := []requestBodyTestCase{}
mp := mocks.NewMockProxy(t)
mp.EXPECT().CreateCollection(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Times(11)
mp.EXPECT().CreateCollection(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Times(12)
mp.EXPECT().CreateIndex(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Times(6)
mp.EXPECT().LoadCollection(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Times(6)
mp.EXPECT().CreateIndex(mock.Anything, mock.Anything).Return(commonErrorStatus, nil).Twice()
mp.EXPECT().CreateCollection(mock.Anything, mock.Anything).Return(commonErrorStatus, nil).Once()
mp.EXPECT().CreateCollection(mock.Anything, mock.Anything).Return(commonErrorStatus, nil).Twice()
testEngine := initHTTPServerV2(mp, false)
path := versionalV2(CollectionCategory, CreateAction)
// quickly create collection
Expand Down Expand Up @@ -564,6 +564,18 @@ func TestCreateCollection(t *testing.T) {
]
}}`),
})
// dim should not be specified for SparseFloatVector field
postTestCases = append(postTestCases, requestBodyTestCase{
path: path,
requestBody: []byte(`{"collectionName": "` + DefaultCollectionName + `", "schema": {
"fields": [
{"fieldName": "book_id", "dataType": "Int64", "isPrimary": true, "elementTypeParams": {}},
{"fieldName": "word_count", "dataType": "Int64", "isPartitionKey": false, "elementTypeParams": {}},
{"fieldName": "partition_field", "dataType": "VarChar", "isPartitionKey": true, "elementTypeParams": {"max_length": 256}},
{"fieldName": "book_intro", "dataType": "SparseFloatVector", "elementTypeParams": {}}
]
}, "params": {"partitionsNum": "32"}}`),
})
postTestCases = append(postTestCases, requestBodyTestCase{
path: path,
requestBody: []byte(`{"collectionName": "` + DefaultCollectionName + `", "schema": {
Expand Down Expand Up @@ -612,6 +624,18 @@ func TestCreateCollection(t *testing.T) {
errMsg: "",
errCode: 65535,
})
postTestCases = append(postTestCases, requestBodyTestCase{
path: path,
requestBody: []byte(`{"collectionName": "` + DefaultCollectionName + `", "schema": {
"fields": [
{"fieldName": "book_id", "dataType": "Int64", "isPrimary": true, "elementTypeParams": {}},
{"fieldName": "word_count", "dataType": "Int64", "elementTypeParams": {}},
{"fieldName": "book_intro", "dataType": "SparseFloatVector", "elementTypeParams": {"dim": 2}}
]
}, "indexParams": [{"fieldName": "book_intro", "indexName": "book_intro_vector", "metricType": "L2"}]}`),
errMsg: "",
errCode: 65535,
})

for _, testcase := range postTestCases {
t.Run("post"+testcase.path, func(t *testing.T) {
Expand Down Expand Up @@ -1240,16 +1264,19 @@ func TestSearchV2(t *testing.T) {
float16VectorField.Name = "float16Vector"
bfloat16VectorField := generateVectorFieldSchema(schemapb.DataType_BFloat16Vector)
bfloat16VectorField.Name = "bfloat16Vector"
sparseFloatVectorField := generateVectorFieldSchema(schemapb.DataType_SparseFloatVector)
sparseFloatVectorField.Name = "sparseFloatVector"
collSchema.Fields = append(collSchema.Fields, &binaryVectorField)
collSchema.Fields = append(collSchema.Fields, &float16VectorField)
collSchema.Fields = append(collSchema.Fields, &bfloat16VectorField)
collSchema.Fields = append(collSchema.Fields, &sparseFloatVectorField)
mp.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{
CollectionName: DefaultCollectionName,
Schema: collSchema,
ShardsNum: ShardNumDefault,
Status: &StatusSuccess,
}, nil).Times(9)
mp.EXPECT().Search(mock.Anything, mock.Anything).Return(&milvuspb.SearchResults{Status: commonSuccessStatus, Results: &schemapb.SearchResultData{TopK: int64(0)}}, nil).Twice()
}, nil).Times(10)
mp.EXPECT().Search(mock.Anything, mock.Anything).Return(&milvuspb.SearchResults{Status: commonSuccessStatus, Results: &schemapb.SearchResultData{TopK: int64(0)}}, nil).Times(3)
testEngine := initHTTPServerV2(mp, false)
queryTestCases := []requestBodyTestCase{}
queryTestCases = append(queryTestCases, requestBodyTestCase{
Expand Down Expand Up @@ -1377,6 +1404,10 @@ func TestSearchV2(t *testing.T) {
errMsg: "can only accept json format request, error: dimension: 2, bytesLen: 4, but length of []byte: 3: invalid parameter[expected=BFloat16Vector][actual=\x01\x02\x03]",
errCode: 1801,
})
queryTestCases = append(queryTestCases, requestBodyTestCase{
path: SearchAction,
requestBody: []byte(`{"collectionName": "book", "data": [{"1": 0.1}], "annsField": "sparseFloatVector", "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"]}`),
})

for _, testcase := range queryTestCases {
t.Run("search", func(t *testing.T) {
Expand Down
51 changes: 51 additions & 0 deletions internal/distributed/proxy/httpserver/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,15 @@ func checkAndSetData(body string, collSchema *schemapb.CollectionSchema) (error,
return merr.WrapErrParameterInvalid(schemapb.DataType_name[int32(fieldType)], dataString, err.Error()), reallyDataArray
}
reallyData[fieldName] = vectorArray
case schemapb.DataType_SparseFloatVector:
if dataString == "" {
return merr.WrapErrParameterInvalid(schemapb.DataType_name[int32(fieldType)], "", "missing vector field: "+fieldName), reallyDataArray
}
sparseVec, err := typeutil.CreateSparseFloatRowFromJSON([]byte(dataString))
if err != nil {
return merr.WrapErrParameterInvalid(schemapb.DataType_name[int32(fieldType)], dataString, err.Error()), reallyDataArray
}
reallyData[fieldName] = sparseVec
case schemapb.DataType_Float16Vector:
if dataString == "" {
return merr.WrapErrParameterInvalid(schemapb.DataType_name[int32(fieldType)], "", "missing vector field: "+fieldName), reallyDataArray
Expand Down Expand Up @@ -638,6 +647,9 @@ func anyToColumns(rows []map[string]interface{}, sch *schemapb.CollectionSchema)
data = make([][]byte, 0, rowsLen)
dim, _ := getDim(field)
nameDims[field.Name] = dim
case schemapb.DataType_SparseFloatVector:
data = make([][]byte, 0, rowsLen)
nameDims[field.Name] = int64(0)
default:
return nil, fmt.Errorf("the type(%v) of field(%v) is not supported, use other sdk please", field.DataType, field.Name)
}
Expand Down Expand Up @@ -704,6 +716,13 @@ func anyToColumns(rows []map[string]interface{}, sch *schemapb.CollectionSchema)
nameColumns[field.Name] = append(nameColumns[field.Name].([][]byte), candi.v.Interface().([]byte))
case schemapb.DataType_BFloat16Vector:
nameColumns[field.Name] = append(nameColumns[field.Name].([][]byte), candi.v.Interface().([]byte))
case schemapb.DataType_SparseFloatVector:
content := candi.v.Interface().([]byte)
rowSparseDim := typeutil.SparseFloatRowDim(content)
if rowSparseDim > nameDims[field.Name] {
nameDims[field.Name] = rowSparseDim
}
nameColumns[field.Name] = append(nameColumns[field.Name].([][]byte), content)
default:
return nil, fmt.Errorf("the type(%v) of field(%v) is not supported, use other sdk please", field.DataType, field.Name)
}
Expand Down Expand Up @@ -895,6 +914,18 @@ func anyToColumns(rows []map[string]interface{}, sch *schemapb.CollectionSchema)
},
},
}
case schemapb.DataType_SparseFloatVector:
colData.Field = &schemapb.FieldData_Vectors{
Vectors: &schemapb.VectorField{
Dim: nameDims[name],
Data: &schemapb.VectorField_SparseFloatVector{
SparseFloatVector: &schemapb.SparseFloatArray{
Dim: nameDims[name],
Contents: column.([][]byte),
},
},
},
}
default:
return nil, fmt.Errorf("the type(%v) of field(%v) is not supported, use other sdk please", colData.Type, name)
}
Expand Down Expand Up @@ -963,6 +994,19 @@ func serializeByteVectors(vectorStr string, dataType schemapb.DataType, dimensio
return values, nil
}

func serializeSparseFloatVectors(vectors []gjson.Result, dataType schemapb.DataType) ([][]byte, error) {
values := make([][]byte, 0)
for _, vector := range vectors {
vectorBytes := []byte(vector.String())
sparseVector, err := typeutil.CreateSparseFloatRowFromJSON(vectorBytes)
if err != nil {
return nil, merr.WrapErrParameterInvalid(schemapb.DataType_name[int32(dataType)], vector.String(), err.Error())
}
values = append(values, sparseVector)
}
return values, nil
}

func convertVectors2Placeholder(body string, dataType schemapb.DataType, dimension int64) (*commonpb.PlaceholderValue, error) {
var valueType commonpb.PlaceholderType
var values [][]byte
Expand All @@ -980,6 +1024,9 @@ func convertVectors2Placeholder(body string, dataType schemapb.DataType, dimensi
case schemapb.DataType_BFloat16Vector:
valueType = commonpb.PlaceholderType_BFloat16Vector
values, err = serializeByteVectors(gjson.Get(body, HTTPRequestData).Raw, dataType, dimension, dimension*2)
case schemapb.DataType_SparseFloatVector:
valueType = commonpb.PlaceholderType_SparseFloatVector
values, err = serializeSparseFloatVectors(gjson.Get(body, HTTPRequestData).Array(), dataType)
}
if err != nil {
return nil, err
Expand Down Expand Up @@ -1070,6 +1117,8 @@ func buildQueryResp(rowsNum int64, needFields []string, fieldDataList []*schemap
rowsNum = int64(len(fieldDataList[0].GetVectors().GetFloat16Vector())/2) / fieldDataList[0].GetVectors().GetDim()
case schemapb.DataType_BFloat16Vector:
rowsNum = int64(len(fieldDataList[0].GetVectors().GetBfloat16Vector())/2) / fieldDataList[0].GetVectors().GetDim()
case schemapb.DataType_SparseFloatVector:
rowsNum = int64(len(fieldDataList[0].GetVectors().GetSparseFloatVector().Contents))
default:
return nil, fmt.Errorf("the type(%v) of field(%v) is not supported, use other sdk please", fieldDataList[0].Type, fieldDataList[0].FieldName)
}
Expand Down Expand Up @@ -1125,6 +1174,8 @@ func buildQueryResp(rowsNum int64, needFields []string, fieldDataList []*schemap
row[fieldDataList[j].FieldName] = fieldDataList[j].GetVectors().GetFloat16Vector()[i*(fieldDataList[j].GetVectors().GetDim()*2) : (i+1)*(fieldDataList[j].GetVectors().GetDim()*2)]
case schemapb.DataType_BFloat16Vector:
row[fieldDataList[j].FieldName] = fieldDataList[j].GetVectors().GetBfloat16Vector()[i*(fieldDataList[j].GetVectors().GetDim()*2) : (i+1)*(fieldDataList[j].GetVectors().GetDim()*2)]
case schemapb.DataType_SparseFloatVector:
row[fieldDataList[j].FieldName] = typeutil.SparseFloatBytesToMap(fieldDataList[j].GetVectors().GetSparseFloatVector().Contents[i])
case schemapb.DataType_Array:
row[fieldDataList[j].FieldName] = fieldDataList[j].GetScalars().GetArrayData().Data[i]
case schemapb.DataType_JSON:
Expand Down
Loading

0 comments on commit 3bc6be0

Please sign in to comment.