In [291]:
import os
import binascii
# hexdump function
def hexdump(src, offset=0, length=16):
    FILTER = ''.join([(len(repr(chr(x))) == 3) and chr(x) or '.' for x in range(256)])
    lines = []
    for c in xrange(0, len(src), length):
        chars = src[c:c+length]
        hex = ' '.join(["%02x" % ord(x) for x in chars])
        printable = ''.join(["%s" % ((ord(x) <= 127 and FILTER[ord(x)]) or '.') for x in chars])
        lines.append("%04x  %-*s  %s\n" % (c+offset, length*3, hex, printable))
    return ''.join(lines)

def hexdump_diff(src, dst, offset=0, length=16):
    FILTER = ''.join([(len(repr(chr(x))) == 3) and chr(x) or '.' for x in range(256)])
    lines = []
    for c in xrange(0, len(src), length):
        chars = src[c:c+length]
        chars2 = dst[c:c+length]
        hex = ' '.join(["%02x" % ord(x) for x in chars])
        printable = ''.join(["%s" % ((ord(x) <= 127 and FILTER[ord(x)]) or '.') for x in chars])
        printable2 = ''.join(["%s" % ((ord(x) <= 127 and FILTER[ord(x)]) or '.') for x in chars2])
        # lines.append("%04x  %-*s  %s\n" % (c+offset, length*3, hex, printable))
        lines.append("%04x  %s  %s\n" % (c+offset, printable, printable2))
    return ''.join(lines)

def hexdump_2x2(section1, section2, offset=0, length=16):
    FILTER = ''.join([(len(repr(chr(x))) == 3) and chr(x) or '.' for x in range(256)])
    lines = []
    s1a = section1.section_out
    s1b = section1.show(section2.xor_key)
    s2a = section2.section_out
    s2b = section2.show(section1.xor_key)
    x = 2 + len("%x" % offset)
    lines.append(" "*x + "p1key1: %s  "%section1.hex_key + "p1key2: %s  "%section2.hex_key + 
                 "p2key2: %s  "%section2.hex_key + "p2key1: %s"%section1.hex_key + "\n")
    for c in xrange(0, len(s1a), length):
        chars1 = s1a[c:c+length]
        chars2 = s1b[c:c+length]
        chars3 = s2a[c:c+length]
        chars4 = s2b[c:c+length]
        printable1 = ''.join(["%s" % ((ord(x) <= 127 and FILTER[ord(x)]) or '.') for x in chars1])
        printable2 = ''.join(["%s" % ((ord(x) <= 127 and FILTER[ord(x)]) or '.') for x in chars2])
        printable3 = ''.join(["%s" % ((ord(x) <= 127 and FILTER[ord(x)]) or '.') for x in chars3])
        printable4 = ''.join(["%s" % ((ord(x) <= 127 and FILTER[ord(x)]) or '.') for x in chars4])
        # lines.append("%04x  %-*s  %s\n" % (c+offset, length*3, hex, printable))
        lines.append("%04x  %s  %s  %s  %s\n" % (c+offset, printable1, printable2, printable3, printable4))
    return ''.join(lines)

def hexdump_kt(section1, section2, offset=0, length=16):
    FILTER = ''.join([(len(repr(chr(x))) == 3) and chr(x) or '.' for x in range(256)])
    lines = []
    s1 = section1.section_out
    s2 = section2.section_out
    x = 2 + len("%x" % offset)
    lines.append(" "*x + "Section 1 Hexdump "+ section1.hex_key + " "*24 + "Section 2 ASCII "+ section2.hex_key+"\n")
    for c in xrange(0, len(s1), length):
        chars1 = s1[c:c+length]
        chars2 = s2[c:c+length]
        hex1 = ' '.join(["%02x" % ord(x) for x in chars1])
        printable2 = ''.join(["%s" % ((ord(x) <= 127 and FILTER[ord(x)]) or '.') for x in chars2])
        lines.append("%04x  %-*s  %s\n" % (c+offset, length*3, hex1, printable2))
    return ''.join(lines)

# xor function
from itertools import cycle
def xor(data, key):
    return ''.join([chr(ord(c1) ^ ord(c2)) for c1,c2 in zip(data, cycle(key))])

