# Import modules

In [None]:
import os.path
import json
import codecs
from collections import Counter
import random
import math

import cv2
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch.utils.data as D

from sklearn.model_selection import train_test_split

## Load files

In [None]:
TRAIN_PATH = "../input/herbarium-2020-fgvc7/nybg2020/train/"
TRAIN_META_PATH = "../input/herbarium-2020-fgvc7/nybg2020/train/metadata.json"

TEST_PATH = "../input/herbarium-2020-fgvc7/nybg2020/test/"
TEST_META_PATH = "../input/herbarium-2020-fgvc7/nybg2020/test/metadata.json"

SUBMISSION_PATH = '../input/herbarium-2020-fgvc7/sample_submission.csv'


with codecs.open(TRAIN_META_PATH, 'r', encoding='utf-8', errors='ignore') as f:
    train_meta = json.load(f)
    
with codecs.open(TEST_META_PATH, 'r', encoding='utf-8', errors='ignore') as f:
    test_meta = json.load(f)

## Quick look and represent data as spreadsheets

### Quick look on train and test sets (as could be seen there are no 'annotations', 'categories',  'regions' provided to test set)

In [None]:
print('Train keys: ', train_meta.keys())
print('Test keys: ', test_meta.keys())

### Quick look on training data annotations

In [None]:
train_df = pd.DataFrame(train_meta['annotations'])
display(train_df)

### Quick look on training data caregories

In [None]:
train_cat = pd.DataFrame(train_meta['categories'])
train_cat.columns = ['family', 'genus', 'category_id', 'category_name']
display(train_cat)

### Quick look on training data images info

In [None]:
train_img = pd.DataFrame(train_meta['images'])
train_img.columns = ['file_name', 'height', 'image_id', 'license', 'width']
display(train_img)

### Quick look on training data regions info

In [None]:
train_reg = pd.DataFrame(train_meta['regions'])
train_reg.columns = ['region_id', 'region_name', ]
display(train_reg)

### Merge training data to a single spreadsheet

In [None]:
train_df = train_df.merge(train_cat, on='category_id', how='outer')
train_df = train_df.merge(train_img, on='image_id', how='outer')
train_df = train_df.merge(train_reg, on='region_id', how='outer')

In [None]:
display(train_df)

### Print training dataset info

In [None]:
print(train_meta['info'])

### Quick look on test (submission) data 

In [None]:
test_df = pd.DataFrame(test_meta['images'])
display(test_df)

### Quick look on submission file

In [None]:
sample_sub = pd.read_csv(SUBMISSION_PATH)
display(sample_sub)

## Plot histograms

### Images HW distribution

In [None]:
heights = [int(w) for w in train_df['height'] if isinstance(w, float) and not math.isnan(w)]
h, b = np.histogram(heights, bins=len(set(widths)))
fig = plt.figure(figsize = (25, 5))
ax = fig.gca()
plt.plot(b[1:], h)
plt.grid()
plt.show()

In [None]:
widths = [int(w) for w in train_df['width'] if isinstance(w, float) and not math.isnan(w)]
h, b = np.histogram(widths, bins=len(set(widths)))
fig = plt.figure(figsize = (25, 5))
ax = fig.gca()
plt.plot(b[1:], h)
plt.grid()
plt.show()

### Distribution category -> image count

In [None]:
h, b = np.histogram(train_df['category_id'], bins=len(np.unique(train_df['category_id'])))
h.sort()
fig = plt.figure(figsize = (25, 5))
ax = fig.gca()
plt.plot(h[::-1])
plt.grid()
plt.show()

### Distribution genus -> image count

In [None]:
GENUS_INDEX = 5

counts = list(Counter(train_df.iloc[:, GENUS_INDEX]).values())
counts.sort()
counts.reverse()

fig = plt.figure(figsize = (25, 5))
ax = fig.gca()
plt.plot(counts)
plt.grid()
plt.show()

### Distribution family -> image count

In [None]:
FAMILY_INDEX = 4

counts = list(Counter(train_df.iloc[:, FAMILY_INDEX]).values())
counts.sort()
counts.reverse()

fig = plt.figure(figsize = (25, 5))
ax = fig.gca()
plt.plot(counts)
plt.grid()
plt.show()

## Prepare simple torch dataset

In [None]:
class HerbariumDataset(D.Dataset):
    def __init__(self, data, path):
        self.data = data
        self.path = path

    def __len__(self):
        return len(self.data)

    def __getitem__(self, i):
        fname = self.data['file_name'].values[i]
        fpath = os.path.join(self.path, fname)
        image = cv2.imread(fpath)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        label = self.data['category_id'].values[i]
        
        return image, label

In [None]:
train_data, test_data = train_test_split(train_df)

train_dataset = HerbariumDataset(train_data, TRAIN_PATH)
test_dataset = HerbariumDataset(test_data, TRAIN_PATH)  # There should be train path, it is correct

In [None]:
img, label = train_dataset[random.randint(0, len(train_dataset))]
print(label)
plt.imshow(img)

In [None]:
img, label = test_dataset[random.randint(0, len(test_dataset))]
print(label)
plt.imshow(img)