In [8]:
%matplotlib inline
import os

In [9]:
# %load utils.py
import os,sys
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import glob as gb
import re
import ipdb
import pickle
import copy

np.seterr(divide='ignore')
split_re = re.compile(r'([()|])')
###################################################################################

reflexives = {
	'myself',
	'yourself',
	'himself',
	'herself',
	'itself',
	'yourselves',
	'themselves',
	'ourselves',
}

indexicals = {	# not present in OntoNotes
	'that',
	'this',
}

dep_rels_considered = [
	'SBJ',
	'OBJ',
	'PMOD',
]

third_person_prp = [
	'he',
	'him',
	'his',		# always NMOD, so never extracted
	'she',
	'her',
	'hers',
	'it',
	'its',		# always NMOD or COORD, so never extracted
	'they',
	'them',
	'their',	# always NMOD, so never extracted
]

pronounBydeprel = {
	'SBJ': ['he','him','she','her','hers','it','they'],
	'OBJ': ['him','her','it','them','they'],
	'PMOD': ['him','her','it','them']
}

pronoun_type_debug = [
	'sing_mas',
	'sing_fem',
	'sing_neu',
	'plural',
	'<NOT FOUND>',
]

pronounBytype = {
	'sing_mas': ['he','him','his'],
	'sing_fem': ['she','her','hers'],
	'sing_neu': ['it','its'],
	'plural': ['they','them','their']
}

#################################################################################################
class Entity:
	def __init__(self,_phrase='',_pos='-',global_pos=-np.inf):
		self.phrase = _phrase
		self.pos = _pos
		self.global_pos = global_pos
		self.entity_id = -1
		self.referent_id = -1 # global id in referents vocabulary
		self.dep_rel = ""

		
	def __str__(self):
		gp = self.global_pos
		if gp==-np.inf:
			gp=-1
		res = "[ phrase:%s |  pos: %s |  global_pos: %i |  entity_id: %i  | dep_rel: %s ]" % (self.phrase,self.pos,gp,self.entity_id,self.dep_rel)
		return res

	def __repr__(self):
		gp = self.global_pos
		if gp==-np.inf:
			gp=-1
		res = "[ phrase:%s |  pos: %s |  global_pos: %i |  entity_id: %i  | dep_rel: %s ]" % (self.phrase,self.pos,gp,self.entity_id,self.dep_rel)
		return res

