In [None]:
# Plot correlation heat maps for several sklearn datasets
# Inspect variables with high correlation further
# Maybe we can reduce the number of features in the datasets
# by removing highly correlated variables (collinearity)
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn import datasets

In [None]:
def plot_correlation_heatmap(data, title):
    """
    Plot a correlation heatmap for a given dataset
    """
    df = pd.DataFrame(data.data, columns=data.feature_names)
    df['target'] = data.target
    sns.heatmap(df.corr(), annot=True, cmap='coolwarm')
    plt.title(title)
    plt.show()


# Iris data set

In [None]:
iris = datasets.load_iris()
plot_correlation_heatmap(iris, 'Iris dataset')

In [None]:
# Explore this further in a pairplot
df = pd.DataFrame(iris.data, columns=iris.feature_names)
df['type'] = iris.target
sns.pairplot(df, vars=["petal width (cm)", "petal length (cm)"], hue='type')

# Wine data set

In [None]:
wine = datasets.load_wine()
plt.figure(figsize=(20, 20))
plot_correlation_heatmap(wine, 'Wine dataset')

In [None]:
# Explore this further in a pairplot
df = pd.DataFrame(wine.data, columns=wine.feature_names)
df['target'] = wine.target
sns.pairplot(df, vars=["total_phenols", "flavanoids"], hue='target')

# Breast cancer data set

In [None]:
data = datasets.load_breast_cancer()
plt.figure(figsize=(20, 20))
plot_correlation_heatmap(data, 'Breast cancer dataset')

In [None]:
# pairplot of radius error and perimeter error
df = pd.DataFrame(data.data, columns=data.feature_names)
df['target'] = data.target
sns.pairplot(df, vars=['mean radius', 'mean perimeter'], hue='target')