# Transfer learning

In [1]:
APPLY = True
EAGER = False
COLLECTION_NAME = '<var:table_name>' if not APPLY else 'sample_transfer_learning'
MODALITY = 'text'

In [2]:
from superduper import superduper, CFG

db = superduper('mongomock://test_db')

[32m2025-Mar-22 15:40:00.38[0m| [1mINFO    [0m | [36mDuncans-MacBook-Pro.local[0m| [36msuperduper.misc.importing[0m:[36m13  [0m | [1mLoading plugin: mongodb[0m
[32m2025-Mar-22 15:40:00.49[0m| [1mINFO    [0m | [36mDuncans-MacBook-Pro.local[0m| [36msuperduper.base.datalayer[0m:[36m51  [0m | [1mBuilding Data Layer[0m
[32m2025-Mar-22 15:40:00.49[0m| [1mINFO    [0m | [36mDuncans-MacBook-Pro.local[0m| [36msuperduper.base.datalayer[0m:[36m68  [0m | [1mData Layer built[0m
[32m2025-Mar-22 15:40:00.49[0m| [1mINFO    [0m | [36mDuncans-MacBook-Pro.local[0m| [36msuperduper.backends.base.cluster[0m:[36m109 [0m | [1mCluster initialized in 0.00 seconds.[0m
[32m2025-Mar-22 15:40:00.49[0m| [1mINFO    [0m | [36mDuncans-MacBook-Pro.local[0m| [36msuperduper.base.build[0m:[36m148 [0m | [1mConfiguration: 
 +----------------+-------------------------------+
| Configuration  |             Value             |
+----------------+--------------------------

<!-- TABS -->
## Get useful sample data

In [3]:
def getter():
    import json
    import random
    import subprocess

    subprocess.run([
        'curl', '-O', 'https://superduperdb-public-demo.s3.amazonaws.com/text_classification.json',
    ])
    with open("text_classification.json", "r") as f:
        data = json.load(f)
    subprocess.run(['rm', 'text_classification.json'])
    data = data[:200]
    def fold(): return {True: 'valid', False: 'train'}[random.random() < 0.1]
    data = [{**r, '_fold': fold()} for r in data]
    return data

After obtaining the data, we insert it into the database.

<!-- TABS -->
## Insert simple data

After turning on auto_schema, we can directly insert data, and superduper will automatically analyze the data type, and match the construction of the table and datatype.

In [4]:
if APPLY:
    data = getter()
    from superduper import Table

    db.apply(
        Table(
            COLLECTION_NAME,
            fields={'x': 'str', 'y': 'int'},
        ),
        force=True,
    )
    
    ids = db[COLLECTION_NAME].insert(data)

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
  1 1298k    1 17003    0     0  16151      0  0:01:22  0:00:01  0:01:21 16208

[32m2025-Mar-22 15:40:15.47[0m| [1mINFO    [0m | [36mDuncans-MacBook-Pro.local[0m| [36msuperduper.base.datalayer[0m:[36m388 [0m | [1mLoad (('Table', 'sample_transfer_learning')) from metadata...[0m
[32m2025-Mar-22 15:40:15.47[0m| [1mINFO    [0m | [36mDuncans-MacBook-Pro.local[0m| [36msuperduper.base.apply[0m:[36m94  [0m | [1mFound these changes and/ or additions that need to be made:[0m
[32m2025-Mar-22 15:40:15.47[0m| [1mINFO    [0m | [36mDuncans-MacBook-Pro.local[0m| [36msuperduper.base.apply[0m:[36m96  [0m | [1m----------------------------------------------------------------------------------------------------[0m
[32m2025-Mar-22 15:40:15.47[0m| [1mINFO    [0m | [36mDuncans-MacBook-Pro.local[0m| [36msuperduper.base.apply[0m:[36m97  [0m | [1mMETADATA EVENTS:[0m
[32m2025-Mar-22 15:40:15.47[0m| [1mINFO    [0m | [36mDuncans-MacBook-Pro.local[0m| [36msuperduper.base.apply[0m:[36m98  [0m | [1m----------------------------------------

100 1298k  100 1298k    0     0   711k      0  0:00:01  0:00:01 --:--:--  713k


<!-- TABS -->
## Compute features

In [5]:
import sentence_transformers
from superduper import Listener
from superduper_sentence_transformers import SentenceTransformer


embedding = SentenceTransformer(
    identifier="embedding",
    model='all-MiniLM-L6-v2',
    postprocess=lambda x: x.tolist(),
)

[2025-03-22 15:40:21] datasets INFO PyTorch version 2.5.1 available.
[2025-03-22 15:40:21] sentence_transformers.SentenceTransformer INFO Load pretrained SentenceTransformer: all-MiniLM-L6-v2


In [7]:
feature_extractor_listener = Listener(
    model=embedding,
    select=db[COLLECTION_NAME],
    key='x',
    identifier="features"
)


if APPLY and EAGER:
    feature_extractor_listener = db.apply(
        feature_extractor_listener,
        force=True,
    )

<!-- TABS -->
## Build and train classifier

In [8]:
from superduper_sklearn import Estimator, SklearnTrainer
from sklearn.svm import SVC


scikit_model = Estimator(
    identifier="my-model-scikit",
    object=SVC(),
    trainer=SklearnTrainer(
        "my-scikit-trainer",
        key=(feature_extractor_listener.outputs, "y"),
        select=db[COLLECTION_NAME].outputs(feature_extractor_listener.predict_id),
    ),
    upstream=[feature_extractor_listener],
)

Define a validation for evaluating the effect after training.

In [9]:
from superduper import Dataset, Metric, Validation

def acc(x, y):
    return sum([xx == yy for xx, yy in zip(x, y)]) / len(x)

accuracy = Metric(identifier="acc", object=acc)

t = db[COLLECTION_NAME]
select = t.filter(t['_fold'] == 'valid').outputs(feature_extractor_listener.predict_id)

validation = Validation(
    "transfer_learning_performance",
    key=(feature_extractor_listener.outputs, "y"),
    datasets=[
        Dataset(
            identifier="my-valid",
            select=select,
        )
    ],
    metrics=[accuracy],
)
scikit_model.validation = validation

If we execute the apply function, then the model will be added to the database, and because the model has a Trainer, it will perform training.

In [10]:
if APPLY and EAGER:
    db.apply(scikit_model, force=True)

Get the training metrics

In [12]:
from superduper import Application

application = Application(
    identifier='transfer-learning',
    components=[feature_extractor_listener, scikit_model],
)

In [13]:
if APPLY:
    db.apply(application)

[32m2025-Mar-22 15:40:57.31[0m| [1mINFO    [0m | [36mDuncans-MacBook-Pro.local[0m| [36msuperduper.base.datalayer[0m:[36m388 [0m | [1mLoad (('Application', 'transfer-learning')) from metadata...[0m
[32m2025-Mar-22 15:40:57.31[0m| [1mINFO    [0m | [36mDuncans-MacBook-Pro.local[0m| [36msuperduper.base.datalayer[0m:[36m388 [0m | [1mLoad (('Listener', 'features')) from metadata...[0m
[32m2025-Mar-22 15:40:57.31[0m| [1mINFO    [0m | [36mDuncans-MacBook-Pro.local[0m| [36msuperduper.base.datalayer[0m:[36m388 [0m | [1mLoad (('SentenceTransformer', 'embedding')) from metadata...[0m
[32m2025-Mar-22 15:40:58.88[0m| [1mINFO    [0m | [36mDuncans-MacBook-Pro.local[0m| [36msuperduper.base.datalayer[0m:[36m388 [0m | [1mLoad (('Table', '_outputs__features__c1b5d86988903f267e496024ca7b7ad2')) from metadata...[0m
[32m2025-Mar-22 15:40:58.89[0m| [1mINFO    [0m | [36mDuncans-MacBook-Pro.local[0m| [36msuperduper.base.datalayer[0m:[36m388 [0m | [1mLoa

  


[32m2025-Mar-22 15:41:00.15[0m| [1mINFO    [0m | [36mDuncans-MacBook-Pro.local[0m| [36msuperduper.base.event[0m:[36m142 [0m | [1mCreating superduper_sentence_transformers.model.SentenceTransformer:embedding:34fa3ac0f2f9c2363b1ccab9b2b3f230[0m
[32m2025-Mar-22 15:41:00.15[0m| [1mINFO    [0m | [36mDuncans-MacBook-Pro.local[0m| [36msuperduper.base.datalayer[0m:[36m388 [0m | [1mLoad (('Table', 'ArtifactRelations')) from metadata...[0m
[32m2025-Mar-22 15:41:00.16[0m| [1mINFO    [0m | [36mDuncans-MacBook-Pro.local[0m| [36msuperduper.base.datalayer[0m:[36m388 [0m | [1mLoad (('Table', 'Table')) from metadata...[0m
[32m2025-Mar-22 15:41:00.16[0m| [1mINFO    [0m | [36mDuncans-MacBook-Pro.local[0m| [36msuperduper.base.datalayer[0m:[36m388 [0m | [1mLoad (('Table', 'SentenceTransformer')) from metadata...[0m
[32m2025-Mar-22 15:41:00.17[0m| [1mINFO    [0m | [36mDuncans-MacBook-Pro.local[0m| [36msuperduper.base.datalayer[0m:[36m388 [0m | [1mLo

Batches:   0%|          | 0/7 [00:00<?, ?it/s]

[32m2025-Mar-22 15:41:01.38[0m| [1mINFO    [0m | [36mDuncans-MacBook-Pro.local[0m| [36msuperduper.base.datalayer[0m:[36m388 [0m | [1mLoad (('Table', '_outputs__features__c1b5d86988903f267e496024ca7b7ad2')) from metadata...[0m
[32m2025-Mar-22 15:41:01.41[0m| [1mINFO    [0m | [36mDuncans-MacBook-Pro.local[0m| [36msuperduper.base.datalayer[0m:[36m388 [0m | [1mLoad (('Table', 'Job')) from metadata...[0m
[32m2025-Mar-22 15:41:01.45[0m| [1mINFO    [0m | [36mDuncans-MacBook-Pro.local[0m| [36msuperduper.base.datalayer[0m:[36m388 [0m | [1mLoad (('Table', 'Job')) from metadata...[0m
[32m2025-Mar-22 15:41:01.61[0m| [1mINFO    [0m | [36mDuncans-MacBook-Pro.local[0m| [36msuperduper_sklearn.model[0m:[36m42  [0m | [1mLoading dataset into memory for Estimator.fit[0m


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 177/177 [00:00<00:00, 3405467.01it/s]


[32m2025-Mar-22 15:41:03.23[0m| [1mINFO    [0m | [36mDuncans-MacBook-Pro.local[0m| [36msuperduper.base.datalayer[0m:[36m388 [0m | [1mLoad (('Estimator', 'my-model-scikit')) from metadata...[0m
[32m2025-Mar-22 15:41:03.24[0m| [1mINFO    [0m | [36mDuncans-MacBook-Pro.local[0m| [36msuperduper.base.datalayer[0m:[36m388 [0m | [1mLoad (('Listener', 'features')) from metadata...[0m
[32m2025-Mar-22 15:41:03.24[0m| [1mINFO    [0m | [36mDuncans-MacBook-Pro.local[0m| [36msuperduper.base.datalayer[0m:[36m388 [0m | [1mLoad (('SentenceTransformer', 'embedding')) from metadata...[0m
[32m2025-Mar-22 15:41:04.87[0m| [1mINFO    [0m | [36mDuncans-MacBook-Pro.local[0m| [36msuperduper.base.datalayer[0m:[36m388 [0m | [1mLoad (('Table', '_outputs__features__c1b5d86988903f267e496024ca7b7ad2')) from metadata...[0m
[32m2025-Mar-22 15:41:04.88[0m| [1mINFO    [0m | [36mDuncans-MacBook-Pro.local[0m| [36msuperduper.base.datalayer[0m:[36m388 [0m | [1mLoad ((

[32m2025-Mar-22 15:41:04.89[0m| [1mINFO    [0m | [36mDuncans-MacBook-Pro.local[0m| [36msuperduper.base.apply[0m:[36m94  [0m | [1mFound these changes and/ or additions that need to be made:[0m
[32m2025-Mar-22 15:41:04.89[0m| [1mINFO    [0m | [36mDuncans-MacBook-Pro.local[0m| [36msuperduper.base.apply[0m:[36m96  [0m | [1m----------------------------------------------------------------------------------------------------[0m
[32m2025-Mar-22 15:41:04.89[0m| [1mINFO    [0m | [36mDuncans-MacBook-Pro.local[0m| [36msuperduper.base.apply[0m:[36m97  [0m | [1mMETADATA EVENTS:[0m
[32m2025-Mar-22 15:41:04.89[0m| [1mINFO    [0m | [36mDuncans-MacBook-Pro.local[0m| [36msuperduper.base.apply[0m:[36m98  [0m | [1m----------------------------------------------------------------------------------------------------[0m
[32m2025-Mar-22 15:41:04.89[0m| [1mINFO    [0m | [36mDuncans-MacBook-Pro.local[0m| [36msuperduper.base.apply[0m:[36m105 [0m | [1m[0]: 

In [14]:
if APPLY:
    model = db.load('Estimator', 'my-model-scikit')
    print(model.metric_values)

[32m2025-Mar-22 15:41:09.21[0m| [1mINFO    [0m | [36mDuncans-MacBook-Pro.local[0m| [36msuperduper.base.datalayer[0m:[36m388 [0m | [1mLoad (('Estimator', 'my-model-scikit')) from metadata...[0m
{'my-valid/acc': 0.8695652173913043}


In [16]:
from superduper import Template, Table, Schema
from superduper.components.dataset import RemoteData

t = Template(
    'transfer_learning',
    default_tables=[Table(
        'sample_transfer_learning',
        fields={'x': 'str', 'y': 'int'},
        data=RemoteData(
            'text_classification',
            getter=getter,
        ),
    )],
    template=application,
    substitutions={'docs': 'table_name', 'text': 'modality'},
    template_variables=['table_name', 'framework', 'modality'],
    types={
        'table_name': {
            'type': 'str',
            'default': 'sample_transfer_learning',
        },
        'modality': {
            'type': 'str',
            'default': 'text',
        },
        'framework': {
            'type': 'str',
            'default': 'scikit-framework',
        },
    },
    db=db
)

In [17]:
t.export('.')