# TABLES

In [2]:
%load_ext memory_profiler
%load_ext snakeviz
%load_ext cython
%load_ext autoreload
%autoreload 2


from IPython.core import debugger
ist = debugger.set_trace

In [3]:
from py.typyMagics import *
ipy = get_ipython()
ipy.register_magics(typyMagics)

In [4]:
import sys
sys.path.insert(0,'../')

## DEFINE

In [5]:
%%run_and_write ../typyPRISM/core/Table.py
class Table:
    '''Baseclass used to define tables of parameters
    
    This class should used/instatiated directly. It is intended
    to be only inherited.
    
    '''
    def listify(self,values):
        '''Helper fuction that converts any input into a list of inputs.
        
        The purpose of this function is to help with iterating over types,
        and handling the case of a single "str" type being passed. 
        '''
        if isinstance(values,str):
            values = [values]
        else:
            try:
                iter(values)
            except TypeError:
                values = [values]
            else:
                values = list(values)
        return values
        

Overwriting ../typyPRISM/core/Table.py


In [17]:
%%run_and_write ../typyPRISM/core/PairTable.py
from typyPRISM.core.Table import Table
from typyPRISM.core.MatrixArray import MatrixArray
from typyPRISM.core.Space import Space
from itertools import product
import numpy as np

class PairTable(Table):
    '''Container for data that is keyed by pairs of types
    
    Since PRISM is a theory based in *pair*-correlation functions, it 
    follows that many of the necessary parameters of the theory are specified
    between the pairs of types. This goal of this container is to make setting,
    getting, and checking these data easy.
    
    To start, setter/getter methods have been set up to set groups of types 
    simultaneously. For example::
        
            PT = PairTable(['A','B','C','D'],'density',symmetric=True)
            
            # The following sets the 'A-C' and 'B-C' pairs to be 0.5. Also, since
            # we set symmetric=True above, 'C-A' and 'C-B' is also set
            PT[['A','B'],'['C']] = 0.5
    
    This allows for the rapid construction of datasets where many of the parameters
    are repeated. 
    
    Parameters
    ----------
    types: list
        Lists of the types that will be used to key the PairTable. The length of this
        list should be equal to the rank of the PRISM problem to be solved i.e. 
        len(types) == number of sites in system
        
    name: string
        The name of the PairTable. This is simply used as a convencience for identifying
        the table internally. 
    
    symmetric: bool
        If True, the table will automatically set both off-diagonal values during
        assignment e.g. PT['A','B'] = 5 will set 'A-B' and 'B-A'
    

    
    '''
    def __init__(self,types,name,symmetric=True):
        self.types = types
        self.symmetric = symmetric
        self.name = name
        self.values = {t1:{t2:None for t2 in types} for t1 in types}
    
    def __repr__(self):
        return '<PairTable: {}>'.format(self.name)
        
    def __iter__(self):
        for (i,t1),(j,t2) in product(enumerate(self.types),enumerate(self.types)):
            yield (i,j),(t1,t2),self.values[t1][t2]
            
    def __getitem__(self,index):
        t1,t2 = index
        return self.values[t1][t2]
    
    def __setitem__(self,index,value):
        types1,types2 = index
        for t1 in self.listify(types1):
            for t2 in self.listify(types2):
                self.values[t1][t2] = value
                if self.symmetric and t1!=t2:
                    self.values[t2][t1] = value
            
    def check(self):
        '''Is everything in the table set?'''
        for i,t,val in self.iterpairs():
            if val is None:
                raise ValueError('PairTable {} is not fully specified!'.format(self.name))
            
            
    def iterpairs(self,full=False,diagonal=True):
        '''Convenience function for looping over table pairs.
        
        Parameters
        ----------
        full: bool
            If True, all i,j pairs (upper and lower diagonal) will be looped over
            
        diagonal: bool
            If True, the i==j (on-diagonal) pairs will be considered when looping
        
        '''
        
        if full:
            test = lambda i,j: True
        elif diagonal:
            test = lambda i,j: i<=j
        else:
            test = lambda i,j: i<j
            
        for (i,j),(t1,t2),val in self.__iter__():
            if test(i,j):
                yield (i,j),(t1,t2),(val)
                
    def setUnset(self,value):
        '''Set all values that have not been specified to a value'''
        for i,(t1,t2),v in self.iterpairs():
            if v is None:
                self[t1,t2] = value
                
    def exportToMatrixArray(self,space=Space.Real):
        '''Convenience function for converting a table of arrays to a MatrixArray'''
        lengths = []
        for i,t,val in self.iterpairs():
            if val is None:
                raise ValueError('Can\'t export not-fully specified Table {}'.format(self.name))
            lengths.append(len(val))
            
        if not len(set(lengths))<=1:
            raise ValueError('Arrays in Table are not all the same length. Aborting export.')
        
        length = lengths[0]
        rank = len(self.types)
        MA = MatrixArray(length=length,rank=rank,space=space)
        
        for (i,j),t,val in self.iterpairs():
            MA[i,j] = val
        return MA
    
    def apply(self,funk):
        '''Apply a function to all elements in the table in place
        
        Parameters
        ----------
        funk: any object with __call__ method
            function to be called on all table elements
        
        '''
        for i,(t1,t2),val in self.iterpairs():
            self[t1,t2] = funk(val)
        

        

