Skip to content

Commit

Permalink
Allow expired accounts to logout (matrix-org#7443)
Browse files Browse the repository at this point in the history
  • Loading branch information
anoadragon453 authored and phil-flex committed Jun 16, 2020
1 parent 7c6437a commit 83c0e1f
Show file tree
Hide file tree
Showing 5 changed files with 140 additions and 22 deletions.
1 change: 1 addition & 0 deletions changelog.d/7443.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Allow expired user accounts to log out their device sessions.
50 changes: 33 additions & 17 deletions synapse/api/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from netaddr import IPAddress

from twisted.internet import defer
from twisted.web.server import Request

import synapse.logging.opentracing as opentracing
import synapse.types
Expand Down Expand Up @@ -162,19 +163,25 @@ def get_public_keys(self, invite_event):

@defer.inlineCallbacks
def get_user_by_req(
self, request, allow_guest=False, rights="access", allow_expired=False
self,
request: Request,
allow_guest: bool = False,
rights: str = "access",
allow_expired: bool = False,
):
""" Get a registered user's ID.
Args:
request - An HTTP request with an access_token query parameter.
allow_expired - Whether to allow the request through even if the account is
expired. If true, Synapse will still require an access token to be
provided but won't check if the account it belongs to has expired. This
works thanks to /login delivering access tokens regardless of accounts'
expiration.
request: An HTTP request with an access_token query parameter.
allow_guest: If False, will raise an AuthError if the user making the
request is a guest.
rights: The operation being performed; the access token must allow this
allow_expired: If True, allow the request through even if the account
is expired, or session token lifetime has ended. Note that
/login will deliver access tokens regardless of expiration.
Returns:
defer.Deferred: resolves to a ``synapse.types.Requester`` object
defer.Deferred: resolves to a `synapse.types.Requester` object
Raises:
InvalidClientCredentialsError if no user by that token exists or the token
is invalid.
Expand Down Expand Up @@ -205,7 +212,9 @@ def get_user_by_req(

return synapse.types.create_requester(user_id, app_service=app_service)

user_info = yield self.get_user_by_access_token(access_token, rights)
user_info = yield self.get_user_by_access_token(
access_token, rights, allow_expired=allow_expired
)
user = user_info["user"]
token_id = user_info["token_id"]
is_guest = user_info["is_guest"]
Expand Down Expand Up @@ -280,22 +289,28 @@ def _get_appservice_user_id(self, request):
return user_id, app_service

@defer.inlineCallbacks
def get_user_by_access_token(self, token, rights="access"):
def get_user_by_access_token(
self, token: str, rights: str = "access", allow_expired: bool = False,
):
""" Validate access token and get user_id from it
Args:
token (str): The access token to get the user by.
rights (str): The operation being performed; the access token must
allow this.
token: The access token to get the user by
rights: The operation being performed; the access token must
allow this
allow_expired: If False, raises an InvalidClientTokenError
if the token is expired
Returns:
Deferred[dict]: dict that includes:
`user` (UserID)
`is_guest` (bool)
`token_id` (int|None): access token id. May be None if guest
`device_id` (str|None): device corresponding to access token
Raises:
InvalidClientTokenError if a user by that token exists, but the token is
expired
InvalidClientCredentialsError if no user by that token exists or the token
is invalid.
is invalid
"""

if rights == "access":
Expand All @@ -304,7 +319,8 @@ def get_user_by_access_token(self, token, rights="access"):
if r:
valid_until_ms = r["valid_until_ms"]
if (
valid_until_ms is not None
not allow_expired
and valid_until_ms is not None
and valid_until_ms < self.clock.time_msec()
):
# there was a valid access token, but it has expired.
Expand Down Expand Up @@ -575,7 +591,7 @@ async def check_can_change_room_list(self, room_id: str, user: UserID):
return user_level >= send_level

@staticmethod
def has_access_token(request):
def has_access_token(request: Request):
"""Checks if the request has an access_token.
Returns:
Expand All @@ -586,7 +602,7 @@ def has_access_token(request):
return bool(query_params) or bool(auth_headers)

@staticmethod
def get_access_token_from_request(request):
def get_access_token_from_request(request: Request):
"""Extracts the access_token from the request.
Args:
Expand Down
6 changes: 3 additions & 3 deletions synapse/rest/client/v1/logout.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@ def on_OPTIONS(self, request):
return 200, {}

async def on_POST(self, request):
requester = await self.auth.get_user_by_req(request)
requester = await self.auth.get_user_by_req(request, allow_expired=True)

if requester.device_id is None:
# the acccess token wasn't associated with a device.
# The access token wasn't associated with a device.
# Just delete the access token
access_token = self.auth.get_access_token_from_request(request)
await self._auth_handler.delete_access_token(access_token)
Expand All @@ -62,7 +62,7 @@ def on_OPTIONS(self, request):
return 200, {}

async def on_POST(self, request):
requester = await self.auth.get_user_by_req(request)
requester = await self.auth.get_user_by_req(request, allow_expired=True)
user_id = requester.user.to_string()

# first delete all of the user's devices
Expand Down
69 changes: 68 additions & 1 deletion tests/rest/client/v1/test_login.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from mock import Mock

import synapse.rest.admin
from synapse.rest.client.v1 import login
from synapse.rest.client.v1 import login, logout
from synapse.rest.client.v2_alpha import devices
from synapse.rest.client.v2_alpha.account import WhoamiRestServlet

Expand All @@ -20,6 +20,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource,
login.register_servlets,
logout.register_servlets,
devices.register_servlets,
lambda hs, http_server: WhoamiRestServlet(hs).register(http_server),
]
Expand Down Expand Up @@ -256,6 +257,72 @@ def _delete_device(self, access_token, user_id, password, device_id):
self.render(request)
self.assertEquals(channel.code, 200, channel.result)

@override_config({"session_lifetime": "24h"})
def test_session_can_hard_logout_after_being_soft_logged_out(self):
self.register_user("kermit", "monkey")

# log in as normal
access_token = self.login("kermit", "monkey")

# we should now be able to make requests with the access token
request, channel = self.make_request(
b"GET", TEST_URL, access_token=access_token
)
self.render(request)
self.assertEquals(channel.code, 200, channel.result)

# time passes
self.reactor.advance(24 * 3600)

# ... and we should be soft-logouted
request, channel = self.make_request(
b"GET", TEST_URL, access_token=access_token
)
self.render(request)
self.assertEquals(channel.code, 401, channel.result)
self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
self.assertEquals(channel.json_body["soft_logout"], True)

# Now try to hard logout this session
request, channel = self.make_request(
b"POST", "/logout", access_token=access_token
)
self.render(request)
self.assertEquals(channel.result["code"], b"200", channel.result)

@override_config({"session_lifetime": "24h"})
def test_session_can_hard_logout_all_sessions_after_being_soft_logged_out(self):
self.register_user("kermit", "monkey")

# log in as normal
access_token = self.login("kermit", "monkey")

# we should now be able to make requests with the access token
request, channel = self.make_request(
b"GET", TEST_URL, access_token=access_token
)
self.render(request)
self.assertEquals(channel.code, 200, channel.result)

# time passes
self.reactor.advance(24 * 3600)

# ... and we should be soft-logouted
request, channel = self.make_request(
b"GET", TEST_URL, access_token=access_token
)
self.render(request)
self.assertEquals(channel.code, 401, channel.result)
self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
self.assertEquals(channel.json_body["soft_logout"], True)

# Now try to hard log out all of the user's sessions
request, channel = self.make_request(
b"POST", "/logout/all", access_token=access_token
)
self.render(request)
self.assertEquals(channel.result["code"], b"200", channel.result)


class CASTestCase(unittest.HomeserverTestCase):

Expand Down
36 changes: 35 additions & 1 deletion tests/rest/client/v2_alpha/test_register.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from synapse.api.constants import LoginType
from synapse.api.errors import Codes
from synapse.appservice import ApplicationService
from synapse.rest.client.v1 import login
from synapse.rest.client.v1 import login, logout
from synapse.rest.client.v2_alpha import account, account_validity, register, sync

from tests import unittest
Expand Down Expand Up @@ -313,6 +313,7 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
synapse.rest.admin.register_servlets_for_client_rest_resource,
login.register_servlets,
sync.register_servlets,
logout.register_servlets,
account_validity.register_servlets,
]

Expand Down Expand Up @@ -405,6 +406,39 @@ def test_manual_expire(self):
channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result
)

def test_logging_out_expired_user(self):
user_id = self.register_user("kermit", "monkey")
tok = self.login("kermit", "monkey")

self.register_user("admin", "adminpassword", admin=True)
admin_tok = self.login("admin", "adminpassword")

url = "/_matrix/client/unstable/admin/account_validity/validity"
params = {
"user_id": user_id,
"expiration_ts": 0,
"enable_renewal_emails": False,
}
request_data = json.dumps(params)
request, channel = self.make_request(
b"POST", url, request_data, access_token=admin_tok
)
self.render(request)
self.assertEquals(channel.result["code"], b"200", channel.result)

# Try to log the user out
request, channel = self.make_request(b"POST", "/logout", access_token=tok)
self.render(request)
self.assertEquals(channel.result["code"], b"200", channel.result)

# Log the user in again (allowed for expired accounts)
tok = self.login("kermit", "monkey")

# Try to log out all of the user's sessions
request, channel = self.make_request(b"POST", "/logout/all", access_token=tok)
self.render(request)
self.assertEquals(channel.result["code"], b"200", channel.result)


class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):

Expand Down

0 comments on commit 83c0e1f

Please sign in to comment.