Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Obtain information about robots from cloud #278

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 180 additions & 7 deletions poetry.lock

Large diffs are not rendered by default.

14 changes: 14 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ paho-mqtt = "~1.6.1"
mashumaro = {version = "^3.12"}
click = { version = "^8.1", optional = true }
tabulate = { version = "^0.9", optional = true }
requests = "^2.31.0"
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't we use aiohttp or httpx for this to keep the lib async?

Copy link
Collaborator Author

@Orhideous Orhideous Mar 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are only three requests, which will be done well if once a year.
After all, there is asyncio.to_thread. What do you think?


[tool.poetry.extras]
cli = ["click", "tabulate"]
Expand All @@ -41,6 +42,7 @@ codespell = "^2.2.6"
mypy = "^1.8"
types-paho-mqtt = "~1.6.0"
types-tabulate = "~0.9.0"
types-requests = "^2.31.0"

[build-system]
requires = ["poetry-core>=1.0.0"]
Expand All @@ -66,6 +68,11 @@ ignore = [
"PLR0912",
]
select = ["ALL"]
[tool.ruff.lint.per-file-ignores]
# There are reverse-engineered classes
"roombapy/cloud/models/map/full.py" = ["RUF009", "D101"]
"roombapy/cloud/aws.py" = ["PLR0913"]


[tool.ruff.format]
docstring-code-format = true
Expand All @@ -76,6 +83,13 @@ strict_optional = true
strict = true
packages = ["roombapy", "tests"]

[[tool.mypy.overrides]]
module = "roombapy.cloud.models.map.full"
disable_error_code = "assignment"
[[tool.mypy.overrides]]
module = "roombapy.cloud.models.serialization"
disable_error_code = "no-any-return,type-arg,call-overload"

[tool.pydantic-mypy]
init_forbid_extra = true
init_typed = true
Expand Down
1 change: 1 addition & 0 deletions roombapy/cloud/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Remote robot control."""
273 changes: 273 additions & 0 deletions roombapy/cloud/aws.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,273 @@
"""Generate AWSv4 signature headers.

Code is partially borrowed from tedder/requests-aws4auth (MIT licensed)
"""

from __future__ import annotations

import datetime
import hashlib
import hmac
import posixpath
import re
import shlex
from dataclasses import dataclass
from typing import Iterable
from urllib.parse import parse_qs, quote, unquote, urlparse

DEFAULT_HEADERS = frozenset({"host", "content-type", "date", "x-amz-*"})
DEFAULT_IROBOT_HEADERS = frozenset({"host", "date", "x-amz-*"})

Headers = dict[str, str]


def normalize_whitespace(text: str) -> str:
"""Replace runs of whitespace with a single space.

Ignore text enclosed in quotes.
"""
if re.search(r"\s", text):
return " ".join(shlex.split(text, posix=False))
return text


def _canonicalize_query_string(query_string: str) -> str:
"""Parse and format querystring as per AWS4 auth requirements.

Perform percent quoting as needed.
"""
safe_qs_unresvd = "-_.~"
space = " "
query_string = query_string.split(space)[0]
# prevent parse_qs from interpreting semicolon as an alternative
# delimiter to ampersand
query_string = query_string.replace(";", "%3B")
qs_items = {}
for name, vals in parse_qs(query_string, keep_blank_values=True).items():
key = quote(name, safe=safe_qs_unresvd)
values = [quote(val, safe=safe_qs_unresvd) for val in vals]
qs_items[key] = values

return "&".join(
[
f"{name}={value}"
for name, values in sorted(qs_items.items())
for value in sorted(values)
]
)


def _canonicalize_path(request_path: str, service: str) -> str:
"""Generate the canonical path as per AWS4 auth requirements.

Not documented anywhere, determined from aws4_testsuite examples,
problem reports and testing against the live services.
"""
safe_chars = "/~"
qs = ""
fixed_path = request_path
if "?" in fixed_path:
fixed_path, qs = fixed_path.split("?", 1)
fixed_path = posixpath.normpath(fixed_path)
fixed_path = re.sub("/+", "/", fixed_path)
if request_path.endswith("/") and not fixed_path.endswith("/"):
fixed_path += "/"
full_path = fixed_path
# S3 seems to require unquoting first. 'host' service is used in
# amz_testsuite tests
if service in ["s3", "host"]:
full_path = unquote(full_path)
full_path = quote(full_path, safe=safe_chars)
if qs:
qm = "?"
full_path = qm.join((full_path, qs))
return full_path


def _get_canonical_headers(
url: str, headers: Headers, include_header_names: Iterable[str]
) -> tuple[str, str]:
"""Generate the Canonical Headers section of the Canonical Request.

