In [None]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

import gc
import glob
import json
import string
import random
import pprint
import numpy as np
import pandas as pd
from tqdm import tqdm
from functools import partial
from argparse import Namespace
import matplotlib.pyplot as plt

import tensorflow as tf
from tensorflow.keras.layers import *
from tensorflow.keras.models import *

from sklearn.model_selection import StratifiedKFold

import wandb
from wandb.keras import WandbCallback

# Imports for augmentations. 
from albumentations import Compose, RandomResizedCrop, Cutout, Rotate, HorizontalFlip, VerticalFlip,\
                           RandomBrightnessContrast, ShiftScaleRotate, CenterCrop, Resize, Normalize

In [None]:
### Create Kaggle Dataset if not exists 

DATASET_NAME = f'sorghum-100-tfrecords'

!rm -r ../tmp/{DATASET_NAME}

os.makedirs(f'../tmp/{DATASET_NAME}', exist_ok=True)

from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
secret_value_0 = user_secrets.get_secret("KAGGLE_KEY")
    
os.environ['KAGGLE_USERNAME'] = 'ayuraj'
os.environ['KAGGLE_KEY'] = secret_value_0

!kaggle datasets init -p ../tmp/{DATASET_NAME}

with open(f'../tmp/{DATASET_NAME}/dataset-metadata.json') as f:
    dataset_meta = json.load(f)
dataset_meta['id'] = f'ayuraj/{DATASET_NAME}'
dataset_meta['title'] = DATASET_NAME
with open(f'../tmp/{DATASET_NAME}/dataset-metadata.json', "w") as outfile:
    json.dump(dataset_meta, outfile)
print(dataset_meta)

!cp ../tmp/{DATASET_NAME}/dataset-metadata.json ../tmp/{DATASET_NAME}/meta.json
!ls ../tmp/{DATASET_NAME}

!kaggle datasets create -u -p ../tmp/{DATASET_NAME}

In [None]:
ROOT_PATH = '../input/sorghum-id-fgvc-9/'
TRAIN_PATH = ROOT_PATH+'train_images/'
TEST_PATH = ROOT_PATH+'test/'

def add_train_path(row):
    return TRAIN_PATH+row.image

def add_test_path(row):
    return TEST_PATH+row.image

def parse_label(row):
    target = row.target
    return int(label2ids[target])

train_files = glob.glob(TRAIN_PATH+'*')
test_files = glob.glob(TEST_PATH+'*')

# Prep Train CSV
df = pd.read_csv(ROOT_PATH+'train_cultivar_mapping.csv')
df.rename(columns={'cultivar': 'target'}, inplace=True)
df = df.sample(frac=1).reset_index(drop=True)

labels = df.target.unique()
label2ids = {label:idx for idx, label in enumerate(labels)}
df['img_path'] = df.apply(lambda row: add_train_path(row), axis=1)
df['target'] = df.apply(lambda row: parse_label(row), axis=1)

# Prep Test CSV
test_df = pd.read_csv(ROOT_PATH+'sample_submission.csv')
test_df.rename(columns={'filename': 'image'}, inplace=True)
test_df['img_path'] = test_df.apply(lambda row: add_test_path(row), axis=1)

In [None]:
df.head(3)

In [None]:
test_df.head(3)

In [None]:
# remove TRAIN missing file
missing_files = []
for idx, tmp_df in tqdm(df.iterrows()):
    img_path = tmp_df.img_path
    if img_path not in train_files:
        missing_files.append(img_path)
print('missing files num: ', len(missing_files))

df = df[~df.img_path.isin(missing_files)].reset_index(drop=True)
df = df.sample(frac=1).reset_index(drop=True)

df.head(1)

In [None]:
# remove TEST missing file
missing_files = []
for idx, tmp_df in tqdm(test_df.iterrows()):
    img_path = tmp_df.img_path
    if img_path not in test_files:
        missing_files.append(img_path)
print('missing files num: ', len(missing_files))

test_df = test_df[~test_df.img_path.isin(missing_files)].reset_index(drop=True)

test_df.head(1)

In [None]:
def get_stratified_k_fold(df, target, num_folds):
    """
    Add fold numbers to the given dataframe
    
    Arguments:
    df: Dataframe
    target: List of target to stratify on
    num_folds: Number of folds
    """
    kfold = StratifiedKFold(num_folds, shuffle=True, random_state=42)

    for fold, (train_indices, valid_indices) in enumerate(kfold.split(df, target)):
        df.loc[valid_indices, 'fold'] = fold
        
    return df.astype({'fold': int})

df = get_stratified_k_fold(df, df.target.values, 5)
df.head(1)

In [None]:
num_samples = 4096
num_tfrecords = len(test_df) // num_samples
print(num_tfrecords)

num_folds = 5

In [None]:
def image_feature(value):
    """Returns a bytes_list from a string / byte."""
    return tf.train.Feature(
        bytes_list=tf.train.BytesList(value=[tf.io.encode_jpeg(value).numpy()])
    )


