In [2]:
from typing import Optional, Union
import numpy as np
import pandas as pd
import plotly.graph_objects as go
import re

In [3]:
# Using production coffee data as at 2023-06-26
data = pd.read_csv("coffee_data.txt", delimiter="\t")

In [4]:
data[:3]

Unnamed: 0,name,country_of_origin,roastery,process,varietal,elevation,tasting_notes,added_by,date_added
0,Finca Mumuxa,Guatemala,Carrow,Washed,"Catuai, Caturra",1800.0,"chocolate brownie, melon, stone fruit",Ned,2023-03-09 14:30:11
1,San Martin,Costa Rica,Blossom,Anaerobic Honey,Catuai,,"complex, sweet, syrupy",Ned,2023-03-13 08:21:36
2,Negele Gorbitu,Ethiopia,Friedhats,Anaerobic Natural,"Kurume, Dega, Woshilo",2010.0,"strawberry, papaya, melon, roasted pineapple",Ned,2023-03-13 08:24:18


In [5]:
# Some tidying of variables
data.elevation = data.elevation.str.replace(",", "").astype(float)
data.tasting_notes = data.tasting_notes.str.split(pat=", ")
data.varietal = data.varietal.str.split(pat=", ")

In [9]:
class VariationCalculator:

    def __init__(self, df: pd.DataFrame, entity_col: str, variation_functions: Optional[dict] = None):
        """
        Class for calculating variation rankings for users.
        """
        self.df = df
        self.entity_col = entity_col
        self.variation_functions = variation_functions
        if self.variation_functions is None:
            self.variation_functions = {}
        for attribute in [column for column in self.df.columns if column != self.entity_col]:
            if self.variation_functions.get(attribute, None) is None:
                self.variation_functions[attribute] = self.unique_value_ratio

        for attribute in self.variation_functions.keys():
            if attribute not in self.df.columns:
                raise ValueError(f"Unknown attribute '{attribute}' supplied in index_calcs.")

    def unique_value_ratio(self, X):
        """
        Default method for calculating variation of an attribute. Compares the number
        of unique values against the total number of values.
        """
        return len(set(X)) / len(X)

    def get_attribute_variation(self, attribute):
        """
        Returns a variation index for each person for a given attribute.

        Returns
        -------
        dict
            A dictionary with the variation index (values) for each person (keys)
            in the input data.
        """
        if attribute not in self.df.columns:
            raise ValueError(f"Attribute '{attribute}' not found in data.")

        variation = {}
        # calculate each person's variation value
        for person, group in data.groupby(self.entity_col):
            variation[person] = self.variation_functions[attribute](group[attribute])

        # calculate an overall variation value
        group_variation = self.variation_functions[attribute](data[attribute])

        # normalize by dividing each persons' value by the group value
        normalized_variation = {}
        for person, index in variation.items():
            normalized_variation[person] = index / group_variation

        return normalized_variation

    def get_data_variation(self, return_df: bool = False) -> Union[dict, pd.DataFrame]:
        """
        Calculate variation for all attributes
        """
        variation = {}
        for attribute in self.variation_functions.keys():
            variation[attribute] = self.get_attribute_variation(attribute=attribute)
        if return_df:
            variation = pd.DataFrame(variation)
        return variation

    def plot_data_variation(self):
        """
        Plot radial map of variation by user
        """
        variation = self.get_data_variation(return_df=True)
        variation.columns = [re.sub("_", " ", column).title() for column in variation.columns]
        fig = go.Figure()
        for entity in self.df[self.entity_col].unique():
            fig.add_trace(go.Scatterpolar(
                r=[round(x, 2) for x in variation.loc[entity, :].tolist()],
                theta=variation.columns.tolist(),
                fill='toself',
                name=entity
            ))
        fig.update_layout(
            polar=dict(
                radialaxis=dict(
                visible=False,
                )),
            showlegend=True,
            title="Coffee User Variation"
        )
        fig.show()


def list_unique_ratio(X):
    """
    Custom variation function to handle columns where each cell is a list.
    """
    def collapse_nested_list(nested_list):
        result = []
        for item in nested_list:
            if isinstance(item, list):
                result.extend(collapse_nested_list(item))
            else:
                result.append(item)
        return result
    X_new = collapse_nested_list(X)
    return len(set(X_new)) / len(X_new)


dc = VariationCalculator(df=data, entity_col="added_by", variation_functions={"elevation": np.var, "tasting_notes": list_unique_ratio, "varietal": list_unique_ratio})
dc.get_attribute_variation(attribute="roastery")
dc.get_data_variation(return_df=True)
dc.plot_data_variation()