# Import libraries

In [4]:
from time import sleep
import pandas as pd
import os

import boto3

s3 = boto3.client('s3')
comprehend = boto3.client('comprehend')
application_autoscaling = boto3.client('application-autoscaling')


In [2]:
BUCKET_NAME = "skills-hmoon-dataset"
KEY_PREFIX = "mynews"
COMPREHEND_DATA_ACCESS_ROLE_ARN = "arn:aws:iam::856210586235:role/comprehend-custom-role"


# Prepare data

In [25]:
DATA_ROOT_PATH = "./data/BBC News Summary/News Articles/"

news_text_file_paths = []

for root, dirs, files in os.walk(DATA_ROOT_PATH):
    for file in files:
        news_text_file_paths.append(f"{root}/{file}")


In [28]:
news_data = []

for filepath in news_text_file_paths:
    with open(filepath, "r") as f:
        lines = f.readlines()

        title = lines[0]
        description = lines[2]
        category = filepath.split("/")[-2]
        data = (title, description, category, filepath)
        news_data.append(data)


In [32]:
df = pd.DataFrame(news_data, columns=["title", "description", "category", "filepath"])
df

Unnamed: 0,title,description,category,filepath
0,Ad sales boost Time Warner profit\n,Quarterly profits at US media giant TimeWarner...,business,./data/BBC News Summary/News Articles/business...
1,Dollar gains on Greenspan speech\n,The dollar has hit its highest level against t...,business,./data/BBC News Summary/News Articles/business...
2,Yukos unit buyer faces loan claim\n,The owners of embattled Russian oil giant Yuko...,business,./data/BBC News Summary/News Articles/business...
3,High fuel prices hit BA's profits\n,British Airways has blamed high fuel prices fo...,business,./data/BBC News Summary/News Articles/business...
4,Pernod takeover talk lifts Domecq\n,Shares in UK drinks and food firm Allied Domec...,business,./data/BBC News Summary/News Articles/business...
...,...,...,...,...
2220,BT program to beat dialler scams\n,BT is introducing two initiatives to help beat...,tech,./data/BBC News Summary/News Articles/tech/397...
2221,Spam e-mails tempt net shoppers\n,Computer users across the world continue to ig...,tech,./data/BBC News Summary/News Articles/tech/398...
2222,Be careful how you code\n,A new European directive could put software wr...,tech,./data/BBC News Summary/News Articles/tech/399...
2223,US cyber security chief resigns\n,The man making sure US computer networks are s...,tech,./data/BBC News Summary/News Articles/tech/400...


In [42]:
df["text"] = df.apply(lambda x: x["title"] + x["description"], axis=1)
df["text"]

0       Ad sales boost Time Warner profit\nQuarterly p...
1       Dollar gains on Greenspan speech\nThe dollar h...
2       Yukos unit buyer faces loan claim\nThe owners ...
3       High fuel prices hit BA's profits\nBritish Air...
4       Pernod takeover talk lifts Domecq\nShares in U...
                              ...                        
2220    BT program to beat dialler scams\nBT is introd...
2221    Spam e-mails tempt net shoppers\nComputer user...
2222    Be careful how you code\nA new European direct...
2223    US cyber security chief resigns\nThe man makin...
2224    Losing yourself in online gaming\nOnline role ...
Name: text, Length: 2225, dtype: object

In [43]:
df.to_csv("news_train.csv", columns=["category", "text"], index=False, header=False)

In [47]:
s3.upload_file(
    Bucket=BUCKET_NAME,
    Key=KEY_PREFIX + "/news_train.csv",
    Filename="news_train.csv"
)

# Create custom classifier

In [49]:
train_s3_uri = f"s3://{BUCKET_NAME}/{KEY_PREFIX}/news_train.csv"

classifier_arn = comprehend.create_document_classifier(
    DocumentClassifierName="mynews",
    LanguageCode="en",
    DataAccessRoleArn="arn:aws:iam::856210586235:role/comprehend-custom-role",
    InputDataConfig={
        "DataFormat": "COMPREHEND_CSV",
        "S3Uri": train_s3_uri,
    }
)["DocumentClassifierArn"]
classifier_arn

'arn:aws:comprehend:us-east-1:856210586235:document-classifier/mynews'

