# Decoding a substitution cipher using Markov Chain Monte Carlo 

Suppose we are presented with some text that is encoded. Suppose we knew, for some reason, that the text is coded using a substitution cipher. For example, if the text at hand is `WAR AND PEACE`, and every character was mapped to the next character in the alphabet, with `Z` mapping to the space and spaces mapping back to `A`, then we get the cipher text `YBSABOEAQFBDF`.

Our goal is to decode such cipher text. In other words, given text that is encoded with a substitution cipher, we wish to identify the key that was used to encode the plain text, and apply the inverse of the key to decode the input. 

## Solution strategy

The solution strategy follows the usual idea for breaking substitution ciphers. If we have a probability function that assigns probabilities to English sentences based on character transitions, then we can frame the problem of identifying the key as searching over all keys to find the one whose decoded string has the highest probability of being English.

First, we need to build a function that can assign probabilities to any English sentence. We will use a first order or a second order Markov model to assign this probability. That is, using a first order Markov model, we will estimate the probability of `SLEEP` as $Pr(SLEEP) = Pr(S) Pr(L | S) Pr(E | L) Pr( E | E) Pr(P |E)$. Now, all we need to do is to construct these transition probabilities from a large corpus.

Once we have this, we will use a Markov Chain Monte Carlo (MCMC) method, specifically the Metropolis Hastings algorithm to search over the discrete space of all substitutions (i.e. the keys) to find the most probable one. This solution strategy is described in the introduction to Persi Diaconis' 2009 paper *The Markov Chain Monte Carlo revolution*.


## Initial setup

First, let us start with some boiler plate. We will restrict ourselves to the uppercase letters and numbers and read *War and Peace* to build transition probabilities. 

In [1]:
import scala.util.Random
import scala.io.Source
import $ivy.`org.vegas-viz::vegas:0.3.8`
import vegas._

val book = "data/war-and-peace.txt"
val epsilon = 1E-9

val random = new Random
val alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789 "
  

Checking https://repo1.maven.org/maven2/org/webjars/bower/vega/maven-metadata.xml.sha1
Checking https://repo1.maven.org/maven2/org/webjars/bower/vega/maven-metadata.xml
Checked https://repo1.maven.org/maven2/org/webjars/bower/vega/maven-metadata.xml.sha1
Checked https://repo1.maven.org/maven2/org/webjars/bower/vega/maven-metadata.xml


