# Exploratory Data Analysis Starter

## Import packages

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

# Shows plots in jupyter notebook
%matplotlib inline

# Set plot style
sns.set(color_codes=True)

---

## Loading data with Pandas

We need to load `client_data.csv` and `price_data.csv` into individual dataframes so that we can work with them in Python. For this notebook and all further notebooks, it will be assumed that the CSV files will the placed in the same file location as the notebook. If they are not, please adjust the directory within the `read_csv` method accordingly.

In [None]:
client_df = pd.read_csv('task_3/data/client_data.csv')
price_df = pd.read_csv('task_3/data/price_data.csv')

You can view the first 3 rows of a dataframe using the `head` method. Similarly, if you wanted to see the last 3, you can use `tail(3)`

In [None]:
client_df.head(3)

In [None]:
price_df.head(3)

---

## Descriptive statistics of data

### Data types

It is useful to first understand the data that you're dealing with along with the data types of each column. The data types may dictate how you transform and engineer features.

To get an overview of the data types within a data frame, use the `info()` method.

In [None]:
client_df.info()

In [None]:
price_df.info()

### Statistics

Now let's look at some statistics about the datasets. We can do this by using the `describe()` method.

In [None]:
client_df.describe()

In [None]:
price_df.describe()

---

## Data visualization

If you're working in Python, two of the most popular packages for visualization are `matplotlib` and `seaborn`. We highly recommend you use these, or at least be familiar with them because they are ubiquitous!

Below are some functions that you can use to get started with visualizations. 

In [None]:
def plot_stacked_bars(dataframe, title_, size_=(18, 10), rot_=0, legend_="upper right"):
    """
    Plot stacked bars with annotations
    """
    ax = dataframe.plot(
        kind="bar",
        stacked=True,
        figsize=size_,
        rot=rot_,
        title=title_
    )

    # Annotate bars
    annotate_stacked_bars(ax, textsize=14)
    # Rename legend
    plt.legend(["Retention", "Churn"], loc=legend_)
    # Labels
    plt.ylabel("Company base (%)")
    plt.show()

def annotate_stacked_bars(ax, pad=0.99, colour="white", textsize=13):
    """
    Add value annotations to the bars
    """

    # Iterate over the plotted rectanges/bars
    for p in ax.patches:
        
        # Calculate annotation
        value = str(round(p.get_height(),1))
        # If value is 0 do not annotate
        if value == '0.0':
            continue
        ax.annotate(
            value,
            ((p.get_x()+ p.get_width()/2)*pad-0.05, (p.get_y()+p.get_height()/2)*pad),
            color=colour,
            size=textsize
        )

def plot_distribution(dataframe, column, ax, bins_=50):
    """
    Plot variable distirbution in a stacked histogram of churned or retained company
    """
    # Create a temporal dataframe with the data to be plot
    temp = pd.DataFrame({"Retention": dataframe[dataframe["churn"]==0][column],
    "Churn":dataframe[dataframe["churn"]==1][column]})
    # Plot the histogram
    temp[["Retention","Churn"]].plot(kind='hist', bins=bins_, ax=ax, stacked=True)
    # X-axis label
    ax.set_xlabel(column)
    # Change the x-axis to plain style
    ax.ticklabel_format(style='plain', axis='x')

Thhe first function `plot_stacked_bars` is used to plot a stacked bar chart. An example of how you could use this is shown below:

In [None]:
churn = client_df[['id', 'churn']]
churn.columns = ['Companies', 'churn']
churn_total = churn.groupby(churn['churn']).count()
churn_percentage = churn_total / churn_total.sum() * 100
plot_stacked_bars(churn_percentage.transpose(), "Churning status", (5, 5), legend_="lower right")

The second function `annotate_bars` is used by the first function, but the third function `plot_distribution` helps you to plot the distribution of a numeric column. An example of how it can be used is given below:

In [None]:
consumption = client_df[['id', 'cons_12m', 'cons_gas_12m', 'cons_last_month', 'imp_cons', 'has_gas', 'churn']]

fig, axs = plt.subplots(nrows=1, figsize=(18, 5))

plot_distribution(consumption, 'cons_12m', axs)

In [None]:
# EDA Template – PowerCo Churn Analysis

# Step 0: Imports
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path

# Set style
sns.set(style='whitegrid')

In [None]:
# Define paths
data_dir = Path("data/")  # or "." if files are in root
customers = pd.read_csv(data_dir / "customer_data.csv")
prices = pd.read_csv(data_dir / "price_data.csv")
churn = pd.read_csv(data_dir / "churn_data.csv")

# Preview data
customers.head()
prices.head()
churn.head()


In [None]:
# Check data types
print(customers.info())
print(prices.info())
print(churn.info())


In [None]:
# Summary statistics
customers.describe()
prices.describe()
churn.describe()

# Count unique values
for df, name in zip([customers, prices, churn], ["Customers", "Prices", "Churn"]):
    print(f"\n{name} - Unique values per column:")
    print(df.nunique())


In [None]:
# Visual check for missing data
plt.figure(figsize=(10, 6))
sns.heatmap(customers.isnull(), cbar=False, cmap="Reds")
plt.title("Missing Values in Customer Data")

# Percent missing
print(customers.isnull().mean().sort_values(ascending=False))


In [None]:
# Numeric distributions
num_cols = ['fixed_price', 'variable_price', 'consumption', 'forecasted_usage']  # adjust names

for col in num_cols:
    plt.figure(figsize=(7, 4))
    sns.histplot(customers[col], kde=True, bins=30)
    plt.title(f"Distribution of {col}")
    plt.xlabel(col)
    plt.tight_layout()


In [None]:
cat_cols = ['region', 'industry']  # change according to data

for col in cat_cols:
    plt.figure(figsize=(6, 3))
    sns.countplot(data=customers, x=col, order=customers[col].value_counts().index)
    plt.xticks(rotation=45)
    plt.title(f"Distribution of {col}")
    plt.tight_layout()


In [None]:
# Merge dataframes (assuming customer_id exists)
df = customers.merge(churn, on='customer_id', how='left')

# Churn distribution
sns.countplot(data=df, x='churned')
plt.title("Churn Status Distribution")


- The column `consumption` is right-skewed, suggesting most customers have low usage.
- Fixed and variable prices show some variation over time.
- Missing values are minimal in `customer_data.csv`, but need to be imputed in `forecasted_usage`.
- Most customers belong to 'Retail' and 'Industrial' sectors.
