In [32]:
import abc
from enum import Enum
from abc import ABC, abstractmethod
from typing import Tuple, Optional, Dict

# IMPLEMENTATION: Core Neural Type system classes 

In [33]:
class NeuralTypeComparisonResult(Enum):
    """The result of comparing two neural type objects for compatibility.
    When comparing A.compare_to(B):"""

    SAME = 0
    LESS = 1  # A is B
    GREATER = 2  # B is A
    DIM_INCOMPATIBLE = 3  # Resize connector might fix incompatibility
    TRANSPOSE_SAME = 4  # A transpose and/or converting between lists and tensors will make them same
    CONTAINER_SIZE_MISMATCH = 5 # A and B contain different number of elements
    INCOMPATIBLE = 6 # A and B are incompatible   
    SAME_TYPE_INCOMPATIBLE_PARAMS = 7 # A and B are of the same type but parametrized differently

In [34]:
class AxisKindAbstract(Enum):
    """This is an abstract Enum to represents what does varying axis dimension mean.
    In practice, you will almost always use AxisKind Enum. This Enum should be inherited by
    your OWN Enum if you aren't satisfied with AxisKind. Then your own Enum can be used 
    instead of AxisKind."""
    pass

class AxisKind(AxisKindAbstract):
    """This Enum represents what does varying axis dimension mean.
    For example, does this dimension correspond to width, batch, time, etc."""
    Batch = 0
    Time = 1
    Dimension = 2
    Width = 3
    Height = 4

    def __str__(self):
        return str(self.name).lower()

    @staticmethod
    def from_str(label):
        """Returns AxisKind instance based on short string representation"""
        _label = label.lower().strip()
        if _label == "b" or _label == "n" or _label == "batch":
            return AxisKind.Batch
        elif _label == "t" or _label == "time":
            return AxisKind.Time
        elif _label == "d" or _label == "c" or _label == "channel":
            return AxisKind.Dimension
        elif _label == "w" or _label == "width":
            return AxisKind.Width
        elif _label == "h" or _label == "height":
            return AxisKind.Height
        else:
            raise ValueError(f"Can't create AxisKind from {label}")
            

In [35]:
class AxisType(object):
    """This class represents axis semantics and (optionally) it's dimensionality
       Args:
           kind (AxisKindAbstract): 
           size (int, optional): 
           is_list (bool, default=False): 
    """
    def __init__(self, 
                 kind: AxisKindAbstract, 
                 size: Optional[int], 
                 is_list=False):
        if size is not None and is_list:
            raise ValueError("The axis can't be list and have a fixed size")
        self.kind = kind
        self.size = size
        self.is_list = is_list

In [36]:
class ElementType(ABC):
    """Abstract class defining semantics of the tensor elements.
    We are replying on Python for inheritance checking"""    
    @abstractmethod    
    def __str__(cls):
        pass
       
    @property        
    def type_parameters(self) -> Dict:
        """Override this property to parametrize your type"""
        return {}
    
    @property
    def fields(self) -> Optional[Tuple]:
        return None
   
    def compare(self, second) -> NeuralTypeComparisonResult:
        # First, check general compatibility
        result = NeuralTypeComparisonResult.SAME
        first_t = type(self)
        second_t = type(second)
        
        if first_t == second_t:
            result = NeuralTypeComparisonResult.SAME            
        elif issubclass(first_t, second_t):
            result = NeuralTypeComparisonResult.LESS
        elif issubclass(second_t, first_t):
            result = NeuralTypeComparisonResult.GREATER
        else:
            result = NeuralTypeComparisonResult.INCOMPATIBLE        
        
        if result != NeuralTypeComparisonResult.SAME:
            return result
        else:
            # now check that all parameters match    
            check_params = set(self.type_parameters.keys()) == set(second.type_parameters.keys())
            if check_params is False:
                return NeuralTypeComparisonResult.SAME_TYPE_INCOMPATIBLE_PARAMS
            else:
                for k1, v1 in self.type_parameters.items():
                    if v1 != second.type_parameters[k1]:
                        return NeuralTypeComparisonResult.SAME_TYPE_INCOMPATIBLE_PARAMS
            # check that all fields match
            if self.fields == second.fields:
                return NeuralTypeComparisonResult.SAME
            else:
                return NeuralTypeComparisonResult.INCOMPATIBLE