def bytes_feature(value):
    """Returns a bytes_list from a string / byte."""
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value.encode()]))


def int64_feature(value):
    """Returns an int64_list from a bool / enum / int / uint."""
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))


def create_train_example(image, target, fold, index, img_name):
    feature = {
        "image": image_feature(image),
        "target": int64_feature(target),
        "fold": int64_feature(fold),
        "csv_index": int64_feature(index),
        'image_name': bytes_feature(img_name)
    }
    return tf.train.Example(features=tf.train.Features(feature=feature))

def parse_train_tfrecord_fn(example):
    feature_description = {
        "image": tf.io.FixedLenFeature([], tf.string),
        "target": tf.io.FixedLenFeature([], tf.int64),
        "fold": tf.io.FixedLenFeature([], tf.int64),
        "csv_index": tf.io.FixedLenFeature([], tf.int64),
        "image_name": tf.io.FixedLenFeature([], tf.string),
    }
    example = tf.io.parse_single_example(example, feature_description)
    example["image"] = tf.io.decode_jpeg(example["image"], channels=3)
    return example


def create_test_example(image, index, img_name):
    feature = {
        "image": image_feature(image),
        "csv_index": int64_feature(index),
        'image_name': bytes_feature(img_name)
    }
    return tf.train.Example(features=tf.train.Features(feature=feature))

def parse_test_tfrecord_fn(example):
    feature_description = {
        "image": tf.io.FixedLenFeature([], tf.string),
        "csv_index": tf.io.FixedLenFeature([], tf.int64),
        "image_name": tf.io.FixedLenFeature([], tf.string),
    }
    example = tf.io.parse_single_example(example, feature_description)
    example["image"] = tf.io.decode_jpeg(example["image"], channels=3)
    return example

In [None]:
!ls ../tmp

In [None]:
train_tfrecords_dir = f'../tmp/{DATASET_NAME}/train'
test_tfrecords_dir = f'../tmp/{DATASET_NAME}/test'

os.makedirs(train_tfrecords_dir, exist_ok=True)
os.makedirs(test_tfrecords_dir, exist_ok=True)

!ls ../tmp/{DATASET_NAME}

In [None]:
for fold in range(num_folds):
    print(fold)
    tmp_df = df[df.fold == fold]

    with tf.io.TFRecordWriter(
        train_tfrecords_dir + "/train_fold_%.2i-%i.tfrec" % (fold, len(tmp_df))
    ) as writer:
        for idx, row in tqdm(tmp_df.iterrows()):
            image = tf.io.decode_jpeg(tf.io.read_file(row.img_path))
            example = create_train_example(
                image, row.target, row.fold, idx, row.image
            )
            writer.write(example.SerializeToString())

In [None]:
!ls ../tmp/{DATASET_NAME}/train

In [None]:
train_files = glob.glob(f"{train_tfrecords_dir}/*.tfrec")
raw_dataset = tf.data.TFRecordDataset(train_files[0])
parsed_dataset = raw_dataset.map(parse_train_tfrecord_fn)

for features in parsed_dataset.take(1):
    for key in features.keys():
        if key != "image":
            print(f"{key}: {features[key]}")

    print(f"Image shape: {features['image'].shape}")
    plt.figure(figsize=(7, 7))
    plt.imshow(features["image"].numpy())
    plt.show()

In [None]:
for tfrec_num in range(num_tfrecords):
    print(tfrec_num)
    tmp_df = test_df.loc[(tfrec_num * num_samples) : ((tfrec_num + 1) * num_samples)]

    with tf.io.TFRecordWriter(
        test_tfrecords_dir + "/test_%.2i-%i.tfrec" % (tfrec_num, len(tmp_df))
    ) as writer:
        for idx, row in tqdm(tmp_df.iterrows()):
            image = tf.io.decode_jpeg(tf.io.read_file(row.img_path))
            example = create_test_example(
                image, idx, row.image
            )
            writer.write(example.SerializeToString())

In [None]:
!ls ../tmp/{DATASET_NAME}/test

In [None]:
test_files = glob.glob(f"{test_tfrecords_dir}/*.tfrec")
raw_dataset = tf.data.TFRecordDataset(test_files[0])
parsed_dataset = raw_dataset.map(parse_test_tfrecord_fn)

for features in parsed_dataset.take(1):
    for key in features.keys():
        if key != "image":
            print(f"{key}: {features[key]}")

    print(f"Image shape: {features['image'].shape}")
    plt.figure(figsize=(7, 7))
    plt.imshow(features["image"].numpy())
    plt.show()

In [None]:
from datetime import datetime
version_name = datetime.now().strftime("%Y%m%d-%H%M%S")
print(version_name)

In [None]:
!kaggle datasets version -m {version_name} -p ../tmp/{DATASET_NAME} -r zip -q