## SageMaker Training using Weight & Bias



In [None]:
import boto3
import json
import os

from botocore.exceptions import ClientError

from dotenv import load_dotenv
load_dotenv("../../.env")

In [None]:
import os
import json

import sagemaker
from sagemaker.pytorch import PyTorch

session = sagemaker.Session()
region = session.boto_region_name

AWS_ACCOUNT_ID = os.environ.get('AWS_ACCOUNT_ID')

iam_role = f"arn:aws:iam::{AWS_ACCOUNT_ID}:role/sagemaker-execution-role"

training_job_output = "s3://sagemaker-wandb-samples/training-jobs/"


In [None]:


wandb_secret_name = "weights_and_bias_secret"

# Initialize the Secrets Manager client
secretsmanager = boto3.client('secretsmanager')

# Create the secret with WANDB_API_KEY
secret_value = {"WANDB_API_KEY": os.environ.get("WANDB_API_KEY")}
secret_string = json.dumps(secret_value)

# Create the secret in AWS Secrets Manager
try:
    # Try to get the secret to check if it exists
    secretsmanager.get_secret_value(SecretId=wandb_secret_name)
    
    # Secret exists, update it
    response = secretsmanager.update_secret(
        SecretId=wandb_secret_name,
        SecretString=secret_string
    )
    print(f"Secret updated successfully: {wandb_secret_name}")
    
except ClientError as e:
    if e.response['Error']['Code'] == 'ResourceNotFoundException':
        # Secret doesn't exist, create it
        try:
            response = secretsmanager.create_secret(
                Name=wandb_secret_name,
                SecretString=secret_string
            )
            print(f"Secret created successfully: {response['ARN']}")
        except Exception as create_error:
            print(f"Error creating secret: {str(create_error)}")
    else:
        # Other error occurred
        print(f"Error accessing secret: {str(e)}")

### Set hyperparameters

In [None]:
instance_type = 'ml.g5.xlarge'
image_uri = f"{AWS_ACCOUNT_ID}.dkr.ecr.us-east-1.amazonaws.com/mnist-training:latest"

estimator = PyTorch(
    entry_point="train.py",
    source_dir="code",
    role=iam_role,
    instance_type=instance_type, 
    instance_count=1,
    volume_size=50,
    output_path=training_job_output,
    hyperparameters={
        "epochs": 5
    }, 
    environment={
        "WANDB_SECRET_NAME": wandb_secret_name,
        "WANDB_PROJECT": "MNIST",
        "AWS_DEFAULT_REGION": "us-east-1", # for training script to access region-based resources - secret.
        # "WANDB_CHECKPOINT_NAME": 
        # "WANDB_CHECKPOINT_TAG": "latest"
    },
    image_uri=image_uri
)

In [None]:
estimator.fit(wait=False)