class VoidType(ElementType):
    """Void-like type which is compatible with everything
    """
    def __str__(self):
        return str("void type. compatible with everything")    
    
    def compare(cls, second: abc.ABCMeta) -> NeuralTypeComparisonResult:
        return NeuralTypeComparisonResult.SAME    

In [37]:
class NeuralType(object):
    """This is the main class which would represent neural type concept.
    nmTensors derives from this. It is used to represent *the types* of inputs and outputs."""
    def __init__(self, elements_type: ElementType, axes: Tuple, optional=False):
        self.__check_sanity(axes)
        self.elements_type = elements_type
        axes_list = []
        for axis in axes:
            if isinstance(axis, str):
                axes_list.append(AxisType(AxisKind.from_str(axis), None))
            elif isinstance(axis, AxisType):
                axes_list.append(axis)
            else:
                raise ValueError(f"axis type must be either str or AxisType instance")
        self.axes_tuple = tuple(axes_list)
        self.optional = optional
    
    def compare(self, second) -> NeuralTypeComparisonResult:
        # First, handle dimensionality
        axes_a = self.axes_tuple
        axes_b = second.axes_tuple
        
        kinds_a = dict()
        kinds_b = dict()      
                        
        dimensions_pass = True
        for axis_a, axis_b in zip(axes_a, axes_b):
            kinds_a[axis_a.kind] = axis_a.size
            kinds_b[axis_b.kind] = axis_b.size
            if axis_a.kind != axis_b.kind or axis_a.is_list != axis_b.is_list:
                dimensions_pass = False
        
        if kinds_a.keys() != kinds_b.keys():
            return NeuralTypeComparisonResult.INCOMPATIBLE
        for kind, size in kinds_a.items():
            if size != kinds_b[kind]:
                return NeuralTypeComparisonResult.DIM_INCOMPATIBLE        

        element_comparison_result = self.elements_type.compare(second.elements_type)
        if dimensions_pass:
            return element_comparison_result
        elif element_comparison_result == NeuralTypeComparisonResult.SAME:
            return NeuralTypeComparisonResult.TRANSPOSE_SAME
        else:
            return NeuralTypeComparisonResult.INCOMPATIBLE
    
    def __check_sanity(self, axes):
        # check that list come before any tensor dimension
        are_strings = True
        for axis in axes:
            if not isinstance(axis, str):
                are_strings = False
            if isinstance(axis, str) and not are_strings:
                raise ValueError("Either use full class names or all strings")
        if are_strings:
            return
        checks_passed = True
        saw_tensor_dim = False
        for axis in axes:            
            if not axis.is_list:
                saw_tensor_dim = True
            else: # current axis is a list
                if saw_tensor_dim: # which is preceded by tensor dim
                    checks_passed=False
        if not checks_passed:
            raise ValueError("You have list dimension after Tensor dimension. All list dimensions must preceed Tensor dimensions")

# Usage examples

#### Example 1. Define Jasper Encoder's outputs

In [38]:
class ChannelType(ElementType):
    def __str__(self):
        return "convolutional channel value"

class AcousticEncodedRepresentation(ChannelType):    
    def __str__(self):
        return "encoded representation returned by the acoustic encoder model"

In [39]:
T1 = NeuralType(elements_type=AcousticEncodedRepresentation(), 
                axes=(AxisType(AxisKind.Batch, None), 
                      AxisType(AxisKind.Dimension, None),  
                      AxisType(AxisKind.Time, None)))

In [40]:
T2 = NeuralType(AcousticEncodedRepresentation(), ('B', 'D', 'T'))

In [41]:
T1.compare(T2)

<NeuralTypeComparisonResult.SAME: 0>

#### Example 2. Define AudioData layer output - support for frequency

