<h1><span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"></ul></div>

In [1]:
import argparse
import os
import collections
import random
import time
from operator import itemgetter

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

import spacy
from datasets import load_dataset, list_metrics, load_metric

In [8]:


def is_ascii(s):
    return all(ord(c) < 128 for c in s)

def filterText(iterator):

    valid  = []
    for text in iterator:
        if len(text) < 100:
            continue
        if not is_ascii(text):
            continue
        valid.append(text)

    return valid



class DataProcessor:
    def __init__(
        self, 
        text, 
        write_dir=None, 
        parallel=False
    ):
        self.text = text

        self.write_dir = write_dir
        self.parallel = parallel
        
        self.raw_texts = []
        self.ner_texts = []
        self.permuted = []
        self.changed_ents = []
        
        self.ents = collections.defaultdict(list)

        self.model = spacy.load(
            "en_core_web_sm", 
            exclude=['tagger', 'parser', 'attribute_ruler', 'lemmatizer']
        )
        self.model.add_pipe('sentencizer')
        
        self.keep_ents = ['PERSON', 'ORG', 'GPE']
    
    
    def run(self, func, args):
        if self.parallel:
            with ProcessPoolExecutor() as executor:
                for output in executor.map(func, args):
                    return output.result(timeout=None)
        else:
            for output in func(*args):
                yield output
            
    
    
    def permuteEnts(self):
        timestamp = time.time()
        
            
        for idx, (sent, ents) in enumerate(self.ner_texts):
            
            if self.write_dir:
                if not os.path.exists(self.write_dir):
                    os.mkdir(self.write_dir)
                    print(f"Warning: {self.write_dir} does not exist. Creating...")
                permuteFile = open(self.write_dir + f'/permuted_entities.{idx}', 'w')
                origFile = open(self.write_dir + f'/original_entities.{idx}', 'w')
                entFile = open(self.write_dir + f'/entity_swaps.{idx}', 'w')

            eligible = list(filter(lambda x: x[3] in self.keep_ents, ents))
            orig_ent = random.choice(eligible)
            ent_type = orig_ent[3]
            start, end  = orig_ent[1:3]
            while True:
                replace_ent = random.choice(self.ents[ent_type])
                if replace_ent != orig_ent[0]: break

            prefix = sent[:start]
            suffix = sent[end:]
            new_sent = prefix + replace_ent + suffix

            if self.write_dir:
                permuteFile.write(new_sent + "\n")
                origFile.write(self.raw_texts[idx].strip('\n').strip(" ") + "\n")
                entFile.write(f"{orig_ent[0]}|{replace_ent}\n")

                permuteFile.close()
                origFile.close()
                entFile.close()
                
            self.permuted.append(new_sent)
            self.changed_ents.append((orig_ent[0], replace_ent))
            
    
    
    def processEnts(self):
                
        for output in self.runNER(self.text):
            self.ner_texts.append(output)
        
        
    def runNER(self, texts):
        for doc in self.model.pipe(texts):
            processed = []
            for sent in doc.sents:
                if any([e.label_ in self.keep_ents for e in sent.ents]):
                    ents = []
                    for e in sent.ents:
                        ents.append((e.text, e.start_char - sent.start_char, e.end_char - sent.start_char, e.label_))
                        self.ents[e.label_].append(e.text)
                    processed.append((sent.text, ents))
            if processed:
                self.raw_texts.append(doc.text)
                yield random.choice(processed)
            
    
    def __repr__(self):
        
        return (f"DataProcessor:<{len(self.text)} RAW>"
                f"<{len(self.ner_texts)} NER>"
                f"<{len(self.permuted)} PERM>"
                f"<{sum([len(self.ents[k]) for k in self.ents])} ENTS>")


