# WEAT in Python2

This class provides all functions that are required for the WEAT method. For default execution only import the present python script and call the function "call(model, X, Y, A, B, tag = 'none', permutation_num=500000)" from a external
file.

Here are the data sets that are used to define the assciation concepts. Here ist is only positive and negative. Sets are named after the terminology in each source paper. Yet, concepts are similar.

In [None]:
# Association sets from IAT literature

# Greenwald et al.(1998) - Measuring Individual Differences In Implicit Cognition - The Implicit Association Test
gw_pos = ['caress', 'freedom', 'health', 'love', 'peace', 'cheer', 'friend', 'heaven', 'loyal', 'pleasure',
          'diamond', 'gentle', 'honest', 'lucky', 'rainbow', 'diploma', 'gift', 'honor', 'miracle', 'sunrise',
          'family', 'happy', 'laughter', 'paradise', 'vacation']

gw_neg = ('abuse crash filth murder sickness accident death grief poison stink assault disaster hatred pollute '
          'tragedy bomb divorce jail poverty ugly cancer evil kill rotten vomit agony prison').split()

# Nosek, Banaji, Greendwald (2002) - Math Male Me Female Therefore Me Not Math
Pleasant_1 = 'assertive athletic strong compassion support sympathetic laughter champion paradise vacation'.split()
Unpleasant_1 = 'brutal destroy ruthless confusion insecure naive bad poor waste crude'.split()

Pleasant_2 = 'ambition cuddle excitement glory joy love paradise pleasure romantic miracle'.split()
Unpleasant_2 = 'agony death detest disaster humiliate jealousy punishment stress tragedy war'.split()

Pleasant_3 = 'affectionate cozy enjoyment friend hug laughter passion peace snuggle triumph'.split()
Unpleasant_3 = 'afraid crucify despise failure hatred irritate nightmare slap terrible violent'.split()

# Nosek et al. (2002) - Harvesting  implicit  group  attitudes  and  beliefs  from  a demonstration  web  site
harvest_good = 'Joy Love Peace Wonderful Pleasure Friend Laughter Happy'.lower().split()
harvest_bad = 'Agony Terrible Horrible Nasty Evil War Awful Failure Death'.lower().split()

# Monteith & Pettit (2011) - Implicit and explicit  stigmatizing  attitudes  and  stereotypes  about  depression.
mp_good = 'positive pleasant enjoy glorious wonderful bliss'.split()
mp_bad = 'negative horrible agony terrible unpleasant despise'.split()

###
generalPos = set(Pleasant_1 + Pleasant_2 + Pleasant_3 + gw_pos + harvest_good + mp_good)
generalNeg = set(Unpleasant_1 + Unpleasant_2 + Unpleasant_3 + gw_neg + harvest_bad + mp_bad)​

Start with some imports of course ... 

In [None]:
from gensim.models import KeyedVectors
import matplotlib.pyplot as plt
import numpy
import math
import logging, os
from util import *
# logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.INFO)

Load pretrained W2V Google Model.
The models can be downloaded at https://github.com/mmihaltz/word2vec-GoogleNews-vectors and https://github.com/eyaler/word2vec-slim, respectively. In most cases, the slim version is more than sufficient and much faster.

In [None]:
#m = models.KeyedVectors.load_word2vec_format('./modelsGoogle/GoogleNews-vectors-negative300.bin', binary=True)
m = models.KeyedVectors.load_word2vec_format('./modelsGoogle/word2vec-slim/GoogleNews-vectors-negative300-SLIM.bin', binary=True)

In [None]:
​def get_cosines(model, word, word_list):
    """
    Calculate the cosine-similarity of a word to every word in a word list.
    Here it is the mean. May also be replaced by other metrices. 
    :param model:       underlying model
    :param word:        target word
    :param word_list:   list of words to calculate distance to
    :return:            list of cosine similarities, mean cosine similarity
    """
    cosines = []
    for elem in word_list:
        try:
            cosines.append(model.similarity(word, elem))    # Similarity: Compute cosine similarity between two words.
        except KeyError:
            logging.info('no cosine for ' + word + ' and ' + elem + ' available')

    if not cosines:       # Fehler abfangen
        logging.error('No cosine values available')
        return 0, 0

    mean_cosine = sum(cosines) / len(cosines)
    return cosines, mean_cosine


