# 04. Fine-tuning

So far we haven't even trained a single model. We've only used some pretrained ones which should work well enough in general cases, but not necessarily in a specific domain with its own terminology. But that doesn't mean we need to start from scratch. Those original embeddings still capture some useful pieces of information, and we should rather **slightly adjust them**, instead of starting from the very beginning.

![](images/fine_tuning.png)

In [None]:
import pandas as pd

from datasets import load_dataset
from torch.utils.data import Dataset
from quaterion.dataset.similarity_samples import SimilarityPairSample

Our dataset should be represented as pairs of question and corresponding answer. If we knew there are several valid answers for a specific question, then we could divide it into groups. In our case, we'll assume there is a single answer for a given question.

In [None]:
class TweetsQADataset(Dataset):

    def __init__(self, subset: str = "train"):
        self.dataset = pd.DataFrame(load_dataset("tweet_qa")[subset])

    def __getitem__(self, index) -> SimilarityPairSample:
        item = self.dataset.iloc[index]
        return SimilarityPairSample(
            obj_a=item["Question"],
            obj_b=item["Tweet"],
            subgroup=hash(item["Tweet"]),
        )

    def __len__(self):
        return self.dataset.shape[0]

Once the dataset is ready, we need to start preparing the model. Since we want it to be importable in different notebooks, that has to be done in a separate file, `model.py`

In [None]:
from model import TweetsQAModel

All the training is almost identical to the one we do with PyTorch or PyTorch Lightning.

In [None]:
import pytorch_lightning as pl

from quaterion import Quaterion
from quaterion.dataset import PairsSimilarityDataLoader

In [None]:
train_dataset = TweetsQADataset("test")
validation_dataset = TweetsQADataset("validation")
train_dataloader = PairsSimilarityDataLoader(train_dataset, batch_size=512)
validation_dataloader = PairsSimilarityDataLoader(validation_dataset, batch_size=512)

In [None]:
pl.seed_everything(42, workers=True)
tweets_qa_model = TweetsQAModel("all-MiniLM-L6-v2")

In [None]:
trainer = pl.Trainer(
    min_epochs=1,
    max_epochs=100,
    auto_select_gpus=True,
    num_sanity_val_steps=2,
)
Quaterion.fit(tweets_qa_model, trainer, train_dataloader, validation_dataloader)

In [None]:
tweets_qa_model.save_servable("tweets_qa_servable")