In [1]:
import sys
sys.path.append('/Users/Cybele/GIT/birkhoff/birkhoff/')
sys.path.append('/Users/Cybele/GIT/birkhoff/src/')
sys.path.append('/Users/Cybele/GIT/birkhoff/')

import string
import math
import random

import matplotlib.pyplot as plt
%matplotlib inline

from scipy.optimize import linear_sum_assignment

import autograd.numpy as np
import autograd.numpy.random as npr
import autograd.scipy as scipy
from autograd import grad
from autograd.optimizers import adam,sgd
from autograd.scipy.misc import logsumexp
from birkhoff.primitives import birkhoff_to_perm



In [2]:
# This is not very important: just define the methods for trimming and process the text

n_alph_car=26
# This function takes as input a decryption key and creates a dict for key where each letter in the decryption key
# maps to a alphabet For example if the decryption key is "DGHJKL...." this function will create a dict like {D:A,G:B,H:C....} 
def create_cipher_dict(cipher):
    cipher_dict = {}
    alphabet_list = list(string.ascii_uppercase[:n_alph_car])
    for i in range(len(cipher)):
        cipher_dict[alphabet_list[i]] = cipher[i]
    return cipher_dict


# This function takes a text and applies the cipher/key on the text and returns text.
def apply_cipher_on_text(text, cipher):
    cipher_dict = create_cipher_dict(cipher)
    text = list(text)
    newtext = ""
    for elem in text:
        if elem.upper() in cipher_dict:
            newtext += cipher_dict[elem.upper()]
        else:
            newtext += " "
    return newtext


# This function takes as input a path to a long text and creates scoring_params dict which contains the
# number of time each pair of alphabet appears together
# Ex. {'AB':234,'TH':2343,'CD':23 ..}
def create_scoring_params_dict(longtext_path):
    scoring_params = {}
    alphabet_list = list(string.ascii_uppercase[:n_alph_car])
    with open(longtext_path) as fp:
        for line in fp:
            data = list(line.strip())
            for i in range(len(data) - 1):
                alpha_i = data[i].upper()
                alpha_j = data[i + 1].upper()
                if alpha_i not in alphabet_list and alpha_i != " ":
                    alpha_i = " "
                if alpha_j not in alphabet_list and alpha_j != " ":
                    alpha_j = " "
                key = alpha_i + alpha_j
                if key in scoring_params:
                    scoring_params[key] += 1
                else:
                    scoring_params[key] = 1
    return scoring_params


def score_params_on_cipher(text):
    scoring_params = {}
    alphabet_list = list(string.ascii_uppercase[:n_alph_car])
    data = list(text.strip())
    for i in range(len(data) - 1):
        alpha_i = data[i].upper()
        alpha_j = data[i + 1].upper()
        if alpha_i not in alphabet_list and alpha_i != " ":
            alpha_i = " "
        if alpha_j not in alphabet_list and alpha_j != " ":
            alpha_j = " "
        key = alpha_i + alpha_j
        if key in scoring_params:
            scoring_params[key] += 1
        else:
            scoring_params[key] = 1
    return scoring_params


# This function takes the text to be decrypted and a cipher to score the cipher.
# This function returns the log(score) metric

def get_cipher_score(text, cipher, scoring_params):
    cipher_dict = create_cipher_dict(cipher)
    decrypted_text = apply_cipher_on_text(text, cipher)
    scored_f = score_params_on_cipher(decrypted_text)
    cipher_score = 0
    for k, v in scored_f.iteritems():
        if k in scoring_params:
            cipher_score += v * math.log(scoring_params[k])
    return cipher_score


# Generate a proposal cipher by swapping letters at two random location
def generate_cipher(cipher):
    pos1 = random.randint(0, len(list(cipher)) - 1)
    pos2 = random.randint(0, len(list(cipher)) - 1)
    if pos1 == pos2:
        return generate_cipher(cipher)
    else:
        cipher = list(cipher)
        pos1_alpha = cipher[pos1]
        pos2_alpha = cipher[pos2]
        cipher[pos1] = pos2_alpha
        cipher[pos2] = pos1_alpha
        return "".join(cipher)


# Toss a random coin with robability of head p. If coin comes head return true else false.
def random_coin(p):
    unif = random.uniform(0, 1)
    if unif >= p:
        return False
    else:
        return True