def s_word(model, w, A, B, out):
    """
    Calculate the association of w with the attribute sets (Is w rather associated to A (positive value) or to B (negative value))
    :param model:       underlying model
    :param w:           target word
    :param A:           association set 1
    :param B:           association set 2
    :param out:         (boolean) only do prints if true
    :return:            s-value
    """
    cosines_wA, mean_cos_wA = get_cosines(model, w, A)
    cosines_wB, mean_cos_wB = get_cosines(model, w, B)

    s = 0
    s_word_val = mean_cos_wA - mean_cos_wB          # negativ, wenn w zu B gehört, sonst positiv
    s += s_word_val
    if out:
        assignment = 'failed (0)'
        if s_word_val > 0:
            assignment = 'A (' + format(s_word_val, '.4f') + ')'
        elif s_word_val < 0:
            assignment = 'B (' + format(s_word_val, '.4f') + ')'

        info = w + ': A = ' + format(mean_cos_wA, '.4f') + ' ; B = ' \
               + format(mean_cos_wB, '.4f') + ' | assignment: ' + assignment
        print(info)
    return s_word_val


# not  used
def s(model, X, Y, A, B, out):
    """
    Calculate the differential association of the two sets of target group with the attribute. This is exactly the same
    function as "s_corrected" but without normalisation. Thus, this function is no longer used by default.
    :param model:       underlying model
    :param X:           target set 1
    :param Y:           target set 2
    :param A:           association set 1
    :param B:           association set 2
    :param out:         only do prints if true
    :return:            s value
    """
    s_xAB_all = []
    s_yAB_all = []

    for x in X:
        if x not in model:
            logging.info(x + ' is not in vocabulary -> skip it')
        else:
            curr = s_word(model, x, A, B, out)             # im Optimalfall positiv
            s_xAB_all.append(curr)
    if out:
        print('~~~~~~~~~~*~~~~~~~~~~*~~~~~~~~~~*~~~~~~~~~~*~~~~~~~~~~*~~~~~~~~~~')

    for y in Y:
        if y not in model:
            logging.info(y + ' is not in vocabulary -> skip it')
        else:
            curr = s_word(model, y, A, B, out)             # im Optimalfall negativ
            s_yAB_all.append(curr)

    if not s_xAB_all or not s_yAB_all:          # Fehler abfangen
        logging.error('none of the words is in vocabulary')
    sum_sXAB = sum(s_xAB_all)                   # -> Summe positiv
    sum_sYAB = sum(s_yAB_all)                   # -> Summe negativ

    return sum_sXAB - sum_sYAB      # im Optimalfall ein hoher wert, da zwei Mal -


def s_corrected(model, X, Y, A, B, out):
    """
    Calculate the differential association of the two sets of target group with the attribute.
    This function is a corrected and thus updated version of s(model, X, Y, A, B, out).
    Normalisation allows to insert target word groups with different numbers of elements.
    :param model:       underlying model
    :param X:           target set 1
    :param Y:           target set 2
    :param A:           association set 1
    :param B:           association set 2
    :param out:         only do prints if true
    :return:            s value
    """
    s_xAB_all = []
    s_yAB_all = []

    for x in X:
        if x not in model:
            logging.info(x + ' in vocabulary (s, x): ' + str(x in model))
            print(x, model, x in model)
        else:
            curr = s_word(model, x, A, B, out)          # im Optimalfall positiv
            s_xAB_all.append(curr)
    if out:
        print('~~~~~~~~~~*~~~~~~~~~~*~~~~~~~~~~*~~~~~~~~~~*~~~~~~~~~~*~~~~~~~~~~')

    for y in Y:
        if y not in model:
            logging.info(y + ' in vocabulary (s, y): ' + str(y in model))
            print(y, model, y in model)
        else:
            curr = s_word(model, y, A, B, out)          # im Optimalfall negativ
            s_yAB_all.append(curr)

    if not s_xAB_all or not s_yAB_all:                  # Fehler abfangen
        logging.error('none of the words is in vocabulary')
        return 0
    sum_sXAB = sum(s_xAB_all) / len(X)                  # -> Summe positiv
    sum_sYAB = sum(s_yAB_all) / len(Y)                  # -> Summe negativ

    return sum_sXAB - sum_sYAB                          # im Optimalfall ein hoher wert, da zwei Mal -


