# Introduction to Matplotlib - Data Visualization

Welcome to the Matplotlib tutorial! **Matplotlib** is the most widely used Python library for creating visualizations. A picture is worth a thousand words, and with data, a good visualization can reveal patterns and insights that are hard to see in tables of numbers.

## What is Matplotlib?

Matplotlib provides:
- **Publication-quality plots** - Professional charts for reports and papers
- **Wide variety of plot types** - Line plots, scatter plots, bar charts, histograms, and more
- **Extensive customization** - Control every aspect of your plots
- **Integration with NumPy and Pandas** - Works seamlessly with the tools you've already learned

## What You'll Learn

In this notebook, you'll learn:
1. **Line plots** - Visualizing trends and continuous data
2. **Scatter plots** - Showing relationships between variables
3. **Bar charts** - Comparing categories
4. **Histograms** - Visualizing distributions
5. **Customization** - Colors, labels, legends, styles
6. **Subplots** - Multiple plots in one figure
7. **Pandas integration** - Plotting directly from DataFrames

## Prerequisites

This notebook builds on concepts from:
- **Notebook 2 (NumPy)** - We'll use NumPy arrays and functions like `linspace`
- **Notebook 3 (Pandas)** - We'll plot data from DataFrames

Let's get started!

## Setup - Importing Libraries

We'll import matplotlib, NumPy (from notebook 2), and Pandas (from notebook 3):

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

# For displaying plots in the notebook
%matplotlib inline

# Optional: Set default figure size
plt.rcParams['figure.figsize'] = (10, 6)

print("Libraries imported successfully!")
print(f"Matplotlib version: {plt.matplotlib.__version__}")
print(f"NumPy version: {np.__version__}")
print(f"Pandas version: {pd.__version__}")

## 1. Line Plots - Visualizing Trends

Line plots are perfect for showing trends over time or continuous data. They connect data points with lines.

### Basic Line Plot

Let's create our first plot! We'll use **`np.linspace()`** which we learned in notebook 2 to create evenly spaced x-values.

In [None]:
# ===== BASIC LINE PLOT =====
# Create x values using linspace (from notebook 2)
# linspace(start, stop, num_points) creates evenly spaced values
x = np.linspace(0, 10, 50)  # 50 points from 0 to 10
y = np.sin(x)  # Calculate sine of each x value

# Create the plot
plt.plot(x, y)
plt.title('Sine Wave')  # Add title
plt.xlabel('x')         # Label x-axis
plt.ylabel('sin(x)')    # Label y-axis
plt.grid(True)          # Add grid
plt.show()              # Display the plot

In [None]:
# ===== MULTIPLE LINES ON ONE PLOT =====
x = np.linspace(0, 2 * np.pi, 100)  # Recall: linspace from notebook 2
y1 = np.sin(x)
y2 = np.cos(x)
y3 = np.sin(x) * np.cos(x)

plt.figure(figsize=(10, 6))  # Set figure size
plt.plot(x, y1, label='sin(x)')      # label for legend
plt.plot(x, y2, label='cos(x)')
plt.plot(x, y3, label='sin(x)Â·cos(x)')

plt.title('Trigonometric Functions', fontsize=14)
plt.xlabel('x (radians)', fontsize=12)
plt.ylabel('y', fontsize=12)
plt.legend()  # Show legend
plt.grid(True, alpha=0.3)  # Semi-transparent grid
plt.show()

### Customizing Line Plots

You can customize colors, line styles, markers, and more:

In [None]:
# ===== LINE CUSTOMIZATION =====
x = np.linspace(0, 10, 20)
y1 = x
y2 = x ** 2
y3 = x ** 0.5

plt.figure(figsize=(10, 6))

# Different line styles
plt.plot(x, y1, color='blue', linestyle='-', linewidth=2, label='Linear')
plt.plot(x, y2, color='red', linestyle='--', linewidth=2, label='Quadratic')
plt.plot(x, y3, color='green', linestyle=':', linewidth=2, label='Square Root')

# Add markers
plt.plot(x, y1, 'o', color='blue', markersize=5)  # Circle markers
plt.plot(x, y2, 's', color='red', markersize=5)   # Square markers
plt.plot(x, y3, '^', color='green', markersize=5) # Triangle markers

