# Background

In [4]:
# notes and figures from various sources
import pyplot.pyplot as py
import numpy as np
import json, yaml


ImportError: No module named pyplot.pyplot

Define a temperature and alpha parameter, fix words in lexicon, prior probabilities.

---

RSA Fundamentals:

Elements: literal listener $L_0$, pragmatic speaker $S_1$, pragmatic listener $L_2$, cost function on dialogue (fixed utterances?) $\kappa$, inverse temperature parameter $\alpha$, prior probabilities of referents $P$, prior probabilities (truth functions?) of utterances referring to referents $L$, current referent $r$, current utterance $u$.

---

Literal Listener:

$l_o(r\,|\,u, L)\,\propto\,L(u\,|\,r)\,P(r)$



Pragmatic Speaker:


$s_1(u\,|\,r, L)\,\propto\,e\,^{\alpha\log{l_o(r\,|\,u, L)}\,-\,\kappa(u)}$


Pragmatic Listener:

$l_2(r\,|\,u, L)\,\propto\,s_1(u\,|\,t,\,L)\,P(r)$


---

This back and forth nature addresses Grice's conversational implicature theory but poses an obvious paradoxical problem i.e., it is always optimal to soft-maximise the current previous agent i.e., use strategy $S_{n-1}$ at whatever iteration you are at and use Bayes ruleto invert their decision procedure. I think most of the work in this setting so far sets recursion depth = 3. 

Final equilibirium state currently is when only one agent is optimal wrt. the game, the other is $nearly$ optimal i.e., off by one iteration.

---

(read Goodman NIPS 13, Potts NAACL 18)

---


In [None]:
'''def get_s1(l0, alpha):
    l0_a = l0 * alpha; s1 = l0_a - logsumexp(l0_a, axis=2)
    return s1

def get_l2(s1):
    l2 = s1 - logsumexp(s1, axis=3)
    return l2

def get_lstar(l0, l2, bw):
    unnorm = (bw * l0 + (1 - bw) * l2)[:, :, 0, :]; lstar = unnorm - logsumexp(unnorm, axis=2)
    return lstar'''


# Old RSA

In [None]:
#shared lexicon, experiment with concatenated i.e., shared, individual. incorporate perturbations.

Assumptions made:

1. All agents have knowledge of a shared lexicon that they update priors over.
2. They have full knowledge of the cost function.
3. Main assumption is that there exists a single, shared lexicon that everyone should be using and I, as an agent, know that the other agent knows it and this is what I want to learn. 

# Neural RSA

In [None]:
#incorporate shared, individual. incorporate perturbations (?)

LSTM encoders output Gaussian distributions over referents, then softmax and classify for most probable referent.

# Testing implementation

In [None]:
import numpy as np
from operator import itemgetter
import os, sys, re, random
from collections import defaultdict
from itertools import combinations

def row_norm(m):
    return np.divide(m.T, np.sum(m, axis=1)).T

def col_norm(m):
    return np.divide(m, np.sum(m, axis=0))

def safe_log(x):
    with np.errstate(divide='ignore'):
        return np.log(x)
    
def inner_product(x, y):
    return np.dot(x, y)

def powerset(x, minsize=0, maxsize=None):
    result = []
    if maxsize == None: maxsize = len(x)
    for i in range(minsize, maxsize+1):
        for val in combinations(x, i): result.append(list(val))
    return result

def mean_sq_error(x, y):
    return np.mean((x-y)**2)

def display_matrix(m, rnames=None, cnames=None, title='', digits=4):
    rwidth = 2 + max([len(x) for x in rnames] + [digits+2])
    cwidth = 2 + max([len(x) for x in cnames] + [digits+2])
    m = np.round(m, digits)
    s = ''; divider = ''; linebreak = '\n';
    for i in range(m.shape[0]):
        rowcontents = divider.join(str(x).rjust(cwidth) for x in m[i, :])
        s += str(rnames[i]).rjust(rwidth) + divider + rowcontents + linebreak
    print s
    
m = np.matrix([[1.0, 2.0], [3.0, 4.0]])
m = row_norm(m); print m
m = col_norm(m); print m