In [292]:
# define sections
class Section(object):
    def __init__(self, data, start, length):
        self._start = start
        self._length = length
        self._data = data[start:start+length]
        self.reset()
    
    @property
    def start(self):
        return self._start

    @property
    def data(self):
        return self._data

    @property
    def hex_key(self):
        return binascii.hexlify(self.xor_key)

    def reset(self):
        self.xor_freq = None
        self.xor_key = None
        self.section_out = None

    def reset_xor(self, cb):
        key, stats = cb(self._data)
        self.xor_freq = stats
        self.xor(binascii.unhexlify(key))
        return key
    
    def xor(self, key):
        assert len(key) == 4
        # key is not hexlified here
        self.xor_key = key
        self.section_out = xor(self._data, key)
        return self.section_out

    def show(self, key):
        assert len(key) == 4
        return xor(self._data, key)

    def __eq__(self, o):
        return self._start == o._start
    
class FirmwarePart(object):
    def __init__(self, fname, start, length):
        self.filename = fname
        self.data = open(fname, 'rb').read()
        self.start = start
        self.length = length
        self.sections = []
        self._header = self.data[:start]
        for _start in range(start, len(self.data), length):
            self.sections.append(Section(self.data,_start,length))
        self.reset()

    def reset(self):
        [section.reset() for section in self.sections]
        return
    
    @property
    def name(self):
        return os.path.basename(self.filename)
    
    def __iter__(self):
        return iter(self.sections)

    def __getitem__(self, address):
        return self.sections[(address-self.start)/self.length]

    def apply_xor(self, key_list):
        assert type(key_list) == list
        for i, (section, key) in enumerate(zip(self.sections, key_list)):
            section.xor(binascii.unhexlify(key))
            words = map(binascii.hexlify, map(''.join, zip(*[iter(section._data)]*4)))
            section.xor_freq = Counter(words)
        print "%d/%d sections Xor-ed" % (i+1, len(self.sections))
    
    def search_xor_keys_by_addr(self, cb, start=0, stop=None):
        key, stats = cb(self.data[start:stop])
        assert len(key) == 8
        return key, stats

    def search_xor_keys(self, cb, first=0, last=None):
        for i, section in enumerate(self.sections[first:last]):
            section.reset_xor(cb)
        keys = self.get_section_keys()
        #
        print len(keys), "sections/keys",
        nb_distinct = len(Counter(keys).most_common())
        print "and", nb_distinct, "distinct keys"
        return keys

    def get_section_keys(self):
        return [binascii.hexlify(section.xor_key) for section in self.sections]

    def get_section_freqs(self):
        return [section.xor_freq for section in self.sections]
    
    def get_xor_byte_for(self, target_char, address):
        section = self[address]
        c = section.data[address-section.start]
        return chr(ord(c) ^ ord(target_char))
    
    def write_to_file(self, filename):
        with open(filename, 'wb') as fout:
            fout.write(self._header)
            for s in self.sections:
                if s.section_out is None:
                    print "section 0x%s has not unxored" % s.hex_key
                    s.xor(binascii.unhexlify('00000000'))
                fout.write(s.section_out)
            fout.close()
        return

    def __len__(self):
        return len(self.data)

In [293]:
import struct
def and_(w1, w2):
    return struct.pack("I", struct.unpack("I", w1)[0] & struct.unpack("I", w2)[0])

from collections import Counter
def find_xor_key(_section):
    # try to find key by finding 0x00000000
    #words = map(''.join, zip(*[iter( binascii.hexlify(_section) )]*8))
    words = map(binascii.hexlify, map(''.join, zip(*[iter(_section)]*4)))
    freqs = Counter(words)
    key = freqs.most_common(1)[0][0]
    # unhex_key = binascii.unhexlify(key)
    return key, freqs
# Counter.most_common() is the winner

