In [10]:
from typing import List, Tuple, Iterable
import numpy as np
import random
from copy import deepcopy
import gurobipy as gp
from gurobipy import GRB

In [None]:
class ReboundModel:

    def build_model(self):
        m = gp.Model("ReboundOptimization")
        return m

    def Lay0_form(self, m):
        blocks = []
        for i in range(6):
            block_vars = [[m.addVar(vtype=GRB.BINARY, name=f"x_{i}_{r}_{c}") for c in range(4)] for r in range(4)]
            blocks.append(block_vars)

        all_vars = [blocks[i][r][c] for i in range(6) for r in range(4) for c in range(4)]
        m.addConstr(gp.quicksum(all_vars) <= 48, "max_ones")############这里添加了对L0层的约束
        State = [[[0]*4 for _ in range(4)] for _ in range(12)]
        State[0], State[1], State[2], State[3], State[6], State[7] = blocks[0], blocks[1], blocks[2], blocks[3], blocks[4], blocks[5]
        return State
    
    def SubBytes(self, State):
        return State
    
    def ShiftRows(self, State):
        new_state = []
        for block in State:
            # 如果 block 是 numpy array
            # if isinstance(block, np.ndarray):
            #     shifted_block = np.zeros_like(block)
            #     for r in range(4):
            #         shifted_block[r] = np.roll(block[r], -r)  # 左移 r 个位置
            #     new_state.append(shifted_block)
            # else:
                # 如果是 Gurobi 或 pulp 的变量矩阵，使用列表推导
            shifted_block = [[None]*4 for _ in range(4)]
            for r in range(4):
                for c in range(4):
                    shifted_block[r][c] = block[r][(c + r) % 4]
            new_state.append(shifted_block)
        return new_state
    
    def ShiftRows(self, State):
        new_state = []
        for block in State:
            # 如果 block 是 numpy array
            # if isinstance(block, np.ndarray):
            #     shifted_block = np.zeros_like(block)
            #     for r in range(4):
            #         shifted_block[r] = np.roll(block[r], -r)  # 左移 r 个位置
            #     new_state.append(shifted_block)
            # else:
                # 如果是 Gurobi 或 pulp 的变量矩阵，使用列表推导
            shifted_block = [[None]*4 for _ in range(4)]
            for r in range(4):
                for c in range(4):
                    shifted_block[r][c] = block[r][(c - r) % 4]
            new_state.append(shifted_block)
        return new_state

    def truncated_diff_column(self, m, input_col, col_name="col"):
        """
        构造单列截断差分的输出变量和约束，并统计0的数量
        如果整列输入为0，则不统计零 cell
        """
        # 输出列变量
        output_col = [m.addVar(vtype=GRB.BINARY, name=f"{col_name}_y{r}") for r in range(4)]
        
        # 计算输入列中1的数量
        i_sum = gp.quicksum(input_col)
        j_sum = gp.quicksum(output_col)
        
        # 列活跃变量
        active_col = m.addVar(vtype=GRB.BINARY, name=f"{col_name}_active")
        m.addConstr(i_sum >= active_col, name=f"{col_name}_active_lower")
        m.addConstr(i_sum <= 4 * active_col, name=f"{col_name}_active_upper")
        
        # 核心约束
        m.addConstr(i_sum + j_sum >= 5 * active_col, name=f"{col_name}_ij_constraint")
        
        # 输入全0 → 输出全0
        for y in output_col:
            m.addConstr(y <= active_col, name=f"{col_name}_output_zero_if_inactive")
        
        # 统计列中为0的 cell 数量，只统计活跃列
        zero_terms = []
        for y in output_col:
            if isinstance(y, (int, float)):
                zero_terms.append((1 - y) * active_col)  # 常数
            else:
                zero_terms.append((1 - y) * active_col)  # Gurobi变量#######为了排除掉一列都是0的情况
        zero_count = gp.quicksum(zero_terms)

        return output_col, zero_count



    def MixColumns(self, State, m):
        new_state = []
        all_cells = []  # 所有输出 cell
        total_zero_expr = 0  # 所有列的零值数量

        for b, block in enumerate(State):
            c1, c2, c3, c4 = [], [], [], []
            for i in range(4):
                input_col = [block[j][i] for j in range(4)]
                output, zero_count = self.truncated_diff_column(m, input_col, col_name=f"b{b}_col{i}")
                
                # 汇总
                total_zero_expr += zero_count
                all_cells.extend(output)

                # 按列装入
                c1.append(output[0])
                c2.append(output[1])
                c3.append(output[2])
                c4.append(output[3])

            new_block = [[c1[r], c2[r], c3[r], c4[r]] for r in range(4)]
            new_state.append(new_block)

        # 统计活跃 cell 数
        active_count = 0
        expr_terms = []
        for cell in all_cells:
            if isinstance(cell, (int, float)):
                active_count += cell
            else:
                expr_terms.append(cell)
        active_count_expr = active_count + gp.quicksum(expr_terms)

        return new_state, active_count_expr, total_zero_expr

    
    def AES(self, State, m):
        State = self.SubBytes(State)
        State = self.ShiftRows(State)
        State, active_count, zero_count = self.MixColumns(State, m)
        return State, active_count, zero_count


    def XOR(self, m, State1, State2, Prob):
        """
        XOR 两个状态：按 block → cell 对应
        - 0 ⊕ 0 = 0
        - 0 ⊕ 1 / 1 ⊕ 0 = 1
        - 1 ⊕ 1 = 新变量 z
        - 如果有 Gurobi 变量参与 → 新变量 z
        """
        new_state = []
        white_terms = []  # 收集所有 (1 - z)
        for b in range(len(State1)):  # 遍历 block
            block1 = State1[b]
            block2 = State2[b]
            new_block = [[None]*4 for _ in range(4)]
            for r in range(4):
                for c in range(4):
                    x = block1[r][c]
                    y = block2[r][c]

                    # 1. 两边都是常数
                    if isinstance(x, (int, float)) and isinstance(y, (int, float)):
                        if x == 0 and y == 0:
                            new_block[r][c] = 0
                        elif (x == 1 and y == 0) or (x == 0 and y == 1):
                            new_block[r][c] = 1
                        elif x == 1 and y == 1:
                            # 1 ⊕ 1 -> 新变量
                            z = m.addVar(vtype=GRB.BINARY, name=f"xor_{b}_{r}_{c}")
                            new_block[r][c] = z
                            white_terms.append(1 - z)
                    else:
                        # 2. 至少有一个是 Gurobi 变量
                        z = m.addVar(vtype=GRB.BINARY, name=f"xor_{b}_{r}_{c}")
                        # 线性化 XOR：z = x + y (mod 2)
                        # 等价约束：z = x + y - 2*and(x,y)
                        m.addConstr(z >= x - y, name=f"xor_lb1_{b}_{r}_{c}")
                        m.addConstr(z >= y - x, name=f"xor_lb2_{b}_{r}_{c}")
                        m.addConstr(z <= x + y, name=f"xor_ub_{b}_{r}_{c}")
                        new_block[r][c] = z

                        # 补充 w，用于检测 (x=1,y=1,z=0)
                        w = m.addVar(vtype=GRB.BINARY, name=f"white_{b}_{r}_{c}")
                        # w ≤ x, w ≤ y, w ≤ 1-z, w ≥ x+y+1-z-2
                        m.addConstr(w <= x, name=f"w_le_x_{b}_{r}_{c}")
                        m.addConstr(w <= y, name=f"w_le_y_{b}_{r}_{c}")
                        m.addConstr(w <= 1 - z, name=f"w_le_z_{b}_{r}_{c}")
                        m.addConstr(w >= x + y + (1 - z) - 2, name=f"w_ge_{b}_{r}_{c}")
                        white_terms.append(w)
            new_state.append(new_block)

        if white_terms:
            new_white = gp.quicksum(white_terms) * 8
        else:
            new_white = 0

        return new_state, Prob + new_white
    
    def forwards(self, State, m, Prob):
        new_state = []
        new_state[0],  = self.AES(State[2], m)
        new_state[0], active_count, zero_count = self.XOR(m, new_state[0], State[8], Prob)