[32mimport [39m[36mscala.util.Random
[39m
[32mimport [39m[36mscala.io.Source
[39m
[32mimport [39m[36m$ivy.$                           
[39m
[32mimport [39m[36mvegas._

[39m
[36mbook[39m: [32mString[39m = [32m"data/war-and-peace.txt"[39m
[36mepsilon[39m: [32mDouble[39m = [32m1.0E-9[39m
[36mrandom[39m: [32mutil[39m.[32mRandom[39m = scala.util.Random@55d2cf6f
[36malphabet[39m: [32mString[39m = [32m"ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789 "[39m

## Building the bigram probability table

Next, let us build up the probability table of character bigrams for the English language using *War and Peace*. To do this, we will read all the data and count the character unigrams and bigrams that are attested in the text. We can define the bigram probabilities by smoothing ratios of the counts.

In [2]:
def count[T](items: Iterator[T]): Map[T, Int] = {
    val map = collection.mutable.Map[T, Int]()
    items.foreach { c =>
      if(map.contains(c)) map(c) = map(c) + 1
      else map(c) = 1
    }
    map.toMap
}

def smooth(numerator: Int, denominator: Int, numOptions: Int, smoothingFactor: Double): Double = {
    (numerator + smoothingFactor) / (denominator + numOptions * smoothingFactor)
}


lazy val transitions = {
    val bookText = Source.fromFile(book).getLines().
                    mkString (" ").map(_.toUpper).replaceAll("[^A-Z0-9 ]", " ").replaceAll("\\s+", " ")

    val alphabetSize = alphabet.size
    
    val alphaCounts = count(bookText.toIterator)

    val pairCounts = count(bookText.sliding(2))

    val bigramProbabilities = for (prev <- alphabet; current <- alphabet) yield {
        val bigram = s"$prev$current"
        val bigramCount = pairCounts.getOrElse(bigram, 0)
        val prevCount = alphaCounts.getOrElse(prev, 0)
        bigram -> smooth(bigramCount, prevCount, alphabetSize, epsilon)
    }

  bigramProbabilities.toMap
}


defined [32mfunction[39m [36mcount[39m
defined [32mfunction[39m [36msmooth[39m
[36mtransitions[39m: [32mMap[39m[[32mString[39m, [32mDouble[39m] = [32m<lazy>[39m

Using the transition table, we can define the log probability of some text as the sum of log probabilities of all the bigrams it contains.

In [3]:
def logProbability(text: String) =
  text.sliding(2).map(pair => math.log(transitions(pair.toString))).sum

logProbability("QZ")

defined [32mfunction[39m [36mlogProbability[39m
[36mres2_1[39m: [32mDouble[39m = [32m-28.472157174217898[39m

## Encoding and decoding sentences

Let us build up a basic substitution cipher here. The key used to encode text is simply a map from characters to characters. The process of encoding is simply mapping each character in the text to another.

In [4]:
def encode(text: String, key: Char => Char) = text map key

defined [32mfunction[39m [36mencode[39m

To decode cipher text, we will have to invert the key and apply it to the cipher text. 

In [5]:
def decode(text: String, key: Char => Char) = {
    val inverse = alphabet.map(char => key(char) -> char).toMap
    encode(text, inverse)
}

defined [32mfunction[39m [36mdecode[39m

Let us see if this works. We will first generate a random key by mapping the alphabet to a random version of itself. Using this, if we encode text and decode text, we should get back the original text.

In [6]:
def randomKey: Map[Char, Char] = {
    val shuffled = random.shuffle((0 until alphabet.size).toList)
    alphabet.zip (shuffled map alphabet).toMap
}

def testDecoder(plainText: String) = {
    val key = randomKey
    val coded = encode(plainText, key)
    val decoded = decode(coded, key)
    publish.html(s"<b>Plain text</b> $plainText")
    publish.html(s"<b>Cipher text</b> $coded")
    publish.html(s"<b>Decoded text</b> $decoded")

}

testDecoder("MARY HAD A LITTLE LAMB")

defined [32mfunction[39m [36mrandomKey[39m
defined [32mfunction[39m [36mtestDecoder[39m

## Deciphering codes

Finally, we have all the pieces needed for implementing the MCMC decoder. The state space we are navigating is that of keys. We need a way to propose a new key given a current one. To simplify things, we can do so by randomly transposing two elements of the key. For example, if a key has the mapping `A -> B` and `C -> D`, we could transpose them to get a key that maps `A -> D` and `C -> B`.

In [7]:
  def transpose(key: Map[Char, Char])  = {
    val i = key.keys.toList(random.nextInt(key.size))
    val j = key.keys.toList(random.nextInt(key.size))

    key map {
      pair =>
        val k = pair._1
        val v = pair._2
        if(k == i) k -> key(j)
        else if(k ==j) k -> key(i)
        else pair
    }
  }


defined [32mfunction[39m [36mtranspose[39m

Now the decoder itself. The basic approach is to perform the following steps for multiple iterations:
1. Use the current key to decode the cipher text and compute the log probability of the decoded text.
2. Propose a change to the key using the transpose function above and compute the log probability of the decoded text that uses the changed key.
3. If the proposed key has a better score than the current one, move to the proposed key. Otherwise, toss a coin whose bias is the ratio of the changed probability to the current one. If this coin says heads, then move to the proposed key. Otherwise, the key does not change.

Along the way, there is bookkeeping that tracks the log probabilities and prints the decoded text along the way.

In [8]:
def decrypt(cipherText: String, iters: Int) = {
    val scoringFunction = logProbability _

    var currentKey = randomKey
    
    var last = cipherText
    
    val scores = new collection.mutable.ListBuffer[Map[String, Double]]

    for(iter <- 0 to iters) {
        val decoded = decode(cipherText, currentKey)
        val score = scoringFunction(decoded)
        
        val changedKey = transpose(currentKey)
        val changedScore = scoringFunction(decode(cipherText, changedKey))

        if(changedScore > score) {
          currentKey = changedKey
        } else {
          val bias = changedScore - score
          if(math.log(random.nextDouble()) < bias) {
            currentKey = changedKey
          }
        }
        
        // book keeping for plotting and logging
        scores.append(Map("Iteration" -> iter, "Log probability" -> score))
        if((iter < 10000 && iter % (iters / 10).toInt == 0) || iter % 10000==0) {
            val diff = last.zip(decoded).count(p => p._1 != p._2)
            last = decoded
            publish.html(s"<li>[$iter: $diff differences] <code>$decoded</code></li>")
        }
        
    
    }
    
    // let us also plot the log probabilities over time
    Vegas().withData(scores).
        encodeX("Iteration", Quant).
        encodeY("Log probability", Quant).
        mark(Line).show
}

defined [32mfunction[39m [36mdecrypt[39m

## Putting it all together

Let us put it all together. We need some text to encode and decode. Let us read the text from some file in the file system.

In [9]:
def text(textfile: String) = Source.fromFile(textfile).getLines().
                    mkString(" ").toUpperCase.replaceAll("[^A-Z0-9 ]", " ").replaceAll("\\s+", " ")

defined [32mfunction[39m [36mtext[39m

Next, we will need a key that will encode the text. Let us generate it randomly and encode the text. 

In [10]:
val key = randomKey

[36mkey[39m: [32mMap[39m[[32mChar[39m, [32mChar[39m] = [33mMap[39m(
  [32m'E'[39m -> [32m'B'[39m,
  [32m'X'[39m -> [32m'S'[39m,
  [32m'8'[39m -> [32m'Z'[39m,
  [32m'4'[39m -> [32m'6'[39m,
  [32m'9'[39m -> [32m'U'[39m,
  [32m'N'[39m -> [32m'K'[39m,
  [32m'T'[39m -> [32m'3'[39m,
  [32m'Y'[39m -> [32m'F'[39m,
  [32m'J'[39m -> [32m'4'[39m,
  [32m'U'[39m -> [32m'P'[39m,
  [32m'F'[39m -> [32m'I'[39m,
[33m...[39m

In [11]:
def encodedText(file: String) = encode(text(file), key)
val cipherText = encodedText("data/2.txt")

defined [32mfunction[39m [36mencodedText[39m
[36mcipherText[39m: [32mString[39m = [32m"HKCQ3V3HQ3HTQCVKOCHKCQ3V3HQ3HTV9C1NFQHTQC3NBCAB3YJ1J9HQCNVQ3HK5QCV95JYH3NACHQCVCAVYLJDCTNVHKCAJK3BCTVY9JCATATCAB3NJOCIJYCJ83VHKHK5CVCQB2PBKTBCJICYVKOJACQVA19BQCIYJACVC1YJ8V8H9H3FCOHQ3YH8P3HJKCIJYCWNHTNCOHYBT3CQVA19HK5CHQCOHIIHTP93C3NHQCQB2PBKTBCTVKC8BCPQBOC3JCV11YJSHAV3BC3NBCOHQ3YH8P3HJKCBC5C3JC5BKBYV3BCVCNHQ3J5YVACJYC3JCTJA1P3BCVKCHK3B5YV9CQPTNCVQCVKCBS1BT3BOCDV9PBCAB3YJ1J9HQCNVQ3HK5QCVKOCJ3NBYCATATCV95JYH3NAQCVYBC5BKBYV99FCPQBOCIJYCQVA19HK5CIYJACAP93HCOHABKQHJKV9COHQ3YH8P3HJKQCBQ1BTHV99FCWNBKC3NBCKPA8BYCJICOHABKQHJKQCHQCNH5NCIJYCQHK59BCOHABKQHJKV9COHQ3YH8P3HJKQCJ3NBYCAB3NJOQCVYBCPQPV99FCVDVH9V89BCBC5CVOV13HDBCYB4BT3HJKCQVA19HK5C3NV3CTVKCOHYBT39FCYB3PYKCHKOB1BKOBK3CQVA19BQCIYJAC3NBCOHQ3YH8P3HJKCVKOCVYBCIYBBCIYJAC3NBC1YJ89BACJICVP3JTJYYB9V3BOCQVA19BQC3NV3CHQCHKNBYBK3CHKCATATCAB3NJOQC"[39m

Finally, let us see if the our MCMC based method can decipher the original text.

In [None]:
decrypt(cipherText, iters = 10000)