In [15]:


# STEP 2: Import libraries
import tenseal as ts
import numpy as np
import pandas as pd

# STEP 3: Load data files (pre-FC features and single FC layer params)
features   = np.load("NEW/pre_fc1_features.npy")          # shape [N, 512]
labels     = np.load("NEW/pre_fc1_labels.npy")            # shape [N]
fc1_weight = pd.read_csv("NEW/fc1_folded_weight.csv", header=None).values  # [10, 512]
fc1_bias   = pd.read_csv("NEW/fc1_folded_bias.csv",   header=None).values.flatten()  # [10]

# Ensure numerical stability by using float64 everywhere
# features   = features.astype(np.float64)
# fc1_weight = fc1_weight.astype(np.float64)
# fc1_bias   = fc1_bias.astype(np.float64)

# CIFAR-10 label mapping
cifar10_labels = ['airplane', 'automobile', 'bird', 'cat', 'deer',
                  'dog', 'frog', 'horse', 'ship', 'truck']

# STEP 4: Setup TenSEAL CKKS context
USE_HIGH_PRECISION = False
if USE_HIGH_PRECISION:
    context = ts.context(
        ts.SCHEME_TYPE.CKKS,
        poly_modulus_degree=16384,
        coeff_mod_bit_sizes=[60, 40, 40, 40, 60]
    )
    context.global_scale = 2**45
else:
    context = ts.context(
        ts.SCHEME_TYPE.CKKS,
        poly_modulus_degree=8192,
        coeff_mod_bit_sizes=[60, 40, 40, 60]
    )
    context.global_scale = 2**40

context.generate_galois_keys()

# STEP 5: Degree-4 polynomial activation with manual scale control
# (Kept identical to your reference)
def poly_activation_safe(enc_val, context):
    A = -0.00068481
    B = -1.59833239e-17
    C = 0.0887234775
    D = 0.5
    E = 0.738099333

    # Client-side: decrypt -> compute powers in plaintext -> re-encrypt each power
    x_plain = enc_val.decrypt()[0]
    x2 = x_plain * x_plain
    x3 = x2 * x_plain
    x4 = x3 * x_plain

    x_enc  = ts.ckks_vector(context, [x_plain])
    x2_enc = ts.ckks_vector(context, [x2])
    x3_enc = ts.ckks_vector(context, [x3])
    x4_enc = ts.ckks_vector(context, [x4])

    return A * x4_enc + B * x3_enc + C * x2_enc + D * x_enc + E

# STEP 6: Encrypted inference on N samples (single FC layer + poly activation)
N = min(10000, features.shape[0])
correct = 0
results = []

for i in range(N):
    x = features[i]              # [512]
    y_true = int(labels[i])

    # Encrypt the 512-d feature vector
    enc_x = ts.ckks_vector(context, x.tolist())

    # Compute encrypted logits: for each class j, enc_x.dot(w_j) + b_j
    enc_logits = []
    for j in range(fc1_weight.shape[0]):  # 10 classes
        logit_j = enc_x.dot(fc1_weight[j].tolist())
        logit_j += fc1_bias[j]
        enc_logits.append(logit_j)

    # Apply polynomial activation (scale-safe)
    enc_activated = [poly_activation_safe(logit, context) for logit in enc_logits]

    # Decrypt activated scores to get prediction (client-side)
    decrypted_scores = [val.decrypt()[0] for val in enc_activated]
    predicted = int(np.argmax(decrypted_scores))

    results.append({
        "Iteration": i + 1,
        "Actual Label": cifar10_labels[y_true],
        "Predicted/Inferenced Label": cifar10_labels[predicted],
        "Correct": y_true == predicted
    })

    if predicted == y_true:
        correct += 1

accuracy = (correct / N) * 100
df = pd.DataFrame(results)
df["Total Accuracy"] = ""
df.loc[N - 1, "Total Accuracy"] = f"{accuracy:.2f}%"

print(df)
print(f"\nEncrypted top-1 accuracy on {N} samples (with degree-4 poly activation): {accuracy:.2f}%")


      Iteration Actual Label Predicted/Inferenced Label  Correct  \
0             1          cat                        cat     True   
1             2         ship                       ship     True   
2             3         ship                       ship     True   
3             4     airplane                   airplane     True   
4             5         frog                       frog     True   
...         ...          ...                        ...      ...   
9995       9996         ship                       ship     True   
9996       9997          cat                        cat     True   
9997       9998          dog                        dog     True   
9998       9999   automobile                 automobile     True   
9999      10000        horse                      horse     True   

     Total Accuracy  
0                    
1                    
2                    
3                    
4                    
...             ...  
9995                 
9996   