Skip to content

Commit

Permalink
refactor(socialaccount): Subproviders
Browse files Browse the repository at this point in the history
  • Loading branch information
pennersr committed Jul 1, 2023
1 parent 75ab261 commit cc5279b
Show file tree
Hide file tree
Showing 66 changed files with 656 additions and 411 deletions.
24 changes: 24 additions & 0 deletions ChangeLog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ Note worthy changes

- New provider: Miro.

- It is now possible to manage OpenID Connect providers via the Django
admin. Simply add a `SocialApp` for each OpenID Connect provider.


Security notice
---------------
Expand All @@ -35,6 +38,27 @@ Backwards incompatible changes
- The Mozilla Persona provider has been removed. The project was shut down on
November 30th 2016.

- A large internal refactor has been performed to be able to add support for
providers oferring one or more subproviders. This refactor has the following
impact:

- The provider registry methods ``get_list()``, ``by_id()`` have been
removed. The registry now only providers access to the provider classes, not
the instances.

- ``provider.get_app()`` has been removed -- use ``provider.app`` instead.

- ``SocialApp.objects.get_current()`` has been removed.

- The ``SocialApp`` model now has additional fields ``provider_id``, and
``settings``.

- The OpenID Connect provider ``SOCIALACCOUNT_PROVIDERS`` settings structure
changed. Instead of the OpenID Connect specific ``SERVERS`` construct, it
now uses the regular ``APPS`` approach. Please refer to the OpenID Connect
documentation for details.



0.54.0 (2023-03-31)
*******************
Expand Down
127 changes: 117 additions & 10 deletions allauth/socialaccount/adapter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
from __future__ import absolute_import

from django.core.exceptions import ValidationError
from django.core.exceptions import (
ImproperlyConfigured,
MultipleObjectsReturned,
ValidationError,
)
from django.db.models import Q
from django.urls import reverse
from django.utils.translation import gettext_lazy as _

Expand Down Expand Up @@ -188,18 +193,120 @@ def deserialize_instance(self, model, data):
def serialize_instance(self, instance):
return serialize_instance(instance)

def get_app(self, request, provider, config=None):
def list_providers(self, request):
from allauth.socialaccount.providers import registry

ret = []
provider_classes = registry.get_class_list()
apps = self.list_apps(request)
apps_map = {}
for app in apps:
apps_map.setdefault(app.provider, []).append(app)
for provider_class in provider_classes:
provider_apps = apps_map.get(provider_class.id, [])
if not provider_apps:
if provider_class.uses_apps:
continue
provider_apps = [None]
for app in provider_apps:
provider = provider_class(request=request, app=app)
ret.append(provider)
return ret

def get_provider(self, request, provider):
"""Looks up a `provider`, supporting subproviders by looking up by
`provider_id`.
"""
from allauth.socialaccount.providers import registry

provider_class = registry.get_class(provider)
if provider_class is None or provider_class.uses_apps:
app = self.get_app(request, provider=provider)
if not provider_class:
# In this case, the `provider` argument passed was a
# `provider_id`.
provider_class = registry.get_class(app.provider)
if not provider_class:
raise ImproperlyConfigured(f"unknown provider: {app.provider}")
return provider_class(request, app=app)
elif provider_class:
assert not provider_class.uses_apps
return provider_class(request, app=None)
else:
raise ImproperlyConfigured(f"unknown provider: {app.provider}")

def list_apps(self, request, provider=None, client_id=None):
"""SocialApp's can be setup in the database, or, via
`settings.SOCIALACCOUNT_PROVIDERS`. This methods returns a uniform list
of all known apps matching the specified criteria, and blends both
(db/settings) sources of data.
"""
# NOTE: Avoid loading models at top due to registry boot...
from allauth.socialaccount.models import SocialApp

config = config or app_settings.PROVIDERS.get(provider, {}).get("APP")
if config:
app = SocialApp(provider=provider)
for field in ["client_id", "secret", "key", "certificate_key"]:
setattr(app, field, config.get(field))
else:
app = SocialApp.objects.get_current(provider, request)
return app
# Map provider to the list of apps.
provider_to_apps = {}