def find_xor_key_3(_section):
    '''
    lets get the next 9 top keys, compare them to the first choice
    and & them to find similar bits.
    '''
    # select the top choice for Null words based on frequency analysis.
    key, freqs = find_xor_key(_section)
    best, freq = freqs.most_common(1)[0]
    assert key == best
    if freq >= len(_section)/(4*4): # 25% is representive of zeros punched holes. 32 occurrences per section.
        return key, freqs
    # otherwise try anoth layer of frequency analysis
    # select the next 9 to && them and recover best probable XOR key
    keys = [binascii.unhexlify(x) for x,c in freqs.most_common() if c > 1]
    if len(keys) == 0:
        return key, freqs
    #keys = [binascii.unhexlify(x) for x,c in freqs.most_common()]
    tmp = keys[0]
    fixed_possible = []
    for x in keys[1:]:
        tmp2 = xor(tmp, x)
        tmp2_i = struct.unpack("I",tmp2)[0]
        bc_ = bin(tmp2_i).count("1")
        if bc_ < 5:
            fixed = and_(tmp, x)
            # print 'possible partial key %d bits' % bc_, binascii.hexlify(x), 'XOR3', binascii.hexlify(tmp2), "FIXED", binascii.hexlify(fixed)
            fixed_possible.append(binascii.hexlify(fixed))
    if len(fixed_possible) == 0:
        # print "no meta choices for this one"
        return key, freqs
    res = Counter(fixed_possible)
    best = res.most_common(1)[0][0]
    return best, res


In [294]:
def get_address_from_i(i):
    return 0x800+(i*0x200)

def get_i_from_address(addr):
    return (addr-0x800)/0x200    
    
def key_rev_case(key):
    return xor(key, binascii.unhexlify('20202020'))

def key_reverse(key):
    return xor(key, binascii.unhexlify('ffffffff'))

def hkey_reverse(key):
    return binascii.hexlify(key_reverse(binascii.unhexlify(key)))

def show_rev_case(section, key=None):
    addr = section._start
    i = get_i_from_address(addr)
    if key is None:
        key = section.xor_key
    key = key_rev_case(key)
    hex_key = binascii.hexlify(key)
    print hex_key
    print hexdump(section.show(key))
    print "if REV_CASE better", change_string(i, hex_key, addr, "revcase")
    return hex_key

# replace xor key due to known text attack
def known_text_attack(_sections, address, target_bytes):
    assert len(target_bytes) >= 4
    k = ['\x00','\x00','\x00','\x00']
    for i in range(0,4,1):
        # we expect to see target_bytes instead of that garbage
        k[i] = _sections.get_xor_byte_for(target_bytes[i], address + i)
    k = ''.join(k)
    return k, binascii.hexlify(k)

def try_known_text_attack(_sections, address, known_text):
    # handle border cases, cross section
    # we expect 'known_text'
    i = get_i_from_address(address)
    key,hkey = known_text_attack(_sections, address, known_text)
    _section = _sections[address]
    new_out = _section.show(key)
    print hexdump_diff(_section.section_out, new_out,_section._start)
    #print hexdump(new_out, _section._start)
    print "if KNOWN_TEXT seems right,", change_string(i, hkey, address, "kt")
    return hkey

def try_known_text_attack_before(_sections, address, known_text):
    # handle border cases, cross section
    # we expect 'known_text'
    i = get_i_from_address(address)
    key,hkey = known_text_attack(_sections, address, known_text)
    _section = _sections[address]
    new_out = _section.show(key)
    before = _sections[address-0x200]
    print hexdump(before.section_out[0x100:], before._start+0x100)
    print hexdump(new_out[:0x100], _section._start)
    print "if KNOWN_TEXT seems right,", change_string(i, hkey, address, "kt")
    return hkey

def try_known_text_attack_after(_sections, address, known_text):
    # handle border cases, cross section
    # we expect 'known_text'
    i = get_i_from_address(address)
    key,hkey = known_text_attack(_sections, address, known_text)
    _section = _sections[address]
    new_out = _section.show(key)
    after = _sections[address+0x200]
    print "if KNOWN_TEXT below seems right,", change_string(i, hkey, address, "kt")
    print hexdump(new_out[0x100:], _section._start+0x100)
    print hexdump(after.section_out[:0x100], after._start)
    return hkey

