In [1]:
### * Script to insert content in vector database.
### * Script assumes data is stored in .csv file.

from pymilvus import connections, Collection
from pymilvus import utility, FieldSchema, CollectionSchema, DataType, Collection
from dotenv import load_dotenv
import os
import csv
from sentence_transformers import SentenceTransformer

load_dotenv()

FILE_NAME = "tata_punch_owner_manual.pdf"
FOLDER_NAME = FILE_NAME.split(".")[0]
CSV_FILE_LINK = f"output/{FOLDER_NAME}/{FOLDER_NAME}_csv.csv"
COLLECTION_NAME = "rag_vs"

MILVUS_TOKEN = os.getenv("MILVUS_TOKEN")
MILVUS_URI = os.getenv("MILVUS_URI")
DIMENSION = 384
BATCH_SIZE = 128

  from tqdm.autonotebook import tqdm, trange


In [2]:
collection: Collection = None

### Creating a collection
connections.connect(uri=MILVUS_URI, token=MILVUS_TOKEN)

if not utility.has_collection(collection_name=COLLECTION_NAME):
    print(f"collection {COLLECTION_NAME} not found. creating one.")
    fields = [
        FieldSchema(
            name="id", dtype=DataType.INT64, is_primary=True, auto_id=True
        ),  # id is auto=increment
        FieldSchema(name="content", dtype=DataType.VARCHAR, max_length=65535),
        FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=DIMENSION),
        FieldSchema(name="page_num", dtype=DataType.INT64),
        FieldSchema(name="image_url", dtype=DataType.VARCHAR, max_length=1000),
        FieldSchema(name="file_name", dtype=DataType.VARCHAR, max_length=200),
    ]
    schema = CollectionSchema(fields=fields)
    collection = Collection(name=COLLECTION_NAME, schema=schema)

    index_params = {
        "metric_type": "L2",
        "index_type": "IVF_FLAT",
        "params": {"nlist": 1536},
    }
    collection.create_index(field_name="embedding", index_params=index_params)
    collection.load()
else:
    print(f"collection {COLLECTION_NAME} already exists.")
    collection = Collection(name=COLLECTION_NAME)

collection rag_vs already exists.


In [3]:
transformer = SentenceTransformer("all-MiniLM-L6-v2")


def csv_load(file):
    with open(file, newline="") as f:
        reader = csv.reader(f, delimiter=",")
        for row in reader:
            yield (row[0], row[1], row[2])


def generate_embeddings(data: list[str]):
    embeddings = transformer.encode(data)
    return [x for x in embeddings]

def insert_data(data):
    embeddings = generate_embeddings(data[0]) if data[0] else []
    ins = [
        data[0],
        embeddings,
        data[1],
        data[2],
        data[3],
    ]
    collection.insert(ins)


data_batch = [[], [], [],[]]

count = 0

for content, page_num, image_url in csv_load(CSV_FILE_LINK):
    data_batch[0].append(content)
    data_batch[1].append(int(page_num))
    data_batch[2].append(image_url)
    data_batch[3].append(FILE_NAME)
    if len(data_batch[0]) % BATCH_SIZE == 0:
        insert_data(data_batch)
        data_batch = [[], [], [],[]]
        print(f"\ninserted... {count} contents")
    count += 1

if len(data_batch[0]) != 0:
    insert_data(data_batch)

collection.flush()





inserted... 127 contents

inserted... 255 contents

inserted... 383 contents

inserted... 511 contents

inserted... 639 contents

inserted... 767 contents

inserted... 895 contents

inserted... 1023 contents

inserted... 1151 contents

inserted... 1279 contents

inserted... 1407 contents

inserted... 1535 contents

inserted... 1663 contents

inserted... 1791 contents

inserted... 1919 contents

inserted... 2047 contents

inserted... 2175 contents

inserted... 2303 contents

inserted... 2431 contents
