In [None]:
# Import python packages
import streamlit as st
import pandas as pd
from snowflake.ml.ray.datasource import SFStageImageDataSource, SFStageTextDataSource


# We can also use Snowpark for our analyses!
from snowflake.snowpark.context import get_active_session
session = get_active_session()


### Create a Data Source to read unstructured data

In [None]:
# reading the image, resize the image to 256 x 256 to lower memory requirement and help performance
image_source = SFStageImageDataSource(
    stage_location = "@DATA_STAGE_RAY/images/",
    database = "ST_DB",
    schema = "ST_SCHEMA",
    image_size=(256, 256),
)

In [None]:
# reading the label
# data is loaded after these two steps
# create pointer to data in stages
label_source = SFStageTextDataSource(
    stage_location = "@DATA_STAGE_RAY/labels/",
    database = "ST_DB",
    schema = "ST_SCHEMA",
)

### Load into a ray dataset

In [None]:
import ray

# everything is lazy loaded
# turns data source into a dataset
image_ds = ray.data.read_datasource(image_source)

In [None]:
print(image_ds.schema())

In [None]:
print(f'Total load {image_ds.count()} images')
image_ds.show(2)

In [None]:
# read the dataset, use 8 workers to read from the stage concurrently
label_ds = ray.data.read_datasource(label_source, concurrency=8)

In [None]:
print(label_ds.schema())

In [None]:
label_ds.show(1)

### Batch Process both dataset to include addition columns
**Image Dataset**: add a join key, encode the images, standardize image
**Label Dataset**: add a join key, interrpet the labels

Image standardization

In [None]:
import numpy as np
from typing import Dict
import base64
import os

def process_image(row):
    # If grayscale (2D), convert to 3D
    img = row['image']
    if len(img.shape) == 2:
        row['image'] = np.stack([img] * 3, axis=-1)  # Duplicate grayscale channel 3 times

    encoded_image = base64.b64encode(row['image'])
    row['encoded_image'] = encoded_image

    fn = row['file_name']
    join_id = os.path.splitext(fn)[0].split('/')[-1]
    row['join_id'] = join_id
    return row

# processed_image_ds = image_ds.map_batches(convert_to_torch, concurrency=4)
processed_image_ds = image_ds.map(process_image)

In [None]:
# force trigger operation for 1 image
processed_image_ds.show(1)

In [None]:
import os

def expand_label_column(batch: pd.DataFrame) -> pd.DataFrame:
    xmin_list = []
    ymin_list = []
    xmax_list = []
    ymax_list = []
    class_list = []
    file_names = []
    ids = []
    
    # Process each row
    for _, row in batch.iterrows():
        # Split the text and convert to list
        values = row['text'].strip().split()
        
        # Ensure we have exactly 5 values
        if len(values) != 5:
            raise ValueError(f"Expected 5 values in text, but got {len(values)} values")
            
        # Add values to respective lists
        xmin_list.append(float(values[0]))
        ymin_list.append(float(values[1]))
        xmax_list.append(float(values[2]))
        ymax_list.append(float(values[3]))
        class_list.append(int(values[4]))
        file_name = row['file_name']
        file_names.append(file_name)
        ids.append(os.path.splitext(file_name)[0].split('/')[-1] + '_test')
    
    # Create new dataframe
    new_df = pd.DataFrame({
        'join_id': ids,
        'file_name': file_names,
        'xmin': xmin_list,
        'ymin': ymin_list,
        'xmax': xmax_list,
        'ymax': ymax_list,
        'class': class_list,
    })
    return new_df 

processed_label_ds = label_ds.map_batches(expand_label_column, concurrency=6, batch_format='pandas')

In [None]:
processed_label_ds.show(1)

### Merge image source and label source into a single dataset
We have two ways of achieving this: 1) if customer is more famaliar with `pandas.Dataframe` and if the data fit into memory, then we can convert all data into pandas (or write into snowflake) and do the rest of the ops. 2) If the data does not fit into memory, we can directly leverage ray dataset to do the processing. 

**Note**: Ray dataset is not naturally architeched to support join ops, so it's better for to use other method (in memory / snowflake) to perform joins

#### Let's start with method 2 just to show it is possbile to do joins with ray as well

In [None]:
# currently ray dataset does not offer a JOIN between two dataset, we can offer a utility function 
# for customer to perform joins in container runtime
# skip this

joined_ds = optimized_dataset_join(
    processed_image_ds,
    processed_label_ds,
    left_on='join_id',
    right_on='join_id'
)

In [None]:
# This call currently is slow
# skip this

joined_ds.show(1)

#### Method 1: convert both dataset into pandas and perform joins

In [None]:
# show how to convert a ray dataset to a panda dataframe
image_df = processed_image_ds.drop_columns(cols=['image']).to_pandas()

In [None]:
# pandas - return first 5 rows
image_df.head()

In [None]:
label_df = processed_label_ds.to_pandas()

In [None]:
label_df.head()

In [None]:
# perform merge 
merged_train_df = pd.merge(image_df, label_df, how='inner', on='join_id')


In [None]:
merged_train_df.head()

## Save the Transformed Dataset to a snowflake table
Customer may also save the processed image dataset and label dataset into snowflake easily

In [None]:
from snowflake.ml.ray.datasink import SnowflakeTableDatasink

session.use_role(role="ACCOUNTADMIN")
session.use_database(database="ST_DB")
session.use_schema(schema="ST_SCHEMA")

table_to_save = "RAY_DEMO_JAN21_IMAGE_DS"
datasink = SnowflakeTableDatasink(
    table_name=table_to_save,
    database = "ST_DB",
    schema = "ST_SCHEMA",
    auto_create_table=True,
    override=True,
)

