In [None]:
import h5py
import numpy as np

with h5py.File('../datasets/ATMega8515_raw_traces.h5', 'r') as f:
    N = 5000 
    traces = f['traces'][:N].astype(np.int16)
    
    metadata = f['metadata'][:N]
    plaintexts = metadata['plaintext'][:N]   # shape: (N, 16)
    keys = metadata['key'][:N]              # shape: (N, 16)

    ciphertexts = metadata['ciphertext'][:N] if 'ciphertext' in metadata.dtype.names else None
    masks = metadata['masks'][:N] if 'masks' in metadata.dtype.names else None

print("Plaintext shape:", plaintexts.shape)   
print("Ciphertext shape:", None if ciphertexts is None else ciphertexts.shape)
print("Trace shape:", traces.shape)           
print("Masks shape:", None if masks is None else masks.shape)
print("Key shape:", keys.shape)  

Plaintext shape: (5000, 16)
Ciphertext shape: (5000, 16)
Trace shape: (5000, 100000)
Masks shape: (5000, 16)
Key shape: (5000, 16)


In [2]:
import scalib.modeling
import scalib.attacks
import scalib.postprocessing

SBOX = np.array(
    [
        0x63, 0x7C, 0x77, 0x7B, 0xF2, 0x6B, 0x6F, 0xC5, 0x30, 0x01, 0x67, 0x2B, 0xFE, 0xD7, 0xAB,
        0x76, 0xCA, 0x82, 0xC9, 0x7D, 0xFA, 0x59, 0x47, 0xF0, 0xAD, 0xD4, 0xA2, 0xAF, 0x9C, 0xA4,
        0x72, 0xC0, 0xB7, 0xFD, 0x93, 0x26, 0x36, 0x3F, 0xF7, 0xCC, 0x34, 0xA5, 0xE5, 0xF1, 0x71,
        0xD8, 0x31, 0x15, 0x04, 0xC7, 0x23, 0xC3, 0x18, 0x96, 0x05, 0x9A, 0x07, 0x12, 0x80, 0xE2,
        0xEB, 0x27, 0xB2, 0x75, 0x09, 0x83, 0x2C, 0x1A, 0x1B, 0x6E, 0x5A, 0xA0, 0x52, 0x3B, 0xD6,
        0xB3, 0x29, 0xE3, 0x2F, 0x84, 0x53, 0xD1, 0x00, 0xED, 0x20, 0xFC, 0xB1, 0x5B, 0x6A, 0xCB,
        0xBE, 0x39, 0x4A, 0x4C, 0x58, 0xCF, 0xD0, 0xEF, 0xAA, 0xFB, 0x43, 0x4D, 0x33, 0x85, 0x45,
        0xF9, 0x02, 0x7F, 0x50, 0x3C, 0x9F, 0xA8, 0x51, 0xA3, 0x40, 0x8F, 0x92, 0x9D, 0x38, 0xF5,
        0xBC, 0xB6, 0xDA, 0x21, 0x10, 0xFF, 0xF3, 0xD2, 0xCD, 0x0C, 0x13, 0xEC, 0x5F, 0x97, 0x44,
        0x17, 0xC4, 0xA7, 0x7E, 0x3D, 0x64, 0x5D, 0x19, 0x73, 0x60, 0x81, 0x4F, 0xDC, 0x22, 0x2A,
        0x90, 0x88, 0x46, 0xEE, 0xB8, 0x14, 0xDE, 0x5E, 0x0B, 0xDB, 0xE0, 0x32, 0x3A, 0x0A, 0x49,
        0x06, 0x24, 0x5C, 0xC2, 0xD3, 0xAC, 0x62, 0x91, 0x95, 0xE4, 0x79, 0xE7, 0xC8, 0x37, 0x6D,
        0x8D, 0xD5, 0x4E, 0xA9, 0x6C, 0x56, 0xF4, 0xEA, 0x65, 0x7A, 0xAE, 0x08, 0xBA, 0x78, 0x25,
        0x2E, 0x1C, 0xA6, 0xB4, 0xC6, 0xE8, 0xDD, 0x74, 0x1F, 0x4B, 0xBD, 0x8B, 0x8A, 0x70, 0x3E,
        0xB5, 0x66, 0x48, 0x03, 0xF6, 0x0E, 0x61, 0x35, 0x57, 0xB9, 0x86, 0xC1, 0x1D, 0x9E, 0xE1,
        0xF8, 0x98, 0x11, 0x69, 0xD9, 0x8E, 0x94, 0x9B, 0x1E, 0x87, 0xE9, 0xCE, 0x55, 0x28, 0xDF,
        0x8C, 0xA1, 0x89, 0x0D, 0xBF, 0xE6, 0x42, 0x68, 0x41, 0x99, 0x2D, 0x0F, 0xB0, 0x54, 0xBB,
        0x16,
    ],
    dtype=np.uint16,
)

