## Sentence Transformer Models with Snowpark

This is a common enough use case of using different embedding models in Snowflake, so here is a notebook to get started.

I used sentence transfomers version 2.2.2 and this didn't work with 2.7.0 when I ran it in May 2024.

This is based on example from Michael Gorkow - [Sentence Transformers using the Model Registry](https://medium.com/@michaelgorkow/custom-embedding-models-from-hugging-face-in-snowflake-fd9cc79e25c8).

If you need to make a custom embedding model, because of the versioning issues, you can use the [following example of Arctic embedding](https://github.com/michaelgorkow/snowflake_huggingface/blob/main/arctic_embeddings_snowflake.ipynb).


In [15]:
from snowflake.snowpark.session import Session
# Create a Snowpark session with a default connection.
session = Session.builder.create()

In [16]:
snowflake_environment = session.sql('select current_user(), current_version()').collect()
from snowflake.snowpark.version import VERSION
snowpark_version = VERSION

# Current Environment Details
print('User                        : {}'.format(snowflake_environment[0][0]))
print('Role                        : {}'.format(session.get_current_role()))
print('Database                    : {}'.format(session.get_current_database()))
print('Schema                      : {}'.format(session.get_current_schema()))
print('Warehouse                   : {}'.format(session.get_current_warehouse()))
print('Snowflake version           : {}'.format(snowflake_environment[0][1]))
print('Snowpark for Python version : {}.{}.{}'.format(snowpark_version[0],snowpark_version[1],snowpark_version[2]))

User                        : RSHAH
Role                        : "RAJIV"
Database                    : "RAJIV"
Schema                      : "PUBLIC"
Warehouse                   : "RAJIV"
Snowflake version           : 8.20.0
Snowpark for Python version : 1.15.0a1


In [17]:
from sentence_transformers import SentenceTransformer 
from snowflake.ml.model.model_signature import FeatureSpec, DataType, ModelSignature
# Get the model registry object
from snowflake.ml.registry import Registry

reg = Registry(
    session=session, 
    database_name=session.get_current_database(), 
    schema_name=session.get_current_schema()
    )

model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2" )

In [18]:
# In this example the output column will be called EMBEDDING 
model_sig = ModelSignature(
                  inputs=[
                          FeatureSpec(dtype=DataType.STRING, name='TEXT') ###Change this to match your text field 
                      ],
                      outputs=[
                          FeatureSpec(dtype=DataType.DOUBLE, name='EMBEDDING', shape=(384,))
                      ]
                  )

In [20]:
# Register the model to Snowflake
snow_model_custom = reg.log_model(
    model, 
    version_name='V_8',
    model_name='all_MiniLM_L6_v2', 
    signatures={'encode':model_sig}
    )

  return next(self.gen)


Text we will get embeddings for:

In [None]:
df = session.table('IMDB_SAMPLE')
df.show()

----------------------------------------------------------------
|"TEXT"                                              |"LABEL"  |
----------------------------------------------------------------
|Monarch Cove was one of the best Friday night's...  |1        |
|Not only did they get the characters all wrong,...  |0        |
|By no means my favourite Austen novel, and Palt...  |1        |
|I first saw this film when I was flipping throu...  |1        |
|I watched this film last night with anticipatio...  |0        |
|As I have matured, my fascination with the Acad...  |1        |
|Ned Kelly (Ledger), the infamous Australian out...  |0        |
|With the death of GEORGE NADER, on 4 February 2...  |1        |
|A lot of talk has been made about "psychologica...  |1        |
|When his in-laws are viciously murdered by a ga...  |0        |
----------------------------------------------------------------



Now with embeddings

In [28]:
from snowflake.snowpark.functions import col
from snowflake.snowpark import types as T
from snowflake.snowpark import functions as F


# Create Embeddings from Huggingface Model
embedding_df = snow_model_custom.run(df)
# We have to convert the output of the Huggingface model to Snowflake's Vector Datatype
embedding_df = embedding_df.with_column('EMBEDDING', F.col('EMBEDDING').cast(T.VectorType(float,384)))
embedding_df.show()

---------------------------------------------------------------------------------------------------------------------
|"TEXT"                                              |"LABEL"  |"EMBEDDING"                                         |
---------------------------------------------------------------------------------------------------------------------
|Monarch Cove was one of the best Friday night's...  |1        |[-0.1078076884150505, -0.07621019333600998, 0.0...  |
|Not only did they get the characters all wrong,...  |0        |[-0.0718730241060257, -0.0075439647771418095, 0...  |
|By no means my favourite Austen novel, and Palt...  |1        |[-0.06701567769050598, -0.10788971185684204, -0...  |
|I first saw this film when I was flipping throu...  |1        |[-0.08029511570930481, 0.023791490122675896, -0...  |
|I watched this film last night with anticipatio...  |0        |[-0.02025693468749523, -0.029329409822821617, 0...  |
|As I have matured, my fascination with the Acad...  |1 

Take advantage of the Cortex functions in Snowflake, like for distance

In [29]:
# Finally we can calculate the distance between all the embeddings 
# and our search vector

search_text = "The movie HER is a great movie about AI and love"
closest_texts = embedding_df.with_column(
    'VECTOR_DISTANCE', 
    F.vector_l2_distance(
        F.col('EMBEDDING'), 
        F.call_builtin('all_MiniLM_L6_v2!ENCODE', 
            F.lit(search_text))['EMBEDDING'].cast(T.VectorType(float,384))
    )
).cache_result()

# As we can see, all of the closest texts are AI related like our search vector
closest_texts.order_by('VECTOR_DISTANCE').drop('EMBEDDING').show(max_width=100)

---------------------------------------------------------------------------------------------------------------------------------------
|"TEXT"                                                                                                |"LABEL"  |"VECTOR_DISTANCE"   |
---------------------------------------------------------------------------------------------------------------------------------------
|By no means my favourite Austen novel, and Paltrow is by no means my favourite actress, but I fou...  |1        |1.156562608266896   |
|Ned Kelly (Ledger), the infamous Australian outlaw and legend. Sort of like Robin Hood, with a mi...  |0        |1.200213786790789   |
|Monarch Cove was one of the best Friday night's drama shown in a long time.I am asking the writer...  |1        |1.2160948661672624  |
|I first saw this film when I was flipping through the movie channels on my parents DirecTV. It wa...  |1        |1.2551056541560486  |
|As I have matured, my fascination with the Acad

Test data if you need some

In [None]:
# Create some test data to work with
ai_texts_german = [
    "KI revolutioniert die Geschäftsanalytik, indem sie tiefere Einblicke in Daten bietet.",
    "Unternehmen nutzen KI, um die Analyse und Interpretation komplexer Datensätze zu transformieren.",
    "Mit KI können Unternehmen nun große Datenmengen verstehen, um die Entscheidungsfindung zu verbessern.",
    "Künstliche Intelligenz ist ein Schlüsselwerkzeug für Unternehmen, die ihre Datenanalyse verbessern möchten.",
    "Der Einsatz von KI in Unternehmen hilft dabei, bedeutungsvolle Informationen aus großen Datensätzen zu extrahieren."
]

different_texts_german = [
    "Der große Weiße Hai ist einer der mächtigsten Raubtiere des Ozeans.",
    "Van Goghs Sternennacht stellt die Aussicht aus seinem Zimmer in der Anstalt bei Nacht dar.",
    "Quantencomputing könnte potenziell viele der derzeit verwendeten kryptografischen Systeme brechen.",
    "Die brasilianische Küche ist bekannt für ihre Vielfalt und Reichhaltigkeit, beeinflusst von Europa, Afrika und den amerindischen Kulturen.",
    "Das schnellste Landtier, der Gepard, erreicht Geschwindigkeiten von bis zu 120 km/h."
]
 
search_text = "Maschinelles Lernen ist eine unverzichtbare Ressource für Unternehmen, die ihre Dateneinblicke verbessern möchten."

df = session.create_dataframe(ai_texts_german+different_texts_german, schema=['TEXT'])