plt.title('Line Styles and Markers')
plt.xlabel('x')
plt.ylabel('y')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

# Common line styles: '-' (solid), '--' (dashed), ':' (dotted), '-.' (dash-dot)
# Common markers: 'o' (circle), 's' (square), '^' (triangle), '*' (star), '+' (plus)

In [None]:
# ===== SHORTHAND NOTATION =====
# You can combine color, marker, and line style in one string
x = np.linspace(0, 5, 10)
y = x ** 2

plt.figure(figsize=(10, 6))
plt.plot(x, y, 'ro-', label='Red circles with solid line')  # r=red, o=circle, -=solid
plt.plot(x, y + 5, 'bs--', label='Blue squares with dashed line')  # b=blue, s=square, --=dashed
plt.plot(x, y + 10, 'g^:', label='Green triangles with dotted line')  # g=green, ^=triangle, :=dotted

plt.title('Shorthand Notation Examples')
plt.xlabel('x')
plt.ylabel('y')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

### Real-World Example: Temperature Over Time

In [None]:
# ===== TEMPERATURE EXAMPLE =====
# Simulate temperature data for a week
days = np.arange(1, 8)  # Days 1-7 (recall: arange from notebook 2)
temperatures = [22, 24, 23, 25, 26, 24, 23]
day_labels = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun']

plt.figure(figsize=(10, 6))
plt.plot(days, temperatures, 'o-', color='orange', linewidth=2, markersize=8)
plt.xticks(days, day_labels)  # Custom x-axis labels
plt.title('Weekly Temperature Forecast', fontsize=16, fontweight='bold')
plt.xlabel('Day of Week', fontsize=12)
plt.ylabel('Temperature (Â°C)', fontsize=12)
plt.grid(True, alpha=0.3, linestyle='--')
plt.ylim(20, 28)  # Set y-axis limits

# Add horizontal line for average
avg_temp = np.mean(temperatures)
plt.axhline(y=avg_temp, color='red', linestyle='--', label=f'Average: {avg_temp:.1f}Â°C')
plt.legend()

plt.tight_layout()  # Adjust spacing
plt.show()

## 2. Scatter Plots - Showing Relationships

Scatter plots show individual data points and are great for visualizing relationships between two variables.

### Basic Scatter Plot

In [None]:
# ===== BASIC SCATTER PLOT =====
# Generate random data using NumPy (from notebook 2)
np.random.seed(42)  # For reproducibility
x = np.random.randn(100)  # 100 random points from standard normal
y = 2 * x + np.random.randn(100) * 0.5  # y related to x with some noise

plt.figure(figsize=(10, 6))
plt.scatter(x, y, alpha=0.6)  # alpha controls transparency
plt.title('Scatter Plot Example')
plt.xlabel('X Variable')
plt.ylabel('Y Variable')
plt.grid(True, alpha=0.3)
plt.show()

### Customizing Scatter Plots

You can control size, color, and transparency of points:

In [None]:
# ===== SCATTER PLOT CUSTOMIZATION =====
np.random.seed(42)
n_points = 50
x = np.random.rand(n_points) * 10
y = np.random.rand(n_points) * 10
colors = np.random.rand(n_points)  # Random colors
sizes = np.random.rand(n_points) * 1000  # Random sizes

plt.figure(figsize=(10, 6))
scatter = plt.scatter(x, y, c=colors, s=sizes, alpha=0.6, cmap='viridis', edgecolors='black')
plt.colorbar(scatter, label='Color Value')  # Add color scale
plt.title('Customized Scatter Plot', fontsize=14)
plt.xlabel('X Variable')
plt.ylabel('Y Variable')
plt.grid(True, alpha=0.3)
plt.show()

### Real-World Example: Height vs Weight

In [None]:
# ===== HEIGHT VS WEIGHT EXAMPLE =====
np.random.seed(42)

# Generate sample data
heights_male = np.random.normal(175, 7, 50)  # mean=175cm, std=7cm
weights_male = heights_male * 0.9 + np.random.normal(0, 5, 50) - 80

