## Preparing the 20 Newsgroups Dataset

We will need some data to work with. For the purposes of this demo we will make use of the 20 newsgroups dataset. Even though 20 newsgroups is a toy dataset, it offers enough complications to show how we can piece together embeddings using ``vectorizers``.

In [1]:
%load_ext autoreload
%autoreload 2

In [3]:
# from src import paths
# from src.data import Dataset

In [2]:
import numpy as np
import matplotlib.colors
import seaborn as sns
import pandas as pd

In [8]:
import csv

In [3]:
def read_format_recipes(recipe_min_size=3):
    ingredients_id = pd.read_csv('../data/cat-edge-Cooking/node-labels.txt', sep='\t', header=None)
    ingredients_id.index = [x+1 for x in ingredients_id.index]
    ingredients_id.columns = ['Ingredient']
    
    recipes_with_id = []
    with open('../data/cat-edge-Cooking/hyperedges.txt', newline = '') as hyperedges:
        hyperedge_reader = csv.reader(hyperedges, delimiter='\t')
        for hyperedge in hyperedge_reader:
            recipes_with_id.append(hyperedge)
            
    recipes_all = [[ingredients_id.loc[int(i)]['Ingredient'] for i in x] for x in recipes_with_id]
    
    # Keep recipes with 3 ingredients and more
    keep_recipes = np.where(np.array([len(x) for x in recipes_all])>=recipe_min_size)[0]
    recipes = [recipes_all[i] for i in keep_recipes]
    
    recipes_label_id_all = pd.read_csv('../data/cat-edge-Cooking/hyperedge-labels.txt', sep='\t', header=None)
    recipes_label_id_all.columns = ['label']
    recipes_label_id = recipes_label_id_all.iloc[keep_recipes].reset_index()

    label_name = pd.read_csv('../data/cat-edge-Cooking/hyperedge-label-identities.txt', sep='\t', header=None)
    label_name.columns = ['country']
    label_name.index = [x+1 for x in label_name.index]
    
    grps_tmp = {
        'asian' : ('chinese', 'filipino', 'japanese','korean', 'thai', 'vietnamese'),
        'american' : ('brazilian', 'mexican', 'southern_us'),
        'english' : ('british', 'irish'),
        'islands' : ('cajun_creole', 'jamaican'),
        'europe' : ('french', 'italian', 'spanish'),
        'others' : ('greek', 'indian', 'moroccan', 'russian')
    }

    grps = {key:[key+'.'+x for x in value] for key, value in grps_tmp.items()}


    color_key = {}
    for l, c in zip(grps['asian'], sns.color_palette("Blues", 6)[0:]):
        color_key[l] = matplotlib.colors.rgb2hex(c)
    for l, c in zip(grps['american'], sns.color_palette("Purples", 4)[1:]):
        color_key[l] = matplotlib.colors.rgb2hex(c)
    for l, c in zip(grps['others'], sns.color_palette("YlOrRd", 4)):
        color_key[l] = matplotlib.colors.rgb2hex(c)
    for l, c in zip(grps['europe'], sns.color_palette("light:teal", 4)[1:]):
        color_key[l] = matplotlib.colors.rgb2hex(c)
    for l, c in zip(grps['islands'], sns.color_palette("light:#660033", 4)[1:3]):
        color_key[l] = matplotlib.colors.rgb2hex(c)
    for l, c in zip(grps['english'], sns.color_palette("YlGn", 4)[1:]):
        color_key[l] = matplotlib.colors.rgb2hex(c)
    color_key["ingredient"] = "#777777bb"
    
    new_names = []
    for key, value in grps.items():
        new_names = new_names + value

    label_name['new_label'] = [new_name for x in label_name.country for new_name in new_names if x in new_name]
    
    return(recipes, recipes_label_id, ingredients_id, label_name, color_key)

### Careful: some recipes have no ingredient left after the ingredient pruning (based on frequency)

