# Recreate Stratificated tfrecords

references:

[How To Create TFRecords](https://www.kaggle.com/cdeotte/how-to-create-tfrecords)

In [None]:
# LOAD LIBRARIES
import numpy as np, pandas as pd, os
import matplotlib.pyplot as plt, cv2
import tensorflow as tf, re, math
import glob
from sklearn.model_selection import StratifiedKFold

In [None]:
FOLDS=7
IMG_SIZE = 512
SEED = 2020

In [None]:
BASE = '../input/cassava-leaf-disease-classification'

In [None]:
# LOAD TRAIN META DATA
train = pd.read_csv(BASE+os.sep+'train.csv')

In [None]:
train.head(10)

In [None]:
folds = train.copy()
Fold = StratifiedKFold(n_splits=FOLDS, shuffle=True, random_state=SEED)
for n, (train_index, val_index) in enumerate(Fold.split(folds, folds['label'])):
    folds.loc[val_index, 'fold'] = int(n)
folds['fold'] = folds['fold'].astype(int)
print(folds.groupby(['fold', 'label']).size())

In [None]:
def _bytes_feature(value):
  """Returns a bytes_list from a string / byte."""
  if isinstance(value, type(tf.constant(0))):
    value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _float_feature(value):
  """Returns a float_list from a float / double."""
  return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

def _int64_feature(value):
  """Returns an int64_list from a bool / enum / int / uint."""
  return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

In [None]:
def serialize_example(feature0, feature1):
  feature = {
      'image': _bytes_feature(feature0),
      'target': _int64_feature(feature1)
  }
  example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
  return example_proto.SerializeToString()

In [None]:
for f in range(FOLDS):
    ct = (folds['fold'] == f).sum()
    idx = folds[folds['fold'] == f].index
    print(idx)
    print(ct)
    print('Writing TFRecord %i of %i...'%(f,ct))
    with tf.io.TFRecordWriter('train%.2i-%i.tfrec'%(f,ct)) as writer:
        for k in range(ct):
            path = BASE+'/train_images/'+folds['image_id'][idx[k]]            
            img = cv2.imread(path)
            img = cv2.resize(img, (IMG_SIZE, IMG_SIZE))
            img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) # Fix incorrect colors
            if k==0: plt.imshow(img),plt.show()
            img = cv2.imencode('.jpg', img, (cv2.IMWRITE_JPEG_QUALITY, 94))[1].tostring()
            name = folds['image_id'][idx[k]].split('.')[0]
            row = folds.loc[folds.image_id==name]
            example = serialize_example(
                img, 
                folds['label'][idx[k]],
                )
            writer.write(example)
            if k%100==0: print(k,', ',end='')