In [None]:
!pip install weaviate-client
!pip install datasets

## Load data

Notes:
* collection name: `Articles`
* languages included: `en`, `de`, `fr`, `es`, `it`, `ja`, `ar`, `zh`, `ko`, `hi`
* source: [Cohere/wikipedia-22-12-(lang)-embeddings](https://huggingface.co/Cohere)

In [116]:
import os
import weaviate

auth_config = weaviate.AuthBearerToken(
  access_token = os.getenv("COHERE_AUTH_WRITE"),
  expires_in=36000 # in seconds, by default 60s
)

client = weaviate.Client(
    url="https://cohere-demo.weaviate.network/",
    auth_client_secret=auth_config,
    additional_headers={
        "X-Cohere-Api-Key": os.getenv("COHERE_API_KEY"),
    },
    timeout_config = (20, 240)
)
client.is_ready()

True

In [29]:
# delete existing schema, (note, this will delete the data in the Articles collection)
# client.schema.delete_class("Articles")

article_schema = {
    "class": "Articles",
    "description": "Wiki Article",
    "vectorizer": "text2vec-cohere",
    "moduleConfig": {
        "text2vec-cohere": {
            "model": "multilingual-22-12",
            "truncate": "RIGHT"
        }
    },
    "vectorIndexConfig": {
        "distance": "dot"
    },
    "properties": [
    {
        "name": "text",
        "dataType": [ "text" ],
        "description": "Article body",
        "moduleConfig": {
            "text2vec-cohere": {
                "skip": False,
                "vectorizePropertyName": False
            }
        }
    },
    {
        "name": "title",
        "dataType": [ "text" ],
        "moduleConfig": { "text2vec-cohere": { "skip": True } }
    },
    {
        "name": "url",
        "dataType": [ "string" ],
        "moduleConfig": { "text2vec-cohere": { "skip": True } }
    },
    {
        "name": "wiki_id",
        "dataType": [ "int" ],
        "moduleConfig": { "text2vec-cohere": { "skip": True } }
    },
    {
        "name": "views",
        "dataType": [ "int" ],
        "moduleConfig": { "text2vec-cohere": { "skip": True } }
    },
    {
        "name": "lang",
        "dataType": [ "string" ],
        "moduleConfig": { "text2vec-cohere": { "skip": True } }
    },
    {
        "name": "lang_id",
        "dataType": [ "int" ],
        "moduleConfig": { "text2vec-cohere": { "skip": True } }
    },
    ]
}

# add the schema
client.schema.create_class(article_schema)
print("The collection schema has been created")

The collection schema has been created


In [112]:
### Step 1 - configure Weaviate Batch, which optimizes CRUD operations in bulk
# - starting batch size of 100
# - dynamically increase/decrease based on performance
# - add timeout retries if something goes wrong

client.batch.configure(
    batch_size=200,
    # dynamic=True,
    timeout_retries=3,
)

<weaviate.batch.crud_batch.Batch at 0x7f7cfbd31ac0>

In [113]:
def round_views(val):
    if(val < 10):
        return int(val)

    if(val < 100):
        return int(val/10)*10
    
    if(val < 1_000):
        return int(val/100)*100
    
    if(val < 10_000):
        return int(val/1_000)*1_000
    
    if(val < 100_000):
        return int(val/10_000)*10_000
    
    return int(val/100_000)*100_000

In [114]:
from datasets import load_dataset
from tqdm import tqdm

def import_wiki_data(lang, lang_id, max_rows, skip_rows=0):
    print(f"Importing {max_rows} data items for {lang}")

    dataset = load_dataset(f"Cohere/wikipedia-22-12-{lang}-embeddings", split='train', streaming=True)
    dataset = dataset.skip(skip_rows)

    # counter = 0
    counter = skip_rows

    # dataset.num_rows

    with client.batch as batch:
        for item in tqdm(dataset, initial=skip_rows, total=max_rows):
            vector = item["emb"]
            data_to_insert = {   
                # '_id': doc_id,
                'wiki_id': item['wiki_id'],
                'title': item['title'],
                'text': item['text'],
                'url': item['url'],
                'lang': lang,
                'lang_id': lang_id,
                'views': round_views(item['views']),
                'num_langs': item['langs'],
            }

            batch.add_data_object(data_to_insert, "Articles", None, vector)
            
            counter += 1
            if counter >= max_rows:
                break
    
    print(f"Imported {counter} items for {lang}")
    print( "-----------------------------------")

In [115]:
# import_per_country = 100_000
import_per_country = 1_000_000

# import_wiki_data('en', 0, import_per_country, 1_000_000)
# import_wiki_data('de', 1, import_per_country, 1_000_000)
# import_wiki_data('fr', 2, import_per_country, 1_000_000)
# import_wiki_data('es', 3, import_per_country, 1_000_000)
# import_wiki_data('it', 4, import_per_country, 1_000_000)
import_wiki_data('ja', 5, import_per_country, 985_000)
# import_wiki_data('ar', 6, import_per_country, 561_000)
# import_wiki_data('zh', 7, import_per_country, 100_000)
# import_wiki_data('ko', 8, import_per_country, 100_000)
# import_wiki_data('hi', 9, import_per_country, 100_000)

Importing 1000000 data items for ja


Using custom data configuration Cohere--wikipedia-22-12-ja-embeddings-ccaffb31b2ed5e09
 97%|█████████▋| 972400/1000000 [00:00<?, ?it/s]Got disconnected from remote data host. Retrying in 5sec [1/20]
Got disconnected from remote data host. Retrying in 5sec [1/20]
 98%|█████████▊| 983185/1000000 [47:27<03:43, 75.36it/s]        [ERROR] Batch ReadTimeout Exception occurred! Retrying in 2s. [1/3]
 98%|█████████▊| 984999/1000000 [1:00:06<1:11:33,  3.49it/s]


KeyboardInterrupt: 

### Show number of imported items

In [None]:
# Test that all data has loaded – get object count
result = (
    client.query.aggregate("Articles")
    .with_where({
        "path": ["lang_id"],
        "operator": "Equal",
        "valueInt": 5
    })
    # .with_group_by_filter(["lang_id"])
    # .with_fields("groupedBy {value}")
    .with_meta_count()
    .do()
)
print("Object count: ", result["data"]["Aggregate"]["Articles"])