In [1]:
# Full code for the simulation including imports, function definitions, and the simulation loop

import numpy as np


def softmax(x):
    """Compute softmax values for each sets of scores in x."""
    e_x = np.exp(x - np.max(x))
    return e_x / e_x.sum()


def quantize_to_8bit(x):
    """Quantize the given array to 8-bit representation (values between 0 and 255)."""
    return np.round(x * 255).astype(np.uint8)


def quantize_to_4bit(x):
    """Quantize the given array to 4-bit representation (values between 0 and 15)."""
    return np.round(x * 15).astype(np.uint8)


def normalize_to_0_1(x, max_value=255):
    """Normalize the given array to have values between 0 and 1."""
    return x / max_value


def simulation_with_quantized_softmax(num_simulations=100000):
    errors = []

    for _ in range(num_simulations):
        # Random 8-bit 128-vector
        vector_a = np.random.randint(0, 256, 128)

        # Applying softmax to the vector
        softmax_vector_a = softmax(vector_a)

        # Quantizing and normalizing the softmax output
        quantized_softmax_vector_a = quantize_to_8bit(softmax_vector_a)
        normalized_softmax_vector_a = normalize_to_0_1(quantized_softmax_vector_a)

        # Another random 8-bit 128-vector
        vector_b = np.random.randint(0, 256, 128)

        # Multiplying the quantized and normalized softmax vector with vector_b and summing the result
        result_with_normalized_softmax = np.sum(normalized_softmax_vector_a * vector_b)

        # print("Result with normalized softmax:", result_with_normalized_softmax)

        q_result_with_normalized_softmax = normalize_to_0_1(
            quantize_to_4bit(normalize_to_0_1(result_with_normalized_softmax)),
            max_value=15,
        )

        # print("Quantized result with normalized softmax:", q_result_with_normalized_softmax)

        # Using max index set to 1
        # max_index_vector_a = np.zeros_like(vector_a)
        # max_index_vector_a[np.argmax(vector_a)] = 1
        # result_with_max_index = np.sum(max_index_vector_a * vector_b)
        result_with_max_index = vector_b[np.argmax(vector_a)]

        # print("Result with max index:", result_with_max_index)

        q_result_with_max_index = normalize_to_0_1(
            quantize_to_4bit(normalize_to_0_1(result_with_max_index)),
            max_value=15,
        )

        # print("Quantized result with max index:", q_result_with_max_index)

        # Calculating the error
        error = np.abs(q_result_with_normalized_softmax - q_result_with_max_index)
        errors.append(error)

    # Average error across simulations
    return np.mean(errors)


# Running the simulation
average_error_with_quantized_softmax = simulation_with_quantized_softmax()
print(
    "Average error with quantized softmax:",
    average_error_with_quantized_softmax * 100,
    "%",
)

Average error with quantized softmax: 8.558466666666668 %
