# Transfer learning

<!-- TABS -->
## Connect to pinnacle

In [None]:
from pinnacle import pinnacle

db = pinnacle('mongomock:///test_db')

<!-- TABS -->
## Get useful sample data

In [None]:
# <tab: Text-Classification>
!curl -O https://pinnacledb-public-demo.s3.amazonaws.com/text_classification.json
import json

with open("text_classification.json", "r") as f:
    data = json.load(f)
num_classes = 2

In [None]:
# <tab: Image-Classification>
!curl -O https://pinnacledb-public-demo.s3.amazonaws.com/images_classification.zip && unzip images_classification.zip
import json
from PIL import Image

with open('images/images.json', 'r') as f:
    data = json.load(f)
    
data = [{'x': Image.open(d['image_path']), 'y': d['label']} for d in data]
num_classes = 2

After obtaining the data, we insert it into the database.

In [None]:
# <tab: Text-Classification>
datas = [{'txt': d['x'], 'label': d['y']} for d in data]

In [None]:
# <tab: Image-Classification>
datas = [{'image': d['x'], 'label': d['y']} for d in data]

<!-- TABS -->
## Insert simple data

After turning on auto_schema, we can directly insert data, and pinnacle will automatically analyze the data type, and match the construction of the table and datatype.

In [None]:
from pinnacle import Document

table_or_collection = db['docs']

ids = db.execute(table_or_collection.insert([Document(data) for data in datas]))
select = table_or_collection.select()

<!-- TABS -->
## Compute features

In [None]:
# <tab: Text>
key = 'txt'
import sentence_transformers
from pinnacle import vector, Listener
from pinnacle_sentence_transformers import SentenceTransformer

pinnaclemodel = SentenceTransformer(
    identifier="embedding",
    object=sentence_transformers.SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2"),
    postprocess=lambda x: x.tolist(),
)

jobs, listener = db.apply(
    Listener(
        model=pinnaclemodel,
        select=select,
        key=key,
        identifier="features"
    )
)

In [None]:
# <tab: Image>
key = 'image'
import torchvision.models as models
from torchvision import transforms
from pinnacle_torch import TorchModel
from pinnacle import Listener
from PIL import Image

class TorchVisionEmbedding:
    def __init__(self):
        # Load the pre-trained ResNet-18 model
        self.resnet = models.resnet18(pretrained=True)
        
        # Set the model to evaluation mode
        self.resnet.eval()
        
    def preprocess(self, image):
        # Preprocess the image
        preprocess = preprocess = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
        tensor_image = preprocess(image)
        return tensor_image
        
model = TorchVisionEmbedding()
pinnaclemodel = TorchModel(identifier='my-vision-model-torch', object=model.resnet, preprocess=model.preprocess, postprocess=lambda x: x.numpy().tolist())

jobs, listener = db.apply(
    Listener(
        model=pinnaclemodel,
        select=select,
        key=key,
        identifier="features"
    )
)

## Choose features key from feature listener

In [None]:
input_key = listener.outputs
training_select = select.outputs(listener.predict_id)

We can find the calculated feature data from the database.

In [None]:
feature = list(training_select.limit(1).execute())[0][input_key]
feature_size = len(feature)

<!-- TABS -->
## Build and train classifier

In [None]:
# <tab: Scikit-Learn>
from pinnacle_sklearn import Estimator, SklearnTrainer
from sklearn.svm import SVC

model = Estimator(
    identifier="my-model",
    object=SVC(),
    trainer=SklearnTrainer(
        "my-trainer",
        key=(input_key, "label"),
        select=training_select,
    ),
)

In [None]:
# <tab: Torch>
import torch
from torch import nn
from pinnacle_torch.model import TorchModel
from pinnacle_torch.training import TorchTrainer
from torch.nn.functional import cross_entropy


class SimpleModel(nn.Module):
    def __init__(self, input_size=16, hidden_size=32, num_classes=3):
        super(SimpleModel, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        return out

preprocess = lambda x: torch.tensor(x)

# Postprocess function for the model output    
def postprocess(x):
    return int(x.topk(1)[1].item())

def data_transform(features, label):
    return torch.tensor(features), label

# Create a Logistic Regression model
# feature_length is the input feature size
model = SimpleModel(feature_size, num_classes=num_classes)
model = TorchModel(
    identifier='my-model',
    object=model,         
    preprocess=preprocess,
    postprocess=postprocess,
    trainer=TorchTrainer(
        key=(input_key, 'label'),
        identifier='my_trainer',
        objective=cross_entropy,
        loader_kwargs={'batch_size': 10},
        max_iterations=1000,
        validation_interval=100,
        select=select,
        transform=data_transform,
    ),
)

Define a validation for evaluating the effect after training.

In [None]:
from pinnacle import Dataset, Metric, Validation


def acc(x, y):
    return sum([xx == yy for xx, yy in zip(x, y)]) / len(x)


accuracy = Metric(identifier="acc", object=acc)
validation = Validation(
    "transfer_learning_performance",
    key=(input_key, "label"),
    datasets=[
        Dataset(identifier="my-valid", select=training_select.add_fold('valid'))
    ],
    metrics=[accuracy],
)
model.validation = validation

If we execute the apply function, then the model will be added to the database, and because the model has a Trainer, it will perform training tasks.

In [None]:
db.apply(model)

In [None]:
model.encode()

Get the training metrics

In [None]:
model = db.load('model', model.identifier)
model.metric_values

In [None]:
from pinnacle import Template

t = Template('transfer-learner', template=model, substitutions={'docs': 'table'})

In [None]:
t.export('.')