heights_female = np.random.normal(162, 6, 50)  # mean=162cm, std=6cm
weights_female = heights_female * 0.9 + np.random.normal(0, 5, 50) - 80

plt.figure(figsize=(10, 6))
plt.scatter(heights_male, weights_male, alpha=0.6, s=60, c='blue', label='Male')
plt.scatter(heights_female, weights_female, alpha=0.6, s=60, c='red', label='Female')

plt.title('Height vs Weight by Gender', fontsize=14, fontweight='bold')
plt.xlabel('Height (cm)', fontsize=12)
plt.ylabel('Weight (kg)', fontsize=12)
plt.legend(fontsize=10)
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

## 3. Bar Charts - Comparing Categories

Bar charts are perfect for comparing values across different categories.

### Vertical Bar Chart

In [None]:
# ===== VERTICAL BAR CHART =====
categories = ['Product A', 'Product B', 'Product C', 'Product D', 'Product E']
sales = [120, 95, 150, 80, 135]

plt.figure(figsize=(10, 6))
plt.bar(categories, sales, color='steelblue', edgecolor='black')
plt.title('Product Sales Comparison', fontsize=14, fontweight='bold')
plt.xlabel('Product', fontsize=12)
plt.ylabel('Sales (units)', fontsize=12)
plt.grid(axis='y', alpha=0.3)  # Only horizontal grid lines
plt.tight_layout()
plt.show()

In [None]:
# ===== HORIZONTAL BAR CHART =====
plt.figure(figsize=(10, 6))
plt.barh(categories, sales, color='coral', edgecolor='black')  # barh = horizontal
plt.title('Product Sales Comparison (Horizontal)', fontsize=14, fontweight='bold')
plt.xlabel('Sales (units)', fontsize=12)
plt.ylabel('Product', fontsize=12)
plt.grid(axis='x', alpha=0.3)  # Only vertical grid lines
plt.tight_layout()
plt.show()

### Grouped Bar Chart

In [None]:
# ===== GROUPED BAR CHART =====
categories = ['Q1', 'Q2', 'Q3', 'Q4']
sales_2023 = [120, 135, 148, 160]
sales_2024 = [130, 142, 155, 170]

x = np.arange(len(categories))  # Label positions
width = 0.35  # Bar width

plt.figure(figsize=(10, 6))
plt.bar(x - width/2, sales_2023, width, label='2023', color='skyblue', edgecolor='black')
plt.bar(x + width/2, sales_2024, width, label='2024', color='lightcoral', edgecolor='black')

plt.title('Quarterly Sales Comparison', fontsize=14, fontweight='bold')
plt.xlabel('Quarter', fontsize=12)
plt.ylabel('Sales ($1000s)', fontsize=12)
plt.xticks(x, categories)
plt.legend()
plt.grid(axis='y', alpha=0.3)
plt.tight_layout()
plt.show()

### Stacked Bar Chart

In [None]:
# ===== STACKED BAR CHART =====
categories = ['North', 'South', 'East', 'West']
product_a = [30, 35, 25, 40]
product_b = [25, 30, 35, 30]
product_c = [20, 25, 20, 25]

plt.figure(figsize=(10, 6))
plt.bar(categories, product_a, label='Product A', color='#FF6B6B')
plt.bar(categories, product_b, bottom=product_a, label='Product B', color='#4ECDC4')
plt.bar(categories, product_c, bottom=np.array(product_a) + np.array(product_b), 
        label='Product C', color='#95E1D3')

plt.title('Regional Sales by Product (Stacked)', fontsize=14, fontweight='bold')
plt.xlabel('Region', fontsize=12)
plt.ylabel('Sales (units)', fontsize=12)
plt.legend()
plt.grid(axis='y', alpha=0.3)
plt.tight_layout()
plt.show()

## 4. Histograms - Visualizing Distributions

Histograms show the distribution of a single variable by grouping data into bins.

### Basic Histogram

In [None]:
# ===== BASIC HISTOGRAM =====
# Generate random data from normal distribution (from notebook 2)
np.random.seed(42)
data = np.random.randn(1000)  # 1000 samples from standard normal