def generateDataset(
    writeDir, 
    process=True,
    sample=int(1e5),
    set='train', 
    pct='10'
    ):
    wikitext = load_dataset(
            'wikitext', 
            'wikitext-103-raw-v1', 
            cache_dir="/Volumes/External_HD/Dev/datasets/wikitext", 
            split=f'{set}[:{pct}%]'
        )

    random.seed(123)
    wiki_len = len(wikitext['text']) - 100
    if wiki_len <= sample:
        passage_idxs = list(range(wiki_len))
    else:
        passage_idxs = random.sample(range(1, wiki_len), sample)
    res_list = list(itemgetter(*passage_idxs)(wikitext['text'])) 
    sampleText = filterText(res_list)
    dp = DataProcessor(sampleText, write_dir=writeDir)

    if process:
        print("running processor")
        dp.keep_ents = ['PERSON']
        dp.processEnts()
        print(dp)   
        dp.permuteEnts()
        print(dp)
    else:
        return dp


In [3]:
class TorchDataset(torch.utils.data.Dataset):
    def __init__(self, list_IDs, tokenizer, data_loc="..", dataset='train', max_length=200):
        self.list_IDs = list_IDs
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.ent_length = 5
        self.dataset = dataset
        self.loc = data_loc
        
        
    def tokenize(self, textList, ent=False):
        tokList = []
        for idx in range(len(textList)):
            if ent:
                tok = self.tokenizer(
                    textList[idx],
                    truncation=True,
                    max_length=self.ent_length, 
                    padding="max_length"
                )
            else:
                tok = self.tokenizer(
                    textList[idx],
                    truncation=True,
                    max_length=self.max_length, 
                    padding="max_length"
                )
            tokList.append(
                (
                    torch.tensor(tok['input_ids']), 
                    torch.tensor(tok['attention_mask'])
                )
            )
        if len(tokList) > 1:
            return tokList
        return tokList[0]
        

    def __len__(self):
        return len(self.list_IDs)

    def __getitem__(self, index):
        ## TODO: make this specific to a batch 
        
        ID = self.list_IDs[index]
        
        if self.dataset == 'train':
            path = f"{self.loc}/data"
        elif self.dataset == 'valid':
            path = f"{self.loc}/data/valid"
        elif self.dataset == 'test':
            path = f"{self.loc}/data/test"

        with open(f"{path}/original_entities.{ID}") as raw:
            raw_sample = raw.read()
        with open(f"{path}/permuted_entities.{ID}") as perm:
            permuted_sample = perm.read()
        with open(f"{path}/entity_swaps.{ID}") as ent:
            ent_sample = ent.read()
            ents = ent_sample.strip().split('|')
            new_ent = ents[-1]
            old_ent = ents[0]

        raw, perm = self.tokenize([" "+raw_sample, " "+permuted_sample])
        new_ent_tok, old_ent_tok = self.tokenize([" "+new_ent, " "+old_ent], ent=True)


        return raw, perm, new_ent_tok, old_ent_tok


In [5]:
!mkdir data_tests

In [13]:
# dp = generateDataset('data_tests', set='train', pct='100', process=False, sample=int(1e6))
dp = generateDataset('data_tests', set='train', pct='100', process=False, sample=int(1e5))

Reusing dataset wikitext (/Volumes/External_HD/Dev/datasets/wikitext/wikitext/wikitext-103-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91)


In [16]:
# dp.keep_ents = ['PERSON']
# dp.processEnts()
# print(dp)   
# dp.permuteEnts()

In [17]:
from transformers import GPT2Tokenizer

In [35]:
tokenizer = GPT2Tokenizer.from_pretrained(
        'gpt2', cache_dir="/Volumes/External_HD/Dev/Editable_NLP_Models/hf"
        )
tokenizer.pad_token = tokenizer.eos_token

In [36]:
import glob
max_obs=float('inf')
writtenFiles = glob.glob("data/permuted*")  
fileIndex = max(map(lambda x: int(x.split(".")[-1]), writtenFiles))
limitIndex = min(max_obs, fileIndex)
ds = TorchDataset(list(range(limitIndex)), tokenizer, data_loc='.', dataset='train')


