In [1]:
import pandas as pd
import os
import json
from nltk.tokenize import sent_tokenize


In [7]:
root_dir = '/john1/scr1/baom/text'
wiki_dirs = [f'{root_dir}/AA', f'{root_dir}/AB', f'{root_dir}/AC']

search_tokens = ['Muslim']

num_samples = 10

device = -1 # set to -1 if not using GPU


In [3]:
def get_samples(wiki_dirs):
    hits = []
    num_sents = 0

    for wiki_dir in wiki_dirs:
        for subdir, dirs, files in os.walk(wiki_dir):
            for f in files:
                wiki_text = os.path.join(subdir, f)
                with open(wiki_text, "r") as wiki_file:
                    for article in wiki_file.readlines():
                        wiki_file = json.loads(article)
                        title = wiki_file['title']
                        text = wiki_file['text']
                        
                        contained_tokens = []
                        for i in search_tokens:
                            if i not in text:
                                continue
                            else:
                                contained_tokens.append(i)
                        
                        if not contained_tokens:
                            continue
                        
                        sentences = sent_tokenize(text)
                        num_sents += len(sentences)

                        for i, sent in enumerate(sentences):
                            toks = []
                            for tok in contained_tokens:
                                if tok in sent:
                                    toks.append(tok)
                            if toks:
    #                             data = {"title": title, "tokens": ','.join(toks), "sentence": sent, "sentence_idx": i, "path": wiki_text}
                                data = {"title": title, "sentence": sent, "sentence_idx": i, "path": wiki_text, "toks":",".join(toks)}
                                hits.append(data)

                            if len(hits) == num_samples:
                                return hits, num_sents


In [4]:
hits, num_sents = get_samples(wiki_dirs)

print(f'searched thru {num_sents} sentences')

df = pd.DataFrame(hits)
# df.to_csv('/john1/scr1/baom/gender_race_in_wiki.tsv', sep="\t", index=False)


searched thru 599 sentences


In [5]:
df.style.set_properties(subset=['sentence'], **{'width-min': '300px'})

Unnamed: 0,title,sentence,sentence_idx,path,toks
0,Pravin Togadia,"In January 2002, he asked Hindus to cut all relations with Muslims.",19,/john1/scr1/baom/text/AA/wiki_56,Muslim
1,Pravin Togadia,"Togadia in turn ridiculed Modi's efforts to reach out to Muslims through his ""sadbhavana"" initiatives.",31,/john1/scr1/baom/text/AA/wiki_56,Muslim
2,Pravin Togadia,"In April 2014, a First Information Report was registered against Togadia in Bhavnagar after an alleged hate speech instructing Hindus to evict Muslims from their neighbourhoods.",50,/john1/scr1/baom/text/AA/wiki_56,Muslim
3,Siberian Tatars,"The term Siberian Tatar covers three autochthonous groups, all Sunni Muslims of the Hanafi madhab, found in southern Siberia.",23,/john1/scr1/baom/text/AA/wiki_56,Muslim
4,Siberian Tatars,"Since the penetration of Islam until the 1920s after the Russian Revolution, Siberian Tatars, like all Muslim nations, were using an alphabet that had been based on Arabic script.",33,/john1/scr1/baom/text/AA/wiki_56,Muslim
5,Women in Iran,"They, in turn, handed it to the Byzantines, from whom the Arab conquerors turned it into the hijab, transmitting it over the vast reaches of the Muslim world.",32,/john1/scr1/baom/text/AA/wiki_56,Muslim
6,Women in Iran,Fatimah inspired her husband as a devout Muslim.,109,/john1/scr1/baom/text/AA/wiki_56,Muslim
7,Women in Iran,"Later, after the Muslim Arabs conquered Sassanid Iran, early Muslims adopted veiling as a result of their exposure to the strong Iranian cultural influence.",260,/john1/scr1/baom/text/AA/wiki_56,Muslim
8,Lise Payette,"The proposal, seen to target Muslim women, was widely criticized even by some Quebec nationalists.",45,/john1/scr1/baom/text/AA/wiki_56,Muslim
9,"Bhavani, Tamil Nadu","As per the religious census of 2011, Bhavani had 93.33% Hindus, 4.24% Muslims, 2.35% Christians, 0.01% Sikhs, 0.05% following other religions and 0.02% following no religion or did not indicate any religious preference.",71,/john1/scr1/baom/text/AA/wiki_56,Muslim


