# Tritonlytics Multilabel Classification - Standard CSS Themes

Experiments related to building a LM and multilabel classification model for survey comments captured in the Tritonlytics survey delivery system

In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

import sys
sys.path.append('..')

In [None]:
from fastai.text import *   # Quick accesss to NLP functionality
from fastai.callbacks import *

import pdb
from tritonlytics import Metrics as metrics_util, DataGeneration as dg_util, PandasUtil as pd_util
from tritonlytics.evaluation import *
from tritonlytics.callbacks import RocAucEvaluation

import dill as pickle

import spacy
spacy_en = spacy.load('en')
spacy_es = spacy.load('es')

# pandas and plotting config
import seaborn as sns
sns.set_style('whitegrid')

plt.rcParams['figure.figsize'] = (9,6)

pd.set_option('display.max_rows', 500)
pd.set_option('display.max_columns', 500)
pd.set_option('display.width', 1000)
pd.set_option('display.max_colwidth', 100)

In [None]:
print(f'fastai version: {__version__}')

In [None]:
torch.cuda.set_device(1)
print(f'Using GPU #{torch.cuda.current_device()}')

## Utility methods

In [None]:
def convert_to_snakecase(name):
    s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name)
    return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower().replace('__', '_') 

In [None]:
# https://drive.google.com/file/d/0B1yuv8YaUVlZZ1RzMFJmc1ZsQmM/view : apostrophe lookup dict
appos_regex_repl = {
    r"\baren't\b" : "are not",
    r"\bcan't\b" : "cannot",
    r"\bcouldn't\b" : "could not",
    r"\bdidn't\b" : "did not",
    r"\bdoesn't\b" : "does not",
    r"\bdon't\b" : "do not",
    r"\bhadn't\b" : "had not",
    r"\bhasn't\b" : "has not",
    r"\bhaven't\b" : "have not",
    r"\bhe'd\b" : "he would",
    r"\bhe'll\b" : "he will",
    r"\bhe's\b" : "he is",
    r"\bi'd\b" : "I would",
    r"\bi'd\b" : "I had",
    r"\bi'll\b" : "I will",
    r"\bi'm\b" : "I am",
    r"\bisn't\b" : "is not",
    r"\bits\b" : "it is",
    r"\bit's\b" : "it is",
    r"\bit'll\b" : "it will",
    r"\bi've\b" : "I have",
    r"\blet's\b" : "let us",
    r"\bmightn't\b" : "might not",
    r"\bmustn't\b" : "must not",
    r"\bshan't\b" : "shall not",
    r"\bshe'd\b" : "she would",
    r"\bshe'll\b" : "she will",
    r"\bshe's\b" : "she is",
    r"\bshouldn't\b" : "should not",
    r"\bthat's\b" : "that is",
    r"\bthere's\b" : "there is",
    r"\bthey'd\b" : "they would",
    r"\bthey'll\b" : "they will",
    r"\bthey're\b" : "they are",
    r"\bthey've\b" : "they have",
    r"\bwe'd\b" : "we would",
    r"\bwe're\b" : "we are",
    r"\bweren't\b" : "were not",
    r"\bwe've\b" : "we have",
    r"\bwhat'll\b" : "what will",
    r"\bwhat're\b" : "what are",
    r"\bwhat's\b" : "what is",
    r"\bwhat've\b" : "what have",
    r"\bwhere's\b" : "where is",
    r"\bwho'd\b" : "who would",
    r"\bwho'll\b" : "who will",
    r"\bwho're\b" : "who are",
    r"\bwho's\b" : "who is",
    r"\bwho've\b" : "who have",
    r"\bwon't\b" : "will not",
    r"\bwouldn't\b" : "would not",
    r"\byou'd\b" : "you would",
    r"\byou'll\b" : "you will",
    r"\byou're\b" : "you are",
    r"\byou've\b" : "you have",
    r"\b're\b" : " are",
    r"\bwasn't\b" : "was not",
    r"\bwe'll\b" : "will",
    r"\bdidn't\b" : "did not",
    r"\btryin'\b" : "trying"
}

In [None]:
# based on https://www.kaggle.com/prashantkikani/pooled-gru-with-preprocessing
emoji_str_repls = {
    "&lt;3": " love ",
    ":]" : " happy ",
    "=)" : " happy ",
    "8)": " happy ",
    ":-)": " happy ",
    ":)": " happy ",
    "(-:": " happy ",
    "(:": " happy ",
    ":&gt;": " happy ",
    ":')": " happy ",
    "(:" : " happy ",
    ":d": " laughing ",
    ":dd": " laughing ",
    ";-)" : " wink ",
    ";)": " wink ",
    ":p": " playful ",
    ":o" : " surprise ",
    ":-(": " sad ",
    ":(": " sad ",
    "=(" : " sad ",
    "):" : " sad ",
    ":/": " skeptical ",
    ":s": " skeptical ",
    ":-s": " skeptical ",
    "^^": " nervous ",
    "^_^": " nervous ",
    "-_-" : " shame ",
}