#### Define encapsulating class

class Module:
    def __init__(self,
                lexica=None,
                baselexicon=None,
                states=None,
                costs=None,
                messages=None,
                prior=None,
                lexprior=None,
                lexcount=None, 
                temperature=1.0,
                alpha=1.0,
                beta=1.0,
                nullmsg=True,
                nullcost=5.0):
        self.lexica = lexica
        self.baselexicon = baselexicon
        self.states = states
        self.costs = costs
        self.messages = messages
        self.prior = prior
        self.lexprior = lexprior
        self.lexcount = lexcount
        self.temperature = temperature
        self.alpha = alpha
        self.beta = beta 
        self.nullmsg = nullmsg
        self.nullcost = nullcost
        
        #intialise base prior arrays 
        if type(self.prior) == type(None):
            val = 1.0/len(self.states)
            self.prior = np.repeat(val, len(self.states))
        if type(self.lexprior) == type(None) and self.lexcount != None:
            val = 1.0/len(self.lexcount)
            self.prior = np.repeat(val, len(self.lexcount))
        else:
            self.lexprior = defaultdict(lambda: 1.0)
        if type(self.costs) == type(None):
            self.costs = np.zeros(len(self.messages))
            if self.nullmsg:
                self.costs[-1] = self.nullcost
        self.final_listener = np.zeros((len(self.messages), len(self.states)))
        self.final_speaker = None  


####  Interaction iterative functions

    def rsa(self, lex=None):
        if lex is None: lex = self.baselexicon
        literal = self.l0(lex)
        speaker = self.S(literal)
        listener = self.L(speaker)
        return [literal, speaker, listener]

    def run_base_model(self, lex, n=2, display=True, digits=4):
        return self.run(
                    n=n, 
                    display=display, 
                    digits=digits,
                    initial_listener = self.l0(lex),
                    start_level=0)

    def run(self,
       initial_listener,
       n=2,
       display=True,
       digits=4,
       start_level=0,
       ):
    #langs
        langs = [initial_listener]
        for i in range(1, (n-1)*2, 2):
            langs.append(self.S(langs[i-1]))
            langs.append(self.L(langs[i]))
        
        if len(langs) < 2:
            self.final_speaker = None
            self.final_listener = langs[-1]
        else:
            self.final_speaker, self.final_listener = langs[-2:]
        
        if display:
            self.display_iteration(langs, start_level=start_level, digits=digits)
        return langs
    

#### Agents

    def l0(self, lex):
        return row_norm(lex*self.prior)

    def L(self, speaker):
        return self.l0(speaker.T)

    def S(self, listener):
        return row_norm(np.exp(self.temperature * ((self.alpha*safe_log(listener.T)) - self.costs)))

    def s1(self, lex):
        return self.S(self.l0(lex))

    def l1(self, lex):
        return self.L(self.s1(lex))

    def lex_likelihood(self):
        p = np.array([np.sum(self.s1(lex), axis=0) * self.lexprior[i] for i, lex in enumerate(self.lexica)])
        return col_norm(p)

    def listener_lexical_marginalisation(self, listener):
        return np.sum(listener, axis=1)

    def speaker_lexical_marginalisation(self, speaker):
        return row_norm(np.sum(speaker, axis=0))

