In [1]:
import tensorflow as tf 

In [3]:
%load_ext autoreload

In [4]:
%autoreload 2

In [None]:
def parse_window_all(element, read_hits=False):
    context_features = {
        's_f': tf.io.FixedLenFeature([24], tf.float32),
        's_l': tf.io.FixedLenFeature([3], tf.int64),
        's_m': tf.io.FixedLenFeature([8], tf.float32),
        # window class
        'w_cl' : tf.io.FixedLenFeature([], tf.int64),
        # number of clusters
        'n_cl' : tf.io.FixedLenFeature([], tf.int64),
        # flag (pdgid id)
        'f' :  tf.io.FixedLenFeature([], tf.int64)
    }
    clusters_features = {
        "cl_f" : tf.io.FixedLenSequenceFeature([22], dtype=tf.float32),
        "cl_m" : tf.io.FixedLenSequenceFeature([1], dtype=tf.float32),
        "cl_l" : tf.io.FixedLenSequenceFeature([6], dtype=tf.int64),
    }
    if read_hits:
        context_features['s_h'] = tf.io.FixedLenFeature([], tf.string)
        clusters_features["cl_h0"] = tf.io.RaggedFeature(dtype=tf.float32)
        clusters_features["cl_h1"] = tf.io.RaggedFeature(dtype=tf.float32)
        clusters_features["cl_h2"] = tf.io.RaggedFeature(dtype=tf.float32)
        clusters_features["cl_h4"] = tf.io.RaggedFeature(dtype=tf.float32)

    ex = tf.io.parse_single_sequence_example(element, context_features=context_features, sequence_features=clusters_features)
    
    if read_hits:
        seed_hits = tf.io.parse_tensor(ex[0]['s_h'], out_type=tf.float32)
        seed_hits.set_shape(tf.TensorShape((None, 4))) 
        ex[0]['s_h'] = seed_hits
        
        cluster_hits = tf.ragged.stack([ex[1]['cl_h0'], ex[1]['cl_h1'],ex[1]['cl_h2'],ex[1]['cl_h4']],axis=2)
        ex[1]['cl_h'] = cluster_hits 
        ex[1].pop("cl_h0")
        ex[1].pop("cl_h1")
        ex[1].pop("cl_h2")
        ex[1].pop("cl_h4")
    
    return ex

In [107]:
def parse_windows_sparse(elements, read_hits=False):
    context_features = {
        's_f': tf.io.FixedLenFeature([24], tf.float32),
        's_l': tf.io.FixedLenFeature([3], tf.int64),
        's_m': tf.io.FixedLenFeature([8], tf.float32),
        # window class
        'w_cl' : tf.io.FixedLenFeature([], tf.int64),
        # number of clusters
        'n_cl' : tf.io.FixedLenFeature([], tf.int64),
        # flag (pdgid id)
        'f' :  tf.io.FixedLenFeature([], tf.int64)
    }
    clusters_features = {
        "cl_f" : tf.io.FixedLenSequenceFeature([22], dtype=tf.float32),
        "cl_m" : tf.io.FixedLenSequenceFeature([1], dtype=tf.float32),
        "cl_l" : tf.io.FixedLenSequenceFeature([6], dtype=tf.int64),
    }
    if read_hits:
        # context_features['s_h'] = tf.io.FixedLenFeature([1], tf.string)
        clusters_features["cl_h0"] = tf.io.VarLenFeature(dtype=tf.float32)
        clusters_features["cl_h1"] = tf.io.VarLenFeature(dtype=tf.float32)
        clusters_features["cl_h2"] = tf.io.VarLenFeature(dtype=tf.float32)
        clusters_features["cl_h4"] = tf.io.VarLenFeature(dtype=tf.float32)

    ex = tf.io.parse_sequence_example(elements, context_features=context_features, sequence_features=clusters_features,name="input")
    
    
    
    return ex