# Takes as input a text to decrypt and runs a MCMC algorithm for n_iter. Returns the state having maximum score and also
# the last few states 
def MCMC_decrypt(n_iter, cipher_text, scoring_params):
    current_cipher = string.ascii_uppercase  # Generate a random cipher to start
    state_keeper = set()
    best_state = ''
    score = 0
    for i in range(n_iter):
        state_keeper.add(current_cipher)
        proposed_cipher = generate_cipher(current_cipher)
        score_current_cipher = get_cipher_score(cipher_text, current_cipher, scoring_params)
        score_proposed_cipher = get_cipher_score(cipher_text, proposed_cipher, scoring_params)
        acceptance_probability = min(1, math.exp(score_proposed_cipher - score_current_cipher))
        if score_current_cipher > score:
            best_state = current_cipher
        if random_coin(acceptance_probability):
            current_cipher = proposed_cipher
        if i % 500 == 0:
            print "iter", i, ":", apply_cipher_on_text(cipher_text, current_cipher)[0:99]
    return state_keeper, best_state


## Run the Main Program:

scoring_params = create_scoring_params_dict('/Users/Cybele/GIT/MCMC-Cipher-Solver/war_and_peace.txt')

def trimm_text(alphabet_list, text):
    text_trimmed = ''
    for i in range(len(text)):
        if text[i].upper() in alphabet_list:
            text_trimmed +=text[i].upper()
    return text_trimmed

def calculate_N_from_text(alphabet_list, trimmed_text):
    K = len(alphabet_list)
    N = np.zeros((K , K))
    for i in range(len(trimmed_text)-1):
        if trimmed_text[i].upper() in alphabet_list and trimmed_text[i+1].upper() in alphabet_list:
            i0=alphabet_list.index(trimmed_text[i].upper())
            i1=alphabet_list.index(trimmed_text[i+1].upper())
            N[i0,i1]+=1
    return N


def decipher_text(alphabet_list, perm,text):
    str_final =''
    for t in text:
        
        ind = alphabet_list.index(t.upper())
        ind_perm = np.where(perm == ind)[0][0]
        str_final += alphabet_list[ind_perm]
    return str_final

def cipher_text(alphabet_list, perm,text):
    str_final =''
    for t in text:
        ind = alphabet_list.index(t.upper())
        ind_perm = perm[ind]
        str_final += alphabet_list[ind_perm]
    return str_final

In [3]:
#Define text
plain_text = "As Oliver gave this first proof of the free and proper action of his lungs, \
the patchwork coverlet which was carelessly flung over the iron bedstead, rustled; \
the pale face of a young woman was raised feebly from the pillow; and a faint voice imperfectly \
articulated the words, Let me see the child, and die. \
The surgeon had been sitting with his face turned towards the fire: giving the palms of his hands a warm \
and a rub alternately. As the young woman spoke, he rose, and advancing to the bed's head, said, with more kindness \
than might have been expected of him In the beginning was the Word, and the Word was with God, and the Word was God. \
He was with God in the beginning. Through him all things were made; without him nothing was made \
that has been made. In him was life, and that life was the light of all mankind. The light shines in the darkness, \
and the darkness has not overcome it. There was a man sent from God whose name was John. He came as a witness \
to testify concerning that light, so that through him all might believe. He himself was not the light; he came \
only as a witness to the light. The true light that gives light to everyone was coming into the world. \
He was in the world, and though the world was made through him, the world did not recognize him.\
He came to that which was his own, but his own did not receive him.  Yet to all who did receive him, to those who believed \
in his name, he gave the right to become children of God children born not of natural descent, \
nor of human decision or a husband’s will, but born of God.\
What we think, or what we know, or what we believe is, in the end, of little consequence. The only consequence is what we do\
The jour printer with gray head and gaunt jaws works at his case,\
He turns his quid of tobacco while his eyes blurr with the manu\
script"



alphabet_list = list(string.ascii_uppercase[:n_alph_car]) + [' ']

K = len(alphabet_list) # Number of characters in alphabet + space

log_M = np.zeros((K, K )) # matrix with observed (as the text is large, convergeces to 'true')  \
                          # log transition probabilities of characters in originalt text (war and peace)

for key in scoring_params:
    i0 = alphabet_list.index(key[0])
    i1 = alphabet_list.index(key[1])
    log_M[i0,i1] = np.log(scoring_params[key])


N = np.zeros((K, K)) # matrix with observed transitions in (trimmed) plain_text

text_trimmed = trimm_text(alphabet_list,plain_text)

perm_true = np.random.choice(range(K), K ,replace = False) # define a random permutation
#perm_true = range(K)
P_true = np.zeros((K, K))
P_true[np.arange(K), perm_true] = 1

