Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 17 additions & 4 deletions src/websockets/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
import argparse
import asyncio
import os
import ssl
import sys
from typing import Generator
from typing import Any, Generator

from .asyncio.client import ClientConnection, connect
from .asyncio.messages import SimpleQueue
Expand Down Expand Up @@ -101,9 +102,9 @@ async def send_outgoing_messages(
break


async def interactive_client(uri: str) -> None:
async def interactive_client(uri: str, **kwargs: Any) -> None:
try:
websocket = await connect(uri)
websocket = await connect(uri, **kwargs)
except Exception as exc:
print(f"Failed to connect to {uri}: {exc}.")
sys.exit(1)
Expand Down Expand Up @@ -151,6 +152,11 @@ def main(argv: list[str] | None = None) -> None:
group = parser.add_mutually_exclusive_group()
group.add_argument("--version", action="store_true")
group.add_argument("uri", metavar="<uri>", nargs="?")
parser.add_argument(
"--insecure",
action="store_true",
help="Disable SSL certificate verification for wss:// connections.",
)
args = parser.parse_args(argv)

if args.version:
Expand All @@ -171,8 +177,15 @@ def main(argv: list[str] | None = None) -> None:
except ImportError: # readline isn't available on all platforms
pass

kwargs: dict[str, Any] = {}
if args.insecure:
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE
kwargs["ssl"] = ssl_context

# Remove the try/except block when dropping Python < 3.11.
try:
asyncio.run(interactive_client(args.uri))
asyncio.run(interactive_client(args.uri, **kwargs))
except KeyboardInterrupt: # pragma: no cover
pass
26 changes: 25 additions & 1 deletion tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# Run a test server in a thread. This is easier than running an asyncio server
# because we would have to run main() in a thread, due to using asyncio.run().
from .sync.server import get_uri, run_server
from .utils import SERVER_CONTEXT


vt100_commands = re.compile(r"\x1b\[[A-Z]|\x1b[78]|\r")
Expand Down Expand Up @@ -109,4 +110,27 @@ def test_connection_failure(self):

def test_no_args(self):
output = self.run_main([], expected_exit_code=2)
self.assertEqual(output, "usage: websockets [--version | <uri>]\n")
self.assertEqual(output, "usage: websockets [--version] [--insecure] [<uri>]\n")

def test_insecure_connection(self):
def text_handler(websocket):
websocket.send("secure hello")

with run_server(text_handler, ssl=SERVER_CONTEXT) as server:
server_uri = get_uri(server)
output = self.run_main([server_uri, "--insecure"], "")
self.assertEqual(
remove_commands_and_prompts(output),
add_connection_messages("\n< secure hello\n", server_uri),
)

def test_insecure_connection_fails_without_flag(self):
def text_handler(websocket):
websocket.send("secure hello")

with run_server(text_handler, ssl=SERVER_CONTEXT) as server:
server_uri = get_uri(server)
output = self.run_main([server_uri], expected_exit_code=1)
self.assertTrue(
output.startswith(f"Failed to connect to {server_uri}: ")
)