In [28]:
from pathlib import Path
from typing import *
import torch
import torch.optim as optim
import numpy as np
import pandas as pd
from functools import partial
from overrides import overrides

from allennlp.data import Instance
from allennlp.data.token_indexers import TokenIndexer
from allennlp.data.tokenizers import Token
from allennlp.nn import util as nn_util

In [29]:
class Config(dict):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        for k, v in kwargs.items():
            setattr(self, k, v)
    
    def set(self, key, val):
        self[key] = val
        setattr(self, key, val)
        
config = Config(
    testing=True,
    seed=1,
    batch_size=64,
    lr=3e-4,
    epochs=2,
    hidden_sz=64,
    max_seq_len=100, # necessary to limit memory usage
    max_vocab_size=100000,
)


In [30]:
USE_GPU = torch.cuda.is_available()

In [31]:
from allennlp.common.checks import ConfigurationError

In [32]:
from allennlp.data.vocabulary import Vocabulary
from allennlp.data.dataset_readers import DatasetReader

## Prepare dataset

In [10]:
label_cols = ["science", "funny"]
# label_cols = ['science', 'funny', 'engineering', 'compsci',
#                  'machinelearning', 'datascience', 'math', 'statistics']

In [33]:
from allennlp.data.fields import TextField, MetadataField, ArrayField

class RedditDatasetReader(DatasetReader):
    def __init__(self, tokenizer: Callable[[str], List[str]]=lambda x: x.split(),
                 token_indexers: Dict[str, TokenIndexer] = None,
                 max_seq_len: Optional[int]=config.max_seq_len) -> None:
        super().__init__(lazy=False)
        self.tokenizer = tokenizer
        self.token_indexers = token_indexers or {"tokens": SingleIdTokenIndexer()}
        self.max_seq_len = max_seq_len

    @overrides
    def text_to_instance(self, tokens: List[Token],
                         labels: np.ndarray=None,
                        id: str=None,) -> Instance:
        sentence_field = TextField(tokens, self.token_indexers)
     
        fields = {"tokens": sentence_field}
        
        fields["label"] = ArrayField(array=labels)
    
        id_field = MetadataField(id)
        fields["id"] = id_field
        
        return Instance(fields)
    
    @overrides
    def _read(self, file_path: str) -> Iterator[Instance]:
        # Reads in the pickle file and inputs the document, label, and id to make an instance
        df = pd.read_pickle(file_path)
        #df = df[df.labels<2]
        # Imports data into the tokenizer
        for index, row in df.iterrows():
            #print(row.documents)
            yield self.text_to_instance([Token(x) for x in self.tokenizer(row.documents)],
                                        row[label_cols].values,
                                        index)
            
            

## Prepare token handlers

In [34]:
from allennlp.data.tokenizers.word_splitter import SpacyWordSplitter
from allennlp.data.token_indexers import SingleIdTokenIndexer

# the token indexer is responsible for mapping tokens to integers
token_indexer = SingleIdTokenIndexer()

def tokenizer(x: str):
    return [w.text for w in
            SpacyWordSplitter(language='en_core_web_sm', 
                              pos_tags=False).split_words(x)[:config.max_seq_len]]

In [35]:
reader = RedditDatasetReader(
    tokenizer=tokenizer,
    token_indexers={"tokens": token_indexer}
)

In [36]:
#train_ds, test_ds = (reader.read(fname) for fname in ['reddit_thread_label_small','reddit_thread_label_small'])
train_ds = reader.read('reddit_thread_label_small_train')

53it [00:00, 105.27it/s]

TextField of length 33 with text: 
 		[Scientists, are, using, imaging, tests, to, show, for, the, first, time, that, fructose, can,
		trigger, brain, changes, that, may, lead, to, overeating, ., Fructose, is, a, sugar, that,
		saturates, the, American, diet, .]
 		and TokenIndexers : {'tokens': 'SingleIdTokenIndexer'}
TextField of length 4 with text: 
 		[Planets, oldest, fossils, found]
 		and TokenIndexers : {'tokens': 'SingleIdTokenIndexer'}
TextField of length 60 with text: 
 		[In, Brazil, 8, species, have, been, put, on, the, list, for, cloning, ., The, jaguar, being, one,
		of, them, ., There, are, already, 420, samples, of, wild, tissue, ., When, cloning, begins, ,, the,
		jaguars, ,, (, and, the, other, 7, species, ), ,, will, be, kept, in, captivity, *, ,, in, case,
		the, population, of, wild, jaguars, collapsed, .]
 		and TokenIndexers : {'tokens': 'SingleIdTokenIndexer'}
