## Jane Street Market Prediction - EDA

This notebook explores the dataset provided in the __[Jane Street Market Prediction](https://www.kaggle.com/c/jane-street-market-prediction)__ competition.

In [None]:
import json
import os

import pandas as pd
import cv2 as cv

import matplotlib.pyplot as plt
import seaborn as sns

# location of data files
comp_folder = os.path.join(os.pardir, "input", "jane-street-market-prediction")
train_images = os.path.join(comp_folder, "train_images")
labels_csv = os.path.join(comp_folder, "train.csv")
strings_json = os.path.join(comp_folder, "label_num_to_disease_map.json")

Look at the label data frame and the class imbalance.

In [None]:
# read labels and create a column with the names
with open(strings_json) as file:
    label_to_disease = {l: s.replace("Cassava ", "") 
                        for l, s in json.loads(file.read()).items()}

df = pd.read_csv(labels_csv, index_col="image_id")
df["disease"] = df["label"].astype(str).map(label_to_disease)

# take a look at the structure of train.csv
print(df)

# calculate percentage samples with each disease
percentages = 100 * df["disease"].value_counts(normalize=True, sort=False)

# plot as a barplot
bar, ax = plt.subplots(figsize=(8, 5))
ax = sns.barplot(x=percentages, y=percentages.index, orient="h")
ax.set_title("Label Distribution")
ax.get_xaxis().set_visible(False)
plt.box(False)

for rect in ax.patches:
    width = rect.get_width()
    height = rect.get_height() 
    ax.text(width, rect.get_y() + height / 1.7, f"  {round(width, 2)}%")

Check how many images of each shape we have.

In [None]:
shapes = {}
for file in os.listdir(train_images):
    # read the image and increase the counter for that shape
    image = cv.imread(os.path.join(train_images, file))
    shapes[image.shape] = shapes.get(image.shape, 0) + 1

print(f"Number of training images: {sum(shapes.values())}")
print(f"Shapes of the training images: {shapes}")

Show some sample images for each type of disease.

In [None]:
for disease in label_to_disease.values():
    plt.figure(figsize=(16, 12))

    # sample 9 images with the given disease
    plot_df = df[df["disease"] == disease].sample(9)
    
    for ix, file in enumerate(plot_df.index.values):
        # read the image and convert to RGB (order the channels correctly)
        image = cv.imread(os.path.join(train_images, file))
        image = cv.cvtColor(image, cv.COLOR_BGR2RGB)

        # plot the image in the appropriate subplot
        plt.subplot(3, 3, ix + 1)
        plt.imshow(image)
        plt.title(disease)
        plt.axis("off")