In [1]:
import boto3
import os
from sagemaker import get_execution_role, Session
from sagemaker.pytorch import PyTorchModel

# Initialize a SageMaker session
sagemaker_session = Session()

role = "arn:aws:iam::016114370410:role/tf-binding-sites"

prefix = "tf-binding-sites/inference/input"
local_dir = "/Users/wejarrard/projects/tf-binding/data/jsonl"

# Initialize the S3 client
s3 = boto3.client('s3')

# Specify your S3 bucket name
bucket_name = sagemaker_session.default_bucket()

# Function to delete all objects in a specified S3 bucket/prefix
def delete_s3_objects(bucket_name, prefix):
    response = s3.list_objects_v2(Bucket=bucket_name, Prefix=prefix)
    if 'Contents' in response:
        for item in response['Contents']:
            s3.delete_object(Bucket=bucket_name, Key=item['Key'])
        print(f"Deleted all objects in {bucket_name}/{prefix}")
    else:
        print(f"No objects found in {bucket_name}/{prefix} to delete.")

# Delete existing files from the specified S3 location
delete_s3_objects(bucket_name, prefix)

# Upload new files to the specified S3 location
inputs = sagemaker_session.upload_data(path=local_dir, key_prefix=prefix)
print(f"Input spec: {inputs}")


sagemaker.config INFO - Not applying SDK defaults from location: /Library/Application Support/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /Users/wejarrard/Library/Application Support/sagemaker/config.yaml
Deleted all objects in sagemaker-us-west-2-016114370410/tf-binding-sites/inference/input
Input spec: s3://sagemaker-us-west-2-016114370410/tf-binding-sites/inference/input


In [2]:

# Get model artifact location by estimator.model_data, or give an S3 key directly
model_artifact_s3_location = "s3://tf-binding-sites/finetuning/results/output/AR-LOO-THP-1-2024-05-15-00-00-49-482/output/model.tar.gz"

# Create PyTorchModel from saved model artifact
pytorch_model = PyTorchModel(
    model_data=model_artifact_s3_location,
    role=role,
    framework_version="2.1",
    py_version="py310",
    # image_uri="763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference:2.1.0-gpu-py310-cu118-ubuntu20.04-sagemaker-v1.8",
    source_dir="/Users/wejarrard/projects/tf-binding/src/inference/scripts",
    entry_point="inference.py",
    # code_location="inference/code",   
    sagemaker_session=sagemaker_session,
    )


# Create transformer from PyTorchModel object
output_path = f"s3://tf-binding-sites/inference/output/"

transformer = pytorch_model.transformer(instance_count=1, 
                                        instance_type="ml.g5.2xlarge", 
                                        output_path=output_path,
                                        strategy="MultiRecord",
                                        max_concurrent_transforms=1,
                                        max_payload=100,
                                    )


In [3]:
# Start the transform job
transformer.transform(
    data=inputs,
    data_type="S3Prefix",
    content_type="application/jsonlines",
    split_type="Line",
    # compression_type="Gzip",
    wait=False,
)

print(f"Transformation output saved to: {output_path}")

INFO:sagemaker:Creating transform job with name: pytorch-inference-2024-06-12-23-05-54-065


Transformation output saved to: s3://tf-binding-sites/inference/output/
