In [38]:
import pandas as pd
import os
import matplotlib.pyplot as plt
import numpy as np
import copy
import unittest

In [39]:
def setElement(df, element):
    """
        Choose the metric to see the data in. 
        Arguments:
            df: the data frame. 
            element: String, the element to choose frame.
        Return a dataframe with the Element column filtered down to only element. 
    """
    return df[df['Element'] == element]

def createYearList(y1, y2):
    """
        Create a list of [Yy1, Yy1 + 1, Yy1 + 2, ...., Yy2]
        Arguments:
            y1: integer, the year to start with
            y2: interger, the year to end the range
        Return a list of years between y1 and y2 inclusive   
    """
    yearCols = []
    for y in range(y1, y2 + 1):
        yearCols.append("Y" + str(y))
    return yearCols

def cropData(df, agItem):
    """
        Filter the data for a specific crop/food. 
        Argument:
            df: the dataframe
            agItem: String, the crop/food to filter for. Must be an item in the "Item" column
        Return: the agriculture data filtered by the crop/food.     
    """
    if (agItem not in itemsInData):
        txt = "{item} is not an item in the data frame".format(item = agItem)
        raise ValueError(txt)
    return df[df["Item"] == agItem]

def checkYear(year):
    """
        Check that the YEAR is within the years in the dataset. 
        Arguments:
            year: int, the year to check
        Return True if it is, False otherwise.
    """
    return year >= firstYrInData and year <= lastYrInData

def checkRange(year1, year2):
    """
        Check year range is valid. Raise an error if it's not. 
    """
    checkYear(year1)
    checkYear(year2)
    if (year2 < year1):
        txt = "{y2} is before {y1}".format(y2=year2, y1 = year1)
        raise ValueError(txt)
    if (not checkYear(year1) or not checkYear(year2)):
        raise ValueError("Years must be within the range of {y1} and {y2}".format(y1=year1, y2=year2))

def getYearData(df, year1, year2, keepCols):
    """
        Get the year columns between year1 and year2
        Arguments:
            df: the dataframe
            year1: integer, the beginning year column. Between 1961 and 2020 for the agriculture dataframe.
            year2: integer, the ending year column. Between 1961 and 2020 for the agriculture dataframe.
            keepCols: list, columns besides the year columns
        Return: the ag data with the year columns only    
    """
    checkRange(year1, year2)
    getCols = copy.deepcopy(keepCols)
    getCols.extend(createYearList(year1, year2))
    return df[getCols]

def subsetAgData(df, crop, y1, y2, keepCols):
    """
        Get a subset of the agricultre data for the production 
        of the crop between year y1 and year y2. 
        Arguments:
            df: the dataframe.
            crop: String, the item to filter for
            y1: int, the start year for the data
            y2: int, the end year for the data
            keepCols: list of columns that are not the year columns
        Return: a subset of the agricultre dataframe
    """
    dfSubset = cropData(df, crop)
    return getYearData(dfSubset, y1, y2, keepCols)

def getItemUnit(df):
    """
        Get the unit of measurement for the item after subsetting
        the data frame.
        Arguments:
            df: the dataframe
        Return: A String for the unit of the item's measurment. 
    """
    return df['Unit'].unique()[0]

def dropRegionRows(df):
    """
        Return the agriculture data frame without the region rows. 
    """
    return df[df["Area Code"] < 420]

def addOtherSum(df):
    """
        Add the sum of the countries labelled Other for each year into the df.
        Arguments:
            df: The dataframe
        Return: the dataframe with the other sums for each year added into it. 
    """
    otherSum = df[df["Label"] == "Other"].groupby(["Year"], as_index=False)["Amount"].sum().reset_index()
    otherSum["Label"] = ["Other" for i in range(len(otherSum))]
    otherSum.drop("index", axis=1, inplace = True)
    df = df.drop(df[df['Label'] == "Other"].index)
    df = df[["Label", "Year", "Amount"]]
    df = pd.concat([df, otherSum], ignore_index=True)
    return df

def getTopXSubset(df, topX):
    """
        Get the top X countries by total production for the range of years.
        Add a label column. Countries in the top X will have their lables 
        as area name, all others will be labeled other. 
    """
    sumProduction = df.groupby(["Area"])['Amount'].sum().reset_index()
    sumProduction = sumProduction.sort_values(by = ['Amount'], ascending = [False])
    topXcountries = sumProduction['Area'].to_numpy()[:topX]
    df['Label'] = [name if name in topXcountries else "Other" for name in df['Area']]
    return df