"""
Read entities, considering maximally spanning NPs (ignoring nested NPs and nested entities).
Entity selection is done with constraints in NNP and PRP - see paper -.

"""
def read_conll2010_task1_max_np(filename,n_docs='all'):
	documents = []
	sentence = []
	count_docs = 0
	nesting_depth = np.inf
	global_line = 0

	for line in open(filename):
		line = line.strip('\n')
		if line.startswith('#begin'):
			doc = []
			entity_filter = np.zeros(1000,dtype=bool)	# {ent_id: True, if there is at least one NNP in the chain}
			chain_lens = np.zeros(1000)	# {ent_id: count of elements in the chain}

			nesting_depth = np.inf
			event = Entity()
			global_line = 0
			continue
			#all temp files are reset
		if line.startswith('#end'):
			documents.append(doc)
			count_docs+=1
			if n_docs!='all':
				if count_docs>=n_docs:
					break
			continue
		
		if line=='':
			nesting_depth = np.inf	# default val=INF
			continue

		## MAIN STUFF
		global_line += 1 		# increment in each valid line
		comp = line.split('\t')
		token = comp[1].lower()		# consider lowercase for all
		pos = comp[4]
		dep_rel = comp[10]
		coref_str = comp[-1]

		"""
		# only consecutive proper nouns
		if nesting_depth!=np.inf:
			if pos in ['NNP','NNPS'] and event.pos==pos:
				event.phrase += " "+token
			else:
				event.pos = '-'
		"""
		
		# maximally spanning NPs considered, labeled as NNP is there is an NNP in the phrase
		if nesting_depth!=np.inf:	# only enters for maximally entities
			event.phrase += " "+token
			if pos in ['NNP','NNPS']:
				event.pos = pos
			elif event.pos not in ['NNP','NNPS']:
				event.pos = '-'

			if dep_rel=='SBJ' and event.dep_rel!='SBJ':
				event.dep_rel = dep_rel
			elif dep_rel=='OBJ' and event.dep_rel!='SBJ':
				event.dep_rel = dep_rel
			elif dep_rel=='PMOD' and event.dep_rel not in ['SBJ','OBJ','NAME']:
				event.dep_rel = dep_rel

		if coref_str=='_':
			continue

		temp = split_re.split(coref_str)
		splitted = [a for a in temp if a!='']
		k=0

		while(k<len(splitted)):		# | no hace nada
			if splitted[k]=='(':
				if nesting_depth==np.inf:
					nesting_depth=0
					event = Entity(token,pos,global_line)	# initialization of active entity
					event.dep_rel = dep_rel
				nesting_depth+=1
			elif splitted[k].isdigit():
				id = int(splitted[k])
				if nesting_depth==1:	# depth of 1 corresponds to maximal entity
					event.entity_id = id
			elif splitted[k]==')':
				nesting_depth-=1
				if nesting_depth==0: 	# if ) is closing maximal
					# cleaning phrase
					phrase = event.phrase.strip(' ')
					if len(phrase)>1:
						if phrase[-2:]==' .' or phrase[-2:]==' ,':	# if ends in point or comma
							phrase = phrase[:-2]
						phrase = phrase.strip(' ')	# get rid of spaces on sides
					event.phrase = phrase

					if any([ event.pos in ['NNP','NNPS'], 
							 all([	event.pos in ['PRP','PRP$'],		# all constraints in selection of pronouns
							 		token in third_person_prp,		# third person pronouns
							 		token not in reflexives,		# no reflexives
							 		token not in indexicals,		# no indexicals
							 		event.dep_rel in dep_rels_considered,	# pronouns in SUBJ, OBJ or PMOD position
							 	])
						]):
						# filtering preprocessing
						if event.pos in ['NNP','NNPS']:				# found an NNP in the chain
							entity_filter[event.entity_id] = True
							if event.dep_rel not in dep_rels_considered:
								event.dep_rel = 'SBJ'
						doc.append(event)

					nesting_depth=np.inf
			k+=1
	#END-FOR-FILE

	documents = np.array(documents)
	
	return documents

"""
Filter entities according to the following rules:
 - at least one NNP in chain
 - if one NNP in chain is present in the lexicon, consider all NNPs of chain
"""
def filter_entities(documents,lexicon):
	filtered_docs = []
	for doc in documents:
		# all elements in chains with at least one valid NNP will be considered
		valid_chains = set()
		for entity in doc:
			if any([ entity.pos=='NNPS',								#plural goes in anyway
					 entity.pos=='NNP' and entity.phrase in lexicon,	# only if present in lexicon vocabulary
					]):
				valid_chains.add(entity.entity_id)
		# filter documents
		new_doc = np.array( [entity for entity in doc if entity.entity_id in valid_chains ] )
		filtered_docs.append(new_doc)

	filtered_docs = np.array(filtered_docs)
	return filtered_docs	


#################################################################################################
"""
Read NP lexicon with genre and number counts. Only considered the cases: (\w )+, (! \w), (\w !)
return:
	counts: (V,4) : [masculine_count  femenine_count  neutral_count  plural_count]
	nouns_vocab (V,) : vocabulary of NP
"""
def read_noun_lexicon(_dir):
	nouns_vocab = []
	counts = []
	for filename in gb.glob(_dir+'*'):
		for line in open(filename,encoding='latin-1'):
			line = line.strip("\n")
			last_line=line
			if line=='':
				continue
			splitted = line.split("\t")
			phrase = splitted[0]

			mas,fem,neu,plu = [int(a) for a in splitted[1].split(" ")]
			sp_np = phrase.split(' ')

			save = False
			if all([word.isalpha() for word in sp_np]):
				save = True
			"""
			# save <! word> or <word !>
			if len(sp_np)>1:
				if any([
					 len(sp_np)>1 and (sp_np[0]=='!' and sp_np[1].isalpha()),
					 len(sp_np)>1 and (sp_np[1]=='!' and sp_np[0].isalpha()),
				]):
					save = True
			"""
			if save:
				counts.append([mas,fem,neu,plu])
				nouns_vocab.append(phrase)
	counts = np.array(counts)
	nouns_vocab = np.array(nouns_vocab)
	return counts,nouns_vocab

"""
Get intersection between corpus vocabulary and genre_number lexicon vocabulary
"""
def get_vocab_from_lexicon(documents,lexicon_vocab):
	vocab = set()
	for doc in documents:
		for ent in doc:
			if ent.phrase in lexicon_vocab:
				vocab.add(ent.phrase)
	return vocab



