# Happy Whale Images Converted to HDF5 for Faster Batch Loading

In [None]:
import h5py
import cv2
import io
from PIL import Image
import os
import numpy as np
from tqdm import tqdm

# paths
TRAIN_IMAGES = '../input/happy-whale-and-dolphin/train_images'
TEST_IMAGES = '../input/happy-whale-and-dolphin/test_images'



def list_files(gtdir):
    file_list = []
    for root, dirs, files in os.walk(gtdir):
        for file in files:
            file_list.append(os.path.join(root,file))
    return file_list

def tohdf5(file_list, out_file_path='train_images.hdf5'):
    print ('=> Converting images to hdf5')
    print ('=> Total Images To Process : {}'.format(len(file_list)))
    pbar = tqdm(total=len(file_list))
    count = 0
    with h5py.File(out_file_path, "w") as h5:
        for f_ in file_list:
            image = Image.open(f_)
            if image.mode == 'L':
                image = image.convert('RGB')
            image = image.resize((224,224))
            image = np.array(image)
            file_name = f_.split(os.sep)[-1]
            #print (file_name, image.shape)
            h5.create_dataset(file_name, data=image)
            count = count + 1
            if count % 10 == 0:
                pbar.update(count)
    h5.close()
    pbar.close()
    print('=>  Finished Converting images to hdf5')
       
print('=> ========= Converting Train Images ========= <=')
file_list = list_files(TRAIN_IMAGES)
tohdf5(file_list,out_file_path='train_images.hdf5')
print('=> ========= Converting Test Images ========= <=')
file_list = list_files(TEST_IMAGES)
tohdf5(file_list,out_file_path='test_images.hdf5')

# Example DataLoader and DataSet in PyTorch

In [None]:
import torch
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
import pandas as pd

TRAIN_CSV = '../input/happy-whale-and-dolphin/train.csv'
TEST_CSV = '../input/happy-whale-and-dolphin/sample_submission.csv'

# Change accordingly 
# if input
#DATASET_ROOT = '../input/happy-whale-to-hdf5-224x224'

# if output
DATASET_ROOT = './'

# Read CSV to DataFrame
train_df = pd.read_csv(TRAIN_CSV)


# Train Transforms
train_transforms  = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # imagenet
        ])


# Label Encoder
def get_label_encoder_decoder(unique_values):
    label_encoder = {}
    label_decoder = {}
    for idx, label in enumerate(unique_values):
        label_encoder[label] = idx
        label_decoder[idx] = label
    return label_encoder, label_decoder

label_encoder_ind_id, label_decoder_ind_id = get_label_encoder_decoder(train_df['individual_id'].unique())

# torch dataloader
class DolphinWhaleDatasetH5(Dataset):
    
    def __init__(self, root_dir, data_frame, is_train=True, transforms=None):
    
        self.image_names = data_frame['image'].values
        self.is_train = is_train
        if is_train:
            self.labels = data_frame['individual_id'].values
        else:
            self.labels = [-1] *  len(self.image_names)
           
        self.transforms = transforms
        print ('=> Reading HDF5 File...')
        hdf5_path = os.path.join(root_dir,'{}_images.hdf5'.format('train' if is_train else 'test'))
        self.h5 = h5py.File(hdf5_path,'r')
        print('=> Dataset created, image hdf5 file is : {}'.format(hdf5_path))
        
    def __len__(self):
        return len(self.image_names)
    
    def fetch_item_train(self,idx):
        
        # image name
        image_name = self.image_names[idx]
       
        # read image 
        image = np.array(self.h5[image_name])
        
        # fetch and encode label
        label = label_encoder_ind_id[self.labels[idx]]
       
        if self.transforms:
            image = self.transforms(image)
        
        return {'image':image,
                'label':label,}
    
    def fetch_item_test(self,idx):
        image_name = self.image_names[idx]
       
        # read image 
        image = np.array(self.h5[image_name])  
       
        if self.transforms:
            image = self.transforms(image)
        
        return {'image':image,
                'image_name':image_name}
    
    def __getitem__(self, index):
        if self.is_train:
            return self.fetch_item_train(index)
        else:
            return self.fetch_item_test(index)
    
# Training and Validation Dataset
dataset = DolphinWhaleDatasetH5(DATASET_ROOT,train_df,is_train=True, transforms=train_transforms)

train_loader = DataLoader(
    dataset, batch_size=4, num_workers=1)

for batch_idx, sample_ in enumerate(train_loader):
    inputs = sample_['image']  
    print(inputs.shape)
    if batch_idx > 10:
        break