In [1]:
import sagemaker
from sagemaker import get_execution_role
import json
import boto3

sess = sagemaker.Session()

role = get_execution_role()
bucket = sess.default_bucket() # Replace with your own bucket name if needed
print(bucket)
prefix = 'blazingtext/supervised' #Replace with the prefix under which you want to store the data if needed

sagemaker-us-east-1-869530972998


In [2]:
# download the training examples from dbpedia
!wget https://github.com/saurabh3949/Text-Classification-Datasets/raw/master/dbpedia_csv.tar.gz

--2019-11-25 14:25:47--  https://github.com/saurabh3949/Text-Classification-Datasets/raw/master/dbpedia_csv.tar.gz
Resolving github.com (github.com)... 140.82.113.4
Connecting to github.com (github.com)|140.82.113.4|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://raw.githubusercontent.com/saurabh3949/Text-Classification-Datasets/master/dbpedia_csv.tar.gz [following]
--2019-11-25 14:25:47--  https://raw.githubusercontent.com/saurabh3949/Text-Classification-Datasets/master/dbpedia_csv.tar.gz
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 151.101.248.133
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|151.101.248.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 68431223 (65M) [application/octet-stream]
Saving to: ‘dbpedia_csv.tar.gz.6’


2019-11-25 14:25:48 (236 MB/s) - ‘dbpedia_csv.tar.gz.6’ saved [68431223/68431223]



In [3]:
!tar -xzvf dbpedia_csv.tar.gz

dbpedia_csv/
dbpedia_csv/test.csv
dbpedia_csv/classes.txt
dbpedia_csv/train.csv
dbpedia_csv/readme.txt


In [4]:
!head dbpedia_csv/train.csv -n 3

1,"E. D. Abbott Ltd"," Abbott of Farnham E D Abbott Limited was a British coachbuilding business based in Farnham Surrey trading under that name from 1929. A major part of their output was under sub-contract to motor vehicle manufacturers. Their business closed in 1972."
1,"Schwan-Stabilo"," Schwan-STABILO is a German maker of pens for writing colouring and cosmetics as well as markers and highlighters for office use. It is the world's largest manufacturer of highlighter pens Stabilo Boss."
1,"Q-workshop"," Q-workshop is a Polish company located in Poznań that specializes in designand production of polyhedral dice and dice accessories for use in various games (role-playing gamesboard games and tabletop wargames). They also run an online retail store and maintainan active forum community.Q-workshop was established in 2001 by Patryk Strzelewicz – a student from Poznań. Initiallythe company sold its products via online auction services but in 2005 a website and online store wereestablis

In [5]:
!cat dbpedia_csv/classes.txt

Company
EducationalInstitution
Artist
Athlete
OfficeHolder
MeanOfTransportation
Building
NaturalPlace
Village
Animal
Plant
Album
Film
WrittenWork


In [6]:
index_to_label = {}
with open('dbpedia_csv/classes.txt', 'r') as f:
    for i,label in enumerate(f.readlines()):
        index_to_label[str(i+1)] = label.strip()
    

In [7]:
index_to_label

{'1': 'Company',
 '2': 'EducationalInstitution',
 '3': 'Artist',
 '4': 'Athlete',
 '5': 'OfficeHolder',
 '6': 'MeanOfTransportation',
 '7': 'Building',
 '8': 'NaturalPlace',
 '9': 'Village',
 '10': 'Animal',
 '11': 'Plant',
 '12': 'Album',
 '13': 'Film',
 '14': 'WrittenWork'}

In [8]:
from random import shuffle
import multiprocessing
from multiprocessing import Pool
import csv
import nltk
nltk.download('punkt')

[nltk_data] Downloading package punkt to /home/ec2-user/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [9]:
# Transform the row by replacing the index in the row into the actual class, prefixed with __label__
def transform_instance(row):
    cur_row = []
    label = "__label__" + index_to_label[row[0]]  #Prefix the index-ed label with __label__
    cur_row.append(label)
    cur_row.extend(nltk.word_tokenize(row[1].lower()))
    cur_row.extend(nltk.word_tokenize(row[2].lower()))
    return cur_row


In [10]:
def preprocess(input_file, output_file, keep=1):
    all_rows = []
    with open(input_file, 'r') as csvinfile:
        csv_reader = csv.reader(csvinfile, delimiter=',')
        for row in csv_reader:
            all_rows.append(row)
    shuffle(all_rows)
    all_rows = all_rows[:int(keep*len(all_rows))]
    pool = Pool(processes=multiprocessing.cpu_count())
    transformed_rows = pool.map(transform_instance, all_rows)
    pool.close() 
    pool.join()
    
    with open(output_file, 'w') as csvoutfile:
        csv_writer = csv.writer(csvoutfile, delimiter=' ', lineterminator='\n')
        csv_writer.writerows(transformed_rows)

In [11]:
def write_output(transformed_rows):
    with open(file_to_row_mapping[str(len(transformed_rows))], 'w') as csvoutfile:
        csv_writer = csv.writer(csvoutfile, delimiter=' ', lineterminator='\n')
        csv_writer.writerows(transformed_rows)