"""
Annotate each phrase in vocab with genre and number (pronoun type id)
"""
def annotate_pro_type(lex_counts,lex_vocab,corpus_vocab):
	pro_type_ids = []
	for phrase in corpus_vocab:
		pro_id=-1
		if phrase in lex_vocab:
			p_idx = np.nonzero(lex_vocab==phrase)[0][0]
			pro_id = lex_counts[p_idx,:].argmax()
		elif phrase in third_person_prp:
			for _type,pros in pronounBytype.items():
				if phrase in pros:
					pro_id = pronoun_type_debug.index(_type)
					break
		else:
			pro_id = pronoun_type_debug.index('plural')		# if not in lexicon, PLURAL, since all NNPs are considered
		pro_type_ids.append(pro_id)
	pro_type_ids = np.array(pro_type_ids)
	return pro_type_ids

def saveObject(obj, name='model'):
	with open(name + '.pickle', 'wb') as fd:
		pickle.dump(obj, fd, protocol=pickle.HIGHEST_PROTOCOL)

def uploadObject(obj_name):
	# Load tagger
	with open(obj_name + '.pickle', 'rb') as fd:
		obj = pickle.load(fd)
	return obj

#################################################################################################
"""
Get orreference chains from a single document
@param doc: list of Entity objects
"""
def get_coref_chains(doc):
	chains = {}
	for entity in doc:
		ent_id = entity.entity_id
		if ent_id not in chains:
			chains[ent_id] = []
		pair = tuple([entity.phrase,entity.pos])
		if pair not in chains[ent_id]:
			chains[ent_id].append(pair)
	
	return chains

def get_referents(doc,lexicon):
	referents = {}	# 1000 assumed to be max ent_id
	chains = get_coref_chains(doc)

	for chain_id,chain in chains.items():
		for ref_exp in chain:
			text = ref_exp[0]
			pos = ref_exp[1]
			if pos=='NNP' and text in lexicon:
				referents[chain_id] = text
			elif pos=='NNPS':
				referents[chain_id] = text
				break
	return referents

'''
Reformat corpus and get vocabulary of refering expressions (pronouns and referents).
return: new_documents: [[entity_obj with phrase:ref_exp_id, and entity_id]]
		ref_exp_vocab: [all referents + all pronouns (third person PRP)]  -> assumption: a proper name only refers to itself
'''
def reformat_data(documents,lexicon):
	ref_exp_vocab=[]
	# make referents vocabulary and coreference chains
	for doc in documents:
		referents = get_referents(doc,lexicon)
		for entity in doc:
			referent = referents[entity.entity_id]
			ref_id=-1
			if referent in ref_exp_vocab:
				ref_id = ref_exp_vocab.index(referent)
			else:
				ref_id = len(ref_exp_vocab)
				ref_exp_vocab.append(referent)
			entity.referent_id = ref_id
			if entity.pos[0]=='N':			# replace original noun with referent (assumption: proper names only refer to themselves)
				entity.phrase = referent
	# add pronouns
	for pro in third_person_prp:
		if pro not in ref_exp_vocab:
			ref_exp_vocab.append(pro)

	# reformat documents
	new_documents = []
	for doc in documents:
		new_doc = copy.deepcopy(doc)
		for event in new_doc:
			event.phrase = ref_exp_vocab.index(event.phrase)
		new_documents.append(new_doc)
	new_documents = np.array(new_documents)
	ref_exp_vocab = np.array(ref_exp_vocab)

	return new_documents,ref_exp_vocab


def debug_dep_rel_pronoun(data):
	debug = {
	'SBJ': set(),
	'OBJ': set(),
	'PMOD': set()
	}
	for doc in data:
		for ent in doc:
			if ent.pos[0]=='P':
				debug[ent.dep_rel].add(ent.phrase)
	return debug


In [10]:
# %load class_speaker_model.py
import numpy as np
import scipy
from utils import  third_person_prp, pronoun_type_debug, pronounBydeprel, pronounBytype
from sklearn.metrics import classification_report, accuracy_score
import ipdb

