In [None]:
%pip install --upgrade sagemaker

In [None]:
import sagemaker
import os
import boto3
from botocore.config import Config
from sagemaker import get_execution_role
from sagemaker.tensorflow import TensorFlow

role = get_execution_role()
sess = sagemaker.Session()
region = boto3.Session().region_name
bucket = sess.default_bucket()
prefix = 'octo/datasets/aloha_sim_dataset'
base_job_name = 'octo-flax'
py_version = "py310"
source_dir = f"{os.getcwd()}"
entry_point = "02_finetune_new_observation_action.py"
instance_type = "ml.g5.48xlarge"
instance_count = 1
keep_alive_period_in_seconds = 300
image_uri = f"763104351884.dkr.ecr.{region}.amazonaws.com/tensorflow-training:2.14.1-gpu-py310-cu118-ubuntu20.04-sagemaker"
s3_client = boto3.client('s3')

## Upload the Octo finetuning dataset to Amazon S3

In [None]:
!curl https://rail.eecs.berkeley.edu/datasets/example_sim_data.zip -O
!unzip example_sim_data.zip

In [None]:
local_directory = f"{os.getcwd()}/aloha_sim_dataset"

for root, dirs, files in os.walk(local_directory):

    for filename in files:

        local_path = os.path.join(root, filename)
        relative_path = os.path.relpath(local_path, local_directory)
        s3_path = os.path.join(prefix, relative_path)
    
        print('Searching "%s" in "%s"' % (s3_path, bucket))
        try:
            s3_client.head_object(Bucket=bucket, Key=s3_path)
            print(f"Path found on S3! Skipping {s3_path}...")
        except:
            print("Uploading %s..." % s3_path)
            s3_client.upload_file(local_path, bucket, s3_path)

In [None]:
!rm example_sim_data.zip
!rm -fr aloha_sim_dataset

## Finetune Octo on Amazon SageMaker

In [None]:
flax_estimator = TensorFlow(
    role = role,
    instance_count = instance_count,
    base_job_name = base_job_name,
    image_uri = image_uri,
    py_version = py_version,
    source_dir = source_dir,
    entry_point = entry_point,
    instance_type = instance_type,
    keep_alive_period_in_seconds = keep_alive_period_in_seconds,
)

flax_estimator.fit(inputs={
    'train': f"s3://{bucket}/{prefix}"})