# First, populate it with the DB backed apps.
db_apps = SocialApp.objects.on_site(request)
if provider:
db_apps = db_apps.filter(
Q(provider_id="", provider=provider) | Q(provider_id=provider)
)
if client_id:
db_apps = db_apps.filter(client_id=client_id)
for app in db_apps:
apps = provider_to_apps.setdefault(app.provider, [])
apps.append(app)

# Then, extend it with the settings backed apps.
for p, pcfg in app_settings.PROVIDERS.items():
app_configs = pcfg.get("APPS")
if app_configs is None:
app_config = pcfg.get("APP")
if app_config is None:
continue
app_configs = [app_config]

apps = provider_to_apps.setdefault(p, [])
for config in app_configs:
app = SocialApp(provider=p)
for field in [
"name",
"provider_id",
"client_id",
"secret",
"key",
"certificate_key",
"settings",
]:
if field in config:
setattr(app, field, config[field])
if client_id and app.client_id != client_id:
continue
if (
provider
and app.provider_id != provider
and app.provider != provider
):
continue
apps.append(app)

# Flatten the list of apps.
apps = []
for provider_apps in provider_to_apps.values():
apps.extend(provider_apps)
return apps

def get_app(self, request, provider, client_id=None):
from allauth.socialaccount.models import SocialApp

apps = self.list_apps(request, provider=provider, client_id=client_id)
if len(apps) > 1:
raise MultipleObjectsReturned
elif len(apps) == 0:
raise SocialApp.DoesNotExist()
return apps[0]


def get_adapter(request=None):
Expand Down
30 changes: 29 additions & 1 deletion allauth/socialaccount/app_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,35 @@ def PROVIDERS(self):
"""
Provider specific settings
"""
return self._setting("PROVIDERS", {})
ret = self._setting("PROVIDERS", {})
oidc = ret.get("openid_connect")
if oidc:
ret["openid_connect"] = self._migrate_oidc(oidc)
return ret

def _migrate_oidc(self, oidc):
servers = oidc.get("SERVERS")
if servers is None:
return oidc
ret = {}
apps = []
for server in servers:
app = dict(**server["APP"])
app_settings = {}
if "token_auth_method" in server:
app_settings["token_auth_method"] = server["token_auth_method"]
app_settings["server_url"] = server["server_url"]
app.update(
{
"name": server.get("name", ""),
"provider_id": server["id"],
"settings": app_settings,
}
)
assert app["provider_id"]
apps.append(app)
ret["APPS"] = apps
return ret

@property
def EMAIL_REQUIRED(self):
Expand Down
29 changes: 29 additions & 0 deletions allauth/socialaccount/migrations/0004_app_provider_id_settings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Generated by Django 3.2.19 on 2023-06-30 13:16

from django.db import migrations, models


class Migration(migrations.Migration):
dependencies = [
("socialaccount", "0003_extra_data_default_dict"),
]

operations = [
migrations.AddField(
model_name="socialapp",
name="provider_id",
field=models.CharField(
blank=True, max_length=200, verbose_name="provider ID"
),
),
migrations.AddField(
model_name="socialapp",
name="settings",
field=models.JSONField(blank=True, default=dict),
),
migrations.AlterField(
model_name="socialaccount",
name="provider",
field=models.CharField(max_length=200, verbose_name="provider"),
),
]
51 changes: 33 additions & 18 deletions allauth/socialaccount/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,30 +19,31 @@


class SocialAppManager(models.Manager):
def get_current(self, provider, request=None):
cache = {}
if request:
cache = getattr(request, "_socialapp_cache", {})
request._socialapp_cache = cache
app = cache.get(provider)
if not app:
if allauth.app_settings.SITES_ENABLED:
site = get_current_site(request)
app = self.get(sites__id=site.id, provider=provider)
else:
app = self.get(provider=provider)
cache[provider] = app
return app
def on_site(self, request):
if allauth.app_settings.SITES_ENABLED:
site = get_current_site(request)
return self.filter(sites__id=site.id)
return self.all()