class SpeakerModel:
	def __init__(self,_data,_vocab,_pronoun_type_ids,_lex_counts,_alpha=0.1,_decay=3, _salience='recency'):
		# corpus  Ndocs x entities
		self.data = _data
		# vocab: list of phrases, used to map phr_id in corpus
		self.vocab = _vocab # np array
		self.vocab_size = len(_vocab)
		# number of referents
		self._n_referents = self.vocab_size - len(third_person_prp)
		# pronoun type (e.g. singular_male) ids in same order as vocabulary
		self.pronoun_type_ids = _pronoun_type_ids
		# number of pronoun types
		self._n_pro_type = len(pronoun_type_debug) - 1 # don't count NOT FOUND
		# counts of unseen referents by pronoun types, got from lexicon
		self._u_pro_type = np.zeros(self._n_pro_type)
		# total number of unseen referents
		self._u_total = 0

		# parameter for prob of new referent
		self._alpha = _alpha
		# decay for discourse salience
		self._decay = _decay
		# discourse salience measure, p(r)| vals: ['freq','rec']
		self._salience = _salience

		# number of pronouns and proper names to predict
		self._n_samples = sum([len(doc) for doc in self.data])
		# V: number of ref. expressions that can refer to r, constant across referents
		self._V = self._n_samples / self._n_referents
		# word likelihood given referent p(w|r)
		self._p_w_r = 1.0 / self._V

		#class labels
		self.class_labels = ['PRP','NNP']
		# true class
		self._Y 	 = []
		# predicted class
		self._Y_pred = []

		#calc_unseen_counts lex_counts
		self.calc_unseen_counts(_lex_counts)

		# save metrics
		self._model_loglikelihood = 0.0
		self._total_acc = 0.0
		self._np_acc = 0.0		# proper name accuracy
		self._pro_acc = 0.0		# pronoun accuracy


	def calc_unseen_counts(self,_lex_counts):
		self._u_pro_type = _lex_counts[:self._n_pro_type]
		self._u_total = _lex_counts[-1]

	def speech_cost(self,word_id):
		#return len( self.vocab[word_id] )
		return np.log( len(self.vocab[word_id]) )
		#return np.log( len(self.vocab[word_id]) )+1.0

	'''
	p(r) : discourse salience of referent r up until now
	'''
	def get_salience(self,referent_id,ref_prev_mentions,last_mention_dist):
		p_r = 0
		if ref_prev_mentions==0:			# new referent
			pro_type = self.pronoun_type_ids[referent_id]
			p_r = self._alpha * self._u_pro_type[pro_type] / self._u_total
		else:
			if self._salience=='frequency':
				p_r = ref_prev_mentions						# frequency measure
			else:											# 'rec'
				p_r = np.exp(-last_mention_dist/self._decay)	# recency measure
		return p_r

	'''
	Sum over all potencial referents compatible with w. Sum_{r'} {p(w|r')*p(r')}
	'''
	def get_sum_potencial_referents(self,pos,referent_id,counts_state,mention_state,global_pos):
		# proper noun only refers to itself
		if pos[0]=='N':
			last_mention_dist = global_pos - mention_state[referent_id]
			prev_mentions = counts_state[referent_id]
			return self.get_salience(referent_id,prev_mentions,last_mention_dist)
		# pronoun spotted
		else:
			pro_type = self.pronoun_type_ids[referent_id]
			# get potencial referents ids in vocabulary
			cond = self.pronoun_type_ids==pro_type
			potencial_refs_ids = []
			for i,val in enumerate(cond):
				if val and self.vocab[i] not in third_person_prp:
					potencial_refs_ids.append(i)

			# sum over all potencial referents
			sum_p_r = 0
			for ref_id in potencial_refs_ids:
				last_mention_dist = global_pos - mention_state[ref_id]
				prev_mentions = counts_state[ref_id]
				if prev_mentions!=0:												# sum only active referents
					sum_p_r +=  self.get_salience(ref_id,prev_mentions,last_mention_dist)

			if sum_p_r==0:		# if PRP and there is no referent mentioned before
				return 0.0		# return 
			sum_p_r += self.get_salience(referent_id,0,0)		# add unseen entity prob

			return sum_p_r

	"""
	Predict between pronoun or proper name for each referent as discourse advances.
	"""
	def predict(self):
		self._model_loglikelihood = 0.0		# calculate model log likelihood as it predicts
		for doc in self.data:
			referent_counts = np.zeros(1000)
			last_mention = np.zeros(1000)
			for entity in doc:
				dep_rel = entity.dep_rel
				ref_id = entity.referent_id
				pro_type_name = pronoun_type_debug[ self.pronoun_type_ids[ref_id] ]

				true_pos = entity.pos[:3]
				# true label
				self._Y.append(self.class_labels.index(true_pos))
				# predicting refering expression (POS)
				pred_pos = self.class_labels.index('NNP')	# pick proper name by default

				#print("ref_exp:",self.vocab[entity.phrase])
				#print("referent:",self.vocab[ref_id])

				# discourse salience for referent r
				last_mention_distance = entity.global_pos - last_mention[ref_id]
				p_r = self.get_salience(ref_id,referent_counts[ref_id],last_mention_distance)

				# PROPER NAME CASE
				sum_p_r_np = self.get_sum_potencial_referents('NNP', ref_id, referent_counts, last_mention, entity.global_pos)
				cw_np = self.speech_cost(ref_id)
				log_speaker_np =  -np.log(sum_p_r_np) - np.log(cw_np)

				if abs(log_speaker_np)!=np.inf:
					self._model_loglikelihood += log_speaker_np

				#ipdb.set_trace()

				# PRONOUN CASE
				sum_p_r_pro = self.get_sum_potencial_referents('PRP', ref_id, referent_counts, last_mention, entity.global_pos)
				# speaker's cost: cross agreement and grammatical position
				log_speaker_pro = -np.inf
				for pro in pronounBydeprel[dep_rel]:
					if pro in pronounBytype[pro_type_name]:
						# valid pronouns
						pro_id = np.nonzero(self.vocab==pro)[0][0]
						cw_pro = self.speech_cost(pro_id)

						local_log_speaker_pro = 0.0
						if sum_p_r_pro==0:		# no referent active
							local_log_speaker_pro = -np.inf
						else:
							local_log_speaker_pro = -np.log(sum_p_r_pro) - np.log(cw_pro)
						log_speaker_pro = max(log_speaker_pro,local_log_speaker_pro)

						# add model log likelihood only if Ps != 0
						if local_log_speaker_pro!=-np.inf:
							self._model_loglikelihood += local_log_speaker_pro

				if log_speaker_pro > log_speaker_np:
					pred_pos = self.class_labels.index('PRP')

				# predicted label
				self._Y_pred.append(pred_pos)

				#update counts
				referent_counts[ref_id] += 1
				last_mention[ref_id] = entity.global_pos

				#print("Predicted label:",self.class_labels[pred_pos])
				#ipdb.set_trace()
			#END-FOR-DOC
		#END-FOR-DOCUMENTS
		self._Y = np.array(self._Y)
		self._Y_pred = np.array(self._Y_pred)


	def evaluate(self,verbose=True):
		class_accs = []
		for i in range(len(self.class_labels)):
			label = self.class_labels[i]
			acc_class = sum((self._Y==i) * (self._Y_pred==i)) / sum(self._Y==i)
			if label[0]=='P':
				self._pro_acc = acc_class
			else:
				self._np_acc  = acc_class
			class_accs.append(acc_class)
			if verbose:
				print("Accuracy %s: %.2f" % (label,acc_class*100))
		self._total_acc = accuracy_score(self._Y,self._Y_pred)
		if verbose:
			print("Total accuracy: %.2f" % (self._total_acc*100) )
			print("Model log likelihood: %.2f" % self._model_loglikelihood)
		if 0 in class_accs:
			return 0
		return self._total_acc