def comp(section1, section2):
    x1 = section1.hex_key
    x2 = section2.hex_key
    addr = section1._start
    i = get_i_from_address(addr)
    
    # difference in key
    if x1 == x2:
        print "****** SAME KEY %s ******" % x1
    print 'section1', x1, x2
    # print hexdump(section1.section_out, addr)
    # print hexdump_diff(section1.section_out, section1.show(binascii.unhexlify(x2)), addr)
    print "if column 2 better,", change_string(i, x2, addr, "comp")

    # print 'section2', x2, x1
    # print hexdump(section2.section_out, addr)
    # print hexdump_diff(section2.section_out, section2.show( binascii.unhexlify(x1)), addr)
    print "if column 4 better,", change_string(i, x1, addr, "comp")
    
    print hexdump_2x2(section1, section2, addr)
    print "Otherwise add %d to ignore" % i
    
    # print 'unity1 find_xor_key_3'
    find_xor_key_3(section1._data)
    # print 'unity2 find_xor_key_3'
    find_xor_key_3(section2._data)
    
def comp_at(s1, s2,  address):
    i = get_i_from_address(address)
    comp(s1.sections[i], s2.sections[i])
    return s1.sections[i], s2.sections[i], i

def change_string(i, key, address, r=''):
    return "change_key( %d, '%s') # %s @0x%0.8x" % (i, key, r, address)

def try_key(section, new_key):
    addr = section._start
    i = get_i_from_address(addr)
    print "if KEY below is better", change_string(i, new_key, addr)
    print hexdump(section.show(binascii.unhexlify(new_key)), section._start)

def show_orig(_sections, address):
    _section = _sections[address]
    print hexdump(_section._data, _section._start)
    
def try_reverse_key(_section, key=None):
    addr = _section._start
    i = get_i_from_address(addr)
    if key is None:
        key = _section.xor_key
    else:
        key = binascii.unhexlify(key)        
    key = key_reverse(key)
    hex_key = binascii.hexlify(key)
    #print hex_key
    print "if REVERSE_KEY below is better", change_string(i, hex_key, addr)
    print hexdump(_section.show(key))
    return hex_key


In [6]:
## keys bit pattern
def show_bit_h(_slice, bitnum):
    return "".join( [bin(i)[2:].zfill(32)[bitnum] for i in int_keys[_slice]])

def show_bit_sliced(_slice, bitrange=range(0,32)):
    for bitnum in bitrange:
        print "bit %02d" % bitnum, show_bit_h(_slice, bitnum)

def show_bit_vertical_hexdump(cb, bitnum, step, _slice):    
    print "bit %02d, width: %d, keys %d->%d" % (bitnum, step, _slice.start, _slice.stop)
    #for num in range(1512, 1680, step):
    allbits = cb(_slice, bitnum)
    bits = ''
    for num in range(_slice.start, _slice.stop, step):
        addr = get_address_from_i(num)
        # print r, hex(o), num, _slice
        prevbits = bits
        bits = allbits[num-_slice.start: num+step-_slice.start]
        #p = pattern.findPatternText(bits)
        print "%08x" % addr, bits, "same as prev: ", prevbits == bits#, p
    #print '\npattern:', pattern.findPatternText(show_bit_h(_slice, bitnum))

def show_bit_vhexdump_diff(bitnum, step, _slice):    
    print "bit %02d, width: %d, keys %d->%d" % (bitnum, step, _slice.start, _slice.stop)
    print " "*8,"From Keys", "Generation".zfill(step*2-20).replace("0"," ")
    #for num in range(1512, 1680, step):
    allbits = show_bit_h(_slice, bitnum)
    allgenbits = gen_bit_h(_slice, bitnum)
    for num in range(_slice.start, _slice.stop, step):
        diff = ''
        addr = get_address_from_i(num)
        bits1 = allbits[num-_slice.start: num+step-_slice.start]
        bits2 = allgenbits[num-_slice.start: num+step-_slice.start]
        if bits1 != bits2:
            diff = '  <<<<<'
        print "%08x" % addr, bits1, " | ", bits2, diff
    return

def show_bit_vhexdump(bitnum, step, _slice, cb=show_bit_h):
    print 'From keys:',
    return show_bit_vertical_hexdump(show_bit_h, bitnum, step, _slice)

def show_bit_vhexdump_gen(bitnum, step, _slice):
    print 'From Gen :',
    return show_bit_vertical_hexdump(gen_bit_h, bitnum, step, _slice)

def gen_bit_h(_slice, bitnum):
    global generators
    generator = generators[bitnum]
    generator.reset()
    return "".join( [b for b in generator.generate(_slice)])

