# __Using the Custom Model in ThanoSQL__ 

## __0. Prepare Dataset and Model__

This tutorial uses the Beans dataset. This dataset is of leaf images taken in the field in different districts in Uganda by the Makerere AI lab in collaboration with the National Crops Resources Research Institute (NaCRRI), the national body in charge of research in agriculture in Uganda.



Reference: <https://github.com/AI-Lab-Makerere/ibean>

### __Prepare Dataset__

#### Download and Unzip Data

In [1]:
import os
from shutil import unpack_archive
from urllib.request import urlretrieve

url = "https://storage.googleapis.com/ibeans"

for split in ["train", "validation", "test"]:
    urlretrieve(f"{url}/{split}.zip", f"{split}.zip")
    unpack_archive(f"{split}.zip", ".")
    os.remove(f"{split}.zip")

#### Install Necessary Packages

In [None]:
!pip install torch torchvision

#### Create a Training Dataset 
Following code block has been referenced from this [link](https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html) and has been modified for this tutorial's need.

In [None]:
from torch.utils.data import DataLoader
from torchvision import transforms as T
from torchvision.datasets import ImageFolder

data_transforms = {
    "train": T.Compose(
        [
            T.RandomResizedCrop(224),
            T.RandomHorizontalFlip(),
            T.ToTensor(),
            T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ]
    ),
    "validation": T.Compose(
        [
            T.Resize(224),
            T.CenterCrop(224),
            T.ToTensor(),
            T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ]
    ),
}

image_datasets = {
    split: ImageFolder(split, data_transforms[split])
    for split in ["train", "validation"]
}
dataloaders = {
    split: DataLoader(image_datasets[split], batch_size=8, shuffle=split == "train")
    for split in ["train", "validation"]
}
dataset_sizes = {split: len(image_datasets[split]) for split in ["train", "validation"]}

### __Prepare the Model__

#### Create a Model Training Function

In [4]:
import time
import copy
import torch


def train_model(model, criterion, optimizer, num_epochs=3):
    start_time = time.time()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    best_model_weights = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print(f"Epoch {epoch}/{num_epochs - 1}")
        print("-" * 10)

        # Every epoch goes through a training and validation phase
        for phase in ["train", "validation"]:
            if phase == "train":
                model.train()
            else:
                model.eval()

            running_loss = 0.0
            running_corrects = 0

            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                optimizer.zero_grad()

                # Forward propagation 
                with torch.set_grad_enabled(phase == "train"):
                    outputs = model(inputs)
                    preds = torch.argmax(outputs, dim=1)
                    loss = criterion(outputs, labels)

                    # Backward propagation during training phase only 
                    if phase == "train":
                        loss.backward()
                        optimizer.step()

                # Statistics 
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects / dataset_sizes[phase]

            print(f"{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}")

            # Save if the model accuracy is higher than the previous accuracy 
            if phase == "validation" and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_weights = copy.deepcopy(model.state_dict())

        print()

    time_elapsed = time.time() - start_time
    print(f"Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s")
    print(f"Best val Acc: {best_acc:4f}")

    model.load_state_dict(best_model_weights)
    return model

#### Load the Model

This tutorial uses mobilevit v2 as it has a high accuracy for a lightweight model. 

In [None]:
model = torch.hub.load("rwightman/pytorch-image-models", "mobilevitv2_050", pretrained=True, num_classes=3)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss()

#### Train and Save a Model

In [6]:
trained_model = train_model(model, criterion, optimizer, num_epochs=3)

Epoch 0/2
----------
train Loss: 0.5634 Acc: 0.7921
validation Loss: 0.2599 Acc: 0.8947

Epoch 1/2
----------
train Loss: 0.3259 Acc: 0.8762
validation Loss: 0.2687 Acc: 0.9173

Epoch 2/2
----------
train Loss: 0.2883 Acc: 0.8830
validation Loss: 0.1434 Acc: 0.9624

Training complete in 1m 26s
Best val Acc: 0.962406


In [7]:
torch.save(trained_model, "trained_model.pth")

