In [6]:
import json, os, random
from pymilvus import connections, FieldSchema, CollectionSchema, DataType, Collection, utility

CLUSTER_ENDPOINT="http://localhost:19530" # Set your cluster endpoint
TOKEN="root:Milvus" # Set your token
COLLECTION_NAME="medium_articles_2020" # Set your collection name
DATASET_PATH="../medium_articles_2020_dpr.json" # Set your dataset path

In [7]:
# 0. Connect to cluster
connections.connect(
    uri=CLUSTER_ENDPOINT, # Public endpoint obtained from Zilliz Cloud
    token=TOKEN, # Username and password specified when you created this cluster
)

In [8]:
# 1. Define fields
fields = [
    FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True, max_length=100),
    FieldSchema(name="title", dtype=DataType.VARCHAR, max_length=512),
    FieldSchema(name="title_vector", dtype=DataType.FLOAT_VECTOR, dim=768),
    # The following field is a JSON field
    FieldSchema(name="article_meta", dtype=DataType.JSON)
]

# 2. Build the schema
schema = CollectionSchema(
    fields,
    description="Schema of Medium articles",
    enable_dynamic_field=False
)

# 3. Create collection
collection = Collection(
    name=COLLECTION_NAME, 
    description="Medium articles published between Jan and August in 2020 in prominent publications",
    schema=schema
)

In [9]:
# 4. Index collection
# 'index_type' defines the index algorithm to be used.
#    AUTOINDEX is the only option.
#
# 'metric_type' defines the way to measure the distance 
#    between vectors. Possible values are L2, IP, and Cosine,
#    and defaults to Cosine.
index_params = {
    "index_type": "AUTOINDEX",
    "metric_type": "L2",
    "params": {}
}

# To name the index, do as follows:
collection.create_index(
    field_name="title_vector", 
    index_params=index_params,
)

# 5. Load collection
collection.load()

# Get loading progress
progress = utility.loading_progress(COLLECTION_NAME)

print(progress)

{'loading_progress': '100%'}


In [10]:
# 6. Prepare data

# Prepare a list of rows
with open(DATASET_PATH) as f:
    data = json.load(f)
    list_of_rows = data['rows']

    rows = []
    for row in list_of_rows:
        # Remove the id field because auto-id is enabled for the primary key
        del row['id']
        # Create the article_meta field and 
        row['article_meta'] = {}
        # Move the following keys into the article_meta field
        row['article_meta']['link'] = row.pop('link')
        row['article_meta']['reading_time'] = row.pop('reading_time')
        row['article_meta']['publication'] = row.pop('publication')
        row['article_meta']['claps'] = row.pop('claps')
        row['article_meta']['responses'] = row.pop('responses')
        row['article_meta']['tags_1'] = [ random.randint(0, 40) for x in range(40)]
        row['article_meta']['tags_2'] = [ [ random.randint(0, 40) for y in range(4) ] for x in range(10) ]
        # Append this row to the data_rows list
        rows.append(row)

print(json.dumps(rows[:1], indent=4))

