In [2]:
!pip install -Uqq sagemaker transformers[torch] pytorch-lightning datasets wandb


[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
distributed 2022.7.0 requires tornado<6.2,>=6.0.3, but you have tornado 6.3.3 which is incompatible.[0m[31m
[0m

In [3]:
import sagemaker

sagemaker_session = sagemaker.Session()
region = sagemaker_session.boto_region_name

bucket = sagemaker_session.default_bucket()
prefix = "sagemaker/food101"

role = sagemaker.get_execution_role()

sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /root/.config/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /root/.config/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /root/.config/sagemaker/config.yaml


In [5]:
!pip install torchvision

Collecting torchvision
  Obtaining dependency information for torchvision from https://files.pythonhosted.org/packages/84/eb/4f6483ae9094e164dc5b9b792e377f7d37823b0bedc3eef3193d416d2bb6/torchvision-0.16.0-cp310-cp310-manylinux1_x86_64.whl.metadata
  Downloading torchvision-0.16.0-cp310-cp310-manylinux1_x86_64.whl.metadata (6.6 kB)
Downloading torchvision-0.16.0-cp310-cp310-manylinux1_x86_64.whl (6.9 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.9/6.9 MB[0m [31m33.7 MB/s[0m eta [36m0:00:00[0m00:01[0m:00:01[0m
[?25hInstalling collected packages: torchvision
Successfully installed torchvision-0.16.0
[0m

In [None]:
from getpass import getpass

wandb_logging = True

if wandb_logging:
    wandb_api_key = getpass("Copy your WANDB_API_KEY: ")

In [None]:
import wandb
from pytorch_lightning.loggers import WandbLogger

if wandb_logging:
    wandb.login(key=wandb_api_key, relogin=True)

    logger = WandbLogger(
        project="Foodformer",
        name="VisionTransformer-base",
        checkpoint_name="vit",
        log_model=True,
        save_dir=".",
    )
else:
    logger = None

In [None]:
from functools import partial

from torchvision.datasets import Food101
from transformers import ViTFeatureExtractor

model_name_or_path = 'google/vit-base-patch16-224-in21k'
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name_or_path)

preprocessor = partial(feature_extractor, return_tensors='pt')

train_ds = Food101(root="food101_dataset", split='train', transform=preprocessor, download=True)
test_ds = Food101(root="food101_dataset", split='test', transform=preprocessor)

labels = train_ds.classes



Downloading https://data.vision.ee.ethz.ch/cvl/food-101.tar.gz to food101_dataset/food-101.tar.gz


100%|██████████| 4996278331/4996278331 [04:38<00:00, 17948913.66it/s]


Extracting food101_dataset/food-101.tar.gz to food101_dataset


In [None]:
names = ["train", "test"]
datasets = [train_ds, test_ds]

# 🏺 create our Artifact
raw_data = wandb.Artifact(
    "food101-train",
    type="dataset",
    description="Custom Food101 dataset, split into train/test",
    metadata={"sizes": [len(dataset) for dataset in datasets]},
)

for name, data in zip(names, datasets):
    with raw_data.new_file(name + ".pt", mode="wb") as fs:
        torch.save(data, fs)

# ✍️ Save the artifact to W&B.
wandb.run.log_artifact(raw_data)

In [None]:
inputs = sagemaker_session.upload_data(path="./food101_dataset", bucket=bucket, key_prefix=prefix)
print("input spec (in this case, just an S3 path): {}".format(inputs))

In [None]:
import sagemaker
from sagemaker.pytorch import PyTorch

sagemaker_session = sagemaker.Session()

pytorch_estimator = PyTorch(
    entry_point='train.py',
    instance_type='ml.g4dn.2xlarge',
    instance_count=1,
    framework_version='2.0.0',
    py_version='py310',
    output_path="s3://mlops-course-wei",
    role=role,
    dependencies=['requirements.txt'],
    # source_dir=".",
    # hyperparameters={},
)

In [None]:
pytorch_estimator.fit({"train": f"s3://{bucket}/{prefix}"}, wait=True)

In [None]:
predictor = pytorch_estimator.deploy(initial_instance_count=1, instance_type="ml.g4dn.2xlarge")

In [None]:
sagemaker_session.delete_endpoint(endpoint_name=predictor.endpoint_name)

In [None]:
wandb.finish()