In [None]:
# Import required libraries
import boto3
import sagemaker
from sagemaker import get_execution_role

# Initialize SageMaker session and role
sagemaker_session = sagemaker.Session()
role = get_execution_role()
bucket = 'sagemaker-edge-demo-<your-unique-id>'  # Replace with your S3 bucket name
prefix = 'gan-mnist'
region = boto3.Session().region_name

In [None]:
import torch
from torchvision import datasets, transforms

# Define a transform to normalize the data
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Download the training dataset
train_dataset = datasets.MNIST('./data', download=True, train=True, transform=transform)

In [None]:
from sagemaker.pytorch import PyTorch

estimator = PyTorch(
    entry_point='gan_mnist.py',
    source_dir='.',  # Directory containing the script
    role=role,
    framework_version='1.9',
    py_version='py38',
    instance_count=1,
    instance_type='ml.m5.large',
    output_path=f's3://{bucket}/{prefix}/output',
    hyperparameters={
        'n_epochs': 5,
        'batch_size': 64,
        'lr': 0.0002,
        'latent_dim': 100
    }
)

# Start training
estimator.fit()

In [None]:
compiled_model_path = f's3://{bucket}/{prefix}/compiled'

In [None]:
model_artifact = sagemaker_session.upload_data('model.tar.gz', bucket=bucket, key_prefix=f'{prefix}/model')
print('Model artifact uploaded to:', model_artifact)

In [None]:
compiled_model = estimator.compile_model(
    target_instance_family='ml_c5',
    input_shape={'latent_vector': [1, 100]},
    output_path=compiled_model_path,
    framework='pytorch',
    framework_version='1.9',
    role=role,
    model=model_artifact
)