# License

Copyright 2019 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");  
you may not use this file except in compliance with the License.  
You may obtain a copy of the License at . 

      http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing,  
software distributed under the License is distributed on an "AS IS" BASIS,  
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  
See the License for the specific language governing permissions and  
limitations under the License.

# Pre-work

In [0]:
# @title Upload files (skip this if this is run locally)

# Use this cell to update the following files
#   1. requirements.txt
#   2. e2e_demo_credential.json
#   3. roots.pem
#   4. rs256.key
from google.colab import files
uploaded = files.upload()

In [3]:
# @title Install missing packages

# run this cell to install packages if some are missing
!pip install -r requirements.txt
!pip install gcsfs
!pip install pyjwt
!pip install paho-mqtt



In [0]:
# @title Import libraries

import json
import os
import gcsfs
import numpy as np
import pandas as pd
import time
import google.cloud.bigquery
import datetime
import random
import ssl
import jwt
import paho.mqtt.client as mqtt
import pandas as pd
import json
import sys

In [0]:
# @title Configurations

# project related
GOOGLE_CLOUD_PROJECT = 'my-project-fy' #@param
GOOGLE_APPLICATION_CREDENTIALS = 'e2e_demo_credential.json' #@param
os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = GOOGLE_APPLICATION_CREDENTIALS

# IoT related
REGION = 'asia-east1' #@param
REGISTRY_ID = 'fy-project-registry' #@param
DEVICE_IDS = ['fy-raspi'] #@param
PRIVATE_KEY_PATH = 'rs256 (2).key' #@param
CA_CERTIFICATES_PATH = 'roots.pem' #@param

# data related
POWER_DATA_PATH = 'gs://gcp_blog/e2e_demo/test.csv'
DATASET_ID = 'EnergyDisaggregation'

# Start to publish data

When the following cells are run, the program mimics multiple electric meters each of which sends readings at the specified interval to a server residing in GAE.  
The GAE forwards the readings to CMLE where a trained model is deployed, and it publishes the data together with CMLE's predictions to pub/sub, both the data and the predictions are also saved in BigQuery for future analysis.

In [0]:
# @title Utility

def load_data(project, path):
  """Load test data from GCS.

  Load test csv data from GCS.

  Returns:
    pandas.DataFrame, active power data.
  """
  fs = gcsfs.GCSFileSystem(project)
  with fs.open(path) as f:
    power_data = pd.read_csv(f, index_col=0)['gross']
  return power_data

# The maximum backoff time before giving up, in seconds.
MAXIMUM_BACKOFF_TIME = 32