plt.figure(figsize=(10, 6))
plt.hist(data, bins=30, color='skyblue', edgecolor='black', alpha=0.7)
plt.title('Histogram of Normal Distribution', fontsize=14, fontweight='bold')
plt.xlabel('Value', fontsize=12)
plt.ylabel('Frequency', fontsize=12)
plt.grid(axis='y', alpha=0.3)
plt.tight_layout()
plt.show()

In [None]:
# ===== MULTIPLE HISTOGRAMS =====
# Compare two distributions
np.random.seed(42)
data1 = np.random.normal(100, 15, 1000)  # mean=100, std=15
data2 = np.random.normal(110, 20, 1000)  # mean=110, std=20

plt.figure(figsize=(10, 6))
plt.hist(data1, bins=30, alpha=0.6, label='Group A', color='blue', edgecolor='black')
plt.hist(data2, bins=30, alpha=0.6, label='Group B', color='red', edgecolor='black')

plt.title('Comparing Two Distributions', fontsize=14, fontweight='bold')
plt.xlabel('Score', fontsize=12)
plt.ylabel('Frequency', fontsize=12)
plt.legend()
plt.grid(axis='y', alpha=0.3)
plt.tight_layout()
plt.show()

In [None]:
# ===== HISTOGRAM WITH STATISTICS =====
np.random.seed(42)
data = np.random.normal(75, 10, 500)

plt.figure(figsize=(10, 6))
n, bins, patches = plt.hist(data, bins=25, color='lightgreen', edgecolor='black', alpha=0.7)

# Add vertical lines for mean and median
mean_val = np.mean(data)
median_val = np.median(data)

plt.axvline(mean_val, color='red', linestyle='--', linewidth=2, label=f'Mean: {mean_val:.1f}')
plt.axvline(median_val, color='blue', linestyle='--', linewidth=2, label=f'Median: {median_val:.1f}')

plt.title('Exam Scores Distribution', fontsize=14, fontweight='bold')
plt.xlabel('Score', fontsize=12)
plt.ylabel('Number of Students', fontsize=12)
plt.legend()
plt.grid(axis='y', alpha=0.3)
plt.tight_layout()
plt.show()

print(f"Statistics:")
print(f"  Mean: {mean_val:.2f}")
print(f"  Median: {median_val:.2f}")
print(f"  Std Dev: {np.std(data):.2f}")

### ðŸŽ¯ Practice Exercise: Basic Plots

Create the following visualizations:

1. Create a line plot showing y = xÂ³ for x from -5 to 5 (use `np.linspace()` from notebook 2)
2. Generate 200 random points from a uniform distribution (use `np.random.uniform()` from notebook 2) and create a scatter plot
3. Create a bar chart showing monthly expenses: Rent=1200, Food=400, Transport=200, Entertainment=300, Other=250
4. Generate 1000 samples from a normal distribution with mean=50 and std=10, then create a histogram with 40 bins

In [None]:
# Your code here:


## 5. Subplots - Multiple Plots in One Figure

Subplots allow you to create multiple plots side by side in a single figure.

### Basic Subplots

In [None]:
# ===== SUBPLOTS (2x2 GRID) =====
# Create data
x = np.linspace(0, 10, 100)
y1 = np.sin(x)
y2 = np.cos(x)
y3 = x
y4 = x ** 2

# Create 2x2 grid of subplots
fig, axes = plt.subplots(2, 2, figsize=(12, 10))

# Plot 1: Top-left (row 0, col 0)
axes[0, 0].plot(x, y1, 'b-')
axes[0, 0].set_title('Sine Wave')
axes[0, 0].set_xlabel('x')
axes[0, 0].set_ylabel('sin(x)')
axes[0, 0].grid(True, alpha=0.3)

# Plot 2: Top-right (row 0, col 1)
axes[0, 1].plot(x, y2, 'r-')
axes[0, 1].set_title('Cosine Wave')
axes[0, 1].set_xlabel('x')
axes[0, 1].set_ylabel('cos(x)')
axes[0, 1].grid(True, alpha=0.3)

# Plot 3: Bottom-left (row 1, col 0)
axes[1, 0].plot(x, y3, 'g-')
axes[1, 0].set_title('Linear')
axes[1, 0].set_xlabel('x')
axes[1, 0].set_ylabel('y')
axes[1, 0].grid(True, alpha=0.3)