In [11]:
# %load alternative_models.py
from class_speaker_model import SpeakerModel
import numpy as np

class SM_NoDiscourse(SpeakerModel):
	'''
	p(r) : Uniform discourse salience
	'''
	def get_salience(self,referent_id,ref_prev_mentions,last_mention_dist):
		return 1.0 / self._n_referents

class SM_NoCost(SpeakerModel):
	def speech_cost(self,word_id):
		return 1

class SM_NoUnseen(SpeakerModel):
	def calc_unseen_counts(self,_lex_counts):
		self._u_pro_type = np.ones(self._n_pro_type)
		self._u_total = 1		#cte

In [12]:
data_conll_dir = "../datasets/semeval_2010_t1_eng/data"
data_noun_lexicon_dir = "../datasets/noun_gender_number/"


### Data and resources reading configuration

In [13]:
n_docs='all'
update_pron_types = False		# True: load agreement annotation for referents
update_lexicon = False			# True: load vocabulary intersection between genre_number lexicon and corpus

GN_counts = []
GN_np_vocab = []

### Reading training data

In [14]:
data_file = os.path.join(data_conll_dir,'en.train.txt')
docs = read_conll2010_task1_max_np(data_file,n_docs=n_docs)

### Read proper names present in genre_number lexicon

In [15]:
gn_read = False
if update_lexicon:
    print("Reading genre_number lexicon...")
    GN_counts, GN_np_vocab = read_noun_lexicon(data_noun_lexicon_dir)	# 1.9 GB of memory OMG
    gn_read = True
    lexicon = get_vocab_from_lexicon(docs,GN_np_vocab)
    saveObject(lexicon,'names_in_gnlexicon')