In [87]:
def parse_windows_all(elements, bs , read_hits=False):
    context_features = {
        's_f': tf.io.FixedLenFeature([24], tf.float32),
        's_l': tf.io.FixedLenFeature([3], tf.int64),
        's_m': tf.io.FixedLenFeature([8], tf.float32),
        # window class
        'w_cl' : tf.io.FixedLenFeature([], tf.int64),
        # number of clusters
        'n_cl' : tf.io.FixedLenFeature([], tf.int64),
        # flag (pdgid id)
        'f' :  tf.io.FixedLenFeature([], tf.int64)
    }
    clusters_features = {
        "cl_f" : tf.io.FixedLenSequenceFeature([22], dtype=tf.float32),
        "cl_m" : tf.io.FixedLenSequenceFeature([1], dtype=tf.float32),
        "cl_l" : tf.io.FixedLenSequenceFeature([6], dtype=tf.int64),
    }
    if read_hits:
#         context_features['s_h'] = tf.io.FixedLenFeature([1], tf.string)
        clusters_features["cl_h0"] = tf.io.RaggedFeature(dtype=tf.float32)
        clusters_features["cl_h1"] = tf.io.RaggedFeature(dtype=tf.float32)
        clusters_features["cl_h2"] = tf.io.RaggedFeature(dtype=tf.float32)
        clusters_features["cl_h4"] = tf.io.RaggedFeature(dtype=tf.float32)

    ex = tf.io.parse_sequence_example(elements, context_features=context_features, sequence_features=clusters_features)
    
    if read_hits:
#         seed_hits = tf.io.parse_tensor(ex[0]['s_h'], out_type=tf.float32)
#         seed_hits.set_shape(tf.TensorShape((None, 4))) 
#         ex[0]['s_h'] = seed_hits
        
        cluster_hits = tf.ragged.stack([ex[1]['cl_h0'], ex[1]['cl_h1'],ex[1]['cl_h2'],ex[1]['cl_h4']],axis=3)
        ex[1]['cl_h'] = cluster_hits 
        ex[1].pop("cl_h0")
        ex[1].pop("cl_h1")
        ex[1].pop("cl_h2")
        ex[1].pop("cl_h4")
    
    return ex

In [48]:
from global_model.tf_data import * 

In [17]:
def only_cl_hits(*kargs):
    print (kargs[1])
    return kargs

In [108]:
# Create datasets from TFRecord files.
dataset = tf.data.TFRecordDataset(tf.io.gfile.glob("/eos/user/r/rdfexp/ecal/cluster/output_deepcluster_dumper/windows_data/electrons/recordio_allinfo_v2/training/calo_matched/*.proto"))
#dataset = dataset.map(_parse_tfr_element,num_parallel_calls=tf.data.experimental.AUTOTUNE)
#dataset = dataset.shuffle(10000, reshuffle_each_iteration=True)

In [103]:
def cluster_features_and_hits(feat_index): 
    def process(*kargs):
        cl_f = kargs[1]['cl_f']
        cl_l = kargs[1]['cl_l']
        cl_X = tf.gather(cl_f, indices=feat_index,axis=2)
        cl_hits = kargs[1]['        cl_h']
        is_seed = tf.gather(cl_l,indices=[0],axis=2)
        in_sc = tf.gather(cl_l,indices=[3],axis=2)
    #     X = tf.gather(seed['s_f'], indices=[0,1,4,6,16,17,18,19,20,21,22,23])
    #     X_scaled = tf.gather(norm_feat, indices=[0,1,4,6,16,17,18,19,20,21,22,23])
    #     calomatch = seed['s_l'][0]
    #     caloseed = seed['s_l'][1]
    #     seed_pos = tf.cast(tf.where(cls['cl_l'][:,0]==1)[0], tf.float32)
    #     n_cl = [tf.cast(seed['n_cl'], tf.float32)]
    #     X_train = tf.concat([X_scaled, seed_pos, n_cl], axis=0)
    #     X = tf.concat([X, eed_pos, n_cl], axis=0)
        return cl_X, cl_hits, is_seed, in_sc,  kargs[2]["cl_f"]
    return process


