In [None]:
# import tensorflow as tf
import tensorflow.compat.v1 as tf
import tensorflow_datasets as tfds
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import tempfile
import pprint

from PIL import Image, ImageOps
from sklearn.model_selection import KFold, GroupKFold, StratifiedKFold
from tqdm import tqdm

import cv2
import glob
import io
import os
import yaml

import IPython.display as display

%matplotlib inline

In [None]:
class CFG:
    SEED = 6718
    N_SPLITS = 5    
    EPOCHS = 5
    BATCH_SIZE = 32 # REPLICAS * 32
    IMG_SIZE = 512
    OUTPUT_DIR = ''

In [None]:
import albumentations as A
import albumentations.pytorch.transforms as T


mean = (0.485, 0.456, 0.406) # RGB
std = (0.229, 0.224, 0.225) # RGB

transform = {
    'train' : A.Compose([
        A.HorizontalFlip(p=0.5),
        A.RandomResizedCrop(CFG.IMG_SIZE, CFG.IMG_SIZE),
        A.Cutout(max_h_size=56, max_w_size=56, num_holes=5, p=0.5),
        A.Normalize(mean, std),
        # T.ToTensorV2()
    ]),
    'val' : A.Compose([
        A.Resize(CFG.IMG_SIZE, CFG.IMG_SIZE),
        # A.Normalize(mean, std),
        # T.ToTensorV2()
    ]),    
}

In [None]:
ROOT_PATH = '/kaggle/input/plant-pathology-2020-fgvc7/'
TRAIN_PATH = ROOT_PATH + 'train.csv'
TEST_PATH = ROOT_PATH + 'test.csv'
SUB_PATH = ROOT_PATH + 'sample_submission.csv'
IMG_PATH = ROOT_PATH + 'images/'
LABELS = ['healthy', 'multiple_diseases', 'rust', 'scab']

In [None]:
train = pd.read_csv(TRAIN_PATH)


y = train[LABELS].values

kf = KFold(n_splits=CFG.N_SPLITS,random_state=CFG.SEED, shuffle=True)
for fold, (train_idx, valid_idx) in enumerate(kf.split(train,y)):
    train.loc[valid_idx, 'kfold'] = fold

train['kfold'] = train['kfold'].astype(int)

print(train.shape)
train.head()

In [None]:
test = pd.read_csv(TEST_PATH)
for L in LABELS:
    test[L] = -1
    
print(test.shape)
test.head()

In [None]:
if 1:
    def _bytes_feature(value):
      """Returns a bytes_list from a string / byte."""
      if isinstance(value, type(tf.constant(0))):
        value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
      return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


    def _float_feature(value):
      """Returns a float_list from a float / double."""
      return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

    
    def _int64_feature(value):
      """Returns an int64_list from a bool / enum / int / uint."""
      return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
    
    
    def _int64_list_feature(value):
      """Returns an int64_list from a bool / enum / int / uint."""
      return tf.train.Feature(int64_list=tf.train.Int64List(value=value)) # <-- 配列の時は [] いらない
    
    
    def serialize_example(feature0, feature1):
      """
      Creates a tf.train.Example message ready to be written to a file.
      """
      # Create a dictionary mapping the feature name to the tf.train.Example-compatible
      # data type.
      feature = {
          'image': _bytes_feature(tf.io.encode_jpeg(feature0)), # cast uint tesnsor -> bytes
          'label': _int64_list_feature(feature1), # for array_like object
      }

      # Create a Features message using tf.train.Example.
      example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
      return example_proto.SerializeToString()


    def generator():
      for features in features_dataset:
        yield serialize_example(*features)

In [None]:
import random
import torch
import torch.nn as nn


class CreateTFRecordDataset(torch.utils.data.Dataset):
    def __init__(self, paths, y=None):
        self.paths = paths
        self.y = y

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, item):

        images = cv2.imread(self.paths[item])
        images = transform['val'](image=images)['image']

        if self.y is not None:
            targets = self.y[item]
            return (images, targets)
        
        return images

In [None]:
# train = train.head(64)
# test = test.head(64)

In [None]:
train_paths = IMG_PATH + train['image_id'] + '.jpg'
train_labels = train[LABELS].values

test_paths = IMG_PATH + test['image_id'] + '.jpg'
test_labels = test[LABELS].values

In [None]:
train_dataset = CreateTFRecordDataset(train_paths, train_labels)
test_dataset = CreateTFRecordDataset(test_paths, test_labels)

train_loader = torch.utils.data.DataLoader(
                    train_dataset, shuffle=False, 
                    batch_size=CFG.BATCH_SIZE,
                    num_workers=0, pin_memory=True)

test_loader = torch.utils.data.DataLoader(
                    test_dataset, shuffle=False, 
                    batch_size=CFG.BATCH_SIZE,
                    num_workers=0, pin_memory=True)

In [None]:
tk0 = tqdm(train_loader, total=len(train_loader))  
for step, data in enumerate(tk0):
    slices = (data[0].numpy(), data[1].numpy())
    features_dataset = tf.data.Dataset.from_tensor_slices(slices)

    filename = f'train_{step}.tfrec'

    serialized_features_dataset = tf.data.Dataset.from_generator(
            generator, output_types=tf.string, output_shapes=())
        
        
    print(f"{filename} writing ...")
    writer = tf.data.experimental.TFRecordWriter(filename)
    writer.write(serialized_features_dataset)    

In [None]:
tk0 = tqdm(test_loader, total=len(test_loader))  
for step, data in enumerate(tk0):
    slices = (data[0].numpy(), data[1].numpy())
    features_dataset = tf.data.Dataset.from_tensor_slices(slices)

    filename = f'test_{step}.tfrec'

    serialized_features_dataset = tf.data.Dataset.from_generator(
            generator, output_types=tf.string, output_shapes=())
        
        
    print(f"{filename} writing ...")
    writer = tf.data.experimental.TFRecordWriter(filename)
    writer.write(serialized_features_dataset)  

In [None]:
plt.imshow(train_dataset[0][0])

In [None]:
tfrec_paths = [p for p in os.listdir() if 'tfrec' in p]

In [None]:
raw_dataset = tf.data.TFRecordDataset(tfrec_paths)
raw_dataset

In [None]:
# Create a description of the features.
feature_description = {
    'image': tf.io.FixedLenFeature([], tf.string, default_value=''),
    'label': tf.io.FixedLenSequenceFeature([], tf.int64, default_value=0, allow_missing=True),
}

def _parse_function(example_proto):
  # Parse the input `tf.train.Example` proto using the dictionary above.
  return tf.io.parse_single_example(example_proto, feature_description)

In [None]:
def _decode_image_function(example):
    image = example['image']
    label = example['label']
    
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.cast(image, tf.float32) / 255.0
    return image, label

In [None]:
parsed_dataset = raw_dataset \
                   .map(_parse_function) \
                   .map(_decode_image_function)
parsed_dataset

In [None]:
for parsed_record in parsed_dataset.take(2):

    image = parsed_record[0].numpy()
    labels = parsed_record[1].numpy()
    
    print(image.shape)
    print(labels, labels.shape)
    
    plt.imshow(image)