In [15]:
"""Given any data directory containing a doc.txt file, uses BERT or GPT to generate
a corresponding tokens.pickle file and a corresponding activations.npz file."""
from transformers import GPT2Tokenizer, GPT2Model, BertTokenizer, BertModel
import os
import argparse
import pickle
import torch
import numpy as np
import sys
import ast
import pandas as pd
from tqdm.notebook import tqdm
sys.path.insert(0, os.path.abspath('.'))  # add CWD to path

max_docs = None  # Max number of documents to read or None. If None, this is ignored.
max_contexts = None  # Max number of contexts to read or None. If None, this is ignored.
max_toks = 30  # Max number of tokens in an acceptable document. If None, this is ignored.

model_type = 'bert' # 'gpt'

random_state = 1
frac = 1.0 # 0.02 // fraction of rows to sample from in provided .tsv file




In [16]:
tokenizer = None
model = None

if model_type == 'bert':
    tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
    model = BertModel.from_pretrained('bert-base-cased', output_hidden_states=True)
elif model_type == 'gpt':
    tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
    model = GPT2Model.from_pretrained('gpt2', output_hidden_states=True)
else:
    print("Incorrect model_type set.")
    exit()

if device != -1:
    # move the model to the GPU
    torch.cuda.set_device(device)
    device = torch.device("cuda", device)
    model.to(device)

df = df[df['sentence'].map(len) < 512]
df_sub = df.sample(frac = frac, random_state = random_state)

if 'tokens' in df_sub.columns:
    df_sub.drop(columns=['tokens'], inplace=True)

# Create a list of contexts. Each context will be a tuple: (doc's tokens, position in doc).
contexts = []
# Create a dictionary to map layer to list of docs' activations.
# Each doc's activations will be size (# contexts x size of embedding)
layers = {}
n_docs_consolidated = 0
n_long_docs = 0

for _, row in tqdm(df_sub.iterrows()):

    sent = row['sentence']
    inputs = tokenizer(sent, return_tensors="pt")
    tokens = [tokenizer.decode(i).replace(' ', '') for i in inputs['input_ids'].tolist()[0]]

    try:
        outputs = model(**inputs)
        hidden_state = outputs.hidden_states
        hidden_state = torch.stack(hidden_state, dim=0)
        hidden_state = torch.squeeze(hidden_state, dim=1)
    except Exception as e:
        print(str(e))
        print(row['sentence'])
        hidden_state = ()

    for tok_i in range(len(tokens)):
        context = (tokens, tok_i)
        contexts.append(context)

    num_layers = hidden_state.shape[0]
    for l in range(num_layers):
        layer = f'arr_{l}'

        if layer not in layers:
            layers[layer] = hidden_state[l, :, :,].detach().numpy()
        else:
            layers[layer] = np.concatenate([layers[layer], hidden_state[l, :, :,].detach().numpy()])

    print(f'Doc {n_docs_consolidated}: ({len(tokens)} tokens) --> {len(contexts)} total contexts')
    n_docs_consolidated += 1
    print(n_docs_consolidated)
    if n_docs_consolidated == max_docs:
        break  # Done

print(f'Found {n_docs_consolidated} docs & {len(contexts)} contexts and obtained activations of shape {layers[layer].shape}')
if max_toks:
    print(f'Ignored {n_long_docs} docs longer than {max_toks} tokens.')


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Doc 0: (40 tokens) --> 40 total contexts
1
Doc 1: (62 tokens) --> 102 total contexts
2
Doc 2: (14 tokens) --> 116 total contexts
3
Doc 3: (36 tokens) --> 152 total contexts
4
Doc 4: (16 tokens) --> 168 total contexts
5
Doc 5: (31 tokens) --> 199 total contexts
6
Doc 6: (30 tokens) --> 229 total contexts
7
Doc 7: (33 tokens) --> 262 total contexts
8
Doc 8: (20 tokens) --> 282 total contexts
9
Doc 9: (39 tokens) --> 321 total contexts
10

Found 10 docs & 321 contexts and obtained activations of shape (321, 768)
Ignored 0 docs longer than 30 tokens.
