In [9]:
import os, re, nltk
from cached_property import cached_property

sent_tokenizer = nltk.data.load('tokenizers/punkt/english.pickle')


class Document:
    """ Contains all sentences in a Wiki article and the filename """
    def __init__(self, sentences, filename):
        self.sentences = sentences
        self.filename = filename
    
    def __getitem__(self, idx):
        return self.sentences[idx]

    def __repr__(self):
        return 'Document containing %d sentences' % (len(self.sentences))
    
    def __len__(self):
        return len(self.sentences)


class Sentence:
    """ Contain text, tokenized text, and label of a sentence """
    def __init__(self, text, label):
        self.text = text
        self.tokens = [clean_token(t) for t in text.split()]
        self.label = label
        
    def __getitem__(self, idx):
        return self.tokens[idx]
        
    def __repr__(self):
        return '"' + self.text + '"'
    
    def __len__(self):
        return len(self.tokens)


class LazyVectors:
    """Load only those vectors from GloVE that are in the vocab."""
    
    unk_idx = 1

    def __init__(self, name='glove.840B.300d.txt'):
        """ Requires the glove vectors to be in a folder named .vector_cache
        
        In Bash/Terminal from directory this class is in:
        >> mkdir .vector_cache
        >> mv glove.840B.300d.txt .vector_cache/glove.840B.300d.txt
        """
        self.name = name
        self.set_vocab()

    @cached_property
    def loader(self):
        return Vectors(self.name)
    
    def get_vocab(self, filename='vocabulary.txt'):
        with open(filename, 'r') as f:
            vocab = f.read().split(',')
        return vocab
    
    def set_vocab(self):
        """Set corpus vocab """
        # Intersect with model vocab.
        self.vocab = [v for v in self.get_vocab() if v in self.loader.stoi]

        # Map string -> intersected index.
        self._stoi = {s: i for i, s in enumerate(self.vocab)}

    def weights(self):
        """Build weights tensor for embedding layer """
        # Select vectors for vocab words.
        weights = torch.stack([
            self.loader.vectors[self.loader.stoi[s]]
            for s in self.vocab
        ])

        # Padding + UNK zeros rows.
        return torch.cat([
            torch.zeros((2, self.loader.dim)),
            weights,
        ])

    def stoi(self, s):
        """Map string -> embedding index.
        """
        idx = self._stoi.get(s)
        return idx + 2 if idx else self.unk_idx


def crawl_directory(dirname):
    """ Walk a nested directory to get all filename ending in a pattern """
    filenames = []
    for path, subdirs, files in os.walk(dirname):
        for name in files:
            if not name.endswith('.DS_Store'):
                yield os.path.join(path, name)

def read_document(filename, minlen=0):
    """ Read in a Wiki-727 file """
    doc, subsection = [], ''
    with open(filename) as f:
        for line in f.readlines()[1:]:
            if line.startswith('========'):
                doc.append(subsection)
                subsection = ''
            else:
                subsection += ' ' + line[:-1] # Exclude \n 
        doc.append(subsection)
        
    doc = flatten([sent_tokenizer.tokenize(d.strip())
                for d in doc if len(d) > minlen]) # Exclude empty subsections
    
    labels = doc_to_labels(doc)
    
    document = Document([Sentence(text, label) for text, label in zip(doc, labels)], 
                         filename)
        
    return document
        
def doc_to_labels(document):
    """ Convert Wiki-727 file to labels 
    (last sentence of a subsection is 1, otherwise 0) """
    return flatten([(len(doc)-1)*[0] + [1] for doc in document])

def clean_token(token):
    """  Remove everything but whitespace, the alphabet; 
    separate apostrophes for stopwords """
    if token.isdigit():
        token = '#NUM'
    else:
        token = re.sub(r"[^a-z0-9\s]", '', token.lower())
        token = re.sub(r"[']+", ' ', token)
    return token

def flatten(alist):
    """ Flatten a list of lists into one list """
    return [item for sublist in alist for item in sublist]