In [None]:
import pathlib as pl

import glob
import pandas as pd
import numpy as np

import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns

import itertools, warnings, pickle, time, os

from collections import Counter, defaultdict
from typing import List, Dict, Tuple

# -----------------------------------------------------------------------
# CAMERA-READY PLOTTING (thanks Alex Boyd!)
# -----------------------------------------------------------------------
# The following code is borrowed from material provided by Alex!
FULL_WIDTH = 5.50107
COL_WIDTH  = 4.50461


# Put at top of plotting script (requires tex be installed though)
matplotlib.rc('font', family='serif', size=20)
matplotlib.rc('text', usetex=True)


def adjust(fig, left=0.0, right=1.0, bottom=0.0, top=1.0, wspace=0.0, hspace=0.0):
    fig.subplots_adjust(
        left   = left,  # the left side of the subplots of the figure
        right  = right,  # the right side of the subplots of the figure
        bottom = bottom,  # the bottom of the subplots of the figure
        top    = top,  # the top of the subplots of the figure
        wspace = wspace,  # the amount of width reserved for blank space between subplots
        hspace = hspace,  # the amount of height reserved for white space between subplots
    )
    

def save_fig(fig, name, **kwargs):
    fig.savefig(f"./camera_ready/images/{name}.pdf", bbox_inches="tight", **kwargs)


# Axes formatting
from matplotlib.ticker import MultipleLocator, PercentFormatter


# Accessibility
sns.set_palette(sns.color_palette("colorblind"))
matplotlib.rcParams["axes.prop_cycle"] = matplotlib.cycler(color=sns.color_palette("colorblind"))


# Composite plots 
def disable_axis(ax):
    ax.set_zorder(-100)  # Avoids a visual rendering bug
    ax.set_xticks([])
    ax.set_xticklabels([])
    ax.set_yticks([])
    ax.set_yticklabels([])
    plt.setp(ax.spines.values(), color=None)

# Dataset generation - Stage 1.  Word Selection

This notebook represents the very first step in our data generation pipeline. Since our goal is to create a dataset that is gender-invariant and free of gender co-occurring words, we will have to make sure that the words we use to bootstrap the generation of our dataset are, themselves, abiding by the properties we defined. In particular, we want to make sure that the words selected by our procedure satisfy the following property:

$$ \delta(w) = \texttt{PMI}(w, \texttt{"she"}) - \texttt{PMI}(w, \texttt{"he"}) \in [-\eta, \eta] $$, where $w$ is a word in the vocabulary and $\eta$ is a limit on how much skewed a word can be towards one of the gendered words. As detailed in the paper, we first compute the $\texttt{PMI}$ values and then empirically bin the distribution in 20 symmetric bins around the origin.