:param url: URL to get Host header
:param headers: Existing request headers
:param include_header_names: List of headers to include in the canonical
and signed headers. It's primarily included to allow testing against
specific examples from Amazon. If omitted or None it includes host,
content-type and any header starting 'x-amz-' except for
x-amz-client context, which appears to break mobile analytics auth
if included. Except for the x-amz-client-context exclusion these
defaults are per the AWS documentation.

:returns: Canonical Headers and the Signed Headers strs as a tuple
(canonical_headers, signed_headers).
"""
include = [x.lower() for x in include_header_names]
headers = headers.copy()
# Temporarily include the host header - AWS requires it to be included
# in the signed headers, but Requests doesn't include it in a
# PreparedRequest
if "host" not in headers:
headers["host"] = urlparse(str(url)).netloc.split(":")[0]
# Aggregate for upper/lowercase header name collisions in header names,
# AMZ requires values of colliding headers be concatenated into a
# single header with lowercase name. Although this is not possible with
# Requests, since it uses a case-insensitive dict to hold headers, this
# is here just in case you duck type with a regular dict
canonical_headers: dict[str, list[str]] = {}
for header, value in headers.items():
hdr = header.strip().lower()
val = normalize_whitespace(value).strip()
if (
hdr in include
or "*" in include
or (
"x-amz-*" in include
and hdr.startswith("x-amz-")
and hdr != "x-amz-client-context"
)
):
vals = canonical_headers.setdefault(hdr, [])
vals.append(val)
# Flatten cano_headers dict to string and generate signed_headers
cano_headers = ""
signed_headers_list = []
for hdr in sorted(canonical_headers):
vals = canonical_headers[hdr]
val = ",".join(sorted(vals))
cano_headers += f"{hdr}:{val}\n"
signed_headers_list.append(hdr)
signed_headers = ";".join(signed_headers_list)
return cano_headers, signed_headers


def _get_signature(amz_date: str, canonical_request: str, scope: str) -> bytes:
"""Generate the AWS4 auth signature to sign for the request.

:param amz_date: Date this request is valid for
:param canonical_request: The Canonical Request
:param scope: Request scope:
:returns: Signature
"""
hsh = hashlib.sha256(canonical_request.encode())
sig_items = ["AWS4-HMAC-SHA256", amz_date, scope, hsh.hexdigest()]
return "\n".join(sig_items).encode("utf-8")


def _get_canonical_request(
canonical_headers: str,
signed_headers: str,
service: str,
raw_url: str,
method: str,
payload_hash: str,
) -> str:
"""Create the AWS authentication Canonical Request string."""
url = urlparse(raw_url)
path = _canonicalize_path(url.path, service)
# AWS handles "extreme" query strings differently to urlparse
# (see post-vanilla-query-nonunreserved test in aws_testsuite)
split = raw_url.split("?", 1)
query_string = split[1] if len(split) == 2 else ""
query_string = _canonicalize_query_string(query_string)
request_parts = [
method.upper(),
path,
query_string,
canonical_headers,
signed_headers,
payload_hash,
]
return "\n".join(request_parts)


@dataclass
class SigningKey:
"""AWS signing key. Used to sign AWS authentication strings."""

scope: str
key: bytes

@classmethod
def from_credentials(
cls,
*,
secret_key: str,
region: str,
service: str,
) -> SigningKey:
"""Construct signing key from credentials."""

def _sign(k: bytes, msg: str) -> bytes:
"""Generate an SHA256 HMAC, encoding msg to UTF-8."""
return hmac.new(k, msg.encode("utf-8"), hashlib.sha256).digest()

aws_dt = datetime.datetime.now(tz=datetime.UTC).strftime("%Y%m%d")

init_key = ("AWS4" + secret_key).encode("utf-8")
date_key = _sign(init_key, aws_dt)
region_key = _sign(date_key, region)
service_key = _sign(region_key, service)
key = _sign(service_key, "aws4_request")

scope = f"{aws_dt}/{region}/{service}/aws4_request"

return SigningKey(scope=scope, key=key)


def generate_aws_headers(
*,
url: str,
method: str,
request_headers: Headers,
service: str,
access_id: str,
signing_key: SigningKey,
session_token: str | None,
payload: str | None = None,
include_headers: Iterable[str] = DEFAULT_IROBOT_HEADERS,
) -> Headers:
"""Client-agnostic helper to generate AWSv4 signature.