In [None]:
spelling_regex_repls = {
    # abbreviations
    r"\bacctg\b" : "acct",
    r"\badd'l\b" : "additional",
    r"\br\s\b": "are",
    r"\bu\s\b": "you ",
    r"\b\sm\s\b ": "am",
    r"'cause\b" : "because",
    r"\b(ha)+\b": "haha",
    r"\b(he)+\b": "haha",
    r"\bya+y\b": "yay",
    r"\bwa+y\b": "way",
    r"\bf'real\b" : "for real",
    r"\bgr8\b" : "great",
    r"\bintl\b" : "int'l",
    # common misspellings
    r"\bbailable\b" : "available",
    r"\babilty\b" : "ability",
    r"\babsolutly\b" : "absolutely",
    r"\babsoultely\b" : "absolutely",
    r"\bacces\b" : "access",
    r"\baccesability\b" : "accessibility",
    r"\baccesbility\b" : "accessibility",
    r"\baccesibility\b" : "accessibility",
    r"\baccessability\b" : "accessibility",
    r"\baccessbility\b" : "accessibility",
    r"\baccesable\b" : "accessible",
    r"\baccesible\b" : "accessible",
    r"\baccessable\b" : "accessible",
    r"\bacessible\b" : "accessible",
    r"\bassessable\b" : "availability",
    r"\baccidently\b" : "accidentally",
    r"\baccomadate\b" : "accommodate",
    r"\baccomdate\b" : "accommodate",
    r"\baccomidate\b" : "accommodate",
    r"\baccomodate\b" : "accommodate",
    r"\baccomadating\b" : "accommodating",
    r"\baccomidating\b" : "accommodating",
    r"\baccomodating\b" : "accommodating",
    r"\baccomadations\b" : "accommodations",
    r"\baccomodation\b" : "accommodation",
    r"\baccouting\b" : "accounting",
    r"\baccross\b" : "across",
    r"\badd'l\b" : "additional",
    r"\badditonal\b" : "additional",
    r"\baddtionally\b" : "additionally",
    r"\badminstration\b" : "administration",
    r"\badminstrative\b" : "administrative",
    r"\badminstrator\b" : "administrator",
    r"\badress\b" : "address",
    r"\badvancment\b" : "advancement",
    r"\badvertized\b" : "advertised",
    r"\bafforable\b" : "affordable",
    r"\bafordable\b" : "affordable",
    r"\bafterall\b" : "after all",
    r"\bafterhours\b" : "after hours",
    r"\baggresive\b" : "aggressive",
    r"\bagressive\b" : "aggressive",
    r"\bagressions\b" : "aggressions",
    r"\balittle\b" : "a little",
    r"\balll\b" : "all",
    r"\balloted\b" : "allotted",
    r"\ballthough\b" : "although",
    r"\balthought\b" : "although",
    r"\ballways\b" : "always",
    r"\balos\b" : "also",
    r"\balot\b" : "a lot",
    r"\balotted\b" : "allotted",
    r"\bammount\b" : "amount",
    r"\bammounts\b" : "amounts",
    r"\bamoung\b" : "among",
    r"\bamoungst\b" : "amongst",
    r"\bannouncment\b" : "announcement",
    r"\baparments\b" : "apartments",
    r"\bapparrel\b" : "apparel",
    r"\bappartment\b" : "apartment",
    r"\bappriciate\b" : "appreciate",
    r"\bassitance\b" : "assistance",
    r"\bassitant\b" : "assistant",
    r"\batleast\b" : "at least",
    r"\battentative\b" : "attentive",
    r"\battrocious\b" : "atrocious",
    r"\bavaiable\b" : "available",
    r"\bavaible\b" : "available",
    r"\bavailabe\b" : "available",
    r"\bavailble\b" : "available",
    r"\bavailiable\b" : "available",
    r"\bavailible\b" : "available",
    r"\bavaliable\b" : "available",
    r"\bavalible\b" : "available",
    r"\bavilable\b" : "available",
    r"\bavailiability\b" : "availability",
    r"\bavailabiltiy\b" : "availability",
    r"\bavailabilty\b" : "availability",
    r"\bavailablility\b" : "availability",
    r"\bavailablity\b" : "availability",
    r"\bavailibility\b" : "availability",
    r"\bavaliability\b" : "availability",
    r"\bavaliablity\b" : "availability",
    r"\bavalibility\b" : "availability",
    r"\bactivies\b" : "activities",
    r"\bactivites\b" : "activities",
    r"\bactualy\b" : "actually",
    r"\bacutally\b" : "actually",
    r"\bammenities\b" : "amenities",
    r"\bantoher\b" : "another",
    r"\bassitant\b" : "assistant",
    r"\baswell\b" : "as well",
    r"\baweful\b" : "awful",
    r"\bawfull\b" : "awful",
    r"\bawsome\b" : "awesome",
    r"\bbeacuse\b" : "because",
    r"\bbearly\b" : "barely",
    r"\bbeaurocracy\b" : "bureaucracy",
    r"\bbeaurocratic\b" : "bureaucratic",
    r"\bbecasue\b" : "because",
    r"\bbecuase\b" : "because",
    r"\bbecuse\b" : "because",
    r"\bbefor\b" : "before",
    r"\bbeggining\b" : "beginning",
    r"\bbegining\b" : "beginning",
    r"\bbeleive\b" : "believe",
    r"\bbelive\b" : "believe",
    r"\bbenificial\b" : "beneficial",
    r"\bbenifit\b" : "benefit",
    r'\bbugetary\b' : "budgetary",
    r'\bbuiding\b' : "building",
    r'\bbuidling\b' : "building",
    r'\bbuisness\b' : "business",
    r'\bbuliding\b' : "building",
    r"\bbureacracy\b" : "bureaucracy",
    r"\bburitto\b" : "burrito",
    r"\bbussiness\b" : "business",
    r"\bcalender\b" : "calendar",
    r"\bcan;t\b" : "can't",
    r"\bcasher\b" : "cashier",
    r'\bcatagories\b' : "categories",
    r'\bcatagory\b' : "category",
    r"\bcheapter\b" : "cheaper",
    r"\bcheeper\b" : "cheaper",
    r'\bclasss\b' : "clas",
    r'\bclassses\b' : "classes",
    r"\bcleaniness\b" : "cleanliness",
    r"\bcmapus\b" : "campus",
    r'\bcofee\b' : "coffee",
    r'\bcoffe\b' : "coffee",
    r'\bcollegue\b' : "colleague",
    r'\bcoment\b' : "comment",
    r'\bcoments\b' : "comments",
    r'\bcomming\b' : "coming",
    r'\bcommittment\b' : "commitment",
    r'\bcommment\b' : "comment",
    r'\bcommuication\b' : "communication",
    r'\bcommunter\b' : "commuter",
    r'\bcommunters\b' : "commuters",
    r'\bcomotion\b' : "commotion",
    r'\bcomparision\b' : "comparison",
    r'\bcompatability\b' : "compatibility",
    r'\bcompatable\b' : "compatible",
    r'\bcompetative\b' : "competitive",
    r'\bcompetetive\b' : "competitive",
    r'\bcompetive\b' : "competitive",
    r'\bcompletly\b' : "completely",
    r"\bcomraderie\b" : "camaraderie",
    r'\bcomradery\b' : "camaraderie",
    r'\bcomunication\b' : "communication",
    r'\bcomunity\b' : "community",
    r'\bconcious\b' : "conscious",
    r'\bcondusive\b' : "conducive",
    r'\bconection\b' : "connection",
    r"\bconfortable\b" : "comfortable",
    r'\bconsistant\b' : "consistent",
    r'\bconsistantly\b' : "consistently",
    r'\bconsistenly\b' : "consistently",
    r'\bcontinously\b' : "continuously",
    r'\bcontruction\b' : "construction",
    r'\bconveinent\b' : "convenient",
    r'\bconveinient\b' : "convenient",
    r'\bconveniant\b' : "convenient",
    r'\bconveniece\b' : "convenience",
    r'\bconveninent\b' : "convenient",
    r'\bconvienance\b' : "convenience",
    r'\bconvienant\b' : "convenient",
    r'\bconvience\b' : "convenience",
    r'\bconvienence\b' : "convenience",
    r'\bconvienent\b' : "convenient",
    r'\bconvienet\b' : "convenient",
    r'\bconvienience\b' : "convenience",
    r'\bconvienient\b' : "convenient",
    r'\bconvient\b' : "convenient",
    r'\bconviently\b' : "conveniently",
    r'\bconvinence\b' : "convenience",
    r'\bconvinent\b' : "convenient",
    r'\bconvinience\b' : "convenience",
    r'\bconvinient\b' : "convenient",
    r'\bcorteous\b' : "courteous",
    r'\bcostodial\b' : "custodial",
    r'\bcoureous\b' : "courteous",
    r'\bcourtis\b' : "courteous",
    r'\bcouteous\b' : "courteous",
    r'\bcovenient\b' : "convenient",
    r'\bcroweded\b' : "crowded",
    r'\bcurteous\b' : "courteous",
    r'\bcurtesy\b' : "courtesy",
    r'\bcurtious\b' : "courteous",
    r"\bdeaprtment\b" : "department",
    r"\bdecission\b" : "decision",
    r'\bdefinately\b' : "definitely",
    r'\bdefinetely\b' : "definitely",
    r'\bdefinetly\b' : "definitely",
    r'\bdefinitley\b' : "definitely",
    r'\bdefinitly\b' : "definitely",
    r'\bdelievered\b' : "delivered",
    r'\bdeliverers\b' : "deliveries",
    r'\bdeparment\b' : "department",
    r'\bdeparments\b' : "department",
    r'\bdepartement\b' : "department",
    r"\bdepartment\(s\b" : "departments",
    r'\bdepartmet\b' : "department",
    r'\bdepratment\b' : "department",
    r"\bdeptartment\b" : "department",
    r'\bdescrimination\b' : "discrimination",
    r'\bdesireable\b' : "desirable",
    r"\bdiffernt\b" : "different",
    r"\bdiffrent\b" : "different",
    r'\bdinig\b' : "dining",
    r'\bdirverse\b' : "diverse",
    r'\bdisapointed\b' : "disappointed",
    r'\bdisapointing\b' : "disappointing",
    r'\bdisasterous\b' : "disastrous",
    r'\bdisatisfied\b' : "dissatisfied",
    r'\bdisbursment\b' : "disbursement",
    r'\bdisbursments\b' : "disbursements",
    r'\bdiscretely\b' : "discreetly",
    r'\bdiscusting\b' : "disgusting",
    r'\bdisfunctional\b' : "dysfunctional",
    r'\bdispensors\b' : "dispensers",
    r'\bdispersement\b' : "disbursement",
    r'\bdissapointed\b' : "disappointed",
    r'\bdissapointing\b' : "disappointing",
    r'\bdissapointment\b' : "disappointment",
    r'\bdissappointed\b' : "disappointed",
    r'\bdissappointing\b' : "disappointing",
    r'\bdissatified\b' : "dissatisfied",
    r'\bdiveristy\b' : "diversity",
    r'\bdivison\b' : "division",
    r'\bdivsion\b' : "division",
    r"\bdoens't\b" : "doesn't",
    r"\bdoes't\b" : "doesn't",
    r"\bdoesn;t\b" : "doesn't",
    r"\bdon;t\b" : "don't",
    r'\bdonot\b' : "do not",
    r"\bdosen't\b" : "doesn't",
    r"\bdosent\b" : "doesn't",
    r'\bdumbells\b' : "dumbbells",
    r'\bdurring\b' : "during",
    r"\beatting\b" : "eating",
    r"\beduation\b" : "education",
    r'\beffeciency\b' : "efficiency",
    r'\beffecient\b' : "efficient",
    r'\befficency\b' : "efficiency",
    r'\befficent\b' : "efficient",
    r'\beffiecient\b' : "efficient",
    r'\beimplying\b' : "implying",
    r'\bembarassed\b' : "embarrassed",
    r'\bembarassing\b' : "embarrassing",
    r'\bembarassment\b' : "embarrassment",
    r'\bemploee\b' : "employee",
    r'\bemploye\b' : "employee",
    r'\bemployee\(s\b' : "employees",
    r'\bemployeed\b' : "employed",
    r'\bemployement\b' : "employment",
    r'\bemployes\b' : "employees",
    r'\bemployess\b' : "employees",
    r'\bemplyee\b' : "employee",
    r'\bemplyees\b' : "employees",
    r'\bempolyees\b' : "employees",
    r'\bencoutered\b' : "encountered",
    r'\benought\b' : "enough",
    r'\benrollement\b' : "enrollment",
    r'\benviorment\b' : "environment",
    r'\benviornment\b' : "environment",
    r'\benvirnment\b' : "environment",
    r'\benviroment\b' : "environment",
    r'\benvironement\b' : "environment",
    r'\bequiped\b' : "equipped",
    r'\bespcially\b' : "especially",
    r'\bespecailly\b' : "especially",
    r'\bespecialy\b' : "especially",
    r'\bespeically\b' : "especially",
    r"\besthetically\b" : "aesthetically ",
    r"\bethinicity\b" : "ethnicity",
    r"\bevaulation\b" : "evaluation",
    r"\beventhough\b" : "even though",
    r'\beverday\b' : "every day",
    r'\beverthing\b' : "everything",
    r'\beveryones\b' : "everyones",
    r'\beverythings\b' : "everythings",
    r'\beveryway\b' : "every way",
    r'\beveyone\b' : "everyone",
    r'\beveything\b' : "everything",
    r'\bevrything\b' : "everything",
    r'\bexcelent\b' : "excellent",
    r'\bexcellant\b' : "excellent",
    r'\bexellent\b' : "excellent",
    r'\bexhorbitant\b' : "exorbitant",
    r'\bexistance\b' : "existence",
    r'\bexpecially\b' : "especially",
    r'\bexpensice\b' : "expensive",
    r'\bexpereince\b' : "experience",
    r'\bexperiance\b' : "experience",
    r'\bexperince\b' : "experience",
    r'\bexpierence\b' : "experience",
    r'\bexpirence\b' : "experience",
    r'\bexplaination\b' : "explanation",
    r'\bexremely\b' : "extremely",
    r'\bextemely\b' : "extremely",
    r'\bextention\b' : "extension",
    r'\bextermely\b' : "extremely",
    r'\bextreamly\b' : "extremely",
    r'\bextrememly\b' : "extremely",
    r'\bextremly\b' : "extremely",
    r"\bfacilites\b" : "facilities",
    r'\bfacilties\b' : "facilities",
    r'\bfacilty\b' : "facility",
    r'\bfaculity\b' : "faculty",
    r'\bfacutly\b' : "faculty",
    r'\bfiancial\b' : "financial",
    r"\bfinacial\b" : "financial",
    r"\bfirendly\b" : "friendly",
    r'\bflexability\b' : "flexibility",
    r'\bflexibilty\b' : "flexibility",
    r'\bflexiblity\b' : "flexibility",
    r"\bflourescent\b" : "fluorescent",
    r'\bfreindly\b' : "friendly",
    r'\bfreqency\b' : "frequency",
    r'\bfreqent\b' : "frequent",
    r'\bfriednly\b' : "friendly",
    r'\bfrusterating\b' : "frustrating",
    r'\bfrusturating\b' : "frustrating",
    r'\bfustrating\b' : "frustrating",
    r'\bgovenor\b' : "governor",
    r"\bgraffitti\b" : "graffiti",
    r"\bgrafitti\b" : "graffiti",
    r"\bgreatful\b" : "grateful",
    r"\bguarenteed\b" : "guaranteed",
    r"\bguidlines\b" : "guidelines",
    r"\bguranteed\b" : "guaranteed",
    r"\bhappend\b" : "happened",
    r'\bharrass\b' : "harass",
    r'\bharrassed\b' : "harassed",
    r'\bharrassing\b' : "harassing",
    r'\bharrassment\b' : "harassment",
    r"\bhavn't\b" : "haven't",
    r'\bhealtheir\b' : "healthier",
    r'\bhealthly\b' : "healthy",
    r'\bhealtier\b' : "healthier",
    r'\bhealty\b' : "healthy",
    r'\bheathy\b' : "healthy",
    r'\bheirarchy\b' : "hierarchy",
    r'\bhelful\b' : "helpful",
    r'\bhelpfull\b' : "helpful",
    r'\bhelpul\b' : "helpful",
    r'\bhighschool\b' : "high school",
    r'\bhighschools\b' : "high schools",
    r'\bhorendous\b' : "horrendous",
    r'\bhorible\b' : "horrible",
    r'\bhouseing\b' : "housing",
    r'\bi"m\b' : "i'm",
    r'\bi"ve\b' : "i've",
    r'\bimplimented\b' : "implemented",
    r'\bimporve\b' : "improve",
    r'\bimposible\b' : "impossible",
    r'\bimprovment\b' : "improvement",
    r'\bimprovments\b' : "improvements",
    r'\bincompetant\b' : "incompetent",
    r'\binconsistant\b' : "inconsistent",
    r'\binconveinent\b' : "nconvenient",
    r'\binconvience\b' : "inconvenience",
    r'\binconvienent\b' : "nconvenient",
    r'\binconvienient\b' : "nconvenient",
    r'\binconvient\b' : "nconvenient",
    r'\binconvinient\b' : "nconvenient",
    r'\bindentify\b' : "identify",
    r'\bindependant\b' : "independent",
    r'\bindividual\(s\b' : "individuals",
    r'\binforced\b' : "enforced",
    r'\binformaiton\b' : "information",
    r'\binformtion\b' : "information",
    r'\binfront\b' : "in front",
    r'\binnout\b' : "in-n-out",
    r'\binsentive\b' : "incentive",
    r'\binsufficent\b' : "insufficient",
    r'\binterenet\b' : "internet",
    r'\binterent\b' : "internet",
    r'\bintermural\b' : "intramural",
    r'\bintramurals\b' : "intramurals",
    r'\binvironment\b' : "environment",
    r'\bissue\(s\b' : "issues",
    r'\bit;s\b' : "it's",
    r'\bitem\(s\b' : "items",
    r"\bjob\(s\b" : "jobs",
    r'\bknowledable\b' : "knowledgeable",
    r'\bknowledeable\b' : "knowledgeable",
    r'\bknowledegable\b' : "knowledgeable",
    r'\bknowledgable\b' : "knowledgeable",
    r'\bknowledgably\b' : "knowledgeably",
    r'\bknowledgeably\b' : "knowledgeably",
    r'\bknowledgeble\b' : "knowledgeable",
    r'\bknowlegable\b' : "knowledgeable",
    r'\bknowlegeable\b' : "knowledgeable",
    r'\bliek\b' : "like",
    r'\blieke\b' : "like",
    r'\blimted\b' : "limited",
    r'\bmaintainance\b' : "maintenance",
    r'\bmaintaince\b' : "maintenance",
    r'\bmaintainence\b' : "maintenance",
    r'\bmaintanance\b' : "maintenance",
    r'\bmaintance\b' : "maintenance",
    r'\bmaintanence\b' : "maintenance",
    r'\bmaintenace\b' : "maintenance",
    r'\bmaintenances\b' : "maintenance",
    r'\bmaintence\b' : "maintenance",
    r'\bmaintenece\b' : "maintenance",
    r'\bmaintenence\b' : "maintenance",
    r'\bmaitenance\b' : "maintenance",
    r'\bmanager\(s\b' : "managers",
    r'\bmanagment\b' : "management",
    r'\bmanangement\b' : "management",
    r'\bmangement\b' : "management",
    r'\bmangers\b' : "managers",
    r'\bmanuever\b' : "maneuver",
    r'\bmintues\b' : "minutes",
    r'\bmoblie\b' : "mobile",
    r'\bmulitple\b' : "multiple",
    r'\bn\?a\b' : "n/a",
    r'\bna\b' : "n/a",
    r'\bneccessary\b' : "necessary",
    r'\bnecesary\b' : "necessary",
    r'\bneedes\b' : "needs",
    r'\bneeed\b' : "need",
    r'\bnonexistant\b' : "nonexistent",
    r'\bnothig\b' : "nothing",
    r'\bnothjng\b' : "nothing",
    r'\bnoticable\b' : "noticeable",
    r'\bobsurd\b' : "absurd",
    r'\bocassional\b' : "occasional",
    r'\boccassion\b' : "occasion",
    r'\boccassional\b' : "occasional",
    r'\boccassionally\b' : "occasionally",
    r'\boccassions\b' : "occasions",
    r'\boccations\b' : "occasions",
    r'\boccurances\b' : "occurrences",
    r'\boccured\b' : "occurred",
    r'\boccuring\b' : "occurring",
    r'\boccurr\b' : "occur",
    r'\bofcourse\b' : "of course",
    r'\bofferred\b' : "offered",
    r'\bopinon\b' : "opinion",
    r'\bopitions\b' : "options",
    r'\boportunities\b' : "opportunities",
    r'\bopperation\b' : "operation",
    r'\boppertunities\b' : "opportunities",
    r'\boppinion\b' : "opinion",
    r'\bopportunites\b' : "opportunities",
    r'\bopportunties\b' : "opportunities",
    r'\boppotunities\b' : "opportunities",
    r'\boppurtunities\b' : "opportunities",
    r'\boppurtunity\b' : "opportunity",
    r'\borgnized\b' : "organized",
    r'\boutragous\b' : "outrageous",
    r'\bpage\(s\b' : "pages",
    r'\bpakages\b' : "packages",
    r'\bparkibg\b' : "parking",
    r'\bparkig\b' : "parking",
    r'\bparkign\b' : "parking",
    r'\bparkinglots\b' : "parking lots",
    r'\bpartime\b' : "part-time",
    r'\bparttime\b' : "part-time",
    r'\bpatroling\b' : "patrolling",
    r'\bpeopel\b' : "people",
    r'\bpermitt\b' : "permit",
    r'\bperson\(s\b' : "persons",
    r'\bpersonel\b' : "personnel",
    r'\bpersonell\b' : "personnel",
    r'\bpharamcy\b' : "pharmacy",
    r'\bpleasent\b' : "pleasant",
    r'\bplently\b' : "plenty",
    r'\bplesant\b' : "pleasant",
    r'\bpositon\b' : "position",
    r'\bposses\b' : "possess",
    r'\bpossition\b' : "position",
    r'\bpostion\b' : "position",
    r'\bpostions\b' : "positions",
    r'\bpostition\b' : "position",
    r'\bpostive\b' : "positive",
    r'\bpractioner\b' : "practitioner",
    r'\bpractioners\b' : "practitioners",
    r'\bprefered\b' : "preferred",
    r'\bpreferrably\b' : "preferably",
    r'\bpreform\b' : "perform",
    r'\bpreforming\b' : "performing",
    r'\bpricess\b' : "prices",
    r'\bpriciples\b' : "principles",
    r'\bpricy\b' : "pricey",
    r'\bprking\b' : "parking",
    r'\bproceedures\b' : "procedures",
    r'\bprocurment\b' : "procurement",
    r'\bprofessionaly\b' : "professionally",
    r'\bproffessional\b' : "professional",
    r'\bproffit\b' : "profit",
    r'\bprofitt\b' : "profit",
    r'\bprogam\b' : "program",
    r'\bpromissed\b' : "promised",
    r'\bpublically\b' : "publicly",
    r'\bqucik\b' : "quick",
    r'\bquestion\(s\b' : "questions",
    r'\bquestionaire\b' : "questionnaire",
    r'\breall\b' : "really",
    r'\brealy\b' : "really",
    r'\breccomend\b' : "recommend",
    r'\breccommend\b' : "recommend",
    r'\breceieve\b' : "receive",
    r'\breciept\b' : "receipt",
    r'\breciepts\b' : "receipts",
    r'\brecieve\b' : "receive",
    r'\brecieved\b' : "received",
    r'\brecieves\b' : "receives",
    r'\brecieving\b' : "receiving",
    r'\brecived\b' : "received",
    r'\brecomend\b' : "recommend",
    r'\brecomended\b' : "recommended",
    r'\brediculous\b' : "ridiculous",
    r'\brediculously\b' : "ridiculously",
    r'\brefered\b' : "referred",
    r'\brefering\b' : "referring",
    r'\bregeants\b' : "regents",
    r'\bregistar\b' : "regisrtar",
    r'\bregistars\b' : "regisrtars",
    r'\bregulary\b' : "regularly",
    r'\breimbursment\b' : "reimbursement",
    r'\breponse\b' : "response",
    r'\breponsive\b' : "responsive",
    r'\brepresentitive\b' : "representative",
    r'\breserach\b' : "research",
    r'\bresonable\b' : "reasonable",
    r'\bresouces\b' : "resources",
    r'\bresourses\b' : "resources",
    r'\bresponsed\b' : "responded",
    r'\bresponsibilites\b' : "responsibilites",
    r'\bresponsiblities\b' : "responsibilites",
    r'\bresponsiblity\b' : "responsibility",
    r'\brestaraunts\b' : "restaurants",
    r'\brestraunts\b' : "restaurants",
    r'\brestuarant\b' : "restaurant",
    r'\brestuarants\b' : "restaurants",
    r'\bresturant\b' : "restaurant",
    r'\bresturants\b' : "restaurants",
    r'\bridiculus\b' : "ridiculous",
    r'\briduculous\b' : "ridiculous",
    r'\broomate\b' : "roommate",
    r'\broomates\b' : "roommates",
    r'\bsaleries\b' : "salaries",
    r'\bsandwhich\b' : "sandwich",
    r'\bsandwhiches\b' : "sandwiches",
    r'\bsandwitches\b' : "sandwiches",
    r'\bsatifaction\b' : "satisfaction",
    r'\bsatified\b' : "satisfisatisfieded",
    r'\bsattelite\b' : "satellite",
    r'\bsceience\b' : "science",
    r'\bschedual\b' : "schedule",
    r'\bseemless\b' : "seamless",
    r'\bselction\b' : "selection",
    r'\bsenority\b' : "seniority",
    r'\bsensative\b' : "sensitive",
    r'\bsensored\b' : "censored",
    r'\bseperate\b' : "separate",
    r'\bseperation\b' : "separation",
    r'\bserivce\b' : "service",
    r'\bserivces\b' : "services",
    r'\bserive\b' : "service",
    r'\bserives\b' : "services",
    r'\bservicesi\b' : "services",
    r'\bservidces\b' : "services",
    r'\bservive\b' : "survive",
    r'\bservives\b' : "survives",
    r'\bseverly\b' : "severely",
    r'\bsevice\b' : "service",
    r'\bsevices\b' : "services",
    r'\bshcool\b' : "school",
    r'\bshoud\b' : "should",
    r'\bshoudl\b' : "should",
    r'\bshutttle\b' : "shuttle",
    r'\bsimiliar\b' : "similar",
    r'\bsomeitmes\b' : "sometimes",
    r'\bsomeone\(s\b' : "someones",
    r'\bsomeones\b' : "someones",
    r'\bsometiems\b' : "sometimes",
    r'\bsomone\b' : "someone",
    r'\bsomthing\b' : "something",
    r'\bsophmore\b' : "sophomore",
    r'\bspecialy\b' : "especially",
    r'\bstafff\b' : "staff",
    r'\bstatment\b' : "statement",
    r'\bstong\b' : "strong",
    r'\bstongly\b' : "strongly",
    r'\bstoping\b' : "stopping",
    r'\bstrabucks\b' : "starbucks",
    r'\bstressfull\b' : "stressful",
    r'\bstructure\(s\b' : "structures",
    r'\bstucture\b' : "structure",
    r'\bstuctures\b' : "structures",
    r'\bstuden\b' : "student",
    r'\bstudent\(s\b' : "students",
    r'\bstudetns\b' : "students",
    r'\bstudnet\b' : "student",
    r'\bstudnets\b' : "students",
    r'\bsucess\b' : "success",
    r'\bsudent\b' : "student",
    r'\bsudents\b' : "students",
    r'\bsuperintendant\b' : "superintendent",
    r'\bsuperviser\b' : "supervisor",
    r'\bsupervisor\(s\b' : "supervisors",
    r'\bsupervisores\b' : "supervisors",
    r'\bsuport\b' : "support",
    r'\bsupples\b' : "supplies",
    r'\bsuppossed\b' : "supposed",
    r'\bsuprised\b' : "surprised",
    r'\bsuvey\b' : "survey",
    r'\bsytem\b' : "system",
    r'\bthats\b' : "that's",
    r"\bthe're\b" : "they're",
    r'\btheives\b' : "thieves",
    r'\bthiefs\b' : "thieves",
    r'\bthreating\b' : "threatening",
    r'\bthroughly\b' : "thoroughly",
    r'\bthrought\b' : "throughout",
    r'\bthroughtout\b' : "throughout",
    r'\btodays\b' : "today's",
    r'\btraing\b' : "training",
    r'\btrainning\b' : "training",
    r'\btranfers\b' : "transfers",
    r'\btransfered\b' : "transferred",
    r'\btransfering\b' : "transferring",
    r'\btransporation\b' : "transportation",
    r'\btransportaion\b' : "transportation",
    r'\btransportations\b' : "transportations",
    r'\btransportion\b' : "transportation",
    r'\btrashbags\b' : "trash bags",
    r'\btrashcans\b' : "trash cans",
    r'\btremedously\b' : "tremendously",
    r'\btshirt\b' : "t-shirt",
    r'\btshirts\b' : "t-shirts",
    r'\btution\b' : "tuition",
    r'\btutition\b' : "tuition",
    r'\bunaccessible\b' : "inaccessible",
    r'\bunconvenient\b' : "inconvenient",
    r'\bunecessary\b' : "unnecessary",
    r'\bunflexible\b' : "inflexible",
    r'\bunforseen\b' : "unforeseen",
    r'\buniverisity\b' : "university",
    r'\buniveristy\b' : "university",
    r'\buniverity\b' : "university",
    r'\bunknowledgeable\b' : "unknowledgable",
    r'\bunneccessary\b' : "unnecessary",
    r'\bunrealiable\b' : "unreliable",
    r'\buntill\b' : "until",
    r'\bunversity\b' : "university",
    r'\buseability\b' : "usability",
    r'\busefull\b' : "useful",
    r'\bususally\b' : "usually",
    r'\bvaccum\b' : "vacuum",
    r'\bvaccuum\b' : "vacuum",
    r'\bvaction\b' : "vacation",
    r'\bvacume\b' : "vacuum",
    r'\bvariaty\b' : "variety",
    r'\bvarities\b' : "varieties",
    r'\bvarity\b' : "variety",
    r'\bvegeterian\b' : "vegetarian",
    r'\bvegitarian\b' : "vegetarian",
    r'\bvegitarians\b' : "vegetarians",
    r'\bvegtables\b' : "vegetables",
    r'\bventillation\b' : "ventilation",
    r'\bveriety\b' : "variety",
    r'\bvisted\b' : "visited",
    r'\bvistor\b' : "visitor",
    r'\bvistors\b' : "visitors",
    r'\bweeekends\b' : "weekends",
    r'\bwierd\b' : "weird",
    r'\bwirless\b' : "wireless",
    r'\bwithdrawl\b' : "withdrawal",
    r'\bwoudl\b' : "would",
    r"\bwoudn't\b" : "wouldn't",
    r"\bthier\b" : "their",
    r"\bappartments\b" : "apartments",
    r"\bbenifits\b" : "benefits",
    r"\bexistant\b" : "existent",
    r"\bsaftey\b" : "safety",
    r'\bdon"t\b' : "don't",
}