class SocialApp(models.Model):
objects = SocialAppManager()

# The provider type, e.g. "google", "telegram", "saml".
provider = models.CharField(
verbose_name=_("provider"),
max_length=30,
choices=providers.registry.as_choices(),
)
# For providers that support subproviders, such as OpenID Connect and SAML,
# this ID identifies that instance. SocialAccount's originating from app
# will have their `provider` field set to the `provider_id` if available,
# else `provider`.
provider_id = models.CharField(
verbose_name=_("provider ID"),
max_length=200,
blank=True,
)
name = models.CharField(verbose_name=_("name"), max_length=40)
client_id = models.CharField(
verbose_name=_("client id"),
Expand All @@ -58,6 +59,8 @@ class SocialApp(models.Model):
key = models.CharField(
verbose_name=_("key"), max_length=191, blank=True, help_text=_("Key")
)
settings = models.JSONField(default=dict, blank=True)

if allauth.app_settings.SITES_ENABLED:
# Most apps can be used across multiple domains, therefore we use
# a ManyToManyField. Note that Facebook requires an app per domain
Expand All @@ -79,13 +82,18 @@ class Meta:
def __str__(self):
return self.name

def get_provider(self, request):
provider_class = providers.registry.get_class(self.provider)
return provider_class(request=request, app=self)


class SocialAccount(models.Model):
user = models.ForeignKey(allauth.app_settings.USER_MODEL, on_delete=models.CASCADE)
# Given a `SocialApp` from which this account originates, this field equals
# the app's `app.provider_id` if available, `app.provider` otherwise.
provider = models.CharField(
verbose_name=_("provider"),
max_length=30,
choices=providers.registry.as_choices(),
max_length=200,
)
# Just in case you're wondering if an OpenID identity URL is going
# to fit in a 'uid':
Expand Down Expand Up @@ -129,8 +137,15 @@ def get_profile_url(self):
def get_avatar_url(self):
return self.get_provider_account().get_avatar_url()

def get_provider(self):
return providers.registry.by_id(self.provider)
def get_provider(self, request=None):
provider = getattr(self, "_provider", None)
if provider:
return provider
adapter = get_adapter(request)
provider = self._provider = adapter.get_provider(
request, provider=self.provider
)
return provider

def get_provider_account(self):
return self.get_provider().wrap_account(self)
Expand Down
18 changes: 13 additions & 5 deletions allauth/socialaccount/providers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,25 @@
from collections import OrderedDict

from django.apps import apps
from django.conf import settings

from allauth.utils import import_attribute


class ProviderRegistry(object):
def __init__(self):
self.provider_map = OrderedDict()
self.loaded = False

def get_list(self, request=None):
def get_class_list(self):
self.load()
return [provider_cls(request) for provider_cls in self.provider_map.values()]
return list(self.provider_map.values())

def register(self, cls):
self.provider_map[cls.id] = cls

def by_id(self, id, request=None):
self.load()
return self.provider_map[id](request=request)
def get_class(self, id):
return self.provider_map.get(id)

def as_choices(self):
self.load()
Expand All @@ -41,7 +43,13 @@ def load(self):
except ImportError:
pass
else:
provider_settings = getattr(settings, "SOCIALACCOUNT_PROVIDERS", {})
for cls in getattr(provider_module, "provider_classes", []):
provider_class = provider_settings.get(cls.id, {}).get(
"provider_class"
)
if provider_class:
cls = import_attribute(provider_class)
self.register(cls)
self.loaded = True

Expand Down
2 changes: 1 addition & 1 deletion allauth/socialaccount/providers/apple/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class AppleOAuth2Client(OAuth2Client):
def generate_client_secret(self):
"""Create a JWT signed with an apple provided private key"""
now = datetime.utcnow()
app = get_adapter().get_app(self.request, "apple")
app = get_adapter(self.request).get_app(self.request, "apple")
if not app.key:
raise ImproperlyConfigured("Apple 'key' missing")
if not app.certificate_key:
Expand Down

0 comments on commit cc5279b

Please sign in to comment.