In [2]:
import pandas as pd
import os
from openai import OpenAI
import openai
from dotenv import load_dotenv, find_dotenv
from qdrant_client import models, QdrantClient
from qdrant_client.http import models as rest
from qdrant_client.http.models import Record
from sentence_transformers import SentenceTransformer

load_dotenv(find_dotenv())
client = OpenAI(
    api_key=os.getenv('OPENAI_API_KEY'),
)

# Read the CSV dataset
file_path = '../datasets/songs_embeddings.csv'
df = pd.read_csv(file_path) # Add .head(100) if you want to limit the number of rows

df = df.fillna('')

print(str(len(df)) + ' rows')
df.head()

2222 rows


Unnamed: 0,id,artist,song,lyrics,img_src,sentiment_score,metadata,metadata_vector
0,0,米津玄師,Lemon,夢ならばどれほどよかったでしょう未だにあなたのことを夢にみる忘れた物を取りに帰るように古びた...,https://m.media-amazon.com/images/I/51ZsVIMARh...,-0.669566,米津玄師 / Lemon / 夢ならばどれほどよかったでしょう未だにあなたのことを夢にみる忘...,"[0.0367952436208725, -0.01923556998372078, -0...."
1,1,back number,クリスマスソング,どこかで鐘が鳴ってらしくない言葉が浮かんで寒さが心地よくてあれ　なんで恋なんかしてんだろう聖...,https://m.media-amazon.com/images/I/31eRU7YYby...,-0.066641,back number / クリスマスソング / どこかで鐘が鳴ってらしくない言葉が浮かんで...,"[0.004695550538599491, -0.05602886155247688, -..."
2,2,GReeeeN,キセキ,明日、今日よりも好きになれる　溢れる想いが止まらない今もこんなに好きでいるのに　言葉に出来な...,https://m.media-amazon.com/images/I/51L0WT553N...,0.651807,GReeeeN / キセキ / 明日、今日よりも好きになれる　溢れる想いが止まらない今もこん...,"[0.0385475717484951, -0.042898621410131454, 0...."
3,3,back number,花束,どう思う？これから2人でやっていけると思う？んんどうかなぁでもとりあえずは一緒にいたいと思っ...,https://m.media-amazon.com/images/I/51SHzjh0dC...,-0.711921,back number / 花束 / どう思う？これから2人でやっていけると思う？んんどうか...,"[-0.005037600640207529, -0.05399758741259575, ..."
4,4,RADWIMPS,前前前世 (movie ver.),やっと眼を覚ましたかい　それなのになぜ眼も合わせやしないんだい？「遅いよ」と怒る君　これでも...,https://m.media-amazon.com/images/I/51h10DfD1o...,-0.626399,RADWIMPS / 前前前世 (movie ver.) / やっと眼を覚ましたかい　それな...,"[0.03341161459684372, 0.004328139126300812, -0..."


In [3]:
# qdrant_client = QdrantClient(':memory:') # Uncomment this for testing locally

# Connect to the cloud version of the Qdrant client
qdrant_client = QdrantClient(
    url=os.getenv('QDRANT_URL'),
    api_key=os.getenv('QDRANT_API_KEY'),
)

In [None]:
import ast

# Set the collection name and size
collection_name = 'songs'
vector_size = len(ast.literal_eval(df['metadata_vector'][0]))  # Convert string to list and get its length

# Create a collection
qdrant_client.recreate_collection(
    collection_name=collection_name,
    vectors_config={
        'metadata': rest.VectorParams(
            distance=rest.Distance.COSINE,
            size=vector_size,
        ),
    }
)

# Calculate the length of payload that is being inserted into the Qdrant collection
def calculate_payload_length(payload):
    total_length = 0
    for value in payload.values():
        if isinstance(value, str):
            total_length += len(value)
        elif isinstance(value, list):
            for item in value:
                total_length += len(str(item))
        elif isinstance(value, dict):
            total_length += calculate_payload_length(value)
        else:
            total_length += len(str(value))
    return total_length

# Add vectors to the collection
request_length = 0

for k, v in df.iterrows():
    # Remove the 'metadata_vector' key from the dictionary to reduce the payload length
    result_dict = v.to_dict()
    if 'metadata_vector' in result_dict:
        del result_dict['metadata_vector']

    payload_length = calculate_payload_length(result_dict)
    vector_length = len(ast.literal_eval(v['metadata_vector']))
    total_length = payload_length + vector_length
    # print(f"Payload length for point {k}: {total_length}")

    request_length = request_length + total_length
    
    qdrant_client.upsert(
        collection_name=collection_name,
        points=[
            rest.PointStruct(
                id=k,
                vector={
                    'metadata': ast.literal_eval(v['metadata_vector']),  # Convert string to list
                },
                payload=result_dict,
            )
        ]
    )

print(f"Payload & vector length for all points: {request_length}")

print(qdrant_client.get_collections())
qdrant_client.count(collection_name=collection_name)