# Transfer learning using Sentence Transformers and Scikit-Learn

In [None]:
!pip install pinnacledb
!pip install sentence-transformers

In this example, we'll be demonstrating how to simply implement transfer learning using SuperDuperDB.
You'll find related examples on vector-search and simple training examples using scikit-learn in the 
the notebooks directory of the project. Transfer learning leverages similar components, and may be used synergistically with vector-search. Vectors are, after all, simultaneously featurizations of 
data and may be used in downstream learning tasks.

Let's first connect to MongoDB via SuperDuperDB, you read explanations of how to do this in 
the docs, and in the `notebooks/` directory.

In [None]:
from pinnacledb import pinnacle
from pinnacledb.backends.mongodb import Collection
import os

# Uncomment one of the following lines to use a bespoke MongoDB deployment
# For testing the default connection is to mongomock

mongodb_uri = os.getenv("MONGODB_URI","mongomock://test")
# mongodb_uri = "mongodb://localhost:27017"
# mongodb_uri = "mongodb://pinnacle:pinnacle@mongodb:27017/documents"
# mongodb_uri = "mongodb://<user>:<pass>@<mongo_cluster>/<database>"
# mongodb_uri = "mongodb+srv://<username>:<password>@<atlas_cluster>/<database>"

# Super-Duper your Database!
from pinnacledb import pinnacle
db = pinnacle(mongodb_uri)

collection = Collection('transfer')

We'll use textual data labelled with sentiment, to test the functionality. Transfer learning 
can be used on any data which can be processed with SuperDuperDB models.

In [None]:
import numpy
from datasets import load_dataset

from pinnacledb import Document as D

data = load_dataset("imdb")

N_DATAPOINTS = 500    # increase in order to improve quality

train_data = [
    D({'_fold': 'train', **data['train'][int(i)]}) 
    for i in numpy.random.permutation(len(data['train']))
][:N_DATAPOINTS]

valid_data = [
    D({'_fold': 'valid', **data['test'][int(i)]}) 
    for i in numpy.random.permutation(len(data['test']))
][:N_DATAPOINTS // 10]

db.execute(collection.insert_many(train_data))

r = db.execute(collection.find_one())
r

Let's create a SuperDuperDB model based on a `sentence_transformers` model.
You'll notice that we don't necessarily need a native SuperDuperDB integration to a model library 
in order to leverage its power with SuperDuperDB. For example, in this case, we just need 
to configure the `Model` wrapper to interoperate correctly with the `SentenceTransformer` class. After doing this, we can link the model to a collection, and daemonize the model using the `listen=True` keyword:

In [None]:
from pinnacledb import Model
import sentence_transformers
from pinnacledb.ext.numpy import array

m = Model(
    identifier='all-MiniLM-L6-v2',
    object=sentence_transformers.SentenceTransformer('all-MiniLM-L6-v2'),
    encoder=array('float32', shape=(384,)),
    predict_method='encode',
    batch_predict=True,
)

m.predict(
    X='text',
    db=db,
    select=collection.find(),
    listen=True
)

Now that we've created and added the model which computes features for the `"text"`, we can train a 
downstream model using Scikit-Learn:

In [None]:
from sklearn.svm import SVC

model = pinnacle(
    SVC(gamma='scale', class_weight='balanced', C=100, verbose=True),
    postprocess=lambda x: int(x)
)

model.fit(
    X='text',
    y='label',
    db=db,
    select=collection.find().featurize({'text': 'all-MiniLM-L6-v2'}),
)

Now that the model has been trained, we can apply the model to the database, also daemonizing the model 
with `listen=True`.

In [None]:
model.predict(
    X='text',
    db=db,
    select=collection.find().featurize({'text': 'all-MiniLM-L6-v2'}),
    listen=True,
)

To verify that this process has worked, we can sample a few records, to inspect the sanity of the predictions

In [None]:
r = next(db.execute(collection.aggregate([{'$sample': {'size': 1}}])))
print(r['text'][:100])
print(r['_outputs']['text']['svc'])