else:
    lexicon = uploadObject('names_in_gnlexicon')

### Data filtering and formating

In [16]:
docs = filter_entities(docs,lexicon)
formated_data,vocab = reformat_data(docs,lexicon)

### Number of pronouns and proper names in corpus, after filtering

In [17]:
pros = 0
nps = 0
for doc in docs:
    nps  += sum([1 for ent in doc if ent.pos[0]=='N'])
    pros += sum([1 for ent in doc if ent.pos[0]=='P'])

print("Pronouns: ",pros)
print("Proper nouns: ",nps)
print("Ref_vocab: ",len(vocab))

Pronouns:  251
Proper nouns:  1044
Ref_vocab:  458


### Annotate agreement information for referents

In [18]:
pro_type_ids = []
if update_pron_types:
    if not gn_read:
        print("Reading genre_number lexicon...")
        GN_counts, GN_np_vocab = read_noun_lexicon(data_noun_lexicon_dir)	# 1.9 GB of memory OMG
    pro_type_ids = annotate_pro_type(GN_counts,GN_np_vocab,vocab)

    n_pro_types = len(pronoun_type_debug)-1
    pro_type_counts = [sum(GN_counts[:,i]!=0) for i in range(n_pro_types)]
    pro_type_counts.append( GN_counts.shape[0] )

    saveObject(pro_type_ids,'pro_type_ids')
    saveObject(pro_type_counts,'pro_type_counts')
else:
    pro_type_ids = uploadObject('pro_type_ids')
    pro_type_counts = uploadObject('pro_type_counts')

In [19]:
# Check agreement information for first 20 phrases in vocabulary
for ii in range(len(pronoun_type_debug)-1):
    for i in range(20):
        if pro_type_ids[i]==ii:
            print("%80s : %15s" % (vocab[i],pronoun_type_debug[pro_type_ids[i]] ) )
    print("-"*100)

                                                                             god :        sing_mas
                                                                     mike jensen :        sing_mas
                                                                         richard :        sing_mas
----------------------------------------------------------------------------------------------------
                                                                           betsy :        sing_fem
----------------------------------------------------------------------------------------------------
                                                                         alabama :        sing_neu
                                                                   sunday school :        sing_neu
                                                                american medical :        sing_neu
                                                                 social security :        sing_neu
      

### Tune parameters for speaker model

In [20]:
#Finding optimum values of alpha and decay
alphas = [1e-4, 1e-3, 1e-2, 0.1,1,10]
decays = [1e-2,0.1, 1, 3, 5, 10,100,1e3]
salience_measures = ['recency','frequency']

a_opt, d_opt = 0,0
acc = 0

for salience in salience_measures:
    for alpha in alphas:
        for decay in decays:
            spk = SpeakerModel(	_data=formated_data,
                                _vocab=vocab,
                                _pronoun_type_ids=pro_type_ids,
                                _lex_counts=pro_type_counts,
                                _alpha=alpha,
                                _decay=decay,
                                _salience=salience)
            spk.predict()
            total_acc = spk.evaluate(verbose=False)
            if total_acc>acc:
                acc = total_acc
                a_opt=alpha
                d_opt=decay
        #END-FOR-DECAY
    #END-FOR-ALPHA
    print("%s: Optimum parameters:------" % salience.upper())
    print("   alpha:",a_opt)
    print("   decay:",d_opt)
    print("Training with optimum parameters...")
    spk = SpeakerModel(	_data=formated_data,
                        _vocab=vocab,
                        _pronoun_type_ids=pro_type_ids,
                        _lex_counts=pro_type_counts,
                        _alpha=a_opt,
                        _decay=d_opt,
                        _salience=salience)
    spk.predict()
    spk.evaluate()
    print("-------------------------------------------------------")