# Plot 4: Bottom-right (row 1, col 1)
axes[1, 1].plot(x, y4, 'm-')
axes[1, 1].set_title('Quadratic')
axes[1, 1].set_xlabel('x')
axes[1, 1].set_ylabel('xÂ²')
axes[1, 1].grid(True, alpha=0.3)

plt.tight_layout()  # Adjust spacing between subplots
plt.show()

In [None]:
# ===== SUBPLOTS IN A ROW =====
fig, axes = plt.subplots(1, 3, figsize=(15, 4))  # 1 row, 3 columns

# Generate data
np.random.seed(42)
data = np.random.randn(1000)

# Plot 1: Line plot
x = np.linspace(0, 10, 50)
axes[0].plot(x, np.sin(x), 'b-o')
axes[0].set_title('Line Plot')
axes[0].grid(True, alpha=0.3)

# Plot 2: Scatter plot
axes[1].scatter(np.random.randn(100), np.random.randn(100), alpha=0.6)
axes[1].set_title('Scatter Plot')
axes[1].grid(True, alpha=0.3)

# Plot 3: Histogram
axes[2].hist(data, bins=30, color='green', alpha=0.7, edgecolor='black')
axes[2].set_title('Histogram')
axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

In [None]:
# ===== SUBPLOTS WITH SHARED AXES =====
fig, axes = plt.subplots(2, 1, figsize=(10, 8), sharex=True)  # Share x-axis

x = np.linspace(0, 2*np.pi, 100)

# Top plot
axes[0].plot(x, np.sin(x), 'b-', label='sin(x)')
axes[0].set_ylabel('sin(x)')
axes[0].set_title('Sine and Cosine Functions')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Bottom plot
axes[1].plot(x, np.cos(x), 'r-', label='cos(x)')
axes[1].set_xlabel('x (radians)')
axes[1].set_ylabel('cos(x)')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 6. Advanced Customization

### Figure Styles

Matplotlib comes with built-in styles:

In [None]:
# ===== AVAILABLE STYLES =====
print("Available styles:")
print(plt.style.available)

In [None]:
# ===== USING DIFFERENT STYLES =====
x = np.linspace(0, 10, 100)
y1 = np.sin(x)
y2 = np.cos(x)

# Try different styles
styles = ['default', 'ggplot', 'seaborn-v0_8']

fig, axes = plt.subplots(1, 3, figsize=(15, 4))

for i, style in enumerate(styles):
    with plt.style.context(style):
        axes[i].plot(x, y1, label='sin(x)')
        axes[i].plot(x, y2, label='cos(x)')
        axes[i].set_title(f'Style: {style}')
        axes[i].legend()
        axes[i].grid(True)

plt.tight_layout()
plt.show()

### Annotations and Text

In [None]:
# ===== ADDING ANNOTATIONS =====
x = np.linspace(0, 2*np.pi, 100)
y = np.sin(x)

plt.figure(figsize=(10, 6))
plt.plot(x, y, 'b-', linewidth=2)

# Find maximum point
max_idx = np.argmax(y)
max_x = x[max_idx]
max_y = y[max_idx]

# Add annotation with arrow
plt.annotate('Maximum',
             xy=(max_x, max_y),  # Point to annotate
             xytext=(max_x + 1, max_y - 0.3),  # Text location
             arrowprops=dict(arrowstyle='->', color='red', lw=2),
             fontsize=12,
             color='red')

# Add text box
plt.text(4, -0.5, 'Sine wave from 0 to 2Ï€',
         bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5),
         fontsize=11)

plt.title('Annotated Sine Wave', fontsize=14, fontweight='bold')
plt.xlabel('x (radians)')
plt.ylabel('sin(x)')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

### Saving Figures

In [None]:
# ===== SAVING FIGURES =====
x = np.linspace(0, 10, 100)
y = np.sin(x)

plt.figure(figsize=(10, 6))
plt.plot(x, y, 'b-', linewidth=2)
plt.title('Figure to Save')
plt.xlabel('x')
plt.ylabel('sin(x)')
plt.grid(True, alpha=0.3)

