# Unit 1 Loading and Understanding the Dataset


## Introduction 🚀

Welcome to the first step in our journey of preparing data for **drawing recognition**\! In this lesson, we will focus on loading and understanding the dataset. This is a crucial step because the quality and structure of your data can significantly impact the performance of your drawing recognition model. By the end of this lesson, you'll be equipped with the skills to download and inspect a dataset, setting a strong foundation for the subsequent steps in data preparation.

-----

## What You'll Learn 🧠

In this lesson, you will learn how to load a dataset specifically designed for drawing recognition. We will use a dataset from Google's **Quick, Draw\!** project, which contains millions of drawings across various categories. The drawings in Quick, Draw\! are simple, hand-drawn sketches created by people around the world. Each drawing represents a specific object or concept, such as a cat, house, or bicycle, and is stored as a 28x28 grayscale image.

-----

## What are `.npy` Files? 📂

The dataset files you will download have a ***.npy*** extension. ***.npy*** files are a binary file format used by NumPy to efficiently store arrays on disk. They are commonly used in machine learning projects because they allow for fast reading and writing of large numerical datasets. In this case, each ***.npy*** file contains thousands of 28x28 pixel images for a specific drawing category, stored as NumPy arrays.

Here's a quick look at the code you'll be working with:

```python
import urllib.request
import os

categories = ['cat', 'house', 'airplane', 'apple', 'bicycle']
base_url = 'https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/'

os.makedirs('quickdraw_data', exist_ok=True)

for category in categories:
    file_path = f'quickdraw_data/{category}.npy'
    if not os.path.exists(file_path):
        print(f"Downloading {category}...")
        urllib.request.urlretrieve(base_url + category + '.npy', file_path)
    else:
        print(f"{category}.npy already exists.")
```

This code snippet demonstrates how to download and store datasets for different categories of drawings. You'll learn how to automate the download process and ensure that your data is organized and ready for analysis.

Here is a quick preview of images of the 'apples' category from the **Quick, Draw\!** dataset:

-----

## Why It Matters ✅

Understanding how to load and inspect your dataset is essential because it allows you to verify the data's integrity and structure before diving into more complex preprocessing tasks. By mastering these initial steps, you ensure that your data is reliable and suitable for training a drawing recognition model. This foundational knowledge will empower you to handle datasets confidently, paving the way for successful machine learning projects.

Excited to get started? Let's move on to the practice section and put these concepts into action\!

## Expanding Categories for Drawing Recognition

Now that you've explored the dataset structure, let's expand our collection by adding another category. The Quick Draw dataset contains many different drawing types, and adding more variety will help create a more robust recognition model.

Your task is to add "dog" to the categories list in our download script. This simple modification will automatically download the dog drawings dataset or recognize if it already exists in your data folder.

Adding new categories is a common task when working with classification problems, as you often need to adjust your dataset based on your specific requirements.

```python
import urllib.request
import os

# TODO: Add 'dog' to the categories list
categories = ['cat', 'house', 'airplane', 'apple', 'bicycle', ________]
base_url = 'https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/'

os.makedirs('quickdraw_data', exist_ok=True)

for category in categories:
    file_path = f'quickdraw_data/{category}.npy'
    if not os.path.exists(file_path):
        print(f"Downloading {category}...")
        urllib.request.urlretrieve(base_url + category + '.npy', file_path)
    else:
        print(f"{category}.npy already exists.")

```

Of course. Here is the completed Python script.

```python
import urllib.request
import os

# TODO: Add 'dog' to the categories list
categories = ['cat', 'house', 'airplane', 'apple', 'bicycle', 'dog']
base_url = 'https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/'

os.makedirs('quickdraw_data', exist_ok=True)

for category in categories:
    file_path = f'quickdraw_data/{category}.npy'
    if not os.path.exists(file_path):
        print(f"Downloading {category}...")
        urllib.request.urlretrieve(base_url + category + '.npy', file_path)
    else:
        print(f"{category}.npy already exists.")
```

### **Explanation**

The string `'dog'` was added to the `categories` list. When the script runs, the `for` loop will now iterate through this updated list. It will check for `dog.npy` and, if the file is not found locally in the `quickdraw_data` directory, it will download it from the specified `base_url`.

## Inspecting Dataset Structure and Properties

Great job adding the dog category! Now that we have our datasets downloaded, we need to verify that they're correctly loaded and understand their structure before moving to preprocessing.

In this task, you'll add code to inspect each downloaded .npy file. For each category, you'll load the file, examine its shape, count the number of images, and check the data type.

This verification step is crucial as it helps identify any issues with the downloaded data and gives you insight into what you're working with before applying any transformations.