RECENCY: Optimum parameters:------
   alpha: 0.0001
   decay: 100
Training with optimum parameters...
Accuracy PRP: 86.85
Accuracy NNP: 73.18
Total accuracy: 75.83
Model log likelihood: 7303.38
-------------------------------------------------------
FREQUENCY: Optimum parameters:------
   alpha: 0.0001
   decay: 100
Training with optimum parameters...
Accuracy PRP: 79.28
Accuracy NNP: 71.93
Total accuracy: 73.36
Model log likelihood: 3883.20
-------------------------------------------------------


### Run speaker models

In [21]:
#alpha = 1e-4
#decay = 100
alpha = a_opt
decay = d_opt
salience_measures = ['recency','frequency']
print("Running models...")
print("  alpha: ",alpha)
print("  decay: ",decay)
print("-"*150)

print("%12s | %15s || %5s | %11s | %14s | %9s |" % ("Model","Discourse","T_ACC","Pronoun_ACC","ProperName_ACC","Log-lhood") )
print("="*100)
# Complete model
for salience in salience_measures:
    spk = SpeakerModel( _data=formated_data,
                        _vocab=vocab,
                        _pronoun_type_ids=pro_type_ids,
                        _lex_counts=pro_type_counts,
                        _alpha=alpha,
                        _decay=decay,
                        _salience=salience)
    spk.predict()
    spk.evaluate(verbose=False)
    print("%12s | %15s || %2.2f | %11.2f | %14.2f | %9.2f |" % 
            ("complete", salience, spk._total_acc*100, spk._pro_acc*100, spk._np_acc*100, spk._model_loglikelihood) )
print("-"*100)

# No discourse model
spk = SM_NoDiscourse(	_data=formated_data,
                        _vocab=vocab,
                        _pronoun_type_ids=pro_type_ids,
                        _lex_counts=pro_type_counts,
                        _alpha=alpha,
                        _decay=decay,
                        _salience=salience)
spk.predict()
spk.evaluate(verbose=False)
print("%12s | %15s || %2.2f | %11.2f | %14.2f | %9.2f |" % 
        ("-discourse", "NA", spk._total_acc*100, spk._pro_acc*100, spk._np_acc*100, spk._model_loglikelihood) )
print("-"*100)

# No cost model
for salience in salience_measures:
    spk = SM_NoCost( _data=formated_data,
                    _vocab=vocab,
                    _pronoun_type_ids=pro_type_ids,
                    _lex_counts=pro_type_counts,
                    _alpha=alpha,
                    _decay=decay,
                    _salience=salience)
    spk.predict()
    spk.evaluate(verbose=False)
    print("%12s | %15s || %2.2f | %11.2f | %14.2f | %9.2f |" % 
            ("-cost", salience, spk._total_acc*100, spk._pro_acc*100, spk._np_acc*100, spk._model_loglikelihood) )
print("-"*100)

# No estimates of unseen referents
for salience in salience_measures:
    spk = SM_NoUnseen(  _data=formated_data,
                        _vocab=vocab,
                        _pronoun_type_ids=pro_type_ids,
                        _lex_counts=pro_type_counts,
                        _alpha=alpha,
                        _decay=decay,
                        _salience=salience)
    spk.predict()
    spk.evaluate(verbose=False)
    print("%12s | %15s || %2.2f | %11.2f | %14.2f | %9.2f |" % 
            ("-unseen", salience, spk._total_acc*100, spk._pro_acc*100, spk._np_acc*100, spk._model_loglikelihood) )
print("-"*100)

Running models...
  alpha:  0.0001
  decay:  100
------------------------------------------------------------------------------------------------------------------------------------------------------
       Model |       Discourse || T_ACC | Pronoun_ACC | ProperName_ACC | Log-lhood |
    complete |         recency || 75.83 |       86.85 |          73.18 |   7303.38 |
    complete |       frequency || 73.36 |       79.28 |          71.93 |   3883.20 |
----------------------------------------------------------------------------------------------------
  -discourse |              NA || 66.49 |       39.04 |          73.08 |  14199.09 |
----------------------------------------------------------------------------------------------------
       -cost |         recency || 80.62 |        0.00 |         100.00 |   8091.49 |
       -cost |       frequency || 80.62 |        0.00 |         100.00 |   4671.30 |
----------------------------------------------------------------------------------------