# Save in different formats
plt.savefig('my_plot.png', dpi=300, bbox_inches='tight')  # High resolution PNG
plt.savefig('my_plot.pdf', bbox_inches='tight')           # PDF (vector format)
plt.savefig('my_plot.svg', bbox_inches='tight')           # SVG (vector format)

plt.show()
print("âœ“ Figures saved!")

## 7. Plotting with Pandas

Pandas (from notebook 3) integrates seamlessly with Matplotlib. You can plot directly from DataFrames!

### Creating Sample Data with Pandas

In [None]:
# ===== CREATE SAMPLE DATAFRAME =====
# Recall: DataFrames from notebook 3
dates = pd.date_range('2024-01-01', periods=12, freq='M')
sales_data = pd.DataFrame({
    'Month': dates,
    'Product_A': np.random.randint(100, 200, 12),
    'Product_B': np.random.randint(80, 180, 12),
    'Product_C': np.random.randint(90, 190, 12)
})

print("Sales DataFrame:")
print(sales_data.head())

In [None]:
# ===== PLOTTING FROM DATAFRAME - LINE PLOT =====
# Simple one-liner!
sales_data.plot(x='Month', y=['Product_A', 'Product_B', 'Product_C'], 
                figsize=(12, 6), 
                title='Monthly Sales by Product',
                ylabel='Sales (units)',
                grid=True)
plt.tight_layout()
plt.show()

In [None]:
# ===== BAR PLOT FROM DATAFRAME =====
# Average sales by product
avg_sales = sales_data[['Product_A', 'Product_B', 'Product_C']].mean()

avg_sales.plot(kind='bar', 
               figsize=(10, 6),
               color=['skyblue', 'lightcoral', 'lightgreen'],
               edgecolor='black',
               title='Average Monthly Sales by Product',
               ylabel='Average Sales (units)',
               rot=0)  # rot=0 means no rotation of x-labels
plt.grid(axis='y', alpha=0.3)
plt.tight_layout()
plt.show()

In [None]:
# ===== SCATTER PLOT FROM DATAFRAME =====
sales_data.plot(kind='scatter',
                x='Product_A',
                y='Product_B',
                figsize=(10, 6),
                alpha=0.6,
                s=100,
                c='Product_C',  # Color by Product_C values
                cmap='viridis',
                colorbar=True,
                title='Product A vs Product B Sales')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

In [None]:
# ===== HISTOGRAM FROM DATAFRAME =====
sales_data[['Product_A', 'Product_B', 'Product_C']].plot(kind='hist',
                                                           bins=15,
                                                           alpha=0.6,
                                                           figsize=(10, 6),
                                                           title='Distribution of Sales',
                                                           xlabel='Sales (units)')
plt.grid(axis='y', alpha=0.3)
plt.tight_layout()
plt.show()

In [None]:
# ===== BOX PLOT FROM DATAFRAME =====
# Box plots show distribution with quartiles
sales_data[['Product_A', 'Product_B', 'Product_C']].plot(kind='box',
                                                           figsize=(10, 6),
                                                           title='Sales Distribution by Product',
                                                           ylabel='Sales (units)')
plt.grid(axis='y', alpha=0.3)
plt.tight_layout()
plt.show()

### Real-World Example: Analyzing Employee Data

In [None]:
# ===== CREATE EMPLOYEE DATAFRAME =====
# Similar to examples from notebook 3
np.random.seed(42)
n_employees = 100

employees = pd.DataFrame({
    'Department': np.random.choice(['Sales', 'IT', 'HR', 'Finance'], n_employees),
    'Salary': np.random.normal(60000, 15000, n_employees),
    'YearsExperience': np.random.randint(0, 20, n_employees),
    'PerformanceScore': np.random.randint(1, 11, n_employees)
})

employees['Salary'] = employees['Salary'].clip(lower=30000)  # Minimum salary

print("Employee Data:")
print(employees.head())
print(f"\nTotal employees: {len(employees)}")

In [None]:
# ===== VISUALIZATION 1: AVERAGE SALARY BY DEPARTMENT =====
# Using groupby from notebook 3
dept_avg = employees.groupby('Department')['Salary'].mean().sort_values()

