Skip to content

Commit

Permalink
Merge pull request nats-io#183 from brianshannan/fstrings
Browse files Browse the repository at this point in the history
use fstrings where possible
  • Loading branch information
wallyqs committed Oct 8, 2020
2 parents 4a4d611 + d00bfe4 commit 90751e8
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 78 deletions.
12 changes: 8 additions & 4 deletions examples/nats-req/__main__.py
Expand Up @@ -18,6 +18,7 @@
import signal
import nats


def show_usage():
usage = """
nats-req [-s SERVER] <subject> <data>
Expand Down Expand Up @@ -54,10 +55,7 @@ async def error_cb(e):
async def reconnected_cb():
print(f"Connected to NATS at {nc.connected_url.netloc}...")

options = {
"error_cb": error_cb,
"reconnected_cb": reconnected_cb
}
options = {"error_cb": error_cb, "reconnected_cb": reconnected_cb}

if len(args.creds) > 0:
options["user_credentials"] = args.creds
Expand All @@ -72,6 +70,11 @@ async def reconnected_cb():
print(e)
show_usage_and_die()

async def req_callback(msg):
await msg.respond(b'a response')

await nc.subscribe(args.subject, cb=req_callback)

try:
future = nc.request(args.subject, data.encode())
print(f"Published [{args.subject}] : '{data}'")
Expand All @@ -84,6 +87,7 @@ async def reconnected_cb():
print("nats: request timed out!")
await nc.drain()


if __name__ == '__main__':
loop = asyncio.get_event_loop()
try:
Expand Down
51 changes: 12 additions & 39 deletions nats/aio/client.py
Expand Up @@ -28,6 +28,7 @@
from nats.aio.utils import new_inbox
from nats.aio.nuid import NUID
from nats.protocol.parser import *
from nats.protocol import command as prot_command

__version__ = '0.11.2'
__lang__ = 'python3'
Expand All @@ -38,9 +39,6 @@
CONNECT_OP = b'CONNECT'
PING_OP = b'PING'
PONG_OP = b'PONG'
PUB_OP = b'PUB'
SUB_OP = b'SUB'
UNSUB_OP = b'UNSUB'
OK_OP = b'+OK'
ERR_OP = b'-ERR'
_CRLF_ = b'\r\n'
Expand Down Expand Up @@ -274,12 +272,7 @@ def __init__(
self._client = client

def __repr__(self):
return "<{}: subject='{}' reply='{}' data='{}...'>".format(
self.__class__.__name__,
self.subject,
self.reply,
self.data[:10].decode(),
)
return f"<{self.__class__.__name__}: subject='{self.subject}' reply='{self.reply}' data='{self.data[:10].decode()}...'>"

async def respond(self, data: bytes):
if not self.reply:
Expand Down Expand Up @@ -766,12 +759,7 @@ async def _send_publish(self, subject, reply, payload, payload_size):
# Avoid sending messages with empty replies.
raise ErrBadSubject

payload_size_bytes = ("%d" % payload_size).encode()
pub_cmd = b''.join([
PUB_OP, _SPC_,
subject.encode(), _SPC_,
reply.encode(), _SPC_, payload_size_bytes, _CRLF_, payload, _CRLF_
])
pub_cmd = prot_command.pub_cmd(subject, reply, payload)
self.stats['out_msgs'] += 1
self.stats['out_bytes'] += payload_size
await self._send_command(pub_cmd)
Expand Down Expand Up @@ -829,11 +817,7 @@ def _remove_sub(self, sid, max_msgs=0):
self._subs.pop(sid, None)

async def _send_subscribe(self, sub):
sub_cmd = b''.join([
SUB_OP, _SPC_,
sub._subject.encode(), _SPC_,
sub._queue.encode(), _SPC_, ("%d" % sub._id).encode(), _CRLF_
])
sub_cmd = prot_command.sub_cmd(sub._subject, sub._queue, sub._id)
await self._send_command(sub_cmd)
await self._flush_pending()

Expand Down Expand Up @@ -930,11 +914,7 @@ async def _request_old_style(self, subject, payload, timeout=0.5):
raise ErrTimeout

async def _send_unsubscribe(self, sid, limit=1):
b_limit = b''
if limit > 0:
b_limit = ("%d" % limit).encode()
b_sid = ("%d" % sid).encode()
unsub_cmd = b''.join([UNSUB_OP, _SPC_, b_sid, _SPC_, b_limit, _CRLF_])
unsub_cmd = prot_command.unsub_cmd(sid, limit)
await self._send_command(unsub_cmd)
await self._flush_pending()

Expand Down Expand Up @@ -1064,16 +1044,16 @@ def _setup_server_pool(self, connect_url):
elif ":" in connect_url:
# Expand the scheme for the user
# e.g. 127.0.0.1:4222
uri = urlparse("nats://%s" % connect_url)
uri = urlparse(f"nats://{connect_url}")
else:
# Just use the endpoint with the default NATS port.
# e.g. demo.nats.io
uri = urlparse("nats://%s:4222" % connect_url)
uri = urlparse(f"nats://{connect_url}:4222")

# In case only endpoint with scheme was set.
# e.g. nats://demo.nats.io or localhost:
if uri.port is None:
uri = urlparse("nats://%s:4222" % uri.hostname)
uri = urlparse(f"nats://{uri.hostname}:4222")
except ValueError:
raise NatsError("nats: invalid connect url option")

Expand Down Expand Up @@ -1259,20 +1239,13 @@ async def _attempt_reconnect(self):
# auto unsubscribe the number of messages we have left
max_msgs = sub._max_msgs - sub._received

sub_cmd = b''.join([
SUB_OP, _SPC_,
sub._subject.encode(), _SPC_,
sub._queue.encode(), _SPC_, ("%d" % sid).encode(),
_CRLF_
])
sub_cmd = prot_command.sub_cmd(
sub._subject, sub._queue, sid
)
self._io_writer.write(sub_cmd)

if max_msgs > 0:
b_max_msgs = ("%d" % max_msgs).encode()
b_sid = ("%d" % sid).encode()
unsub_cmd = b''.join([
UNSUB_OP, _SPC_, b_sid, _SPC_, b_max_msgs, _CRLF_
])
unsub_cmd = prot_command.unsub_cmd(sid, max_msgs)
self._io_writer.write(unsub_cmd)

for sid in subs_to_remove:
Expand Down
18 changes: 18 additions & 0 deletions nats/protocol/command.py
@@ -0,0 +1,18 @@
PUB_OP = 'PUB'
SUB_OP = 'SUB'
UNSUB_OP = 'UNSUB'
_CRLF_ = '\r\n'


def pub_cmd(subject, reply, payload):
return f'{PUB_OP} {subject} {reply} {len(payload)}{_CRLF_}'.encode(
) + payload + _CRLF_.encode()


def sub_cmd(subject, queue, sid):
return f'{SUB_OP} {subject} {queue} {sid}{_CRLF_}'.encode()


def unsub_cmd(sid, limit):
limit_s = '' if limit == 0 else f'{limit}'
return f'{UNSUB_OP} {sid} {limit_s}{_CRLF_}'.encode()
40 changes: 13 additions & 27 deletions tests/test_client.py
Expand Up @@ -27,7 +27,7 @@ def test_default_connect_command(self):
nc.options["name"] = None
nc.options["no_echo"] = False
got = nc._connect_command()
expected = 'CONNECT {"echo": true, "lang": "python3", "pedantic": false, "protocol": 1, "verbose": false, "version": "%s"}\r\n' % __version__
expected = f'CONNECT {{"echo": true, "lang": "python3", "pedantic": false, "protocol": 1, "verbose": false, "version": "{__version__}"}}\r\n'
self.assertEqual(expected.encode(), got)

def test_default_connect_command_with_name(self):
Expand All @@ -38,7 +38,7 @@ def test_default_connect_command_with_name(self):
nc.options["name"] = "secret"
nc.options["no_echo"] = False
got = nc._connect_command()
expected = 'CONNECT {"echo": true, "lang": "python3", "name": "secret", "pedantic": false, "protocol": 1, "verbose": false, "version": "%s"}\r\n' % __version__
expected = f'CONNECT {{"echo": true, "lang": "python3", "name": "secret", "pedantic": false, "protocol": 1, "verbose": false, "version": "{__version__}"}}\r\n'
self.assertEqual(expected.encode(), got)

def tests_generate_new_inbox(self):
Expand Down Expand Up @@ -197,7 +197,7 @@ async def test_publish(self):
nc = NATS()
await nc.connect()
for i in range(0, 100):
await nc.publish("hello.%d" % i, b'A')
await nc.publish(f"hello.{i}", b'A')

with self.assertRaises(ErrBadSubject):
await nc.publish("", b'')
Expand All @@ -208,9 +208,7 @@ async def test_publish(self):
self.assertEqual(100, nc.stats['out_msgs'])
self.assertEqual(100, nc.stats['out_bytes'])

endpoint = '127.0.0.1:{port}'.format(
port=self.server_pool[0].http_port
)
endpoint = f'127.0.0.1:{self.server_pool[0].http_port}'
httpclient = http.client.HTTPConnection(endpoint, timeout=5)
httpclient.request('GET', '/varz')
response = httpclient.getresponse()
Expand All @@ -223,7 +221,7 @@ async def test_flush(self):
nc = NATS()
await nc.connect()
for i in range(0, 10):
await nc.publish("flush.%d" % i, b'AA')
await nc.publish(f"flush.{i}", b'AA')
await nc.flush()
self.assertEqual(10, nc.stats['out_msgs'])
self.assertEqual(20, nc.stats['out_bytes'])
Expand Down Expand Up @@ -266,9 +264,7 @@ async def subscription_handler(msg):
self.assertEqual(2, nc.stats['out_msgs'])
self.assertEqual(22, nc.stats['out_bytes'])

endpoint = '127.0.0.1:{port}'.format(
port=self.server_pool[0].http_port
)
endpoint = f'127.0.0.1:{self.server_pool[0].http_port}'
httpclient = http.client.HTTPConnection(endpoint, timeout=5)
httpclient.request('GET', '/connz')
response = httpclient.getresponse()
Expand Down Expand Up @@ -576,9 +572,7 @@ async def subscription_handler(msg):
nc._subs[sub._id].received

await asyncio.sleep(1)
endpoint = '127.0.0.1:{port}'.format(
port=self.server_pool[0].http_port
)
endpoint = f'127.0.0.1:{self.server_pool[0].http_port}'
httpclient = http.client.HTTPConnection(endpoint, timeout=5)
httpclient.request('GET', '/connz')
response = httpclient.getresponse()
Expand Down Expand Up @@ -1966,19 +1960,15 @@ async def replies(msg):
await nc2.subscribe("my-replies.*", cb=replies)
for i in range(0, 201):
await nc2.publish(
"foo",
b'help',
reply="my-replies.%s" % nc._nuid.next().decode()
"foo", b'help', reply=f"my-replies.{nc._nuid.next().decode()}"
)
await nc2.publish(
"bar",
b'help',
reply="my-replies.%s" % nc._nuid.next().decode()
"bar", b'help', reply=f"my-replies.{nc._nuid.next().decode()}"
)
await nc2.publish(
"quux",
b'help',
reply="my-replies.%s" % nc._nuid.next().decode()
reply=f"my-replies.{nc._nuid.next().decode()}"
)

# Relinquish control so that messages are processed.
Expand Down Expand Up @@ -2064,19 +2054,15 @@ async def replies(msg):
await nc2.subscribe("my-replies.*", cb=replies)
for i in range(0, 201):
await nc2.publish(
"foo",
b'help',
reply="my-replies.%s" % nc._nuid.next().decode()
"foo", b'help', reply=f"my-replies.{nc._nuid.next().decode()}"
)
await nc2.publish(
"bar",
b'help',
reply="my-replies.%s" % nc._nuid.next().decode()
"bar", b'help', reply=f"my-replies.{nc._nuid.next().decode()}"
)
await nc2.publish(
"quux",
b'help',
reply="my-replies.%s" % nc._nuid.next().decode()
reply=f"my-replies.{nc._nuid.next().decode()}"
)

# Relinquish control so that messages are processed.
Expand Down
16 changes: 8 additions & 8 deletions tests/test_parser.py
Expand Up @@ -158,7 +158,7 @@ async def test_parse_info(self):
nc = MockNatsClient()
ps = Parser(nc)
server_id = 'A' * 2048
data = '''INFO {"server_id": "%s", "max_payload": 100, "auth_required": false, "connect_urls":["127.0.0.0.1:4223"]}\r\n''' % server_id
data = f'INFO {{"server_id": "{server_id}", "max_payload": 100, "auth_required": false, "connect_urls":["127.0.0.0.1:4223"]}}\r\n'
await ps.parse(data.encode())
self.assertEqual(len(ps.buf), 0)
self.assertEqual(ps.state, AWAITING_CONTROL_LINE)
Expand All @@ -185,13 +185,13 @@ async def payload_test(sid, subject, reply, payload):

ps = Parser(nc)
reply = 'A' * 2043
data = '''PING\r\nMSG hello 1 %s''' % reply
data = f'PING\r\nMSG hello 1 {reply}'
await ps.parse(data.encode())
await ps.parse(b'''AAAAA 0\r\n\r\nMSG hello 1 world 0''')
await ps.parse(b'AAAAA 0\r\n\r\nMSG hello 1 world 0')
self.assertEqual(msgs, 1)
self.assertEqual(len(ps.buf), 19)
self.assertEqual(ps.state, AWAITING_CONTROL_LINE)
await ps.parse(b'''\r\n\r\n''')
await ps.parse(b'\r\n\r\n')
self.assertEqual(msgs, 2)

@async_test
Expand Down Expand Up @@ -219,12 +219,12 @@ async def payload_test(sid, subject, reply, payload):
# FIXME: Malformed long protocol lines will not be detected
# by the client, so we rely on the ping/pong interval
# from the client to give up instead.
data = '''PING\r\nWRONG hello 1 %s''' % reply
data = f'PING\r\nWRONG hello 1 {reply}'
await ps.parse(data.encode())
await ps.parse(b'''AAAAA 0''')
await ps.parse(b'AAAAA 0')
self.assertEqual(ps.state, AWAITING_CONTROL_LINE)
await ps.parse(b'''\r\n\r\n''')
await ps.parse(b'''\r\n\r\n''')
await ps.parse(b'\r\n\r\n')
await ps.parse(b'\r\n\r\n')


if __name__ == '__main__':
Expand Down

0 comments on commit 90751e8

Please sign in to comment.