In [None]:
rebound_model = ReboundModel()
m = rebound_model.build_model()
State = rebound_model.Lay0_form(m)
State = rebound_model.SubBytes(State)
State = rebound_model.ShiftRows(State)
State, active_count, zero_count = rebound_model.MixColumns(State, m)
State, active_count, zero_count = rebound_model.AES(State, m)
State, Prob = rebound_model.XOR(m, State, State, 0)
print(State[4])

[[<gurobi.Var *Awaiting Model Update*>, <gurobi.Var *Awaiting Model Update*>, <gurobi.Var *Awaiting Model Update*>, <gurobi.Var *Awaiting Model Update*>], [<gurobi.Var *Awaiting Model Update*>, <gurobi.Var *Awaiting Model Update*>, <gurobi.Var *Awaiting Model Update*>, <gurobi.Var *Awaiting Model Update*>], [<gurobi.Var *Awaiting Model Update*>, <gurobi.Var *Awaiting Model Update*>, <gurobi.Var *Awaiting Model Update*>, <gurobi.Var *Awaiting Model Update*>], [<gurobi.Var *Awaiting Model Update*>, <gurobi.Var *Awaiting Model Update*>, <gurobi.Var *Awaiting Model Update*>, <gurobi.Var *Awaiting Model Update*>]]


In [None]:
# 总和约束：最多48个1#######这是初始L0层的约束，参与AES的块里最多只有一半的格子活跃
all_vars = [blocks[i][r][c] for i in range(6) for r in range(4) for c in range(4)]
m.addConstr(gp.quicksum(all_vars) <= 48, "max_ones")

# 示例目标（可替换为密码学相关目标）
m.setObjective(gp.quicksum(all_vars), GRB.MAXIMIZE)

# 求解
m.optimize()

# 查看结果
for i in range(6):
    print(f"Block {i}:")
    for r in range(4):
        row = [int(blocks[i][r][c].X) for c in range(4)]
        print(row)