```python
import urllib.request
import numpy as np
import os

categories = ['cat', 'house', 'airplane', 'apple', 'bicycle', 'dog']
base_url = 'https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/'

os.makedirs('quickdraw_data', exist_ok=True)

for category in categories:
    file_path = f'quickdraw_data/{category}.npy'
    if not os.path.exists(file_path):
        print(f"Downloading {category}...")
        urllib.request.urlretrieve(base_url + category + '.npy', file_path)
    else:
        print(f"{category}.npy already exists.")

# TODO: Add code to verify the downloaded data
print("\nVerifying downloaded data:")
for category in categories:
    file_path = f'quickdraw_data/{category}.npy'
    if os.path.exists(file_path):
        # TODO: Load the data file using np.load()
        data = ________
        
        # TODO: Print the category name, shape, number of images, and data type
        print(f"Category: {category}")
        print(f"  Shape: ________")
        print(f"  Number of images: ________")
        print(f"  Data type: {data.dtype}")
    else:
        print(f"Warning: {category} data file not found!")

```

Of course\! Here is the completed Python script to inspect the dataset.

```python
import urllib.request
import numpy as np
import os

categories = ['cat', 'house', 'airplane', 'apple', 'bicycle', 'dog']
base_url = 'https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/'

os.makedirs('quickdraw_data', exist_ok=True)

for category in categories:
    file_path = f'quickdraw_data/{category}.npy'
    if not os.path.exists(file_path):
        print(f"Downloading {category}...")
        urllib.request.urlretrieve(base_url + category + '.npy', file_path)
    else:
        print(f"{category}.npy already exists.")

# TODO: Add code to verify the downloaded data
print("\nVerifying downloaded data:")
for category in categories:
    file_path = f'quickdraw_data/{category}.npy'
    if os.path.exists(file_path):
        # TODO: Load the data file using np.load()
        data = np.load(file_path)
        
        # TODO: Print the category name, shape, number of images, and data type
        print(f"Category: {category}")
        print(f"  Shape: {data.shape}")
        print(f"  Number of images: {data.shape[0]}")
        print(f"  Data type: {data.dtype}")
    else:
        print(f"Warning: {category} data file not found!")
```

### **Explanation**

1.  **`data = np.load(file_path)`**: This line uses NumPy's `load()` function to read the specified ***.npy*** file from disk and load its contents into the `data` variable as a NumPy array.
2.  **`Shape: {data.shape}`**: Every NumPy array has a `.shape` attribute, which is a tuple representing its dimensions. For these files, the shape is `(number_of_images, number_of_pixels)`, so you'll see something like `(123456, 784)`.
3.  **`Number of images: {data.shape[0]}`**: The number of images corresponds to the first dimension of the array's shape. We access this value using the index `[0]`.

## Visualizing Drawings with Matplotlib

Great work on inspecting the dataset properties! Now let's take our understanding one step further by visualizing the actual drawings.

After verifying numerical properties, it's always good practice to look at your data. For image datasets, this means displaying some sample images to confirm they match their labels and aren't corrupted.

In this task, you'll display one random drawing from each category using matplotlib. You'll need to load the data, select a random sample, reshape it to the proper dimensions (28x28 pixels), and display it in a grid with appropriate labels.

This visual inspection completes our initial data exploration before moving on to preprocessing steps.

```python
import urllib.request
import numpy as np
import os
import matplotlib.pyplot as plt

categories = ['cat', 'house', 'airplane', 'apple', 'bicycle', 'dog']
base_url = 'https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/'

os.makedirs('quickdraw_data', exist_ok=True)

for category in categories:
    file_path = f'quickdraw_data/{category}.npy'
    if not os.path.exists(file_path):
        print(f"Downloading {category}...")
        urllib.request.urlretrieve(base_url + category + '.npy', file_path)
    else:
        print(f"{category}.npy already exists.")

# Verify downloaded data
print("\nVerifying downloaded data:")
for category in categories:
    file_path = f'quickdraw_data/{category}.npy'
    if os.path.exists(file_path):
        data = np.load(file_path)
        print(f"Category: {category}")
        print(f"  Shape: {data.shape}")
        print(f"  Number of images: {len(data)}")
        print(f"  Data type: {data.dtype}")
    else:
        print(f"Warning: {category} data file not found!")

# TODO: Add code to visualize sample drawings from each category
plt.figure(figsize=(15, 3))
for i, category in enumerate(categories):
    file_path = f'quickdraw_data/{category}.npy'
    if os.path.exists(file_path):
        # TODO: Load the data file
        data = ________
        
        # TODO: Select a random drawing index
        random_idx = ________
        
        # TODO: Get the drawing at the random index
        drawing = ________
        
        # TODO: Reshape the drawing to 28x28 pixels
        drawing = drawing.reshape(________, ________)
        
        # Create subplot
        plt.subplot(1, len(categories), i+1)
        # TODO: Display the image with a grayscale colormap
        plt.imshow(________, cmap=________)
        plt.title(category)
        plt.axis('off')
    else:
        print(f"Cannot visualize {category}: file not found")

# Save the figure instead of displaying it
plt.tight_layout()
plt.savefig('static/images/plot.png')

```

