## Custom Embeddings


https://docs.llamaindex.ai/en/stable/examples/embeddings/custom_embeddings.html

The example below uses Instructor Embeddings (install/setup details here), and implements a custom embeddings class. Instructor embeddings work by providing text, as well as “instructions” on the domain of the text to embed. This is helpful when embedding text from a very specific and specialized topic.

In [1]:
# Install dependencies
!pip install InstructorEmbedding torch transformers sentence-transformers



In [7]:
import openai
import os


In [8]:
from typing import Any, List
from InstructorEmbedding import INSTRUCTOR

from llama_index.bridge.pydantic import PrivateAttr
from llama_index.embeddings.base import BaseEmbedding

In [9]:
class InstructorEmbedding(BaseEmbedding):
    _model: INSTRUCTOR = PrivateAttr()
    _instruction: str = PrivateAttr()

    def __init__(
        self,
        instructor_model_name: str = "hkunlp/instructor-large",
        instruction: str = "Represent a document for semantic search:",
        **kwargs: Any,
    ) -> None:
        self._model = INSTRUCTOR(instructor_model_name)
        self._instruction = instruction
        super().__init__(**kwargs)

    @classmethod
    def class_name(cls) -> str:
        return "instructor"

    async def _aget_query_embedding(self, query: str) -> List[float]:
        return self._get_query_embedding(query)

    async def _aget_text_embedding(self, text: str) -> List[float]:
        return self._get_text_embedding(text)

    def _get_query_embedding(self, query: str) -> List[float]:
        embeddings = self._model.encode([[self._instruction, query]])
        return embeddings[0]

    def _get_text_embedding(self, text: str) -> List[float]:
        embeddings = self._model.encode([[self._instruction, text]])
        return embeddings[0]

    def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
        embeddings = self._model.encode(
            [[self._instruction, text] for text in texts]
        )
        return embeddings

In [10]:
from llama_index import ServiceContext, SimpleDirectoryReader, VectorStoreIndex

## Download Data

In [5]:
#Download Data
!mkdir -p 'data/paul_graham/'

In [6]:
!wget 'https://raw.githubusercontent.com/run-llama/llama_index/main/docs/examples/data/paul_graham/paul_graham_essay.txt' -O 'data/paul_graham/paul_graham_essay.txt'

--2023-12-03 22:16:07--  https://raw.githubusercontent.com/run-llama/llama_index/main/docs/examples/data/paul_graham/paul_graham_essay.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 2606:50c0:8001::154, 2606:50c0:8002::154, 2606:50c0:8003::154, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|2606:50c0:8001::154|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 75042 (73K) [text/plain]
Saving to: 'data/paul_graham/paul_graham_essay.txt'


2023-12-03 22:16:08 (2.38 MB/s) - 'data/paul_graham/paul_graham_essay.txt' saved [75042/75042]



## Load Documents

In [11]:
#Load Documents
documents = SimpleDirectoryReader("./data/paul_graham/").load_data()

In [10]:
!pip install llama-cpp-python

Collecting llama-cpp-python
  Downloading llama_cpp_python-0.2.20.tar.gz (8.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m8.7/8.7 MB[0m [31m14.8 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25h  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Installing backend dependencies ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone
Collecting diskcache>=5.6.1
  Downloading diskcache-5.6.3-py3-none-any.whl (45 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m45.5/45.5 kB[0m [31m1.5 MB/s[0m eta [36m0:00:00[0m
[?25hBuilding wheels for collected packages: llama-cpp-python
  Building wheel for llama-cpp-python (pyproject.toml) ... [?25ldone
[?25h  Created wheel for llama-cpp-python: filename=llama_cpp_python-0.2.20-cp39-cp39-macosx_10_16_x86_64.whl size=2009328 sha256=0cc231882abe7b9bf9495faa290106db7c7872493e6d5273b5beb63770ba8297
  Stored in directory: 

In [None]:
#!ollama run llama2:7b

In [2]:
from llama_index.llms import Ollama

In [3]:
llm = Ollama(model="llama2:7b", )

In [12]:
service_context = ServiceContext.from_defaults(
    llm=llm,
    embed_model=InstructorEmbeddings(embed_batch_size=2), chunk_size=512
)


load INSTRUCTOR_Transformer
max_seq_length  512


## Index Setup

In [13]:
# if running for the first time, will download model weights first!
index = VectorStoreIndex.from_documents(
    documents, service_context=service_context
)

In [14]:
response = index.as_query_engine().query("What did the author do growing up?")
print(response)

Based on the context information provided, the author did the following growing up:

1. Wrote short stories and imagined that they were deep and meaningful.
2. Used an IBM 1401 computer in junior high school to write programs, but found it confusing and couldn't remember any of the programs he wrote.
3. Got a TRS-80 microcomputer kit and started programming more seriously, writing simple games, a program to predict how high model rockets would fly, and a word processor for his father to use.
4. Switched from studying philosophy in college to studying artificial intelligence (AI) because he found it more interesting and thought it was the future of technology.


In [2]:
#!python3 -m pip install jax-metal

Collecting jax-metal
  Downloading jax_metal-0.0.4-py3-none-macosx_10_14_x86_64.whl (51.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m51.9/51.9 MB[0m [31m5.1 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0mm
[?25hCollecting jaxlib==0.4.11
  Downloading jaxlib-0.4.11-cp39-cp39-macosx_10_14_x86_64.whl (74.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m74.9/74.9 MB[0m [31m5.8 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0mm
[?25hCollecting jax==0.4.11
  Downloading jax-0.4.11.tar.gz (1.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m2.6 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25h  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone
Building wheels for collected packages: jax
  Building wheel for jax (pyproject.toml) ... [?25ldone
[?25h  Created wheel for jax: filename=jax-0.

In [3]:
#!python3 -c 'import jax; print(jax.numpy.arange(10))'

Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/Users/tubakaraca/opt/anaconda3/lib/python3.9/site-packages/jax/__init__.py", line 35, in <module>
    from jax import config as _config_module
  File "/Users/tubakaraca/opt/anaconda3/lib/python3.9/site-packages/jax/config.py", line 17, in <module>
    from jax._src.config import config  # noqa: F401
  File "/Users/tubakaraca/opt/anaconda3/lib/python3.9/site-packages/jax/_src/config.py", line 24, in <module>
    from jax._src import lib
  File "/Users/tubakaraca/opt/anaconda3/lib/python3.9/site-packages/jax/_src/lib/__init__.py", line 84, in <module>
    cpu_feature_guard.check_cpu_features()
RuntimeError: This version of jaxlib was built using AVX instructions, which your CPU and/or operating system do not support. You may be able work around this issue by building jaxlib from source.


In [3]:
#!pip uninstall jax jaxlib

Found existing installation: jax 0.4.18
Uninstalling jax-0.4.18:
  Would remove:
    /Users/tubakaraca/opt/anaconda3/lib/python3.9/site-packages/jax-0.4.18.dist-info/*
    /Users/tubakaraca/opt/anaconda3/lib/python3.9/site-packages/jax/*
Proceed (Y/n)? ^C
[31mERROR: Operation cancelled by user[0m[31m
[0m

In [4]:
#!pip install --upgrade jax jaxlib

Collecting jax
  Downloading jax-0.4.20-py3-none-any.whl (1.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.7/1.7 MB[0m [31m8.7 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hCollecting jaxlib
  Downloading jaxlib-0.4.20-cp39-cp39-macosx_10_14_x86_64.whl (82.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m82.6/82.6 MB[0m [31m4.1 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0mm
Installing collected packages: jaxlib, jax
Successfully installed jax-0.4.20 jaxlib-0.4.20
