---
---
# 1) Exploratory Data Analysis
This notebook contains an adventurous dive into the Overhead-MNIST satellite image data set. Poised to become the benchmark satellite image data set, it serves as an orbital analogy to the famous MNIST handwritten digit pictures. 


----
# 2) Installs & Imports
* An accelerated runtime is not required

In [None]:
import os
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib import rcParams
%matplotlib inline

# Set global plot values for uniformity
rcParams['figure.facecolor'] = 'lightgray'
rcParams['figure.figsize'] = (13, 5)

#Set randomized variables for reproducibility
rs = 42
print('Random state: ', rs)

---
# 3) Load & View Data
* Check for missing values

In [None]:
# Read train.csv into DataFrame and view
train = pd.read_csv('../input/overheadmnist/version2/train.csv')
train.dropna(axis = 0, inplace = True)

classes = pd.read_csv('../input/overheadmnist/version2/classes.csv')


# Ensure tidiness
print('Missing: ', train.isna().sum().sum())

train.head()

In [None]:
classes

In [None]:
# Get class size for normalization
class_lengths = np.array([cnt for cnt in classes['train_count'].values])
class_lengths

---
# 4) Data Exploration
Useful plotting libraries: 
 > * matplotlib.pyplot
 > * pandas (buit-in pyplot)
 > * seaborn

In [None]:
# Find total sums then individually normalize
tot_sums = pd.DataFrame(train.groupby('label').sum().sum(axis = 1).sort_values(ascending = False), 
             columns = ['sum'])

tot_sums['norm_sum'] = tot_sums['sum'] / class_lengths

print(tot_sums)

tot_sums['sum'].plot.barh(ec = 'k')
_ = plt.title('Total Pixel Sums (raw)')
plt.show()

tot_sums['norm_sum'].plot.barh(ec = 'k')
_ = plt.title('Total Pixel Sums (normalized)')
plt.show()

### Remarks
The normalized sums account for difference in sample sizes.

## Heatmaps
Color is an excellent way to add information information into a plot. 
Homogeneous clusters, repeating patterns, and color gradients can all portray value information.

### Sum Correlation & Covariance

In [None]:
plt.figure(figsize = (14, 7))
sns.heatmap(train.groupby('label').sum().corr())

plt.figure(figsize = (14, 7))
sns.heatmap(train.groupby('label').sum().cov())

### Median Correlation & Covariance

In [None]:
plt.figure(figsize = (14, 7))
sns.heatmap(train.groupby('label').median().corr())

plt.figure(figsize = (14, 7))
sns.heatmap(train.groupby('label').median().cov())

### Mean Correlation & Covariance

In [None]:
plt.figure(figsize = (14, 7))
sns.heatmap(train.groupby('label').mean().corr())

plt.figure(figsize = (14, 7))
sns.heatmap(train.groupby('label').mean().cov())

### Standard Deviation Correlation & Covariance

In [None]:
plt.figure(figsize = (14, 7))
sns.heatmap(train.groupby('label').std().corr())

plt.figure(figsize = (14, 7))
sns.heatmap(train.groupby('label').std().cov())

### Pearson r Correlation

In [None]:
# Heat map of the feature correlation matrix
plt.figure(figsize = (14, 7))
sns.heatmap(train.corr(), vmin = -1, vmax = 1)
_ = plt.title('Heatmap of Pearson r Correlation Matrix')

### Distributions

In [None]:
# Helper function
def class_hist(df):
    # Creates a histogram of a pic from each class
    vals = np.sort(df['label'].unique())
    plt.figure(figsize = (14, 24), tight_layout = True)
    # Returns histograms by class
    for i, clss in enumerate(vals):
        plt.subplot(len(vals), 1, i + 1)
        tmp = df[df['label'] == clss].drop('label', axis = 1).iloc[0, :]
        tmp.plot.hist(bins = 255, figsize = (14, 14), edgecolor = 'k')
        plt.ylabel('Class {}'.format(clss))
    plt.suptitle('Sample Class Histograms')
    plt.show()


In [None]:
class_hist(train)

### Cumulative Distributions

In [None]:
# Cumulative Distributions of pixel sums
tmp = train.groupby('label').sum().T
for i in tmp.columns:
    sns.displot(tmp[i], height = 5, aspect = 2.4, kind = 'ecdf')

## Means

In [None]:
# Total class means
train.groupby('label').mean().mean(axis = 1).plot.bar(ec = 'k')
plt.title('Class Pixel Means')
plt.xlabel('Class')
plt.ylabel('Pixel Mean')
plt.show()

In [None]:
tmp = train.groupby('label').mean().T
sns.displot(tmp, bins = 255, height = 5, aspect = 2.6, alpha = .5)
plt.title('Pixel Mean Distributions')

In [None]:
means = train.groupby('label').mean().T.describe().loc[['mean', 'min', 'max','std'], :].T
means

In [None]:
print('Pixel medians by class:')
means.drop('std', axis = 1).sort_values(by = 'mean', 
                                       ascending = False).T.plot.bar(ec = 'k')
_ = plt.title('Trends & Groups in Pixel Medians')

means[['std']].sort_values(by = 'std', 
                             ascending = False).T.plot.bar(ec = 'k')
