# Installation

To use the Azure Storage Connector for PyTorch, we can install it with `pip`:

In [None]:
%pip install azstoragetorch



And we can confirm `azstoragetorch` is installed by importing it:

In [None]:
import azstoragetorch
print(azstoragetorch)

Let's also install some other packages we'll need for later demos:

In [None]:
%pip install Pillow torchvision

# Bootstrap

Prior to running through the steps in this notebook, run the cell below to bootstrap resources needed for running this notebook. Make sure to replace `<replace-account-name>` with the Azure Storage account name you want to use for this notebook.

In running the `bootstrap.py` script, it will:
* Create a container named `azstoragetorchintro`
* Create a local directory `local-models` with a ResNet-18 model
* Upload the ResNet-18 model to the `azstoragetorchintro` container
* Upload the Caltech 101 dataset to the `azstoragetorchintro` container

In [None]:
%run bootstrap.py "https://<replace-account-name>.blob.core.windows.net"

Copy the `CONTAINER_URL` output value from previous cell to `CONTAINER_URL` value in cell below:

In [None]:
CONTAINER_URL = "https://<replace-account-name>.blob.core.windows.net/azstoragetorchintro"

# Loading and saving PyTorch models


## Loading a model

The core interfaces for loading a PyTorch model is the `torch.load()` function. Say we had model weights stored locally in the local directory `local-models`, we can load the model weights, using `torch.load()` passing in the name of the file or a file-like object from `open()`:

In [None]:
import torch

# Load from string of filename 
state_dict = torch.load("local-models/resnet18_weights.pth", weights_only=True)

# Load from file-like object
with open("local-models/resnet18_weights.pth", "rb") as f:
    state_dict = torch.load(f, weights_only=True)

print(state_dict)


These can then be loaded directly into the model:

In [None]:
import torchvision.models

resnet_model = torchvision.models.resnet18()
resnet_model.load_state_dict(state_dict)

`azstoragetorch` offers the `BlobIO` file-like object to easily load the model weights from a blob in Azure Blob Storage as if you were loading the models locally from disk. Just provide the URL to the blob and `rb` as the mode (just like you would for `open()`):

In [None]:
from azstoragetorch.io import BlobIO

with BlobIO(f"{CONTAINER_URL}/models/resnet18_weights.pth", "rb") as f:
    state_dict = torch.load(f, weights_only=True)
    print(state_dict)

## Save a model

PyTorch offers the `torch.save()` for saving a model. It allows you to save models locally:

In [None]:
# Save by filename
torch.save(resnet_model.state_dict(), "local-models/resnet18_weights_saved.pth")

# Or save using file-like object
with open("local-models/resnet18_weights_saved_by_filelike.pth", "wb") as f:
    torch.save(resnet_model.state_dict(), f)

And we can see the locally saved copies of the weights:

In [None]:
import os
os.listdir("local-models")

To upload the weights to Azure Blob Storage, we can use `BlobIO` again but this time in write mode (i.e., `wb`):

In [None]:
with BlobIO(f"{CONTAINER_URL}/models/resnet18_weights_saved.pth", "wb") as f:
    torch.save(resnet_model.state_dict(), f)

# PyTorch datasets

`azstoragetorch` offers a map-style dataset, `BlobDataset`, and an iterable-sytle dataset, `IterableBlobDataset`. To instantiate a dataset, use on of their class methods. For example, use `from_container_url()` to build the dataset by listing blobs in an Azure Storage container:

In [None]:
from azstoragetorch.datasets import BlobDataset

dataset = BlobDataset.from_container_url(CONTAINER_URL, prefix="datasets/caltech101")
print(len(dataset))

Data samples in the dataset map directly to blobs in the container. The default return value from datasets are dictionary representations of the blob. For example, we can access an arbitrary sample from our map-style dataset:

In [None]:
sample = dataset[4827]
print(sample)

And the `data` of the sample can be rendered into an image:

In [None]:
from PIL import Image
import io

img = Image.open(io.BytesIO(sample["data"]))
display(img)

Furthermore, these datasets can be directly provided to a PyTorch `DataLoader`

In [None]:
from torch.utils.data import DataLoader

loader = DataLoader(dataset)
for batch in loader:
    print(batch)
    break

However, this is likely not the format that a PyTorch model will expect. Specifically, it will want it as a `torch.Tensor`. This can be converted using a `transform` callable:

In [None]:
import torchvision.transforms
from PIL import Image

# Based on recommendation from PyTorch:
# https://pytorch.org/hub/pytorch_vision_resnet/
def blob_to_category_and_tensor(blob):
    with blob.reader() as f:
        img = Image.open(f).convert("RGB")
        img_transform = torchvision.transforms.Compose([
            torchvision.transforms.Resize(256),
            torchvision.transforms.CenterCrop(224),
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
        img_tensor = img_transform(img)
    # Get second to last component of blob name which will be the image category. For example:
    # blob.blob_name -> datasets/caltech101/dalmatian/image_0001.jpg
    # category -> dalmatian
    category = blob.blob_name.split("/")[-2]
    return category, img_tensor
    

We can now provide this transform to the dataset

In [None]:
from azstoragetorch.datasets import IterableBlobDataset

iterable_dataset = IterableBlobDataset.from_container_url(
    CONTAINER_URL,
    prefix="datasets/caltech101/dalmatian/",
    transform=blob_to_category_and_tensor
)
print(next(iter(iterable_dataset)))

We can run the resnet18 model from before in `eval()` mode to double check our transformations 

In [None]:
from torchvision.models import ResNet18_Weights
CATEGORIES = ResNet18_Weights.DEFAULT.meta["categories"]

loader = DataLoader(iterable_dataset, batch_size=32)

resnet_model.eval()
for _, img_tensors in loader:
    # Output tensor of confidence scores across each image for each supported category
    output = resnet_model(img_tensors)
    # Retrieve highest value index where indexes map to category ids
    category_ids = torch.argmax(output, dim=1)
    # Print human readable category (e.g. "dalmatian") for index with highest value
    print([CATEGORIES[category_id] for category_id in category_ids])
    break

# Cleanup

Run the `cleanup.py` script to cleanup all resources created from this notebook. Make sure to replace `<replace-account-name-from-bootstrap>` with the Azure Storage account name you specified in the Bootstrap section.

In running the script, it will:
* Delete `azstoragetorchintro` container and all blobs in container
* Delete `local-models` directory

In [None]:
%run cleanup.py "https://<replace-account-name-from-bootstrap>.blob.core.windows.net"