In [57]:
dataloader = torch.utils.data.DataLoader(
        ds,
        batch_size=1,
        shuffle=True
    )


In [58]:
def stripPadding(tok):
    flat = tok[0].squeeze()
    return flat[flat != 50256]

def decode(tok):
    strp = stripPadding(tok)
    return tokenizer.decode(strp)

In [62]:
for raw, perm, new_ent_tok, old_ent_tok in dataloader:
    for tok in [raw, perm, old_ent_tok, new_ent_tok]:
        print(decode(tok), '\n')
        
    val = input("\n\n")

 Tom Savini as Deputy Tolo : Gets his finger bitten off by the sickos, then gets eaten alive by them.
 

  Mulder as Deputy Tolo : Gets his finger bitten off by the sickos, then gets eaten alive by them.
 

 Tom Savini 

 Mulder 




 Warren Martyn and Adrian Wood, the authors of the book I Can 't Believe It's a Bigger and Better Updated Unofficial Simpsons Guide, disliked the episode, writing that it was " very dull " and that Dafoe was not used well. However, Dafoe is one of show runner Josh Weinstein's favorite guest stars. Ian Johnson argued Dafoe's casting was " rare " and " somewhat offbeat ".
 

 Phil Theobald argued Dafoe's casting was " rare " and " somewhat offbeat ".
 

 Ian Johnson 

 Phil Theobald 




 Danielle Beaubrun was the only Saint Lucian swimmer to participate in the Beijing Olympics. She was the youngest member of the Saint Lucian delegation, at age 18. The 2008 Summer Olympics served as Beaubrun's Olympic debut. Beaubrun did not initially qualify for Olympic sta




 Borat attends a United Pentecostal camp meeting, at which Republican U.S. Representative Chip Pickering and Mississippi Supreme Court Chief Justice James W. Smith, Jr. are present. He regains his faith, and forgives Azamat and Pamela. He accompanies church members on a bus to Los Angeles and disembarks to find Azamat dressed as Oliver Hardy ( though Borat thinks that he is dressed as Adolf Hitler ). The two reconcile and Azamat tells Borat where to find Pamela Anderson. Borat finally comes face @-@ to @-@ face with Anderson at a book signing at a Virgin Megastore. After showing Anderson his " traditional marriage sack ", Borat pursues her throughout the store in an attempt to abduct her until he is tackled and handcuffed by security guards. Borat visits Luenell and they return to Kazakhstan together.
 

  Borat attends a United Pentecostal camp meeting, at which Republican U.S. Representative Chip Pickering and Mississippi Supreme Court Chief Justice Steve Vladeck, Jr. are present.




 The next event reported by the Chronicle of the Kings of Alba is dated to 906. This records that :
 

  The next event reported by the Chronicle of the Kings of Lenny Kravitz is dated to 906.
 

 Alba 

 Lenny Kravitz 




 One of the people to whom Jenner posed his question was Noel Stanton, a man from Bedfordshire, England, who was serving in Sydney with the Royal Navy at the time. Stanton became preoccupied with the memory of this meeting for several months afterwards and, the next year, became a committed Christian. Stanton went on to found the Jesus Army in Northampton, England, in 1969. In 1945, Jenner approached Norrie Jeffs, who had just returned from participating in Operation Meridian at Palembang on Sumatra, and, having asked Jeffs his question, Jeffs responded that he was already a Christian. Jenner then invited Jeffs over to his house, where Jeffs met several other visitors, including the woman who would later become his wife. In 1952, another person Jenner accosted wi




 In each of their following two matches, at home against Hampshire and away to Worcestershire, Somerset batted first and then enforced the follow @-@ on after bowling their opponents out cheaply. In each their opponents managed to avoid defeat, and both matches resulted in draws. During the Worcestershire match, Langer's first innings 107 took him past Sir Donald Bradman's total of 28 @,@ 067 first @-@ class runs to become the highest @-@ scoring Australian batsman. Successive draws against Nottinghamshire, Warwickshire, Sussex and Hampshire meant that Somerset travelled to Durham requiring a victory to maintain any realistic hopes of claiming the County Championship title. No play was possible on the third and fourth days, and the match resulted in another draw, leaving Somerset with only a slim mathematical chance of the title. Another draw, against Lancashire, while Durham beat Nottinghamshire, meant that Durham clinched the title. Somerset drew with Worcesters 

 During the Worc

