# Fine-tune the Instructor-Model for Dialogue Summarization using SageMaker Training Jobs


# NOTE:  THIS NOTEBOOK WILL TAKE ABOUT 20 MINUTES TO COMPLETE.

# PLEASE BE PATIENT.

In [2]:
import boto3
import sagemaker
import pandas as pd

sess = sagemaker.Session()
bucket = sess.default_bucket()
role = sagemaker.get_execution_role()
region = boto3.Session().region_name

import botocore.config

config = botocore.config.Config(
    user_agent_extra='gaia/1.0'
)

sm = boto3.Session().client(service_name="sagemaker", 
                            region_name=region, 
                            config=config)

sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /root/.config/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /root/.config/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /root/.config/sagemaker/config.yaml


# _PRE-REQUISITE: You need to have succesfully run the notebooks in the `PREPARE` section before you continue with this notebook._

In [3]:
%store -r processed_train_data_s3_uri

In [4]:
try:
    processed_train_data_s3_uri
except NameError:
    print("++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
    print("[ERROR] Please run the notebooks in the PREPARE section before you continue.")
    print("++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")

In [5]:
print(processed_train_data_s3_uri)

s3://sagemaker-us-west-2-079002598131/sagemaker-scikit-learn-2023-10-25-19-28-08-587/output/train


In [6]:
%store -r processed_validation_data_s3_uri

In [7]:
try:
    processed_validation_data_s3_uri
except NameError:
    print("++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
    print("[ERROR] Please run the notebooks in the PREPARE section before you continue.")
    print("++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")

In [8]:
print(processed_validation_data_s3_uri)

s3://sagemaker-us-west-2-079002598131/sagemaker-scikit-learn-2023-10-25-19-28-08-587/output/validation


In [9]:
%store -r processed_test_data_s3_uri

In [10]:
try:
    processed_test_data_s3_uri
except NameError:
    print("++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
    print("[ERROR] Please run the notebooks in the PREPARE section before you continue.")
    print("++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")

In [11]:
print(processed_test_data_s3_uri)

s3://sagemaker-us-west-2-079002598131/sagemaker-scikit-learn-2023-10-25-19-28-08-587/output/test


# Specify the Dataset in S3
We are using the train, validation, and test splits created in the previous section.

In [12]:
print(processed_train_data_s3_uri)

!aws s3 ls $processed_train_data_s3_uri/

s3://sagemaker-us-west-2-079002598131/sagemaker-scikit-learn-2023-10-25-19-28-08-587/output/train
2023-10-25 19:35:07    2540571 1698262500203.parquet
2023-10-25 19:35:07    2545157 1698262503103.parquet


In [13]:
print(processed_validation_data_s3_uri)

!aws s3 ls $processed_validation_data_s3_uri/

s3://sagemaker-us-west-2-079002598131/sagemaker-scikit-learn-2023-10-25-19-28-08-587/output/validation
2023-10-25 19:35:07     150701 1698262500203.parquet
2023-10-25 19:35:07     150220 1698262503103.parquet


In [14]:
print(processed_test_data_s3_uri)

!aws s3 ls $processed_test_data_s3_uri/

s3://sagemaker-us-west-2-079002598131/sagemaker-scikit-learn-2023-10-25-19-28-08-587/output/test
2023-10-25 19:35:07     157115 1698262500203.parquet
2023-10-25 19:35:07     153865 1698262503103.parquet


# Specify S3 Input Data

In [15]:
from sagemaker.inputs import TrainingInput

s3_input_train_data = TrainingInput(s3_data=processed_train_data_s3_uri)
s3_input_validation_data = TrainingInput(s3_data=processed_validation_data_s3_uri)
s3_input_test_data = TrainingInput(s3_data=processed_test_data_s3_uri)

print(s3_input_train_data.config)
print(s3_input_validation_data.config)
print(s3_input_test_data.config)

{'DataSource': {'S3DataSource': {'S3DataType': 'S3Prefix', 'S3Uri': 's3://sagemaker-us-west-2-079002598131/sagemaker-scikit-learn-2023-10-25-19-28-08-587/output/train', 'S3DataDistributionType': 'FullyReplicated'}}}
{'DataSource': {'S3DataSource': {'S3DataType': 'S3Prefix', 'S3Uri': 's3://sagemaker-us-west-2-079002598131/sagemaker-scikit-learn-2023-10-25-19-28-08-587/output/validation', 'S3DataDistributionType': 'FullyReplicated'}}}
{'DataSource': {'S3DataSource': {'S3DataType': 'S3Prefix', 'S3Uri': 's3://sagemaker-us-west-2-079002598131/sagemaker-scikit-learn-2023-10-25-19-28-08-587/output/test', 'S3DataDistributionType': 'FullyReplicated'}}}


# Setup Hyper-Parameters for FLAN Model

In [16]:
model_checkpoint='google/flan-t5-base'

In [17]:
epochs = 1 # increase this if you want to train for a longer period
learning_rate = 0.00001
weight_decay = 0.01
train_batch_size = 4
validation_batch_size = 4
test_batch_size = 4
train_instance_count = 1
train_instance_type = "ml.c5.9xlarge"
train_volume_size = 1024
input_mode = "FastFile"
train_sample_percentage = 0.01 # increase this if you want to train on more data

# Setup Metrics To Track Model Performance

In [18]:
metrics_definitions = [
    {"Name": "train:loss", "Regex": "'train_loss': ([0-9\\.]+)"},
    {"Name": "validation:loss", "Regex": "'eval_loss': ([0-9\\.]+)"},
]

# Specify Checkpoint S3 Location
This is used for Spot Instances Training.  If nodes are replaced, the new node will start training from the latest checkpoint.

In [19]:
import uuid

checkpoint_s3_prefix = "checkpoints/{}".format(str(uuid.uuid4()))
checkpoint_s3_uri = "s3://{}/{}/".format(bucket, checkpoint_s3_prefix)

print(checkpoint_s3_uri)

s3://sagemaker-us-west-2-079002598131/checkpoints/13aab899-f629-4e25-a856-0fbfe8dce688/


# Setup Our Script to Run on SageMaker
Prepare our model to run on the managed SageMaker service

In [20]:
!pygmentize src/train.py

[34mimport[39;49;00m [04m[36margparse[39;49;00m[37m[39;49;00m
[34mimport[39;49;00m [04m[36mos[39;49;00m[37m[39;49;00m
[34mimport[39;49;00m [04m[36mjson[39;49;00m[37m[39;49;00m
[34mimport[39;49;00m [04m[36mpprint[39;49;00m[37m[39;49;00m
[37m[39;49;00m
[34mfrom[39;49;00m [04m[36mtransformers[39;49;00m [34mimport[39;49;00m AutoTokenizer, AutoModelForSeq2SeqLM, TrainingArguments, Trainer, GenerationConfig[37m[39;49;00m
[34mfrom[39;49;00m [04m[36mdatasets[39;49;00m [34mimport[39;49;00m load_dataset[37m[39;49;00m
[37m[39;49;00m
[34mdef[39;49;00m [32mlist_files[39;49;00m(startpath):[37m[39;49;00m
[37m    [39;49;00m[33m"""Helper function to list files in a directory"""[39;49;00m[37m[39;49;00m
    [34mfor[39;49;00m root, dirs, files [35min[39;49;00m os.walk(startpath):[37m[39;49;00m
        level = root.replace(startpath, [33m'[39;49;00m[33m'[39;49;00m).count(os.sep)[37m[39;49;00m
        indent = [33m'[39;49;00m[33m 

In [21]:
from sagemaker.pytorch import PyTorch

estimator = PyTorch(
    entry_point="train.py",
    source_dir="src",
    role=role,
    instance_count=train_instance_count,
    instance_type=train_instance_type,
    volume_size=train_volume_size,
    checkpoint_s3_uri=checkpoint_s3_uri,
    py_version="py39",
    framework_version="1.13",
    hyperparameters={
        "epochs": epochs,
        "learning_rate": learning_rate,
        "weight_decay": weight_decay,        
        "train_batch_size": train_batch_size,
        "validation_batch_size": validation_batch_size,
        "test_batch_size": test_batch_size,
        "model_checkpoint": model_checkpoint,
        "train_sample_percentage": train_sample_percentage,
    },
    input_mode=input_mode,
    metric_definitions=metrics_definitions,
)

sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /root/.config/sagemaker/config.yaml


# Train the Model on SageMaker

In [22]:
estimator.fit(
    inputs={"train": s3_input_train_data, "validation": s3_input_validation_data, "test": s3_input_test_data},
    wait=False,
)

INFO:sagemaker.image_uris:image_uri is not presented, retrieving image_uri based on instance_type, framework etc.
INFO:sagemaker:Creating training-job with name: pytorch-training-2023-10-25-19-41-48-426


Using provided s3_resource


In [23]:
training_job_name = estimator.latest_training_job.name
print("Training Job Name: {}".format(training_job_name))

Training Job Name: pytorch-training-2023-10-25-19-41-48-426


In [24]:
from IPython.core.display import display, HTML

display(
    HTML(
        '<b>Review <a target="blank" href="https://console.aws.amazon.com/sagemaker/home?region={}#/jobs/{}">Training Job</a> After About 5 Minutes</b>'.format(
            region, training_job_name
        )
    )
)

  from IPython.core.display import display, HTML


In [25]:
from IPython.core.display import display, HTML

display(
    HTML(
        '<b>Review <a target="blank" href="https://console.aws.amazon.com/cloudwatch/home?region={}#logStream:group=/aws/sagemaker/TrainingJobs;prefix={};streamFilter=typeLogStreamPrefix">CloudWatch Logs</a> After About 5 Minutes</b>'.format(
            region, training_job_name
        )
    )
)

  from IPython.core.display import display, HTML


In [26]:
from IPython.core.display import display, HTML

display(
    HTML(
        '<b>Review <a target="blank" href="https://s3.console.aws.amazon.com/s3/buckets/{}/{}/?region={}&tab=overview">S3 Output Data</a> After The Training Job Has Completed</b>'.format(
            bucket, training_job_name, region
        )
    )
)

  from IPython.core.display import display, HTML


In [27]:
%%time

estimator.latest_training_job.wait(logs=False)


2023-10-25 19:41:49 Starting - Starting the training job.
2023-10-25 19:42:03 Starting - Preparing the instances for training..............
2023-10-25 19:43:18 Downloading - Downloading input data....
2023-10-25 19:43:43 Training - Downloading the training image..........................
2023-10-25 19:45:59 Training - Training image download completed. Training in progress...........................................
2023-10-25 19:49:35 Uploading - Uploading generated training model..
2023-10-25 19:49:51 Completed - Training job completed
CPU times: user 568 ms, sys: 69.2 ms, total: 637 ms
Wall time: 8min 4s


# Deploy the Fine-Tuned Model to a Real Time Endpoint

In [28]:
sm_model = estimator.create_model(
    entry_point='inference.py',
    source_dir='src',
)
endpoint_name = training_job_name.replace('pytorch-training', 'summary-tuned')
predictor = sm_model.deploy(
    initial_instance_count=1,
    instance_type='ml.m5.2xlarge',
    endpoint_name=endpoint_name
)

INFO:sagemaker:Repacking model artifact (s3://sagemaker-us-west-2-079002598131/pytorch-training-2023-10-25-19-41-48-426/output/model.tar.gz), script artifact (src), and dependencies ([]) into single tar.gz file located at s3://sagemaker-us-west-2-079002598131/pytorch-training-2023-10-25-19-49-54-208/model.tar.gz. This may take some time depending on model size...
INFO:sagemaker:Creating model with name: pytorch-training-2023-10-25-19-49-54-208
INFO:sagemaker:Creating endpoint-config with name summary-tuned-2023-10-25-19-41-48-426
INFO:sagemaker:Creating endpoint with name summary-tuned-2023-10-25-19-41-48-426


-----!

# Zero Shot Inference with the Fine-Tuned Model in a SageMaker Endpoint

In [29]:
zero_shot_prompt = """Summarize the following conversation.

#Person1#: Tom, I've got good news for you.
#Person2#: What is it?
#Person1#: Haven't you heard that your novel has won The Nobel Prize?
#Person2#: Really? I can't believe it. It's like a dream come true. I never expected that I would win The Nobel Prize!
#Person1#: You did a good job. I'm extremely proud of you.
#Person2#: Thanks for the compliment.
#Person1#: You certainly deserve it. Let's celebrate!

Summary:"""

In [30]:
import json
from sagemaker import Predictor
predictor = Predictor(
    endpoint_name=endpoint_name,
    sagemaker_session=sess,
)
response = predictor.predict(zero_shot_prompt,
        {
            "ContentType": "application/x-text",
            "Accept": "application/json",
        },
)
response_json = json.loads(response.decode('utf-8'))
print(response_json)

Tom's novel has won the Nobel Prize.


# Tear Down the Endpoint

In [31]:
# predictor.delete_endpoint()