Overwriting ../typyPRISM/core/PairTable.py


In [12]:
%%run_and_write ../typyPRISM/core/ValueTable.py
from typyPRISM.core.Table import Table
import numpy as np

class ValueTable(Table):
    '''Container for data that is keyed by types
    
    Parameters
    ----------
    types: list
        Lists of the types that will be used to key the ValueTable. The length of this
        list should be equal to the rank of the PRISM problem to be solved i.e. 
        len(types) == number of sites in system
        
    name: string
        The name of the ValueTable. This is simply used as a convencience for identifying
        the table internally. 
    
    
    '''
    def __init__(self,types,name):
        self.types = types
        self.name = name
        self.values = {t:None for t in types}
    
    def __repr__(self):
        return '<ValueTable: {}>'.format(self.name)
        
    def __iter__(self):
        for i,t in enumerate(self.types):
            yield i,t,self.values[t]
            
    def __getitem__(self,index):
        t = index
        return self.values[t]
    
    def __setitem__(self,index,value):
        types1 = index
        for t in self.listify(types1):
            self.values[t] = value
            
    def check(self):
        '''Is everything in the table set?'''
        for i,t,val in self:
            if val is None:
                raise ValueError('ValueTable {} is not fully specified!'.format(self.name))
            
            
    def setUnset(self,value):
        '''Set all values that have not been specified to a value'''
        for i,t,v in self:
            if v is None:
                self[t] = value


Overwriting ../typyPRISM/core/ValueTable.py


In [18]:
%%run_and_write ../typyPRISM/test/PairTable_TestCase.py
import unittest
from typyPRISM.core.PairTable import PairTable
from typyPRISM.core.MatrixArray import MatrixArray
import numpy as np

