## Preparing the 20 Newsgroups Dataset

We will need some data to work with. For the purposes of this demo we will make use of the 20 newsgroups dataset. Even though 20 newsgroups is a toy dataset, it offers enough complications to show how we can piece together embeddings using ``vectorizers``.

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from src import paths
from src.data import Dataset

2023-03-30 14:06:28,392 - utils - INFO - Note: NumExpr detected 32 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
2023-03-30 14:06:28,393 - utils - INFO - NumExpr defaulting to 8 threads.


In [3]:
import numpy as np
import matplotlib.colors
import seaborn as sns

In [4]:
ds_in = Dataset.load('20_newsgroups')
print(ds_in.DESCR)

2023-03-30 14:06:32,392 - datasets - INFO - Generated output datasets: ['20_newsgroups'] via edge:'_20_newsgroups'



The 20 Newsgroups dataset is a collection of approximately 20,000 newsgroup documents, partitioned (nearly) evenly across 20 different newsgroups.

The data is organized into 20 different newsgroups, each corresponding to a different topic. Some of the newsgroups are very closely related to each other (e.g. comp.sys.ibm.pc.hardware / comp.sys.mac.hardware), while others are highly unrelated (e.g misc.forsale / soc.religion.christian).

Here are the categories:

 * `alt.atheism`,
 * `comp.graphics`,
 * `comp.os.ms-windows.misc`,
 * `comp.sys.ibm.pc.hardware`,
 * `comp.sys.mac.hardware`,
 * `comp.windows.x`,
 * `misc.forsale`,
 * `rec.autos`,
 * `rec.motorcycles`,
 * `rec.sport.baseball`,
 * `rec.sport.hockey`,
 * `sci.crypt`,
 * `sci.electronics`,
 * `sci.med`,
 * `sci.space`,
 * `soc.religion.christian`,
 * `talk.politics.guns`,
 * `talk.politics.mideast`,
 * `talk.politics.misc`,
 * `talk.religion.misc`

The current version is obtained by wrapping `sklearn.datasets.fetch_20newsgroups

First, we will do a little pruning: the 20 newsgroups dataset contains some extremely short posts (once the headers, footers and quotes are removed) that will muddy up the results. We will prune out any post less than 200 characters (200 is chosen somewhat arbitrarily).

In [5]:
prune_limit = 200

In [6]:
long_enough = [len(t) > prune_limit for t in ds_in.data]
targets = np.array(ds_in.target)[long_enough]
news_data = [t for t in ds_in.data if len(t) > prune_limit]

For each newsgroup post, the target data is the name of the newsgroup the post was sent to. There are broad groups that the specific newsgroups can be classified into, such as religion, politics, computing, sport and science. While some of broad groups can be gleaned programmatically from the newsgroup name (with its dotted hierarchy), some groups (like alt.atheism being in the religion topic) require more care. We will hand curate the newsgroups into 6 broad categories:
* religion
* politics
* sport
* comp
* sci
* misc

Using these broad categories, we will create a custom color palette for the data when visualizing results such that different newsgroups in the same category can have similar colours.

In [7]:
religion = ("alt.atheism", "talk.religion.misc", "soc.religion.christian")
politics = ("talk.politics.misc", "talk.politics.mideast", "talk.politics.guns")
sport = ("rec.sport.baseball", "rec.sport.hockey")
comp = (
    "comp.graphics",
    "comp.os.ms-windows.misc",
    "comp.sys.ibm.pc.hardware",
    "comp.sys.mac.hardware",
    "comp.windows.x",
)
sci = (
    "sci.crypt",
    "sci.electronics",
    "sci.med",
    "sci.space",
)
misc = (
    "misc.forsale",
    "rec.autos",
    "rec.motorcycles",
)

In [8]:
color_key = {}
for l, c in zip(religion, sns.color_palette("Blues", 4)[1:]):
    color_key[l] = matplotlib.colors.rgb2hex(c)
for l, c in zip(politics, sns.color_palette("Purples", 4)[1:]):
    color_key[l] = matplotlib.colors.rgb2hex(c)
for l, c in zip(comp, sns.color_palette("YlOrRd", 5)):
    color_key[l] = matplotlib.colors.rgb2hex(c)
for l, c in zip(sci, sns.color_palette("light:teal", 5)[1:]):
    color_key[l] = matplotlib.colors.rgb2hex(c)
for l, c in zip(sport, sns.color_palette("light:#660033", 4)[1:3]):
    color_key[l] = matplotlib.colors.rgb2hex(c)
for l, c in zip(misc, sns.color_palette("YlGn", 4)[1:]):
    color_key[l] = matplotlib.colors.rgb2hex(c)
color_key["word"] = "#777777bb"

With a dataset and a carefully designed color palette we are in good shape to do some analysis of different embedding methods using UMAP to obtain visualizations of the embeddings. 

## Save this Dataset
Let's save this as a dataset for easy re-use in our other notebooks, and add the color palette to the metadata of the dataset. 

Note: This Dataset has already been added to the catalog and the following cells do not need to be run again. They are included here as a reference.

In [9]:
from src.helpers import notebook_as_transformer

In [10]:
new_dataset_name = f'{ds_in.name}_pruned'
new_data = news_data
new_target = targets
new_metadata = ds_in.metadata.copy()
new_metadata['color_key'] = color_key
added_descr_txt = f"""\n This dataset is a subselection of the {ds_in.name} Dataset where we have pruned out any post less than {prune_limit} \
characters ({prune_limit} is chosen somewhat arbitrarily). A custom `color_key` can be found in the metadata."""
new_metadata['descr'] += added_descr_txt

new_ds = Dataset(dataset_name=new_dataset_name, data=new_data, target=new_target,
                 metadata=new_metadata)


In [11]:
# Due to various design choiced in Jupyter, we need to specify this name manually.
nbname = '00-20-newsgroups-setup.ipynb'
dsdict = notebook_as_transformer(notebook_name=nbname,
                                 input_datasets=[ds_in],
                                 output_datasets=[new_ds],
                                 overwrite_catalog=True)

