# Legitimize probabilties:

The probabilities we're currently using right now (particularly in the zoomed-in 1000x1000 analysis) are a mess: they are normalized within the 1000x1000 window rather than over the entire corpus; I don't know what the marginal distribution really should be; etc. In this notebook I'll write down some conventions to use in the future.

w.r.t. individual lexical probability we have two choices:
1. `p(a) = {#(a) / |A|}` ... meaning, we normalize adj. freq by size of the adjective class
2. `p(a) = {#(a) / |corpus|}` ... meaning, we normalize by the total size of the corpus (# tokens)
3. `p(a) = {#(a) / (|A|+|N|)}` ... meaning, we normaize by the shared total size of Adj and Nouns.

with [2] and [3], `p(a)` and `p(n)` would share a common denominator, and be comparable (this is a plus!)

w.r.t joint probability, we have co-occurrence data $f_{ij} = \#(a_i n_j)$.
$$p(a_i n_j) = \frac{ \#(a_i n_j) }{
    \sum_{j=1}^{|N|} \sum_{i=1}^{|A|} \#(a_i n_j)
}$$

what is this the probability of? this is the probability of encountering Adj-N pair $a_in_j$ 
among all possible Adj-N pairs.
an alternative way to normalize this would have been to consider all two-item pairs of any UPOS, but that's
equivalent to the decision to be made between [2] and [3] above.



it is also true that
$$
p(a_in_j) \;\; \propto \;\; \#(a_in_j)
$$
because the denominator is common for each pair. so normalizing within our subset of interest was fine.

Let 
$T = \sum_{j=1}^{|N|} \sum_{i=1}^{|A|} \#(a_i n_j)$,
and
$T' = \sum_{j=1}^{1000} \sum_{i=1}^{1000} \#(a_i n_j)$

then 
$$
p(a_i n_j) =  \frac{\#(a_i n_j)}{T'} \cdot \frac{T'}{T} = p'(a_in_j) \frac{T'}{T}
$$

however: how does this change conditional probabilities?

$$
p(n_j | a_i) = \frac{
    \#(a_i n_j)
}{
    \sum_{j'=1}^{|N|} \#(a_in_{j'})
}
=
\frac{p(a_in_j)}{p(a_i)} = \frac{p'(a_in_j)}{p(a_i)}\frac{T'}{T} \; \propto \; \frac{p'(a_in_j)}{p(a_i)}
$$

## Now we'll load co-occurrence data

In [58]:
import pandas as pd
from composlang.utils import minmax
import numpy as np
from tqdm.auto import tqdm
from matplotlib import pyplot as plt
from scipy.special import logsumexp
import typing

In [4]:
adj_freq = pd.read_pickle(".//adj_freqs.pkl")
noun_freq = pd.read_pickle(".//noun_freqs.pkl")
pair_freq = pd.read_pickle(".//pair_freq.pkl")
freqs = dict(adj_freq=adj_freq, noun_freq=noun_freq, pair_freq=pair_freq)

In [7]:
A = sum(adj_freq.values())
N = sum(noun_freq.values())
print(f"{A = :,}    {N = :,}")

A = 34,994,090    N = 104,576,156


In [8]:
AN = sum(pair_freq.values())
print(f"{AN = :,}")

AN = 25,178,311


In [19]:
adj_p = {k: np.log(v / (A)) for k, v in adj_freq.items()}
noun_p = {k: np.log(v / (N)) for k, v in noun_freq.items()}
pair_p = {k: np.log(v / (AN)) for k, v in pair_freq.items()}

Now that these are all probability distributions, normalized over sensible normalizing factors, 
we can verify some of their properties.

They should all individually add up to 1.

In [20]:
sum(map(np.exp, adj_p.values())), sum(map(np.exp, noun_p.values())), sum(
    map(np.exp, pair_p.values())
)

(1.0000000000032314, 0.9999999999912396, 1.0000000001624214)

In [22]:
pd.to_pickle(adj_p, "./adj_p.pkl")
pd.to_pickle(noun_p, "./noun_p.pkl")
pd.to_pickle(pair_p, "./pair_p.pkl")

## Corpus-based conditionals 
Now within the subset $A_{<1000}, N_{<1000}$ we want to approximate (estimating these for the full joint distribution will take too long)
First, create a joint distribution matrix to make conditional normalizing operations easier.

Le `joint` be our joint distribution for the top 1000 lexical items.
`joint[i,j]` = $\log p_{corpus}(a_i, n_j)$ 

In [186]:
class JointDist:
    def __init__(
        self,
        adj_weights: typing.Collection[typing.Tuple[str, float]],
        noun_weights: typing.Collection[typing.Tuple[str, float]],
        pair_weights: typing.Collection[
            typing.Tuple[
                typing.Tuple[str, str],
                float,
            ]
        ],
        m=1_000,
    ):
        import numpy as np
        from scipy.special import logsumexp

        self.m = m
        self.adj_index: typing.Dict[str, int] = {}
        self.noun_index: typing.Dict[str, int] = {}
        self.adj_weights = []
        self.noun_weights = []

        for i, a in enumerate(adj_weights):
            if i >= m:
                break
            if a not in self.adj_index:
                self.adj_index[a] = len(self.adj_index)
                self.adj_weights.append(adj_weights[a])
        for i, n in enumerate(noun_weights):
            if i >= m:
                break
            if n not in self.noun_index:
                self.noun_index[n] = len(self.noun_index)
                self.noun_weights.append(noun_weights[n])

        self.joint = np.zeros((m, m)) - np.inf
        for an, p in tqdm(pair_weights.items(), total=len(pair_weights)):
            try:
                self[an.split()] = p
            except KeyError:
                pass
            except TypeError:
                print(an)

        # normalize to get a proper joint distribution
        self.joint -= logsumexp(self.joint)

    def get_index(self, adj, noun) -> typing.Tuple[int, int]:
        aix = self.adj_index[adj] if isinstance(adj, str) else adj or ...
        nix = self.noun_index[noun] if isinstance(noun, str) else noun or ...
        return aix, nix

    def __getitem__(self, key) -> float:
        if isinstance(key, str):
            key = key.split()
        ix = self.get_index(*key)
        return self.joint[ix]

    def __setitem__(self, key, value):
        ix = self.get_index(*key)
        self.joint[ix] = value

    def conditionalize(self, axis: int = 0) -> "JointDist":
        """
        axis 0 corresponds to marginalizing over adjectives to get p(n|a).
            we sum over axis 0 (sum the distribution corresponding to each adjective)
            and divide by it, so that row is left with p(n|a)
        axis 1 corresponds to marginalizing over nouns to get p(a|n) (by summing over axis 1)
            we sum over axis 1 (sum the distribution corresponding to each noun)
            and divide by it, so that column is left with p(a|n)
        """
        import copy
        from scipy.special import logsumexp

        if axis == 0:
            new_joint = self.joint - logsumexp(self.joint, axis=1 - axis)[:, None]
        elif axis == 1:
            new_joint = self.joint - logsumexp(self.joint, axis=1 - axis)
        else:
            raise ValueError("axis must be 0 or 1 for bivariate joint distribution")
        self_copy = copy.deepcopy(self)
        self_copy.joint = new_joint
        return self_copy

    def get_marginal_of_axis(self, axis: int) -> np.ndarray:
        """
        returns the marginal distribution (one-dimensional) along the specified axis
        axis 0 corresponds to marginalizing  to get p(a)
        """
        return self.joint.sum(axis=1 - axis)

In [166]:
logsumexp(joint.joint, axis=1)[987]

-8.877082674548296

In [167]:
logsumexp(joint["peculiar", ...])

-8.877082674548296

In [162]:
def invert_dict(d):
    return {v: k for k, v in d.items()}


invert_dict(joint.noun_index)[987]

'off-'

In [187]:
joint = JointDist(adj_p, noun_p, pair_p)

  0%|          | 0/4423676 [00:00<?, ?it/s]

other :   --
serum 25-hydroxyvitamin bone
poor t .
quick .   --
beautiful .   --
white can ,


In [188]:
logsumexp(joint[:, :])

0.0

In [206]:
cond = joint.conditionalize(axis=0)

## Excellent! We have a joint distribution that sums up to 1 (0) now!

In [208]:
pd.to_pickle(joint, "./joint1000.pkl")
pd.to_pickle(cond, "./cond_N_A_1000.pkl")