ciphered_text =cipher_text(alphabet_list, perm_true, text_trimmed)
N = calculate_N_from_text(alphabet_list, ciphered_text)



In [20]:
#Here, the Revelant functions

def logistic(psi): return np.exp(psi)/(1+np.exp(psi))

def round_to_perm(P):
    """ Rounds matrix P to its closest permutation, using the slow hungarian"""
    
    K = P.shape[0]
    
    assert P.shape == (K, K)
    row, col = linear_sum_assignment(-P)
    P = np.zeros((K, K))
    P[row, col] = 1.0
    return P

def unpack_params_sparse(params, K, M):
    """Unpacks params in the sparse matrices logit_mu, logit_sigma, with sparsity pattern given by M. It could be improved
     by avoiding the for"""

    indexes = np.where(M.flatten() == 1)
    logit_mu = []
    logit_sigma = []
    
    cont = 0
    
    for i in range(K*K):
        if i in indexes[0]:
            logit_mu.append(params[:n_params][cont])
            logit_sigma.append(params[n_params:][cont])
            cont+=1
        else:
            logit_mu.append(0)
            logit_sigma.append(1)
    return np.reshape(np.array(logit_mu),(K,K)), np.reshape(np.array(logit_sigma),(K,K))
    

#Rowwise normalization
def rowwise_softmax(psi):
    """Rowwise normalization"""
    maxes = np.amax(psi, axis=1)
    maxes = maxes.reshape(maxes.shape[0], 1)
    e = np.exp(psi - maxes)
    dist = (e.T / np.sum(e, axis=1)).T
    return dist

def columnwise_softmax(psi):
    """Columnwise normalization"""
    maxes = np.amax(psi, axis=0)
    maxes = maxes.reshape(1, maxes.shape[0])
    e = np.exp(psi - maxes)
    dist = (e/ np.sum(e, axis=0))
    return dist

def log_prior(P, sigma_prior, K ):
    """Consider a product (coordinate-wise) of mixtures of two gaussians with std sigma_prior and centers at 0 and 1)"""
    mixture_centers = [0 , 1]
    differences = np.tile(np.reshape(P, (K,K,1)), (1, 1, 2)) - mixture_centers
    n = K * K 

    return np.sum(logsumexp(- (differences ** 2 / (2 * sigma_prior )**2), axis = 2)) - n/2 *np.log( 2* np.pi) \
    - n* np.log(2) - n*np.log(sigma_prior)

def log_prob(P, log_M, N):
    """
    M represents the real transition probabilities  (inferred from a text) of letter
    N is matrix containing the observed transitions from i-ths symbol to j-th symbol.

    p is a Permutation matrix, p[i,j] = 1 if the i-the symbol is ciphered to the j-th symbol"""

    
    return np.trace(np.dot(log_M.T, np.dot(P, np.dot(N, P.T))))


def sample_to_pi(sample, temp):
    
    return sample * temp + (1 - temp) * round_to_perm(sample) 
    

def get_samples(params, noise,  temp, limits, M):
    """ obtains a reparameterized variable (sample) and a one closer to the center of the voronoi cell"""
    K = noise.shape[1]
    logit_mu, logit_sigma = unpack_params_sparse(params, K, M)
    
    #Do ten iterations of sinkhorn propagation (I've seen with 3 we also do a good job)
    mu = np.exp(logit_mu*M)
    for j in range(10):
        
        mu = rowwise_softmax(np.log(mu*M))
        mu = columnwise_softmax(np.log(mu*M))
    
    sigma = limits[0] + (limits[1] - limits[0]) * logistic(logit_sigma)
    sample = noise * sigma + mu
    
    return (sample, np.array([sample_to_pi(sample[i,:,:], temp) for i in range(sample.shape[0])]))
    
def log_density_gaussian(params, temp, limits, K):
    
    logit_mu, logit_sigma = unpack_params_sparse(params,  K, M)
    sigma = limits[0] + (limits[1] - limits[0]) * logistic(logit_sigma)
    log_sigma =np.log(sigma)
    entropy = 0.5 * log_sigma.size * (1.0 + np.log(2 * np.pi)) + np.sum(log_sigma)
    
    return entropy + K *(K) * np.log(temp)