In [None]:
weirdchar_str_repls = {
    "#39;" : "'",   
    'amp;' : '&',   
    '#146;' : "'",   
    'nbsp;' : ' ',   
    '#36;' : '$',   
    '\\n' : "\n",   
    'quot;' : "'",   
    '’' : "'",   
    "´" : "'",
    "`" : "'",
    '`' : "'", 
    '´' : "'", 
    '“' : '"',   
    '”' : '"',   
    '<br />' : "\n",   
    '\\"' : '"',   
    '<unk>' : 'u_n',   
    ' @.@ ' : '.',   
    ' @-@ ' : '-',   
    '\\' : ' \\ ',   
    '•' : '-'
}

In [None]:
# does regex replace making the substitution the same case
def re_replace(word, replacement, text):
    def func(match):
        g = match.group()
        if g.islower(): return replacement.lower()
        if g.istitle(): return replacement.title()
        if g.isupper(): return replacement.upper()
        return replacement      
    
    return re.sub(word, func, text, flags=re.I)

# define regex and string replacements
re_repls = {}  # e.g., { **spelling_regex_repls } 
str_repls = {} # e.g., { **weirdchar_str_repls }

def make_replacements(t:str) -> str:
    # replace based on regexs (keeping case) and then strings
    for k, v in re_repls.items(): t = re_replace(k, v, t)
    for k, v in str_repls.items(): t = t.replace(k, t)
    return t

