In [1]:
import os
from collections import Counter, OrderedDict

from croniter import (CroniterBadCronError, CroniterBadDateError,
                      CroniterNotAlphaError, croniter)
from marshmallow import (Schema, ValidationError, fields, post_dump,
                         validates_schema)

from uuid import uuid4

In [2]:
def validate_cron(expression):
    """Validate cron format"""
    try:
        _ = croniter(expression)
    except (CroniterNotAlphaError, CroniterBadDateError, CroniterBadCronError):
        raise ValidationError("Invalid cron expression.")

### Why just these 3???
def validate_variable_definition(text):
    """Validate variable definition field."""
    if text not in ['String', 'Price', 'Categorical']:
        message = "Variable definition must be " \
                  "'String', 'Price' or 'Categorical'."
        raise ValidationError(message)

In [3]:
# pylint: disable=no-init
class NoNullDump(object):  # pylint: disable= too-few-public-methods
    """ Strip null values before dumping """
    # pylint: disable=no-self-use
    @post_dump
    def strip_null(self, data):
        """ Strip null values before dumping """
        return {k: v for (k, v) in data.items() if v is not None}
    

class ColumnSchema(Schema, NoNullDump):
    """ Basic column schema """
    name = fields.Str(required=True)
    friendly_name = fields.Str()
    variable_definition = fields.Str(validate=validate_variable_definition)
    resolution_alias = fields.Str()
    hidden = fields.Bool(default=False)
    use_as_label = fields.Bool(default=False)
    
    
class IndexSchema(Schema, NoNullDump):
    """ Basic index column schema, similar to column """
    name = fields.Str(required=True)
    friendly_name = fields.Str()
    variable_definition = fields.Str(validate=validate_variable_definition)
    resolution_alias = fields.Str()


# FIXME: Why is it called a NodeList???
class NodeListSchema(Schema, NoNullDump):
    """ NodeList list schema """
    name = fields.Str(required=True)
    table_name = fields.Str(required=True)
    index_column = fields.Nested(IndexSchema, required=True)
    metadata_columns = fields.Nested(ColumnSchema, many=True, required=False)
    labels = fields.List(fields.Str(), required=False)


class EdgeListSchema(Schema, NoNullDump):
    """ Edge list schema """
    name = fields.Str(required=True)
    table_name = fields.Str(required=True)
    merge_same = fields.Bool(default=True)
    directed = fields.Bool(default=True)
    # FIXME: deprecate index column? IDK...
    index_column = fields.Nested(IndexSchema, required=False)
    source_column = fields.Nested(IndexSchema, required=True)
    target_column = fields.Nested(IndexSchema, required=True)
    weight_column = fields.Nested(ColumnSchema, required=False)
    metadata_columns = fields.Nested(ColumnSchema, many=True, required=False)
    source_metadata_columns = fields.Nested(ColumnSchema, many=True, required=False)
    target_metadata_columns = fields.Nested(ColumnSchema, many=True, required=False)
    labels = fields.List(fields.Str())
    source_labels = fields.List(fields.Str())
    target_labels = fields.List(fields.Str())
    edge_category = fields.Str()


class CommunityListSchema(Schema, NoNullDump):
    """ Community list schema """
    name = fields.Str(required=True)
    table_name = fields.Str(required=True)
    index_column = fields.Nested(IndexSchema, required=True)
    metadata_columns = fields.Nested(ColumnSchema, many=True, required=False)
    labels = fields.List(fields.Str())
    grouping_columns = fields.Nested(ColumnSchema, many=True, required=True)
    edge_labels = fields.List(fields.Str())
    edge_metadata_columns = fields.Nested(ColumnSchema, many=True, required=False)
    edge_category = fields.Str()

    @validates_schema
    def check_grouping_columns(self, data):  # pylint: disable=no-self-use
        """checks at least one grouping column"""
        if len(data["grouping_columns"]) < 1:
            raise ValidationError("Require 1 or more grouping column(s)")


