-
Notifications
You must be signed in to change notification settings - Fork 57
/
generators.py
271 lines (241 loc) · 11.3 KB
/
generators.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
# -*- mode:python; coding:utf-8 -*-
# Copyright (c) 2020 IBM Corp. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Capabilities to allow the generation of various oscal objects."""
import inspect
import logging
import math
import typing
import uuid
from datetime import date, datetime
from enum import Enum
from typing import Any, Dict, List, Type, TypeVar, Union, cast
import pydantic.v1.networks
from pydantic.v1 import ConstrainedStr
import trestle.common.const as const
import trestle.common.err as err
import trestle.common.type_utils as utils
from trestle.common import str_utils
from trestle.common.str_utils import AliasMode
from trestle.core.base_model import OscalBaseModel
from trestle.oscal import OSCAL_VERSION
from trestle.oscal.common import Base64
from trestle.oscal.common import Base64Datatype
from trestle.oscal.common import Methods
from trestle.oscal.common import ObservationTypeValidValues
from trestle.oscal.common import OscalVersion
from trestle.oscal.common import TaskValidValues
from trestle.oscal.ssp import DateDatatype
logger = logging.getLogger(__name__)
TG = TypeVar('TG', bound=OscalBaseModel)
sample_base64_value = 0
sample_base64 = Base64(filename=const.REPLACE_ME, media_type=const.REPLACE_ME, value=sample_base64_value)
type_base64 = type(sample_base64)
sample_date_value = '2400-02-29'
sample_task_valid_value = TaskValidValues.milestone
sample_method = Methods.EXAMINE
sample_observation_type_valid_value = ObservationTypeValidValues.historic
def safe_is_sub(sub: Any, parent: Any) -> bool:
"""Is this a subclass of parent."""
is_class = inspect.isclass(sub)
return is_class and issubclass(sub, parent)
def is_enum_method(type_: type) -> bool:
"""Test for method."""
rval = False
if utils.get_origin(type_) == Union:
args = typing.get_args(type_)
for arg in args:
if "<enum 'Methods'>" == f'{arg}':
rval = True
break
return rval
def is_enum_task_valid_value(type_: type) -> bool:
"""Test for task valid value."""
rval = False
if utils.get_origin(type_) == Union:
args = typing.get_args(type_)
for arg in args:
if "<enum 'TaskValidValues'>" == f'{arg}':
rval = True
break
return rval
def is_enum_observation_type_valid_value(type_: type) -> bool:
"""Test for observation type valid value."""
rval = False
if utils.get_origin(type_) == Union:
args = typing.get_args(type_)
for arg in args:
if "<enum 'ObservationTypeValidValues'>" == f'{arg}':
rval = True
break
return rval
def generate_sample_value_by_type(
type_: type,
field_name: str,
) -> Union[datetime, bool, int, str, float, Enum]:
"""Given a type, return sample value.
Includes the Optional use of passing down a parent_model
"""
# FIXME: Should be in separate generator module as it inherits EVERYTHING
if is_enum_method(type_):
return sample_method
if is_enum_task_valid_value(type_):
return sample_task_valid_value
if is_enum_observation_type_valid_value(type_):
return sample_observation_type_valid_value
if type_ is Base64:
return sample_base64
if type_ is datetime:
return datetime.now().astimezone()
if type_ is bool:
return False
if type_ is int:
return 0
if type_ is float:
return 0.00
if safe_is_sub(type_, ConstrainedStr) or (hasattr(type_, '__name__') and 'ConstrainedStr' in type_.__name__):
# This code here is messy. we need to meet a set of constraints. If we do
# TODO: handle regex directly
if 'uuid' == field_name:
return str(uuid.uuid4())
# some things like location_uuid in lists arrive here with field_name=''
if type_.regex and type_.regex.pattern.startswith('^[0-9A-Fa-f]{8}'): # type: ignore
return const.SAMPLE_UUID_STR
if field_name == 'date_authorized':
return str(date.today().isoformat())
if field_name == 'oscal_version':
return OSCAL_VERSION
if 'uuid' in field_name:
return const.SAMPLE_UUID_STR
# Only case where are UUID is required but not in name.
if field_name.rstrip('s') == 'member_of_organization':
return const.SAMPLE_UUID_STR
return const.REPLACE_ME
if hasattr(type_, '__name__') and 'ConstrainedIntValue' in type_.__name__:
# create an int value as close to the floor as possible does not test upper bound
multiple = type_.multiple_of if type_.multiple_of else 1 # type: ignore # default to every integer
# this command is a bit of a problem
floor = type_.ge if type_.ge else 0 # type: ignore
floor = type_.gt + 1 if type_.gt else floor # type: ignore
if math.remainder(floor, multiple) == 0:
return floor
return (floor + 1) * multiple
if safe_is_sub(type_, Enum):
# keys and values diverge due to hypens in oscal names
return type_(list(type_.__members__.values())[0]) # type: ignore
if type_ is str:
if field_name == 'oscal_version':
return OSCAL_VERSION
return const.REPLACE_ME
if type_ is pydantic.v1.networks.EmailStr:
return pydantic.v1.networks.EmailStr('dummy@sample.com')
if type_ is pydantic.v1.networks.AnyUrl:
# TODO: Cleanup: this should be usable from a url.. but it's not inuitive.
return pydantic.v1.networks.AnyUrl('https://sample.com/replaceme.html', scheme='http', host='sample.com')
if type_ is list:
raise err.TrestleError(f'Unable to generate sample for type {type_}')
# default to empty dict for anything else
return {} # type: ignore
def is_by_type(model_type: Union[Type[TG], List[TG], Dict[str, TG]]) -> bool:
"""Check for by type."""
rval = False
if model_type == type_base64:
rval = True
return rval
def generate_sample_model(
model: Union[Type[TG], List[TG], Dict[str, TG]], include_optional: bool = False, depth: int = -1
) -> TG:
"""Given a model class, generate an object of that class with sample values.
Can generate optional variables with an enabled flag. Any array objects will have a single entry injected into it.
Note: Trestle generate will not activate recursive loops irrespective of the depth flag.
Args:
model: The model type provided. Typically for a user as an OscalBaseModel Subclass.
include_optional: Whether or not to generate optional fields.
depth: Depth of the tree at which optional fields are generated. Negative values (default) removes the limit.
Returns:
The generated instance with a pro-forma values filled out as best as possible.
"""
effective_optional = include_optional and not depth == 0
model_type = model
# This block normalizes model type down to
if utils.is_collection_field_type(model): # type: ignore
model_type = utils.get_origin(model) # type: ignore
model = utils.get_inner_type(model) # type: ignore
model = cast(TG, model) # type: ignore
model_dict = {} # type: ignore
# this block is needed to avoid situations where an inbuilt is inside a list / dict.
# the only time dict ever appears is with include_all, which is handled specially
# the only type of collection possible after OSCAL 1.0.0 is list
if safe_is_sub(model, OscalBaseModel):
for field in model.__fields__: # type: ignore
if model_type in [OscalVersion]:
model_dict[field] = OSCAL_VERSION
break
if field == 'include_all':
if include_optional:
model_dict[field] = {}
continue
outer_type = model.__fields__[field].outer_type_ # type: ignore
# next appears to be needed for python 3.7
if utils.get_origin(outer_type) == Union:
outer_type = outer_type.__args__[0]
if model.__fields__[field].required or effective_optional: # type: ignore
# FIXME could be ForwardRef('SystemComponentStatus')
if utils.is_collection_field_type(outer_type):
inner_type = utils.get_inner_type(outer_type)
if inner_type == model:
continue
model_dict[field] = generate_sample_model(
outer_type, include_optional=include_optional, depth=depth - 1
)
elif is_by_type(outer_type):
model_dict[field] = generate_sample_value_by_type(outer_type, field)
elif safe_is_sub(outer_type, OscalBaseModel):
model_dict[field] = generate_sample_model(
outer_type, include_optional=include_optional, depth=depth - 1
)
else:
# Handle special cases (hacking)
if model_type in [Base64Datatype]:
model_dict[field] = sample_base64_value
elif model_type in [Base64]:
if field == 'filename':
model_dict[field] = sample_base64.filename
elif field == 'media_type':
model_dict[field] = sample_base64.media_type
elif field == 'value':
model_dict[field] = sample_base64.value
elif model_type in [DateDatatype]:
model_dict[field] = sample_date_value
# Hacking here:
# Root models should ideally not exist, however, sometimes we are stuck with them.
# If that is the case we need sufficient information on the type in order to generate a model.
# E.g. we need the type of the container.
elif field == '__root__' and hasattr(model, '__name__'):
model_dict[field] = generate_sample_value_by_type(
outer_type, str_utils.classname_to_alias(model.__name__, AliasMode.FIELD)
)
else:
model_dict[field] = generate_sample_value_by_type(outer_type, field)
# Note: this assumes list constrains in oscal are always 1 as a minimum size. if two this may still fail.
else:
if model_type is list:
return [generate_sample_value_by_type(model, '')] # type: ignore
if model_type is dict:
return {const.REPLACE_ME: generate_sample_value_by_type(model, '')} # type: ignore
raise err.TrestleError('Unhandled collection type.')
if model_type is list:
return [model(**model_dict)] # type: ignore
if model_type is dict:
return {const.REPLACE_ME: model(**model_dict)} # type: ignore
return model(**model_dict) # type: ignore