Skip to content

Commit

Permalink
Fix D* (#247)
Browse files Browse the repository at this point in the history
Signed-off-by: Joostlek <joostlek@outlook.com>
  • Loading branch information
joostlek committed Feb 21, 2024
1 parent 32183d4 commit 3bc659f
Show file tree
Hide file tree
Showing 13 changed files with 74 additions and 6 deletions.
4 changes: 0 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,6 @@ ignore = [
"TRY400",
"COM812", # Conflicts with other rules
"ISC001", # Conflicts with other rules
"D100", # documentation info, should be removed after fixes
"D101", # documentation info, should be removed after fixes
"D102", # documentation info, should be removed after fixes
"D103", # documentation info, should be removed after fixes
"PLR2004", # Just annoying, not really useful
"PLR0912",
"PLW2901",
Expand Down
2 changes: 2 additions & 0 deletions roombapy/const.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Constants for roombapy."""

MQTT_ERROR_MESSAGES = {
0: None,
1: "Bad protocol",
Expand Down
9 changes: 8 additions & 1 deletion roombapy/discovery.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
"""Module for discovering Roomba devices on the local network."""
from __future__ import annotations

import logging
Expand All @@ -9,6 +10,8 @@


class RoombaDiscovery:
"""Class for discovering Roomba devices on the local network."""

udp_bind_address = ""
udp_address = "<broadcast>"
udp_port = 5678
Expand All @@ -18,16 +21,18 @@ class RoombaDiscovery:
log = None

def __init__(self):
"""Init discovery."""
"""Initialize the discovery class."""
self.server_socket = _get_socket()
self.log = logging.getLogger(__name__)

def find(self, ip=None):
"""Find Roomba devices on the local network."""
if ip is not None:
return self.get(ip)
return self.get_all()

def get_all(self):
"""Get all Roomba devices on the local network."""
self._start_server()
self._broadcast_message(self.amount_of_broadcasted_messages)
robots = set()
Expand All @@ -40,11 +45,13 @@ def get_all(self):
return robots

def get(self, ip):
"""Get Roomba device with the specified IP address."""
self._start_server()
self._send_message(ip)
return self._get_response(ip)

def _get_response(self, ip=None):
"""Get a response from the Roomba device."""
try:
while True:
raw_response, addr = self.server_socket.recvfrom(1024)
Expand Down
4 changes: 4 additions & 0 deletions roombapy/entry_points.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
"""Entry points for the roombapy package."""
import logging
import sys

Expand All @@ -9,6 +10,7 @@


def discovery():
"""Discover Roomba devices on the local network."""
roomba_ip = _get_ip_from_arg()

roomba_discovery = RoombaDiscovery()
Expand All @@ -22,6 +24,7 @@ def discovery():


def password():
"""Get password for a Roomba device."""
roomba_ip = _get_ip_from_arg()
_validate_ip(roomba_ip)
_wait_for_input()
Expand All @@ -37,6 +40,7 @@ def password():


def connect():
"""Connect to a Roomba device."""
roomba_ip = _get_ip_from_arg()
_validate_ip(roomba_ip)

Expand Down
14 changes: 13 additions & 1 deletion roombapy/remote_client.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
"""Roomba remote client."""
import logging
import ssl
from functools import cache
Expand Down Expand Up @@ -25,6 +26,8 @@ def generate_tls_context() -> ssl.SSLContext:


class RoombaRemoteClient:
"""Roomba remote client."""

address = None
port = None
blid = None
Expand All @@ -35,7 +38,7 @@ class RoombaRemoteClient:
on_disconnect = None

def __init__(self, address, blid, password, port=8883):
"""Create mqtt client."""
"""Initialize the Roomba remote client."""
self.address = address
self.blid = blid
self.password = password
Expand All @@ -44,21 +47,27 @@ def __init__(self, address, blid, password, port=8883):
self.mqtt_client = self._get_mqtt_client()

def set_on_message(self, on_message):
"""Set the on message callback."""
self.mqtt_client.on_message = on_message

def set_on_connect(self, on_connect):
"""Set the on connect callback."""
self.on_connect = on_connect

def set_on_publish(self, on_publish):
"""Set the on publish callback."""
self.mqtt_client.on_publish = on_publish

def set_on_subscribe(self, on_subscribe):
"""Set the on subscribe callback."""
self.mqtt_client.on_subscribe = on_subscribe

def set_on_disconnect(self, on_disconnect):
"""Set the on disconnect callback."""
self.on_disconnect = on_disconnect

def connect(self):
"""Connect to the Roomba."""
attempt = 1
while attempt <= MAX_CONNECTION_RETRIES:
self.log.info(
Expand All @@ -80,12 +89,15 @@ def connect(self):
return False

def disconnect(self):
"""Disconnect from the Roomba."""
self.mqtt_client.disconnect()

def subscribe(self, topic):
"""Subscribe to a topic."""
self.mqtt_client.subscribe(topic)

def publish(self, topic, payload):
"""Publish a message to a topic."""
self.mqtt_client.publish(topic, payload)

def _open_mqtt_connection(self):
Expand Down
11 changes: 11 additions & 0 deletions roombapy/roomba.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,17 +85,21 @@ def __init__(self, remote_client, continuous=True, delay=1):
self.client_error = None

def register_on_message_callback(self, callback):
"""Register a function to be called when a message is received."""
self.on_message_callbacks.append(callback)

def register_on_disconnect_callback(self, callback):
"""Register a function to be called when a disconnect occurs."""
self.on_disconnect_callbacks.append(callback)

def _init_remote_client_callbacks(self):
"""Initialize the remote client callbacks."""
self.remote_client.set_on_message(self.on_message)
self.remote_client.set_on_connect(self.on_connect)
self.remote_client.set_on_disconnect(self.on_disconnect)

def connect(self):
"""Connect to the Roomba."""
if self.roomba_connected or self.periodic_connection_running:
return

Expand All @@ -116,12 +120,14 @@ def _connect(self):
return is_connected

def disconnect(self):
"""Disconnect from the Roomba."""
if self.continuous:
self.remote_client.disconnect()
else:
self.stop_connection = True

def periodic_connection(self):
"""Periodic connection to the Roomba."""
# only one connection thread at a time!
if self.periodic_connection_running:
return
Expand All @@ -139,6 +145,7 @@ def periodic_connection(self):
self.periodic_connection_running = False

def on_connect(self, error):
"""On connect callback."""
self.log.info("Connecting to Roomba %s", self.remote_client.address)
self.client_error = error
if error is not None:
Expand All @@ -153,6 +160,7 @@ def on_connect(self, error):
self.remote_client.subscribe(self.topic)

def on_disconnect(self, error):
"""On disconnect callback."""
self.roomba_connected = False
self.client_error = error
if error is not None:
Expand All @@ -171,6 +179,7 @@ def on_disconnect(self, error):
self.log.info("Disconnected from Roomba %s", self.remote_client.address)

def on_message(self, _mosq, _obj, msg):
"""On message callback."""
if self.exclude != "":
if self.exclude in msg.topic:
return
Expand Down Expand Up @@ -203,6 +212,7 @@ def on_message(self, _mosq, _obj, msg):
callback(json_data)

def send_command(self, command, params=None):
"""Send a command to the Roomba."""
if params is None:
params = {}

Expand All @@ -223,6 +233,7 @@ def send_command(self, command, params=None):
self.remote_client.publish("cmd", str_command)

def set_preference(self, preference, setting):
"""Set a preference on the Roomba."""
self.log.debug("Set preference: %s, %s", preference, setting)
val = setting
# Parse boolean string
Expand Down
2 changes: 2 additions & 0 deletions roombapy/roomba_factory.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
"""Factory class to create Roomba class to control your robot."""
from roombapy import Roomba
from roombapy.remote_client import RoombaRemoteClient

Expand All @@ -9,6 +10,7 @@ class RoombaFactory:
def create_roomba(
address=None, blid=None, password=None, continuous=True, delay=1
):
"""Create a Roomba instance."""
remote_client = RoombaFactory._create_remote_client(
address, blid, password
)
Expand Down
5 changes: 5 additions & 0 deletions roombapy/roomba_info.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
"""Module for RoombaInfo class."""
from __future__ import annotations

from functools import cached_property
Expand All @@ -16,6 +17,8 @@


class RoombaInfo(BaseModel):
"""Class for storing information about a Roomba device."""

hostname: str
firmware: str = Field(alias="sw")
ip: str
Expand All @@ -28,6 +31,7 @@ class RoombaInfo(BaseModel):
@field_validator("hostname")
@classmethod
def hostname_validator(cls, value: str) -> str:
"""Validate the hostname."""
if "-" not in value:
raise ValueError(f"hostname does not contain a dash: {value}")
model_name, blid = value.split("-")
Expand All @@ -39,6 +43,7 @@ def hostname_validator(cls, value: str) -> str:

@cached_property
def blid(self) -> str:
"""Return the BLID."""
return self.hostname.split("-")[1]

class Config:
Expand Down
6 changes: 6 additions & 0 deletions tests/abstract_test_roomba.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
"""This module provides an abstract test class for Roomba."""
from roombapy import Roomba, RoombaFactory

ROOMBA_CONFIG = {
Expand All @@ -11,6 +12,8 @@


class AbstractTestRoomba:
"""Abstract test class for Roomba."""

@staticmethod
def get_default_roomba(
address=ROOMBA_CONFIG["host"],
Expand All @@ -19,6 +22,7 @@ def get_default_roomba(
continuous=ROOMBA_CONFIG["continuous"],
delay=ROOMBA_CONFIG["delay"],
) -> Roomba:
"""Get a default Roomba."""
return RoombaFactory.create_roomba(
address=address,
blid=blid,
Expand All @@ -29,6 +33,8 @@ def get_default_roomba(

@staticmethod
def get_message(topic, payload):
"""Get a message."""

class Message:
pass

Expand Down
8 changes: 8 additions & 0 deletions tests/test_decode.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
"""Test the decoding of the Roomba discovery messages."""
from roombapy.discovery import RoombaDiscovery, _decode_data

TEST_ROOMBA_INFO = """
Expand All @@ -14,34 +15,41 @@


def test_skip_garbage() -> None:
"""Test skipping garbage data."""
assert _decode_data(b"\x0f\x00\xff\xf0") is None


def test_skip_own_messages() -> None:
"""Test skipping own messages."""
assert _decode_data(RoombaDiscovery.roomba_message.encode()) is None


def test_skip_broken_json() -> None:
"""Test skipping broken JSON."""
assert _decode_data(b'{"test": 1') is None


def test_skip_unknown_json() -> None:
"""Test skipping unknown JSON."""
assert _decode_data(b'{"test": 1}') is None


def test_skip_unknown_hostname() -> None:
"""Test skipping unknown hostname."""
assert _decode_data(b'{"hostname": "test"}') is None
assert _decode_data(TEST_ROOMBA_INFO.encode()) is None


def test_skip_hostnames_without_blid() -> None:
"""Test skipping hostnames without BLID."""
decoded = _decode_data(
TEST_ROOMBA_INFO.replace("hostname_placeholder", "iRobot-").encode()
)
assert decoded is None


def test_allow_approved_hostnames() -> None:
"""Test allowing approved hostnames."""
blid = "test"
for hostname in [f"Roomba-{blid}", f"iRobot-{blid}"]:
decoded = _decode_data(
Expand Down
4 changes: 4 additions & 0 deletions tests/test_discovery.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
"""Test for the discovery module."""
from roombapy.discovery import RoombaDiscovery


class TestDiscovery:
"""Test the discovery module."""

def test_discovery_with_wrong_msg(self):
"""Test discovery with wrong message."""
# given
discovery = RoombaDiscovery()

Expand Down
4 changes: 4 additions & 0 deletions tests/test_roomba.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
"""Test the Roomba class."""
from tests import abstract_test_roomba


class TestRoomba(abstract_test_roomba.AbstractTestRoomba):
"""Test the Roomba class."""

def test_roomba_with_data(self):
"""Test Roomba with data."""
# given
roomba = self.get_default_roomba()

Expand Down
Loading

0 comments on commit 3bc659f

Please sign in to comment.