KeyboardInterrupt: Interrupted by user

In [64]:
" = = = = Boston Bruins ( 2007 @-@ 2015 ) = = = ="

' = = = = Boston Bruins ( 2007 @-@ 2015 ) = = = ='

In [202]:
!pwd

/Users/spencerbraun/dev/repos/editable_nlp


In [194]:
class WikitextDataset(torch.utils.data.Dataset):
    def __init__(
        self,          
        data_loc="..", 
        dataset='train', 
        pct=100, 
        min_length=100
    ):
        self.dataset = load_dataset(
            'wikitext', 
            'wikitext-103-raw-v1', 
            cache_dir=data_loc, 
            split=f'{dataset}[:{pct}%]'
        )
        self.filtered = self.filterText(self.dataset['text'])
        self.min_length = min_length
    
    @staticmethod
    def filterText(iterator, min_len=100):
        isascii = lambda s: all(ord(c) < 128 for c in s)
        valid  = []
        for text in iterator:
            if len(text) < min_len:
                continue
            if not isascii(text):
                continue
            valid.append(text)

        return valid

    def __len__(self):
        return len(self.filtered)

    def __getitem__(self, index):
        
        return self.filtered[index]

In [205]:
ds = WikitextDataset(
    data_loc="/Volumes/External_HD/Dev/datasets/wikitext", 
    dataset='train', 
    pct=1, 
    min_length=100
)

Reusing dataset wikitext (/Volumes/External_HD/Dev/datasets/wikitext/wikitext/wikitext-103-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91)


In [245]:
tokenizer.padding_side = "left" 
tokenizer.pad_token = tokenizer.eos_token
def pad_collate(batch):
    toks = tokenizer(
                batch,
                truncation=True,
                max_length=200,
                padding=True
            )
    return torch.tensor(toks['input_ids']), torch.tensor(toks['attention_mask'])

In [249]:
dataloader = torch.utils.data.DataLoader(
        ds,
        batch_size=1,
        num_workers=1,
        pin_memory=True,
        shuffle=False,
        collate_fn=pad_collate
    )

In [250]:
for ids, mask in dataloader:
    print(ids.shape)
    input()

torch.Size([1, 107])

torch.Size([1, 112])

torch.Size([1, 200])

torch.Size([1, 200])


KeyboardInterrupt: Interrupted by user

In [None]:
def wikiDataloader(
    tokenizer, 
    bs=10, 
    data_loc='..',
    dataset='train',
    shuffle=False,
    max_length=200,
    min_length=20
    ):
    
    tokenizer.padding_side = "left" 
    tokenizer.pad_token = tokenizer.eos_token
    
    def pad_collate(batch):
        tokenizer(
                    batch,
                    truncation=True,
                    max_length=200,
                    padding=True
                )
        return tok_pad, mask_pad

    ds = WikitextDataset(
        data_loc=f"{data_loc}/hf", 
        dataset=dataset,
        pct=100,
        min_length=min_length
        )
    dataloader = torch.utils.data.DataLoader(
        ds,
        batch_size=bs,
        num_workers=2,
        pin_memory=True,
        shuffle=shuffle,
        collate_fn=pad_collate if bs > 1 else None
    )

    return dataloader