In [1]:
import argparse

import apache_beam as beam
from apache_beam.io import ReadFromText
from apache_beam.io import WriteToText
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.options.pipeline_options import SetupOptions

import pandas
import tensorflow_hub as hub
from joblib import load

In [2]:
# Load sentence encoder
sentence_encoder = 'https://tfhub.dev/google/universal-sentence-encoder/4'
embed = hub.load(sentence_encoder)

In [3]:
# Load model
model = load('kmeans_shakespeare.joblib')

In [4]:
# Example of how to format data for embedding and prediction

sample_text_string = "We know what we are, but know not what we may be"
sample_text_dataframe = pandas.DataFrame({sample_text_string})
# Run sample text through embeddinng model
embedding_sample = embed(sample_text_dataframe[0]).numpy().tolist() 
# Predict cluster that the text belongs to
predicted_topic = model.predict(embedding_sample)[0]

print('input text: {} | predicted topic: {}'.format(sample_text_string, predicted_topic))

input text: We know what we are, but know not what we may be | predicted topic: 4


In [5]:
# Another example to compare model output for a different input string

sample_text_string_2 = "Give every man thy ear, but few thy voice"
sample_text_dataframe_2 = pandas.DataFrame({sample_text_string_2})
# Run sample text through embeddinng model
embedding_sample_2 = embed(sample_text_dataframe_2[0]).numpy().tolist() 
# Predict cluster that the text belongs to
predicted_topic_2 = model.predict(embedding_sample_2)[0]

print('input text: {} | predicted topic: {}'.format(sample_text_string_2, predicted_topic_2))

input text: Give every man thy ear, but few thy voice | predicted topic: 3


In [6]:
class CleanTextDoFn(beam.DoFn):
  """Replace tabs and remove empty lines from each input line."""
  def process(self, element):
    """ Iterate over each input and apply embedding
    Args:
      element: the element being processed
    Returns:
      The processed element.
    """
    cleaned_element = element.replace('\t', ' ')
    
    if len(cleaned_element) > 1:
        yield cleaned_element

In [7]:
class EmbeddingDoFn(beam.DoFn):
  """Generate embedding from each input line."""
  def process(self, element):
    """ Iterate over each input and apply embedding
    Args:
      element: the element being processed
    Returns:
      The processed element.
    """
    input_dataframe = pandas.DataFrame({element})
    input_embedding = embed(input_dataframe[0]).numpy().tolist() 
    
    embedding = {
        'input_text': element,
        'input_embedding': input_embedding
    }
    
    if len(element) > 0:
        yield embedding

In [8]:
class PredictTopicDoFn(beam.DoFn):
  """Predict Topic from each embedding."""
  def process(self, element):
    """ Iterate over each input and apply embedding
    Args:
      element: the element being processed
    Returns:
      The processed element.
    """
    input_text = element['input_text']
    input_embedding = element['input_embedding']
    
    predicted_topic = model.predict(element['input_embedding'])
    
    predicted_topic = {
        'text': input_text,
        'predicted_topic': predicted_topic[0]
    }
    
    yield predicted_topic

In [11]:
def run(argv=None):
  """Main entry point; defines and runs the prediction pipeline."""
  parser = argparse.ArgumentParser()
  parser.add_argument(
      '--input',
      dest='input',
      default='./data/kinglear.txt',
      help='Input file to process.')
  parser.add_argument(
      '--output',
      dest='output',
      default='output.txt',
      help='Output file to write results to.')
  known_args, pipeline_args = parser.parse_known_args(argv)

  pipeline_options = PipelineOptions(pipeline_args)

  with beam.Pipeline(options=pipeline_options) as p:

    predictions = (
        p 
        | 'Read' >> ReadFromText(known_args.input)
        | 'Clean' >> (beam.ParDo(CleanTextDoFn()))
        | 'Embed' >> (beam.ParDo(EmbeddingDoFn()))
        | 'Predict Topic' >> (beam.ParDo(PredictTopicDoFn()))
    )

    predictions | 'Write' >> WriteToText(known_args.output)

In [12]:
run()

