In [None]:
!pip install weaviate-client
!pip install datasets

## Load data

Notes:
* collection name: `Articles`
* languages included: `en`, `de`, `fr`, `es`, `it`, `ja`, `ar`, `zh`, `ko`, `hi`
* source: [Cohere/wikipedia-22-12-(lang)-embeddings](https://huggingface.co/Cohere)

In [None]:
import os
import weaviate

client = weaviate.Client(
    url="https://cohere-wiki-demo.weaviate.network",
    additional_headers={
        "X-Cohere-Api-Key": os.getenv("COHERE_API_KEY"),
    }
)
client.is_ready()

In [None]:
# delete existing schema, (note, this will delete the data in the Articles collection)
client.schema.delete_class("Articles")

article_schema = {
    "class": "Articles",
    "description": "Wiki Article",
    "vectorizer": "text2vec-cohere",
    "moduleConfig": {
        "text2vec-cohere": {
            "model": "multilingual-22-12",
            "truncate": "RIGHT"
        }
    },
    "vectorIndexConfig": {
        "distance": "dot"
    },
    "properties": [
    {
        "name": "text",
        "dataType": [ "text" ],
        "description": "Article body",
        "moduleConfig": {
            "text2vec-cohere": {
                "skip": False,
                "vectorizePropertyName": False
            }
        }
    },
    {
        "name": "title",
        "dataType": [ "string" ],
        "moduleConfig": { "text2vec-cohere": { "skip": True } }
    },
    {
        "name": "url",
        "dataType": [ "string" ],
        "moduleConfig": { "text2vec-cohere": { "skip": True } }
    },
    {
        "name": "wiki_id",
        "dataType": [ "int" ],
        "moduleConfig": { "text2vec-cohere": { "skip": True } }
    },
    {
        "name": "views",
        "dataType": [ "number" ],
        "moduleConfig": { "text2vec-cohere": { "skip": True } }
    },
    {
        "name": "lang",
        "dataType": [ "string" ],
        "moduleConfig": { "text2vec-cohere": { "skip": True } }
    },
    {
        "name": "lang_id",
        "dataType": [ "int" ],
        "moduleConfig": { "text2vec-cohere": { "skip": True } }
    },
    ]
}

# add the schema
client.schema.create_class(article_schema)
print("The collection schema has been created")

In [None]:
### Step 1 - configure Weaviate Batch, which optimizes CRUD operations in bulk
# - starting batch size of 100
# - dynamically increase/decrease based on performance
# - add timeout retries if something goes wrong

client.batch.configure(
    batch_size=100,
    dynamic=True,
    timeout_retries=3,
)

In [17]:
from datasets import load_dataset
from tqdm import tqdm

def import_wiki_data(lang, lang_id, max_rows):
    print(f"Importing {max_rows} data items for {lang}")

    data = load_dataset(f"Cohere/wikipedia-22-12-{lang}-embeddings", split='train', streaming=True)

    counter = 0

    with client.batch as batch:
        for item in tqdm(data, total=max_rows):
            # doc_id = f"{lang}_{item['id']}"

            vector = item["emb"]
            data_to_insert = {   
                # '_id': doc_id,
                'wiki_id': item['wiki_id'],
                'title': item['title'],
                'text': item['text'],
                'url': item['url'],
                'lang': lang,
                'lang_id': lang_id,
                'views': item['views'], 
                'num_langs': item['langs'],
            }

            batch.add_data_object(data_to_insert, "Articles", None, vector)
            counter += 1
            if counter >= max_rows:
                break
    
    print(f"Imported {counter} items for {lang}")

In [None]:
import_per_country = 10_000

import_wiki_data('en', 1,  import_per_country)
import_wiki_data('de', 2,  import_per_country)
import_wiki_data('fr', 3,  import_per_country)
import_wiki_data('es', 4,  import_per_country)
import_wiki_data('it', 5,  import_per_country)
import_wiki_data('ja', 6,  import_per_country)
import_wiki_data('ar', 7,  import_per_country)
import_wiki_data('zh', 8,  import_per_country)
import_wiki_data('ko', 9,  import_per_country)
import_wiki_data('hi', 10, import_per_country)

### Show number of imported items

In [None]:
# Test that all data has loaded – get object count
result = (
    client.query.aggregate("Articles")
    .with_fields("meta { count }")
    .do()
)
print("Object count: ", result["data"]["Aggregate"]["Articles"])