In [42]:
# PARAMETRIZED TYPE Example
class AudioSignal(ElementType):    
    def __str__(self):
        return "encoded representation returned by the acoustic encoder model"
    
    def __init__(self, freq=16000):
        self._params = {}
        self._params['freq'] = freq
    
    @property
    def type_parameters(self):
        return self._params


T1 = NeuralType(AudioSignal(16000), axes=('B', 'T'))
T2 = NeuralType(AudioSignal(8000), axes=('B', 'T'))

In [43]:
T1.compare(T2)

<NeuralTypeComparisonResult.SAME_TYPE_INCOMPATIBLE_PARAMS: 7>

#### Example 3. Transpose Same Example

In [44]:
T3 = NeuralType(AudioSignal(8000), axes=('T', 'B'))
T3.compare(T2)

<NeuralTypeComparisonResult.TRANSPOSE_SAME: 4>

#### Example 4. Input to the data Augmentation module (SpecAugment)

In [45]:
class SpectogramType(ChannelType):
    def __str__(self):
        return "generic spectorgram type"
    
class MelSpectogramType(SpectogramType):
    def __str__(self):
        return "mel spectorgram type"

class MFCCSpectogramType(SpectogramType):
    def __str__(self):
        return "mfcc spectorgram type"

In [46]:
SpecAugmentInput = NeuralType(SpectogramType(), ('B', 'D', 'T'))
DL1_out = NeuralType(MelSpectogramType(), ('B', 'D', 'T'))
DL2_out = NeuralType(MFCCSpectogramType(), ('B', 'D', 'T'))

In [47]:
DL1_out.compare(DL2_out)

<NeuralTypeComparisonResult.INCOMPATIBLE: 6>

In [48]:
SpecAugmentInput.compare(DL1_out)

<NeuralTypeComparisonResult.GREATER: 2>

In [49]:
SpecAugmentInput.compare(DL2_out)

<NeuralTypeComparisonResult.GREATER: 2>

#### Example 5. List of Lists of 3 dimensional matrices

In [50]:
T1=NeuralType(elements_type=ChannelType(), 
              axes=(AxisType(kind=AxisKind.Batch, size=None, is_list=True),                     
                    AxisType(kind=AxisKind.Time, size=None, is_list=True),
                    AxisType(kind=AxisKind.Dimension, size=32, is_list=False),
                    AxisType(kind=AxisKind.Dimension, size=128, is_list=False),
                    AxisType(kind=AxisKind.Dimension, size=256, is_list=False)))

In [51]:
T2=NeuralType(elements_type=ChannelType(), 
              axes=(AxisType(kind=AxisKind.Batch, size=None, is_list=False),                     
                    AxisType(kind=AxisKind.Time, size=None, is_list=False),
                    AxisType(kind=AxisKind.Dimension, size=32, is_list=False),
                    AxisType(kind=AxisKind.Dimension, size=128, is_list=False),
                    AxisType(kind=AxisKind.Dimension, size=256, is_list=False)))

In [52]:
# TODO: SHOULD THIS BE TRANSFORM_SAME???
T2.compare(T1)

<NeuralTypeComparisonResult.TRANSPOSE_SAME: 4>

#### Example 6. Structs

In [53]:
class BoundingBox(ElementType):    
    def __str__(self):
        return "bounding box from detection model"
    def fields(self):
        return ("X", "Y", "W", "H")

# ALSO ADD new, user-defined, axis kind
class AxisKind2(AxisKindAbstract):
    Image = 0

In [54]:
T1=NeuralType(elements_type=BoundingBox(), 
              axes=(AxisType(kind=AxisKind.Batch, size=None, is_list=True),                     
                    AxisType(kind=AxisKind2.Image, size=None, is_list=True)))

In [55]:
class BadBoundingBox(ElementType):    
    def __str__(self):
        return "bad bounding box from detection model"
    def fields(self):
        return ("X", "Y", "H")

In [56]:
T2=NeuralType(elements_type=BoundingBox(), 
              axes=(AxisType(kind=AxisKind.Batch, size=None, is_list=True),                     
                    AxisType(kind=AxisKind2.Image, size=None, is_list=True)))

In [57]:
T2.compare(T1)

<NeuralTypeComparisonResult.INCOMPATIBLE: 6>