# ensure am|pm is considered it own token (7:00pm > 7:00 pm, 7am-10pm > 7 am - 10 pm))
def fix_ampm(t:str) -> str:
    re_ampm = re.compile(r'(\d+)(am|pm|am\-|pm\-|a\.m\.|p\.m\.|a\.m\.\-|p\.m\.\-)')    
    return re_ampm.sub(r'\1 \2 ', t)

# try to handle places where a new sentence doesn't begin with a space (e.g., I like dogs.I like cats)
# without breaking apart things like urls and emails
def fix_sentence_ends(t:str) -> str:
    re_sentend = re.compile(r'(?<!www)\.((?!com|edu|org|net|m\b)[a-zA-Z]+)(?!(@|\.(com|edu|org|net)))\b') 
    return re_sentend.sub(r'. \1 ', t)

# separate hyphen|tilde if it is at beginning of letter/digit
def fix_hyphenated_words(t:str) -> str:
    re_hypword = re.compile(r'\s(\-+|~+)([a-zA-Z0-9])')
    return re_hypword.sub(r' \1 \2', t)


# prepend custom tokenization rules to defaults
custom_tok_rules = defaults.text_pre_rules + [make_replacements, fix_ampm, fix_sentence_ends, fix_hyphenated_words]

# use this customized Tokenizer for qualitative data
tokenizer = Tokenizer(pre_rules=custom_tok_rules)

