In [0]:
import os
import tensorflow as tf
import matplotlib.pyplot as plt
import math
import tqdm

#The length of the data which cointains the indexes and labels of the images       
def check_length(file_name):
    with open(file_name, 'r') as f:
        return sum(1 for line in f)
    
#Train, Validation and Test data    
def divide_data_file(data_file, train_split=0.8, valid_split=0.1):
    data_file_dir = os.path.abspath(os.path.dirname(data_file)) #In the data file, index and label are in a line
    length = check_length(data_file) #Length of the data
    train_index = int(length * train_split) #Train data
    valid_index = train_index + int(length * valid_split) #Valid data
    #Create 3 txt (Train, Valid, Test) and write there the appropriate number of the data 
    with open(data_file, 'r') as df, open(os.path.join(data_file_dir, 'train.txt'), 'w') as train, open(os.path.join(data_file_dir, 'valid.txt'), 'w') as valid, open(os.path.join(data_file_dir, 'test.txt'), 'w') as test:
        for index, line in enumerate(df): #Enumerate get an index to every line
            if index < train_index:
                train.write(line)
            elif index < valid_index:
                valid.write(line)
            else:
                test.write(line)
    
    return train_index, valid_index - train_index, length - valid_index #Return the idnexes of each shards

#Loading txt and split every lines according to tabulator. index '\t' text
def load_file(file_name):
    with open(file_name, 'r') as f:
        for line in f:
           yield line.split('\t') #yield helps us, it remembers the last line, so it can continue where it left it
#
# Getting the image after the index. index => index.jpg
def load_image(idx):
    img = plt.imread(os.path.join('D:/Dani/bme/5.félév/dl/hf/full_images', str(idx) + '.jpg'))
    if len(img.shape) == 2: #We need to care about black and white pictures, there are only 2 dimesions 
        img = img[:, :, None]
    img = tf.Session().run(tf.image.resize_images(img, [200,200])) #Resize the images
    return img

#We need to create tfrecords. Easier teaching, because its own extension of tensorflow
def create_tfrecord_names(tfrecord_path, mode, num_files): #mode (tarin,valid,test), num_files - how much tfrecord should we generate
    mode_path = os.path.join(tfrecord_path, mode)
    if not os.path.isdir(mode_path): #if there isnt a directory, generate
        os.mkdir(mode_path)
    return [tf.python_io.TFRecordWriter(os.path.join(mode_path, '{name}.tfrecords'.format(name=index))) for index in range(num_files)] #call the tfrecord writer
      
   

  def create_tfrecords(data_file, tfrecord_path, max_size=5000, train_split=0.8, valid_split=0.1): #max_size: Maximum data in a tfrecord
    train_length, valid_length, test_length = divide_data_file(data_file, train_split, valid_split) #call divide_data, we get the shards
    data_file_dir = os.path.abspath(os.path.dirname(data_file)) #Source data
    
    modes = { #3 modes
        'train': [  #Create a directory for a train tfrecords
            create_tfrecord_names(tfrecord_path, 'train', math.ceil(train_length / max_size)),
            os.path.join(data_file_dir, 'train.txt') 
        ],
        'valid': [  #Create a directory for a validation tfrecords
            create_tfrecord_names(tfrecord_path, 'valid', math.ceil(valid_length / max_size)),
            os.path.join(data_file_dir, 'valid.txt')
        ],
        'test': [  #Create a directory for a test tfrecords
            create_tfrecord_names(tfrecord_path, 'test', math.ceil(test_length / max_size)),
            os.path.join(data_file_dir, 'test.txt')
        ]
        
    }
             
    for mode in modes:
        tfrecords_list = modes[mode][0] #There is a list for tfrecords, for every mode
        data_file = modes[mode][1] #There is a data_file, for every mode
        prev_tfrecords_index = 0
        with tqdm.tqdm() as pbar: #It's an indicator, shows us the actual state of the generate
            for index, (image_index, text) in enumerate(load_file(data_file)): #Getting the index and the label of the image from load_file
                pbar.update() #update the indicator
                tfrecords_index = index // max_size #eg: max_size is 5000, we have 14970 data, there will be 3 tf records
                example = tf.train.Example( #Need an example and features
                    features=tf.train.Features(
                        feature={'image': tf.train.Feature(float_list=tf.train.FloatList(value=load_image(image_index).reshape(-1))), #Features are - image and text - so create these pairs 
                                 'text': tf.train.Feature(bytes_list=tf.train.BytesList(value=[tf.compat.as_bytes(text)]))}
                ))
                tfrecords_list[tfrecords_index].write(example.SerializeToString())
                if prev_tfrecords_index != tfrecords_index:
                    tfrecords_list[prev_tfrecords_index].close()
                    prev_tfrecords_index += 1;
 
        
create_tfrecords('D:/Dani/bme/5.félév/dl/hf/output.txt', 'D:/Dani/bme/5.félév/dl/hf/tfrecords')