class Publisher:
    """Publish power data every n seconds."""

    # The initial backoff time after a disconnection occurs, in seconds.
    minimum_backoff_time = 1

    # Whether to wait with exponential backoff before publishing.
    should_backoff = False

    def __init__(self,
            project_id,
            region,
            registry_id,
            device_id,
            private_key_path,
            jwt_algorithm,
            jwt_exp_mins,
            ca_certs_path,
            mqtt_bridge_hostname,
            mqtt_bridge_port,
            seq_len,
            sub_topic='events',
            interval=1):
        self.project_id = project_id
        self.region = region
        self.registry_id = registry_id
        self.device_id = device_id
        self.private_key_path = private_key_path
        self.jwt_algorithm = jwt_algorithm
        self.jwt_exp_mins = jwt_exp_mins
        self.ca_certs_path = ca_certs_path
        self.mqtt_bridge_hostname = mqtt_bridge_hostname
        self.mqtt_bridge_port = mqtt_bridge_port
        self.seq_len=seq_len
        self.sub_topic = sub_topic
        self.interval = interval

        self._count = 0
        self._data = {}
        self._mqtt_topic = '/devices/{}/{}'.format(self.device_id, self.sub_topic)
        self._jwt_iat = datetime.datetime.utcnow()
        self._active = False


    def create_jwt(self):
        """Creates a JWT (https://jwt.io) to establish an MQTT connection.
            Returns:
                An MQTT generated from the given project_id and private key, which
                expires in 20 minutes. After 20 minutes, your client will be
                disconnected, and a new JWT will have to be generated.
            Raises:
                ValueError: If the private_key_file does not contain a known key.
            """
        token = {
                # The time that the token was issued at
                'iat': datetime.datetime.utcnow(),
                # The time the token expires.
                'exp': datetime.datetime.utcnow() + datetime.timedelta(minutes=60),
                # The audience field should always be set to the GCP project id.
                'aud': self.project_id
        }

        # Read the private key file.
        with open(self.private_key_path, 'r') as f:
            private_key = f.read()

        print('Creating JWT using {} from private key file {}'.format(
                self.jwt_algorithm, self.private_key_path))

        return jwt.encode(token, private_key, algorithm=self.jwt_algorithm)

    def get_client(self):
        """Create our MQTT client. The client_id is a unique string that identifies
        this device. For Google Cloud IoT Core, it must be in the format below."""
        client = mqtt.Client(
                client_id=('projects/{}/locations/{}/registries/{}/devices/{}'
                           .format(
                                   self.project_id,
                                   self.region,
                                   self.registry_id,
                                   self.device_id)))

        # With Google Cloud IoT Core, the username field is ignored, and the
        # password field is used to transmit a JWT to authorize the device.
        client.username_pw_set(
                username='unused',
                password=self.create_jwt())

        # Enable SSL/TLS support.
        client.tls_set(ca_certs=self.ca_certs_path, tls_version=ssl.PROTOCOL_TLSv1_2)

        # Register message callbacks. https://eclipse.org/paho/clients/python/docs/
        # describes additional callbacks that Paho supports.
        client.on_connect = self.on_connect
        client.on_publish = self.on_publish
        client.on_disconnect = self.on_disconnect
        
        return client

    def publish(self):
        # Finish if inactive
        if not self._active:
            return

        # Process network events.
        self.client.loop()

        # Wait if backoff is required.
        if Publisher.should_backoff:
            # If backoff time is too large, give up.
            if Publisher.minimum_backoff_time > MAXIMUM_BACKOFF_TIME:
                print('Exceeded maximum backoff time. Giving up.')
                return
            # Otherwise, wait and connect again.
            delay = Publisher.minimum_backoff_time + random.randint(0, 1000) / 1000.0
            print('Waiting for {} before reconnecting.'.format(delay))
            time.sleep(delay)
            Publisher.minimum_backoff_time *= 2
            self.client.connect(self.mqtt_bridge_hostname, self.mqtt_bridge_port)

        # Refresh token if JWT iat has been expired.
        seconds_since_issue = (datetime.datetime.utcnow() - self._jwt_iat).seconds
        if seconds_since_issue > 60 * self.jwt_exp_mins:
            print('Refreshing token after {}s').format(seconds_since_issue)
            self._jwt_iat = datetime.datetime.utcnow()
            self.client = self.get_client()

        # Generate payload
        d, t = self._data[self._count]
        Publisher.rotate_message(self._msg, d, t)
        payload = json.dumps(self._msg).encode('utf-8')

        # Publish "payload" to the MQTT topic. qos=1 means at least once
        # delivery. Cloud IoT Core also supports qos=0 for at most once
        # delivery.
        self.client.publish(self._mqtt_topic, payload, qos=1)
        print('Published: #{0:03d} on {1} sent: ({2}, {3}).'.format(self._count+1, self.device_id, self._msg['power'][-1], self._msg['timestamp'][-1]))
        self._count += 1

    def start(self, data, start_time=None, cnt=0):
        """Publish payloads to the MQTT topic.
            Args:
             data: pandas.DataFrame, active power data.
             interval: int, time in second.
             start_time: str, publish data collected from start_time.
             cnt: int, number of data records to publish. If cnt<=0, publish all.
            """
        data = Publisher.trim_data(data, start_time)
        self.client = self.get_client()

        data_info = data.values.tolist()
        time_info = data.index.values.tolist()
        self._data = zip(data_info[self.seq_len-1:], time_info[self.seq_len-1:])
        self._msg = {'device_id': self.device_id, 'power': [0] + data_info[:self.seq_len-1],
            'timestamp': [0] + time_info[:self.seq_len-1]}
        self._active = True

        self.client.loop_start()
        
        # Connect to the Google MQTT bridge.
        self.client.connect(self.mqtt_bridge_hostname, self.mqtt_bridge_port)
    
    def stop(self):
        print "Stop loop"
        self._active = False
        self.client.loop_stop()
        self.client.disconnect()

    @staticmethod
    def rotate_message(msg, data, time):
        """Rotate msg according to the prescribed rule.
            Args:
             msg: dict,
             data: int, active power data.
             time: str, the time of measurement
            Returns: Payload for use in publishing
            """
        msg['power'].pop(0)
        msg['power'].append(data)
        msg['timestamp'].pop(0)
        msg['timestamp'].append(time)

    @staticmethod
    def trim_data(data, start_time):
        """Trims data with start time.
            Args:
             data: pandas.DataFrame, active power data.
             start_time: str, publish data collected from start_time.
            Returns: Trimmed data
            """
        if start_time is not None:
            print('before data trimming: data.shape={}'.format(data.shape))
            data = data[data.index.values >= start_time]
            print('after data trimming: data.shape={}'.format(data.shape))
        return data

    def error_str(self, rc):
        """Convert a Paho error to a human readable string."""
        return '{}: {}'.format(rc, mqtt.error_string(rc))

    def on_connect(self, unused_client, unused_userdata, unused_flags, rc):
        """Callback for when a device connects."""
        print('connected: {}'.format(mqtt.connack_string(rc)))

        # After a successful connect, reset backoff time and stop backing off.
        Publisher.should_backoff = False
        Publisher.minimum_backoff_time = 1

        # Trigger initial publish
        self.publish()

    def on_disconnect(self, unused_client, unused_userdata, rc):
        """Paho callback for when a device disconnects."""
        print('disconnected because: {}'.format(error_str(rc)))

        # Since a disconnect occurred, the next loop iteration will wait with
        # exponential backoff.
        Publisher.should_backoff = True

    def on_publish(self, unused_client, unused_userdata, unused_mid):
        """Paho callback when a message is sent to the broker."""
        time.sleep(self.interval)
        self.publish()

In [7]:
# @title Load power consumption data

power_data = load_data(GOOGLE_CLOUD_PROJECT, POWER_DATA_PATH)
print power_data.shape

(432000,)


In [9]:
# @title Send power data

publishers = []
for device_id in DEVICE_IDS:
    publisher = Publisher(
            project_id=GOOGLE_CLOUD_PROJECT,
            region=REGION,
            registry_id=REGISTRY_ID,
            device_id=device_id,
            private_key_path=PRIVATE_KEY_PATH,
            ca_certs_path=CA_CERTIFICATES_PATH,
            jwt_exp_mins=20,
            jwt_algorithm='RS256',
            mqtt_bridge_hostname='mqtt.googleapis.com',
            mqtt_bridge_port=8883,
            seq_len=20,
            sub_topic='events')

    publisher.start(data=power_data, start_time='2013-09-21 20:18:00')
    publishers.append(publisher)

try: 
    while True:
        time.sleep(1)
except KeyboardInterrupt:
    for publisher in publishers:
        publisher.stop()

Published: #120 on fy-raspi sent: (544, 2013-09-21 20:31:49).
Published: #121 on fy-raspi sent: (548, 2013-09-21 20:31:55).
Stop loop