# pylint: disable=no-self-use
class GraphSchema(Schema, NoNullDump):
    """ Graph specification schema """
    name = fields.Str(required=True)
    data_uri = fields.Str(required=False)
    data_uri_env = fields.Str(required=False)
    graph_uri = fields.Str(required=False)
    graph_uri_env = fields.Str(required=False)
    poll = fields.Str(required=True, validate=validate_cron)
    node_lists = fields.Nested(NodeListSchema, many=True, required=False)
    edge_lists = fields.Nested(EdgeListSchema, many=True, required=False)
    community_lists = fields.Nested(CommunityListSchema, many=True, required=False)

    @validates_schema
    def check_uri_and_env(self, data):  # pylint: disable=no-self-use
        """ Additional validation """
        # ensure that not both the URI and environment variables are specified
        # and at least one of them is specified
        for uri, uri_env in [('data_uri', 'data_uri_env'),
                             ('graph_uri', 'graph_uri_env')]:

            if uri in data and uri_env in data:
                raise ValidationError(
                    'Graph specification cannot contain both {uri} and '
                    '{uri_env}. Only one of them should be specified.'
                    .format(uri=uri, uri_env=uri_env))
            elif uri not in data and uri_env not in data:
                raise ValidationError(
                    'Graph specification should contain at least {uri} or '
                    '{uri_env}.'
                    .format(uri=uri, uri_env=uri_env))

In [4]:
def from_dict_generic(data, schema, clsobj):
    """Generic implementation taking schema-conforming data and instatiating class."""
    if data:
        validated_data, errors = schema().load(data)
        if errors:
            raise ValueError(errors)
        return clsobj(**validated_data)
    else:
        return None

def to_dict_generic(schema, clsobj):
    """Reverse of 'from_dict' with additional auto-generated fields."""
    data, errors = schema().dump(clsobj)
    if errors:
        raise ValueError(errors)
    else:
        return data

In [5]:
class ColumnSpec(object):
    """Column specification model"""
    # pylint: disable=too-many-arguments
    def __init__(self, **kwargs):
        self.name = kwargs["name"]
        self.friendly_name = kwargs.get("friendly_name")
        self.variable_definition = kwargs.get("variable_definition")
        self.resolution_alias = kwargs.get("resolution_alias")
        self.hidden = kwargs.get("hidden") or False
        self.use_as_label = kwargs.get("use_as_label") or False
        
    @classmethod
    def from_dict(cls, data):
        return from_dict_generic(data, ColumnSchema, cls)
    
    def to_dict(self):
        return to_dict_generic(ColumnSchema, self)

In [6]:
class BaseListSpec(object):
    """Base list specification"""
    # pylint: disable=too-many-arguments
    def __init__(self, **kwargs):
        self.name = kwargs["name"]
        self.table_name = kwargs["table_name"]
        self.index_column = ColumnSpec.from_dict(kwargs.get("index_column"))
        self.metadata_columns = [
            ColumnSpec.from_dict(col) for col in kwargs.get("metadata_columns") or []
        ]
        self.labels = kwargs.get("labels") or []

In [7]:
class NodeListSpec(BaseListSpec):
    """Node list specification"""
    # pylint: disable=too-many-arguments
    def __init__(self, **kwargs):
        super(NodeListSpec, self).__init__(**kwargs)
        
    @classmethod
    def from_dict(cls, data):
        return from_dict_generic(data, NodeListSchema, cls)
    
    def to_dict(self):
        return to_dict_generic(NodeListSchema, self)

In [8]:
class EdgeListSpec(BaseListSpec):
    """Edge list specification"""
    # pylint: disable=too-many-arguments
    def __init__(self, **kwargs):
        super(EdgeListSpec, self).__init__(**kwargs)
        # merge_same
        self.directed = kwargs.get("directed") or True
        self.source_column = ColumnSpec.from_dict(kwargs["source_column"])
        self.target_column = ColumnSpec.from_dict(kwargs["target_column"])
        # weight_column
        self.source_metadata_columns = [
            ColumnSpec.from_dict(col) for col in kwargs.get("source_metadata_columns") or []
        ]
        self.target_metadata_columns = [
            ColumnSpec.from_dict(col) for col in kwargs.get("target_metadata_columns") or []
        ]
        self.source_labels = kwargs.get("source_labels") or []
        self.target_labels = kwargs.get("target_labels") or []
        self.edge_category = kwargs.get("edge_category") or str(uuid4())
        
    @classmethod
    def from_dict(cls, data):
        return from_dict_generic(data, EdgeListSchema, cls)
    
    def to_dict(self):
        return to_dict_generic(EdgeListSchema, self)

