For the past five years, you've honed your skills as a Senior Data Scientist for a global university. Your team leverages its data analytics and machine learning skill sets to help other departments make data-driven decisions. One such department is the procurement team, who is trying to decide the best new mobile phone to offer to the university's employees. For the last week, a Junior Data Scientist on your team has been developing a workflow to help provide insight to the procurement team. You will be reviewing their code to ensure it's ready to ship to production. 

The first chunk of code that you'll be reviewing is your colleague's function to prepare smartphone data from a CSV file for visualization. After ingesting and cleaning the smartphone data, your colleague has prepared a function to plot a variable passed to the function, versus `"price"`. However, within this function, there is code that does not adhere to DRY principles and is copied and pasted. Make sure to refactor the code appropriately, using the `column_to_label()` function defined below.

Wow, your colleague even included a unit test to ensure `NaN` values were removed from the cleaned DataFrame! However, it doesn't seem like the unit test is passing when executed. Re-work this unit test to ensure that it matches the transformation logic in the `prepare_smartphone_data()` function.

Once you've made changes to the `test_nan_values` unit test, you'll want to ensure that these unit tests execute with `ExitCode.OK`. This means that the `pytest` defined above has passed testing, and the code is one step closer to being to be shipped to production.

For context, there is a print statement in the `prepare_smartphone_data()` function in the first cell of the notebook below that can be used to visualize the dataset your Junior Data Engineer has been working with. Feel free to update this line of code as needed. This can then be removed after the dataset has been investigated. Best of luck!

In [11]:
import os
import pandas as pd

def prepare_smartphone_data(file_path):
    """
    Loads and cleans smartphone data for visualization and analysis.

    Cleaning steps:
        - Validates file path existence
        - Loads CSV data into a DataFrame
        - Retains only relevant columns for analysis
        - Removes rows with missing battery capacity or OS
        - Converts price from cents to dollars

    Parameters:
        file_path (str): Path to the raw smartphone CSV file

    Returns:
        pd.DataFrame: Cleaned DataFrame with selected columns and transformed price
    """
    if not os.path.exists(file_path):
        raise FileNotFoundError(f"Smartphone data not found at: {file_path}")

    # Load raw data
    raw_data = pd.read_csv(file_path)

    # Define relevant columns
    columns_to_keep = [
        "brand_name",
        "os",
        "price",
        "avg_rating",
        "processor_speed",
        "battery_capacity",
        "screen_size"
    ]

    # Retain only relevant columns
    trimmed_data = raw_data[columns_to_keep]

    # Drop rows missing battery capacity or OS
    cleaned_data = trimmed_data.dropna(subset=["battery_capacity", "os"])

    # Convert price from cents to dollars
    cleaned_data["price"] = cleaned_data["price"] / 100

    return cleaned_data

In [12]:
import seaborn as sns
import matplotlib.pyplot as plt

def column_to_label(column_name):
    """
    Converts a column name to a formatted label for plots.

    Parameters:
        column_name (str): Original column name

    Returns:
        str: Human-readable label
    """
    if isinstance(column_name, str):
        return " ".join(column_name.split("_")).title()
    raise TypeError("Please make sure to pass a value of type 'str'.")

def visualize_versus_price(clean_data, x):
    """
    Visualizes the relationship between a selected feature and price.

    Parameters:
        clean_data (pd.DataFrame): Cleaned smartphone data
        x (str): Column name to plot on the x-axis

    Returns:
        None
    """
    label = column_to_label(x)

    sns.scatterplot(x=x, y="price", data=clean_data, hue="os")
    plt.xlabel(label)
    plt.ylabel("Price ($)")
    plt.title(f"{label} vs. Price")
    plt.tight_layout()
    plt.show()

In [13]:
# Import required packages
import pytest
import ipytest

ipytest.config.rewrite_asserts = True
__file__ = "notebook.ipynb"


@pytest.fixture()
def clean_smartphone_data():
    """
    Fixture to load and clean smartphone data from CSV.
    """
    return prepare_smartphone_data("./data/smartphones.csv")


def test_nan_values(clean_smartphone_data):
    """
    Test that there are no NaN values in 'battery_capacity' or 'os'
    after cleaning.
    """
    assert clean_smartphone_data["battery_capacity"].isnull().sum() == 0
    assert clean_smartphone_data["os"].isnull().sum() == 0


ipytest.run("-qq")

.                                                                                            [100%]


<ExitCode.OK: 0>