def yearsToRows(df, yearColumns):
    """
        Use pd.melt to move the year columns to rows. 
        Arguments:
            df: the dataset (agData)
            yearColumns: the columns to make as rows
        Return: the result of running pd.melt (years to rows)
    """
    df = pd.melt(df, id_vars = ["Area Code", "Area", "Item", "Element", "Unit"], value_vars = yearColumns,
                    var_name = "Year", value_name="Amount")
    return df


def findMidPoint(y1, y2):
    """
        Calculate the year in between two years
        Arguments:
            y1: int
            y2: int
        Return: int, the midpoint year between the two years
    """
    return (y1 + y2) / 2

def getItemData(item, df, element, y1, y2, keepCols):
    """
        Clean data up to getting item and element. Move year columns to rows. 
        Arguments:
            item: String, the item to get data for.
            df: the dataframe
            element: String, one of the Element column values
            y1: int, start year
            y2: int, end year
            keepCols: list of columns to keep
        Returns a dataframe of the item, element, keepCols, and year columns to rows. 
    """
    df = setElement(df, element)
    df = subsetAgData(df, item, y1, y2, keepCols)
    df = dropRegionRows(df)
    df = df.drop(df[df["Area"] == "China"].index)
    df = yearsToRows(df, createYearList(y1, y2))
    return df
    

def prepData(item, df = pd.read_csv(os.path.join(os.getcwd(), "Data", "Production_Crops_Livestock_E_All_Data.csv"), 
                                    encoding="latin-1"),
             element="Production", y1 = 2000, y2 = 2020, 
             keepCols=["Area Code", "Area", "Item", "Element", "Unit"], topX=10, useTopX=True):
    """
        Prepare the data for plotting.
        Arguments:
            df: the dataframe
            element: String, the element like "Production"
            y1: int, start year
            y2: int, end year
            keepCols: the columns besides the year columns
            topX: int, the number of top producers for the plot
        Returns a tuple of the prepped data frame and the unit
    """
    df = getItemData(item, df, element, y1, y2, keepCols)
    unit = getItemUnit(df)
    if useTopX:
        df = getTopXSubset(df, topX)
        df = addOtherSum(df)
    return (df,unit)

In [40]:
pathToData = os.path.join(os.getcwd(), "Data", "Production_Crops_Livestock_E_All_Data.csv")
agData = pd.read_csv(pathToData, encoding="latin-1")

itemsInData = sorted(agData["Item"].unique())
elementsToChooseFrom = agData["Element"].unique()
keepCols = ["Area Code", "Area", "Item", "Element", "Unit"]
fruit = "Oranges"
year1 = 2000
year2 = 2020
firstYrInData = 1961
lastYrInData = 2020
topX = 10

keepColumnsList = ["Area", "Area Code", "Unit", "Element", "Item", "Y2019", "Y2020"]

agDataOranges = agData[(agData["Item"] == "Oranges") & (agData["Element"] == "Production")]
agDataOranges = agDataOranges[keepColumnsList]
agDataOranges = agDataOranges.head(5)

agDataGrapes = agData[(agData["Item"] == "Grapes") & (agData["Element"] == "Production")]
agDataGrapes = agDataGrapes[keepColumnsList]
agDataGrapes = agDataGrapes.iloc[10:15]


In [41]:
def findMidPoint(y1, y2):
    """
        Calculate the year in between two years
        Arguments:
            y1: int
            y2: int
        Return: int, the midpoint year between the two years
    """
    return (y1 + y2) / 2


def plotProductionCrop(df, element, item, y1, y2, topX, plottingFunc, 
                       keepCols = ["Area Code", "Area", "Item", "Element", "Unit"], useTopX = True):
    """
        Put all the methods together and 
        create a plot for the item
        from y1 to y2 inclusive for the 
        top X producers. 
        Arguments:
            df: dataframe
            element: String, example: "Production"
            y1: int, start year
            y2: int, end year
            topX: int, the number of top producers for the plot
            keepCols: the columns besides the year columns
            createPlotFunc: a function to create the plot like createStackPlot
    """
    preppedInfo = prepData(item, df, element, y1, y2, keepCols,topX, useTopX=useTopX)
    df = preppedInfo[0]
    unit = preppedInfo[1]
    elementMap = {"Production": "Producers", "Area harvested": "Growers By Area", "Stocks": "Producers",
                 "Yield": "Most Efficient Producers"}
    createPlot(df, item, unit, topX, element, elementMap, plottingFunc)

