Skip to content

Commit

Permalink
refactor validating handles in ATProto and elsewhere
Browse files Browse the repository at this point in the history
for #982
  • Loading branch information
snarfed committed May 3, 2024
1 parent b8e6782 commit 2bf526a
Show file tree
Hide file tree
Showing 7 changed files with 62 additions and 33 deletions.
5 changes: 3 additions & 2 deletions activitypub.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def owns_id(cls, id):
return False

@classmethod
def owns_handle(cls, handle):
def owns_handle(cls, handle, allow_internal=False):
"""Returns True if handle is a WebFinger ``@-@`` handle, False otherwise.
Example: ``@user@instance.com``. The leading ``@`` is optional.
Expand All @@ -171,7 +171,8 @@ def owns_handle(cls, handle):
return False

user, domain = parts
return user and domain and not cls.is_blocklisted(domain)
return user and domain and not cls.is_blocklisted(
domain, allow_internal=allow_internal)

@classmethod
def handle_to_id(cls, handle):
Expand Down
9 changes: 7 additions & 2 deletions atproto.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,9 @@ def owns_id(cls, id):
or id.startswith('https://bsky.app/'))

@classmethod
def owns_handle(cls, handle):
if not re.match(DOMAIN_RE, handle):
def owns_handle(cls, handle, allow_internal=False):
# TODO: implement allow_internal
if not did.HANDLE_RE.fullmatch(handle):
return False

@classmethod
Expand Down Expand Up @@ -248,6 +249,10 @@ def create_for(cls, user):
Args:
user (models.User)
Raises:
ValueError: if the user's handle is invalid, eg begins or ends with an
underscore or dash
"""
assert not isinstance(user, ATProto)

Expand Down
38 changes: 22 additions & 16 deletions ids.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import re
from urllib.parse import urljoin, urlparse