[
    {
        "title": "The Reported Mortality Rate of Coronavirus Is Not Important",
        "title_vector": [
            0.041732933,
            0.013779674,
            -0.027564144,
            -0.013061441,
            0.009748648,
            0.00082446384,
            -0.00071647146,
            0.048612226,
            -0.04836573,
            -0.04567751,
            0.018008126,
            0.0063936645,
            -0.011913628,
            0.030776596,
            -0.018274948,
            0.019929802,
            0.020547243,
            0.032735646,
            -0.031652678,
            -0.033816382,
            -0.051087562,
            -0.033748355,
            0.0039493158,
            0.009246126,
            -0.060236514,
            -0.017136049,
            0.028754413,
            -0.008433934,
            0.011168004,
            -0.012391256,
            -0.011225835,
            0.031775184,
            0.002929508,
            -0.007448661,
            -0.

In [11]:
# 7. Insert data in rows
results = collection.insert(rows[:1000])

print(f"Data inserted successfully! Upserted rows: {results.insert_count}")

Data inserted successfully! Upserted rows: 1000


In [12]:
# 8. Count entities

counts = collection.query(expr="", output_fields=["count(*)"])

print(counts)

[{'count(*)': 1000}]


In [13]:
# 9. Count entities with condition

expr = 'article_meta["claps"] > 30 and article_meta["reading_time"] < 10'

counts = collection.query(expr=expr, output_fields=["count(*)"])

print(counts)

[{'count(*)': 729}]


In [15]:
# 10. Check if a specific element exists in a JSON field

# Search

# matches all articles with tags_1 having the member 16
expr_1 = 'JSON_CONTAINS(article_meta["tags_1"], 16)'

# matches all articles with tags_2 having the member [5, 3, 39, 8]
expr_2 = 'JSON_CONTAINS(article_meta["tags_2"], [5, 3, 39, 8])'

# matches all articles with tags_1 having a member from [5, 3, 39, 8]
expr_3 = 'JSON_CONTAINS_ANY(article_meta["tags_1"], [5, 3, 39, 8])'

# matches all articles with tags_1 having all members from [2, 4, 6]
expr_4 = 'JSON_CONTAINS_ALL(article_meta["tags_1"], [2, 4, 6])'

query_vector = rows[0]['title_vector']

# Define search parameters
search_params = {
    "metric_type": "L2",
    "params": {"nprobe": 10}
}

res = collection.search(
    data=[query_vector],
    anns_field="title_vector",
    param={"metric_type": "L2", "params": {"nprobe": 10}},
    limit=5,
    expr=expr_1,
    output_fields=["title", "article_meta"]
)

def get_tags_1(value, target):
    try:
        return value.index(target) >= 0
    except (ValueError):
        return False

results = [ {
    "id": hit.id,
    "distance": hit.distance,
    "entity": {
        "title": hit.entity.get("title"),
        "link": hit.entity.get("article_meta")['link'],
        "tags_1": get_tags_1(hit.entity.get("article_meta")["tags_1"], 16),
    }
} for hits in res for hit in hits ]

ids = [ hits.ids for hits in res ]

print(ids)

distances = [ hits.distances for hits in res ]

print(distances)

def get_tags_1(value, target):
    try:
        return value.index(target) >= 0
    except (ValueError):
        return False

results = [ {
    "id": hit.id,
    "distance": hit.distance,
    "entity": {
        "title": hit.entity.get("title"),
        "link": hit.entity.get("article_meta")['link'],
        "tags_1": get_tags_1(hit.entity.get("article_meta")["tags_1"], 16),
    }
} for hits in res for hit in hits ]

print(json.dumps(results, indent=4))

[[445311585782931790, 445311585782931694, 445311585782930922, 445311585782931661, 445311585782931091]]
[[0.436093807220459, 0.49443864822387695, 0.4948429763317108, 0.5095388889312744, 0.5383652448654175]]
[
    {
        "id": 445311585782931790,
        "distance": 0.436093807220459,
        "entity": {
            "title": "Mortality Rate As an Indicator of an Epidemic Outbreak",
            "link": "https://towardsdatascience.com/mortality-rate-as-an-indicator-of-an-epidemic-outbreak-704592f3bb39",
            "tags_1": true
        }
    },
    {
        "id": 445311585782931694,
        "distance": 0.49443864822387695,
        "entity": {
            "title": "Choosing the right performance metrics can save lives against Coronavirus",
            "link": "https://towardsdatascience.com/choosing-the-right-performance-metrics-can-save-lives-against-coronavirus-2f27492f6638",
            "tags_1": true
        }
    },
    {
        "id": 445311585782930922,
        "distance": 0.49

In [16]:
# query

res = collection.query(
    limit=5,
    expr=expr_4,
    output_fields=["title", "article_meta"]
)

res = [ {
    "title": x.get("title"),
    "article_meta": {
        "link": x.get("article_meta")['link'],
        "reading_time": x.get("article_meta")['reading_time'],
        "publication": x.get("article_meta")['publication'],
        "claps": x.get("article_meta")['claps'],
        "responses": x.get("article_meta")['responses'],
        "tags_1": get_tags_1(x.get("article_meta")['tags_1'], 2),
    },
} for x in res ]

print(json.dumps(res, indent=4))

[
    {
        "title": "Building Comprehensible Customer Churn Prediction Models",
        "article_meta": {
            "link": "https://medium.com/swlh/building-comprehensible-customer-churn-prediction-models-ca61ecce529d",
            "reading_time": 13,
            "publication": "The Startup",
            "claps": 261,
            "responses": 4,
            "tags_1": true
        }
    },
    {
        "title": "Would you rather have 8% of $25 or 25% of $8?",
        "article_meta": {
            "link": "https://medium.com/swlh/would-you-rather-have-8-of-25-or-25-of-8-486b3bc48f28",
            "reading_time": 3,
            "publication": "The Startup",
            "claps": 208,
            "responses": 5,
            "tags_1": true
        }
    },
    {
        "title": "Blockchain, IoT and AI \u2014 A Perfect Fit",
        "article_meta": {
            "link": "https://medium.com/swlh/blockchain-iot-and-ai-a-perfect-fit-1-e04c6ad73fbc",
            "reading_time": 11,
    

In [17]:
# 10. Drop collection

utility.drop_collection(COLLECTION_NAME)