#### Create a Dataframe to Insert into the ThanoSQL 

In [8]:
import numpy as np
import pandas as pd

test_dataset = ImageFolder("test", data_transforms["validation"])

data = np.stack([img.numpy() for img, _ in test_dataset])
df = pd.DataFrame(pd.Series(data.tolist()), columns=["image"])  # column name must be an "image"
df.to_pickle("test_data.pkl")

As mentioned in the [ThanoSQL Workspace](https://docs.thanosql.ai/en/getting_started/how_to_use_ThanoSQL/#5-thanosql-workspace), you must create an API token and run the query below to execute the query of ThanoSQL. 

In [None]:
%load_ext thanosql
%thanosql API_TOKEN=<Issued_API_TOKEN>

In [10]:
%%thanosql
COPY beans_test 
OPTIONS (overwrite=True)
FROM "test/udm_tutorial/test_data.pkl"

Success


<div class="admonition note">
    <h4 class="admonition-title">Query Details</h4>
    <ul>
        <li>"<strong>COPY</strong>" specifies the name of the dataset to be saved as a database table. </li>
        <li>"<strong>OPTIONS</strong>" specifies the option values to be used for the <strong>COPY</strong> clause.
        <ul>
           <li>"overwrite": determines whether to overwrite a table if it already exists. If set as True, the old table is replaced with the new table (True|False, DEFAULT: False) </li>
        </ul>
        </li>
    </ul>
</div>

## __1.Check Dataset__

To check the table's contents, run the following query.

In [11]:
%%thanosql
SELECT *
FROM beans_test
LIMIT 5

Unnamed: 0,image
0,"[[[-0.028684020042419434, -0.04580877348780632..."
1,"[[[-0.0629335269331932, -0.0629335269331932, -..."
2,"[[[1.9577873945236206, 1.8721636533737183, 1.7..."
3,"[[[0.21106265485286713, 0.0569397434592247, -0..."
4,"[[[-1.3815395832061768, -1.432913899421692, -1..."


## __2. Upload Custom Model__

To upload a custom model, run the following query.

In [12]:
%%thanosql
UPLOAD MODEL beans_mobilevit
OPTIONS (
    overwrite=True,
    framework="pytorch"
    )
FROM "test/udm_tutorial/trained_model.pth"

Success


## __3. Predict Using a Custom Model__

To predict the result using a custom model, run the following query.

In [13]:
%%thanosql
PREDICT
USING beans_mobilevit
AS (
    SELECT *
    FROM beans_test
    ORDER BY RANDOM()
    LIMIT 5
    )

Unnamed: 0,image,predict_result
0,"[[[-0.09718302637338638, -0.11430778354406357,...","[-1.734525203704834, -1.7788751125335693, 3.94..."
1,"[[[-1.3986643552780151, -1.4500386714935303, -...","[-1.6501493453979492, -1.6760544776916504, 3.8..."
2,"[[[-1.4157891273498535, -1.4842880964279175, -...","[-1.0188010931015015, 3.2187016010284424, -2.8..."
3,"[[[-1.2445416450500488, -1.158917784690857, -1...","[-1.5477955341339111, 2.5614449977874756, -1.3..."
4,"[[[-1.2445416450500488, -1.278791069984436, -1...","[-2.2948460578918457, -1.4243049621582031, 4.2..."


In [14]:
pred_df = _ 
pred_df["predict_result"] = pred_df["predict_result"].apply(np.argmax)
pred_df["predict_result"] = pred_df["predict_result"].apply(test_dataset.classes.__getitem__)
pred_df

Unnamed: 0,image,predict_result
0,"[[[-0.09718302637338638, -0.11430778354406357,...",healthy
1,"[[[-1.3986643552780151, -1.4500386714935303, -...",healthy
2,"[[[-1.4157891273498535, -1.4842880964279175, -...",bean_rust
3,"[[[-1.2445416450500488, -1.158917784690857, -1...",bean_rust
4,"[[[-1.2445416450500488, -1.278791069984436, -1...",healthy