def alt_ticks_gen(mod_val, ticks_val1, retval1, ticks_val2, retval2, val):
    """ return retval1 at specific ticks_val1
     return retval2 at specific ticks_val2
    """
    m = val % mod_val
    if m in tick_val1 and m in tick_val2:
        raise RuntimeError("alt_ticks_gen: %d (%d\%%d) are both in tick1 and tick2" %(m, val, mod_val))
    elif m in tick_val1:
        return retval1
    elif m in tick_val2:
        return retval2
    return False

def find_inverts_bit_pattern(data, invert_p1, invert_p2):
    p = data
    i = 0
    i_s = []
    p1 = invert_p1
    p2 = invert_p2
    while i >=0 and i < len(p):
        try:
            f1 = p.index(p1, i) + len(p1)
        except ValueError:
            f1 = len(p)
        try:
            f2 = p.index(p2, i) + len(p2)
        except ValueError:
            f2 = len(p)
        if f1 == f2:
            #print "f1==f2", f1, len(p)
            break
        elif f1 < f2:
            i_s.append((p1, f1))
            i = f1
        else: # f2 < f1:
            i_s.append((p2, f2))
            i = f2
    # debug
    #for p,i in i_s:
    #    print i, 
    #print ''
    return i_s

def find_inverts_diff2(_slice, bitnum, invert_patterns):
    offset = _slice.start
    ret1 = []
    ret2 = []
    def p_info(i1, i1_diff, pat1, i2, i2_diff, pat2):
        i1+=offset
        i2+=offset
        extra=''
        if pat1 != pat2 or i1 != i2:
            extra=' <<<'
        ret1.append("(%d) %d %s" % (i1, i1_diff, pat1))
        ret2.append("(%d) %d %s" % (i2, i2_diff, pat2))
        
    data1 = show_bit_h(_slice, bitnum)
    data2 = gen_bit_h(_slice, bitnum)
    p1 = invert_patterns[bitnum][0]
    p2 = invert_patterns[bitnum][1]
    print "bit %02d Orig versus Generated" % bitnum
    i_s1 = find_inverts_bit_pattern(data1, p1, p2)
    i_s2 = find_inverts_bit_pattern(data2, p1, p2)
    pat1, indice_prev_1 = i_s1[0]
    pat2, indice_prev_2 = i_s2[0]
    p_info(indice_prev_1, 0, pat1,indice_prev_2, 0, pat2)
    for ind, ((pat1, indice1), (pat2,indice2)) in enumerate(zip(i_s1[1:], i_s2[1:])):
        #pat2,indice2 = i_s2[ind+1]
        p_info(indice1, indice1-indice_prev_1, pat1,
               indice2, indice2-indice_prev_2, pat2)
        indice_prev_1 = indice1
        indice_prev_2 = indice2
    # difflib
    import difflib
    for x in difflib.ndiff(ret1, ret2):
        print x
    print ''
    return


def find_inverts_diff(_slice, bitnum, invert_patterns):
    offset = _slice.start
    def p_info(i1, i1_diff, pat1, i2, i2_diff, pat2):
        i1+=offset
        i2+=offset
        extra=''
        if pat1 != pat2 or i1 != i2:
            extra=' <<<'
        print "(%d) %0.2d %s | (%d) %0.2d %s %s" % (i1, i1_diff, pat1, i2, i2_diff, pat2, extra)
        
    data1 = show_bit_h(_slice, bitnum)
    data2 = gen_bit_h(_slice, bitnum)
    p1 = invert_patterns[bitnum][0]
    p2 = invert_patterns[bitnum][1]
    print "bit %02d Orig versus Generated" % bitnum
    i_s1 = find_inverts_bit_pattern(data1, p1, p2)
    i_s2 = find_inverts_bit_pattern(data2, p1, p2)
    pat1, indice_prev_1 = i_s1[0]
    pat2, indice_prev_2 = i_s2[0]
    p_info(indice_prev_1, 0, pat1,indice_prev_2, 0, pat2)
    for ind, ((pat1, indice1), (pat2,indice2)) in enumerate(zip(i_s1[1:], i_s2[1:])):
        #pat2,indice2 = i_s2[ind+1]
        p_info(indice1, indice1-indice_prev_1, pat1,
               indice2, indice2-indice_prev_2, pat2)
        indice_prev_1 = indice1
        indice_prev_2 = indice2
    print ''
    return