In [109]:
df = dataset.batch(10).map(lambda obj: parse_windows_sparse(obj, read_hits=True))

In [104]:
df = dataset.batch(10).map(lambda obj: parse_windows_batch(obj,5, read_hits=True)).map(cluster_features_and_hits(get_cluster_features_indexes(["cluster_deta","cluster_dphi"]))).filt

In [110]:
el = next(iter(df))

In [111]:
el

({'f': <tf.Tensor: shape=(10,), dtype=int64, numpy=array([11, 11, 11, 11, 11, 11, 11, 11, 11, 11])>,
  'n_cl': <tf.Tensor: shape=(10,), dtype=int64, numpy=array([6, 5, 4, 6, 7, 8, 5, 6, 4, 3])>,
  's_f': <tf.Tensor: shape=(10, 24), dtype=float32, numpy=
  array([[ 1.47121322e+00, -6.39126003e-01,  8.50000000e+01,
           3.34000000e+02,  0.00000000e+00,  1.14316826e+02,
           4.98746910e+01,  1.22074326e+02,  5.32591705e+01,
           1.20920464e+02,  5.27386475e+01,  9.77760971e-01,
           6.65611634e-03,  6.15860563e-06,  1.18546225e-02,
           4.59328473e-01,  9.77760971e-01,  6.67275162e-03,
           5.94148469e-06,  1.18654398e-02,  4.59328473e-01,
           1.10000000e+01,  6.94015669e-03,  1.02384873e-02],
         [-7.09592998e-01,  2.98527408e+00, -4.10000000e+01,
           1.82000000e+02,  0.00000000e+00,  3.05662518e+01,
           2.42108173e+01,  3.09168797e+01,  2.44885406e+01,
           3.15492878e+01,  2.49847775e+01,  9.93191302e-01,
           9.

In [37]:
obj2 = df.take(1)

In [38]:
next(iter(obj2))

(<tf.Tensor: shape=(6, 22), dtype=float32, numpy=
 array([[ 8.50000000e+01,  3.34000000e+02,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  1.14316826e+02,
          4.98746910e+01,  1.22074326e+02,  5.32591705e+01,
          9.77760971e-01,  6.65611634e-03,  6.15860563e-06,
          1.18546225e-02,  4.59328473e-01,  9.77760971e-01,
          6.67275162e-03,  5.94148469e-06,  1.18654398e-02,
          4.59328473e-01,  1.10000000e+01,  6.94015669e-03,
          1.02384873e-02],
        [ 8.50000000e+01,  3.28000000e+02,  0.00000000e+00,
          3.03769112e-03,  9.30574536e-02,  3.05592632e+00,
          1.32961380e+00,  3.62426043e+00,  1.57689226e+00,
          1.19116390e+00,  1.36063853e-02,  9.78259050e-05,
          1.69923399e-02,  4.50912118e-01,  1.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  9.17003956e-03,
          6.58439577e-01,  2.00000000e+00,  3.01638572e-03,
          8.66006128e-03],
        [ 7.80000000e+01,  3.06000000e+02,  0.00000000e+

In [39]:
ps = ([None,])

ds_train = df.take(10).batch(3)

In [40]:
a = next(iter(ds_train))

InvalidArgumentError: Cannot batch tensors with different shapes in component 0. First element had shape [6,22] and element 1 had shape [5,22].

<tf.Tensor: shape=(9, 4), dtype=float32, numpy=
array([[-41.        , 182.        ,   0.        ,  21.892532  ],
       [-41.        , 183.        ,   0.        ,   0.46643642],
       [-42.        , 183.        ,   0.        ,   0.2937847 ],
       [-42.        , 182.        ,   0.        ,   0.8742974 ],
       [-42.        , 181.        ,   0.        ,   0.56234646],
       [-41.        , 181.        ,   0.        ,   4.902018  ],
       [-40.        , 181.        ,   0.        ,   0.54801035],
       [-40.        , 182.        ,   0.        ,   0.7728225 ],
       [-39.        , 182.        ,   0.        ,   0.25400263]],
      dtype=float32)>