Skip to content

Commit

Permalink
Only create ssl context once (#62)
Browse files Browse the repository at this point in the history
Only create ssl context once
  • Loading branch information
bdraco committed Sep 6, 2023
1 parent 8944299 commit dc2a369
Showing 1 changed file with 31 additions and 15 deletions.
46 changes: 31 additions & 15 deletions sense_energy/asyncsenseable.py
@@ -1,6 +1,7 @@
import asyncio
import ssl
import sys
from functools import lru_cache
from time import time

import aiohttp
Expand All @@ -15,6 +16,21 @@
else:
from asyncio import timeout as asyncio_timeout


@lru_cache(maxsize=None)
def get_ssl_context(ssl_verify: bool, ssl_cafile: str) -> ssl.SSLContext:
"""Create or set the SSL context. Use custom ssl verification, if specified."""
if not ssl_verify:
ssl_context = ssl.create_default_context()
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE
elif ssl_cafile:
ssl_context = ssl.create_default_context(cafile=ssl_cafile)
else:
ssl_context = ssl.create_default_context()
return ssl_context


class ASyncSenseable(SenseableBase):
def __init__(
self,
Expand Down Expand Up @@ -42,14 +58,7 @@ def __init__(

def set_ssl_context(self, ssl_verify, ssl_cafile):
"""Create or set the SSL context. Use custom ssl verification, if specified."""
if not ssl_verify:
self.ssl_context = ssl.create_default_context()
self.ssl_context.check_hostname = False
self.ssl_context.verify_mode = ssl.CERT_NONE
elif ssl_cafile:
self.ssl_context = ssl.create_default_context(cafile=ssl_cafile)
else:
self.ssl_context = ssl.create_default_context()
self.ssl_context = get_ssl_context(ssl_verify, ssl_cafile)

async def authenticate(self, username, password, ssl_verify=True, ssl_cafile=""):
"""Authenticate with username (email) and password. Optionally set SSL context as well.
Expand All @@ -59,9 +68,11 @@ async def authenticate(self, username, password, ssl_verify=True, ssl_cafile="")

# Get auth token
async with self._client_session.post(
API_URL + "authenticate", headers=self.headers, timeout=self.api_timeout, data=auth_data
API_URL + "authenticate",
headers=self.headers,
timeout=self.api_timeout,
data=auth_data,
) as resp:

# check MFA code required
if resp.status == 401:
data = await resp.json()
Expand Down Expand Up @@ -91,9 +102,11 @@ async def validate_mfa(self, code):

# Get auth token
async with self._client_session.post(
API_URL + "authenticate/mfa", headers=self.headers, timeout=self.api_timeout, data=mfa_data
API_URL + "authenticate/mfa",
headers=self.headers,
timeout=self.api_timeout,
data=mfa_data,
) as resp:

# check for 200 return
if resp.status != 200:
raise SenseAuthenticationException(f"API Return Code: {resp.status}")
Expand All @@ -111,9 +124,11 @@ async def renew_auth(self):

# Get auth token
async with self._client_session.post(
API_URL + "renew", headers=self.headers, timeout=self.api_timeout, data=renew_data
API_URL + "renew",
headers=self.headers,
timeout=self.api_timeout,
data=renew_data,
) as resp:

# check for 200 return
if resp.status != 200:
raise SenseAuthenticationException(f"API Return Code: {resp.status}")
Expand Down Expand Up @@ -145,7 +160,8 @@ async def update_realtime(self, retry=True):

async def async_realtime_stream(self, callback=None, single=False):
"""Reads realtime data from websocket. Data is passed to callback if available.
Continues reading realtime stream data forever unless 'single' is set to True."""
Continues reading realtime stream data forever unless 'single' is set to True.
"""
url = WS_URL % (self.sense_monitor_id, self.sense_access_token)
# hello, features, [updates,] data
async with websockets.connect(url, ssl=self.ssl_context) as ws:
Expand Down

0 comments on commit dc2a369

Please sign in to comment.