def find_inverts(_slice, bitnum, invert_patterns, with_diff=True):
    offset = _slice.start
    def p_info(i1, i1_diff, pat1):
        i1+=offset
        print "(%d) %d %s " % (i1, i1_diff, pat1)
    data = show_bit_h(_slice, bitnum)
    p1 = invert_patterns[bitnum][0]
    p2 = invert_patterns[bitnum][1]
    print "bit %02d" % bitnum
    i_s = find_inverts_bit_pattern(data, p1, p2)
    pattern, indice_start = i_s[0]
    if with_diff:
        print "diffs:"
        for ind, (pattern, indice) in enumerate(i_s[1:]):
            p_info(indice, indice-indice_start, pattern)
            indice_start = indice
        print ''
    print ''
    return

def find_inverts_gen(_slice, bitnum, invert_patterns, with_diff=True):
    offset = _slice.start
    def p_info(i1, i1_diff, pat1):
        i1+=offset
        print "(%d) %d %s " % (i1, i1_diff, pat1)
    data = gen_bit_h(_slice, bitnum)
    p1 = invert_patterns[bitnum][0]
    p2 = invert_patterns[bitnum][1]
    print "GENERATED bit %02d" % bitnum
    i_s = find_inverts_bit_pattern(data, p1, p2)
    pattern, indice_start = i_s[0]
    if with_diff:
        print "diffs:"
        for ind, (pattern, indice) in enumerate(i_s[1:]):
            p_info(indice, indice-indice_start, pattern)
            indice_start = indice
        print ''
    print ''
    return

def fix_key_bit(keys, bitnum, address, bitval):
    i = get_i_from_address(address)
    k = keys[i]
    valk = int(k, 16)
    b = [c for c in bin(valk)[2:].zfill(32)]
    # print b[bitnum], bitval
    b[bitnum] = bitval
    keys[i] = '%x' % (int(''.join(b),2))
    return k, keys[i]

In [296]:
class CyclicalTicker(object):
    """
    Will return retval1 every time, except on specific cycle ticks where retval 2 will be returned
    """
    def __init__(self, mod_val, retval1, tick_list, retval2, init=-1, wrap_cb=None):
        self.mod_val = mod_val
        self.tick_list = tick_list
        self.retval1 = retval1
        self.retval2 = retval2
        self.counter = init
        self.init = init
        self.tock = None
    
    def tick(self):
        self.counter = (1+self.counter)% self.mod_val
        if self.counter in self.tick_list:
            self.tock = self.retval2
        else:
            self.tock = self.retval1
        return self.tock
    
    @property
    def has_wrapped(self):
        return self.counter == 0

    @property
    def is_val1(self):
        return self.tock == self.retval1

    @property
    def is_val2(self):
        return not self.is_val1
    
    def reset(self):
        self.counter = self.init
    
    def __str__(self):
        return "<Cycl.T C:%d M:%d r1:%s r2:%s tlist:%s" % (self.counter, self.mod_val, self.retval1, self.retval2, self.tick_list)

class BinaryCyclicalTicker(CyclicalTicker):
    """
    Will return True all the time except after n ticks
    """
    def __init__(self, mod_val, init=-1):
        super(BinaryCyclicalTicker, self).__init__(mod_val, True, [mod_val-1], False, init)
    
class Generator(object):
    def __init__(self, init):
        self.init = init
        self.reset()
    
    def generate(self, _slice):
        # start from 0 and ignore up to start
        for i in range(_slice.start):
            self.bit()
        # then produce bits
        for i in range(_slice.start, _slice.stop):
            yield self.bit()
        self.reset()
        raise StopIteration

    def reset(self):
        raise NotImplementedError()

    def bit(self):
        raise NotImplementedError()

class KeysBasedBitGenerator(Generator):
    def __init__(self, keys, bitnum):
        self.keys = keys
        self.bitnum = bitnum
        self.counter = 0
    
    def generate(self, _slice):
        # then produce bits
        for ind in range(_slice.start,_slice.stop):
            self.counter = ind
            yield self.bit()
        raise StopIteration

    def bit(self):
        k = self.keys[self.counter]
        i = int(k, 16)
        bit = bin(i)[2:].zfill(32)[self.bitnum]
        self.counter += 1
        return bit
    
    def reset(self):
        self.counter = 0
        return