def variational_objective(params, t,  M, sigma_prior, limits_sigma):
    """Provides a stochastic estimate of the variational lower bound."""
    K = M.shape[0]
    
    noise = npr.randn(num_mcmc_samples, K, K)
    logit_mu, logit_sigma = unpack_params_sparse(params,  K, M)
    
    (samples, P_samples) = get_samples(params, noise, temperature(t), limits_sigma, M)
    elbo = 0
    
    for P in P_samples:
        elbo = elbo + log_prob(P,log_M, N) / num_mcmc_samples 
        elbo = elbo + log_prior(P, sigma_prior, K ) / num_mcmc_samples
    elbo = elbo + log_density_gaussian(params, temperature(t), limits_sigma, K)
        
    return [-elbo, log_prob(P_samples[0],log_M,N), log_prior(P, sigma_prior, K ), log_density_gaussian(params, temperature(t), limits_sigma, K)]
 

def callback(params, t, g, perline=10):
    """ Display something every perline iterations"""
    K=27
    num_mcmc_samples = 1
    elbos.append(variational_objective(params, t, M, sigma_prior, limits_sigma))
    params_all.append(params)
    
    """Provides a stochastic estimate of the variational lower bound."""
    def n_correct(P1,P2):
        return P1.shape[0]- np.sum(np.abs(P1-P2))/2.0

    def perm_to_P(perm):
        K=len(perm)
        P = np.zeros((K,K))
        P[range(K),perm] = 1 
        return P
    
    if (t % perline) == 0:
        sys.stdout.write('. [Iter {0}/{1}] VLB: {2:.1f}\n'.format(t, num_adam_iters, -elbos[-1][0]))
        print 'Different components of elbo' + ' ' + str(elbos[-1][1]) + ' ' +  str(elbos[-1][2]) + ' ' +  str(elbos[-1][3])
        
        noise = npr.randn(num_mcmc_samples,K, K)
        (sample, P_samples) = get_samples(params, noise, temperature(t), limits_sigma, M)
        perm_inferred = linear_sum_assignment(-P_samples[0])
        P_inf =perm_to_P(np.array(perm_inferred[1]))
          if t >0:
            print 'Currently decoded text: ' + decipher_text(alphabet_list,np.array(perm_inferred[1]),ciphered_text)[:143]
        else:
            print 'Ciphered text: ' + decipher_text(alphabet_list,np.array(perm_inferred[1]),ciphered_text)[:143]
        print 'For a random sample the number of correctly guessed characters was ' +str(n_correct(P_true,P_inf))
        print 'Log probability of non-rounded sample is ' + str(log_prob(P_samples[0],log_M,N))
        print 'Log probability of rounded sample is ' + str(log_prob(P_inf,log_M,N))
        print 'Log probability of solution is ' + str(log_prob(P_true,log_M,N))
      
    else:
        sys.stdout.write('.')
    sys.stdout.flush()
    

In [21]:
# Global parameters
npr.seed(0)

num_adam_iters = 1000
num_mcmc_samples = 1
limits_sigma =[0.00,1]
sigma_prior =0.1 #variance of mixture of gaussians prior

#Now encode soft contrains (P[i,j] = 0) using a mask "M". This matrix has one entries and 0 if P[i,j] =0 
#(a very small number here as we take logs). Notice I am selecting 2 soft contrains per character
M = np.ones((K,K))
listchar = npr.choice(K, K,replace= False)
for m in listchar:
    
    i = np.where(P_true[m,:] ==1)[0]
    
    random = npr.choice(K, 2, replace = False)
    M[m, [j for j in random if j not in [i]]] = 1e-8
    
def temperature(i): 
    """ for now lets keep the temperature fixed """
    tau0=0.75
    MIN_TEMP = 0.75
    ANNEAL_RATE = 0.1
    np_temp=np.maximum(tau0*np.exp(-ANNEAL_RATE*i),MIN_TEMP)
    return np_temp

indexes = np.where(M.flatten() ==1)[0] #find indexes of efective parameters
n_params = len(indexes)

init_mean = -10*np.ones(n_params)
init_logit_std = -5*npr.randn(n_params)
init_var_params = np.concatenate([init_mean, init_logit_std])


print("Variational inference for matching...")
print("Initializing with MAP estimate")

elbos=[]
# SGD with Adam

var_objective = lambda x,t: variational_objective(x, t , M, sigma_prior, limits_sigma)[0]
gradient = grad(var_objective)
elbos = []
params_all =[]

variational_params = adam(gradient, init_var_params, step_size=0.05, num_iters=num_adam_iters, callback=callback)