:param url: Full request URL
:param method: Request method
:param service: AWS service the key is scoped for
:param access_id: AWS access ID
:param session_token: STS temporary credentials
:param signing_key: An SigningKey instance.
:param request_headers: Headers for this request
:param payload: Payload to be signed
:param include_headers: Headers to be signed
:return: Headers for HTTP client
"""
aws_headers = {}

now = datetime.datetime.now(tz=datetime.UTC)
amz_date = now.strftime("%Y%m%dT%H%M%SZ")
aws_headers["x-amz-date"] = amz_date

if payload is not None:
content_hash = hashlib.sha256(payload.encode())
else:
content_hash = hashlib.sha256(b"")
payload_hash = content_hash.hexdigest()

aws_headers["x-amz-content-sha256"] = payload_hash
if session_token:
aws_headers["x-amz-security-token"] = session_token

# generate signature
result = _get_canonical_headers(url, request_headers, include_headers)
canonical_headers, signed_headers = result
cano_req = _get_canonical_request(
canonical_headers, signed_headers, service, url, method, payload_hash
)
signature = _get_signature(amz_date, cano_req, signing_key.scope)
hsh = hmac.new(signing_key.key, signature, hashlib.sha256)
sig = hsh.hexdigest()
auth_str = "AWS4-HMAC-SHA256 "
auth_str += f"Credential={access_id}/{signing_key.scope}, "
auth_str += f"SignedHeaders={signed_headers}, "
auth_str += f"Signature={sig}"
aws_headers["Authorization"] = auth_str
return {**request_headers, **aws_headers}
59 changes: 59 additions & 0 deletions roombapy/cloud/login.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
"""Authentication flow for iRobot cloud."""

from __future__ import annotations

import logging

import requests

from roombapy.cloud.models import login as models

logger = logging.getLogger(__name__)

TIMEOUT = 10
APPLICATION_ID = "ANDROID-C7FB240E-DF34-42D7-AE4E-A8C17079A294"
DISCOVERY_URL = (
"https://disc-prod.iot.irobotapi.com/v1/discover/endpoints?country_code=US"
)
GIGYA_LOGIN_URL_TEMPLATE = "https://accounts.%s/accounts.login"
IROBOT_LOGIN_ENDPOINT = "/v2/login"

LoginResponse = tuple[models.Deployment, models.IRobotLoginResponse]


def login(username: str, password: str) -> LoginResponse:
"""Obtain access credentials and robots' details from cloud."""
response = requests.get(DISCOVERY_URL, timeout=TIMEOUT)
deployments = models.DeploymentsResponse.from_json(response.text)
deployment = deployments.deployments[deployments.current_deployment]
gigya = deployments.gigya

gigya_login_url = GIGYA_LOGIN_URL_TEMPLATE % gigya.datacenter_domain
gigya_login_payload = {
"apiKey": gigya.api_key,
"loginID": username,
"password": password,
"format": "json",
"targetEnv": "mobile",
}
response = requests.post(
gigya_login_url, data=gigya_login_payload, timeout=TIMEOUT
)
gigya_response = models.GigyaLoginResponse.from_json(response.text)

irobot_login_url = f"{deployment.http_base}{IROBOT_LOGIN_ENDPOINT}"
irobot_login_payload = {
"app_id": APPLICATION_ID,
"assume_robot_ownership": "0",
"gigya": {
"signature": gigya_response.signature,
"timestamp": gigya_response.signature_timestamp,
"uid": gigya_response.user_id,
},
}
response = requests.post(
irobot_login_url, json=irobot_login_payload, timeout=TIMEOUT
)
irobot_response = models.IRobotLoginResponse.from_json(response.text)

return deployment, irobot_response
1 change: 1 addition & 0 deletions roombapy/cloud/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""DTOs for cloud APIs."""
Loading