[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/weaviate/recipes/blob/main/weaviate-features/services-research/contextual_document_embeddings.ipynb)

# Contextual Document Embeddings

Notebook author: Danny Williams @ Weaviate


## Overview

[Contextual Document Embeddings](https://arxiv.org/abs/2410.02525) is a new technique for embedding documents that takes into account the context of the neighbouring documents/chunks.

The model is hosted on Hugging Face under [jxm/cde-small-v1](https://huggingface.co/jxm/cde-small-v1), and can be loaded using the `sentence-transformers` library, and therefore can easily be used to load locally and embed documents, and therefore can also easily be used in Weaviate.

### Contextual Document Embeddings

In short, the model does two things differently to standard embedding models:

1. During training, similar documents are clustered into batches, and the model is trained on each batch.
2. During inference, the model takes in the document and a collection of documents within the same dataset, and outputs an embedding that takes this context into account.

### Setup

Firstly, let's install the necessary libraries:


In [1]:
import os
os.environ["LOG_LEVEL"] = "error"
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
%%capture
!pip install sentence-transformers
!pip install weaviate-client
!pip install einops

## Example Data

Let's first set up an example to show off the model's capabilities.

In [3]:
from sklearn.datasets import fetch_20newsgroups

# newsgroups posts specifically about baseball
data = fetch_20newsgroups(subset='train', categories=['rec.sport.baseball'], remove=('headers', 'footers', 'quotes'))

In [4]:
# example of one piece of text
print(data["data"][0])


Maybe it's just me, but the combination of those *young* faces peeking out
from under oversized aqua helmets screams "Little League" in every fibre of
my being...



## Using the model

The model can be used as-is, or we can include surrounding contexts to improve the embedding towards our specific purposes.

### As-is

In [5]:
%%capture
# load the model
from sentence_transformers import SentenceTransformer
model = SentenceTransformer('jxm/cde-small-v1', trust_remote_code=True)

In [6]:
# embed a single document
single_non_contextual_embedding = model.encode(
    data["data"][0],
    prompt_name="document",
    convert_to_tensor=True
)
print(f"Embedding shape: {single_non_contextual_embedding.shape}")
print(f"Embedding: {single_non_contextual_embedding[:4].tolist()}...")


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Embedding shape: torch.Size([768])
Embedding: [0.2638474404811859, 0.1533348560333252, -0.2799359858036041, 0.05649382993578911]...


In [7]:
# embed the entire dataset
non_contextual_document_embeddings = model.encode(
    data["data"],
    prompt_name="document",
    convert_to_tensor=True
)
print(f"Dataset embeddings shape: {non_contextual_document_embeddings.shape}")


Dataset embeddings shape: torch.Size([597, 768])


### With Context

First, we compute embeddings _for context_, taken on our full dataset, so that the model can use this later to provide context to future document embeddings.

In [8]:
minicorpus_size = model[0].config.transductive_corpus_size
print(f"Mini-corpus size: {minicorpus_size}")

dataset_embeddings = model.encode(
    data["data"][:minicorpus_size],
    prompt_name="document",
    convert_to_tensor=True
)

Mini-corpus size: 512


Now we can use these embeddings to provide context to a new document, with a simple argument to the `encode` method of `dataset_embeddings`.

In [9]:
single_contextual_embedding = model.encode(
    data["data"][0],
    prompt_name="document",
    dataset_embeddings=dataset_embeddings,
    convert_to_tensor=True
)

print(f"Embedding shape: {single_contextual_embedding.shape}")
print(f"Embedding: {single_contextual_embedding[:4].tolist()}...")


Embedding shape: torch.Size([768])
Embedding: [0.025975484400987625, 0.016233962029218674, 0.05705662816762924, -0.003157366067171097]...


We can see that this embedding is different to the non-contextual embedding, but should still be semantically similar to the original document, due to the context added via the `dataset_embeddings` argument. These embeddings 'prime' the model to understand the context of the dataset, and so the embeddings for new documents are more contextually aware.

In [10]:
contextual_document_embeddings = model.encode(
    data["data"],
    prompt_name="document",
    dataset_embeddings=dataset_embeddings,
    convert_to_tensor=True
)
print(f"Dataset embeddings shape: {contextual_document_embeddings.shape}")


Dataset embeddings shape: torch.Size([597, 768])


## Weaviate

Let's create two collections in Weaviate to query, one with not contextually primed embeddings, and one with contextual embeddings.
First we will create a local Weaviate embedded instance.  

In [11]:
import contextlib
import weaviate

client = weaviate.connect_to_embedded()


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Weaviate allows [bring your own vectors](https://weaviate.io/developers/weaviate/starter-guides/custom-vectors), which is what we will use here. We can use the embeddings that we created earlier.

In [12]:
import weaviate.classes as wvc

# Collection 1: non-contextual embeddings
non_contextual_collection = client.collections.create(
    "non_contextual_embeddings",
    vectorizer_config = wvc.config.Configure.Vectorizer.none()
)

# appending the objects to the list, consisting of the text and the embedding vector calculated earlier
objs = []
for i, d in enumerate(data["data"]):
    objs.append(wvc.data.DataObject(
            properties={
                "text": data["data"][i],
            },
            vector = non_contextual_document_embeddings[i].tolist()
        )
    )

non_contextual_collection.data.insert_many(objs);

# Collection 2: contextual embeddings
contextual_collection = client.collections.create(
    "contextual_embeddings",
    vectorizer_config = wvc.config.Configure.Vectorizer.none()
)

# appending the objects to the list
objs = []
for i, d in enumerate(data["data"]):
    objs.append(wvc.data.DataObject(
            properties={
                "text": data["data"][i],
            },
            vector = contextual_document_embeddings[i].tolist()
        )
    )

contextual_collection.data.insert_many(objs);

Now we can query the two collections.

We first need to embed our query.

In [13]:
query = "How much is Rickey Henderson being paid?"
non_contextual_query_embedding = model.encode(
    query,
    prompt_name="query",
    convert_to_tensor=True
)

And retrieve the results with a near vector query in Weaviate.

In [14]:
non_contextual_results = non_contextual_collection.query.near_vector(
    near_vector = non_contextual_query_embedding.tolist(),
    limit=3
)

client.collections.delete("non_contextual_embeddings")
client.collections.delete("contextual_embeddings")

In [15]:
print("Top 3 results (non-contextual):\n")
for i, obj in enumerate(non_contextual_results.objects):
    print(f"Result {i+1}")
    print("_____")
    print(f"Text: {obj.properties['text']}")
    print("_____")
    print("\n\n")

Top 3 results (non-contextual):

Result 1
_____
Text: Davis will be paid by three clubs this year, I think the Phils are
responsbible for about $600,000 or so.  They didn't wait for him to clear
waivers as three other clubs were also very interested in him.  A gamble?
Yes.

Won the CY Young, too, for that year.
_____



Result 2
_____
Text: : I believe that Rusty Staub was also a jewish ball-player
: Also, Mordaci Brown back in the early 20th century.  He was a pitcher whose
: nickname was "3 fingers" Brown....for obvious reasons....he had 3 fingers.

0 for 2, ma_ind25.

Daniel Patrick Staub is a Catholic school kid from Nawlins, Mordecai
Brown a farm kid (probably Protestant) from somewhere in the Midwest.
He lost those fingers in a farm machinery accident.

Jim Palmer isn't Jewish himself, but Mr. Jockey Shorts's adoptive 
parents are.

Also, I'm not absolutely certain that Carew actually converted.  His
wife and children certainly are Jewish.

--
_____



Result 3
_____
Text: 
Wasn'

Now we do the same for the more contextually aware embeddings. First embed the query with the context of the dataset, in the same way as we did for the documents.

In [16]:
contextual_query_embedding = model.encode(
    query,
    prompt_name="query",
    dataset_embeddings=dataset_embeddings,
    convert_to_tensor=True
)

And then query with the near vector search.

In [17]:
contextual_results = contextual_collection.query.near_vector(
    near_vector = contextual_query_embedding.tolist(),
    limit=3
)

In [18]:
print("Top 3 results (contextual):\n")
for i, obj in enumerate(contextual_results.objects):
    print(f"Result {i+1}")
    print("_____")
    print(f"Text: {obj.properties['text']}")
    print("_____")
    print("\n\n")

Top 3 results (contextual):

Result 1
_____
Text: 
Actually, I could care less what his salary is.  It has something to do
with the fact that we live in America, and everyone is entitled to
whatever he can legally obtain.  If Sandy Alderson and the Haas family
willingly negotiate a salary of $35 million per year with Rickey, I couldn't
care less.

But what REALLY GETS MY GOAT is the bullshit he spouted in spring training,
about `Well... sometimes I may not play as hard, or might be hurt more
often, in a place where I'm not appreciated'.  This quote was in the Chronicle
about the second week of camp, and strongly suggests that he was going to 
dog it all year if the ownership didn't kiss his butt and ante up some
more money.  For God's sake, Rickey, you signed a contract 4 years ago,
now honor it and play!  

Say all you want to about Steve Garvey, and believe
me, I hated him too, but at least when he put his signature on a piece
of paper he shut his mouth and played hard until the cont