-
Notifications
You must be signed in to change notification settings - Fork 2.7k
/
Copy pathmodels.py
153 lines (120 loc) · 5.74 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
from datetime import timedelta
from typing import Any
from annoying.fields import AutoOneToOneField
from core.utils.common import is_community
from django.db import models
from django.utils.translation import gettext_lazy as _
from organizations.models import Organization
from rest_framework_simplejwt.backends import TokenBackend
from rest_framework_simplejwt.exceptions import TokenError
from rest_framework_simplejwt.tokens import RefreshToken
from rest_framework_simplejwt.tokens import api_settings as simple_jwt_settings
class JWTSettings(models.Model):
"""Organization-specific JWT settings for authentication"""
organization = AutoOneToOneField(Organization, related_name='jwt', primary_key=True, on_delete=models.DO_NOTHING)
api_tokens_enabled = models.BooleanField(
_('JWT API tokens enabled'),
default=False,
help_text='Enable JWT API token authentication for this organization',
)
api_token_ttl_days = models.IntegerField(
_('JWT API token time to live (days)'),
default=(200 * 365), # "eternity", 200 years
help_text='Number of days before JWT API tokens expire',
)
legacy_api_tokens_enabled = models.BooleanField(
_('legacy API tokens enabled'),
default=True,
help_text='Enable legacy API token authentication for this organization',
)
created_at = models.DateTimeField(_('created at'), auto_now_add=True)
updated_at = models.DateTimeField(_('updated at'), auto_now=True)
def has_view_permission(self, user):
"""Any member of the organization can view JWT settings."""
return self.organization.has_permission(user)
def has_modify_permission(self, user):
"""Determine who can modify JWT settings.
In label studio enterprise: Only organization owners/admins can modify JWT settings.
In label studio open-source: Any organization member can modify JWT settings.
"""
if not self.organization.has_permission(user):
return False
# open-source
if is_community():
return True
# enterprise
is_owner = user.is_owner if hasattr(user, 'is_owner') else (user.id == self.organization.created_by.id)
is_administrator = hasattr(user, 'is_administrator') and user.is_administrator
return is_owner or is_administrator
def has_permission(self, user):
"""Only organization owners/admins can modify JWT settings."""
return self.has_modify_permission(user)
class LSTokenBackend(TokenBackend):
"""A custom JWT token backend that truncates tokens before storing in the database.
Extends simlpe jwt's TokenBackend to provide methods for generating both
truncated tokens (header + payload only) and full tokens (header + payload + signature).
This preserves privacy of the token by not exposing the signature to the frontend.
"""
def encode(self, payload: dict[str, Any]) -> str:
"""Encode a payload into a truncated JWT token string.
Args:
payload: Dictionary containing the JWT claims to encode
Returns:
A truncated JWT string containing only the header and payload portions,
with the signature section removed
"""
header, payload, signature = super().encode(payload).split('.')
return '.'.join([header, payload])
def encode_full(self, payload: dict[str, Any]) -> str:
"""Encode a payload into a complete JWT token string.
Args:
payload: Dictionary containing the JWT claims to encode
Returns:
A complete JWT string containing header, payload and signature portions
"""
return super().encode(payload)
class LSAPIToken(RefreshToken):
"""API token that utilizes JWT, but stores a truncated version and expires
based on user settings
This token class extends RefreshToken to provide organization-specific token
lifetimes and support for truncated tokens. It uses the LSTokenBackend to
securely store the token (without the signature).
"""
lifetime = timedelta(days=365 * 200) # "eternity" (200 years)
_token_backend = LSTokenBackend(
simple_jwt_settings.ALGORITHM,
simple_jwt_settings.SIGNING_KEY,
simple_jwt_settings.VERIFYING_KEY,
simple_jwt_settings.AUDIENCE,
simple_jwt_settings.ISSUER,
simple_jwt_settings.JWK_URL,
simple_jwt_settings.LEEWAY,
simple_jwt_settings.JSON_ENCODER,
)
def get_full_jwt(self) -> str:
"""Get the complete JWT token string (including the signature).
Returns:
The full JWT token string with header, payload and signature
"""
return self.get_token_backend().encode_full(self.payload)
def blacklist(self):
"""Blacklist this token.
Raises:
rest_framework_simplejwt.exceptions.TokenError: If the token is already blacklisted.
"""
self.check_blacklist()
return super().blacklist()
class TruncatedLSAPIToken(LSAPIToken):
"""Handles JWT tokens that contain only header and payload (no signature).
Used when frontend has access to truncated refresh tokens only."""
def __init__(self, token, *args, **kwargs):
"""Initialize a truncated token, ensuring it has exactly 2 parts before adding a dummy signature."""
# Ensure we have exactly 2 parts (header and payload)
parts = token.split('.')
if len(parts) > 2:
token = '.'.join(parts[:2])
elif len(parts) < 2:
raise TokenError('Invalid Label Studio token')
# Add dummy signature with exactly 43 'x' characters to match expected JWT signature length
token = token + '.' + ('x' * 43)
super().__init__(token, verify=False, *args, **kwargs)