def effect_size(model, X, Y, A, B):
    """
    calculate the effect size of association
    :param X:           target set 1
    :param Y:           target set 2
    :param A:           association set 1
    :param B:           association set 2
    :param model:       underlying model
    :return:
    """
    s_values_x = []
    s_values_y = []

    for x in X:
        if x not in model:
            logging.info(x + ' in vocabulary (effect_size, x): ' + str(x in model))
        else:
            s_values_x.append(s_word(model, x, A, B, False))
    for y in Y:
        if y not in model:
            logging.info(y + ' in vocabulary (effect_size, y): ' + str(y in model))
        else:
            s_values_y.append(s_word(model, y, A, B, False))

    if not s_values_x or not s_values_y:                # Fehler abfangen
        logging.error('non of the words is in vocabulary')
        return 0

    mean_s_val_x = sum(s_values_x) / len(s_values_x)
    mean_s_val_y = sum(s_values_y) / len(s_values_y)

    s_values_all = s_values_x + s_values_y
    mean_s_val_all = sum(s_values_all) / len(s_values_all)
    s_values_all_corrected = s_values_all
    s_values_all_corrected[:] = [((x - mean_s_val_all)**2) for x in s_values_all]        # iterable: (x_i - mean(x))²
    std_dev = math.sqrt((1 / (len(s_values_all) - 1)) * sum(s_values_all_corrected))

    return (mean_s_val_x - mean_s_val_y) / std_dev


def permutation_test(model, X, Y, A, B, tag, permutation_num):
    """
    do s(X,Y,A,B) for all possible permutations of X & Y to check whether they are lower than weat
    :param model:           underlying model
    :param X:               target set 1
    :param Y:               target set 2
    :param A:               association set 1
    :param B:               association set 2
    :param tag:             (string) name for the saved logfile (no file extension)
    :param permutation_num: Max number of random permutations
    :return:                weat value, p value, number of permutations with higher weat value,
                            ~ with equal ~, ~ with lower weat value
    """
    all_target_words = X + Y
    n = len(all_target_words)
    if (n%2) == 1:
        logging.error('ungerade Anzahl von Target Words im Spiel!')
    all_permutations = []

    k = 0

    full = False
    if n <= 15:
        full = True
        logging.info('Do permutation test with full permutation.')
        combinations = list(itertools.combinations(all_target_words, len(X)))
        combinations = combinations[1:]               # remove second half and default order

        controlIter = 0
        for combination in combinations:
            controlIter += 1

            rest = all_target_words.copy()
            for elem in combination:
                rest.remove(elem)
            all_permutations.append([list(combination), rest])

    else:
        iterations = permutation_num
        logging.info('Too many elements for full permutation. Do random sampling. Iterations: ' + str(iterations))

        for _ in itertools.repeat(None, iterations):
            shuffle = numpy.random.permutation(all_target_words).tolist()
            half1 = shuffle[:n // 2]
            half2 = shuffle[n // 2:]
            all_permutations.append([half1,half2])
            k += 1
        # logging.info('permutations done. number: ' + str(len(all_permutations)))

    i = 0
    n_2 = len(all_permutations)
    higher = 0
    equal = 0
    lower = 0

    # logging.info('permutation test: call s-corrected')
    weat = s_corrected(model, X, Y, A, B, False)
    # logging.info('done: call s-corrected. result: ' + str(weat))
    plot_vals = []

    if not full:
        logging.info('betrachte Permutations-P-Werte nur im Betrag')
    logging.info('start loop')
    for permutation in all_permutations:
        s_val = s_corrected(model, permutation[0], permutation[1], A, B, False)
        if not full:
            s_val = abs(s_val)          # Betrachte nur Samples aus dem positiven Raum und versopple so die Anzahl
        plot_vals.append(s_val)
        bar(i, n_2, 50, "P value calculation: ")
        if s_val > weat:
            higher += 1
        elif weat > s_val:
            lower += 1
        elif weat == s_val:
            equal += 1
        else:
            logging.error('hier stimmt was nicht!')
        i += 1

    p_value = 1 - (lower / (higher + equal + lower))       # observed or greater difference
    plt_hist(plot_vals, tag, weat)
    return weat, p_value, higher, equal, lower


def plt_hist(list, name, weat=None, bins=30):
    """
    plot permutation stats as histogram to visualise significance of result. Automatically saves the plot
    :param list:        list of weat-values to plot
    :param name:        (string) name for log file (no file extension)
    :param weat:        (number) actual weat value
    :param bins:        bin size for histograms
    """
    plt.hist(list, bins)
    list.sort()
    if weat:
        plt.axvline(weat, color='r')
    print(list)

    plt.title('permutation test values')
    plt.xlabel('value')
    plt.savefig('./plots/plt_' + name + '.pdf')