In [9]:
class CommunityListSpec(BaseListSpec):
    """Community list model - list of nodes to the community group they belong to"""
    # pylint: disable=too-many-arguments
    def __init__(self, **kwargs):
        super(CommunityListSpec, self).__init__(**kwargs)
        self.grouping_columns = kwargs["grouping_columns"]
        self.edge_labels = kwargs.get("edge_labels") or []
        self.edge_metadata_columns = [
            ColumnSpec.from_dict(col) for col in kwargs.get("edge_metadata_columns") or []
        ]
        self.edge_category = kwargs.get("edge_category") or str(uuid4())
        
    @classmethod
    def from_dict(cls, data):
        return from_dict_generic(data, CommunityListSchema, cls)
    
    def to_dict(self):
        return to_dict_generic(CommunityListSchema, self)

In [10]:
class GraphSpec(object):
    """Graph specification model"""
    # pylint: disable=too-many-instance-attributes
    def __init__(self, **kwargs):
        self.name = kwargs["name"]
        self.data_uri = kwargs.get("data_uri") or os.environ.get(kwargs.get("data_uri_env"))
        self.graph_uri = kwargs.get("graph_uri") or os.environ.get(kwargs.get("graph_uri_env"))
        self.poll = kwargs.get("poll") or "0 0 * * *"
        self.node_lists = [
            NodeListSpec.from_dict(d) for d in kwargs.get("node_lists") or []
        ]
        self.edge_lists = [
            EdgeListSpec.from_dict(d) for d in kwargs.get("edge_lists") or []
        ]
        self.community_lists = [
            CommunityListSpec.from_dict(d) for d in kwargs.get("community_lists") or []
        ]
        
    @classmethod
    def from_dict(cls, data):
        return from_dict_generic(data, GraphSchema, cls)
    
    def to_dict(self):
        return to_dict_generic(GraphSchema, self)

In [11]:
import json
data = json.load(open(os.path.join(os.getcwd(), "data", "test_graph_spec.json"), "r"))

In [13]:
GraphSpec.from_dict(data).to_dict()

{'data_uri': 'data_uri_value',
 'community_lists': [{'edge_labels': ['edge_group_type'],
   'labels': ['toffee_groups'],
   'metadata_columns': [],
   'table_name': 'test_data_toffee_groups_list',
   'grouping_columns': [{'hidden': False,
     'use_as_label': False,
     'friendly_name': 'common_group',
     'name': 'group'}],
   'edge_category': '4f348cf8-817d-40b9-a527-a9447869be41',
   'edge_metadata_columns': [],
   'index_column': {'resolution_alias': 'sweets',
    'variable_definition': 'String',
    'name': 'id'},
   'name': 'toffee group nodes'}],
 'poll': '0 0 * * *',
 'node_lists': [{'labels': ['chocolate'],
   'metadata_columns': [],
   'table_name': 'test_data_chocolate_node_list',
   'index_column': {'resolution_alias': 'chocolate',
    'variable_definition': 'String',
    'name': 'id'},
   'name': 'chocolate nodes'},
  {'labels': ['sweets'],
   'metadata_columns': [{'hidden': False,
     'use_as_label': False,
     'friendly_name': 'sweetness number',
     'variable_defin

In [189]:
CommunityListSpec.from_dict(data["community_lists"][0]).to_dict()

{'index_column': {'name': 'id',
  'variable_definition': 'String',
  'resolution_alias': 'sweets'},
 'grouping_columns': [{'use_as_label': False,
   'hidden': False,
   'name': 'group',
   'friendly_name': 'common_group'}],
 'labels': ['toffee_groups'],
 'table_name': 'test_data_toffee_groups_list',
 'name': 'toffee group nodes',
 'metadata_columns': []}

In [195]:
NodeListSpec.from_dict(data["node_lists"][1]).to_dict()

{'index_column': {'name': 'id',
  'variable_definition': 'String',
  'resolution_alias': 'sweets'},
 'labels': ['sweets'],
 'table_name': 'test_data_sweets_node_list',
 'name': 'sweets nodes',
 'metadata_columns': [{'variable_definition': 'String',
   'use_as_label': False,
   'hidden': False,
   'name': 'prop',
   'friendly_name': 'sweetness number'}]}

In [177]:
data["edge_lists"][0]

{'source_column': {'variable_definition': 'String', 'name': 'chocolate_s'},
 'name': 'chocolate_relations',
 'labels': ['chocolate'],
 'table_name': 'test_data_chocolate_edge_list',
 'target_column': {'variable_definition': 'String', 'name': 'chocolate_t'}}

In [182]:
EdgeListSpec.from_dict(data["edge_lists"][4])

IndexError: list index out of range