In [1]:
import pandas as pd
import numpy as np
import altair as alt
from sklearn.decomposition import PCA
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.cluster import KMeans

In [None]:
!pip install cohere
import cohere
co = cohere.ClientV2('Secret key')



Step 1: Prepare the **Dataset**

We'll work a subset of the Airline Travel Information System (ATIS) intent classification dataset [Source]. The following code loads the dataset into a pandas Dataframe df with a single column "queries" containing 91 inquiries coming to airline travel inquiry systems.

In [10]:
# Load the dataset to a dataframe
df_orig = pd.read_csv('https://raw.githubusercontent.com/cohere-ai/notebooks/main/notebooks/data/atis_intents_train.csv', names=['intent','query'])
df_orig


Unnamed: 0,intent,query
0,atis_flight,i want to fly from boston at 838 am and arriv...
1,atis_flight,what flights are available from pittsburgh to...
2,atis_flight_time,what is the arrival time in san francisco for...
3,atis_airfare,cheapest airfare from tacoma to orlando
4,atis_airfare,round trip fares from pittsburgh to philadelp...
...,...,...
4829,atis_airfare,what is the airfare for flights from denver t...
4830,atis_flight,do you have any flights from denver to baltim...
4831,atis_airline,which airlines fly into and out of denver
4832,atis_flight,does continental fly from boston to san franc...


In [11]:
# Take a small sample for illustration purposes
sample_classes = ['atis_airfare', 'atis_airline', 'atis_ground_service']
df = df_orig.sample(frac=0.1, random_state=30)
df = df[df.intent.isin(sample_classes)]
df_orig = df_orig.drop(df.index)
df.reset_index(drop=True,inplace=True)
df

Unnamed: 0,intent,query
0,atis_airline,which airlines fly from boston to washington ...
1,atis_airline,show me the airlines that fly between toronto...
2,atis_airfare,show me round trip first class tickets from n...
3,atis_airfare,i'd like the lowest fare from denver to pitts...
4,atis_ground_service,show me a list of ground transportation at bo...
...,...,...
86,atis_ground_service,what ground transportation is there in atlanta
87,atis_airline,can i take a single airline from la to charlo...
88,atis_airfare,what is the cost for a one way trip from pitt...
89,atis_ground_service,what ground transportation is available in ba...


In [12]:

# Remove unnecessary column
intents = df['intent'] #save for a later need
df.drop(columns=['intent'], inplace=True)
df

Unnamed: 0,query
0,which airlines fly from boston to washington ...
1,show me the airlines that fly between toronto...
2,show me round trip first class tickets from n...
3,i'd like the lowest fare from denver to pitts...
4,show me a list of ground transportation at bo...
...,...
86,what ground transportation is there in atlanta
87,can i take a single airline from la to charlo...
88,what is the cost for a one way trip from pitt...
89,what ground transportation is available in ba...


**Step 2: Turn Text into Embeddings** \  
Next, we embed each inquiry by calling Cohere’s Embed endpoint with co.embed(). It takes in texts as input and returns embeddings as output. We supply three parameters:

**texts:** The list of texts you want to embed \\
**model:** The model to use to generate the embedding \\
**input_type** — Specifies the type of document to be embedded. At the time of writing, there are four options: \\
**search_document:** For documents against which search is performed \\
**search_query:** For query documents \\
classification: For when the embeddings will be used as an input to a text classifier \\
**clustering:** For when you want to cluster the embeddings


For every piece of text passed to the Embed endpoint, a sequence of 1024 numbers will be generated. Each number represents a piece of information about the meaning contained in that piece of text.

In [14]:
def get_embeddings(texts, model="embed-v4.0", input_type="search_document"):
    output = co.embed(
        texts=texts,
        model=model,
        input_type=input_type,
        embedding_types=["float"]
    )
    return output.embeddings.float

df['query_embeds'] = get_embeddings(df['query'].tolist())
df

Unnamed: 0,query,query_embeds
0,which airlines fly from boston to washington ...,"[0.05444336, -0.021362305, -0.002029419, -0.03..."
1,show me the airlines that fly between toronto...,"[0.022460938, 0.010925293, -0.015136719, -0.01..."
2,show me round trip first class tickets from n...,"[-0.053710938, 0.029418945, -0.0051574707, 0.0..."
3,i'd like the lowest fare from denver to pitts...,"[0.048339844, 0.017211914, -0.020507812, -0.01..."
4,show me a list of ground transportation at bo...,"[0.046875, -0.0038146973, 0.008178711, -0.0532..."
...,...,...
86,what ground transportation is there in atlanta,"[0.026855469, -0.0015258789, -0.020629883, -0...."
87,can i take a single airline from la to charlo...,"[0.044433594, 0.012084961, 0.055664062, 0.0227..."
88,what is the cost for a one way trip from pitt...,"[0.026489258, -0.0014953613, -0.0095825195, -0..."
89,what ground transportation is available in ba...,"[0.055664062, -0.0024261475, -0.005645752, -0...."


**Step 3: Visualize Embeddings with a Heatmap**
Let’s get some visual intuition about this by plotting these numbers in a heatmap. What we can do is compress the dimension to a much lower number, say 10.

The get_pc() function below does this via a technique called Principal Component Analysis (PCA), which reduces the number of dimensions in an embedding while retaining as much information as possible. We set embeds_pc to the ten-dimensional version of the document embeddings.


In [16]:
# Function to return the principal components
def get_pc(arr, n):
    pca = PCA(n_components=n)
    embeds_transform = pca.fit_transform(arr)
    return embeds_transform

# Reduce embeddings to 10 principal components to aid visualization
embeds = np.array(df['query_embeds'].tolist())
embeds_pc = get_pc(embeds, 10)
embeds_pc

array([[-8.08249255e-02,  4.32685894e-01,  2.21944750e-01,
        -1.34397962e-01,  2.76273469e-02,  8.79350006e-02,
         4.70661998e-02,  2.54822227e-02, -6.96558623e-02,
        -5.33133991e-02],
       [-8.39247939e-02,  5.20698125e-01, -1.87679448e-01,
         2.02222098e-01, -1.86410615e-01,  8.28716690e-03,
         3.16765676e-02,  1.61495237e-01, -1.55668336e-02,
        -2.18885656e-01],
       [-2.32537001e-01, -3.41024509e-02, -4.69066138e-01,
        -2.70110925e-01,  3.36086051e-02,  1.05755733e-01,
        -1.58089652e-01,  1.34079861e-01,  1.19964983e-01,
        -5.87898732e-03],
       [-2.90012298e-01, -9.89706201e-02, -7.69238411e-02,
         3.84677699e-01, -2.17212115e-01,  6.17422263e-02,
        -9.27147294e-02,  1.95261975e-01, -8.98451603e-02,
         6.02084285e-02],
       [ 5.08917874e-01,  1.31731103e-01,  1.39267605e-01,
        -1.40794018e-01, -2.71150062e-01,  1.27401576e-01,
        -7.26644700e-02,  3.36761049e-02,  9.43410531e-02,
        -5.