In [None]:
import os
import sagemaker
from sagemaker import get_execution_role
from sagemaker.tensorflow import TensorFlow

# Get the SageMaker execution role
role = get_execution_role()

bucket = 'sourceofimages'  # replace with your bucket name
train_data_uri = f's3://{bucket}/train'
val_data_uri = f's3://{bucket}/validation'

# Specify the output location
output_location = f's3://{bucket}/output'

# Create a TensorFlow Estimator
estimator = TensorFlow(entry_point='my_training_script.py',
                       role=role,
                       instance_count=1,
                       instance_type='ml.m5.4xlarge',
                       framework_version='2.3.0',
                       py_version='py37',
                       output_path=output_location,
                       hyperparameters={
                           'epochs': 10,
                           'batch-size': 32,
                           'learning-rate': 0.01
                       })

# Train the model
estimator.fit({'train': train_data_uri, 'validation': val_data_uri}, wait=True)  


In [None]:
import boto3
import tarfile
import json
import matplotlib.pyplot as plt

# Get the name of the training job
training_job_name = estimator.latest_training_job.name

# Create a SageMaker client
sm = boto3.client('sagemaker')

# Get the details of the training job
response = sm.describe_training_job(TrainingJobName=training_job_name)

# Get the S3 URI of the model artifacts
model_artifacts_s3_uri = response['ModelArtifacts']['S3ModelArtifacts']

# Create an S3 client
s3 = boto3.client('s3')

# Parse the S3 URI
parts = model_artifacts_s3_uri.replace("s3://", "").split("/")
bucket = parts.pop(0)
key = "/".join(parts)

# Download the model.tar.gz file
s3.download_file(bucket, key, 'model.tar.gz')

# Extract the model.tar.gz file
with tarfile.open('model.tar.gz', 'r:gz') as tar:
    tar.extractall()

# Load the training history
with open('history.json', 'r') as f:
    history = json.load(f)

# Plot the training and validation accuracy
plt.figure(figsize=(12, 6))
plt.plot(history['accuracy'], label='Training Accuracy')
plt.plot(history['val_accuracy'], label='Validation Accuracy')
plt.title('Training and Validation Accuracy Over Time')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.savefig("output.png")
plt.show()