In [1]:
import json, os
from pymilvus import connections, FieldSchema, CollectionSchema, DataType, Collection, utility

CLUSTER_ENDPOINT="YOUR_CLUSTER_ENDPOINT" # Set your cluster endpoint
TOKEN="YOUR_CLUSTER_TOKEN" # Set your token
COLLECTION_NAME="medium_articles_2020" # Set your collection name
DATASET_PATH="../medium_articles_2020_dpr.json" # Set your dataset path

In [2]:
# 0. Connect to cluster
connections.connect(
    uri=CLUSTER_ENDPOINT, # Public endpoint obtained from Zilliz Cloud
    token=TOKEN, # API key or a colon-separated cluster username and password
)

In [3]:
# 1. Define fields
fields = [
    FieldSchema(name="id", dtype=DataType.INT64, is_primary=True),
    FieldSchema(name="title", dtype=DataType.VARCHAR, max_length=512),   
    FieldSchema(name="title_vector", dtype=DataType.FLOAT_VECTOR, dim=768),
    FieldSchema(name="link", dtype=DataType.VARCHAR, max_length=512),
    FieldSchema(name="reading_time", dtype=DataType.INT64),
    FieldSchema(name="publication", dtype=DataType.VARCHAR, max_length=512),
    FieldSchema(name="claps", dtype=DataType.INT64),
    FieldSchema(name="responses", dtype=DataType.INT64)
]

# 2. Build the schema
schema = CollectionSchema(
    fields,
    description="Schema of Medium articles",
    enable_dynamic_field=False
)

# 3. Create collection
collection = Collection(
    name=COLLECTION_NAME, 
    description="Medium articles published between Jan and August in 2020 in prominent publications",
    schema=schema
)

In [4]:
# 4. Index collection
# 'index_type' defines the index algorithm to be used.
#    AUTOINDEX is the only option.
#
# 'metric_type' defines the way to measure the distance 
#    between vectors. Possible values are L2, IP, and Cosine,
#    and defaults to Cosine.
index_params = {
    "index_type": "AUTOINDEX",
    "metric_type": "L2",
    "params": {}
}

# To name the index, do as follows:
collection.create_index(
    field_name="title_vector", 
    index_params=index_params,
)

# 5. Load collection
collection.load()

# Get loading progress
progress = utility.loading_progress(COLLECTION_NAME)

print(progress)

{'loading_progress': '100%'}


In [6]:
# 6. Prepare data

# Prepare a list of rows
with open(DATASET_PATH) as f:
    data = json.load(f)
    rows = data['rows']

print(json.dumps(rows[:3], indent=4))

[
    {
        "id": 0,
        "title": "The Reported Mortality Rate of Coronavirus Is Not Important",
        "title_vector": [
            0.041732933,
            0.013779674,
            -0.027564144,
            -0.013061441,
            0.009748648,
            0.00082446384,
            -0.00071647146,
            0.048612226,
            -0.04836573,
            -0.04567751,
            0.018008126,
            0.0063936645,
            -0.011913628,
            0.030776596,
            -0.018274948,
            0.019929802,
            0.020547243,
            0.032735646,
            -0.031652678,
            -0.033816382,
            -0.051087562,
            -0.033748355,
            0.0039493158,
            0.009246126,
            -0.060236514,
            -0.017136049,
            0.028754413,
            -0.008433934,
            0.011168004,
            -0.012391256,
            -0.011225835,
            0.031775184,
            0.002929508,
            -0.007448661

In [7]:
# Prepare a list of columns
with open(DATASET_PATH) as f:
    keys = list(rows[0].keys())
    columns = [ [] for x in keys ]
    for row in rows:
        for x in keys:
            columns[keys.index(x)].append(row[x])

    columns_demo = [ [] for x in keys ]
    for row in rows[:3]:
        for x in keys:
            columns_demo[keys.index(x)].append(row[x])

print(json.dumps(columns_demo, indent=4))

[
    [
        0,
        1,
        2
    ],
    [
        "The Reported Mortality Rate of Coronavirus Is Not Important",
        "Dashboards in Python: 3 Advanced Examples for Dash Beginners and Everyone Else",
        "How Can We Best Switch in Python?"
    ],
    [
        [
            0.041732933,
            0.013779674,
            -0.027564144,
            -0.013061441,
            0.009748648,
            0.00082446384,
            -0.00071647146,
            0.048612226,
            -0.04836573,
            -0.04567751,
            0.018008126,
            0.0063936645,
            -0.011913628,
            0.030776596,
            -0.018274948,
            0.019929802,
            0.020547243,
            0.032735646,
            -0.031652678,
            -0.033816382,
            -0.051087562,
            -0.033748355,
            0.0039493158,
            0.009246126,
            -0.060236514,
            -0.017136049,
            0.028754413,
            -0.008433934,
 

In [8]:
# 7. Upsert data in rows
results = collection.upsert(rows[:1000])

print(f"Data upserted successfully! Upserted rows: {results.upsert_count}")

Data upserted successfully! Upserted rows: 1000


In [9]:
# 8. Upsert data in columns
results = collection.upsert(columns)

print(f"Data upserted successfully! Upserted rows: {results.upsert_count}")

Data upserted successfully! Upserted rows: 5979


In [10]:
# 9. Drop collection

utility.drop_collection(COLLECTION_NAME)