class PairTable_TestCase(unittest.TestCase):
    def test_get_set(self):
        '''Can we set and get values from the table?'''
        PT = PairTable(['A','B','C'],'density')
        PT['A','B'] = 0.4
        PT['A',['A','C']] = 0.25
        
        self.assertEqual(PT['A','B'],0.4)
        self.assertEqual(PT['B','A'],0.4)
        self.assertEqual(PT['A','A'],0.25)
        self.assertEqual(PT['A','C'],0.25)
        self.assertEqual(PT['C','A'],0.25)
        
    def test_check(self):
        '''Can we check to make sure the table is filled?'''
        PT = PairTable(['A','B','C'],'density')
        PT['A','B'] = 0.4
        PT['A',['A','C']] = 0.25
        self.assertRaises(ValueError,PT.check)
        
    def test_apply(self):
        '''Can we apply a function to the table?'''
        PT = PairTable(['A','B','C'],'density')
        PT['A','B'] = 0.4
        PT.setUnset(0.25)
        
        PT.apply(np.square)
        
        self.assertEqual(PT['A','B'],0.4*0.4)
        self.assertEqual(PT['B','A'],0.4*0.4)
        self.assertEqual(PT['A','A'],0.25*0.25)
        self.assertEqual(PT['A','C'],0.25*0.25)
        self.assertEqual(PT['C','A'],0.25*0.25)
        self.assertEqual(PT['C','C'],0.25*0.25)
        
        
        
        
    def test_setUnset(self):
        '''Can wet set all of the unset table values?'''
        PT = PairTable(['A','B','C'],'density')
        PT['A','B'] = 0.4
        PT.setUnset(0.25)
        
        self.assertEqual(PT['A','B'],0.4)
        self.assertEqual(PT['B','A'],0.4)
        self.assertEqual(PT['A','A'],0.25)
        self.assertEqual(PT['A','C'],0.25)
        self.assertEqual(PT['C','A'],0.25)
        self.assertEqual(PT['C','C'],0.25)
        
    def test_iter_full(self):
        '''Can we iterate over all of the table pairs?'''
        types = ['A','B','C']
        ntypes = len(types)
        PT = PairTable(types,'density')
        
        numericPairs = [
                         (0,0),(0,1),(0,2),
                         (1,0),(1,1),(1,2),
                         (2,0),(2,1),(2,2)
                       ]
        
        counter = 0
        for i,t,v in PT.iterpairs(full=True):
            if i in numericPairs:
                numericPairs.remove(i)
            counter+=1
            
        #did we visit all expected pairs?
        self.assertEqual(len(numericPairs),0)
        
        #sanity check, did we visit the correct number of pairs?
        self.assertEqual(counter,ntypes*ntypes)
        
    def test_iter_diagonal(self):
        '''Can we iterate over the upper triangle + diagonal?'''
        types = ['A','B','C']
        ntypes = len(types)
        PT = PairTable(types,'density')
        
        numericPairs = [
                         (0,0),(0,1),(0,2),
                               (1,1),(1,2),
                                     (2,2)
                       ]
        
        counter = 0
        for i,t,v in PT.iterpairs(diagonal=True):
            if i in numericPairs:
                numericPairs.remove(i)
            counter+=1
            
        #did we visit all expected pairs?
        self.assertEqual(len(numericPairs),0)
        
        #sanity check, did we visit the correct number of pairs?
        self.assertEqual(counter,ntypes*(ntypes+1)//2)
        
    def test_iter_triangle(self):
        '''Can we iterate over only the upper triangle?'''
        types = ['A','B','C']
        ntypes = len(types)
        PT = PairTable(types,'density')
        
        numericPairs = [
                         (0,1),(0,2),
                               (1,2),
                                    
                       ]
        
        counter = 0
        for i,t,v in PT.iterpairs(diagonal=False):
            if i in numericPairs:
                numericPairs.remove(i)
            counter+=1
            
        #did we visit all expected pairs?
        self.assertEqual(len(numericPairs),0)
        
        #sanity check, did we visit the correct number of pairs?
        self.assertEqual(counter,ntypes*(ntypes-1)//2)
    
    def test_MatrixArray_export(self):
        types = ['A','B','C']
        length = 1024
        rank = len(types)
        values1 = np.ones(length)
        values2 = np.ones(length)*5.0
        values3 = np.ones(length)*2.1234
        
        MA1 = MatrixArray(length=length,rank=rank,space=Space.Fourier)
        MA1[0,0] = values2
        MA1[0,1] = values1
        MA1[0,2] = values1
        MA1[1,1] = values3
        MA1[1,2] = values3
        MA1[2,2] = values3
        
        ntypes = len(types)
        PT = PairTable(types,'density')
        PT[['A'],['B','C']] = values1
        PT[['A'],['A']] = values2
        PT.setUnset(values3)
        MA2 = PT.exportToMatrixArray(space=Space.Fourier)
        
        np.testing.assert_array_almost_equal(MA1.data,MA2.data)
        self.assertEqual(MA1.space,MA2.space)
        
        
        
        


Overwriting ../typyPRISM/test/PairTable_TestCase.py


In [19]:
%%run_and_write ../typyPRISM/test/ValueTable_TestCase.py
import unittest
from typyPRISM.core.ValueTable import ValueTable
import numpy as np

class ValueTable_TestCase(unittest.TestCase):
    def test_get_set(self):
        '''Can we set and get values from the table?'''
        VT = ValueTable(['A','B','C'],'density')
        VT['A'] = 0.4
        VT[['B','C']] = 0.25
        
        self.assertEqual(VT['A'],0.4)
        self.assertEqual(VT['B'],0.25)
        self.assertEqual(VT['C'],0.25)
        
    def test_check(self):
        '''Can we check to make sure the table is filled?'''
        VT = ValueTable(['A','B','C'],'density')
        VT[['A','B']] = 0.4
        self.assertRaises(ValueError,VT.check)
        
        
    def test_setUnset(self):
        '''Can wet set all of the unset table values?'''
        VT = ValueTable(['A','B','C'],'density')
        VT['A'] = 0.4
        VT.setUnset(0.25)
        self.assertEqual(VT['A'],0.4)
        self.assertEqual(VT['B'],0.25)
        self.assertEqual(VT['C'],0.25)
        
        
    def test_iter(self):
        '''Can we iterate over the table?'''
        types = ['A','B','C']
        ntypes = len(types)
        VT = ValueTable(types,'density')
        
        alphaTypes = ['A','B','C']
        
        counter = 0
        for i,t,v in VT:
            if t in alphaTypes:
                alphaTypes.remove(t)
            counter+=1
            
        #did we visit all expected pairs?
        self.assertEqual(len(alphaTypes),0)
        
        #sanity check, did we visit the correct number of pairs?
        self.assertEqual(counter,ntypes)
        
        
        
        
        


Overwriting ../typyPRISM/test/ValueTable_TestCase.py


In [20]:
import unittest
suite = []
suite.append(unittest.TestLoader().loadTestsFromTestCase(PairTable_TestCase))
suite.append(unittest.TestLoader().loadTestsFromTestCase(ValueTable_TestCase))

suite = unittest.TestSuite(suite)
unittest.TextTestRunner(verbosity=2).run(suite)

test_MatrixArray_export (__main__.PairTable_TestCase) ... ok
test_apply (__main__.PairTable_TestCase)
Can we apply a function to the table? ... ok
test_check (__main__.PairTable_TestCase)
Can we check to make sure the table is filled? ... ok
test_get_set (__main__.PairTable_TestCase)
Can we set and get values from the table? ... ok
test_iter_diagonal (__main__.PairTable_TestCase)
Can we iterate over the upper triangle + diagonal? ... ok
test_iter_full (__main__.PairTable_TestCase)
Can we iterate over all of the table pairs? ... ok
test_iter_triangle (__main__.PairTable_TestCase)
Can we iterate over only the upper triangle? ... ok
test_setUnset (__main__.PairTable_TestCase)
Can wet set all of the unset table values? ... ok
test_check (__main__.ValueTable_TestCase)
Can we check to make sure the table is filled? ... ok
test_get_set (__main__.ValueTable_TestCase)
Can we set and get values from the table? ... ok
test_iter (__main__.ValueTable_TestCase)
Can we iterate over the table? ... ok


<unittest.runner.TextTestResult run=12 errors=0 failures=0>