In [None]:
# In your SageMaker Notebook cell:
import sagemaker
from sagemaker.pytorch import PyTorch
import os

# --- 1. Define SageMaker Session, Role, and S3 Bucket ---
sess = sagemaker.Session()
role = sagemaker.get_execution_role() # IAM role with SageMaker permissions
s3_output_bucket = "spajjuri-transformers" # The default S3 bucket for this SageMaker session

# Define your S3 output path for model artifacts and other job outputs
# SageMaker will upload the contents of SM_MODEL_DIR and SM_OUTPUT_DATA_DIR here.
s3_job_output_prefix = "transformer-model-v2-training-job" # Unique prefix for this training job
s3_output_uri = f"s3://{s3_output_bucket}/{s3_job_output_prefix}"


# --- 2. Configure PyTorch Estimator for Managed Spot Training ---
estimator = PyTorch(
    entry_point='run.py',        # Your main training script
    source_dir='./my_transformer_code',               # Directory containing train.py, model.py, dataset.py, inference.py
                                   # SageMaker will automatically zip and upload this directory to S3.
    role=role,                     # IAM role for the training job
    framework_version='2.1',       # PyTorch version used in your script
    py_version='py310',            # Python version
    instance_count=1,              # Number of instances (for single GPU)
    instance_type='ml.g4dn.xlarge', # The instance type you want to use (e.g., with T4 GPU)
    output_path=s3_output_uri,     # S3 path where model artifacts will be saved after job completes

    # --- Managed Spot Training Configuration ---
    use_spot_instances=True,       # Enable Managed Spot Training
    max_run=3 * 3600,              # Max run time in seconds (e.g., 3 hours for entire job)
                                   # SageMaker will stop the job after this time, whether interrupted or not.
    max_wait=6 * 3600,             # Max wait time in seconds for a Spot instance (e.g., 6 hours)
                                   # If an instance isn't available within max_wait, the job fails.
    
    # --- Checkpoint Configuration (CRUCIAL for resuming interrupted Spot jobs) ---
    # This S3 URI specifies where SageMaker will store/retrieve checkpoints.
    # It must be within the same region as your training job.
    checkpoint_s3_uri=f"s3://{s3_output_bucket}/{s3_job_output_prefix}/checkpoints", 
    # checkpoint_s3_uri = None,
    # --- Hyperparameters to pass to your script ---
    # These will be passed as command-line arguments to train.py
    hyperparameters={
        'num-epochs': 200, # A more realistic number of epochs for Multi30k
        'batch-size': 128,
        'embed-size': 512,
        'num-layers': 6,
        'forward-expansion': 4,
        'heads': 8,
        'dropout': 0.1, # Recommended dropout for Transformers
        'max-length': 256,
        'lr': 0.001,
        'warmup-steps': 4000,
        'clip': 5.0,
    },
    disable_output_compression=False # Optional: keeps logs in S3 readable (not gzipped)
)

# --- 3. Start the training job ---
# For Multi30k, which downloads its own data, you might not need to specify a data input channel here.
# If you had data on S3 (e.g., `s3://my-data-bucket/multi30k/`), you'd provide it as:
# estimator.fit({'training': 's3://my-data-bucket/multi30k/'})
estimator.fit() 

# After the job completes, you can inspect the outputs in S3:
print(f"Training job outputs will be in: {s3_output_uri}")
# To retrieve the model artifact (e.g., for inference):
# trained_model_path = estimator.model_data
# print(f"Trained model artifacts are in: {trained_model_path}")

INFO:sagemaker.telemetry.telemetry_logging:SageMaker Python SDK will collect telemetry to help us better understand our user's needs, diagnose issues, and deliver additional features.
To opt out of telemetry, please disable via TelemetryOptOut parameter in SDK defaults config. For more information, refer to https://sagemaker.readthedocs.io/en/stable/overview.html#configuring-and-using-defaults-with-the-sagemaker-python-sdk.
INFO:sagemaker.image_uris:image_uri is not presented, retrieving image_uri based on instance_type, framework etc.
INFO:sagemaker.image_uris:image_uri is not presented, retrieving image_uri based on instance_type, framework etc.
INFO:sagemaker:Creating training-job with name: pytorch-training-2025-07-20-17-46-54-991


2025-07-20 17:46:56 Starting - Starting the training job...
2025-07-20 17:47:10 Starting - Preparing the instances for training...
2025-07-20 17:47:59 Downloading - Downloading the training image..