Skip to content

Commit

Permalink
datajoint_tools update
Browse files Browse the repository at this point in the history
  • Loading branch information
ChihweiLHBird committed Sep 30, 2020
1 parent 8ca5b46 commit 074ccad
Showing 1 changed file with 22 additions and 4 deletions.
26 changes: 22 additions & 4 deletions pyrfume/datajoint_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ def set_dj_definition(cls, type_map: dict = None) -> None:
_type_map = {
"int": "int",
"str": "varchar(256)",
"float": "float",
"float": "float",
"Quantity": "float",
"datetime": "datetime",
"datetime.datetime": "datetime",
"bool": "tinyint"
Expand All @@ -54,24 +55,41 @@ def set_dj_definition(cls, type_map: dict = None) -> None:
default = getattr(cls, attr)
if isinstance(default, dict):
# Assume the class of objects in some_dict.keys() have corresponding tables in the database
# Assume values of the dict are primitive type which is in the _type_map

# For example, components: Dict[ClassA, int] = {a: 1, b: 2}
# key_cls_name would be "ClassA"
# part_cls_name would be "Component",
# note that the "s" at the end of the dict name will be removed.
#

# skip if type_hint doesn't suggest the type of keys and values in the dict.
if type_hint.__name__ == 'dict':
continue

part_cls_name = attr[0].upper() + attr[1:]
part_cls_name = part_cls_name[:-1] if part_cls_name[-1] == 's' else part_cls_name
key_cls_name = type_hint.__args__[0].__forward_arg__


key_type = type_hint.__args__[0]
value_type = type_hint.__args__[1]

from typing import ForwardRef
key_cls_name = key_type.__forward_arg__ if isinstance(key_type, ForwardRef) else key_type.__name__
value_type = value_type.__forward_arg__ if isinstance(value_type, ForwardRef) else value_type.__name__

assert value_type in _type_map
value_type = _type_map[value_type]

part_cls = type(
part_cls_name,
(dj.Part, object),
{
"definition": """
-> %s
-> %s
""" % (cls.__name__, key_cls_name)
---
value = NULL : %s
""" % (cls.__name__, key_cls_name, value_type)
}
)
cls_dict = dict(vars(cls))
Expand Down

0 comments on commit 074ccad

Please sign in to comment.