In [1]:
from collections import defaultdict, Counter
import collections
import pprint
import math
import bz2
import string
import argparse
from ngram import LM
from nlm_scorer import NlmScorer
import nlm
from copy import deepcopy
from datetime import datetime
# from multiprocessing import Pool
import torch
from torch.multiprocessing import Pool, set_start_method

try:
    torch.multiprocessing.set_start_method('spawn')
except RuntimeError:
    print('Start method already set to spawn!')

Start method already set to spawn!




In [4]:
lm = LM("data/6-gram-wiki-char.lm.bz2", n=6, verbose=False)
model = nlm.load_model("data/mlstm_ns.pt", cuda=False)
nlm = NlmScorer(model, cuda=False)

Reading language model from data/6-gram-wiki-char.lm.bz2...
Done.
Loading model data/mlstm_ns.pt..
Model on board!


In [7]:
def read_file(filename):
    if filename[-4:] == ".bz2":
        with bz2.open(filename, 'rt') as f:
            content = f.read()
            f.close()
    else:
        with open(filename, 'r') as f:
            content = f.read()
            f.close()
    return content

def check_limits(mappings, ext_limits, letter_to_check=0):
    if letter_to_check is None:
        targets = mappings.values()
        counts = Counter(targets).values()
        if any([count > ext_limits for count in counts]):
            return False
        else:
            return True
    else:
        plaintext_letters = list(mappings.values())
        return plaintext_letters.count(letter_to_check) <= ext_limits

def score_single_seq(t):
    i, seq = t
    # if len(seq) >= 20:
    #     print('Scoring:', seq)
    # return lm.score_seq(seq) if len(seq) < 20 else nlm.score_seq(seq)
    return lm.score_patial_seq(seq) if i != 0 else lm.score_seq(seq)

pool = Pool(12)

def score(mappings, cipher_text, lm, nlm):
    deciphered = [mappings[cipher_letter] if cipher_letter in mappings else ' ' for cipher_letter in cipher_text]
    deciphered = ''.join(deciphered)
    # bit_string = [ 'o' if c in mappings else '.' for c in cipher_text]
    # bit_string = ''.join(bit_string)
    seqs = deciphered.split()

    res = sum(pool.map(score_single_seq, zip(range(len(seqs)),seqs)))

    # return lm.score_bitstring(deciphered, bit_string)
    return res

def prune(beams, beamsize):
    sorted_beams = sorted(beams, key=lambda b: b[1], reverse=True)

    return sorted_beams[:beamsize]


def beam_search(cipher_text, lm, nlm, ext_order, ext_limits, init_beamsize):
    Hs = []
    Ht = []
    cardinality = 0
    Hs.append(({}, 0))
    Ve = string.ascii_lowercase
    scorer = lm

    while cardinality < len(ext_order):
        beamsize = int(init_beamsize*(0.94**cardinality))
#         beamsize = init_beamsize
        # if cardinality > 10:
        #     scorer = nlm
        print("Searching for {}/{} letter".format(cardinality, len(ext_order)))
        print("Current size of searching tree: {}".format(len(Hs)))
        cipher_letter = ext_order[cardinality]
        for mappings, sc in Hs:
            for plain_letter in Ve:
                ext_mappings = deepcopy(mappings)
                ext_mappings[cipher_letter] = plain_letter
                if check_limits(ext_mappings, ext_limits, plain_letter):  # only check new added one
                    Ht.append((ext_mappings, score(ext_mappings, cipher_text, lm, nlm)))
        Hs = prune(Ht, beamsize)
        cardinality += 1
        Ht = []
        # print(Hs)
    Hs.sort(key=lambda b: b[1], reverse=True)
    # pp.pprint(Hs)
    return Hs[0]

def contiguous_score(cipher, order):
    order = set(order)
    count = 0
    ngrams = defaultdict(int)
    for c in cipher:
        if c in order:
            if count == 8:
                ngrams[count] += 1
            else:
                count += 1
        else:
            if count != 0:
                ngrams[count] += 1
            count = 0
    if count != 0:
        ngrams[count] += 1
    weights = [0, 0, 1, 1, 1, 1, 1, 2, 3]
    score = 0
    for k, v in ngrams.items():
        score += weights[k] * v
    return score

def prune_orders(orders, beamsize):
    sorted_order = sorted(orders, reverse=True)

    return sorted_order[: beamsize]

def search_ext_order(cipher, beamsize):
    symbols = set(cipher)
    freq = Counter(cipher)
    start = ''
    maxf = 0
    for symbol, f in freq.items():
        if f > maxf:
            maxf = f
            start = symbol
    orders = [([0], [start])]
    orders_tmp = []
    symbols.remove(start)
    for i in range(len(symbols)):
        for scores, order in orders:
            for symbol in symbols:
                if symbol not in order:
                    new_order = deepcopy(order)
                    new_order.append(symbol)
                    new_scores = deepcopy(scores)
                    new_scores.insert(0, contiguous_score(cipher, new_order))
                    orders_tmp.append((new_scores, new_order))
        orders = prune_orders(orders_tmp, beamsize)
        orders_tmp = []
        # pp.pprint(orders)
    orders.sort(reverse=True)
    # pp.pprint(orders)
    return orders[0][1]

In [None]:
cipher = read_file('data/cipher.txt')
cipher = [x for x in cipher if not x.isspace()]
cipher = ''.join(cipher)
ext_order = search_ext_order(cipher, 100)
ext_limits = 8
beamsize = 1000000

print('Start deciphering...')
search_start = datetime.now()
mappings, sc = beam_search(cipher, lm, nlm, ext_order, ext_limits, beamsize)
search_end = datetime.now()
print('Deciphering completed after {}'.format(search_end - search_start))
print('Mappings: ', mappings)
deciphered = [mappings[c] if c in mappings else '_' for c in cipher]
deciphered = ''.join(deciphered)
print('Decipherment: {} \nscore: {}'.format(deciphered, sc))


def read_gold(gold_file):
    with open(gold_file) as f:
        gold = f.read()
    f.close()
    gold = list(gold.strip())
    return gold

def symbol_error_rate(dec, _gold):
    gold = read_gold(_gold)
    correct = 0
    if len(gold) == len(dec):
        for (d,g) in zip(dec, gold):
            if d==g:
                correct += 1
    wrong = len(gold)-correct
    error = wrong/len(gold)
    
    return error
    
# gold decipherment
gold_file = "data/ref.txt"
ser = symbol_error_rate(deciphered, gold_file)
print('Error: ', ser*100, 'Accuracy: ', (1-ser)*100)

Start deciphering...
Searching for 0/54 letter
Current size of searching tree: 1
Searching for 1/54 letter
Current size of searching tree: 26
Searching for 2/54 letter
Current size of searching tree: 676
Searching for 3/54 letter
Current size of searching tree: 17576
Searching for 4/54 letter
Current size of searching tree: 456976