TextField of length 42 with text: 
 		[", Individuals, who, believe, that, racial, groups, have, fixed

166it [00:00, 179.68it/s]

TextField of length 12 with text: 
 		[Atoms, at, negative, absolute, temperature, :, The, hottest, systems, in, the, world]
 		and TokenIndexers : {'tokens': 'SingleIdTokenIndexer'}
TextField of length 16 with text: 
 		[Scalable, nanopatterned, surfaces, designed, by, MIT, researchers, could, make, for, more,
		efficient, power, generation, and, desalination]
 		and TokenIndexers : {'tokens': 'SingleIdTokenIndexer'}
TextField of length 45 with text: 
 		[It, used, to, be, that, the, cause, of, death, could, only, be, determined, by, cutting, a, corpse,
		open, ., But, a, new, ,, virtual, procedure, developed, by, Swiss, researchers, is, providing, new,
		insights, into, dead, bodies, ., It, could, help, identify, previously, undiscovered, murders, .]
 		and TokenIndexers : {'tokens': 'SingleIdTokenIndexer'}
TextField of length 16 with text: 
 		[El, Niño, is, stronger, this, century, ,, but, still, ca, n’t, be, linked, to, climate, change]
 		and TokenIndexers : {'tokens': 'SingleIdT

259it [00:00, 259.97it/s]

TextField of length 11 with text: 
 		[Giant, squid, captured, alive, on, film, for, the, first, time, .]
 		and TokenIndexers : {'tokens': 'SingleIdTokenIndexer'}
TextField of length 31 with text: 
 		[It, 's, So, Hot, in, Australia, That, They, Added, New, Colors, to, the, Weather, Map, ,, ",
		because, their, country, is, ,, you, know, ,, kind, of, on, fire, ., "]
 		and TokenIndexers : {'tokens': 'SingleIdTokenIndexer'}
TextField of length 34 with text: 
 		[NOAA, :, 2012, was, warmest, and, second, most, extreme, year, on, record, for, the, contiguous,
		U.S., (, Every, state, in, the, contiguous, U.S., had, an, above, -, average, annual, temperature,
		for, 2012, ., )]
 		and TokenIndexers : {'tokens': 'SingleIdTokenIndexer'}
TextField of length 22 with text: 
 		[Outburst, from, Enormous, Black, Hole, 11-Billion, Years, Ago, Swept, Past, Earth, in, 2011,
		--"Brighter, than, All, the, Stars, in, Milky, Way, "]
 		and TokenIndexers : {'tokens': 'SingleIdTokenIndexer'}
TextField o

358it [00:00, 339.58it/s]

TextField of length 35 with text: 
 		[Biggest, Structure, in, Universe, -, Large, Quasar, Group, is, 4, Billion, Light, Years, Across, :,
		", This, new, ,, Huge, -, LQG, appears, to, be, the, largest, structure, currently, known, in, the,
		early, Universe, "]
 		and TokenIndexers : {'tokens': 'SingleIdTokenIndexer'}
TextField of length 19 with text: 
 		[Largest, structure, in, the, universe, discovered, !, It, 's, so, big, ,, theory, says, it, should,
		n't, exist, .]
 		and TokenIndexers : {'tokens': 'SingleIdTokenIndexer'}
TextField of length 7 with text: 
 		[NGC, 6872, :, Largest, Spiral, Galaxy, Known]
 		and TokenIndexers : {'tokens': 'SingleIdTokenIndexer'}
TextField of length 50 with text: 
 		[Solar, variability, and, terrestrial, climate, :, A, new, report, issued, by, the, National,
		Research, Council, (, NRC, ), ,, ", The, Effects, of, Solar, Variability, on, Earth, 's, Climate, ,,
		", lays, out, some, of, the, surprisingly, complex, ways, that, solar, activity, can, 

467it [00:01, 418.70it/s]

TextField of length 15 with text: 
 		[How, 19-year, -, old, activist, Zack, Kopplin, is, making, life, hell, for, Louisiana, 's,
		creationists]
 		and TokenIndexers : {'tokens': 'SingleIdTokenIndexer'}
TextField of length 12 with text: 
 		[Change, in, human, social, behavior, in, response, to, a, common, vaccine, .]
 		and TokenIndexers : {'tokens': 'SingleIdTokenIndexer'}
TextField of length 9 with text: 
 		[Inside, NASA, 's, Deal, for, Inflatable, Space, Station, Room]
 		and TokenIndexers : {'tokens': 'SingleIdTokenIndexer'}
TextField of length 14 with text: 
 		[Fecal, Transplants, demonstrate, 94, %, success, rate, for, Clostridium, infections, in,
		randomized, clinical, trial]
 		and TokenIndexers : {'tokens': 'SingleIdTokenIndexer'}
TextField of length 8 with text: 
 		[Mathematicians, aim, to, take, publishers, out, of, publishing]
 		and TokenIndexers : {'tokens': 'SingleIdTokenIndexer'}
TextField of length 30 with text: 
 		[Mathematical, breakthrough, sets, out, rules, 

579it [00:01, 459.24it/s]

TextField of length 23 with text: 
 		[Biologists, may, have, solved, Peto, 's, paradox, -, why, do, larger, animals, such, as, whales,
		or, elephants, get, fewer, cancers, than, humans, ?]
 		and TokenIndexers : {'tokens': 'SingleIdTokenIndexer'}
TextField of length 15 with text: 
 		[New, Huntington, 's, disease, cell, model, derived, from, embryonic, stem, cells, [, FASEB, J, ]]
 		and TokenIndexers : {'tokens': 'SingleIdTokenIndexer'}
TextField of length 9 with text: 
 		[Mars, may, have, been, inhabited, by, microorganisms, :, Study]
 		and TokenIndexers : {'tokens': 'SingleIdTokenIndexer'}
TextField of length 15 with text: 
 		[Men, without, a, sense, of, smell, exhibit, a, strongly, reduced, number, of, sexual,
		relationships, .]
 		and TokenIndexers : {'tokens': 'SingleIdTokenIndexer'}
TextField of length 35 with text: 
 		[Study, suggests, increased, diagnosis, rate, of, attention, -, deficit, /, hyperactivity, disorder,
		., White, male, children, who, live, in, high, -, in

690it [00:01, 459.52it/s]

TextField of length 12 with text: 
 		[Diet, ,, Parental, Behavior, ,, and, Preschool, Can, Boost, Children, ’s, IQ]
 		and TokenIndexers : {'tokens': 'SingleIdTokenIndexer'}
TextField of length 12 with text: 
 		[Sealed, water, garden, thriving, for, 40, years, without, any, water, or, air]
 		and TokenIndexers : {'tokens': 'SingleIdTokenIndexer'}
TextField of length 38 with text: 
 		[Greenland, ice, a, benchmark, for, warming, :, Greenland, was, about, eight, degrees, warmer,
		130,000, years, ago, than, it, is, today, ,, an, analysis, of, an, almost, three, -, kilometre, -,
		long, ice, core, in, Greenland, has, revealed, .]
 		and TokenIndexers : {'tokens': 'SingleIdTokenIndexer'}
TextField of length 16 with text: 
 		[Can, studying, the, girl, who, does, n't, age, unlock, the, ", fountain, of, youth, ", ?]
 		and TokenIndexers : {'tokens': 'SingleIdTokenIndexer'}
TextField of length 10 with text: 
 		[Forget, the, Flu, :, The, Norovirus, Is, Taking, Over, America]
 		and TokenInd

TextField of length 9 with text: 
 		[Heat, from, North, American, cities, causing, warmer, winters, .]
 		and TokenIndexers : {'tokens': 'SingleIdTokenIndexer'}
TextField of length 22 with text: 
 		[University, of, New, Orleans, has, NOT, discovered, a, cure, for, Huntington, 's, disease, ..., (,
		teardown, of, bad, science, press, release, )]
 		and TokenIndexers : {'tokens': 'SingleIdTokenIndexer'}
TextField of length 63 with text: 
 		[", Researchers, have, found, ,, for, instance, ,, that, when, a, parent, acts, affectionately,
		with, his, or, her, infant, —, through, micro, -, moments, of, love, like, making, eye, contact, ,,
		smiling, ,, hugging, ,, and, playing, —, oxytocin, levels, in, both, the, parent, and, the, child,
		rise, in, sync, ., Love, is, a, single, act, ,, performed, by, two, brains, ., "]
 		and TokenIndexers : {'tokens': 'SingleIdTokenIndexer'}
TextField of length 52 with text: 
 		[Domestic, cats, :, ", one, of, the, single, greatest, human, -, linked, thr

784it [00:01, 279.22it/s]

TextField of length 10 with text: 
 		[Scientists, unveil, Staphyloccocus, aureus, superbug, 's, secret, to, antibiotic, resistance]
 		and TokenIndexers : {'tokens': 'SingleIdTokenIndexer'}
TextField of length 9 with text: 
 		[Scientists, Uncover, a, Previously, Unknown, Mechanism, of, Memory, Formation]
 		and TokenIndexers : {'tokens': 'SingleIdTokenIndexer'}
TextField of length 59 with text: 
 		[Two, studies, show, that, silibinin, ,, found, in, milk, thistle, ,, protects, against, UV, -,
		induced, skin, cancer, --, “, When, you, treat, human, skin, cells, with, silibinin, ,, nothing,
		happens, ., It, ’s, not, toxic, ., But, when, you, damage, these, cells, with, UVA, radiation, ,,
		treatment, with, silibinin, kills, the, cells, ,, ”, says, lead, researcher, .]
 		and TokenIndexers : {'tokens': 'SingleIdTokenIndexer'}
TextField of length 14 with text: 
 		[Natural, gas, vehicles, may, help, reduce, oil, dependence, even, more, than, electric, cars, .]
 		and TokenIndexers : {'

897it [00:02, 373.42it/s]

TextField of length 26 with text: 
 		[Study, by, CDC, :, ", Relative, to, normal, weight, ..., overweight, [, 25-<30, BMI, ], was,
		associated, with, significantly, lower, all, -, cause, mortality, ., "]
 		and TokenIndexers : {'tokens': 'SingleIdTokenIndexer'}
TextField of length 6 with text: 
 		[kermadec, trench, the, deep, water, womble]
 		and TokenIndexers : {'tokens': 'SingleIdTokenIndexer'}
TextField of length 9 with text: 
 		[Condensation, ,, Not, Temperature, ,, May, Drive, Global, Winds]
 		and TokenIndexers : {'tokens': 'SingleIdTokenIndexer'}
TextField of length 10 with text: 
 		[Scientists, find, new, target, for, treating, wide, spectrum, of, cancers]
 		and TokenIndexers : {'tokens': 'SingleIdTokenIndexer'}
TextField of length 33 with text: 
 		[By, cloaking, nanoparticles, in, the, membranes, of, white, blood, cells, ,, scientists, may, have,
		found, a, way, to, prevent, the, body, from, recognizing, and, destroying, them, before, they,
		deliver, their, drug, pay

989it [00:02, 369.12it/s]

TextField of length 11 with text: 
 		[Did, viruses, evolve, from, an, extinct, fourth, domain, of, life, ?]
 		and TokenIndexers : {'tokens': 'SingleIdTokenIndexer'}
TextField of length 45 with text: 
 		[Astronomers, have, found, that, six, percent, of, red, dwarf, stars, have, habitable, ,, Earth, -,
		sized, planets, ., Since, red, dwarfs, are, the, most, common, stars, in, our, galaxy, ,, the,
		closest, Earth, -, like, planet, could, be, just, 13, light, -, years, away, .]
 		and TokenIndexers : {'tokens': 'SingleIdTokenIndexer'}
TextField of length 59 with text: 
 		[", Extinction, :, Beyond, dinosaurs, and, dodos, ", -, The, extinction, of, an, animal, species,
		always, brings, finality, ,, but, in, the, wake, of, the, demise, of, species, other, creatures,
		can, prosper, ., More, than, 99, %, of, the, species, which, once, roamed, our, planet, no, longer,
		exist, ,, yet, a, rich, range, of, plants, and, animals, survived, .]
 		and TokenIndexers : {'tokens': 'SingleIdTokenI

1121it [00:02, 476.02it/s]

TextField of length 23 with text: 
 		[Towering, chimney, -, like, sedimentary, rock, spires, known, as, hoodoos, may, provide, an,
		indication, of, an, area, 's, past, earthquake, activity, ,, study]
 		and TokenIndexers : {'tokens': 'SingleIdTokenIndexer'}
TextField of length 8 with text: 
 		[Global, increasing, trends, in, annual, maximum, daily, precipitation]
 		and TokenIndexers : {'tokens': 'SingleIdTokenIndexer'}
TextField of length 18 with text: 
 		[On, November, 13, ,, a, rare, type, of, solar, eclipse, will, occur, on, and, near, the, Atlantic,
		.]
 		and TokenIndexers : {'tokens': 'SingleIdTokenIndexer'}
TextField of length 9 with text: 
 		[Remember, when, Christopher, Walken, looked, like, Scarlett, Johansson, ?]
 		and TokenIndexers : {'tokens': 'SingleIdTokenIndexer'}
TextField of length 2 with text: 
 		[Relaxed, ducks]
 		and TokenIndexers : {'tokens': 'SingleIdTokenIndexer'}
TextField of length 3 with text: 
 		[NBA, players, ...]
 		and TokenIndexers : {'tokens'

1257it [00:02, 561.46it/s]

TextField of length 21 with text: 
 		[Went, to, the, local, pub, with, girlfriend, and, she, ordered, a, fish, and, chips, ,, came, with,
		a, free, beer, .]
 		and TokenIndexers : {'tokens': 'SingleIdTokenIndexer'}
TextField of length 16 with text: 
 		[Came, across, the, original, in, r, /, aww, and, felt, that, it, needed, some, context, .]
 		and TokenIndexers : {'tokens': 'SingleIdTokenIndexer'}
TextField of length 16 with text: 
 		[Sprite, Zero, ,, you, 've, angered, my, vision, -, impaired, dad, for, the, last, time, !]
 		and TokenIndexers : {'tokens': 'SingleIdTokenIndexer'}
TextField of length 5 with text: 
 		[the, artist, is, a, genius]
 		and TokenIndexers : {'tokens': 'SingleIdTokenIndexer'}
TextField of length 5 with text: 
 		[2, reasons, why, winter, sucks]
 		and TokenIndexers : {'tokens': 'SingleIdTokenIndexer'}
TextField of length 4 with text: 
 		[Un, -, American, haircuts]
 		and TokenIndexers : {'tokens': 'SingleIdTokenIndexer'}
TextField of length 18 with text

1394it [00:03, 615.99it/s]

TextField of length 2 with text: 
 		[Every, flight]
 		and TokenIndexers : {'tokens': 'SingleIdTokenIndexer'}
TextField of length 4 with text: 
 		[Fanatics, these, days, .]
 		and TokenIndexers : {'tokens': 'SingleIdTokenIndexer'}
TextField of length 6 with text: 
 		[Take, action, ,, take, control, .]
 		and TokenIndexers : {'tokens': 'SingleIdTokenIndexer'}
TextField of length 11 with text: 
 		[I, left, my, dog, unattended, with, the, 2, year, old, ...]
 		and TokenIndexers : {'tokens': 'SingleIdTokenIndexer'}
TextField of length 23 with text: 
 		[A, family, member, 's, girlfriend, had, surgery, ,, he, stood, above, her, as, she, was, waking,
		up, and, said, this, [, FB, ]]
 		and TokenIndexers : {'tokens': 'SingleIdTokenIndexer'}
TextField of length 5 with text: 
 		[The, most, Canadian, headline, ever]
 		and TokenIndexers : {'tokens': 'SingleIdTokenIndexer'}
TextField of length 6 with text: 
 		[Cat, fight, (, without, cats, )]
 		and TokenIndexers : {'tokens': 'SingleIdToken

1530it [00:03, 628.19it/s]

TextField of length 13 with text: 
 		[My, english, -, speaking, friend, got, scared, by, reading, the, first, three, words]
 		and TokenIndexers : {'tokens': 'SingleIdTokenIndexer'}
TextField of length 4 with text: 
 		[Hussein, 's, DJ, crisis]
 		and TokenIndexers : {'tokens': 'SingleIdTokenIndexer'}
TextField of length 22 with text: 
 		[In, honor, of, MLK, Jr, ,, I, had, my, fourth, graders, write, their, own, ', I, have, a, dream, ',
		speeches, ....]
 		and TokenIndexers : {'tokens': 'SingleIdTokenIndexer'}
TextField of length 6 with text: 
 		[Apartment, converted, into, a, garage, .]
 		and TokenIndexers : {'tokens': 'SingleIdTokenIndexer'}
TextField of length 5 with text: 
 		[The, Magic, Of, Facial, Hair]
 		and TokenIndexers : {'tokens': 'SingleIdTokenIndexer'}
TextField of length 5 with text: 
 		[With, what, ?, Attitude, ?]
 		and TokenIndexers : {'tokens': 'SingleIdTokenIndexer'}
TextField of length 9 with text: 
 		[Sometimes, we, forget, how, beautiful, life, really, is

1595it [00:03, 609.99it/s]

TextField of length 3 with text: 
 		[Seems, ....., legit]
 		and TokenIndexers : {'tokens': 'SingleIdTokenIndexer'}
TextField of length 7 with text: 
 		[Flynn, Rider, 's, too, clever, for, Disney]
 		and TokenIndexers : {'tokens': 'SingleIdTokenIndexer'}
TextField of length 5 with text: 
 		[My, favorite, yearbook, quote, yet]
 		and TokenIndexers : {'tokens': 'SingleIdTokenIndexer'}
TextField of length 11 with text: 
 		[Thank, you, reddit, for, preparing, me, for, my, test, today, .]
 		and TokenIndexers : {'tokens': 'SingleIdTokenIndexer'}
TextField of length 4 with text: 
 		[..., fuck, it, !]
 		and TokenIndexers : {'tokens': 'SingleIdTokenIndexer'}
TextField of length 9 with text: 
 		[The, Airline, of, Middle, Earth, is, doing, it, right]
 		and TokenIndexers : {'tokens': 'SingleIdTokenIndexer'}
TextField of length 8 with text: 
 		[Holy, Shit, It, 's, Me, !, !, !]
 		and TokenIndexers : {'tokens': 'SingleIdTokenIndexer'}
TextField of length 12 with text: 
 		[The, nutrition, 

1719it [00:03, 589.56it/s]

TextField of length 6 with text: 
 		[Beverage, shield, ..., activate, !, !]
 		and TokenIndexers : {'tokens': 'SingleIdTokenIndexer'}
TextField of length 12 with text: 
 		[My, friend, found, this, on, the, floor, of, her, 8th, grade, classrom]
 		and TokenIndexers : {'tokens': 'SingleIdTokenIndexer'}
TextField of length 19 with text: 
 		[Going, to, a, formal, affair, this, evening, ., Girlfriend, said, do, n't, fuck, this, up, .,
		Decisions, decisions, ...]
 		and TokenIndexers : {'tokens': 'SingleIdTokenIndexer'}
TextField of length 14 with text: 
 		[Of, all, the, cars, to, put, a, lift, kit, on, ..., A, PT, cruiser]
 		and TokenIndexers : {'tokens': 'SingleIdTokenIndexer'}
TextField of length 14 with text: 
 		[How, to, tell, your, overprotective, Dad, you, are, travelling, to, Hong, Kong, -, Imgur]
 		and TokenIndexers : {'tokens': 'SingleIdTokenIndexer'}
TextField of length 3 with text: 
 		[I, need, one]
 		and TokenIndexers : {'tokens': 'SingleIdTokenIndexer'}
TextField of l

1848it [00:03, 604.46it/s]

TextField of length 19 with text: 
 		[My, dog, just, got, out, of, surgery, ., I, think, he, might, still, be, high, from, the,
		anaesthetic, .]
 		and TokenIndexers : {'tokens': 'SingleIdTokenIndexer'}
TextField of length 6 with text: 
 		[Monty, Python, 's, Horse, Action, Figure]
 		and TokenIndexers : {'tokens': 'SingleIdTokenIndexer'}
TextField of length 5 with text: 
 		[And, suddenly, a, snow, monster]
 		and TokenIndexers : {'tokens': 'SingleIdTokenIndexer'}
TextField of length 10 with text: 
 		[The, speed, of, this, car, is, too, damn, high, !]
 		and TokenIndexers : {'tokens': 'SingleIdTokenIndexer'}
TextField of length 15 with text: 
 		[", I, swear, honey, ,, I, 'm, just, as, surprised, as, you, are, !, "]
 		and TokenIndexers : {'tokens': 'SingleIdTokenIndexer'}
TextField of length 8 with text: 
 		[My, friend, 's, turtle, is, super, happy, =)]
 		and TokenIndexers : {'tokens': 'SingleIdTokenIndexer'}
TextField of length 7 with text: 
 		[The, Shit, I, See, At, Walmart, 

2000it [00:04, 494.37it/s]

TextField of length 18 with text: 
 		[Leo, Dicaprio, and, Jonah, Hill, share, a, high, -, five, after, swimming, with, a, topless, woman,
		in, Miami]
 		and TokenIndexers : {'tokens': 'SingleIdTokenIndexer'}
TextField of length 4 with text: 
 		[Disappointed, Generic, Brand, Cat]
 		and TokenIndexers : {'tokens': 'SingleIdTokenIndexer'}
TextField of length 10 with text: 
 		[This, is, on, a, fire, truck, near, my, house, .]
 		and TokenIndexers : {'tokens': 'SingleIdTokenIndexer'}
TextField of length 13 with text: 
 		[I, work, in, a, restaurant, ., This, is, our, Christopher, Walk, -, In]
 		and TokenIndexers : {'tokens': 'SingleIdTokenIndexer'}
TextField of length 5 with text: 
 		[Time, to, walk, funny, .]
 		and TokenIndexers : {'tokens': 'SingleIdTokenIndexer'}
TextField of length 7 with text: 
 		[News, sites, need, to, find, scarier, pedophiles]
 		and TokenIndexers : {'tokens': 'SingleIdTokenIndexer'}
TextField of length 9 with text: 
 		[Can, you, pass, me, the, ext, hard, d




In [15]:
len(train_ds)

2000

In [16]:
#vars(train_ds[0].fields["tokens"])

In [17]:
vocab = Vocabulary.from_instances(train_ds, max_vocab_size=config.max_vocab_size)

100%|██████████| 2000/2000 [00:00<00:00, 72693.47it/s]


## Prepare iterator
The iterator is responsible for batching the data and preparing it for input into the model. We'll use the BucketIterator that batches text sequences of smilar lengths together. 

In [23]:
from allennlp.data.iterators import BucketIterator

In [24]:
iterator = BucketIterator(batch_size=config.batch_size, 
                          sorting_keys=[("tokens", "num_tokens")],
                         )

We need to tell the iterator how to numericalize the text data. We do this by passing the vocabulary to the iterator. This step is easy to forget so be careful!

In [25]:
iterator.index_with(vocab)

## Read a sample batch

In [27]:
batch = next(iter(iterator(train_ds)))
batch

{'tokens': {'tokens': tensor([[  11,  625, 1349,  ...,   21,    0,    0],
          [1527, 2601,   10,  ...,   21,    0,    0],
          [  24, 1803,  378,  ..., 3283,    0,    0],
          ...,
          [ 730,  218,    7,  ...,    2,    0,    0],
          [5280,  852,  368,  ...,  134, 1055,    0],
          [  65,   12,   97,  ..., 4440,    2,    0]])},
 'label': tensor([[0., 1.],
         [0., 1.],
         [1., 0.],
         [1., 0.],
         [1., 0.],
         [0., 1.],
         [1., 0.],
         [0., 1.],
         [1., 0.],
         [0., 1.],
         [1., 0.],
         [1., 0.],
         [1., 0.],
         [1., 0.],
         [1., 0.],
         [1., 0.],
         [1., 0.],
         [0., 1.],
         [1., 0.],
         [1., 0.],
         [0., 1.],
         [0., 1.],
         [0., 1.],
         [1., 0.],
         [1., 0.],
         [0., 1.],
         [1., 0.],
         [1., 0.],
         [1., 0.],
         [1., 0.],
         [0., 1.],
         [1., 0.],
         [1., 0.],
  

In [22]:
batch["tokens"]["tokens"]

tensor([[5341, 2299, 1400,    5, 5342, 1512,    8,  214,    4,   81,  225,    0],
        [2978,    9, 2979,    9,  438, 1159,  456,  113,  117,    8, 1679,    0],
        [ 823, 6578,   61,   28,  827,    5,    7, 1090,    6,  603,   21,    0],
        [  87,   63,   29, 2119,  186,  463,  421, 2120,    9,   44,  841,    0],
        [2145,  398,   15,  208,   54,   51,    6,    3,  189,  160,    2,    0],
        [ 595, 5483, 2418,   24,  666,   24,   53,   29,  357,    5, 5484,  495],
        [5271, 5272,  193,    5,  511,  379,    6, 2101,  156,    8, 5273,    0],
        [  65, 2528,   31,    2,   65, 2528, 1545,    2,  251, 1628,   31,    2],
        [  22, 6329, 2026,    9, 6330,   25, 1878,  133,  100,    3,   67,    2],
        [4125,   15,   39,  365,   23, 1240, 4126,  565,   41, 1227,    2,    0],
        [5691, 5692,  743,    9, 1654, 1380,  422,   16,  324,  267,   32, 1469],
        [ 205,  115,    3, 3486, 3487,  721, 1264,  105,  935,    5, 1806,  587],
        [5759,  

In [19]:
batch["tokens"]["tokens"].shape


torch.Size([64, 9])

## Prepare Model

In [20]:
import torch
import torch.nn as nn
import torch.optim as optim

In [21]:
from allennlp.modules.seq2vec_encoders import Seq2VecEncoder, PytorchSeq2VecWrapper
from allennlp.nn.util import get_text_field_mask
from allennlp.models import Model
from allennlp.modules.text_field_embedders import TextFieldEmbedder

class BaselineModel(Model):
    def __init__(self, word_embeddings: TextFieldEmbedder,
                 encoder: Seq2VecEncoder,
                 out_sz: int=len(label_cols)):
        super().__init__(vocab)
        self.word_embeddings = word_embeddings
        self.encoder = encoder
        self.projection = nn.Linear(self.encoder.get_output_dim(), out_sz)
        self.loss = nn.BCEWithLogitsLoss()
        
    def forward(self, tokens: Dict[str, torch.Tensor],
                id: Any, label: torch.Tensor) -> torch.Tensor:
        mask = get_text_field_mask(tokens)
        embeddings = self.word_embeddings(tokens)
        state = self.encoder(embeddings, mask)
        class_logits = self.projection(state)
        
        output = {"class_logits": class_logits}
        output["loss"] = self.loss(class_logits, label)

        return output

## Prepare embeddings

In [22]:
from allennlp.modules.token_embedders import Embedding
from allennlp.modules.text_field_embedders import BasicTextFieldEmbedder

token_embedding = Embedding(num_embeddings=config.max_vocab_size + 2,
                            embedding_dim=300, padding_index=0)
# the embedder maps the input tokens to the appropriate embedding matrix
word_embeddings: TextFieldEmbedder = BasicTextFieldEmbedder({"tokens": token_embedding})

In [23]:
from allennlp.modules.seq2vec_encoders import PytorchSeq2VecWrapper
encoder: Seq2VecEncoder = PytorchSeq2VecWrapper(nn.LSTM(word_embeddings.get_output_dim(),
                                                        config.hidden_sz, bidirectional=True, batch_first=True))

Notice how simple and modular the code for initializing the model is. All the complexity is delegated to each component.

In [24]:
model = BaselineModel(
    word_embeddings, 
    encoder, 
)

In [25]:

if USE_GPU: model.cuda()
else: model

## Basic sanity checks

In [26]:
batch = nn_util.move_to_device(batch, 0 if USE_GPU else -1)

In [27]:
tokens = batch["tokens"]
labels = batch

In [28]:
tokens

{'tokens': tensor([[   91,  2875,    53,    58,    15,    39,  2898,  1115,     8],
         [ 1135,     2,   188, 21229,   416,    12,    60,  1225,     0],
         [  295,    38,    13,    96,    76,     4,   228,     8,     0],
         [   27,  7538,   313,   727, 69058,     2,  2677, 14582,     0],
         [24244, 43217,  3829, 24222,    12,   276,     7,  8052,     0],
         [ 1009,   127, 19055,   313,   727,  1373,   386,  7546,     0],
         [   23,   148,    24,   200,   785,     4, 11688,    82,  4730],
         [ 9757, 40138,    30,   723,   482,    12, 11502,   783,     0],
         [47124,  1797,  5317,    12, 21606,   883,   539,  4296,     0],
         [  795,   181,    17,   123,  6949,   128,    13,    82,     8],
         [  137,   515,    90, 14709, 11654,  2200,    93,    57,     0],
         [34352,   694,  2111,   197,     5,  1079,     3,   573,  8247],
         [ 1535, 24183,  1993,   345,    56,     3,   498,    32,     8],
         [ 7832, 37408,   65

In [29]:
mask = get_text_field_mask(tokens)
mask

tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 0],
 

In [30]:
embeddings = model.word_embeddings(tokens)
state = model.encoder(embeddings, mask)
class_logits = model.projection(state)
class_logits

tensor([[ 0.0751,  0.0196, -0.0549, -0.0456,  0.0791,  0.0380, -0.0497, -0.0488],
        [ 0.0764,  0.0202, -0.0552, -0.0445,  0.0785,  0.0383, -0.0515, -0.0478],
        [ 0.0749,  0.0197, -0.0531, -0.0472,  0.0794,  0.0389, -0.0503, -0.0482],
        [ 0.0743,  0.0196, -0.0538, -0.0466,  0.0768,  0.0382, -0.0497, -0.0488],
        [ 0.0748,  0.0196, -0.0548, -0.0445,  0.0776,  0.0386, -0.0497, -0.0461],
        [ 0.0753,  0.0180, -0.0540, -0.0455,  0.0793,  0.0408, -0.0514, -0.0491],
        [ 0.0743,  0.0199, -0.0560, -0.0453,  0.0783,  0.0391, -0.0495, -0.0477],
        [ 0.0751,  0.0200, -0.0548, -0.0454,  0.0789,  0.0416, -0.0512, -0.0491],
        [ 0.0742,  0.0215, -0.0541, -0.0447,  0.0768,  0.0397, -0.0499, -0.0493],
        [ 0.0755,  0.0185, -0.0542, -0.0460,  0.0788,  0.0377, -0.0501, -0.0485],
        [ 0.0752,  0.0196, -0.0542, -0.0441,  0.0788,  0.0407, -0.0497, -0.0469],
        [ 0.0750,  0.0179, -0.0548, -0.0453,  0.0808,  0.0416, -0.0522, -0.0485],
        [ 0.0756

In [31]:
model(**batch)

{'class_logits': tensor([[ 0.0751,  0.0196, -0.0549, -0.0456,  0.0791,  0.0380, -0.0497, -0.0488],
         [ 0.0764,  0.0202, -0.0552, -0.0445,  0.0785,  0.0383, -0.0515, -0.0478],
         [ 0.0749,  0.0197, -0.0531, -0.0472,  0.0794,  0.0389, -0.0503, -0.0482],
         [ 0.0743,  0.0196, -0.0538, -0.0466,  0.0768,  0.0382, -0.0497, -0.0488],
         [ 0.0748,  0.0196, -0.0548, -0.0445,  0.0776,  0.0386, -0.0497, -0.0461],
         [ 0.0753,  0.0180, -0.0540, -0.0455,  0.0793,  0.0408, -0.0514, -0.0491],
         [ 0.0743,  0.0199, -0.0560, -0.0453,  0.0783,  0.0391, -0.0495, -0.0477],
         [ 0.0751,  0.0200, -0.0548, -0.0454,  0.0789,  0.0416, -0.0512, -0.0491],
         [ 0.0742,  0.0215, -0.0541, -0.0447,  0.0768,  0.0397, -0.0499, -0.0493],
         [ 0.0755,  0.0185, -0.0542, -0.0460,  0.0788,  0.0377, -0.0501, -0.0485],
         [ 0.0752,  0.0196, -0.0542, -0.0441,  0.0788,  0.0407, -0.0497, -0.0469],
         [ 0.0750,  0.0179, -0.0548, -0.0453,  0.0808,  0.0416, -0.0522

In [32]:
loss = model(**batch)["loss"]

# Train Model

In [33]:
optimizer = optim.Adam(model.parameters(), lr=config.lr)

In [34]:
from allennlp.training.trainer import Trainer

trainer = Trainer(
    model=model,
    optimizer=optimizer,
    iterator=iterator,
    train_dataset=train_ds,
    cuda_device=0 if USE_GPU else -1,
    num_epochs=config.epochs,
)

In [35]:
metrics = trainer.train()

loss: 0.1857 ||: 100%|██████████| 4884/4884 [59:40<00:00,  1.09s/it]
loss: 0.1038 ||: 100%|██████████| 4884/4884 [1:32:59<00:00,  1.10s/it]


# Generating Predictions

In [36]:
from allennlp.data.iterators import DataIterator
from tqdm import tqdm
from scipy.special import expit # the sigmoid function

def tonp(tsr): return tsr.detach().cpu().numpy()

class Predictor:
    def __init__(self, model: Model, iterator: DataIterator,
                 cuda_device: int=-1) -> None:
        self.model = model
        self.iterator = iterator
        self.cuda_device = cuda_device
        
    def _extract_data(self, batch) -> np.ndarray:
        out_dict = self.model(**batch)
        return expit(tonp(out_dict["class_logits"]))
    
    def predict(self, ds: Iterable[Instance]) -> np.ndarray:
        pred_generator = self.iterator(ds, num_epochs=1, shuffle=False)
        self.model.eval()
        pred_generator_tqdm = tqdm(pred_generator,
                                   total=self.iterator.get_num_batches(ds))
        preds = []
        with torch.no_grad():
            for batch in pred_generator_tqdm:
                batch = nn_util.move_to_device(batch, self.cuda_device)
                preds.append(self._extract_data(batch))
        return np.concatenate(preds, axis=0)

In [37]:
from allennlp.data.iterators import BasicIterator
# iterate over the dataset without changing its order
seq_iterator = BasicIterator(batch_size=64)
seq_iterator.index_with(vocab)

In [38]:
predictor = Predictor(model, seq_iterator, cuda_device=0 if USE_GPU else -1)
train_preds = predictor.predict(train_ds) 
#test_preds = predictor.predict(test_ds)

100%|██████████| 4884/4884 [01:34<00:00, 51.53it/s]


# A Final Note on Predictors

AllenNLP also provides predictors that take strings as input and outputs model predictions. They're handy if you want to create simple demo or need to make predictions on entirely new data, but since we've already read data as datasets and want to preserve their order, we didn't use them above.

Need to make a reader that will convert the string in the format the that model understands. In this case, similar to the train data. The model is looking for an instance not a string!

In [39]:
import numpy as np
import pandas as pd

In [40]:
fname = 'reddit_thread_label_full_test'

In [41]:
test_ds = reader.read(fname) 
predictor = Predictor(model, seq_iterator, cuda_device=0 if USE_GPU else -1)
test_preds = predictor.predict(test_ds) 
df=pd.read_pickle(fname)
score = np.argmax(test_preds,axis=1)
df['temp'] = df[label_cols].apply(lambda x: np.array(x[label_cols]),axis=1)
df['label'] = df.temp.apply(lambda x:np.argmax(x))
print(f'Score {sum(df.label == score)/len(score)}')

8000it [00:08, 895.84it/s] 
100%|██████████| 125/125 [00:02<00:00, 49.53it/s]


Score 0.686125


In [48]:
df['label'].tail()

7995    7
7996    7
7997    7
7998    7
7999    7
Name: label, dtype: int64

In [42]:
# for label in label_cols:
#     score = np.argmax(train_preds,axis=1)
#     argument = np.argmax(np.array(label_cols) == label)
#     score = np.sum(score == argument)/len(train_preds)
    
#     print(f'{label} has a score of {score}')

In [43]:
# from allennlp.data.fields import TextField, MetadataField, ArrayField

# class NewDatasetReader(DatasetReader):
#     def __init__(self, tokenizer: Callable[[str], List[str]]=lambda x: x.split(),
#                  token_indexers: Dict[str, TokenIndexer] = None,
#                  max_seq_len: Optional[int]=config.max_seq_len) -> None:
#         super().__init__(lazy=False)
#         self.tokenizer = tokenizer
#         self.token_indexers = token_indexers or {"tokens": SingleIdTokenIndexer()}
#         self.max_seq_len = max_seq_len

#     @overrides
#     def text_to_instance(self, tokens: List[Token],
#                          labels: np.ndarray=None,
#                         id: str=None,) -> Instance:
#         sentence_field = TextField(tokens, self.token_indexers)
#         fields = {"tokens": sentence_field}
        
#         fields["label"] = ArrayField(array=labels)
    
#         id_field = MetadataField(np.random.randint(1000))
#         fields["id"] = id_field
#         return Instance(fields)
    
#     @overrides
#     def _read(self, string: str) -> Iterator[Instance]:
#         # Imports data into the tokenizer
#         #for string in strings:
#         print(string)
#         yield self.text_to_instance([Token(x) for x in self.tokenizer(str(string))])
            

In [44]:
# test_reader = NewDatasetReader(
#     tokenizer=tokenizer,
#     token_indexers={"tokens": token_indexer})

In [45]:
# test_reader.read("this tutorial was great!")

In [46]:
# from allennlp.predictors import SentenceTaggerPredictor
# predictor = SentenceTaggerPredictor(model, dataset_reader=reader)
# tag_logits = predictor.predict("The dog ate the apple")['tag_logits']
# tag_ids = np.argmax(tag_logits, axis=-1)
# print([model.vocab.get_token_from_index(i, 'labels') for i in tag_ids])