dept_avg.plot(kind='barh',
              figsize=(10, 6),
              color='steelblue',
              edgecolor='black',
              title='Average Salary by Department',
              xlabel='Average Salary ($)')
plt.grid(axis='x', alpha=0.3)
plt.tight_layout()
plt.show()

In [None]:
# ===== VISUALIZATION 2: SALARY VS EXPERIENCE =====
fig, axes = plt.subplots(1, 2, figsize=(15, 6))

# Scatter plot
for dept in employees['Department'].unique():
    dept_data = employees[employees['Department'] == dept]
    axes[0].scatter(dept_data['YearsExperience'], 
                   dept_data['Salary'],
                   alpha=0.6,
                   s=50,
                   label=dept)

axes[0].set_title('Salary vs Experience by Department', fontsize=12, fontweight='bold')
axes[0].set_xlabel('Years of Experience')
axes[0].set_ylabel('Salary ($)')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Salary distribution histogram
axes[1].hist(employees['Salary'], bins=20, color='coral', edgecolor='black', alpha=0.7)
axes[1].axvline(employees['Salary'].mean(), color='red', linestyle='--', 
               linewidth=2, label=f"Mean: ${employees['Salary'].mean():,.0f}")
axes[1].set_title('Salary Distribution', fontsize=12, fontweight='bold')
axes[1].set_xlabel('Salary ($)')
axes[1].set_ylabel('Number of Employees')
axes[1].legend()
axes[1].grid(axis='y', alpha=0.3)

plt.tight_layout()
plt.show()

### ðŸŽ¯ Practice Exercise: Pandas and Visualization

Create a DataFrame and visualize it:

```python
# Stock price data
dates = pd.date_range('2024-01-01', periods=30, freq='D')
stocks = pd.DataFrame({
    'Date': dates,
    'AAPL': np.random.normal(150, 10, 30).cumsum() + 1000,
    'GOOGL': np.random.normal(140, 12, 30).cumsum() + 1000,
    'MSFT': np.random.normal(145, 8, 30).cumsum() + 1000
})
```

Tasks:
1. Create a line plot showing all three stocks over time
2. Calculate and plot the daily returns (difference between consecutive days) for AAPL
3. Create a bar chart comparing the average price of each stock
4. Create a 2x2 subplot figure showing: (1) line plot of all stocks, (2) histogram of AAPL prices, (3) scatter plot of AAPL vs GOOGL, (4) box plot of all stocks

In [None]:
# Your code here:


## Congratulations! ðŸŽ‰

You've completed the Matplotlib basics tutorial! You now understand:

âœ… **Line plots** - Visualizing trends and continuous data  
âœ… **Scatter plots** - Showing relationships between variables  
âœ… **Bar charts** - Comparing categories  
âœ… **Histograms** - Visualizing distributions  
âœ… **Customization** - Colors, styles, labels, annotations  
âœ… **Subplots** - Multiple plots in one figure  
âœ… **Pandas integration** - Plotting directly from DataFrames  

## What's Next?

Now that you know the visualization basics:
- **Advanced Matplotlib** - 3D plots, animations, custom styles
- **Seaborn** - Statistical visualizations with beautiful defaults
- **Plotly** - Interactive visualizations
- **Real projects** - Apply visualization to real datasets

## Key Takeaways

1. **Choose the right plot type**: Line plots for trends, scatter for relationships, bars for categories, histograms for distributions
2. **Always label your plots**: Title, axis labels, and legends make plots understandable
3. **Use `np.linspace()` for smooth curves** (from notebook 2)
4. **Pandas DataFrames can be plotted directly** with `.plot()` (from notebook 3)
5. **`plt.tight_layout()`** fixes overlapping labels
6. **Save high-resolution figures** with `savefig(dpi=300)`

## Quick Reference

```python
# Line plot
plt.plot(x, y)

# Scatter plot
plt.scatter(x, y)

# Bar chart
plt.bar(categories, values)

# Histogram
plt.hist(data, bins=20)

# Subplots
fig, axes = plt.subplots(2, 2)

# From Pandas
df.plot(x='col1', y='col2')
```

Keep visualizing your data! ðŸ“Šâœ¨