## Configuration

In [None]:
# various default, LM, and classification paths
PATH = Path('../data')
CLEAN_DATA_PATH = Path('../data/clean')

LM_PATH = PATH/'lm'
CLS_PATH = PATH/'classification'
STANDARD_THEME_PATH = CLS_PATH/'standard_themes'
STANDARD_THEME_CSS_PATH = STANDARD_THEME_PATH/'css'

(LM_PATH/'models').mkdir(parents=True, exist_ok=True)
(LM_PATH/'tmp').mkdir(exist_ok=True)

(STANDARD_THEME_CSS_PATH/'models').mkdir(parents=True, exist_ok=True)
(STANDARD_THEME_CSS_PATH/'tmp').mkdir(exist_ok=True)

In [None]:
# basic columns
lm_dtypes = { 
    'Id': int, 'QuestionAnsID': int, 'AnswerText': str, 'AnswerText_NonEnglish': str, 'Language': str,
    
    'SurveyID': int, 'SurveyTypeID': int, 'BenchmarkSurveyType': str, 'ClientId': str,'RspID': int,
    
    'QuestionCategoryAbbr': str, 'QuestionText': str, 'QuestionClass': str, 
    
    'QuestionCategoryID': float, 'QuestionReportAbbr': str, 'QuestionCategoryLabel': str, 
    'BenchmarkLevel1': str, 'BenchmarkLevel2': str, 'BenchmarkLevel3': str, 'ClientBenchmarkLevel': str,
    
    'GroupCode': float, 'GroupID': str, 
    'GroupLevel1Code': float, 'GroupLevel1Name': str,
    'GroupLevel2Code': float, 'GroupLevel2Name': str,
    'GroupLevel3Code': float, 'GroupLevel3Name': str,
    'GroupLevel4Code': float, 'GroupLevel4Name': str,
    'GroupLevel5Code': float, 'GroupLevel5Name': str,
    'GroupLevel6Code': float, 'GroupLevel6Name': str,
    'GroupLevel7Code': float, 'GroupLevel7Name': str,
    'GroupLevel8Code': float, 'GroupLevel8Name': str,
}

lm_dtypes_sc = { convert_to_snakecase(k):v for k,v in lm_dtypes.items() }

# standard css themes
standard_theme_css_dtypes = { 
    'accessible_to_customers': int,
    'consistency_in_policies_information': int,
    'cost_fees': int,
    'courteous_professional_staff': int,
    'effective_communications': int,
    'effectively_uses_websites_online_documentation': int,
    'helpful_staff': int,
    'knowledgeable_staff': int,
    'moving_in_a_positive_direction': int,
    'overall_satisfaction': int,
    'process_improvement': int,
    'provides_effective_advice_guidance': int,
    'provides_training_on_processes_applications': int,
    'resolves_problems_effectively': int,
    'responds_to_requests_within_an_acceptable_time': int,
    'understands_my_needs_and_requirements': int
}

# date columns
date_cols = []
    

STANDARD_THEME_CSS_LABELS = list(standard_theme_css_dtypes.keys())

In [None]:
vocab = pickle.load(open(LM_PATH/'vocab.pkl', 'rb'))

In [None]:
len(vocab.itos)

## Classifier

The classifier is basically a linear layer custom head on top of the LM backbone

In [None]:
chunksize = 24000

bptt, em_sz, nh, nl = 70, 400, 1150, 3
bsz = 80
wd = 1e-7

In [None]:
# define what text columns to use (can be multiple)
corpus_cols = ['answer_text'] 

# define how to identify the text we are using for the LM
corpus_suf = '' #'_cleaned'

In [None]:
train_df = pd.read_csv(STANDARD_THEME_CSS_PATH/'train.csv')
valid_df = pd.read_csv(STANDARD_THEME_CSS_PATH/'test.csv')

Remove any rows whre the "corpus_cols" are nan

In [None]:
train_df.dropna(subset=corpus_cols, inplace=True)
valid_df.dropna(subset=corpus_cols, inplace=True)

In [None]:
STANDARD_THEME_CSS_LABELS

In [None]:
# 11/15/2018 - currently have to put all labels into a single column
# train_df['labels'] = train_df[SENT_LABELS[1:]].apply(lambda row: ' '.join(row.columns[row.values == 1]), axis=1)
# valid_df['labels'] = valid_df[SENT_LABELS[1:]].apply(lambda row: ' '.join(row.columns[row.values == 1], axis=1)

train_df['labels'] = train_df[STANDARD_THEME_CSS_LABELS].apply(
    lambda x: ' '.join(x.index[x.astype(bool)]), axis=1)
valid_df['labels'] = valid_df[STANDARD_THEME_CSS_LABELS].apply(
    lambda x: ' '.join(x.index[x.astype(bool)]), axis=1)

train_df[['labels'] + STANDARD_THEME_CSS_LABELS].head()

In [None]:
cls_processor = [
    TokenizeProcessor(tokenizer=tokenizer, chunksize=chunksize),
    NumericalizeProcessor(vocab=vocab)
]

data_clas = (ItemLists(path=STANDARD_THEME_CSS_PATH,
                     train=TextList.from_df(
                         train_df, path=STANDARD_THEME_CSS_PATH, cols=corpus_cols, processor=cls_processor),
                     valid=TextList.from_df(
                         valid_df, path=STANDARD_THEME_CSS_PATH, cols=corpus_cols, processor=cls_processor)
                    )
             .label_from_df(cols='labels', classes=STANDARD_THEME_CSS_LABELS, label_delim=' ')
             .databunch(bs=bsz)
          )

data_clas.save(f'data_cls_standard_theme_css.pkl')

In [None]:
data_clas = load_data(STANDARD_THEME_CSS_PATH, f'data_cls_standard_theme_css.pkl', bs=bsz)

In [None]:
data_clas.train_ds.vocab.itos[:10]

In [None]:
print(data_clas.train_ds.x[0])
print(data_clas.train_ds.y[0])

In [None]:
print(len(data_clas.train_ds), len(data_clas.train_ds.vocab.itos))
print(len(data_clas.valid_ds), len(data_clas.valid_ds.vocab.itos))

In [None]:
it = iter(data_clas.train_dl)

In [None]:
batch = next(it)
print(batch[0].size())
print(batch[1].size())
print(batch[0].size(), batch[0].type(), batch[1].size(), batch[1].type(), bsz)

In [None]:
' '.join([ data_clas.train_ds.vocab.itos[idx] for idx in batch[0][0,:] ])

In [None]:
data_clas.show_batch()

### Configure a forward or backwards run

In [None]:
backwards = True

m_suf = '_multilabel' #'_cleaned'
m_pre = 'bwd_' if (backwards) else 'fwd_'

data_clas = load_data(STANDARD_THEME_CSS_PATH, f'data_cls_standard_theme_css.pkl', bs=bsz, backwards=backwards)

### Build the classifier (baseline)

In [None]:
class MultiLabelClassifier(nn.Module):
    
    def __init__(self, y_range=None):
        super().__init__()
        self.y_range = y_range
    
    def forward(self, input):
        x, raw_outputs, outputs = input
        x = torch.sigmoid(x)
        if (self.y_range):
            x = x * (self.y_range[1] - self.y_range[0])
            x = x + self.y_range[0]
        
        return x, raw_outputs, outputs

We setup the dropouts for the model - these values have been chosen after experimentation. If you need to update them for custom LMs, you can change the weighting factor (0.7 here) based on the amount of data you have. For more data, you can reduce dropout factor and for small datasets, you can reduce overfitting by choosing a higher dropout factor. *No other dropout value requires tuning*

We first tune the last embedding layer so that the missing tokens initialized with mean weights get tuned properly. So we freeze everything except the last layer.

We also keep track of the *accuracy* metric.

In [None]:
beta = 1
start = 0.1

def fscore(preds, targs):
    return metrics_util.best_fscore(preds, targs, beta, start=start)
    
def opt_th(preds, targs):
    return metrics_util.best_fthresh(preds, targs, beta=beta, start=start)

def multilbl_accuracy(preds, targs):
    return metrics_util.multi_accuracy(preds, targs, beta=beta, start=start)

In [None]:
try: learn.purge(); learn = None; torch.cuda.empty_cache();
except: pass

In [None]:
learn = None; gc.collect()
learn = text_classifier_learner(data_clas, arch=AWD_LSTM, pretrained=False,
                                drop_mult=0.5, bptt=bptt, lin_ftrs=[50], ps=[0.1],
                                alpha=2., beta=1.)

In [None]:
learn.model.add_module('2', MultiLabelClassifier())

learn.clip = 25.
learn.loss_func = F.binary_cross_entropy
learn.metrics = [opt_th, fscore, multilbl_accuracy]

learn.model_dir = 'models'

In [None]:
best_model_cb = partial(SaveModelCallback, monitor='fscore', mode='max', name=f'{m_pre}cls_bestmodel{m_suf}')
# best_model_cb = partial(SaveModelCallback, monitor='val_loss', mode='min', name=f'{lm_pre}cls_bestmodel{exp_suffix}')

learn.callback_fns.append(best_model_cb)
# learn.callback_fns.append(RocAucEvaluation)

In [None]:
best_model_path = STANDARD_THEME_CSS_PATH/f'models/{m_pre}cls_bestmodel{m_suf}*'
!rm {best_model_path}

In [None]:
# copied from /lm/models -> class/models (both fwd and bwd weights)
! cp {LM_PATH/'models/*_lm_enc.pth'} {STANDARD_THEME_CSS_PATH/'models/'}  

In [None]:
learn.load_encoder(f'{m_pre}lm_enc')

In [None]:
# learn.model

In [None]:
lr = 5e-1
wd = 0.

