In [1]:
from pwn import remote, context
from css.mangle import mangle

context.log_level = "error"

cipher_auth_key = "0bfb91847347be4a"

def bxor(a, b):
    return bytes([i ^ j for i, j in zip(a, b)])
def do_cipher_auth_key(x):
    k = bytes.fromhex(cipher_auth_key)
    return bxor(x, k)

r = remote("the-other-css.chal.pwni.ng", 1996)

host_challenge = bytes([0] * 16)
r.send(host_challenge)
challenge_key = host_challenge[:8]
encrypted_host_nonce = host_challenge[8:]
host_mangling_key = do_cipher_auth_key(challenge_key)

r.recv(8)

host_nonce = do_cipher_auth_key(encrypted_host_nonce)

player_challenge_key = r.recv(8)
encrypted_player_nonce = r.recv(8)
player_nonce = do_cipher_auth_key(encrypted_player_nonce)

player_mangling_key = do_cipher_auth_key(player_challenge_key)
response = mangle(player_mangling_key, do_cipher_auth_key(player_nonce))
r.send(response)

mangling_key = bxor(host_mangling_key, player_mangling_key)
session_nonce = bxor(host_nonce, player_nonce)
session_key = mangle(mangling_key, session_nonce)

In [2]:
from tqdm import tqdm
from itertools import count, cycle
from css.cipher import Cipher
from css.mode import Mode

sectors = []
for _ in tqdm(range(10)):
    ct = b""
    try:
        ct = r.recvn(8208, timeout=5)
        sectors.append(ct)
    except Exception:
        ct = r.recv()
        sectors.append(ct)
    if len(ct) != 8208:
        break

stream_cipher = Cipher(session_key, Mode.Data)
print(session_key.hex())

decrypted_sectors_with_sector_xor = []
for sector in tqdm(sectors):
    decrypted_sectors_with_sector_xor.append(stream_cipher.decrypt(sector))

100%|██████████| 10/10 [00:01<00:00,  7.36it/s]


9b0daeae8ef33846


100%|██████████| 10/10 [00:01<00:00,  8.52it/s]


In [3]:
disk_a_ct = open("disks/a.disk", "rb")

sector_params = []
for sector_idx in range(10):
    disk_a_ct.seek(8 * 128 + sector_idx * (8208 + 8))
    sector_nonce = disk_a_ct.read(8)
    sector_ct = disk_a_ct.read(64)
    sector_pt = decrypted_sectors_with_sector_xor[sector_idx]
    sector_stream = bxor(sector_ct, sector_pt)

    sector_params.append({
        "idx": sector_idx,
        "nonce": sector_nonce.hex(),
        "stream": sector_stream.hex()
    })
sector_params