Variational inference for matching...
Initializing with MAP estimate
. [Iter 0/1000] VLB: 16088.8
Different components of elbo 18763.3601518 -2155.42485927 -519.089463801
For a random sample the number of correctly guessed characters was 0.0
log probability of non-rounded sample is-6973.83146065
Log probability of rounded sample is 11811.4166193
Log probability of solution is 17358.4371991
Ciphered text: NDECMHXYJEBNXYE AHDEGHJD ERJCCGECGE AYEGJYYENVSERJCRYJENO HCVECGEAHDEMWVBDE AYERN OATCJUEOCXYJMY ETAHOAETNDEONJYMYDDMLEGMWVBECXYJE AYEHJCVEKYSD
.......... [Iter 10/1000] VLB: 4817.7
Different components of elbo 7433.67065993 -2144.85947108 -471.087737698
For a random sample the number of correctly guessed characters was 1.0
log probability of non-rounded sample is-3097.37234709
Log probability of rounded sample is 7823.91366282
Log probability of solution is 17358.4371991
Currently decoded text: MRBGXVHWKBPMHWBNAVRBFVKRNBCKGGFBGFBNAWBFKWWBMYTBCKGCWKBMLNVGYBGFBAVRBXIYPRBNAWBCMNLADGKJBLG

KeyboardInterrupt: 

In [None]:
#This is for the analysis of solutions, not relevant for now (it plots the evolution of mu, variances and elbos, etc)
limits_sigma =[0.00,1]
noise = npr.randn(5,K, K)
print temperature(100)
print num_mcmc_samples
N_max = 1000
mu_all =np.zeros((N_max,(K) * K ))
sigma_all =np.zeros((N_max, (K ) * K ))
(sample, P_samples) = get_samples(params_all[-1], noise, temperature(20), limits_sigma, M)

for i in range(N_max):
    
    logit_mu, logit_sigma = unpack_params_sparse(params_all[i], K , M)
    
    #mu = rowwise_softmax(np.hstack((logit_mu, np.zeros((logit_mu.shape[0],1)))))
    mu = np.exp(logit_mu)
    for j in range(10):
        mu = rowwise_softmax(np.log(mu))
        mu = columnwise_softmax(np.log(mu))
        
    sigma = limits[1, 0] + (limits[1, 1] - limits[1, 0]) * logistic(logit_sigma)
    mu_all[i,:] = mu.flatten()
    sigma_all[i,:] = sigma.flatten()

plt.figure()
plt.subplot(221)
plt.plot(mu_all)
plt.subplot(222)
plt.plot(sigma_all)
plt.subplot(223)
plt.plot(np.array([elbos[i][-1] for i in range(N_max)]))
plt.subplot(224)
plt.plot(np.reshape(mu_all[-1], (K,K)))

plt.figure()
#print P_samples[0]
plt.imshow(P_samples[0,:,:])

In [None]:
## MCMC sampler, if want to compare
def perm_to_P(perm):
        K=len(perm)
        P = np.zeros((K,K))
        P[range(K),perm] = 1 
        return P
def P_to_perm(P):
    perm =[]
    for i in range(K):
        perm.append(np.where(P[i,:] ==1)[0][0])
    return perm

def swap_perm(P,i,j):
        perm = P_to_perm(P)
        perm_aux = np.copy(perm)
        perm[i] = perm_aux[j]
        perm[j] = perm_aux[i]
        
        return perm_to_P(perm)
def MCMC_solver(log_M,N,N_iter_MCMC_solver):
    K=log_M.shape[0]
    
    
    
    perm_initial = np.random.choice(range(K),K, replace = False)
    P_initial = perm_to_P(perm_initial)

    P = P_initial
    log_prob_current = log_prob(P,log_M,N,0)
    for i in range(N_iter_MCMC_solver):
        if  i % 100 ==0:
            print 'Iteration ' + str(i)
        indexes = np.random.choice(range(K), 2, replace = False)
        P_prop = swap_perm(P, indexes[0], indexes[1])
        log_prob_prop = log_prob(P_prop, log_M, N, 0)
        if(log_prob_prop > log_prob_current):
            P = P_prop
            log_prob_current  = log_prob_prop
        else:
            p = np.exp(log_prob_prop - log_prob_current)
            if(np.random.uniform(0, 1) < p):
                P = P_prop
                log_prob_current  = log_prob_prop
    perm = []
    #for i in range(K):
     #   perm.append(np.where(P[i,:] ==1)[0][0])
    #return np.array(perm)
    return np.array(P_to_perm(P))

In [None]:
N_iter_MCMC_solver = 1000
perm_MCMC = MCMC_solver(log_M,N,N_iter_MCMC_solver)
print decipher_text(alphabet_list,perm_MCMC,ciphered_text)
print log_prob(np.ones((27,27))/27, log_M,N,5)