In [51]:
# classifier_arn = "arn:aws:comprehend:us-east-1:856210586235:document-classifier/mynews"
status = "SUBMITTED"

print(f"Waiting for classifier: {classifier_arn}..", end="")
while status in ["SUBMITTED", "TRAINING"]:
    print(".", end="")
    sleep(10)
    response = comprehend.describe_document_classifier(
        DocumentClassifierArn=classifier_arn
    )
    status = response["DocumentClassifierProperties"]["Status"]
print(status)

Waiting for classifier: arn:aws:comprehend:us-east-1:856210586235:document-classifier/mynews............................................................................................................................................................................................................................TRAINED


In [52]:
endpoint_arn = comprehend.create_endpoint(
    EndpointName="mynews-endpoint",
    ModelArn=classifier_arn,
    DesiredInferenceUnits=1,
)["EndpointArn"]
endpoint_arn

'arn:aws:comprehend:us-east-1:856210586235:document-classifier-endpoint/mynews-endpoint'

In [3]:
# endpoint_arn = "arn:aws:comprehend:us-east-1:856210586235:document-classifier-endpoint/mynews-endpoint"
status = "CREATING"

print(f"Waiting for endpoint: {endpoint_arn}..", end="")
while status == "CREATING":
    print(".", end="")
    sleep(3)
    response = comprehend.describe_endpoint(
        EndpointArn=endpoint_arn
    )
    status = response["EndpointProperties"]["Status"]
print(status)

Waiting for endpoint: arn:aws:comprehend:us-east-1:856210586235:document-classifier-endpoint/mynews-endpoint...IN_SERVICE


In [55]:
def classify_document(text, endpoint_arn):
    response = comprehend.classify_document(
        Text=text,
        EndpointArn=endpoint_arn,
    )
    return response

In [58]:
title = "Musical future for phones"
description = "Analyst Bill Thompson has seen the future and it is in his son's hands."

text = f"{title}\n{description}"
classify_document(text, endpoint_arn)


{'Classes': [{'Name': 'tech', 'Score': 0.9997249245643616},
  {'Name': 'entertainment', 'Score': 9.097243309952319e-05},
  {'Name': 'sport', 'Score': 6.961222970858216e-05}],
 'ResponseMetadata': {'RequestId': '9c6ce21b-c02b-405b-be79-79c14862933e',
  'HTTPStatusCode': 200,
  'HTTPHeaders': {'x-amzn-requestid': '9c6ce21b-c02b-405b-be79-79c14862933e',
   'content-type': 'application/x-amz-json-1.1',
   'content-length': '156',
   'date': 'Mon, 21 Aug 2023 06:52:01 GMT'},
  'RetryAttempts': 0}}

# Autoscale endpoint

In [5]:
scalable_target_arn = application_autoscaling.register_scalable_target(
    ServiceNamespace="comprehend",
    ResourceId=endpoint_arn,
    ScalableDimension="comprehend:document-classifier-endpoint:DesiredInferenceUnits",
    MinCapacity=1,
    MaxCapacity=5
)["ScalableTargetARN"]
scalable_target_arn

'arn:aws:application-autoscaling:us-east-1:856210586235:scalable-target/0cm96110101d81a24307ac1537372faeaa1e'

In [6]:
scaling_policy = {
    "TargetValue": 70,
    "PredefinedMetricSpecification": {
        "PredefinedMetricType": "ComprehendInferenceUtilization"
    }
}

policy_arn = application_autoscaling.put_scaling_policy(
    PolicyName="mynewsEndpointScalingPolicy",
    PolicyType="TargetTrackingScaling",
    ServiceNamespace="comprehend",
    ResourceId=endpoint_arn,
    ScalableDimension="comprehend:document-classifier-endpoint:DesiredInferenceUnits",
    TargetTrackingScalingPolicyConfiguration=scaling_policy
)["PolicyARN"]
policy_arn

'arn:aws:autoscaling:us-east-1:856210586235:scalingPolicy:6110101d-81a2-4307-ac15-37372faeaa1e:resource/comprehend/arn:aws:comprehend:us-east-1:856210586235:document-classifier-endpoint/mynews-endpoint:policyName/mynewsEndpointScalingPolicy'