# embeddings

> Extract text embedding features.

In [None]:
#| default_exp embeddings

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
from sklearn.base import BaseEstimator, TransformerMixin
from textplumber.store import TextFeatureStore
from model2vec import StaticModel
import numpy as np

In [None]:
#| export
class Model2VecEmbedder(BaseEstimator, TransformerMixin):
	""" Sci-kit Learn pipeline component to generate embeddings using Model2Vec. """
	def __init__(self, 
                feature_store: TextFeatureStore, # the feature store to use - this should be the same feature store used in the SpacyPreprocessor component
				model_name:str = 'minishlab/potion-base-8M' # the model name to use
				):
		self.feature_store = feature_store
		self.model_name = model_name
		self.model_ = StaticModel.from_pretrained(self.model_name)

	def fit(self, X, y=None):
		""" Fit is implemented but does nothing. """
		return self

	def transform(self, X):
		""" Generate embeddings for the texts using the Model2Vec model. 
		If the embeddings are already in the feature store, they are used instead of recomputing them. Processing is done in batches of 
		1000 texts to avoid memory issues. """
		embeddings = self.feature_store.get_embeddings_from_texts(X)
		if any(x is None for x in embeddings):
			embeddings = []
			for i in range(0, len(X), 1000):
				X_batch = X[i:i+1000]
				embeddings_batch = self.model_.encode(X_batch)
				embeddings_batch = np.array(embeddings_batch, dtype=np.double) # returning as floats seemed to be causing issues with kmeans pipeline component
				embeddings.append(embeddings_batch)
			embeddings = np.concatenate(embeddings, axis=0)
			self.feature_store.update_embeddings(X, embeddings)
		else:
			# all the embeddings are already in the feature store so no need to reprocess
			pass
		return embeddings
	
	def get_feature_names_out(self):
		""" Get the feature names out from the model. """
		return [f'emb_{i}' for i in range(self.model_.dim)]

TODO: add an example.

In [None]:
#| hide
import os
feature_store = TextFeatureStore('test_embeddings.sqlite')
model2vec_embedder = Model2VecEmbedder(feature_store=feature_store)
model2vec_embedder.fit(['Hello, world!'])
X = model2vec_embedder.fit_transform(['Hello, world!'])
assert X.shape == (1, 256)
del feature_store
os.remove('test_embeddings.sqlite')

In [12]:
#| hide
import nbdev; nbdev.nbdev_export()