_ = plt.title('Trends in Medians Standard Deviation')

In [None]:
plt.figure(figsize = (17, 9))
sns.heatmap(train.groupby('label').mean())
plt.title('Mean Pixel Value by Class')

## Sums

In [None]:
sums = train.groupby('label').sum().T.describe().loc[['mean', 'min', 
                                                      'max','std'], :].T
sums

In [None]:
print('Pixel sums by class:')
sums.drop('std', axis = 1).sort_values(by = 'mean', 
                                       ascending = False).T.plot.bar(ec = 'k')
_ = plt.title('Trends & Groups in Pixel Sums')

sums[['std']].sort_values(by = 'std', ascending = False).T.plot.bar(ec = 'k')
_ = plt.title('Trends in Sums Standard Deviation')


In [None]:
plt.figure(figsize = (17, 9))
sns.heatmap(train.groupby('label').sum())
_ = plt.title('Pixel Sums by Class')

## Medians

In [None]:
medians = train.groupby('label').median().T.describe().loc[['mean', 'max', 
                                                            'min', 'std'], :].T
medians

In [None]:
print('Pixel medians by class:')
medians.drop('std', axis = 1).sort_values(by = 'mean', 
                                       ascending = False).T.plot.bar(ec = 'k')
_ = plt.title('Trends & Groups in Pixel Medians')

medians[['std']].sort_values(by = 'std', 
                             ascending = False).T.plot.bar(ec = 'k')
_ = plt.title('Trends in Median Standard Deviation')

In [None]:
plt.figure(figsize = (17, 9))
sns.heatmap(train.groupby('label').median())
plt.title('Median Pixel Value by Class')

In [None]:
for i in tmp.columns:
    sns.displot(tmp[i], bins = 255, kde = True, height = 5, aspect = 2.4)

## Pixel Ranges

In [None]:
pix_range = train.groupby('label').max().T - train.groupby('label').min().T
ranges = pix_range.describe().T[['mean', 'max', 'min', 'std']]
pix_range.T

In [None]:
ranges

In [None]:
# Plot of heatmap of ranges
plt.figure(figsize = (17, 9))
sns.heatmap(pix_range.T, vmin = 219, vmax = 255, cmap = 'rocket')
_ = plt.title('Heatmap of Pixel Ranges')

In [None]:
ranges.drop('std', axis = 1).sort_values(by = 'mean', 
                                         ascending = False).T.plot.bar(ec = 'k')
_ = plt.title('Trends & Groups in Pixel Range')

ranges[['std']].sort_values(by = 'std', 
                            ascending = False).T.plot.bar(ec = 'k')
_ = plt.title('Trends in Range Standard Deviation')

## Scatter Plots
With over 784! possible 2-D pixel combinations, visualizing them all is beyond the scope of this notebook. A subsequent work will explore these relationships in detail if required. Arbitrary pixel values and those of possible interest (e.g. the first and last pixel) are presented. Strong linear relationships and/or near zero variance can be used to guide feature selection.

In [None]:
fig, (ax1, ax2, ax3) =  plt.subplots(1, 3, sharey = True, figsize = (13, 5))
sns.scatterplot(data = train, x = 'pixel1', y = 'pixel2', size = 20,
                hue = 'label', palette = 'Spectral', alpha = .4, ax = ax1)
sns.scatterplot(data = train, x = 'pixel1', y = 'pixel784', size = 20,
                hue = 'label', palette = 'Spectral', alpha = .4, ax = ax2)
sns.scatterplot(data = train, x = 'pixel100', y = 'pixel600', size = 20, 
                hue = 'label', palette = 'Spectral', alpha = .4, ax = ax3)
plt.suptitle('Pixel Correlation Examples')
plt.show()

fig, (ax1, ax2, ax3) =  plt.subplots(1, 3, sharey = True, figsize = (13, 5))
sns.scatterplot(data = train, x = 'pixel2', y = 'pixel200', size = 20,
                hue = 'label', palette = 'Spectral', alpha = .4, ax = ax1)
sns.scatterplot(data = train, x = 'pixel500', y = 'pixel505', size = 20,
                hue = 'label', palette = 'Spectral', alpha = .4, ax = ax2)
sns.scatterplot(data = train, x = 'pixel392', y = 'pixel784', size = 20, 
                hue = 'label', palette = 'Spectral', alpha = .4, ax = ax3)
plt.suptitle('Pixel Correlation Examples 2')
plt.show()

In [None]:
# Reduce size of dataset for pairplot
datum = train.iloc[::10, ::16]
datum

In [None]:
plt.figure(figsize = (25, 25), tight_layout = True)
sns.pairplot(datum, hue = 'label', palette = 'Spectral')
_ = plt.show()

---
# 7) Conclusion
Two classes show pixel ranges that are clearly different from the rest of the group. The difference in range standard deviation is notable, as well. Patterns observed in heatmaps might further be exploited by removing highly correlated partners and more direct feature engineering if required. These patterns bear resemblance to bar codes, or perhaps discrete spectra which can be exploited for increased accuracy or reduced processing time in situ.
* Class separation characteristics apparent when aggregating by groups
* Some pixels show correlation and might permit removal

## Next Steps
* Model optimiziations
* Feature engineering
---
---