We have inspect it and it only happens to recipes containing a single ingredient. We will remove those recipes

In [4]:
grps_tmp = {
    'asian' : ('chinese', 'filipino', 'japanese','korean', 'thai', 'vietnamese'),
    'american' : ('brazilian', 'mexican', 'southern_us'),
    'english' : ('british', 'irish'),
    'islands' : ('cajun_creole', 'jamaican'),
    'europe' : ('french', 'italian', 'spanish'),
    'others' : ('greek', 'indian', 'moroccan', 'russian')
}

grps = {key:[key+'.'+x for x in value] for key, value in grps_tmp.items()}


color_key = {}
for l, c in zip(grps['asian'], sns.color_palette("Blues", 6)[0:]):
    color_key[l] = matplotlib.colors.rgb2hex(c)
for l, c in zip(grps['american'], sns.color_palette("Purples", 4)[1:]):
    color_key[l] = matplotlib.colors.rgb2hex(c)
for l, c in zip(grps['others'], sns.color_palette("YlOrRd", 4)):
    color_key[l] = matplotlib.colors.rgb2hex(c)
for l, c in zip(grps['europe'], sns.color_palette("light:teal", 4)[1:]):
    color_key[l] = matplotlib.colors.rgb2hex(c)
for l, c in zip(grps['islands'], sns.color_palette("light:#660033", 4)[1:3]):
    color_key[l] = matplotlib.colors.rgb2hex(c)
for l, c in zip(grps['english'], sns.color_palette("YlGn", 4)[1:]):
    color_key[l] = matplotlib.colors.rgb2hex(c)
color_key["ingredient"] = "#777777bb"

In [5]:
color_key

{'asian.chinese': '#dbe9f6',
 'asian.filipino': '#bad6eb',
 'asian.japanese': '#89bedc',
 'asian.korean': '#539ecd',
 'asian.thai': '#2b7bba',
 'asian.vietnamese': '#0b559f',
 'american.brazilian': '#b6b6d8',
 'american.mexican': '#8683bd',
 'american.southern_us': '#61409b',
 'others.greek': '#fee187',
 'others.indian': '#feab49',
 'others.moroccan': '#fc5b2e',
 'others.russian': '#d41020',
 'europe.french': '#9bcdcd',
 'europe.italian': '#4da6a6',
 'europe.spanish': '#008080',
 'islands.cajun_creole': '#c4a0b1',
 'islands.jamaican': '#955072',
 'english.british': '#a2d88a',
 'english.irish': '#4cb063',
 'ingredient': '#777777bb'}

With a dataset and a carefully designed color palette we are in good shape to do some analysis of different embedding methods using UMAP to obtain visualizations of the embeddings. 

## Save this Dataset
Let's save this as a dataset for easy re-use in our other notebooks, and add the color palette to the metadata of the dataset. 

Note: This Dataset has already been added to the catalog and the following cells do not need to be run again. They are included here as a reference.

In [9]:
# from src.helpers import notebook_as_transformer

In [10]:
# new_dataset_name = f'{ds_in.name}_pruned'
# new_data = news_data
# new_target = targets
# new_metadata = ds_in.metadata.copy()
# new_metadata['color_key'] = color_key
# added_descr_txt = f"""\n This dataset is a subselection of the {ds_in.name} Dataset where we have pruned out any post less than {prune_limit} \
# characters ({prune_limit} is chosen somewhat arbitrarily). A custom `color_key` can be found in the metadata."""
# new_metadata['descr'] += added_descr_txt

# new_ds = Dataset(dataset_name=new_dataset_name, data=new_data, target=new_target,
#                  metadata=new_metadata)


In [11]:
# # Due to various design choiced in Jupyter, we need to specify this name manually.
# nbname = '00-20-newsgroups-setup.ipynb'
# dsdict = notebook_as_transformer(notebook_name=nbname,
#                                  input_datasets=[ds_in],
#                                  output_datasets=[new_ds],
#                                  overwrite_catalog=True)