In [3]:
selected_plaintexts = plaintexts[:, 2:].astype(np.uint16)
selected_keys = keys[:, 2:].astype(np.uint16)
selected_masks = masks[:, 2:16].astype(np.uint16)

rin = np.zeros((selected_masks.shape[0],), dtype=np.uint16)
rout = np.zeros((selected_masks.shape[0],), dtype=np.uint16)

x0 = selected_keys ^ selected_plaintexts ^ selected_masks
x1 = selected_masks
xrin = ((selected_keys ^ selected_plaintexts).T ^ rin).T
y0 = SBOX[selected_keys ^ selected_plaintexts] ^ selected_masks
y1 = selected_masks
yrout = (SBOX[(selected_keys ^ selected_plaintexts).T] ^ rout).T
labels = {}
for i in range(14):
    labels[f"k_{i}"] = selected_keys[:, i]
    labels[f"p_{i}"] = selected_plaintexts[:, i]
    labels[f"x0_{i}"] = x0[:, i]
    labels[f"x1_{i}"] = x1[:, i]
    labels[f"y0_{i}"] = y0[:, i]
    labels[f"y1_{i}"] = y1[:, i]
    labels[f"xrin_{i}"] = xrin[:, i]
    labels[f"yrout_{i}"] = yrout[:, i]
labels[f"rout"] = rout[:]
labels[f"rin"] = rin[:]

In [4]:
from tqdm import tqdm
from scalib.metrics import SNR
NBYTES = 14
def target_variables(byte):
    """variables that will be profiled"""
    return ["rin", "rout"] + [
            f"{base}_{byte}" for base in ("x0", "x1", "xrin", "yrout", "y0", "y1")
            ]
snrs = {v: dict() for i in range(NBYTES) for v in target_variables(i)}
for v, m in tqdm(snrs.items(), total=len(snrs), desc="SNR Variables"):
    snr = SNR(nc=256)
    x = labels[v].astype(np.uint16).reshape((5000, 1))
        # Note: if the traces do not fit in RAM, you can call multiple times fit_u
        # on the same SNR object to do incremental SNR computation.
    snr.fit_u(traces, x)
    m["SNR"] = snr.get_snr()[0, :]
        # Avoid NaN in case of scope over-range
    np.nan_to_num(m["SNR"], nan=0.0)


NR Variables: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 86/86 [00:09<00:00,  9.07it/s]

In [None]:
models = dict()
vs = []
pois = []
ncs = []
ps = []

for k, m in snrs.items():
    poi = np.argsort(m["SNR"])[-512:].astype(np.uint32)
    poi.sort()

    values = labels[k]
    unique_classes = np.unique(values)

    if len(unique_classes) >= 100:
        vs.append(k)
        pois.append(poi)
        ncs.append(256)
        ps.append(8)
        models[k] = {"poi": poi}

if not vs:
    raise ValueError("No variables with enough class diversity found.")

x = np.array([labels[v] for v in vs]).T

mlda = scalib.modeling.MultiLDA(ncs=ncs, ps=ps, pois=pois)
mlda.fit_u(traces, x)
mlda.solve()

for lda, v in zip(mlda.ldas, vs):
    models[v]["lda"] = lda


In [None]:
import copy
import collections
from scalib.attacks import FactorGraph, BPState
SASCA_GRAPH = """
NC 256
TABLE sbox

VAR MULTI x0
VAR MULTI x1
VAR MULTI x
VAR MULTI xp
VAR MULTI xrin

VAR MULTI y0
VAR MULTI y1
VAR MULTI y
VAR MULTI yp
VAR MULTI yrout

PUB MULTI p
VAR SINGLE k

PROPERTY x = p ^ k
PROPERTY x = x0 ^ x1
PROPERTY x = xrin

PROPERTY y = sbox[x]
PROPERTY y = y0 ^ y1
PROPERTY y = yrout

"""

