In [1]:
import pandas as pd
import numpy as np

# DataFrame groupby method

The `groupby` method in Pandas is used to split the data into groups based on some criteria. It is often used to perform operations on each of these groups independently, such as aggregation, transformation, or filtration.

<img src="./images/DF_grouping.webp" style="height:300px">

By groupby method in Pandas we perform the "Split-Apply-Combine" Strategy.

Here's how it works:

1. Splitting: You start by splitting your data into groups based on a column.
2. Applying: You apply a function to each group independently.
3. Combining: The results are then combined into a new DataFrame.

<img src="./images/split-apply-combine.svg" style="height:300px">

### Basic example

In [2]:
# Sample DataFrame
data = {
    'Category': ['A', 'A', 'B', 'B', 'C', 'C'],
    'Values': [1, 2, 3, 4, 5, 6]
}
df = pd.DataFrame(data)
df

Unnamed: 0,Category,Values
0,A,1
1,A,2
2,B,3
3,B,4
4,C,5
5,C,6


In [3]:
# Group by 'Category' and calculate the sum of 'Values'
grouped = df.groupby('Category')
grouped['Values'].sum()

Category
A     3
B     7
C    11
Name: Values, dtype: int64

### Groupby Object

groupby() method returns a [GroupBy object](https://pandas.pydata.org/docs/reference/groupby.html)

In [4]:
# iterate on groupby object
for group_name,group in grouped:
	print(f'group_name:{group_name}')
	print(group)
	print()

group_name:A
  Category  Values
0        A       1
1        A       2

group_name:B
  Category  Values
2        B       3
3        B       4

group_name:C
  Category  Values
4        C       5
5        C       6



In [5]:
# get gorups as Dict {group name -> group labels}.
print(grouped.groups)

{'A': [0, 1], 'B': [2, 3], 'C': [4, 5]}


In [6]:
# select a group
grouped.get_group("A")

Unnamed: 0,Category,Values
0,A,1
1,A,2


## Aggregating functions

<table class="colwidths-given table">
<colgroup>
<col style="width: 20%">
<col style="width: 80%">
</colgroup>
<thead>
<tr class="row-odd"><th class="head"><p>Function</p></th>
<th class="head"><p>Description</p></th>
</tr>
</thead>
<tbody>
<tr class="row-even"><td><p><code class="xref py py-meth docutils literal notranslate"><span class="pre">mean()</span></code></p></td>
<td><p>Compute mean of groups</p></td>
</tr>
<tr class="row-odd"><td><p><code class="xref py py-meth docutils literal notranslate"><span class="pre">sum()</span></code></p></td>
<td><p>Compute sum of group values</p></td>
</tr>
<tr class="row-even"><td><p><code class="xref py py-meth docutils literal notranslate"><span class="pre">size()</span></code></p></td>
<td><p>Compute group sizes</p></td>
</tr>
<tr class="row-odd"><td><p><code class="xref py py-meth docutils literal notranslate"><span class="pre">count()</span></code></p></td>
<td><p>Compute count of group</p></td>
</tr>
<tr class="row-even"><td><p><code class="xref py py-meth docutils literal notranslate"><span class="pre">std()</span></code></p></td>
<td><p>Standard deviation of groups</p></td>
</tr>
<tr class="row-odd"><td><p><code class="xref py py-meth docutils literal notranslate"><span class="pre">var()</span></code></p></td>
<td><p>Compute variance of groups</p></td>
</tr>
<tr class="row-even"><td><p><code class="xref py py-meth docutils literal notranslate"><span class="pre">sem()</span></code></p></td>
<td><p>Standard error of the mean of groups</p></td>
</tr>
<tr class="row-odd"><td><p><code class="xref py py-meth docutils literal notranslate"><span class="pre">describe()</span></code></p></td>
<td><p>Generates descriptive statistics</p></td>
</tr>
<tr class="row-even"><td><p><code class="xref py py-meth docutils literal notranslate"><span class="pre">first()</span></code></p></td>
<td><p>Compute first of group values</p></td>
</tr>
<tr class="row-odd"><td><p><code class="xref py py-meth docutils literal notranslate"><span class="pre">last()</span></code></p></td>
<td><p>Compute last of group values</p></td>
</tr>
<tr class="row-even"><td><p><code class="xref py py-meth docutils literal notranslate"><span class="pre">nth()</span></code></p></td>
<td><p>Take nth value, or a subset if n is a list</p></td>
</tr>
<tr class="row-odd"><td><p><code class="xref py py-meth docutils literal notranslate"><span class="pre">min()</span></code></p></td>
<td><p>Compute min of group values</p></td>
</tr>
<tr class="row-even"><td><p><code class="xref py py-meth docutils literal notranslate"><span class="pre">max()</span></code></p></td>
<td><p>Compute max of group values</p></td>
</tr>
</tbody>
</table>

In [7]:
df

Unnamed: 0,Category,Values
0,A,1
1,A,2
2,B,3
3,B,4
4,C,5
5,C,6


In [8]:
# Compute count of group, excluding missing values.
grouped.count()

Unnamed: 0_level_0,Values
Category,Unnamed: 1_level_1
A,2
B,2
C,2


In [9]:
# Compute max of group values
grouped.max()

Unnamed: 0_level_0,Values
Category,Unnamed: 1_level_1
A,2
B,4
C,6


In [10]:
# Compute sum of group values
grouped.sum()

Unnamed: 0_level_0,Values
Category,Unnamed: 1_level_1
A,3
B,7
C,11


In [11]:
# Take the nth row from each group if n is an int, or a subset of rows if n is a list of ints.
# continents_group.nth(1)
grouped.nth([0,-1])

Unnamed: 0,Category,Values
0,A,1
1,A,2
2,B,3
3,B,4
4,C,5
5,C,6


## More examples

### Example 1: Sales Data

Scenario: You have a dataset of sales records for a store, including the date, product category, and sales amount. You want to analyze the total sales by category and month.

In [12]:
# Generate sample sales data

rng = np.random.default_rng(42)
dates = pd.date_range('2023-01-01', periods=100, freq='D')
categories = ['Electronics', 'Clothing', 'Furniture']
sales_data = {
    'Date': np.random.choice(dates, 100),
    'Category': np.random.choice(categories, 100),
    'Sales_Amount': np.random.uniform(20, 500, 100)
}
df_sales = pd.DataFrame(sales_data)
df_sales.head()

Unnamed: 0,Date,Category,Sales_Amount
0,2023-03-15,Clothing,91.99666
1,2023-01-11,Clothing,399.06502
2,2023-01-16,Clothing,201.340619
3,2023-02-28,Electronics,58.810047
4,2023-01-04,Furniture,110.784225


In [13]:
### Calculate total sales by category
# Group by 'Category' and calculate total sales
grouped_sales = df_sales.groupby('Category')
grouped_sales['Sales_Amount'].sum()

Category
Clothing       9594.802793
Electronics    8455.121560
Furniture      8124.840420
Name: Sales_Amount, dtype: float64

In [14]:
### Calculate total sales by category and month
# add Month column, extract from 'Date':
df_sales['Month'] = df_sales['Date'].dt.to_period('M')

# Group by 'Category' and 'Month' and calculate total sales
grouped_sales_month = df_sales.groupby(['Category', 'Month'])
grouped_sales_month['Sales_Amount'].sum()

Category     Month  
Clothing     2023-01    2043.936203
             2023-02    5902.079431
             2023-03     841.839688
             2023-04     806.947470
Electronics  2023-01    3498.116774
             2023-02    2299.813938
             2023-03    2530.924610
             2023-04     126.266238
Furniture    2023-01    2110.887860
             2023-02    2683.403718
             2023-03    1749.057290
             2023-04    1581.491551
Name: Sales_Amount, dtype: float64

### Example 2: Customer Transactions

*Scenario*: You have a dataset of customer transactions, including customer ID, transaction date, and transaction amount. You want to perform the following tasks:

1. Calculate the total transaction amount for each customer.
1. Find the average transaction amount per customer.
1. Identify customers with transactions above a certain threshold.
1. Analyze monthly transactions for each customer.


In [15]:
# Generate sample customer transactions data
np.random.seed(42)
customer_ids = np.arange(1, 11)
dates = pd.date_range('2023-01-01', periods=100, freq='D')
transaction_data = {
    'Customer_ID': np.random.choice(customer_ids, 100),
    'Transaction_Date': np.random.choice(dates, 100),
    'Transaction_Amount': np.random.uniform(10, 1000, 100)
}
df_transactions = pd.DataFrame(transaction_data)
df_transactions

Unnamed: 0,Customer_ID,Transaction_Date,Transaction_Amount
0,7,2023-01-12,285.860000
1,4,2023-02-03,909.183227
2,8,2023-02-02,247.166272
3,5,2023-02-17,153.445923
4,7,2023-01-23,494.558233
...,...,...,...
95,10,2023-01-02,25.302050
96,9,2023-01-02,929.035377
97,7,2023-04-02,433.902307
98,9,2023-02-23,966.988271


In [16]:
# Task 1: Calculate the total transaction amount for each customer
total_transaction_amount = df_transactions.groupby('Customer_ID')['Transaction_Amount'].sum()
print("Total transaction amount for each customer:")
total_transaction_amount

Total transaction amount for each customer:


Customer_ID
1     4184.441341
2     3593.600771
3     5195.621393
4     4958.270198
5     3544.582543
6     2582.612421
7     6638.668663
8     7900.080139
9     7658.588096
10    5328.930893
Name: Transaction_Amount, dtype: float64

In [17]:
# Task 2: Find the average transaction amount per customer
average_transaction_amount = df_transactions.groupby('Customer_ID')['Transaction_Amount'].mean()
print("\nAverage transaction amount per customer:")
average_transaction_amount



Average transaction amount per customer:


Customer_ID
1     597.777334
2     359.360077
3     577.291266
4     550.918911
5     354.458254
6     430.435403
7     603.515333
8     526.672009
9     638.215675
10    484.448263
Name: Transaction_Amount, dtype: float64

In [18]:
# Task 3: Identify customers with transactions above a certain threshold (e.g., 5000)
threshold = 5000
high_value_customers = total_transaction_amount[total_transaction_amount > threshold]
print("\nCustomers with total transactions above threshold:")
high_value_customers


Customers with total transactions above threshold:


Customer_ID
3     5195.621393
7     6638.668663
8     7900.080139
9     7658.588096
10    5328.930893
Name: Transaction_Amount, dtype: float64

In [19]:
# Task 4: Analyze monthly transactions for each customer
df_transactions['Month'] = df_transactions['Transaction_Date'].dt.to_period('M')
monthly_transactions = df_transactions.groupby(['Customer_ID', 'Month'])['Transaction_Amount'].sum()
print("Monthly transactions for each customer:")
monthly_transactions

Monthly transactions for each customer:


Customer_ID  Month  
1            2023-01    2669.730527
             2023-02     933.575969
             2023-03     581.134846
2            2023-01    1482.921170
             2023-02     559.648803
             2023-03    1551.030797
3            2023-01     185.339573
             2023-02    2068.273528
             2023-03    1668.905923
             2023-04    1273.102370
4            2023-01     627.652638
             2023-02    2456.910685
             2023-03    1704.506904
             2023-04     169.199971
5            2023-01     661.036763
             2023-02    1574.822042
             2023-03    1043.361527
             2023-04     265.362211
6            2023-01     822.082280
             2023-03    1220.113203
             2023-04     540.416937
7            2023-01    3227.060661
             2023-02    2584.837702
             2023-03     392.867993
             2023-04     433.902307
8            2023-01    1261.213718
             2023-02    2642.995766
       

### Example 3: Employee salaries per department  

*Scenario*: You have a DataFrame with employee data (id, gender, department, salary)

You task is to answer next questions:

1. Find the employee with max salaray
2. How many males and females works in each departments
3. What is the average salary for males and females per each department.
4. In which department the male/female have higher average salary?

In [20]:
# Generate sample data:
data_rows = 10

# Define possible values for each column
employee_ids = range(1, data_rows+1)
genders = rng.choice(['Male', 'Female'], size=data_rows)
departments = rng.choice(['HR', 'IT', 'Finance', 'Marketing'], size=data_rows)
salaries = np.random.uniform(30000, 120000, size=data_rows)

# Create the DataFrame
employee_data = pd.DataFrame({
    'employee_id': employee_ids,
    'gender': genders,
    'department': departments,
    'salary': salaries
})
employee_data

Unnamed: 0,employee_id,gender,department,salary
0,1,Male,Finance,106770.850992
1,2,Female,Marketing,56500.400286
2,3,Female,Finance,64658.795574
3,4,Male,Marketing,106602.300437
4,5,Male,Finance,58522.980464
5,6,Female,Marketing,45254.347202
6,7,Male,Finance,80112.113621
7,8,Female,HR,114253.929674
8,9,Male,Marketing,92642.681701
9,10,Male,IT,81305.505308


In [21]:
### Find the employee with max salaray
max_salary_employee_idx = employee_data['salary'].idxmax()
employee_data.loc[max_salary_employee_idx,:]

# The idxmax() method is used to find the index of the first occurrence of the maximum value along a specified axis in a DataFrame.

employee_id                8
gender                Female
department                HR
salary         114253.929674
Name: 7, dtype: object

In [22]:
### How many males and females works in each departments
gender_size = employee_data.groupby(['department','gender']).size()
gender_size.unstack(fill_value=0)


# The unstack() method in pandas pivots the level of the (typically hierarchical) index labels to the columns, producing a DataFrame with a new shape. This is particularly useful when you have a MultiIndex (an index with multiple levels) and you want to convert one of the index levels into columns.

gender,Female,Male
department,Unnamed: 1_level_1,Unnamed: 2_level_1
Finance,1,3
HR,1,0
IT,0,1
Marketing,2,2


In [23]:
###  What is the average salary for males and females per each department.
average_salaries = employee_data.groupby(['department', 'gender'])['salary'].mean()
average_salaries.unstack()


gender,Female,Male
department,Unnamed: 1_level_1,Unnamed: 2_level_1
Finance,64658.795574,81801.981692
HR,114253.929674,
IT,,81305.505308
Marketing,50877.373744,99622.491069


In [24]:
### In which department the male/female have higher average salary?

# find the department with the highest average salary for each gender
highest_average_salary_by_gender = average_salaries.groupby('gender').idxmax()
highest_average_salary_by_gender.apply( lambda value:value[0] )

# The apply function in pandas is a versatile method that allows you to apply a function along the axis of a DataFrame or a Series. It can be used to apply a function to each element, row, or column of a DataFrame or Series.

gender
Female           HR
Male      Marketing
Name: salary, dtype: object