Here's the completed script to visualize the drawings.

```python
import urllib.request
import numpy as np
import os
import matplotlib.pyplot as plt

categories = ['cat', 'house', 'airplane', 'apple', 'bicycle', 'dog']
base_url = 'https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/'

os.makedirs('quickdraw_data', exist_ok=True)

for category in categories:
    file_path = f'quickdraw_data/{category}.npy'
    if not os.path.exists(file_path):
        print(f"Downloading {category}...")
        urllib.request.urlretrieve(base_url + category + '.npy', file_path)
    else:
        print(f"{category}.npy already exists.")

# Verify downloaded data
print("\nVerifying downloaded data:")
for category in categories:
    file_path = f'quickdraw_data/{category}.npy'
    if os.path.exists(file_path):
        data = np.load(file_path)
        print(f"Category: {category}")
        print(f"  Shape: {data.shape}")
        print(f"  Number of images: {len(data)}")
        print(f"  Data type: {data.dtype}")
    else:
        print(f"Warning: {category} data file not found!")

# TODO: Add code to visualize sample drawings from each category
plt.figure(figsize=(15, 3))
for i, category in enumerate(categories):
    file_path = f'quickdraw_data/{category}.npy'
    if os.path.exists(file_path):
        # TODO: Load the data file
        data = np.load(file_path)
        
        # TODO: Select a random drawing index
        random_idx = np.random.randint(len(data))
        
        # TODO: Get the drawing at the random index
        drawing = data[random_idx]
        
        # TODO: Reshape the drawing to 28x28 pixels
        drawing = drawing.reshape(28, 28)
        
        # Create subplot
        plt.subplot(1, len(categories), i+1)
        # TODO: Display the image with a grayscale colormap
        plt.imshow(drawing, cmap='gray')
        plt.title(category)
        plt.axis('off')
    else:
        print(f"Cannot visualize {category}: file not found")

# Save the figure instead of displaying it
plt.tight_layout()
plt.savefig('static/images/plot.png')
```

### Explanation of Changes

  * **`data = np.load(file_path)`**: This line uses NumPy's `load()` function to read the specified ***.npy*** file from disk and load its contents into the `data` variable.
  * **`random_idx = np.random.randint(len(data))`**: This selects a single **random integer** between 0 and the total number of images in the file, giving us a random index.
  * **`drawing = data[random_idx]`**: This retrieves the specific drawing (which is a 1D array of 784 pixels) at the randomly selected index.
  * **`drawing = drawing.reshape(28, 28)`**: This reshapes the flat 1D array of 784 pixels into a 2D array of **28x28 pixels**, which is the required format for displaying it as an image.
  * **`plt.imshow(drawing, cmap='gray')`**: This uses Matplotlib to display the 2D `drawing` array. The `cmap='gray'` argument ensures the image is rendered in **grayscale**, which is appropriate for this dataset.

## Analyzing Drawing Dataset Distribution

Now that you've visualized individual drawings, let's analyze the dataset distribution — a critical step before training any machine learning model.

In this task, you'll create a bar chart showing the number of drawings in each category. This visualization will help you determine whether your dataset is balanced or if some categories are overrepresented.

You'll count the drawings in each category file, create a bar chart, and analyze the distribution. This analysis helps identify potential biases that could affect your model's performance across different drawing types.