In [None]:
processed_image_ds.drop_columns(cols=['image']).write_datasink(datasink, concurrency=4)

In [None]:
SELECT * FROM RAY_DEMO_JAN21_IMAGE_DS;

In [None]:
table_to_save = "RAY_DEMO_JAN21_LABEL_DS"
datasink = SnowflakeTableDatasink(
    table_name=table_to_save,
    database = "ST_DB",
    schema = "ST_SCHEMA",
    auto_create_table=True,
    override=True,
)
processed_label_ds.write_datasink(datasink, concurrency=4)

In [None]:
SELECT * FROM RAY_DEMO_JAN21_LABEL_DS;

In [None]:
table_to_save = "RAY_DEMO_JAN21_COMINED_DS"
datasink = SnowflakeTableDatasink(
    table_name=table_to_save,
    database = "ST_DB",
    schema = "ST_SCHEMA",
    auto_create_table=True,
    override=True,
)
processed_label_ds.write_datasink(datasink, concurrency=4)

### Continued: Train a Pytorch Model 

The following steps shows how we can use the processed snowflake table and use it to train a Pytorch Model

The idea is to use the process dataframe that we just generated by Ray dataset, create a training loop that 1) defines training hyperparameter 2) define model architecture 3) distributed train the model usign multi-node cluster. 

In [None]:

import torch 
from torchvision import models, transforms
from PIL import Image
import io
from torch import nn
from torch.nn.parallel import DistributedDataParallel as DDP
from snowflake.ml.modeling.distributors.pytorch import get_context
from torch.utils.data import DataLoader

# ----------------------- 1. load data ---------------------------------------------

# First we load the data from snowflake
images = session.table("RAY_DEMO_JAN21_IMAGE_DS")
labels = session.table("RAY_DEMO_JAN21_LABEL_DS")

joined = images.join(labels, on="join_id")

def train_func():
    # ----------------------- 2. define training hyperparameters ---------------------------------
    NUM_CLASSES = 10
    BATCH_SIZE  = 32
    EPOCHS      = 3
    context = get_context()
    rank = context.get_rank()
    local_rank = context.get_local_rank()
    world_size = context.get_world_size()
    dist.init_process_group(
        backend="nccl",
        init_method="env://",
        rank=rank,
        world_size=world_size
    )
    torch.cuda.set_device(local_rank)
    DEVICE = torch.device(f"cuda:{local_rank}")
    train_ds = context.get_dataset_map()['train'].get_shard().to_torch_dataset()
    data_loader = DataLoader(
        train_ds,
        batch_size=BATCH_SIZE,
        shuffle=False,
        pin_memory=True,
    )

    tfm = transforms.Compose([transforms.ToTensor()])  # bytes -> CHW float-tensor

    def batch_to_torch(pdf):
        """pandas → dict(tensors)"""
        imgs = torch.stack([
            tfm(Image.open(io.BytesIO(b))) for b in pdf["IMG"].values
        ])
        boxes = torch.tensor(pdf[["XMIN","YMIN","XMAX","YMAX"]].values, dtype=torch.float32)
        labels = torch.tensor(pdf["CLASS"].values, dtype=torch.long)
        return {"img": imgs, "box": boxes, "cls": labels}

    # ----------------------- 3. model --------------------------------------------
    class TinyDetector(nn.Module):
        def __init__(self, n_cls):
            super().__init__()
            self.backbone = models.resnet18(weights=None)
            feats = self.backbone.fc.in_features
            self.backbone.fc = nn.Identity()
            self.head = nn.Linear(feats, 4 + n_cls)

        def forward(self, x):
            h = self.backbone(x)
            bbox, cls = h[:, :4], h[:, 4:]
            return bbox, cls

    net = TinyDetector(NUM_CLASSES).to(DEVICE)
    opt = torch.optim.AdamW(net.parameters(), lr=1e-4)
    l_bbox = nn.SmoothL1Loss()
    l_cls  = nn.CrossEntropyLoss()

    # ----------------------- 4. training loop ------------------------------------
    for ep in range(EPOCHS):
        for batch in data_loader:
            t = batch_to_torch(batch)
            x  = t["img"].to(DEVICE)
            yb = t["box"].to(DEVICE)
            yc = t["cls"].to(DEVICE)

            pb, pc = net(x)
            loss = l_bbox(pb, yb) + l_cls(pc, yc)

            opt.zero_grad()
            loss.backward()
            opt.step()

        print(f"epoch {ep+1}/{EPOCHS}  loss={loss.item():.4f}")

    print("✓ training done")

### Use Snowflake Distributed Pytorch API to Train model 
The following API support train pytorch a any kinds of cluster that is available in Snowflake, including multi-GPU or model node training. 

In [None]:
# Import Snowflake Distruptor and PyTorchDistributor
from snowflake.ml.data.data_conenctor import DataConnector
from snowflake.ml.modeling.distributors.pytorch import PyTorchDistributor
from snowflake.ml.modeling.distributors.pytorch import PyTorchDistributor, PyTorchScalingConfig, WorkerResourceConfig

train_data_connector = DataConnector.from_dataframe(joined)

# Create pytorch distributor. This will run the training function on the specified number of nodes and workers.
# In this case it will run with 4 nodes and 1 worker per node, each work as access to 6 cpus and 1 gpu.
pytorch_trainer = PyTorchDistributor(  
    train_func=train_func,
    scaling_config=PyTorchScalingConfig(  
        num_nodes=4,  
        num_workers_per_node=1,  
        resource_requirements_per_worker=WorkerResourceConfig(num_cpus=6, num_gpus=1),  
    )  
) 

pytorch_trainer.run(
    dataset_map={'train': train_data_connector}
)