In [12]:
# preprocess
file_to_row_mapping = {}
def preprocess2(input_file, output_file, keep=1):
    all_rows = []
    with open(input_file, 'r') as csvinfile:
        csv_reader = csv.reader(csvinfile, delimiter=',')
        for row in csv_reader:
            all_rows.append(row)
    shuffle(all_rows)
    all_rows = all_rows[:int(keep*len(all_rows))]
    file_to_row_mapping[str(all_rows)] = output_file
    pool = Pool(processes=multiprocessing.cpu_count())
    pool.map_async(transform_instance, all_rows, callback=write_output)
    pool.close()
    return pool
    
#     with open(output_file, 'w') as csvoutfile:
#         csv_writer = csv.writer(csvoutfile, delimiter=' ', lineterminator='\n')
#         csv_writer.writerows(transformed_rows)

In [13]:
%%time
# Preparing the training dataset

# Since preprocessing the whole dataset might take a couple of mintutes,
# we keep 20% of the training dataset for this demo.
# Set keep to 1 if you want to use the complete dataset
preprocess('dbpedia_csv/train.csv', 'dbpedia.train', keep=.2)
        
# Preparing the validation dataset        
preprocess('dbpedia_csv/test.csv', 'dbpedia.validation')


CPU times: user 13 s, sys: 1.04 s, total: 14 s
Wall time: 1min 27s


In [14]:
train_channel = prefix + '/train'
validation_channel = prefix + '/validation'

sess.upload_data(path='dbpedia.train', bucket=bucket, key_prefix=train_channel)
sess.upload_data(path='dbpedia.validation', bucket=bucket, key_prefix=validation_channel)

s3_train_data = 's3://{}/{}'.format(bucket, train_channel)
s3_validation_data = 's3://{}/{}'.format(bucket, validation_channel)

In [15]:
s3_output_location = 's3://{}/{}/output'.format(bucket, prefix)

In [16]:
region_name = boto3.Session().region_name

In [17]:
container = sagemaker.amazon.amazon_estimator.get_image_uri(region_name, "blazingtext", "latest")
print('Using SageMaker BlazingText container: {} ({})'.format(container, region_name))

Using SageMaker BlazingText container: 811284229777.dkr.ecr.us-east-1.amazonaws.com/blazingtext:latest (us-east-1)


In [18]:
from sagemaker.tuner import ContinuousParameter, IntegerParameter

hyperparameter_ranges = {'learning_rate': ContinuousParameter(0.05, 0.06),
                         'batch_size' : IntegerParameter(10, 20)
                        }

In [19]:
objective_metric_name = 'validation:accuracy'


In [20]:
max_jobs = 5
max_parallel_jobs = 3

In [None]:
bt_model = sagemaker.estimator.Estimator(container,
                                         role, 
                                         train_instance_count=1, 
                                         train_instance_type='ml.c4.4xlarge',
                                         train_volume_size = 30,
                                         train_max_run = 360000,
                                         input_mode= 'File',
                                         output_path=s3_output_location,
                                         sagemaker_session=sess)

In [None]:
bt_model.set_hyperparameters(mode="supervised",
                            epochs=10,
                            min_count=2,
                            vector_dim=10,
                            early_stopping=True,
                            patience=4,
                            min_epochs=5,
                            word_ngrams=2)

In [None]:
from sagemaker.tuner import HyperparameterTuner

tuner = HyperparameterTuner(estimator=bt_model,
                                    objective_metric_name=objective_metric_name,
                                    hyperparameter_ranges=hyperparameter_ranges,
                                    max_jobs=max_jobs,
                                    max_parallel_jobs=max_parallel_jobs)

In [None]:
train_data = sagemaker.session.s3_input(s3_train_data, distribution='FullyReplicated', 
                        content_type='text/plain', s3_data_type='S3Prefix')
validation_data = sagemaker.session.s3_input(s3_validation_data, distribution='FullyReplicated', 
                             content_type='text/plain', s3_data_type='S3Prefix')
data_channels = {'train': train_data, 'validation': validation_data}

In [None]:
tuner.fit(inputs=data_channels, logs=True)
tuner.wait()

In [None]:
text_classifier = tuner.deploy(initial_instance_count=1, instance_type='ml.m4.xlarge')

In [None]:
sentences = ["Convair was an american aircraft manufacturing company which later expanded into rockets and spacecraft.",
            "Berwick secondary college is situated in the outer melbourne metropolitan suburb of berwick ."]
# using the same nltk tokenizer that we used during data preparation for training
tokenized_sentences = [' '.join(nltk.word_tokenize(sent)) for sent in sentences]

In [None]:
payload = {"instances" : tokenized_sentences, "configuration": {"k": 2}}

response = text_classifier.predict(json.dumps(payload))

In [None]:
predictions = json.loads(response)
print(json.dumps(predictions, indent=2))

In [None]:
sess.delete_endpoint(text_classifier.endpoint)