# Transfer learning using Sentence Transformers and Scikit-Learn

In this example, we'll be demonstrating how to simply implement transfer learning using SuperDuperDB.
You'll find related examples on vector-search and simple training examples using scikit-learn in the 
the notebooks directory of the project. Transfer learning leverages similar components, and may be used synergistically with vector-search. Vectors are, after all, simultaneously featurizations of 
data and may be used in downstream learning tasks.

Let's first connect to MongoDB via SuperDuperDB, you read explanations of how to do this in 
the docs, and in the `notebooks/` directory.

In [1]:
from pinnacledb import pinnacle
from pinnacledb.datalayer.mongodb.query import Collection
import pymongo

db = pinnacle(
    pymongo.MongoClient().documents
)

collection = Collection('transfer')

INFO:numexpr.utils:NumExpr defaulting to 8 threads.


We'll use textual data labelled with sentiment, to test the functionality. Transfer learning 
can be used on any data which can be processed with SuperDuperDB models.

In [None]:
import numpy
from datasets import load_dataset

from pinnacledb.core.document import Document as D

data = load_dataset("imdb")

train_data = [
    D({'_fold': 'train', **data['train'][int(i)]}) 
    for i in numpy.random.permutation(len(data['train']))
][:10000]

valid_data = [
    D({'_fold': 'valid', **data['test'][int(i)]}) 
    for i in numpy.random.permutation(len(data['test']))
][:1000]

db.execute(collection.insert_many(train_data))

r = db.execute(collection.find_one())
r

In [2]:
r = db.execute(collection.find_one())
r

Document({'_id': ObjectId('64c126b43a073f413a20791a'), '_fold': 'train', 'text': 'Where to even start? The horrendous acting? The nonsensical plot? The bargain basement effects? The completely loathsome characters? The choppy editing? The headache-inducing Casio keyboard score??? The embarrassingly racist remarks ("Watch it, Charlie!", "Back off, Jackie Chan!!"??? The constant misogyny??? I am a lifelong horror fan, and I have no problem at all with the current "torture-thon" trend of movies. However, this is a poorly-made piece of garbage. I think I suffered more pain watching this than the characters did dying in it! If you like girls being forced to eat stir-fried penis, really poor soft core porn and think lines like "I\'m gonna find that b**** and staple her c*** shut!!" are clever, LIVE FEED is for you.<br /><br />As for me, I feel the need to go wash my eyes out with oven cleaner to prevent from ever seeing this movie again!', 'label': 0, '_outputs': {'text': {'all-MiniLM-L6-v2'

Let's create a SuperDuperDB model based on a `sentence_transformers` model.
You'll notice that we don't necessarily need a native SuperDuperDB integration to a model library 
in order to leverage its power with SuperDuperDB. For example, in this case, we just need 
to configure the `Model` wrapper to interoperate correctly with the `SentenceTransformer` class. After doing this, we can link the model to a collection, and daemonize the model using the `watch=True` keyword:

In [None]:
from pinnacledb.core.model import Model
import sentence_transformers

from pinnacledb.encoders.numpy.array import array

m = Model(
    identifier='all-MiniLM-L6-v2',
    object=sentence_transformers.SentenceTransformer('all-MiniLM-L6-v2'),
    encoder=array('float32', shape=(384,)),
    predict_method='encode',
    batch_predict=True,
)

m.predict(
    X='text',
    db=db,
    select=collection.find(),
    watch=True
)

Now that we've created and added the model which computes features for the `"text"`, we can train a 
downstream model using Scikit-Learn:

In [9]:
from sklearn.svm import SVC

model = pinnacle(
    SVC(gamma='scale', class_weight='balanced', C=100, verbose=True),
    postprocess=lambda x: int(x)
)

model.fit(
    X='text',
    y='label',
    db=db,
    select=collection.find().featurize({'text': 'all-MiniLM-L6-v2'}),
)

Now that the model has been trained, we can apply the model to the database, also daemonizing the model 
with `watch=True`.

In [15]:
model.predict(
    X='text',
    db=db,
    select=collection.find().featurize({'text': 'all-MiniLM-L6-v2'}),
    watch=True,
)



To verify that this process has worked, we can sample a few records, to inspect the sanity of the predictions

In [38]:
r = next(db.execute(collection.aggregate([{'$sample': {'size': 1}}])))
print(r['text'][:100])
print(r['_outputs']['text']['svc'])

-The movie tells the tale of a prince whose life is wonderful, but after an evil wizard tells him to
1