from arroba import did
from flask import request
from google.cloud.ndb.query import FilterNode, Query
from granary.bluesky import BSKY_APP_URL_RE, web_url_to_at_uri
Expand Down Expand Up @@ -162,50 +161,57 @@ def translate_handle(*, handle, from_, to, enhanced):
Returns:
str: the corresponding handle in ``to``
Raises:
ValueError: if the user's handle is invalid, eg begins or ends with an
underscore or dash
"""
assert handle and from_ and to, (handle, from_, to)
assert from_.owns_handle(handle) is not False or from_.LABEL == 'ui'

if from_.LABEL == 'atproto':
assert did.HANDLE_RE.fullmatch(handle)
if not from_.LABEL == 'ui':
if from_.owns_handle(handle, allow_internal=True) is False:
raise ValueError(f'input handle {handle} is not valid for {from_.LABEL}')

if from_ == to:
return handle

output = None
match from_.LABEL, to.LABEL:
case _, 'activitypub':
domain = f'{from_.ABBREV}{SUPERDOMAIN}'
if enhanced or handle == PRIMARY_DOMAIN or handle in PROTOCOL_DOMAINS:
domain = handle
return f'@{handle}@{domain}'
output = f'@{handle}@{domain}'

case _, 'atproto':
output = handle.lstrip('@').replace('@', '.')
for from_char in ATPROTO_DASH_CHARS:
handle = handle.replace(from_char, '-')
output = output.replace(from_char, '-')

handle = handle.lstrip('@').replace('@', '.')
if enhanced or handle == PRIMARY_DOMAIN or handle in PROTOCOL_DOMAINS:
pass
else:
handle = f'{handle}.{from_.ABBREV}{SUPERDOMAIN}'

assert did.HANDLE_RE.fullmatch(handle)
return handle
output = f'{output}.{from_.ABBREV}{SUPERDOMAIN}'

case 'activitypub', 'web':
user, instance = handle.lstrip('@').split('@')
# TODO: get this from the actor object's url field?
return (f'https://{user}' if user == instance
output = (f'https://{user}' if user == instance
else f'https://{instance}/@{user}')

case _, 'web':
return handle
output = handle

# only for unit tests
case _, 'fake' | 'other' | 'eefake':
return f'{to.LABEL}:handle:{handle}'
output = f'{to.LABEL}:handle:{handle}'

assert output, (handle, from_.LABEL, to.LABEL)
# don't check Web handles because they're sometimes URLs, eg
# @user@instance => https://instance/@user
if to.LABEL != 'web' and to.owns_handle(output, allow_internal=True) is False:
raise ValueError(f'translated handle {output} is not valid for {to.LABEL}')

assert False, (handle, from_.LABEL, to.LABEL)
return output


def translate_object_id(*, id, from_, to):
Expand Down
7 changes: 6 additions & 1 deletion protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def owns_id(cls, id):
return False

@classmethod
def owns_handle(cls, handle):
def owns_handle(cls, handle, allow_internal=False):
"""Returns whether this protocol owns the handle, or None if it's unclear.
To be implemented by subclasses.
Expand All @@ -192,6 +192,8 @@ def owns_handle(cls, handle):
Args:
handle (str)
allow_internal (bool): whether to return False for internal domains
like ``fed.brid.gy``, ``bsky.brid.gy``, etc
Returns:
bool or None
Expand Down Expand Up @@ -409,6 +411,9 @@ def create_for(cls, user):
Args:
user (models.User): original source user. Shouldn't already have a
copy user for this protocol in :attr:`copies`.
Raises:
ValueError: if we can't create a copy of the given user in this protocol
"""
raise NotImplementedError()

Expand Down
30 changes: 21 additions & 9 deletions tests/test_atproto.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,9 @@ def test_get_or_create(self, _):
self.assertEqual('han.dull', user.key.get().handle)

def test_owns_id(self):
self.assertEqual(False, ATProto.owns_id('http://foo'))
self.assertEqual(False, ATProto.owns_id('https://bar.baz/biff'))
self.assertEqual(False, ATProto.owns_id('e45fab982'))
self.assertFalse(ATProto.owns_id('http://foo'))
self.assertFalse(ATProto.owns_id('https://bar.baz/biff'))
self.assertFalse(ATProto.owns_id('e45fab982'))

self.assertTrue(ATProto.owns_id('at://did:plc:user/bar/123'))
self.assertTrue(ATProto.owns_id('did:plc:user'))
Expand All @@ -123,12 +123,18 @@ def test_owns_handle(self):
self.assertIsNone(ATProto.owns_handle('foo.com'))
self.assertIsNone(ATProto.owns_handle('foo.bar.com'))

self.assertEqual(False, ATProto.owns_handle('foo'))
self.assertEqual(False, ATProto.owns_handle('@foo'))
self.assertEqual(False, ATProto.owns_handle('@foo.com'))
self.assertEqual(False, ATProto.owns_handle('@foo@bar.com'))
self.assertEqual(False, ATProto.owns_handle('foo@bar.com'))
self.assertEqual(False, ATProto.owns_handle('localhost'))
self.assertFalse(ATProto.owns_handle('foo'))
self.assertFalse(ATProto.owns_handle('@foo'))
self.assertFalse(ATProto.owns_handle('@foo.com'))
self.assertFalse(ATProto.owns_handle('@foo@bar.com'))
self.assertFalse(ATProto.owns_handle('foo@bar.com'))
self.assertFalse(ATProto.owns_handle('localhost'))

self.assertFalse(ATProto.owns_handle('_foo.com'))
self.assertFalse(ATProto.owns_handle('-foo.com'))
self.assertFalse(ATProto.owns_handle('foo_.com'))
self.assertFalse(ATProto.owns_handle('foo-.com'))

# TODO: this should be False
self.assertIsNone(ATProto.owns_handle('web.brid.gy'))

Expand Down Expand Up @@ -701,6 +707,12 @@ def test_create_for(self, mock_post, mock_create_task, mock_zone):

mock_create_task.assert_called()

def test_create_for_bad_handle(self):
# underscores gets translated to dashes, trailing/leading aren't allowed
for bad in 'fake:user_', '_fake:user':
with self.assertRaises(ValueError):
ATProto.create_for(Fake(id=bad))

@patch('google.cloud.dns.client.ManagedZone', autospec=True)
@patch.object(tasks_client, 'create_task', return_value=Task(name='my task'))
@patch('requests.post',
Expand Down
2 changes: 1 addition & 1 deletion tests/testutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def owns_id(cls, id):
or id in cls.fetchable)

@classmethod
def owns_handle(cls, handle):
def owns_handle(cls, handle, allow_internal=False):
return handle.startswith(f'{cls.LABEL}:handle:')

@classmethod
Expand Down
4 changes: 2 additions & 2 deletions web.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,10 +358,10 @@ def owns_id(cls, id):
return False

@classmethod
def owns_handle(cls, handle):
def owns_handle(cls, handle, allow_internal=False):
if handle == PRIMARY_DOMAIN or handle in PROTOCOL_DOMAINS:
return True
elif not is_valid_domain(handle, allow_internal=False):
elif not is_valid_domain(handle, allow_internal=allow_internal):
return False

@classmethod
Expand Down

0 comments on commit 2bf526a

Please sign in to comment.