Skip to content

Commit

Permalink
Add emeter support for strip sockets (#203)
Browse files Browse the repository at this point in the history
* Add support for plugs with emeters.

* Tweaks for emeter

* black

* tweaks

* tweaks

* more tweaks

* dry

* flake8

* flake8

* legacy typing

* Update kasa/smartstrip.py

Co-authored-by: Teemu R. <tpr@iki.fi>

* reduce

* remove useless delegation

* tweaks

* tweaks

* dry

* tweak

* tweak

* tweak

* tweak

* update tests

* wrap

* preen

* prune

* prune

* prune

* guard

* adjust

* robust

* prune

* prune

* reduce dict lookups by 1

* Update kasa/smartstrip.py

Co-authored-by: Teemu R. <tpr@iki.fi>

* delete utils

* isort

Co-authored-by: Brendan Burns <brendan.d.burns@gmail.com>
Co-authored-by: Teemu R. <tpr@iki.fi>
  • Loading branch information
3 people committed Sep 23, 2021
1 parent d720288 commit 94e5a90
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 101 deletions.
83 changes: 37 additions & 46 deletions kasa/smartdevice.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,14 @@
You may obtain a copy of the license at
http://www.apache.org/licenses/LICENSE-2.0
"""
import collections.abc
import functools
import inspect
import logging
from dataclasses import dataclass
from datetime import datetime, timedelta
from enum import Enum, auto
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Set

from .emeterstatus import EmeterStatus
from .exceptions import SmartDeviceException
Expand Down Expand Up @@ -51,6 +52,16 @@ class WifiNetwork:
rssi: Optional[int] = None


def merge(d, u):
"""Update dict recursively."""
for k, v in u.items():
if isinstance(v, collections.abc.Mapping):
d[k] = merge(d.get(k, {}), v)
else:
d[k] = v
return d


def requires_update(f):
"""Indicate that `update` should be called before accessing this method.""" # noqa: D202
if inspect.iscoroutinefunction(f):
Expand Down Expand Up @@ -204,6 +215,11 @@ def _create_request(

return request

def _verify_emeter(self) -> None:
"""Raise an exception if there is no emeter."""
if not self.has_emeter:
raise SmartDeviceException("Device has no emeter")

async def _query_helper(
self, target: str, cmd: str, arg: Optional[Dict] = None, child_ids=None
) -> Any:
Expand Down Expand Up @@ -240,13 +256,17 @@ async def _query_helper(

return result

@property # type: ignore
@requires_update
def features(self) -> Set[str]:
"""Return a set of features that the device supports."""
return set(self.sys_info["feature"].split(":"))

@property # type: ignore
@requires_update
def has_emeter(self) -> bool:
"""Return True if device has an energy meter."""
sys_info = self.sys_info
features = sys_info["feature"].split(":")
return "ENE" in features
return "ENE" in self.features

async def get_sys_info(self) -> Dict[str, Any]:
"""Retrieve system information."""
Expand Down Expand Up @@ -374,10 +394,8 @@ def location(self) -> Dict:
@requires_update
def rssi(self) -> Optional[int]:
"""Return WiFi signal strenth (rssi)."""
sys_info = self.sys_info
if "rssi" in sys_info:
return int(sys_info["rssi"])
return None
rssi = self.sys_info.get("rssi")
return None if rssi is None else int(rssi)

@property # type: ignore
@requires_update
Expand Down Expand Up @@ -410,16 +428,12 @@ async def set_mac(self, mac):
@requires_update
def emeter_realtime(self) -> EmeterStatus:
"""Return current energy readings."""
if not self.has_emeter:
raise SmartDeviceException("Device has no emeter")

self._verify_emeter()
return EmeterStatus(self._last_update[self.emeter_type]["get_realtime"])

async def get_emeter_realtime(self) -> EmeterStatus:
"""Retrieve current energy readings."""
if not self.has_emeter:
raise SmartDeviceException("Device has no emeter")

self._verify_emeter()
return EmeterStatus(await self._query_helper(self.emeter_type, "get_realtime"))

def _create_emeter_request(self, year: int = None, month: int = None):
Expand All @@ -429,23 +443,12 @@ def _create_emeter_request(self, year: int = None, month: int = None):
if month is None:
month = datetime.now().month

import collections.abc

def update(d, u):
"""Update dict recursively."""
for k, v in u.items():
if isinstance(v, collections.abc.Mapping):
d[k] = update(d.get(k, {}), v)
else:
d[k] = v
return d

req: Dict[str, Any] = {}
update(req, self._create_request(self.emeter_type, "get_realtime"))
update(
merge(req, self._create_request(self.emeter_type, "get_realtime"))
merge(
req, self._create_request(self.emeter_type, "get_monthstat", {"year": year})
)
update(
merge(
req,
self._create_request(
self.emeter_type, "get_daystat", {"month": month, "year": year}
Expand All @@ -458,9 +461,7 @@ def update(d, u):
@requires_update
def emeter_today(self) -> Optional[float]:
"""Return today's energy consumption in kWh."""
if not self.has_emeter:
raise SmartDeviceException("Device has no emeter")

self._verify_emeter()
raw_data = self._last_update[self.emeter_type]["get_daystat"]["day_list"]
data = self._emeter_convert_emeter_data(raw_data)
today = datetime.now().day
Expand All @@ -474,9 +475,7 @@ def emeter_today(self) -> Optional[float]:
@requires_update
def emeter_this_month(self) -> Optional[float]:
"""Return this month's energy consumption in kWh."""
if not self.has_emeter:
raise SmartDeviceException("Device has no emeter")

self._verify_emeter()
raw_data = self._last_update[self.emeter_type]["get_monthstat"]["month_list"]
data = self._emeter_convert_emeter_data(raw_data)
current_month = datetime.now().month
Expand Down Expand Up @@ -516,9 +515,7 @@ async def get_emeter_daily(
:param kwh: return usage in kWh (default: True)
:return: mapping of day of month to value
"""
if not self.has_emeter:
raise SmartDeviceException("Device has no emeter")

self._verify_emeter()
if year is None:
year = datetime.now().year
if month is None:
Expand All @@ -538,9 +535,7 @@ async def get_emeter_monthly(self, year: int = None, kwh: bool = True) -> Dict:
:param kwh: return usage in kWh (default: True)
:return: dict: mapping of month to value
"""
if not self.has_emeter:
raise SmartDeviceException("Device has no emeter")

self._verify_emeter()
if year is None:
year = datetime.now().year

Expand All @@ -553,17 +548,13 @@ async def get_emeter_monthly(self, year: int = None, kwh: bool = True) -> Dict:
@requires_update
async def erase_emeter_stats(self) -> Dict:
"""Erase energy meter statistics."""
if not self.has_emeter:
raise SmartDeviceException("Device has no emeter")

self._verify_emeter()
return await self._query_helper(self.emeter_type, "erase_emeter_stat", None)

@requires_update
async def current_consumption(self) -> float:
"""Get the current power consumption in Watt."""
if not self.has_emeter:
raise SmartDeviceException("Device has no emeter")

self._verify_emeter()
response = EmeterStatus(await self.get_emeter_realtime())
return float(response["power"])

Expand Down
112 changes: 73 additions & 39 deletions kasa/smartstrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from kasa.smartdevice import (
DeviceType,
EmeterStatus,
SmartDevice,
SmartDeviceException,
requires_update,
Expand All @@ -15,6 +16,15 @@
_LOGGER = logging.getLogger(__name__)


def merge_sums(dicts):
"""Merge the sum of dicts."""
total_dict: DefaultDict[int, float] = defaultdict(lambda: 0.0)
for sum_dict in dicts:
for day, value in sum_dict.items():
total_dict[day] += value
return total_dict


class SmartStrip(SmartDevice):
"""Representation of a TP-Link Smart Power Strip.
Expand Down Expand Up @@ -75,11 +85,7 @@ def __init__(self, host: str) -> None:
@requires_update
def is_on(self) -> bool:
"""Return if any of the outlets are on."""
for plug in self.children:
is_on = plug.is_on
if is_on:
return True
return False
return any(plug.is_on for plug in self.children)

async def update(self):
"""Update some of the attributes.
Expand All @@ -97,6 +103,10 @@ async def update(self):
SmartStripPlug(self.host, parent=self, child_id=child["id"])
)

if self.has_emeter:
for plug in self.children:
await plug.update()

async def turn_on(self, **kwargs):
"""Turn the strip on."""
await self._query_helper("system", "set_relay_state", {"state": 1})
Expand Down Expand Up @@ -140,16 +150,16 @@ def state_information(self) -> Dict[str, Any]:

async def current_consumption(self) -> float:
"""Get the current power consumption in watts."""
consumption = sum(await plug.current_consumption() for plug in self.children)

return consumption
return sum([await plug.current_consumption() for plug in self.children])

async def set_alias(self, alias: str) -> None:
"""Set the alias for the strip.
:param alias: new alias
"""
return await super().set_alias(alias)
@requires_update
async def get_emeter_realtime(self) -> EmeterStatus:
"""Retrieve current energy readings."""
emeter_rt = await self._async_get_emeter_sum("get_emeter_realtime", {})
# Voltage is averaged since each read will result
# in a slightly different voltage since they are not atomic
emeter_rt["voltage_mv"] = int(emeter_rt["voltage_mv"] / len(self.children))
return EmeterStatus(emeter_rt)

@requires_update
async def get_emeter_daily(
Expand All @@ -163,14 +173,9 @@ async def get_emeter_daily(
:param kwh: return usage in kWh (default: True)
:return: mapping of day of month to value
"""
emeter_daily: DefaultDict[int, float] = defaultdict(lambda: 0.0)
for plug in self.children:
plug_emeter_daily = await plug.get_emeter_daily(
year=year, month=month, kwh=kwh
)
for day, value in plug_emeter_daily.items():
emeter_daily[day] += value
return emeter_daily
return await self._async_get_emeter_sum(
"get_emeter_daily", {"year": year, "month": month, "kwh": kwh}
)

@requires_update
async def get_emeter_monthly(self, year: int = None, kwh: bool = True) -> Dict:
Expand All @@ -179,20 +184,45 @@ async def get_emeter_monthly(self, year: int = None, kwh: bool = True) -> Dict:
:param year: year for which to retrieve statistics (default: this year)
:param kwh: return usage in kWh (default: True)
"""
emeter_monthly: DefaultDict[int, float] = defaultdict(lambda: 0.0)
for plug in self.children:
plug_emeter_monthly = await plug.get_emeter_monthly(year=year, kwh=kwh)
for month, value in plug_emeter_monthly:
emeter_monthly[month] += value
return await self._async_get_emeter_sum(
"get_emeter_monthly", {"year": year, "kwh": kwh}
)

return emeter_monthly
async def _async_get_emeter_sum(self, func: str, kwargs: Dict[str, Any]) -> Dict:
"""Retreive emeter stats for a time period from children."""
self._verify_emeter()
return merge_sums(
[await getattr(plug, func)(**kwargs) for plug in self.children]
)

@requires_update
async def erase_emeter_stats(self):
"""Erase energy meter statistics for all plugs."""
for plug in self.children:
await plug.erase_emeter_stats()

@property # type: ignore
@requires_update
def emeter_this_month(self) -> Optional[float]:
"""Return this month's energy consumption in kWh."""
return sum([plug.emeter_this_month for plug in self.children])

@property # type: ignore
@requires_update
def emeter_today(self) -> Optional[float]:
"""Return this month's energy consumption in kWh."""
return sum([plug.emeter_today for plug in self.children])

@property # type: ignore
@requires_update
def emeter_realtime(self) -> EmeterStatus:
"""Return current energy readings."""
emeter = merge_sums([plug.emeter_realtime for plug in self.children])
# Voltage is averaged since each read will result
# in a slightly different voltage since they are not atomic
emeter["voltage_mv"] = int(emeter["voltage_mv"] / len(self.children))
return EmeterStatus(emeter)


class SmartStripPlug(SmartPlug):
"""Representation of a single socket in a power strip.
Expand All @@ -214,12 +244,22 @@ def __init__(self, host: str, parent: "SmartStrip", child_id: str) -> None:
self._device_type = DeviceType.StripSocket

async def update(self):
"""Override the update to no-op and inform the user."""
_LOGGER.warning(
"You called update() on a child device, which has no effect."
"Call update() on the parent device instead."
"""Query the device to update the data.
Needed for properties that are decorated with `requires_update`.
"""
self._last_update = await self.parent.protocol.query(
self.host, self._create_emeter_request()
)
return

def _create_request(
self, target: str, cmd: str, arg: Optional[Dict] = None, child_ids=None
):
request: Dict[str, Any] = {
"context": {"child_ids": [self.child_id]},
target: {cmd: arg},
}
return request

async def _query_helper(
self, target: str, cmd: str, arg: Optional[Dict] = None, child_ids=None
Expand All @@ -245,12 +285,6 @@ def led(self) -> bool:
"""
return False

@property # type: ignore
@requires_update
def has_emeter(self) -> bool:
"""Children have no emeter to my knowledge."""
return False

@property # type: ignore
@requires_update
def device_id(self) -> str:
Expand Down

0 comments on commit 94e5a90

Please sign in to comment.