In [None]:
import h5py
import pandas as pd
import numpy as np
import os
import tensorflow as tf
import glob

In [None]:
vocab = ['electrons','protons','muons','pions','gamma']

In [None]:
def get_label(file_path):
  # convert the path to a list of path components and one hot them
    parts = tf.string_split([file_path], '/')
    label = parts.values[-2] 
    matches = tf.stack([tf.equal(label, s) for s in vocab], axis=-1)
    onehot = tf.cast(matches, tf.float32)
    return onehot

In [None]:
def get_label_csv(file_path):
    parts = tf.string_split([file_path], '/')
    label = parts.values[-3] 
    matches = tf.stack([tf.equal(label, s) for s in vocab], axis=-1)
    onehot = tf.cast(matches, tf.float32)
    return onehot

In [None]:
def extract_img(img_dir,image_size,grayscale):
    
    img = tf.io.read_file(img_dir)
    # convert the compressed string to a uint8 tensor
    img = tf.image.decode_jpeg(img, channels=3)
    # Use `convert_image_dtype` to convert to floats in the [0,1] range.
    img = tf.image.convert_image_dtype(img, tf.float32)
    # resize the image to the desired size.
    #img = tf.image.per_image_standardization(img)
    img = scale_normalize(img)
    
    img = tf.image.resize(img, [image_size, image_size])
    
    if grayscale:
        img = tf.image.rgb_to_grayscale(img)
    
    label = get_label(img_dir)
    
 
        
    return img,label

In [None]:
def scale_normalize(tensor):
    tensor = tf.div(
   tf.subtract(
      tensor, 
      tf.reduce_min(tensor)
   ), 
   tf.subtract(
      tf.reduce_max(tensor), 
      tf.reduce_min(tensor)
   )
)
    return tensor

In [None]:
def process_csv(features):
    features =  tf.stack(list(features.values()), axis=-1)
    features = tf.reshape(features,shape=[-1,1])
    #features = tf.image.per_image_standardization(features)
    return features

In [None]:
def extract_csv(frequency_dir):
    csv_dataset = tf.data.experimental.make_csv_dataset(frequency_dir,batch_size=8,shuffle=False)
    csv_dataset = csv_dataset.map(process_csv,num_parallel_calls=8)
    return csv_dataset

In [None]:
def add_frequency_data_to_image(image_with_label,frequency_data,grayscale):
    #frequency data is just a single 64 element row with shape [64,1]
     img = image_with_label[0] 
     label = image_with_label[1]

        
     if grayscale:
        #img has 1 channels with shape [img_size,img_size,1]
        frequency_data = tf.expand_dims(frequency_data,axis=0) #new_shape:[1,64,1]
        combined_data_point = tf.concat([img,frequency_data],0) #final_shape:[65,64,1]

        
            
     else:      
        #img has 3 channels with shape [img_size,img_size,1]
         frequency_data = tf.tile(frequency_data,[1,3]) #new_shape:[64,3]
         frequency_data = tf.expand_dims(frequency_data,axis=0) #new_shape:[1,64,3]        
         combined_data_point = tf.concat([img,frequency_data],0) #final_shape:[65,64,3]


        
     return combined_data_point,label    


In [None]:
def load_images(img_dir,image_size,num_parallel_calls,grayscale):
    list_images = tf.data.Dataset.list_files(img_dir)
    img_dataset = list_images.map(lambda img_dir: extract_img(img_dir,image_size,grayscale), num_parallel_calls=num_parallel_calls)
    return img_dataset

In [None]:
def combine_csv_and_img(img_dir,frequency_dir,image_size,minibatch_size,grayscale,num_parallel_calls):
    
#     list_images = tf.data.Dataset.list_files(img_dir)
#     img_dataset = list_images.map(lambda img_dir: extract_img(img_dir,image_size), num_parallel_calls=num_parallel_calls)
    img_dataset = load_images(img_dir,image_size=image_size,num_parallel_calls=num_parallel_calls,grayscale=grayscale)
    
    
    csv_dataset = extract_csv(frequency_dir)
    
    combined_dataset = tf.data.Dataset.zip((img_dataset, csv_dataset))
    
    
    
    combined_dataset = combined_dataset.map(lambda image_with_label,frequency_data:add_frequency_data_to_image(image_with_label,frequency_data,grayscale),num_parallel_calls=num_parallel_calls)
    
    return combined_dataset
    

In [None]:
def load_dataset(img_dir,frequency_dir,minibatch_size,image_size,num_parallel_calls=8,grayscale=False,is_frequency=True):
    
    if is_frequency:
        dataset = combine_csv_and_img(img_dir,frequency_dir,image_size,minibatch_size,grayscale,num_parallel_calls=num_parallel_calls).batch(minibatch_size).prefetch(1) 
    #batched_dataset = dataset.batch(minibatch_size).prefetch(1)
    else:
        dataset = load_images(img_dir,image_size=image_size,num_parallel_calls=num_parallel_calls,grayscale=grayscale).batch(minibatch_size).prefetch(1)

    return dataset