# Fine-Tune with MNIST

Follow steps below to get started with a jupyter notebook for how to train a Towhee operator. This example fine-tunes a pretrained model (eg. resnet-18) with the MNIST dataset.

## 1. Download Operator

Download operator files together with the jupyter notebook.

In [None]:
! git clone https://towhee.io/towhee/resnet-image-embedding.git
! cd resnet-image-embedding
! ls

Then run Python scripts in following steps to train and test a Towhee operator.

## 2. Setup Operator

Create operator and load model by name.

In [None]:
# import sys
# sys.path.append('..')

from resnet_image_embedding import ResnetImageEmbedding

# Set num_classes=10 for MNIST dataset
op = ResnetImageEmbedding('resnet18', num_classes=10)

## 3. Configure Trainer:

Modify training configurations on top of default values.

In [None]:
from towhee.trainer.training_config import TrainingConfig

training_config = TrainingConfig(
    batch_size=64,
    epoch_num=2,
    output_dir='mnist_output'
)

## 4. Prepare Dataset

The example here uses the MNIST dataset for both training and evaluation.

In [None]:
from torchvision import transforms
from towhee import dataset
from torchvision.transforms import Lambda
mean = 0.1307
std = 0.3081
mnist_transform = transforms.Compose([transforms.ToTensor(),
                                          Lambda(lambda x: x.repeat(3, 1, 1)),
                                          transforms.Normalize(mean=[mean] * 3, std=[std] * 3)])
train_data = dataset('mnist', transform=mnist_transform, download=True, root='data', train=True)
eval_data = dataset('mnist', transform=mnist_transform, download=True, root='data', train=False)

# 5. Start Training

Start to train mnist, it will take about 30-100 minutes on a cpu machine. If you train on a gpu machine, it will be much faster.

In [None]:
op.train(training_config, train_dataset=train_data, eval_dataset=eval_data)

Observing epoch progress bars, if loss decreases while metric increases, then you are training the model properly.


# 6. Predict after Training

After training, you can make new predictions with the operator. Comparing new predicted results with actual labels, you can evaluate the fine-tuned model with accuracy.

In [None]:
import matplotlib.pyplot as plt
import torch
import random

# get random picture and predict it.
img_index = random.randint(0, len(eval_data))
img = eval_data.dataset[img_index][0]
img = img.numpy().transpose(1, 2, 0)  # (C, H, W) -> (H, W, C)
pil_img = img * std + mean
plt.imshow(pil_img)
plt.show()
test_img = eval_data.dataset[img_index][0].unsqueeze(0).to(op.trainer.configs.device)
out = op.trainer.predict(test_img)
predict_num = torch.argmax(torch.softmax(out, dim=-1)).item()
print('this picture is number {}'.format(predict_num))