<img src="https://saturn-public-assets.s3.us-east-2.amazonaws.com/example-resources/saturn.png" width="300">

# Inference with Snowflake and Saturn Cloud

This notebook contains steps for loading image files from a Snowflake unstructured table, and running image classification inference. 
Follow along in [our guide on the Snowflake website](https://quickstarts.snowflake.com/).

In [None]:
import pandas as pd
import requests, io, os, datetime, re  # noqa: E401
import snowflake.connector
from snowflake.connector.pandas_tools import write_pandas
import torch
from torchvision import transforms, models
import dask
from PIL import Image

to_pil = transforms.ToPILImage()
from dask_saturn import SaturnCluster
from dask.distributed import Client

## Set up Snowflake Connection

Credentials are stored in the Saturn Cloud credentials tool.

In [None]:
conn_kwargs = dict(
    user=os.environ["SNOWFLAKE_USER"],
    password=os.environ["SNOWFLAKE_PASSWORD"],
    account="mf80263.us-east-2.aws",
    warehouse="COMPUTE_WH",
    database="clothing",
    schema="PUBLIC",
    role="sysadmin",
)

## Set up Dask Cluster

In [None]:
cluster = SaturnCluster()
client = Client(cluster)
client.wait_for_workers(2)
client

## Define Functions

### Data Preprocessing

In [None]:
@dask.delayed
def preprocess(list_img_attr):
    """Ingest images directly from S3, apply transformations,
    and extract the ground truth and image identifier. Accepts
    a filepath."""

    path, snow_path, filesize, orig_timestamp = (
        list_img_attr[4],
        list_img_attr[0],
        list_img_attr[2],
        list_img_attr[3],
    )

    transform = transforms.Compose(
        [
            transforms.Resize(256),
            transforms.CenterCrop(250),
            transforms.ToTensor(),
        ]
    )

    file1 = requests.get(path).content
    img2 = Image.open(io.BytesIO(file1)).convert("RGB")
    nvis = transform(img2)

    truth = re.search(
        "clothing-dataset-small/test/([a-z-]+)\/([^\/]+(\.jpg))", path  # noqa: W605
    ).group(  # noqa: W605
        1
    )  # noqa: W605
    name = re.search(
        "clothing-dataset-small/test/([a-z-]+)\/([^\/]+(\.jpg))", path  # noqa: W605
    ).group(  # noqa: W605
        2
    )  # noqa: W605

    return [name, nvis, truth, path, snow_path, filesize, orig_timestamp]

In [None]:
@dask.delayed
def reformat(batch):
    batch_transposed = list(map(list, zip(*batch)))
    batch_transposed[1] = torch.stack(batch_transposed[1]).to(device)
    return batch_transposed

### Human Readable Predictions

In [None]:
def evaluate_pred_batch(batch, gtruth, classes):
    """Accepts batch of images, returns human readable predictions."""

    _, indices = torch.sort(batch, descending=True)
    percentage = torch.nn.functional.softmax(batch, dim=1)[0] * 100
    percentage, indices = percentage.cpu(), indices.cpu().numpy()

    preds = []
    labslist = []
    for i in range(len(batch)):
        pred = [(classes[idx], percentage[idx].item()) for idx in indices[i][:1]]
        preds.append(pred)

        labs = gtruth[i]
        labslist.append(labs)

    return (preds, labslist)


def is_match(label, pred):
    """Evaluates human readable prediction against ground truth."""
    if re.search(label.replace("_", " "), str(pred).replace("_", " ")):
        match = True
    else:
        match = False
    return match

### Run Inference

In [None]:
@dask.delayed
def run_batch_to_s3(iteritem):
    """Accepts iterable result of preprocessing, generates
    inferences and evaluates."""

    names, images, truelabels, paths, snow_paths, filesizes, orig_timestamps = iteritem

    indices = list(range(0, 10))
    classes = [
        "dress",
        "hat",
        "longsleeve",
        "outwear",
        "pants",
        "shirt",
        "shoes",
        "shorts",
        "skirt",
        "t-shirt",
    ]
    classes2 = dict(zip(indices, classes))

    # Retrieve, set up model
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    resnet = models.resnet50(pretrained=False)
    resnet.load_state_dict(torch.load("./model/modeltrained.pt"))
    resnet = resnet.to(device)

    with torch.no_grad():
        resnet.eval()
        pred_batch = resnet(images)

        # Evaluate batch
        preds, labslist = evaluate_pred_batch(pred_batch, truelabels, classes2)

        # Organize prediction results
        outcomes = []
        for j in range(0, len(images)):
            match = is_match(labslist[j], preds[j])
            outcome = {
                "name": names[j],
                "ground_truth": labslist[j],
                "prediction": preds[j],
                "prediction_text": preds[j][0][0],
                "prediction_prob": preds[j][0][1],
                "evaluation": match,
                "snow_path": snow_paths[j],
                "filesize": filesizes[j],
                "orig_timestamp": orig_timestamps[j],
            }
            outcomes.append(outcome)

        return outcomes

## Connect To Snowflake

Query for the image data from the `clothing_data` table.


In [None]:
stage = "clothing_dataset"
relative_path_col = "RELATIVE_PATH"

with snowflake.connector.connect(**conn_kwargs) as conn:
    df = pd.read_sql(
        f"""select FILE_URL,
    RELATIVE_PATH, SIZE, LAST_MODIFIED,
    get_presigned_url(@{stage}, {relative_path_col})
    as SIGNEDURL from clothing_test""",
        conn,
    )
    list_paths = df["SIGNEDURL"]

### Delayed Preprocessing Steps

In [None]:
n = 80  # batch size
list_df = [df[i : i + n] for i in range(0, df.shape[0], n)]
image_rows = [[x for j, x in y.iterrows()] for y in list_df]
image_batches1 = [[preprocess(list(x)) for x in y] for y in image_rows]
image_batches = [reformat(result) for result in image_batches1]

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# Model file placed on workers

from dask_saturn.plugins import RegisterFiles, sync_files

client.register_worker_plugin(RegisterFiles())
sync_files(client, "/home/jovyan/project/examples/model")
client.run(os.listdir, "./model")

## Run Inference on Cluster

In [None]:
%%time

futures = client.map(run_batch_to_s3, image_batches)
futures_gathered = client.gather(futures)
futures_computed = client.compute(futures_gathered, sync=False)

import logging

results = []
errors = []
for fut in futures_computed:
    try:
        result = fut.result()
    except Exception as e:
        errors.append(e)
        logging.error(e)
    else:
        results.extend(result)

Errors around inability to recognize or read the image may be a result of expired pre-signed links.

## Review Results

In [None]:
df2 = pd.DataFrame(results)
df2.dtypes

In [None]:
df2.head()

### Check Quality of Inference

This calculation just tells you what percent of your model's predictions were correct.

In [None]:
true_preds = [x["evaluation"] for x in results if x["evaluation"] == True]  # noqa: E712
false_preds = [x["evaluation"] for x in results if x["evaluation"] == False]  # noqa: E712
len(true_preds) / len(results) * 100

## Visualize

This section will show some samples of predictions and contrast with the ground truth.

In [None]:
sample = dask.compute(*image_batches)
s5 = list(map(list, zip(*sample)))

test_names = [i for sublist in s5[0] for i in sublist]
test_tensors = [i for sublist in s5[1] for i in sublist]
test_orig = [i for sublist in s5[2] for i in sublist]
test_final = list(zip(test_names, test_tensors, test_orig))

In [None]:
expanded_list = [
    (i, j)
    for i in results
    for j in test_final
    if i["name"] in test_names and j[0] in test_names and i["name"] == j[0]
]

In [None]:
# noqa: W291
import matplotlib.pyplot as plt

cpudevice = torch.device("cpu")

to_pil = transforms.ToPILImage()
imglist = expanded_list[325:330]
f, ax = plt.subplots(nrows=1, ncols=5, figsize=(16, 6))

for i in range(0, 5):
    img1 = to_pil(imglist[i][1][1].to(cpudevice))
    ax[i].imshow(img1).axes.xaxis.set_visible(False)
    ax[i].axes.yaxis.set_visible(False)
    textcol = "green" if imglist[i][0]["evaluation"] == True else "red"  # noqa: E712
    ax[i].set_title(
        f"""Predicted Class: {imglist[i][0]["prediction_text"]} 
    Actual Class: {imglist[i][0]["ground_truth"]} """,  # noqa: W291
        color=textcol,
    )

title = "Sample Images"
f.suptitle(title, fontsize=16)
plt.tight_layout()
plt.show()

## Load Results to Snowflake

Populate a temp table, update the permanent table, then remove the temp table.

In [None]:
make_table = """
    CREATE OR REPLACE TABLE clothing_temp
    (
      FILE_URL VARCHAR,
      SIZE NUMBER,
      LAST_MODIFIED TIMESTAMP_LTZ,
      TYPE VARCHAR,
      CONFIDENCE FLOAT8,
      PRED_TIMESTAMP TIMESTAMP_LTZ
    )
    """

check_library = "show tables in CLOTHING.PUBLIC"

update_query = """
    update clothing_test
      set clothing_test.TYPE = clothing_temp.TYPE,
          clothing_test.CONFIDENCE = clothing_temp.CONFIDENCE,
          clothing_test.PRED_TIMESTAMP = clothing_temp.PRED_TIMESTAMP
      from clothing_temp
      where clothing_test.FILE_URL = clothing_temp.FILE_URL
      and  clothing_test.SIZE = clothing_temp.SIZE
"""

clean_house = "drop table if exists clothing_temp"

In [None]:
with snowflake.connector.connect(**conn_kwargs) as conn:
    cur = conn.cursor()
    cur.execute(make_table)
    print("Temp table created.")
    snow_df = df2[
        ["snow_path", "filesize", "orig_timestamp", "prediction_text", "prediction_prob"]
    ].copy()
    snow_df.rename(
        columns={
            "snow_path": "FILE_URL",
            "filesize": "SIZE",
            "orig_timestamp": "LAST_MODIFIED",
            "prediction_text": "TYPE",
            "prediction_prob": "CONFIDENCE",
        },
        inplace=True,
    )
    snow_df["PRED_TIMESTAMP"] = pd.to_datetime(datetime.datetime.now()).tz_localize("UTC")
    success, nchunks, nrows, _ = write_pandas(conn, snow_df, "CLOTHING_TEMP")
    print(f"Temp results table created: {success}. Rows inserted in table: {nrows}.")
    res = cur.execute(update_query)
    print(f"Updated {res.rowcount} rows in permanent table from temp source.")
    cur.execute(clean_house)
    print("Temp table removed.")