def sasca_graph():
    sasca = scalib.attacks.FactorGraph(SASCA_GRAPH,{"sbox": SBOX.astype(np.uint32)})
    return sasca

    
def attack(traces, labels, models):
    """Run a SASCA attack using new scalib API and evaluate key distribution."""
    secret_key = [int(labels[f"k_{i}"][0]) for i in range(NBYTES)]
    key_distribution = []

    for i in range(NBYTES):
        sasca = copy.deepcopy(sasca_graph())
        
        # Set the labels for the plaintext byte
        p = labels[f"p_{i}"].astype(np.uint32)
        bp = BPState(sasca, traces.shape[0], {"p": p})

        # Set the evidence for target variables
        for var in target_variables(i): 
            if var not in models:
                
                continue
            var_name = var.split('_')[0]
            model = models[var]
            poi = model["poi"]
            lda = model["lda"]
            prs = lda.predict_proba(traces[:, poi])
            prs = prs / prs.sum(axis=1, keepdims=True)  # Normalize
            bp.set_evidence(var_name, prs)

        # Run belief propagation
        bp.bp_loopy(it=5, initialize_states=True)

        # Get the key distribution
        distri = bp.get_distribution("k")
        key_distribution.append(distri)  # Assuming distri is (1, 256)
        
    key_distribution = np.array(key_distribution)
    return secret_key, key_distribution

    
def run_attack_eval(traces, labels, models):
    """Run a SASCA attack on the given traces and evaluate its performance.
    Returns the log2 of the rank of the true key.
    """
    secret_key, key_distribution = attack(traces, labels, models)
    if isinstance(secret_key, (int, np.integer, np.float64)):
        secret_key = [int(secret_key)]
    for i in tqdm(range(len(secret_key))):
        k = secret_key[i]
        probs = key_distribution[i]
        sorted_indices = np.argsort(probs)[::-1] 
        rank = np.where(sorted_indices == k)[0][0]
        print(f"[Byte {i}] True key = {k}, Rank = {rank}, Prob = {probs[k]:.5f}")
    rmin, r, rmax = scalib.postprocessing.rank_accuracy(
        -np.log2(key_distribution), secret_key, max_nb_bin=2**20
    )
    
    lrmin, lr, lrmax = (np.log2(rmin), np.log2(r), np.log2(rmax))
    return lr


In [20]:
ranks = []

for a in tqdm(range(5), desc="attacks"):
    trace_a = traces[a:a+1, :]  
    label_a = {k: val[a:a+1] for k, val in labels.items()}  
    rank = run_attack_eval(trace_a, label_a, models)
    ranks.append(rank)



  -np.log2(key_distribution), secret_key, max_nb_bin=2**20
█████| 14/14 [00:00<00:00, 7435.77it/s]

[Byte 0] True key = 224, Rank = 1, Prob = 0.03970
[Byte 1] True key = 242, Rank = 1, Prob = 0.13169
[Byte 2] True key = 114, Rank = 0, Prob = 0.28072
[Byte 3] True key = 33, Rank = 2, Prob = 0.07738
[Byte 4] True key = 254, Rank = 5, Prob = 0.08108
[Byte 5] True key = 16, Rank = 0, Prob = 0.42093
[Byte 6] True key = 167, Rank = 4, Prob = 0.04331
[Byte 7] True key = 141, Rank = 25, Prob = 0.00187
[Byte 8] True key = 74, Rank = 2, Prob = 0.10452
[Byte 9] True key = 220, Rank = 7, Prob = 0.05129
[Byte 10] True key = 142, Rank = 0, Prob = 0.45632
[Byte 11] True key = 73, Rank = 8, Prob = 0.02412
[Byte 12] True key = 4, Rank = 0, Prob = 0.84773
[Byte 13] True key = 105, Rank = 2, Prob = 0.01898



100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 14/14 [00:00<00:00, 15518.04it/s]
attacks:  40%|██████████████████████████████████████████████████████████                                                                                       | 2/5 [00:00<00:00, 14.71it/s]