In [None]:
learn.freeze()

In [None]:
learn.lr_find(lr/1000, wd=wd)
learn.recorder.plot()

We set learning rates and fit our IMDB LM. We first run one epoch to tune the last layer which contains the embedding weights. This should help the missing tokens in the wikitext103 learn better weights.

In [None]:
%%time
learn.fit_one_cycle(1, lr, wd=wd)

In [None]:
learn.save(f'{m_pre}cls_last_ft{m_suf}')

In [None]:
# will load the best when training ends automaticall ... #learn = learn.load(f'{lm_pre}cls_last_ft{exp_suffix}')

In [None]:
learn.freeze_to(-2)
learn.fit_one_cycle(1, slice(5e-2/(2.6**4),5e-2), moms=(0.8,0.7))

In [None]:
learn.save(f'{m_pre}cls_last2_ft{m_suf}')

In [None]:
# will load the best when training ends automaticall ... #learn = learn.load(f'{lm_pre}cls_last2_ft{exp_suffix}')

In [None]:
learn.freeze_to(-3)
learn.fit_one_cycle(1, slice(5e-3/(2.6**4),5e-3), moms=(0.8,0.7))

In [None]:
learn.save(f'{m_pre}cls_last3_ft{m_suf}')

In [None]:
# will load the best when training ends automaticall ... 
# learn = learn.load(f'{lm_pre}cls_last3_ft{exp_suffix}')

In [None]:
learn.unfreeze()
learn.fit_one_cycle(20, slice(5e-3/(2.6**4),5e-3), moms=(0.8,0.7))

In [None]:
learn.save(f'{m_pre}cls{m_suf}')

Export model for inference

In [None]:
learn.export(file=f'{m_pre}export_clas{m_suf}.pkl')

Use it for inference

In [None]:
inf_learn = load_learner(STANDARD_THEME_CSS_PATH, file=f'{m_pre}export_clas{m_suf}.pkl')

In [None]:
inf_learn.data.single_ds.y.classes = STANDARD_THEME_CSS_LABELS

In [None]:
inf_learn.predict('The pay is too low and parking stinks on campus.  Where is my salary increase?')

Review final validation loss for best model

In [None]:
learn = learn.load(f'{m_pre}cls_bestmodel{m_suf}')
probs, targs, loss = learn.get_preds(DatasetType.Valid, with_loss=True)

print(f'Validation Loss: {loss.mean()}')
print(f'Validation Loss (per label): {loss.mean(dim=0)}')

In [None]:
print(STANDARD_THEME_CSS_LABELS)

In [None]:
learn.predict("There are not enough people to do the work")

## Save models, csvs, to zip and download (optional)

In [None]:
from IPython.display import FileLink

In [None]:
# !zip -r models.zip {LM_PATH}/models/ {CLS_PATH}/models  -x {LM_PATH}/models/lstm_wt103/\*

# FileLink('models.zip')

In [None]:
# !zip verbatims-csvs.zip {PATH}/verbatims.csv {PATH}/verbatims-entities.csv {PATH}/verbatims-meta.csv

# FileLink('verbatims-csvs.zip')

## Review predictions

### Predict sentiment for our validation dataset, including the actual document

In [None]:
# predictions for a single model using the learner's model and data loaders
learn.load(f'{m_pre}cls_bestmodel{m_suf}')
learn.model.cuda(1)
probs, targs, docs = get_cls_predictions(learn, DatasetType.Valid, vocab)

probs.shape, targs.shape, len(docs)

In [None]:
data_clas.valid_ds.y.items.shape, data_clas.valid_ds.y.c, data_clas.valid_ds.y.classes

In [None]:
# determine optimal threshold based on desired f-score
threshold_f05 = metrics_util.best_fthresh(probs, targs, beta=0.5, start=0.1, end=.3).item()
threshold_f1 = metrics_util.best_fthresh(probs, targs, beta=1, start=0.1, end=.3).item()
threshold_f2 = metrics_util.best_fthresh(probs, targs, beta=2, start=0.1, end=.3).item()

threshold_f05, threshold_f1, threshold_f2

In [None]:
res = fbeta(probs, targs, thresh=threshold_f1, beta=1, sigmoid=False)
res

In [None]:
from sklearn import metrics

In [None]:
res = metrics.fbeta_score(targs, (probs > threshold_f1), beta=1, average='samples')
res

In [None]:
preds = ((probs > threshold_f1).byte() == targs.byte()).float().mean()
preds.item()

In [None]:
# determine accuracy based on optimal threshold
val_acc_f05 = accuracy_thresh(probs, targs, threshold_f05, sigmoid=False).item()
val_acc_f1 = accuracy_thresh(probs, targs, threshold_f1, sigmoid=False).item()
val_acc_f2 = accuracy_thresh(probs, targs, threshold_f2, sigmoid=False).item()

val_acc_f05, val_acc_f1, val_acc_f2

### Review classifier

In [None]:
import sklearn
from sklearn import metrics
print (sklearn.__version__)

In [None]:
probs.shape, targs.shape

In [None]:
eval_targs = targs.flatten() # targs[:,0]
eval_probs = probs.flatten() # probs[:,0]

#### Classification Accuracy

The percentage of correct predictions.  Answers the question, *"Overall, how often is the classifier correct?"*

In [None]:
# In multilabel classification, this function computes subset accuracy: 
# the set of labels predicted for a sample must exactly match ALL the corresponding set of labels in y_true.
print(metrics.accuracy_score(targs, (probs > threshold_f1)))

In [None]:
print(metrics.accuracy_score(eval_targs, (eval_probs > threshold_f1).float()))

#### Null Accuracy
 
The accuracy achieved by always predicting the most frequent class.  Answers the question, *"What would the accuracy be by always predicting the most frequent case?"*

In [None]:
u_classes, u_counts = np.unique(eval_targs, return_counts=True)
most_freq_class, most_freq_class_count = u_classes[np.argmax(u_counts)], np.max(u_counts)
print(most_freq_class, most_freq_class_count)

In [None]:
most_freq_class_count / len(eval_targs)

#### Cohen's kappa

This measure is intended to compare labelings by different human annotators (not a classifier vs. ground truth)

Kappa socres are between -1 and 1 ( >= .8 is generally considered good agreement; <= 0 means no agreement ... e.g., practically random labels)

In [None]:
print(metrics.cohen_kappa_score(eval_targs, (eval_probs > threshold_f1).float()))

#### Confusion Matrix

Describes the performance of a classification model

In [None]:
def plot_confusion_matrix(cm, classes, normalize=False, 
                          title='Confusion matrix', cmap=plt.cm.Blues, print_info=False):
    """
    This function prints and plots the confusion matrix.
    Normalization can be applied by setting `normalize=True`.
    """
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        if (print_info): print("Normalized confusion matrix")
    else:
        if (print_info): print('Confusion matrix, without normalization')

    if (print_info): print(cm)

    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar(fraction=0.046, pad=0.04)
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)

    fmt = '.2f' if normalize else 'd'
    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], fmt),
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")

    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.grid(None)

In [None]:
cm = metrics.confusion_matrix(eval_targs, (eval_probs > threshold_f1).float())

In [None]:
# Plot non-normalized confusion matrix
fig = plt.figure(figsize=(12,8))
plt.subplot(1, 2, 1)
plot_confusion_matrix(cm, classes=u_classes,
                      title='Confusion matrix, without normalization')

# Plot normalized confusion matrix
plt.subplot(1, 2, 2)
plot_confusion_matrix(cm, classes=u_classes, normalize=True,
                      title='Normalized confusion matrix')

fig.subplots_adjust(wspace=0.5)
plt.show()

In [None]:
print(metrics.classification_report(eval_targs, (eval_probs > threshold_f1).float(), [0,1]))

### Raw probability distribution

Useful to see how the threshold can be adjusted to increase sensitivity or specificity

In [None]:
plt.hist(eval_probs, bins=10)
# plt.xlim(0,1)
plt.title('Histogram of predicted probabilities')
plt.xlabel('Predicted probability of IsVeryPositive')
plt.ylabel('Frequency')

Demonstrates how you can **decrease** the threshold for predicting label in order to **increase the sensitivity** of the classifier

#### ROC curves and Area Under the Curve (AUC)

***ROC Curve*** answers the question, *"How would sensitivity and specificity be affected by various thresholds without changing the threshold?"*  It is a way **to visualize the performance of a binary classifier.**

The ROC curve can help you **choose a threshold** that balances sensitivity and specificity based on your particular business case.

ROC curves visualize all possible classification thresholds whereas misclassification rate only represents your error rate for a single threshold.

A classifier that does a good job at separating the classes will have a ROC curve that hugs the upper left corner of the plot.  Converseley, a classifier the does a poor job separating the classes will have a ROC curve that is close to the diagonal line (0,0 -> 1,1).  That diagonal line represents a classifier that does no better than random guessing.

In [None]:
fpr, tpr, thresholds = metrics.roc_curve(eval_targs, eval_probs)

In [None]:
plt.plot(fpr, tpr)
plt.xlim = ([0.0, 1.0])
plt.ylim = ([0.0, 1.0])
plt.title('ROC curve for all labels')
plt.xlabel('False Positive Rate (1 - Specificity)')
plt.ylabel('True Positive Rate (Sensitivity/Recall)')
plt.grid(True)

***AUC*** = the percentage of the ROC plot that is underneath the curve.  

AUC summarizes the performance of a classifier in a **single number**.  It says, *"If you randomly chose one positive and one negative observation, what is the likelihood that your classifier will assign a higher predicted probability to the positive observation."*

**An AUC of ~ 0.8 is very good while an AUC of ~ 0.5 represents a poor classifier.**

The ROC curve and AUC are insensitive to whether your predicted probabilities are properly calibrated to actually represent probabilities of class membership (e.g., it works if predicted probs range from 0.9 to 1 instead of 0 to 1).  All the AUC metric cares about is how well your classifier separated the two classes

