# RAG – Configure the collection with a generator and load data

## Get keys and urls

In [None]:
import os
from dotenv import load_dotenv

load_dotenv()

WEAVIATE_URL = os.getenv("WEAVIATE_URL")
WEAVIATE_KEY = os.getenv("WEAVIATE_KEY")
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")

print(WEAVIATE_URL)
print(WEAVIATE_KEY)
print(OPENAI_API_KEY)

## Connect to Weaviate

In [None]:
import weaviate
from weaviate.classes.init import Auth

client = weaviate.connect_to_weaviate_cloud(
    cluster_url=WEAVIATE_URL,
    auth_credentials=Auth.api_key(WEAVIATE_KEY),

    headers = {
        "X-OpenAI-Api-Key": OPENAI_API_KEY
    },
)

client.is_ready()

## Recreate the Wiki collection

In [None]:
from weaviate.classes.config import Configure

def create_wiki_collection():
    if client.collections.exists("Wiki"):
        client.collections.delete("Wiki")

    # Create a collection here - with OpenAI vectorizer and define source properties
    client.collections.create(
        name="Wiki",

        vectorizer_config=[
            Configure.NamedVectors.text2vec_openai(
                name="main_vector",

                model="text-embedding-3-small",
                source_properties=['title', 'text'] # which properties should be used to generate a vector
            )
        ],

        # supported models: https://weaviate.io/developers/weaviate/model-providers/openai/generative#available-models
        generative_config=Configure.Generative.openai(
            model="gpt-4" # gpt-3.5-turbo
        )
    )

create_wiki_collection()

## Load the data from parquet files

In [None]:
from datasets import load_dataset

def prepare_parquet_dataset():
    return load_dataset('parquet', data_files={'train': ['../wiki-data/openai/text-embedding-3-small/*.parquet']}, split="train", streaming=True)

### The import function

In [None]:
from tqdm import tqdm
from weaviate.util import generate_uuid5

def import_wiki_data(max_rows=10_000):
    print(f"Importing {max_rows} data items")

    dataset = prepare_parquet_dataset()
    wiki = client.collections.get("Wiki")

    counter = 0

    with wiki.batch.fixed_size(batch_size=2500, concurrent_requests=4) as batch:
        for item in tqdm(dataset, total=max_rows):

            data_to_insert = {   
                "wiki_id": item["wiki_id"],
                "text": item["text"],
                "title": item["title"],
                "url": item["url"],
            }

            item_id = generate_uuid5(item["wiki_id"])

            # vector = item["vector"]
            item_vector = {
                "main_vector": item["vector"]
            }

            batch.add_object(
                properties=data_to_insert,
                
                uuid=item_id,
                vector=item_vector
            )

            # Check number of errors while running
            if(batch.number_errors > 10):
                print(f"Reached {batch.number_errors} Errors during batch import")
                break
            
            # stop after the request number reaches = max_rows
            counter += 1
            if counter >= max_rows:
                break
    
    # check for errors at the end
    if (len(wiki.batch.failed_objects)>0):
        print("Final error check")
        print(f"Some errors {len(wiki.batch.failed_objects)}")
        print(wiki.batch.failed_objects[-1])
    
    print(f"Imported {counter} items")
    print("-----------------------------------")

In [None]:
import_wiki_data(10_000)

## Check if data loaded correctly

In [None]:
wiki = client.collections.get("Wiki")
len(wiki)

In [None]:
res = wiki.query.fetch_objects(limit=1, include_vector=True)
print(res.objects[0].properties)
print(res.objects[0].vector)

In [None]:
client.close()