In [42]:
def createPlot(df, item, unit, topX, element, elementMap, plottingFunc):
    df['Year'] = [int(i[1:]) for i in df['Year']]
    countryList = list(set([c for c in df['Label']]))
    plotDict = {}
    for country in countryList:
        countryDf = df[df['Label'] == country]
        plotValues = countryDf['Amount'].values.T
        plotValues[np.isnan(plotValues)] = 0
        plotDict[country] = plotValues
    yearList = df['Year'].unique()
    startYear = yearList[0]
    endRangeYear = yearList[len(yearList) - 1]
    midYear = findMidPoint(startYear, endRangeYear)
    quarterYear = findMidPoint(startYear, midYear)
    threeQuarterYear = findMidPoint(midYear, endRangeYear)
    fig, ax = plt.subplots()
    plottingFunc(ax, yearList, plotDict)
    ax.legend(loc='upper left')
    ax.ticklabel_format(useOffset=False, style='plain')
    ax.set_xticks([startYear, quarterYear, midYear, threeQuarterYear, endRangeYear])
    fig = plt.gcf()
    fig.set_size_inches(10, 10)
    ax.set_title("Top {topX} {element} of {crop} between {y1} and {y2}".format(crop=item, element=elementMap[element],
                                                                            y1 = startYear, y2 = endRangeYear,
                                                                              topX = topX), fontsize = 20.0)
    ax.legend(bbox_to_anchor = (1.05, 1), loc = "upper left")
    ax.set_xlabel("Year", fontsize = 15.0)
    ax.set_ylabel("{element} in {unit}".format(element=element, unit=unit), fontsize = 15.0, labelpad = 15.0)
    plt.show()

def stackPlot(ax, years, plotDict):
    """
        Create a stack area plot.
        ax: the axis
        years: the year list
        plotDict: dictionary of {country -> [y1, y2, y3, ...]}
    """
    ax.stackplot(years, plotDict.values(), labels=plotDict.keys())
    ax.set_xlim(left = years[0], right = years[len(years) - 1])
    
def stackBarPlot(ax, years, plotDict):
    """
        Create a stack bar plot.
        ax: the axis
        years: the year list
        plotDict: dictionary of {country -> [y1, y2, y3, ...]}
    """
    bottom = np.full(len(years), 0)
    
    for country in plotDict.keys():
        ax.bar(years, plotDict[country], label=country, bottom=bottom)
        bottom = bottom + plotDict[country]
    ax.set_xlim(left = years[0] - 1, right = years[len(years) - 1] + 1)

In [43]:
def getItemsThatHaveElement(df, element):
    """
        Find all the items that have a specified element like "Production"
        or "Yield"
        Arguments:
            df: dataframe
            element: String
        Return: a list of items that have the specified element. 
    """
    dfOnlyElement = df[df['Element'] == element]
    return dfOnlyElement['Item'].unique()
    

In [44]:
class TestDataPrep(unittest.TestCase):
    def test_prepData(self):
        preppedData = prepData("Oranges", agDataOranges, "Production", 2019, 2020, keepCols, 3)
        expectedTopThree = ["Algeria", "Argentina", "Australia", "Other"]
        actualTopThree = preppedData[0]['Label'].unique()
        self.assertCountEqual(expectedTopThree, actualTopThree)
        self.assertEqual("tonnes", preppedData[1])
    
    def test_withNA(self):
        preppedData = prepData("Grapes", agDataGrapes, "Production", 2019, 2020, keepCols, 2)[0]
        other2019Amount = list(preppedData.loc[(preppedData['Label'] == "Other") & (preppedData['Year'] == "Y2019"), "Amount"])[0]
        other2020Amount = list(preppedData.loc[(preppedData['Label'] == "Other") & (preppedData['Year'] == "Y2020"), "Amount"])[0]
        self.assertEqual(23549.0, other2019Amount)
        self.assertEqual(22235.0, other2020Amount)

class TestGetItemWithElements(unittest.TestCase):
    def test_getItemsThatHaveElement(self):
        pathToData = os.path.join(os.getcwd(), "Data", "Production_Crops_Livestock_E_All_Data.csv")
        agData = pd.read_csv(pathToData, encoding="latin-1")
        agDataTest = agData[(agData['Item'] == "Oranges") | (agData['Item'] == "Meat, chicken")]
        itemsWithYield = getItemsThatHaveElement(agDataTest, "Yield")
        expected = ['Oranges']
        self.assertEqual(expected, itemsWithYield)
    

In [45]:
unittest.main(argv=[''], verbosity=2, exit=False)

test_prepData (__main__.TestDataPrep) ... ok
test_withNA (__main__.TestDataPrep) ... ok
test_getItemsThatHaveElement (__main__.TestGetItemWithElements) ... ok

----------------------------------------------------------------------
Ran 3 tests in 2.583s

OK


<unittest.main.TestProgram at 0x1ad003a9400>