```python
import urllib.request
import numpy as np
import os
import matplotlib.pyplot as plt

categories = ['cat', 'house', 'airplane', 'apple', 'bicycle', 'dog']
base_url = 'https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/'

os.makedirs('quickdraw_data', exist_ok=True)

for category in categories:
    file_path = f'quickdraw_data/{category}.npy'
    if not os.path.exists(file_path):
        print(f"Downloading {category}...")
        urllib.request.urlretrieve(base_url + category + '.npy', file_path)
    else:
        print(f"{category}.npy already exists.")

# TODO: Create a bar chart showing the distribution of drawings across categories
# Create an empty list to store the count of drawings for each category
category_counts = ________

# Loop through each category and count the number of drawings
for category in categories:
    file_path = f'quickdraw_data/{category}.npy'
    if os.path.exists(file_path):
        # TODO: Load the data and append its length to category_counts
        data = ________
        category_counts.append(________)
    else:
        category_counts.append(0)
        print(f"Warning: Could not count {category} drawings - file not found")

# Create the distribution plot
plt.figure(figsize=(10, 6))
# TODO: Create a bar chart with categories on x-axis and counts on y-axis
plt.bar(________, ________, color='skyblue')
plt.title('Distribution of Drawings Across Categories')
plt.xlabel('Categories')
plt.ylabel('Number of Drawings')
plt.xticks(rotation=45)
plt.tight_layout()

# TODO: Add count labels on top of each bar
for i, count in enumerate(category_counts):
    plt.text(i, count + max(category_counts)*0.01, f"{count:,}", 
             ha='center', va='bottom', fontsize=10)

# Save the distribution chart
plt.savefig('static/images/plot.png')

# TODO: Print analysis of the distribution
print("\nDataset Distribution Analysis:")
# Calculate and print the average number of drawings per category
avg_count = ________
print(f"Average drawings per category: {avg_count:.2f}")

# Find and print the most represented category
print(f"Most represented category: {categories[________]} with {max(category_counts):,} drawings")

# Find and print the least represented category
print(f"Least represented category: {categories[________]} with {min(category_counts):,} drawings")

# Check if dataset is balanced (using a simple threshold)
if max(category_counts) / min(category_counts) > 1.5:
    print("The dataset appears to be imbalanced (max/min ratio > 1.5)")
else:
    print("The dataset appears to be relatively balanced")

```

Of course\! Here is the completed Python script for analyzing the dataset distribution.

```python
import urllib.request
import numpy as np
import os
import matplotlib.pyplot as plt

categories = ['cat', 'house', 'airplane', 'apple', 'bicycle', 'dog']
base_url = 'https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/'

os.makedirs('quickdraw_data', exist_ok=True)

for category in categories:
    file_path = f'quickdraw_data/{category}.npy'
    if not os.path.exists(file_path):
        print(f"Downloading {category}...")
        urllib.request.urlretrieve(base_url + category + '.npy', file_path)
    else:
        print(f"{category}.npy already exists.")

# TODO: Create a bar chart showing the distribution of drawings across categories
# Create an empty list to store the count of drawings for each category
category_counts = []

# Loop through each category and count the number of drawings
for category in categories:
    file_path = f'quickdraw_data/{category}.npy'
    if os.path.exists(file_path):
        # TODO: Load the data and append its length to category_counts
        data = np.load(file_path)
        category_counts.append(len(data))
    else:
        category_counts.append(0)
        print(f"Warning: Could not count {category} drawings - file not found")

# Create the distribution plot
plt.figure(figsize=(10, 6))
# TODO: Create a bar chart with categories on x-axis and counts on y-axis
plt.bar(categories, category_counts, color='skyblue')
plt.title('Distribution of Drawings Across Categories')
plt.xlabel('Categories')
plt.ylabel('Number of Drawings')
plt.xticks(rotation=45)
plt.tight_layout()

# TODO: Add count labels on top of each bar
for i, count in enumerate(category_counts):
    plt.text(i, count + max(category_counts)*0.01, f"{count:,}", 
             ha='center', va='bottom', fontsize=10)

# Save the distribution chart
plt.savefig('static/images/plot.png')

# TODO: Print analysis of the distribution
print("\nDataset Distribution Analysis:")
# Calculate and print the average number of drawings per category
avg_count = np.mean(category_counts)
print(f"Average drawings per category: {avg_count:,.0f}")

# Find and print the most represented category
most_represented_idx = np.argmax(category_counts)
print(f"Most represented category: {categories[most_represented_idx]} with {max(category_counts):,} drawings")

# Find and print the least represented category
least_represented_idx = np.argmin(category_counts)
print(f"Least represented category: {categories[least_represented_idx]} with {min(category_counts):,} drawings")

# Check if dataset is balanced (using a simple threshold)
if max(category_counts) / min(category_counts) > 1.5:
    print("The dataset appears to be imbalanced (max/min ratio > 1.5)")
else:
    print("The dataset appears to be relatively balanced")

```

### **Explanation of Changes**

  * **`category_counts = []`**: Initializes an empty list to store the number of drawings for each category.
  * **`data = np.load(file_path)`**: Loads the data file for the current category.
  * **`category_counts.append(len(data))`**: Appends the number of images (i.e., the length of the loaded array) to our `category_counts` list.
  * **`plt.bar(categories, category_counts, ...)`**: Creates the bar chart, passing the list of category names as the x-axis labels and the list of counts as the y-axis values.
  * **`avg_count = np.mean(category_counts)`**: Uses NumPy's `mean()` function to calculate the average number of drawings across all categories.
  * **`np.argmax(category_counts)`**: Finds the **index** of the category with the highest number of drawings. This index is then used to retrieve the corresponding category name from the `categories` list.
  * **`np.argmin(category_counts)`**: Similarly, finds the **index** of the category with the lowest number of drawings.