The notebook is organized as follows:
1. Read the co-occurrence counts from PILE as well as the term-frequencies, as collected by [Razeghi et al (2022)](https://aclanthology.org/2022.emnlp-demos.39/).
2. Preprocess the list of words to remove non-English words.
3. Preprocess the remaining words and remove rare words (e.g., words in the 20% percentile).
4. Compute the $\delta(w) = \texttt{PMI}(w, \texttt{"she"}) - \texttt{PMI}(w, \texttt{"he"})$ value for every word $w$
5. Sample a subset centered around the origin by sampling words that satisfy  $ -\eta \leq \delta(w) \leq \eta$

In [None]:
# Base directory where to find the files
DATA_DIR = "/extra/ucinlp1/cbelem/bias-dataset-project"

## 1. Load term counts

In [None]:
def read_pkl_file(fp: str):
    """Wrapper to read pickled filepath."""
    print("Reading file at", fp)
    start = time.time()
    with open(fp, 'rb') as tff:
        data = pickle.load(tff)
    end = time.time()
    print(f"Time to read file {(end-start)/60:.2f} min")
    return data


def read_original_coccurrence_files(parent_dir: str) -> dict: # 5GB
    """Wrapper to read the co-occurrence term count files from parent_dir."""
    return read_pkl_file(f"{parent_dir}/all_co_words.pkl")


def read_original_tf_files(parent_dir) -> dict: # 16M
    """Wrapper to read the term-frequency file from parent_dir."""
    return read_pkl_file(f"{parent_dir}/term_frequency.pkl")


def read_pmi_diff(filepath: str) -> pd.DataFrame: #1.9M
    """Read precomputed PMI difference file"""
    # Read the PMI difference filepath
    pmi_diff = {"word": [], "pmi_diff": []}
    with open(filepath, "rt") as f:
        for row in f:
            word, _, val = row.rpartition(",")
            pmi_diff["word"].append(word)
            pmi_diff["pmi_diff"].append(float(val))
    return pd.DataFrame(pmi_diff).sort_values("pmi_diff").reset_index(drop=True)

In [None]:
# Easier to lookup for word -> term frequency
TERM_COUNTS_DICT = read_original_tf_files(DATA_DIR)
TERM_COUNTS_TOTAL = sum(TERM_COUNTS_DICT.values())
# Convert term counts into dataframe to add more metadata
TERM_COUNTS_DF = pd.DataFrame(TERM_COUNTS_DICT.items(), columns=["word", "counts"])

total_counts = sum(TERM_COUNTS_DICT.values())

# Add a relative frequency column
TERM_COUNTS_DF["freq"] = TERM_COUNTS_DF["counts"].apply(lambda x: x / total_counts)
TERM_COUNTS_DF.head(5)

## Preprocessing data

**Keeping English Alphabet words:**
In this section, we wish to exclude numbers, punctuation, non-english words from the list of words. Therefore, a first preprocessing step we do is to exclude any word that is not fully created based on the English alphabet. We use Python's default functionality `str.isalpha` to achieve this.

In [None]:
# Determine whether the words belong to the english alphabet or not
TERM_COUNTS_DF['isalpha'] = TERM_COUNTS_DF["word"].apply(str.isalpha)

english_alphabet = TERM_COUNTS_DF["isalpha"].value_counts()[True]
print(f'{english_alphabet/len(TERM_COUNTS_DF):.2%} of the examples',
      f'(out of {len(TERM_COUNTS_DF)}) belong to the English alphabet.')

# Drop words containing non-English alphabet characters
TERM_COUNTS_DF = TERM_COUNTS_DF[TERM_COUNTS_DF["isalpha"]]
TERM_COUNTS_DF[TERM_COUNTS_DF["isalpha"]].tail(10)

**Removing non-English words**: However, restricting to the English alphabet does not exclude other languages. For example, the spanish word "echaban" would be kept if we only applied the previous procedure. One hypothesis would be to remove the non-English words, using a language detector. However, this may also remove borrowed foreign words that are common in the English language, like _influenza_. 

Therefore, in the following cells, we will use a heuristic approach that keeps a word in `TERM_COUNTS_DF` if one of the following conditions is satisfied:

1. The [fasttext](https://fasttext.cc/docs/en/unsupervised-tutorial.html) character-level language classifier predicts the word language to be English word with at least `ENGLISH_PRED_THRESHOLD`% confidence.
2. There exists a sense definition for word $w$ in [WordNet](https://wordnet.princeton.edu/).


Note: We experimented with [langdetect](https://pypi.org/project/langdetect/) library from Google as well, but it performs poorly when identifying individual words, e.g., mentions that _hello_ is not English. Leading to a large number of false negatives (i.e., claiming English words are non-English words). On the other hand, fasttext proved to be much better at this task. Besides, it also gives the confidence associated with the prediction.

In [None]:
try:
    import fasttext
except:
    # Install fasttext
    !pip install fasttext
    import fasttext
from typing import List, Tuple 

    
FTEXT_MODEL_NAME = "lid.176.bin"
# Download the language detection model (trained w/ 176 languages)
if not os.path.isfile(FTEXT_MODEL_NAME):
    !wget https://dl.fbaipublicfiles.com/fasttext/supervised-models/lid.176.bin
        
# Load fasttext model
FTEXT_MODEL = fasttext.load_model(FTEXT_MODEL_NAME)

# Language threshold
ENGLISH_PRED_THRESHOLD = 0.6

In [None]:
def fasttext_predict(word: str, model):
    """Predicts the language using the specified fasttext model."""
    pred = model.predict(word, k=1)
    return pred[0][0].replace("__label__", ""), pred[1][0]


# Determine whether words are english
TERM_COUNTS_DF["ft_pred_lang"], TERM_COUNTS_DF["ft_pred_conf"] = zip(
    *TERM_COUNTS_DF["word"].apply(fasttext_predict, model=FTEXT_MODEL)
)
pred_eng_counts = TERM_COUNTS_DF["ft_pred_lang"].value_counts()["en"]
print(f'{pred_eng_counts/len(TERM_COUNTS_DF):.2%} of the words are predicted to be English')
print(f'Number of unique predicted languages: {TERM_COUNTS_DF["ft_pred_lang"].nunique()}')
TERM_COUNTS_DF["ft_pred_lang"].value_counts()

In [None]:
try:
    import nltk
except:
    # Install fasttext
    !pip install nltk
    import nltk
from nltk.corpus import wordnet

# Count the number of wordnet senses
TERM_COUNTS_DF["wordnet_counts"] = TERM_COUNTS_DF["word"].apply(lambda x: len(wordnet.synsets(x)))
(TERM_COUNTS_DF["wordnet_counts"].value_counts() / len(TERM_COUNTS_DF)).head(10)

In [None]:
def english_words_mask(df: pd.DataFrame, threshold) -> pd.DataFrame:
    is_english = (df["ft_pred_lang"] == "en") & (df["ft_pred_conf"] > threshold)
    is_in_wordnet = (is_english == False) & (df["wordnet_counts"] > 0)
    return (is_english | is_in_wordnet)


IS_ENGLISH_MASK = english_words_mask(TERM_COUNTS_DF, ENGLISH_PRED_THRESHOLD)

nonengl_terms_orig_pmi = TERM_COUNTS_DF[~IS_ENGLISH_MASK].sort_values("word")
print("Total number of non english words:", len(nonengl_terms_orig_pmi))

print("Examples of words dropped due to being dubbed not english according to our procedure...")
print("-", "\n- ".join(nonengl_terms_orig_pmi["word"].values[::2000]))

TERM_COUNTS_DF = TERM_COUNTS_DF[IS_ENGLISH_MASK]
len(TERM_COUNTS_DF)

## Preprocess: Remove rare words

Upon observation of the remaining words, we observe that some of the words in the list (e.g., "succinylacetone", "clientage") correspond to valid English words and, sometimes, typos. Since these words do not represent common English words, it could throw off our model during dataset generation (e.g., degrade generation, lack diversity). 

To account for this, we notice that these words are often rarer and occur less frequently in the dataset. As such, we decided to remove them.

In [None]:
TERM_COUNTS_DF

In [None]:
# Plot the term counts vs rank of english words
sns.lineplot(x=np.arange(len(TERM_COUNTS_DF)), y=TERM_COUNTS_DF["counts"].values)
plt.xscale("log"); plt.xlabel("Term rank")
plt.yscale("log"); plt.ylabel("Term counts")

q = 0.2
q_val = TERM_COUNTS_DF["counts"].quantile(q)
plt.axhline(q_val, label=f"{q:.0%} quantile: {q_val}", ls="--", c="r")
plt.legend()
plt.show()

In [None]:
low_freq_terms_alpha = TERM_COUNTS_DF[TERM_COUNTS_DF["counts"] < q_val].sort_values("counts", ascending=False)
print("Total number of low freq words:", len(low_freq_terms_alpha))

print("Examples of words with higher rank (lower frequency):")
print("-", "\n- ".join(low_freq_terms_alpha["word"].values[::2000]))

In [None]:
TERM_COUNTS_DF = TERM_COUNTS_DF[TERM_COUNTS_DF["counts"] > q_val]
TERM_COUNTS_DF.tail(20)

`TERM_COUNTS_DF` contains the information about a good estimate the English terms that are likely to occur in the data. A few limitations that we should address in the future are:

1. remove slang (e.g., making use of [SlangNet](https://aclanthology.org/L16-1686/), [SlangSD](http://liangwu.me/slangsd/), etc.
2. remove names
3. remove abbreviations


For now, we will proceed, assuming this is the best subset of the English words we can derive from PILE.

## Computing the PMI for each word

In this section, we will compute the $\texttt{PMIDiff}(w)$ for every word $w$. To that end, we will first define a list of gendered words (eg, "mother", "father", "boy", "girl") and we will compute the $\texttt{PMI}$ of every word and these _group words_. Note that the co-occurrence counts loaded in `TERMS_CO_OCCUR`, consist of counts within a window size 10 after stop words have been removed. These do not refer to co-occurrence counts within the same document. 

$$\texttt{PMI}(w, g) = log \frac{p(w, g)}{p(w)p(g)}$$, where $w$ is the word in the vocabullary and $g$ is a group word. PMI therefore represents the strength of association between the two words, namely, how likely are the two words to co-occur together when compared to appearing individually. A negative value indicates that the words are less likely to co-occur together, whereas a positive value implies that the words almost always appear together.


In [None]:
from collections import defaultdict

In [None]:
# This document is huge
TERMS_CO_OCCUR = read_original_coccurrence_files(DATA_DIR)
print("Number of bigrams:", len(TERMS_CO_OCCUR))
TERMS_CO_OCCUR_TOTAL = sum(TERMS_CO_OCCUR.values()) # 131M


GROUP_TERMS = [
    ("she", "he"),
    ("her", "his"),
    ("her", "him"),
    ("hers", "his"),
    ("herself", "himself"),
    ("grandmother", "grandfather"),
    ("grandma", "grandpa"),
    ("stepmother", "stepfather"),
    ("stepmom", "stepdad"),
    ("mother", "father"),
    ("mom", "dad"),
    ("aunt", "uncle"),
    ("aunts", "uncles"),
    ("mummy", "daddy"),
    ("sister", "brother"),
    ("sisters", "brothers"),
    ("daughter", "son"),
    ("daughters", "sons"),
    ("female", "male"),
    ("females", "males"),
    ("feminine", "masculine"),
    ("woman", "man"),
    ("women", "men"),
    ("madam", "sir"),
    ("matriarchy", "patriarchy"),
    ("girl", "boy"),
    ("lass", "lad"),
    ("girls", "boys"),
    ("girlfriend", "boyfriend"),
    ("girlfriends", "boyfriends"),
    ("wife", "husband"),
    ("wives", "husbands"),
    ("queen", "king"),
    ("queens", "kings"),
    ("princess", "prince"),
    ("princesses", "princes"),
    ("lady", "lord"),
    ("ladies", "lords"),
]
FEMALE_TERMS, MALE_TERMS = zip(*GROUP_TERMS)

ALL_TERMS = set(TERM_COUNTS_DF.word.values)
ALL_TERMS.add(FEMALE_TERMS)
ALL_TERMS.add(MALE_TERMS)

# Since we're interested in computing the PMI value for every word and other K words
# we will have to iterate it at least k times (which would be time consuming)
# Therefore, we will filter out the structure to include only pairs where terms defined in 
# `terms` or group words appear.
def select_subset(bigram_counts: dict, terms: set) -> dict:
    results = {}
    for bigram, counts in bigram_counts.items():
        if bigram[0] in terms or bigram[1] in terms:
            results[bigram] = counts       
    return results


# Update term counts dict to contain only the relevant terms
TERM_COUNTS_DICT = {w: v for w, v in TERM_COUNTS_DICT.items() if w in ALL_TERMS}
len(TERM_COUNTS_DICT), len(TERM_COUNTS_DF)

# Update terms co-occurs
TERMS_CO_OCCUR = select_subset(TERMS_CO_OCCUR, ALL_TERMS)
print("Reduced number of bigrams:", len(TERMS_CO_OCCUR)) # roughly 113M pairs remaining

In [None]:
def compute_pmi(unigram_counts: dict, bigram_counts: dict, w, g, unigram_total: int, bigram_total: int):
    """Compute PMI for a words w, g using the bigram and unigram counts structures."""
    p_w = unigram_counts.get(w, 0) / unigram_total
    p_g = unigram_counts.get(g, 0) / unigram_total
    
    p_w_g = (bigram_counts.get((w, g), 0) + bigram_counts.get((g, w), 0)) / bigram_total
    
    if 0 in (p_w, p_g, p_w_g):
        return None
    
    # For numerical stability, we opt for computing PMI as:
    return np.log(p_w_g) - np.log(p_w) - np.log(p_g)


def compute_pmi_per_group_word(words: List[str], group_words: List[str]):
    results = defaultdict(list)
    for group_word in set(group_words):
        for word in words:
            pmi = compute_pmi(
                unigram_counts=TERM_COUNTS_DICT,
                bigram_counts=TERMS_CO_OCCUR,
                w=word, g=group_word,
                unigram_total=TERM_COUNTS_TOTAL, 
                bigram_total=TERMS_CO_OCCUR_TOTAL)
            
            results[f"pmi_{group_word}"].append(pmi)
            
    return results

In [None]:
# Compute the PMI between every word and every female word
PMI_FEMALE = compute_pmi_per_group_word(TERM_COUNTS_DF["word"].values.tolist(), FEMALE_TERMS)

# Compute the PMI between every word and every male word
PMI_MALE = compute_pmi_per_group_word(TERM_COUNTS_DF["word"].values.tolist(), MALE_TERMS)

In [None]:
len(PMI_FEMALE), len(PMI_MALE)

### Computing the PMI difference:

To obtain a sense of how much more likely is a word to co-occur with female words than with male words, we can compute the difference of PMIs as follows:

$$\delta(w, g_F, g_M) = \texttt{PMI}(w, g_F) - \texttt{PMI}(w, g_M)$$, where $g_M$ and $g_F$ represent male and female gendered words, respectively.

In the original version of this work, we simply determined the gendered co-occurrence of a word by computing $\delta(w, \texttt{"she"}, \texttt{"he"})$. However, this may be suboptimal since many other words can be implicitly correlated with gender. 
In this notebook, we will compute the PMI difference as the $max_{(g_F, g_M) \in (G_F, G_M)} |\delta(w, g_F, g_M)|$, where $(G_F, G_M)$ is the list of paired group words (eg, as defined in `GROUP_TERMS`). The intuition is that we will represent the gender polarity of a word with the strongest existing correlation. 

In [None]:
import math 

# Every word w, we have len(GROUP_TERMS) PMI values
# - some of which can be None, if one of the grouped words did not occur with w)
results = defaultdict(list)
results["word"] = TERM_COUNTS_DF["word"].values.tolist()

for word_idx in range(len(TERM_COUNTS_DF)):
    for fterm, mterm in GROUP_TERMS:
        pmi_f = PMI_FEMALE[f"pmi_{fterm}"][word_idx]
        pmi_m = PMI_MALE[f"pmi_{mterm}"][word_idx]
        
        # If one of the terms is not defined, append None
        if pmi_f is None or pmi_m is None or math.isnan(pmi_f) or math.isnan(pmi_m):
            results[f"pmi_{fterm}_{mterm}"].append(None)
        else:
            results[f"pmi_{fterm}_{mterm}"].append(pmi_f - pmi_m)
            
            
results = pd.DataFrame(results)
results.info()

#### Marginal distribution of $\delta(w)$

In [None]:
marginals = pd.DataFrame({
    "pmi_value": PMI_MALE["pmi_he"] + PMI_FEMALE["pmi_she"],
    "gender word": ["he"] * len(PMI_MALE["pmi_he"]) + ["she"] * len(PMI_FEMALE["pmi_she"])
})
marginals.head()

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(COL_WIDTH, COL_WIDTH))
sns.histplot(marginals, x="pmi_value", hue="gender word", ax=ax)

#### Joint distribution

In [None]:
    "pmi_diff": [f - m if f and m else None for f, m in zip(PMI_FEMALE["pmi_she"], PMI_MALE["pmi_he"])]


### Correlation matrix: Analysis of the correlation between different gendered word pairs

In [None]:
gendered_word_pairs = results.set_index("word").copy()
gendered_word_pairs.columns = ["(" + c.replace("pmi_","").replace("_", ", ") + ")" for c in gendered_word_pairs.columns]

# Drop rows with no valid pmi diff
subset_cols = sorted([c for c in gendered_word_pairs if gendered_word_pairs[c].isna().sum() < len(gendered_word_pairs)])
gendered_word_pairs = gendered_word_pairs[subset_cols].corr("kendall")


matplotlib.rc('font', family='serif', size=8)

fig, ax = plt.subplots(1, 1, figsize=(FULL_WIDTH, FULL_WIDTH))
sns.heatmap(gendered_word_pairs, 
            mask = np.triu(np.ones(gendered_word_pairs.shape)),
            vmin=-1, vmax=1, center=0, cbar_kws={"shrink": 0.7},
            cmap="seismic", square=True, linewidths=0.05, ax=ax,
           )

# ax.set_xticks([])
# ax.set_xticklabels([])
# ax.tick_params(axis='x', labelrotation=33)
adjust(fig)
save_fig(fig, "heatmap__alternative_pmi_definitions", dpi=150)

In [None]:
counts_well_defined = len(results) - results.isna().sum(axis=0).copy()
counts_well_defined = pd.DataFrame(counts_well_defined, columns=["Counts"])
counts_well_defined["fraction"] = counts_well_defined["Counts"] / len(results)
print(counts_well_defined.sort_index().to_latex(float_format='{:0.2%}'.format))

### Drop uncommon words

In [None]:
# Mark the most common words
ORIG_PMI_DF["is_common"] = ORIG_PMI_DF["word"].isin(TERM_COUNTS_DF_ALPHA_UQ["word"].values)
ORIG_PMI_DF["is_common"].value_counts()

In [None]:
low_freq_terms_orig_pmi = ORIG_PMI_DF[ORIG_PMI_DF["is_common"] == False].sort_values("word")
print("Total number of low freq words:", len(low_freq_terms_orig_pmi))

print("Examples of words dropped due to lower frequency:")
print("-", "\n- ".join(low_freq_terms_orig_pmi["word"].values))

In [None]:
plt.figure(figsize=(4, 3), dpi=200)
sns.histplot(ORIG_PMI_DF, x="pmi_diff", hue="is_common")

In [None]:
ORIG_PMI_DF_UQ = ORIG_PMI_DF[ORIG_PMI_DF["is_common"]].reset_index(drop=True)
print(len(ORIG_PMI_DF), "-->", len(ORIG_PMI_DF_UQ), "; delta =", len(ORIG_PMI_DF)-len(ORIG_PMI_DF_UQ))
ORIG_PMI_DF_UQ.head()

In [None]:
ORIG_PMI_DF_UQ.drop("word", axis=1).corr("kendall")

In [None]:
ORIG_PMI_DF_UQ_LANG[mask].sort_values("word").head()

In [None]:
ORIG_PMI_DF_UQ_ENG = ORIG_PMI_DF_UQ_LANG[mask]
ORIG_PMI_DF_UQ_ENG

In [None]:
sns.jointplot(ORIG_PMI_DF_UQ_LANG, x="pred_conf", y="wordnet_counts", hue="is_english", s=5)

In [None]:
sns.jointplot(ORIG_PMI_DF_UQ_LANG, x="pred_conf", y="wordnet_counts", hue="is_english", s=5)

In [None]:
sns.jointplot(ORIG_PMI_DF_UQ_ENG, x="pred_conf", y="pmi_diff", hue="is_english", s=5)

In [None]:
plt.figure(figsize=(4, 3), dpi=200)
sns.jointplot(ORIG_PMI_DF_UQ_ENG, x="wordnet_counts", y="pmi_diff", s=5)
plt.xlabel("Number of WordNet definitions")
plt.ylabel("PMI Difference, $\delta(w)$")
plt.show()

In [None]:
TERM_COUNTS_DF.shape

In [None]:
ORIG_PMI_DF_UQ_ENG.shape[0], TERM_COUNTS_DF.shape[0], round(ORIG_PMI_DF_UQ_ENG.shape[0] / TERM_COUNTS_DF.shape[0], 4)

### Obtain the words

In the variable ORIG_PMI_DF_UQ_ENG, we have the selected English words.
We have yet to reduce the set of words to the ones having the same root.
Since we're using stratified sampling to select one word from each bin, we do not need to care too much about this. If two words with the same root are selected, it is likely that it is because they were sampled from different bins. In which case, it may suggest that there is a significant difference.

In [None]:
num_bins = 20
# define PMI range
pmi_diff_max = ORIG_PMI_DF_UQ_ENG["pmi_diff"].apply(np.abs).describe()["max"]
print(pmi_diff_max)

pmi_diff_max = np.ceil(pmi_diff_max)
bins = np.linspace(-pmi_diff_max, pmi_diff_max, num_bins)

ORIG_PMI_DF_UQ_ENG.loc[:,"pmi_diff_bins"] = pd.cut(ORIG_PMI_DF_UQ_ENG["pmi_diff"], bins)
ORIG_PMI_DF_UQ_ENG["pmi_diff_bins"].value_counts()

In [None]:
intervals = sorted(ORIG_PMI_DF_UQ_ENG["pmi_diff_bins"].unique())
interval_idx_middle = [ix for ix, interval in enumerate(intervals) if 0 in interval][0]
intervals[interval_idx_middle]

In [None]:
sampling_bin = ORIG_PMI_DF_UQ_ENG[ORIG_PMI_DF_UQ_ENG["pmi_diff_bins"] == intervals[interval_idx_middle]]
sampling_bin = sampling_bin.sort_values("freq", ascending=False)
sampling_bin.head(30)

In [None]:
sampling_bin.tail(30)

In [None]:
sampling_bin.sort_values("pmi_diff").head(30)

In [None]:
sampling_bin.sort_values("pmi_diff").tail(30)

In [None]:
ORIG_PMI_DF_UQ_ENG["skews"] = ["male"] * len(ORIG_PMI_DF_UQ_ENG)
female_mask = ORIG_PMI_DF_UQ_ENG["pmi_diff"] > 0
ORIG_PMI_DF_UQ_ENG.loc[female_mask, "skews"] = "female"

neutral_mask = (ORIG_PMI_DF_UQ_ENG["pmi_diff"] >= -0.263) & (ORIG_PMI_DF_UQ_ENG["pmi_diff"] <= 0.263)
ORIG_PMI_DF_UQ_ENG.loc[neutral_mask, "skews"] = "neutral"

ORIG_PMI_DF_UQ_ENG["skews"].value_counts() / len(ORIG_PMI_DF_UQ_ENG)

In [None]:
ORIG_PMI_DF_UQ_ENG.sample()

In [None]:
def get_wordnet_info(df: pd.DataFrame):
    results = []
    for ix, row in df.iterrows():
        wordnet_defs = {}
        
        if row["wordnet_counts"] > 0:
            synsets = wordnet.synsets(row["word"])
            wordnet_defs = {s.name(): s.definition() for s in synsets}
            
        results.append(wordnet_defs)
        
    return results


wordnet_sample = get_wordnet_info(ORIG_PMI_DF_UQ_ENG)
ORIG_PMI_DF_UQ_ENG["wordnet_definitions"] = wordnet_sample

In [None]:
ORIG_PMI_DF_UQ_ENG.to_csv("../results__pool_of_words_by_pmi.csv", index=None)

In [None]:
ORIG_PMI_DF_UQ_ENG_NEUTRAL = ORIG_PMI_DF_UQ_ENG[ORIG_PMI_DF_UQ_ENG["skews"] == "neutral"].copy()
ORIG_PMI_DF_UQ_ENG_NEUTRAL.to_csv("../results__neutral__pool_of_words_by_pmi.csv", index=None)
len(ORIG_PMI_DF_UQ_ENG_NEUTRAL)

In [None]:
import os
BASE_DIR = ""
SAMPLES = []
for i, seed in enumerate((9123, 19223, 8172361, 91283, 72613)):
    sample = ORIG_PMI_DF_UQ_ENG_NEUTRAL.sample(n=100, replace=False, random_state=seed)
    
    for num in (5, 10, 20):
        os.makedirs(f"../results-words{num}/words{i+1}", exist_ok=True)
        sample.to_csv(f"../results-words{num}/words{i+1}/selected_words__{seed}.csv")
        words = sorted(sample["word"].unique())

        with open(f"../results-words{num}/words{i+1}/words.txt", "w") as f:
            f.write("\n".join(words))

In [None]:
sample = ORIG_PMI_DF_UQ_ENG.groupby('pmi_diff_bins', group_keys=False).apply(lambda x: x.sample(frac=0.005))
sample["skews"].value_counts() / len(sample)

In [None]:
sample["pmi_diff_bins"].value_counts() / len(sample)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(6, 6))
sns.histplot(ORIG_PMI_DF_UQ_ENG, x="pmi_diff", binwidth=0.1, ax=ax, label=f"Original: {len(ORIG_PMI_DF_UQ_ENG)}", stat="probability")
sns.histplot(sample, x="pmi_diff", binwidth=0.1, ax=ax, label=f"Sample: {len(sample)}", stat="probability")
plt.legend()

In [None]:
sample2 = ORIG_PMI_DF_UQ_ENG.groupby('pmi_diff_bins', group_keys=False).apply(lambda x: x.sample(min(len(x), 10), replace=False))
sample2["skews"].value_counts() / len(sample2)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(6, 6))
sns.histplot(ORIG_PMI_DF_UQ_ENG, x="pmi_diff", binwidth=0.1, ax=ax, label=f"Original: {len(ORIG_PMI_DF_UQ_ENG)}", stat="probability")
sns.histplot(sample2, x="pmi_diff", binwidth=0.1, ax=ax, label=f"Sample: {len(sample2)}", stat="probability")
plt.legend()
sample2["pmi_diff_bins"].value_counts() / len(sample2)

#### add wordnet info

In [None]:
sorted(sample2.word)

### Persist

In [None]:
sample.to_csv("../results/selected_words.csv", index=None)

In [None]:
TERM_COUNTS_DICT["he"],TERM_COUNTS_DICT["his"], TERM_COUNTS_DICT["him"]

In [None]:
TERM_COUNTS_DICT["she"], TERM_COUNTS_DICT["her"]

In this file, we plan to select a set of words from the pretraining set in an automatic fashion. We'll try to make an intuitive choice by considering the following:
 
 $$\text{PMI}(w, \text{"she"}) - \text{PMI}(w, \text{"he"}) = log \frac{P(\text{"she"}|w)}{P(\text{"he"}|w)}$$

Thus, we will deem words whose odd ratio is 2.5 times smaller or larger to be unproprortionally skewed. We will not consider these words for our bias benchmark creation:
- Remove words whose $\frac{P(\text{"she"}|w)}{P(\text{"he"}|w)} \geq \tau \vee \frac{P(\text{"he"}|w)}{P(\text{"she"}|w)} \geq \tau$, where $\tau = 2.5$


## Check original words frequency


In [None]:
orig_df = pd.read_csv("../../experiments-tacl-june-2023/data/pmi_diffs_selected.csv")
# orig_df = orig_df[~orig_df["selected"].isna()]

orig_words_set = set(orig_df["word"].unique())
orig_df["is_common"] = orig_df["word"].isin(TERM_COUNTS_DF_ALPHA_UQ["word"].values)
orig_df["is_common"].value_counts()

In [None]:
sns.lineplot(x=np.arange(len(TERM_COUNTS_DF_ALPHA)), y=TERM_COUNTS_DF_ALPHA["counts"].values)

idx = np.array(TERM_COUNTS_DF_ALPHA[TERM_COUNTS_DF_ALPHA["word"].isin(orig_df["word"])].index)
sns.scatterplot(x=idx, y=TERM_COUNTS_DF_ALPHA["counts"].values[idx], color="red", s=15)
plt.xscale("log"); plt.xlabel("Term rank")
plt.yscale("log"); plt.ylabel("Term counts")

q = 0.2
q_val = TERM_COUNTS_DF_ALPHA["counts"].quantile(q)
plt.axhline(q_val, label=f"{q:.0%} quantile: {q_val}", ls="--", c="r")
plt.legend()
plt.show()

In [None]:
current_df = pd.read_csv("../../experiments-aug-2023/results/selected_words.csv")
current_df

In [None]:
sns.lineplot(x=np.arange(len(TERM_COUNTS_DF_ALPHA)), y=TERM_COUNTS_DF_ALPHA["counts"].values)

idx = np.array(TERM_COUNTS_DF_ALPHA[TERM_COUNTS_DF_ALPHA["word"].isin(current_df["word"])].index)
sns.scatterplot(x=idx, y=TERM_COUNTS_DF_ALPHA["counts"].values[idx], color="black", s=15)
plt.xscale("log"); plt.xlabel("Term rank")
plt.yscale("log"); plt.ylabel("Term counts")

q = 0.2
q_val = TERM_COUNTS_DF_ALPHA["counts"].quantile(q)
plt.axhline(q_val, label=f"{q:.0%} quantile: {q_val}", ls="--", c="r")
plt.legend()
plt.show()

In [None]:
ORIG_PMI_DF_UQ_ENG.sort_values("counts", ascending=False).head(30)

## Originally picked words

In [None]:
orig_words = pd.concat((
    pd.read_csv("../results-words5/words1/selected_words__9123.csv", index_col=0),
    pd.read_csv("../results-words5/words2/selected_words__19223.csv", index_col=0),
    pd.read_csv("../results-words5/words3/selected_words__8172361.csv", index_col=0),
    pd.read_csv("../results-words5/words4/selected_words__91283.csv", index_col=0),
    pd.read_csv("../results-words5/words5/selected_words__72613.csv", index_col=0),
)).drop_duplicates()
print(len(orig_words))

In [None]:
orig_words.to_csv("selected_words.csv")




In [None]:
q_val

In [None]:
orig_words[orig_words.word == "whatcha"]

In [None]:
TERM_COUNTS_DF[TERM_COUNTS_DF["wordnet_counts"] >= 1]

In [None]:
orig_words[orig_words.word.isin(["votary", "wale", "waylaid", "waylay", "ween", "spasmodic"])]