#### $\rightarrow$ Display functions

    def display_expertise_iteration(self, langs, digits=4):
        """Display the full iteration for any the expertise model"""       
        level = 1
        for index in range(0, len(langs)-1, 2):
            self.display_joint_listener_matrices(
                langs[index], level=level, digits=digits)
            self.display_listener_matrix(
                self.listener_lexical_marginalization(langs[index]),
                title="{} - marginalized".format(level),
                digits=digits)                        
            level += 1
            self.display_expert_speaker_matrices(
                langs[index+1], level=level, digits=digits)
            self.display_speaker_matrix(
                self.speaker_lexical_marginalization(langs[index+1]),
                title='{} - marginalized'.format(level),
                digits=digits)
            
    def display_iteration(self, langs, start_level=0, digits=4):
        """Display the full iteration for any model except expertise"""
        self.display_listener_matrix(
            langs[0], title=start_level, digits=digits)        
        start_level += 1
        display_funcs = (self.display_speaker_matrix,
                         self.display_listener_matrix)
        for i, lang in enumerate(langs[1: ]):
            display_funcs[i % 2](lang, title=start_level, digits=digits)
            if i % 2: start_level += 1

    def display_speaker_matrix(self, mat, title='', digits=4):
        """Pretty-printed (to stdout) speaker matrix to standard output"""
        display_matrix(
            mat,
            title='S{}'.format(title),
            rnames=self.states,
            cnames=self.messages,
            digits=digits)

    def display_listener_matrix(self, mat, title='', digits=4):
        """Pretty-printed (to stdout) listener matrix to standard output"""
        display_matrix(
            mat,
            title='L{}'.format(title),
            rnames=self.messages,
            cnames=self.states,
            digits=digits)

    def display_joint_listener(self, mat, title='', digits=4):
        """Pretty-printed (to stdout) lexicon x world joint probability
        table for a given message"""
        lexnames = ['Lex%s: %s' % (i, self.lex2str(lex))
                    for i, lex in enumerate(self.lexica)]
        display_matrix(
            mat,
            rnames=lexnames,
            cnames=self.states,
            title=title,
            digits=digits)        

    def display_joint_listener_matrices(self, mats, level=1, digits=4):
        """Pretty-printed (to stdout) lexicon x world joint probability
        table for all messages"""
        [self.display_joint_listener(
            mat,
            title='L{} - {}'.format(level, self.messages[i]),
            digits=digits)
         for i, mat in enumerate(mats)]
        
    def display_expert_speaker_matrices(self, mats, level=1, digits=4):
        """Pretty-printed (to stdout) list of world x message
        conditional probability tables, one for each lexicon"""
        [self.display_speaker_matrix(
            mat,
            title='{} - Lex{} {}'.format(level, i, self.lex2str(self.lexica[i])),
            digits=digits)
        for i, mat in enumerate(mats)]



    def get_best_inferences(self, digits=4):    
        best_inferences = {}
        # Round to avoid tiny distinctions that don't even display:
        mat = np.round(copy(self.final_listener), 10)
        for i, msg in enumerate(self.messages):
            best_inferences[msg] = [(w, str(np.round(mat[i,j], digits)))
                                    for j, w in enumerate(self.states)
                                    if mat[i,j] == np.max(mat[i])]             
        return best_inferences   



     

In [None]:
# test!

if __name__ == '__main__':
    print 'Inside main!'
    
    #propositions and null message
    TT = [1.0, 1.0]; TF = [1.0, 0.0]; FT = [0.0, 1.0]
    nullsem = [1.0, 0.0]
    
    #logically distinct lexica
    lexica = [
        np.array([TT, TT, nullsem]),
        np.array([TT, TF, nullsem]),
        np.array([TT, FT, nullsem]),
        np.array([TF, TT, nullsem]),
        np.array([TF, TF, nullsem]),
        np.array([TF, FT, nullsem]),
        np.array([FT, TT, nullsem]),
        np.array([FT, TF, nullsem]),
        np.array([FT, FT, nullsem])]
    
    mod = Module(
    lexica = lexica,
    messages=['word_1', 'word_2', 'null'],
    costs=np.array([1.0, 2.0, 5.0]),
    states = ['ref_1', 'ref_2'],
    prior=np.array([2.0/3.0, 1.0/3.0]),
    lexprior=np.repeat(1.0/len(lexica), len(lexica)),
    temperature=3.0,
    alpha=1.0,
    beta=1.0)
    
    n = 2; 
    
    print 'n = ' + str(n) + '\n'
    print 'Creating models!\n'
    baselangs = mod.run_base_model(lexica[6], n=n, display=False)
    
    print 'Essentially RSA does:\nliteral = self.l0(lex)\nspeaker = self.S(literal)\nlistener = self.L(speaker)\n'
    print 'Final listener!\n'
    mod.display_listener_matrix(
        baselangs[-1],
        title=" - Base model")
    print 'Final speaker!\n'
    mod.display_speaker_matrix(
        baselangs[-2],
        title=" - Base model")
    
    
    
    
    