Inference of a single count label c and n coupled labels $x_i,y_i$, $i=1,\ldots,n$, such that $c=x_i+y_i$, leads to 

$\max\limits_{c=x_1+x_2\atop {\cdots\atop c=x_n+y_n}} log(p(c))+\sum_{i=1}^n \big (log(P_x(x_i))+log(P_y(y_i))\big )$

In [31]:
import numpy as np

# create instance of the problem
n_windows = 200
n_max_events = 49
n_coupled_labels = 2

Pc = np.random.rand(n_max_events + 1, n_windows)
Pc = Pc / np.sum( Pc, axis=0)
log_Pc = np.log( Pc )

log_P = []
for i in range(n_coupled_labels):
    Px = np.random.rand(n_max_events + 1, n_windows)
    Px = Px / np.sum( Px, axis=0)

    Py = np.random.rand(n_max_events + 1, n_windows)
    Py = Py / np.sum( Py, axis=0)
    
    log_P.append( (np.log(Px),np.log(Py)))

print(log_Pc.shape)
print(log_P[0][0].shape)
print(log_P[0][1].shape)


(50, 200)
(50, 200)
(50, 200)


In [32]:
def struct_inference(log_Pc, log_P):
    
    n_events = log_Pc.shape[0]-1
    n_wins = log_Pc.shape[1]
    
    phi = []
    arg_phi = []
    score = np.copy( log_Pc )
    for i in range( len(log_P) ):
        log_Px = log_P[i][0]
        log_flip_Py = np.flipud( log_P[i][1] )

        phi_ = np.zeros( (n_events+1,n_wins) )
        arg_phi_ = np.zeros( (n_events+1,n_wins), dtype = int )
        for c in range( n_events+1):
            tmp = log_Px[0:c+1,:] + log_flip_Py[-(c+1):,:]
            arg_phi_[c,:] = np.argmax( tmp, axis=0 )
            #phi_[c,:] = np.max( tmp, axis=0)
            idx_row, idx_col = np.unravel_index(arg_phi_[c,:] * tmp.shape[1] + np.arange(0, tmp.shape[1]), tmp.shape)
            phi_[c,:] = tmp[idx_row, idx_col]

        arg_phi.append(arg_phi_ )

        score += phi_
            
    c = np.argmax(score, axis=0)
        
    lab = []
    for i in range( len(log_P) ):
        idx_row, idx_col = np.unravel_index( c*n_wins+np.arange(0,n_wins ), (n_events+1,n_wins))
        lab.append( arg_phi[i][idx_row,idx_col])
     
    return c, lab


In [33]:
# log_Pc.argmax(0)
# log_P[0][0].argmax(0)

In [34]:
# inference  
c, lab = struct_inference( log_Pc, log_P )

print("number of windows:", log_Pc.shape[1])
print("max num of events in a single window:", log_Pc.shape[0]-1)
print("count_label:", c)
for i in range( len(log_P) ):
    print("coupled_labels",i,"   x:",lab[i], "y:", c-lab[i])

number of windows: 200
max num of events in a single window: 49
count_label: [44 28 27 44 28 17 37 26 39 47 48 18 20 49 29 44 27 32 42 36 35 33 29 42
 42 45 31 42 33 48 31 39 28 40 23 47 43 43 39 44 35 31 47 19 38 35 48 47
 21 39 41 49 46 43 47 43 44 32 44 49 43 45 36 43 35 38 42 28 44 39 40 32
 34 40 29 48 41 33 34 33 13 41 25 19 44 37 16 46 27 33 46 41 49 22 31 19
 49 18 42 49 26 31 36 48 42 43 48 26 41 40 48 44 36 41 30 33 33 49 46 38
 36 15 44 40 46 36 36 22 47 46 46 27 45 42 47 42 43 37 28 49 28 34 17 49
 46 33 42 37 49 36 40 45 22 46 20 29 42 36 47 47 39 27 42 48 38 40 40 33
 47 28 48 48 40 19 47 15 20 40 28 47 49 23 34 29 46 25 23 49 47 26 16 33
 42 35 46 48 26 39 27 38]
coupled_labels 0    x: [24 28  1  8 28 16 27  1 11 10 26  9 17 30 12 10 24  2  9 12 19  9  9  8
 30 36  3 15 15 22 12 14  2 40 13 26  5  4 39  0 26 22 39  5 16 19 12 10
 14  4 13 40 24 43 33 37 17 31  2 29 40 45 22 24 27 31  8  9  2 23 26 15
 29  1 28 23 11 17 34  8  8 25  9 11 29  5 13 38 21 32 12 34 36  5 25  

In [37]:
lab[0]

array([24, 28,  1,  8, 28, 16, 27,  1, 11, 10, 26,  9, 17, 30, 12, 10, 24,
        2,  9, 12, 19,  9,  9,  8, 30, 36,  3, 15, 15, 22, 12, 14,  2, 40,
       13, 26,  5,  4, 39,  0, 26, 22, 39,  5, 16, 19, 12, 10, 14,  4, 13,
       40, 24, 43, 33, 37, 17, 31,  2, 29, 40, 45, 22, 24, 27, 31,  8,  9,
        2, 23, 26, 15, 29,  1, 28, 23, 11, 17, 34,  8,  8, 25,  9, 11, 29,
        5, 13, 38, 21, 32, 12, 34, 36,  5, 25,  6, 30,  8, 26,  7,  3, 31,
       25,  5, 20, 38, 42, 17, 17, 11, 19, 29,  1, 20, 22,  1,  8, 45, 36,
       24, 26,  2,  3,  6, 19, 20, 11,  5, 32,  5, 39, 21, 14, 34, 12, 11,
       20, 22, 17,  9,  2, 34,  6, 23,  0,  7, 42, 26, 29, 26, 23, 16, 10,
       41,  8, 28, 39, 22, 14, 34, 10, 10,  7, 18,  0, 37, 24,  2, 46, 25,
        7, 13, 33,  6,  2,  8,  2, 37,  8, 38, 32, 12, 14, 18,  1, 22, 10,
       25,  7, 22, 15,  6, 40,  9, 10,  4,  6, 11,  3,  8])

### Eyedea

In [1]:
import os
os.chdir('..')

In [2]:
from src import *

In [None]:
def inference(root_uuid, prefix, model_name, inference_function):
    uuids = []
    for i in range(0, 5):
        uuid = f'{root_uuid}/{i}'
        print(uuid)
        uuids.append(uuid)
        config = load_config_locally(uuid)
        config.inference_function = inference_function
        config.coupled_labels = [['n_incoming', 'n_outgoing'], ['n_CAR', 'n_NOT_CAR']]
        datapool = DataPool(config.testing_files, config)
        model, _ = load_model_locally(uuid, model_name)
        summary = validate_datapool(datapool, model, config, Part.WHOLE)


root_uuid = '013_eyedea_all_aligned_RX100_direction_types'
prefix = 'tst_coupled'
model_name = 'rvce'
inference_function = InferenceFunction.COUPLED

inference(root_uuid, prefix, model_name, inference_function)