Notes:
1.  AUC is useful even when there is **high class imbalance** (unlike classification accuracy)
2.  AUC is useful even when predicted probabilities are not properly calibrated (e.g., not between 0 and 1)

In [None]:
print(metrics.roc_auc_score(eval_targs, eval_probs))

Let's look at things label by label ...

In [None]:
label_metrics = {
    'thresholds': { 'f-beta05': threshold_f05, 'f-beta1': threshold_f1, 'f-beta2': threshold_f2 }
}

for idx, lbl in enumerate(STANDARD_THEME_CSS_LABELS):
    lbl_name, lbl_idx, lbl_targs, lbl_probs = lbl, idx, targs[:,idx], probs[:, idx]
    
    label_metrics[lbl_name] = {}
    label_metrics[lbl_name]['accuracies'] = {}
    label_metrics[lbl_name]['cohen_kappas'] = {}
    label_metrics[lbl_name]['confusion_matrices'] = {}
    label_metrics[lbl_name]['roc'] = {}
    label_metrics[lbl_name]['report'] = {}
    
    # get null accuracy (accuracy we'd get if we simply predicted the most common class)
    u_classes, u_counts = np.unique(lbl_targs, return_counts=True)
    most_freq_class, most_freq_class_count = u_classes[np.argmax(u_counts)], np.max(u_counts)
    label_metrics[lbl_name]['null_accuracy'] = most_freq_class_count / len(lbl_targs)
    
    # get raw probability distribution
    label_metrics[lbl_name]['probability_distribution'] = np.histogram(lbl_probs)
    
    # roc/auc curve metrics
    label_metrics[lbl_name]['roc_auc'] = metrics.roc_auc_score(lbl_targs, lbl_probs)
    
    fpr, tpr, thresholds = metrics.roc_curve(lbl_targs, lbl_probs)
    label_metrics[lbl_name]['roc']['fpr'] = fpr
    label_metrics[lbl_name]['roc']['tpr'] = tpr
    label_metrics[lbl_name]['roc']['thresholds'] = thresholds
    
    for k,v in label_metrics['thresholds'].items():
        label_metrics[lbl_name]['accuracies'][k] = metrics.accuracy_score(lbl_targs, (lbl_probs > v))
        label_metrics[lbl_name]['cohen_kappas'][k] = metrics.cohen_kappa_score(lbl_targs, (lbl_probs > v))
        label_metrics[lbl_name]['confusion_matrices'][k] = metrics.confusion_matrix(lbl_targs, (lbl_probs > v))
        
        precision, recall, fbeta_score, support = metrics.precision_recall_fscore_support(lbl_targs, (lbl_probs > v))
        label_metrics[lbl_name]['report'][k] = {}
        label_metrics[lbl_name]['report'][k]['precision'] = precision
        label_metrics[lbl_name]['report'][k]['recall'] = recall
        label_metrics[lbl_name]['report'][k]['fbeta_score'] = fbeta_score
        label_metrics[lbl_name]['report'][k]['support'] = support
          
          

In [None]:
for lbl in label_metrics.keys():
    if (lbl == 'thresholds'): continue
    
    print(f'{lbl.upper()}\n')
    
    print(f'Null Accuracy:\t{label_metrics[lbl]["null_accuracy"]}')
    print(f'AUC Score:\t{label_metrics[lbl]["roc_auc"]}')
    print('')
    
    print(''.join([ f'\t\t{threshold}({np.round(v, 4)})' for threshold, v in label_metrics['thresholds'].items() ]))
    
    print('Accuracy:\t', end='')
    for threshold, v in label_metrics['thresholds'].items():
        print(f'{label_metrics[lbl]["accuracies"][threshold]}\t', end='')
    print('')
    
    print('Cohen\'s Kappa:\t', end='')
    for threshold, v in label_metrics['thresholds'].items():
        print(f'{label_metrics[lbl]["cohen_kappas"][threshold]}\t', end='')
    print('\n')
    
    print('Classification Reports:')
    for k in label_metrics[lbl]['report'].keys():
        print(f'{k}')
        print(f'{"":<20}' + ''.join([ f'{sub_key:<20}' for sub_key in label_metrics[lbl]['report'][k].keys() ]))
        
        for i in range(2):
            print(f'{i:<20}' + ''.join([ f'{np.round(v[i],4):<20}' 
                                      for v in label_metrics[lbl]['report'][k].values() ]))
        
        print(f'{"avg/total":<20}' + ''.join([ f'{ np.round(v.mean(),4) if (sub_key != "support") else np.round(v.sum(),4):<20}' 
                                     for sub_key, v in label_metrics[lbl]['report'][k].items() ]))
        print('')
    print('\n')
    
    print('Confusion Matrices:')
    for threshold, v in label_metrics['thresholds'].items():
        cm = label_metrics[lbl]['confusion_matrices'][threshold]
        
        # Plot non-normalized confusion matrix
        fig = plt.figure(figsize=(12,8))
        plt.subplot(1, 2, 1)
        plot_confusion_matrix(cm, classes=[0,1], 
                              title=f'Confusion matrix, without normalization ({threshold}: {np.round(v,4)})')

        # Plot normalized confusion matrix
        plt.subplot(1, 2, 2)
        plot_confusion_matrix(cm, classes=[0,1], normalize=True, 
                              title=f'Normalized confusion matrix ({threshold}: {np.round(v,4)})')

        fig.subplots_adjust(wspace=0.5)
        plt.show()
    print('\n')
    
    print('ROC Curve:')
    plt.figure(figsize=(12,8))
    plt.plot(label_metrics[lbl]['roc']['fpr'], label_metrics[lbl]['roc']['tpr'])
    plt.xlim = ([0.0, 1.0])
    plt.ylim = ([0.0, 1.0])
    plt.title(f'ROC curve for {lbl}')
    plt.xlabel('False Positive Rate (1 - Specificity)')
    plt.ylabel('True Positive Rate (Sensitivity/Recall)')
    plt.grid(True)
    plt.show()
    
    print('Predicted Probability Distribution:')
    plt.figure(figsize=(12,8))
    plt.xlim = ([0.0, 1.0])
    plt.bar(label_metrics[lbl]['probability_distribution'][1][:-1], 
            label_metrics[lbl]['probability_distribution'][0], width=0.1)
    plt.show()
    
    print('\n')
    print('-'*100)
    print('\n')
        

### Ensemble forwards and backwards passes

In [None]:
try:
    learn_fwd.purge(); learn_fwd = None;
    learn_bwd.purge(); learn_bwd = None;
    gc.collect()
    torch.cuda.empty_cache()
except: pass


bsz = 80
m_suf = '_multilabel'

learn_fwd = load_learner(STANDARD_THEME_CSS_PATH, file=f'fwd_export_clas{m_suf}.pkl')
data_fwd = load_data(STANDARD_THEME_CSS_PATH, f'data_cls_standard_theme_css.pkl', bs=bsz)
learn_fwd.data = data_fwd

learn_bwd = load_learner(STANDARD_THEME_CSS_PATH, file=f'bwd_export_clas{m_suf}.pkl')
data_bwd = load_data(STANDARD_THEME_CSS_PATH, f'data_cls_standard_theme_css.pkl', bs=bsz, backwards=True)
learn_bwd.data = data_bwd

In [None]:
probs_fwd, lbl_fwd, loss_fwd = learn_fwd.get_preds(ordered=True, with_loss=True)
probs_bwd, lbl_bwd, loss_bwd = learn_bwd.get_preds(ordered=True, with_loss=True)

probs_fwd.shape, probs_bwd.shape, loss_fwd.shape

In [None]:
loss_fwd.mean(), probs_bwd.mean(), (loss_fwd.mean() + probs_bwd.mean()) / 2

In [None]:
probs_final = (probs_fwd + probs_bwd) / 2

#### Results

In [None]:
# determine optimal threshold based on desired f-score
threshold_f05 = metrics_util.best_fthresh(probs_fwd, lbl_fwd, beta=0.5, start=0.1, end=.3).item()
threshold_f1 = metrics_util.best_fthresh(probs_fwd, lbl_fwd, beta=1, start=0.1, end=.3).item()
threshold_f2 = metrics_util.best_fthresh(probs_fwd, lbl_fwd, beta=2, start=0.1, end=.3).item()

threshold_f05, threshold_f1, threshold_f2

# determine accuracy based on optimal threshold
val_acc_f05 = accuracy_thresh(probs_fwd, lbl_fwd, threshold_f05, sigmoid=False).item()
val_acc_f1 = accuracy_thresh(probs_fwd, lbl_fwd, threshold_f1, sigmoid=False).item()
val_acc_f2 = accuracy_thresh(probs_fwd, lbl_fwd, threshold_f2, sigmoid=False).item()

print('Fowards Only\n-------------')
print(f'f05:\tOptimal threshold = {threshold_f05}\t(Accuracy = {val_acc_f05})')
print(f'f1:\tOptimal threshold = {threshold_f1}\t\t(Accuracy = {val_acc_f1})')
print(f'f2:\tOptimal threshold = {threshold_f2}\t(Accuracy = {val_acc_f2})')

print(f'\nAccuracy: {accuracy_thresh(probs_fwd, lbl_fwd, sigmoid=False)}')

In [None]:
# determine optimal threshold based on desired f-score
threshold_f05 = metrics_util.best_fthresh(probs_bwd, lbl_fwd, beta=0.5, start=0.1, end=.3).item()
threshold_f1 = metrics_util.best_fthresh(probs_bwd, lbl_fwd, beta=1, start=0.1, end=.3).item()
threshold_f2 = metrics_util.best_fthresh(probs_bwd, lbl_fwd, beta=2, start=0.1, end=.3).item()

threshold_f05, threshold_f1, threshold_f2

