/
fields.py
176 lines (141 loc) · 6.73 KB
/
fields.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
import json
import logging
import os
from functools import lru_cache
from gettext import gettext as _
from cryptography.fernet import Fernet, MultiFernet
from django.conf import settings
from django.core.exceptions import ImproperlyConfigured
from django.db.models import FileField, JSONField, Lookup
from django.db.models.fields import Field, TextField
from django.utils.encoding import force_bytes, force_str
from pulpcore.app.files import TemporaryDownloadedFile
from pulpcore.app.loggers import deprecation_logger
_logger = logging.getLogger(__name__)
@lru_cache(maxsize=1)
def _fernet():
# Cache the enryption keys once per application.
_logger.debug(f"Loading encryption key from {settings.DB_ENCRYPTION_KEY}")
with open(settings.DB_ENCRYPTION_KEY, "rb") as key_file:
return MultiFernet(
[
Fernet(key.strip())
for key in key_file.readlines()
if not key.startswith(b"#") and key.strip() != b""
]
)
class ArtifactFileField(FileField):
"""
A custom FileField that always saves files to location specified by 'upload_to'.
The field can be set as either a path to the file or File object. In both cases the file is
moved or copied to the location specified by 'upload_to' field parameter.
"""
def pre_save(self, model_instance, add):
"""
Return FieldFile object which specifies path to the file to be stored in database.
There are two ways to get artifact into Pulp: sync and upload.
The upload case
- file is not stored yet, aka file._committed = False
- nothing to do here in addition to Django pre_save actions
The sync case:
- file is already stored in a temporary location, aka file._committed = True
- it needs to be moved into Pulp artifact storage if it's not there
- TemporaryDownloadedFile takes care of correctly set storage path
- only then Django pre_save actions should be performed
Args:
model_instance (`class::pulpcore.plugin.Artifact`): The instance this field belongs to.
add (bool): Whether the instance is being saved to the database for the first time.
Ignored by Django pre_save method.
Returns:
FieldFile object just before saving.
"""
file = model_instance.file
artifact_storage_path = self.upload_to(model_instance, "")
already_in_place = file.name in [
artifact_storage_path,
os.path.join(settings.MEDIA_ROOT, artifact_storage_path),
]
is_in_artifact_storage = file.name.startswith(os.path.join(settings.MEDIA_ROOT, "artifact"))
if not already_in_place and is_in_artifact_storage:
raise ValueError(
_(
"The file referenced by the Artifact is already present in "
"Artifact storage. Files must be stored outside this location "
"prior to Artifact creation."
)
)
move = file._committed and file.name != artifact_storage_path
if move:
if not already_in_place:
file._file = TemporaryDownloadedFile(open(file.name, "rb"))
file._committed = False
return super().pre_save(model_instance, add)
class EncryptedTextField(TextField):
"""A field mixin that encrypts text using settings.DB_ENCRYPTION_KEY."""
def __init__(self, *args, **kwargs):
if kwargs.get("primary_key"):
raise ImproperlyConfigured("EncryptedTextField does not support primary_key=True.")
if kwargs.get("unique"):
raise ImproperlyConfigured("EncryptedTextField does not support unique=True.")
if kwargs.get("db_index"):
raise ImproperlyConfigured("EncryptedTextField does not support db_index=True.")
super().__init__(*args, **kwargs)
def get_prep_value(self, value):
if value is not None:
assert isinstance(value, str)
value = force_str(_fernet().encrypt(force_bytes(value)))
return super().get_prep_value(value)
def from_db_value(self, value, expression, connection):
if value is not None:
value = force_str(_fernet().decrypt(force_bytes(value)))
return value
class EncryptedJSONField(JSONField):
"""A Field mixin that encrypts the JSON text using settings.DP_ENCRYPTION_KEY."""
def __init__(self, *args, **kwargs):
if kwargs.get("primary_key"):
raise ImproperlyConfigured("EncryptedJSONField does not support primary_key=True.")
if kwargs.get("unique"):
raise ImproperlyConfigured("EncryptedJSONField does not support unique=True.")
if kwargs.get("db_index"):
raise ImproperlyConfigured("EncryptedJSONField does not support db_index=True.")
super().__init__(*args, **kwargs)
def encrypt(self, value):
if isinstance(value, dict):
return {k: self.encrypt(v) for k, v in value.items()}
elif isinstance(value, (list, tuple, set)):
return [self.encrypt(v) for v in value]
return force_str(_fernet().encrypt(force_bytes(json.dumps(value, cls=self.encoder))))
def decrypt(self, value):
if isinstance(value, dict):
return {k: self.decrypt(v) for k, v in value.items()}
elif isinstance(value, (list, tuple, set)):
return [self.decrypt(v) for v in value]
dec_value = force_str(_fernet().decrypt(force_bytes(value)))
try:
return json.loads(dec_value, cls=self.decoder)
except json.JSONDecodeError:
deprecation_logger.info(
"Failed to decode json in an EncryptedJSONField. Falling back to eval. "
"Please run pulpcore-manager rotate-db-key to repair."
"This is deprecated and will be removed in pulpcore 3.40."
)
return eval(dec_value)
def get_prep_value(self, value):
if value is not None:
if hasattr(value, "as_sql"):
return value
value = self.encrypt(value)
return super().get_prep_value(value)
def from_db_value(self, value, expression, connection):
if value is not None:
value = self.decrypt(super().from_db_value(value, expression, connection))
return value
@Field.register_lookup
class NotEqualLookup(Lookup):
# this is copied from https://docs.djangoproject.com/en/3.2/howto/custom-lookups/
lookup_name = "ne"
def as_sql(self, compiler, connection):
lhs, lhs_params = self.process_lhs(compiler, connection)
rhs, rhs_params = self.process_rhs(compiler, connection)
params = lhs_params + rhs_params
return "%s <> %s" % (lhs, rhs), params