[{'idx': 0,
  'nonce': 'a9debb718177adaa',
  'stream': '50cff4558f8195fbe907baf6480ca15306e5d55cc3f374e0de18dd3e6406d7a7295b3f18ef38d98a51b005f96924e4dde831d69b825139bff8fe1045ba42400e'},
 {'idx': 1,
  'nonce': '4911982d2aafef95',
  'stream': '2599656b5d11fc0fc8682471e676e437199fa8635ce869ef7f8ee678a3b39dca1a5cbd440c6d6464e8be712f959ff9a78241b7bc619b720dd4c3785085c9e0f8'},
 {'idx': 2,
  'nonce': '19d73155de9f486f',
  'stream': 'f849e55da82feebd5600c201a0cddb3650d05a9cf292a4aa125a4738249e96c33b41fdc96a31b9e14e180eee1569af9382edfcd25854e6637a183214d44f08b9'},
 {'idx': 3,
  'nonce': '70230e43f8e4a8a6',
  'stream': 'aee6ae7123581fae8f259fb4460be47e127343d728af86ddfd9ff8d6604191f01d39f6b2d868fe448b4202970a4854dd51ca34a5042723cd1705952e34c6d15b'},
 {'idx': 4,
  'nonce': '9ea411908f185729',
  'stream': '7d52f73b5d6205bbba4c0bcf8c4004e6820c358398ecf358ed69cd02712b1e61143dabc4211d0c972a8bcfdfa986ba60feb177c0bb639c31f58b8d86291418e9'},
 {'idx': 5,
  'nonce': '1a070ad76aa79c50',
  'stream': '5dd1

In [4]:
from functools import lru_cache
from numba import jit
from numba.typed import List

cipher_nsamp = 16
stream_nsamp = 12

@lru_cache
def taps_to_tapmasks(taps):
    tapmasks = List()
    m = 1
    while taps:
        if m & taps:
            tapmasks.append(m)
            taps = taps ^ m
        m <<= 1
    return tapmasks

@jit(nopython=True)
def fast_calc_lfsr_inner(size, seed, tapmasks, iters, flip):
    state = seed

    m = 1
    ret = 0
    rets = []
    for _ in range(iters):
        b = 0
        for tapmask in tapmasks:
            b ^= ((state & tapmask) == tapmask)
        state = (b << (size - 1)) | (state >> 1)

        b ^= flip
        ret = ret | (b * m)
        m <<= 1

        if m == 0x100:
            m = 1
            rets.append(ret)
            ret = 0

    if m != 1:
        rets.append(ret)

    return rets

def calc_lfsr(size, seed, taps, iters, flip):
    tapmasks = taps_to_tapmasks(taps)
    rets = fast_calc_lfsr_inner(size, seed, tapmasks, iters, flip)
    ret = 0
    for idx, i in enumerate(rets):
        ret += i << (idx * 8)
    return ret

def check_key1(key1, numall):
    seed1 = ((key1 & 0xfffff8) << 1) | 8 | (key1 & 7)
    num1 = calc_lfsr(25, seed1, 0x19e4001, stream_nsamp * 8, 1)
    num2 = (numall - num1) & ((1 << (stream_nsamp * 8)) - 1)
    head_key2 = (num2 & 0x1ffffffffff) ^ 0x1ffffffffff
    tail_num2 = calc_lfsr(41, head_key2, 0xfdc0000001, stream_nsamp * 8 - 41, 1)
    if tail_num2 == num2 >> 41:
        return key1
    return None

def keystream_to_numall(keystream):
    keystream = bytes.fromhex(keystream)[:cipher_nsamp]
    numall = int(bytes(keystream[::-1]).hex(), 16)
    return numall

def recover_key1(stream_idx, posthoc=None):
    numall = keystream_to_numall(sector_params[stream_idx]["stream"])
    start_from = 0 if posthoc is None else posthoc
    for i in tqdm(range(start_from, 0x1000000)):
        res = check_key1(i, numall)
        if res is not None:
            print("!!!", stream_idx, hex(res))
            return res
    return None

# posthocs = [None] * 10
posthocs = [
    0xd50000, 0x670000, 0xd10000, 0x200000, 0x810000,
    0x6c0000, 0xef0000, 0x6a0000, 0x460000, 0xfb0000,
]

for sector_param, posthoc in zip(sector_params, posthocs):
    sector_param["key1"] = recover_key1(sector_param["idx"], posthoc=posthoc)

  1%|▏         | 35596/2818048 [00:00<00:53, 51659.50it/s]


!!! 0 0xd58b0c


  0%|          | 30288/10027008 [00:00<00:55, 179479.58it/s]


!!! 1 0x677650


  0%|          | 10171/3080192 [00:00<00:17, 175500.32it/s]


!!! 2 0xd127bb


  0%|          | 34234/14680064 [00:00<01:20, 182841.50it/s]


!!! 3 0x2085ba


  0%|          | 8220/8323072 [00:00<00:47, 173516.35it/s]


!!! 4 0x81201c


  0%|          | 2949/9699328 [00:00<00:56, 171701.08it/s]


!!! 5 0x6c0b85


  1%|▏         | 14007/1114112 [00:00<00:06, 171306.65it/s]


!!! 6 0xef36b7


  0%|          | 37174/9830400 [00:00<00:56, 172390.26it/s]


!!! 7 0x6a9136


  0%|          | 23209/12189696 [00:00<01:06, 183615.52it/s]


!!! 8 0x465aa9


  0%|          | 428/327680 [00:00<00:01, 183161.12it/s]

!!! 9 0xfb01ac





In [5]:
from css.lfsr import LFSR

def undo_lfsr(self):
    bit = self.state >> (self.size - 1)
    self.state = self.state ^ (bit << (self.size - 1))
    self.state = self.state << 1
    for tap in self.taps:
        bit ^= (self.state & tap) == tap
    self.state = self.state ^ bit

LFSR.undo_lfsr = undo_lfsr

def generate_key2(stream_idx):
    key1 = sector_params[stream_idx]["key1"]
    numall = keystream_to_numall(sector_params[stream_idx]["stream"])
    seed1 = ((key1 & 0xfffff8) << 1) | 8 | (key1 & 7)
    num1 = calc_lfsr(25, seed1, 0x19e4001, stream_nsamp * 8, 1)
    num2 = (numall - num1) & ((1 << (stream_nsamp * 8)) - 1)
    head_key2 = (num2 & 0x1ffffffffff) ^ 0x1ffffffffff
    unlfsr2 = LFSR(41, head_key2, 0xfdc0000001)
    for _ in range(41):
        unlfsr2.undo_lfsr()
    lfsr_key2 = unlfsr2.state
    key2 = ((lfsr_key2 >> 1) & 0xfffffffff8) | (lfsr_key2 & 7)
    return key2

for sector_param, posthoc in zip(sector_params, posthocs):
    sector_param["key2"] = generate_key2(sector_param["idx"])
    sector_param["key"] = (
        hex(sector_param["key1"])[2:].zfill(6) + 
        hex(sector_param["key2"])[2:].zfill(10)
    )
    print(sector_param["key"])

d58b0c1a3a08361c
677650352d147b20
d127bb11efa1b9b5
2085ba6294098918
81201c6acdd7e256
6c0b85ec94e8a9f0
ef36b7f5094b82e5
6a9136427269ab4f
465aa988c4a412fe
fb01ac777938dab9


In [6]:
import z3
from css.table import table

mangle_ins = []
mangle_outs = []
for sector_param in sector_params:
    mangle_ins.append(list(bytes.fromhex(sector_param["nonce"])))
    mangle_outs.append(list(bytes.fromhex(sector_param["key"])))

def mix(key, value):
    ret = value ^ z3.LShR(value, 8) ^ key
    return ret

def shift(value):
    ret = value ^ (value << 56)
    return ret

def build_tabulate_one(solver):
    tabulate_one = z3.Function("tabulate", z3.BitVecSort(8), z3.BitVecSort(8))
    for idx, table_i in enumerate(table):
        solver.add(tabulate_one(idx) == table_i)
    return tabulate_one

def tabulate(value, name, tabulate_one, solver):
    value_sym = z3.BitVec(name, 64)
    solver.add(value_sym == value)
    ret = []
    for pos in reversed(range(0, 64, 8)):
        ret.append(tabulate_one(z3.Extract(pos+7, pos, value_sym)))
    ret = z3.Concat(*ret)
    return ret

def u8s_to_bitecval(x):
    return z3.BitVecVal(int(bytes(x).hex(), 16), len(x) * 8)

s = z3.Solver()

tabulate_one = build_tabulate_one(s)

key = z3.BitVec("key", 64)

# posthoc, otherwise will take 30min
s.add(z3.Extract(31, 0, key) == 0xc348b1fb)

for idx, (mangle_in, mangle_out) in enumerate(zip(mangle_ins, mangle_outs)):
    value = u8s_to_bitecval(mangle_in)
    goal = u8s_to_bitecval(mangle_out)

    value = mix(key, value)
    value = shift(value)
    value = mix(key, value)
    value = shift(value)
    value = mix(key, value)
    value = tabulate(value, f"one_{idx}", tabulate_one, s)
    value = shift(value)
    value = mix(key, value)
    value = tabulate(value, f"two_{idx}", tabulate_one, s)
    value = shift(value)
    value = mix(key, value)
    value = shift(value)
    value = mix(key, value)

    s.add(value == goal)

print(s.check())

disk_a_key_hex = hex(s.model()[key].as_long())[2:].zfill(16)
print(disk_a_key_hex)

sat
54e50074c348b1fb


In [7]:
@jit(nopython=True)
def clock_byte_lfsr(size, state, taps):
    byte = 0
    for bitpos in range(8):
        bit = 0
        for tap in taps:
            bit ^= (state & tap) == tap
        state = (state >> 1) | (bit << (size - 1))
        byte |= bit << bitpos
    return state, byte

class QLFSR:
    def __init__(self, size, seed, taps):
        self.size = size
        self.state = seed
        self.taps = List()
        for i in range(size):
            tap = taps & (1 << i)
            if tap > 0:
                self.taps.append(tap)

    def next_byte(self):
        self.state, byte = clock_byte_lfsr(self.size, self.state, self.taps)
        return byte
    
import css.cipher
css.cipher.LFSR = QLFSR

In [8]:
disk_a_key = bytes.fromhex(disk_a_key_hex)
with open("./disks/a.disk", "rb") as f:
    disk_a_key_enc = f.read(8)
player_key_xorpad = bxor(disk_a_key, disk_a_key_enc)
f = open("./disks/b.disk", "rb")
disk_b_key_enc = f.read(8)
disk_b_key = bxor(player_key_xorpad, disk_b_key_enc)

disk_key = disk_b_key
print(disk_key.hex())

outfile = open("./diskB.iso", "wb")

f.seek(8 * 128)
for sector_index in tqdm(count()):
    sector_nonce = f.read(8)
    if len(sector_nonce) == 0:
        break
    sector_key = mangle(disk_key, sector_nonce)
    sector_cipher = Cipher(sector_key, Mode.Data)
    data = sector_cipher.decrypt(f.read(8208))
    
    sector_xorpad = data[:16]
    sector_buf = data[16:]

    sector_pln = bxor(sector_buf, cycle(sector_xorpad))
    outfile.write(sector_pln)

11fe4f04a729a3e6


3879it [01:41, 38.22it/s]