[Byte 0] True key = 224, Rank = 0, Prob = 0.99989
[Byte 1] True key = 242, Rank = 0, Prob = 0.57223
[Byte 2] True key = 114, Rank = 0, Prob = 0.99975
[Byte 3] True key = 33, Rank = 3, Prob = 0.09972
[Byte 4] True key = 254, Rank = 5, Prob = 0.03227
[Byte 5] True key = 16, Rank = 0, Prob = 0.99183
[Byte 6] True key = 167, Rank = 2, Prob = 0.03726
[Byte 7] True key = 141, Rank = 0, Prob = 0.98412
[Byte 8] True key = 74, Rank = 0, Prob = 0.74637
[Byte 9] True key = 220, Rank = 9, Prob = 0.02369
[Byte 10] True key = 142, Rank = 4, Prob = 0.03492
[Byte 11] True key = 73, Rank = 5, Prob = 0.05674
[Byte 12] True key = 4, Rank = 0, Prob = 0.99572
[Byte 13] True key = 105, Rank = 0, Prob = 0.78124




00%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 14/14 [00:00<00:00, 11722.95it/s]

[Byte 0] True key = 224, Rank = 0, Prob = 0.87737
[Byte 1] True key = 242, Rank = 0, Prob = 0.91949
[Byte 2] True key = 114, Rank = 0, Prob = 0.91951
[Byte 3] True key = 33, Rank = 69, Prob = 0.00035
[Byte 4] True key = 254, Rank = 7, Prob = 0.02003
[Byte 5] True key = 16, Rank = 1, Prob = 0.32235
[Byte 6] True key = 167, Rank = 0, Prob = 0.97946
[Byte 7] True key = 141, Rank = 0, Prob = 0.72284
[Byte 8] True key = 74, Rank = 1, Prob = 0.11107
[Byte 9] True key = 220, Rank = 1, Prob = 0.15459
[Byte 10] True key = 142, Rank = 0, Prob = 0.44900
[Byte 11] True key = 73, Rank = 96, Prob = 0.00001
[Byte 12] True key = 4, Rank = 0, Prob = 0.96331
[Byte 13] True key = 105, Rank = 0, Prob = 0.95446



100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 14/14 [00:00<00:00, 11264.20it/s]
attacks:  80%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████                             | 4/5 [00:00<00:00, 14.13it/s]

[Byte 0] True key = 224, Rank = 0, Prob = 0.88095
[Byte 1] True key = 242, Rank = 0, Prob = 0.24538
[Byte 2] True key = 114, Rank = 0, Prob = 0.96185
[Byte 3] True key = 33, Rank = 10, Prob = 0.02137
[Byte 4] True key = 254, Rank = 71, Prob = 0.00008
[Byte 5] True key = 16, Rank = 0, Prob = 0.57467
[Byte 6] True key = 167, Rank = 13, Prob = 0.00965
[Byte 7] True key = 141, Rank = 7, Prob = 0.03395
[Byte 8] True key = 74, Rank = 1, Prob = 0.10705
[Byte 9] True key = 220, Rank = 0, Prob = 0.35070
[Byte 10] True key = 142, Rank = 1, Prob = 0.08610
[Byte 11] True key = 73, Rank = 0, Prob = 0.55108
[Byte 12] True key = 4, Rank = 2, Prob = 0.27272
[Byte 13] True key = 105, Rank = 0, Prob = 0.46781




00%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 14/14 [00:00<00:00, 10215.77it/s]

[Byte 0] True key = 224, Rank = 0, Prob = 0.44096
[Byte 1] True key = 242, Rank = 5, Prob = 0.05101
[Byte 2] True key = 114, Rank = 14, Prob = 0.00130
[Byte 3] True key = 33, Rank = 0, Prob = 0.36079
[Byte 4] True key = 254, Rank = 0, Prob = 0.75916
[Byte 5] True key = 16, Rank = 0, Prob = 0.53490
[Byte 6] True key = 167, Rank = 0, Prob = 0.23773
[Byte 7] True key = 141, Rank = 0, Prob = 0.98236
[Byte 8] True key = 74, Rank = 56, Prob = 0.00019
[Byte 9] True key = 220, Rank = 0, Prob = 0.22674
[Byte 10] True key = 142, Rank = 6, Prob = 0.02148
[Byte 11] True key = 73, Rank = 14, Prob = 0.00780
[Byte 12] True key = 4, Rank = 0, Prob = 0.98922
[Byte 13] True key = 105, Rank = 0, Prob = 0.94749



ttacks: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 13.23it/s]