# determine accuracy based on optimal threshold
val_acc_f05 = accuracy_thresh(probs_bwd, lbl_fwd, threshold_f05, sigmoid=False).item()
val_acc_f1 = accuracy_thresh(probs_bwd, lbl_fwd, threshold_f1, sigmoid=False).item()
val_acc_f2 = accuracy_thresh(probs_bwd, lbl_fwd, threshold_f2, sigmoid=False).item()

print('Backwards Only\n-------------')
print(f'f05:\tOptimal threshold = {threshold_f05} (Accuracy = {val_acc_f05})')
print(f'f1:\tOptimal threshold = {threshold_f1} (Accuracy = {val_acc_f1})')
print(f'f2:\tOptimal threshold = {threshold_f2} (Accuracy = {val_acc_f2})')

print(f'\nAccuracy: {accuracy_thresh(probs_bwd, lbl_fwd, sigmoid=False)}')

In [None]:
# determine optimal threshold based on desired f-score
threshold_f05 = metrics_util.best_fthresh(probs_final, lbl_fwd, beta=0.5, start=0.1, end=.3).item()
threshold_f1 = metrics_util.best_fthresh(probs_final, lbl_fwd, beta=1, start=0.1, end=.3).item()
threshold_f2 = metrics_util.best_fthresh(probs_final, lbl_fwd, beta=2, start=0.1, end=.3).item()

threshold_f05, threshold_f1, threshold_f2

# determine accuracy based on optimal threshold
val_acc_f05 = accuracy_thresh(probs_final, lbl_fwd, threshold_f05, sigmoid=False).item()
val_acc_f1 = accuracy_thresh(probs_final, lbl_fwd, threshold_f1, sigmoid=False).item()
val_acc_f2 = accuracy_thresh(probs_final, lbl_fwd, threshold_f2, sigmoid=False).item()

print('Ensemble\n-------------')
print(f'f05:\tOptimal threshold = {threshold_f05} (Accuracy = {val_acc_f05})')
print(f'f1:\tOptimal threshold = {threshold_f1} (Accuracy = {val_acc_f1})')
print(f'f2:\tOptimal threshold = {threshold_f2} (Accuracy = {val_acc_f2})')

print(f'\nAccuracy: {accuracy_thresh(probs_final, lbl_fwd, sigmoid=False)}')

In [None]:
final_valid_loss = (loss_fwd.mean() + probs_bwd.mean()) / 2

### Inference (ad-hoc documents)

In [None]:
print(STANDARD_THEME_CSS_LABELS)

In [None]:
test_comments = [
    'The parking situation REALLY sucks around here.  It needs to be fixed',
    'I LOVE working at UCSD!!!  It is wonderful',
    """Some staff are just uninformed.There is no support for solo-individual study (no closed off rooms).
        Once a guy (quite tall) walked in into the girl's restroom and used the stalls standing up. 
        There was no line in the guy's restroom. This happened when I done and was going to walk out. 
        I was extremely uncomfortable""",
    "I love UCSD!!! It is a terrible place to work!",
    "I was really uncomfortable to express my opinion!!!"
]

doc_probs, doc_preds, doc_toks = get_cls_doc_predictions(learn.model, vocab, tokenizer, test_comments, 
                                                         threshold=threshold_f1)

In [None]:
len(doc_probs), len(doc_probs), len(doc_toks)

In [None]:
for d_probs, d_preds, d_toks in zip(doc_probs, doc_preds, doc_toks):
    print(f'> {" ".join([t for t in d_toks])}\nProbabilities:\t{d_probs}\nPredictions:\t{d_preds}\n')

### Inference (batch ensemble)

In [None]:
import datetime
yyyymmdd = datetime.date.today().strftime("%Y%m%d")

m_suf = '_multilabel'

# device = torch.device('cpu')
device = torch.device('cuda:1') if torch.cuda.is_available() else torch.device('cpu')
print(device)

In [None]:
verbatims_df = pd.read_csv(LM_PATH/'all.csv', dtype={**lm_dtypes}, parse_dates=[])

inf_df = verbatims_df[verbatims_df.BenchmarkSurveyType.str.startswith('CSS')].copy()
inf_df.reset_index(drop=True, inplace=True)
print(len(inf_df))

corpus_cols = ['AnswerText']  # ['question_text', 'answer_text']

In [None]:
def concat_pool(raw_outputs):
    last_rnn_layer = raw_outputs[-1]
    bsz = last_rnn_layer.shape[0] 
    
    avg_pool = F.adaptive_avg_pool1d(last_rnn_layer.permute(0,2,1), 1).view(bsz, -1)
    max_pool = F.adaptive_max_pool1d(last_rnn_layer.permute(0,2,1), 1).view(bsz, -1)
    last_outp = last_rnn_layer[:,-1,:]

    return torch.cat([last_outp, max_pool, avg_pool], 1)

In [None]:
def get_classification_results(backwards:bool=False, m_suf:str='multilabel'):
    
    model_prefix = 'bwd' if backwards else 'fwd'
    
    # 1. grab learner, procs, and data
    inf_learn = load_learner(STANDARD_THEME_CSS_PATH, file=f'{model_prefix}_export_clas_{m_suf}.pkl')
    txt_procs = inf_learn.data.train_ds.processor
    inf_data = TextList.from_df(inf_df, cols=corpus_cols, processor=txt_procs).split_none().label_empty()
    
    # 2. define a suitable dataloader
    collate_fn = partial(pad_collate, pad_first=True, backwards=backwards)
    sampler = SortSampler(inf_data.train.x, key=[len(t) for t in inf_data.train.x.items].__getitem__)
    dl = DeviceDataLoader.create(inf_data.train, bs=128, sampler=sampler, collate_fn=collate_fn, device=device)
    
    # 3. get probs and document vectors
    inf_learn.model = inf_learn.model.to(device)
    inf_learn.model = inf_learn.model.eval()
    
    test_probs, doc_vecs, concat_doc_vecs = [], [], []
    with torch.no_grad():
        for index, (xb, yb) in enumerate(dl):
            if index % 1000 == 0:  print(index)

            # reset hidden state (if you don't do this you will OOM)
            inf_learn.model.reset()

            # why "detach"? the computation of gradients wrt the weights of netG can be fully 
            # avoided in the backward pass if the graph is detached where it is.
            probs, raw_outputs, outputs = inf_learn.model(xb)

            test_probs.append(to_detach(probs))
            doc_vecs.append(to_detach(raw_outputs[-1][:,-1,:]))
            concat_doc_vecs.append(to_detach(concat_pool(raw_outputs)))

    all_probs = torch.cat(test_probs)
    all_vecs = torch.cat(doc_vecs)
    all_concat_vecs = torch.cat(concat_doc_vecs)

    # 4. ensure results are returned in order
    if hasattr(dl, 'sampler'):
        sampler_idxs = [i for i in dl.sampler]
        reverse_sampler = np.argsort(sampler_idxs)

        all_probs = all_probs[reverse_sampler]
        all_vecs = all_vecs[reverse_sampler]
        all_concat_vecs = all_concat_vecs[reverse_sampler]
        
    # 5. return ordered results
    inf_learn, inf_data = None, None; gc.collect()
    
    return all_probs, all_vecs, all_concat_vecs

In [None]:
%time

probs_fwd, vecs_fwd, concat_vecs_fwd = get_classification_results(backwards=False)
probs_bwd, vecs_bwd, concat_vecs_bwd = get_classification_results(backwards=True)

probs_final = (probs_fwd + probs_bwd) / 2

print(probs_final.shape)
print(probs_fwd.shape, vecs_fwd.shape, concat_vecs_fwd.shape)
print(probs_bwd.shape, vecs_bwd.shape, concat_vecs_bwd.shape)

Add the probabilities of each label to `inf_df`

In [None]:
prob_labels = ['prob_' + lbl for lbl in STANDARD_THEME_CSS_LABELS]
probs_df = pd.DataFrame(probs_final.numpy(), columns=prob_labels)
probs_df.head()

In [None]:
# test_df_filtered.update(probs_df)
final_df = pd.concat([inf_df, probs_df], axis=1)

Add in predictions based on f1 threshold

In [None]:
for lbl in STANDARD_THEME_CSS_LABELS:
    final_df[f'pred_{lbl}'] = (final_df[f'prob_{lbl}'] > threshold_f1).astype(np.int64)

Include found thresholds

In [None]:
final_df['threshold_f05'] = threshold_f05
final_df['threshold_f1'] = threshold_f1
final_df['threshold_f2'] = threshold_f2

final_df['val_acc_f05'] = val_acc_f05
final_df['val_acc_f1'] = val_acc_f1
final_df['val_acc_f2'] = val_acc_f2

final_df['val_loss'] = final_valid_loss.item()

In [None]:
final_df.head()

In [None]:
import datetime
final_df.to_csv(STANDARD_THEME_CSS_PATH/f'{yyyymmdd}_ensemble_predictions{m_suf}.csv', index=False)

### Save document vectors

In [None]:
%time 

np.save(STANDARD_THEME_CSS_PATH/f'{yyyymmdd}_fwd_concat_docvecs_d400{m_suf}.npy', concat_vecs_fwd.numpy())  
np.save(STANDARD_THEME_CSS_PATH/f'{yyyymmdd}_fwd_docvecs_d400{m_suf}.npy', vecs_fwd.numpy())

np.save(STANDARD_THEME_CSS_PATH/f'{yyyymmdd}_bwd_concat_docvecs_d400{m_suf}.npy', concat_vecs_bwd.numpy())  
np.save(STANDARD_THEME_CSS_PATH/f'{yyyymmdd}_bwd_docvecs_d400{m_suf}.npy', vecs_bwd.numpy())

### Playground

In [None]:
final_df.iloc[0].threshold_f05, final_df.iloc[0].threshold_f1, final_df.iloc[0].threshold_f2

In [None]:
len(learn.layer_groups)

In [None]:
[ print(f'{lg}\n') for lg in learn.layer_groups ]