# The Start
Let's evolve our knowledge of embeddings.  We need to embed the text of nodes as well as the text of the question asked of the retriever.

We are using LlamaIndex and Ollama.  Let's get started.

In [None]:
%cd ..
%pwd  # To verify the current working directory

In [14]:
import cohere
import os
from llama_index.core.embeddings import BaseEmbedding
from pydantic import PrivateAttr
# Create a custom Cohere embedding class
class CohereEmbedding(BaseEmbedding):
    _client: cohere.Client = PrivateAttr()
    _model_name: str = PrivateAttr()
    _input_type: str = PrivateAttr()

    def __init__(self, model_name: str = "embed-english-v3.0", input_type: str = "search_query"):
        super().__init__()
        self._client = cohere.Client(api_key=os.getenv("COHERE_API_KEY"))
        self._model_name = model_name
        self._input_type = input_type

    def _get_query_embedding(self, query: str) -> list[float]:
        embeddings = self._client.embed(texts=[query], model=self._model_name, input_type=self._input_type).embeddings
        return embeddings[0]

    def _get_text_embedding(self, text: str) -> list[float]:
        embeddings = self._client.embed(texts=[text], model=self._model_name, input_type=self._input_type).embeddings
        return embeddings[0]

    async def _aget_query_embedding(self, query: str) -> list[float]:
        # Cohere doesn't have an async API, so we'll just call the sync version
        return self._get_query_embedding(query)

    async def _aget_text_embedding(self, text: str) -> list[float]:
        # Cohere doesn't have an async API, so we'll just call the sync version
        return self._get_text_embedding(text)


In [None]:
# Set up ollama embedding for llama index.

## Set embedding model

from llama_index.core import Settings
from llama_index.llms.ollama import Ollama
from llama_index.embeddings.ollama import OllamaEmbedding
# Set up the Cohere embedding model
Settings.embed_model = CohereEmbedding(
    model_name="embed-english-v3.0",
    input_type="search_query"
)

# Settings.embed_model = OllamaEmbedding(
#     model_name='snowflake-arctic-embed',
#     base_url="http://localhost:11434",
#     ollama_additional_kwargs={"mirostat": 0},
# )
## Choose your LLM...
Settings.llm = Ollama(model='mistral', request_timeout=1000.0)

In [3]:
# Load documents.
import pickle

with open('text_nodes.pkl', 'rb') as f:
    text_nodes = pickle.load(f)

In [None]:
text_nodes[10].text

In [4]:

from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction

embedding_function = SentenceTransformerEmbeddingFunction()
print(embedding_function([text_nodes[10].text]))

[[-0.02265961281955242, -0.005239610094577074, 0.07052355259656906, 0.02561274915933609, 0.10104447603225708, 0.028056371957063675, 0.010161704383790493, 0.02678683213889599, -0.010463953018188477, 0.013094304129481316, -0.00229410408064723, 0.024940816685557365, 0.03212709724903107, 0.05114594101905823, -0.13100771605968475, -0.044179853051900864, -0.02100180834531784, 0.0862240269780159, -0.1389358639717102, -0.006520877592265606, 0.025919483974575996, -0.043282076716423035, -0.00014383887173607945, -0.007276103366166353, -0.05586138367652893, 0.03298725187778473, -0.0278569757938385, -0.014384178444743156, 0.003727296367287636, 0.002785601420328021, -0.010051359422504902, 0.085335873067379, 0.08227475732564926, -0.012291735969483852, -0.00842922180891037, 0.0463767908513546, 0.053411033004522324, -0.02952464483678341, -0.011834043078124523, -0.005725303664803505, 0.049482561647892, -0.05733354762196541, 0.08487400412559509, 0.02219979465007782, 0.002378671197220683, -0.0534917972981

In [None]:
import chromadb

chroma_client = chromadb.Client()
collection_name = "microsoft_annual_report_2022"
# Check if the collection exists and delete it if it does
try:
    chroma_client.delete_collection(name=collection_name)
    print(f"Existing collection '{collection_name}' deleted.")
except ValueError:
    print(f"Collection '{collection_name}' does not exist. Proceeding to create.")
chroma_collection = chroma_client.create_collection(collection_name, embedding_function=embedding_function)
# Extract text from each node
documents = [node.text for node in text_nodes]
ids = [str(i) for i in range(len(text_nodes))]
metadatas = [node.metadata for node in text_nodes]
chroma_collection.add(ids=ids, documents=documents, metadatas=metadatas)
chroma_collection.count()

In [None]:
# Create a vector index.
from llama_index.core import VectorStoreIndex

vector_index = VectorStoreIndex.from_documents(text_nodes)


In [None]:
print(text_nodes[0].metadata)