<a href="https://colab.research.google.com/github/tsanoop887-hash/AIF360/blob/main/Math_Aware_INT8_Transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

torch.manual_seed(0)
device="cuda" if torch.cuda.is_available() else "cpu"

QUANTIZATION OPERATOR

In [None]:
def quantize_int8(x: torch.Tensor,scale: float) -> torch.Tensor:
  """
  Eq.(1) : uniform symmetric INT8
  quantization
  x_q = clip(round(x / s),-128,127)
  """
  x_q=torch.round(x / scale)
  x_q=torch.clamp(x_q , -128, 127)
  return x_q.to(torch.int8)

INT8 LINEAR LAYER (GEMM)


In [None]:
from torch._higher_order_ops.scan import stack_y
def int8_linear(x_q,w_q,s_x,s_w,s_y):
  """
  INT8XINT8 -> INT32 Accumulator -> rescale ->INT 8
  Eq.(3) : y_q = clip(round (s_x*s_w/s_y))*sum(x_q*w_q)))
  """

  acc_int32=torch.matmul(x_q.int(),w_q.int())

  scale =(s_x* s_w) /s_y
  y_fp=acc_int32.float()*scale

  y_q=torch.round(y_fp)
  y_q=torch.clamp(y_q,-128,127)

  return y_q.to(torch.int8)

APPROXIMATE SOFTMAX

In [None]:
def approx_softmax(x):
  """
  polynomial softmax approximation
  Hardware-friendly (LUT/PWL
  implementable)

  """
  x=x-x.max(dim=-1,
  keepdim=True).values
  x=torch.clamp(x,-8,0)

  # second order polynomial approx of exp(x)
  exp_x =1.0+x+0.5*x*x

  return exp_x/exp_x.sum(dim=-1,
      keepdim=True)

INT8 SELF-ATTENTION

In [None]:
def int8_self_attention(X,wq,wk,wv,scales):
  """
  Alogrithm 1:INT8 self Attenttion
  """
  s_x,s_w,s_y = scales

  X_q=quantize_int8(X,s_x)
  wq_q=quantize_int8(wq,s_w)
  wk_q=quantize_int8(wk,s_w)
  wv_q=quantize_int8(wv,s_w)

  Q_q=int8_linear(X_q,wq_q,s_x,s_w,s_y)
  K_q=int8_linear(X_q,wk_q,s_x,s_w,s_y)
  V_q=int8_linear(X_q,wv_q,s_x,s_w,s_y)


  scores=torch.matmul(Q_q,K_q.transpose(-2,-1))
  scores=scores/(Q_q.shape[-1]**0.5)

  A=approx_softmax(scores)

  Y=torch.matmul(A,V_q.float()) # Cast V_q to float to match A

  return Y

INT8 FEED-FORWARD NETWORK

In [None]:
def int8_ffn(X, W1, W2, scales):
    """
    INT8 Feed-Forward Network
    """

    s_x, s_w, s_y = scales

    X_q  = quantize_int8(X,  s_x)
    W1_q = quantize_int8(W1, s_w)
    W2_q = quantize_int8(W2, s_w)

    # First linear
    H_q = int8_linear(X_q, W1_q, s_x, s_w, s_y)
    H   = F.relu(H_q.float() * s_y)

    # Second linear
    H_q2 = quantize_int8(H, s_x)
    Y_q  = int8_linear(H_q2, W2_q, s_x, s_w, s_y)

    return Y_q.float() * s_y


INT8 TRANSFORMER BLOCK

In [None]:
class INT8TransformerBlock(nn.Module):
    def __init__(self, d_model, d_ff):
        super().__init__()

        self.Wq = nn.Parameter(torch.randn(d_model, d_model))
        self.Wk = nn.Parameter(torch.randn(d_model, d_model))
        self.Wv = nn.Parameter(torch.randn(d_model, d_model))
        self.W1 = nn.Parameter(torch.randn(d_model, d_ff))
        self.W2 = nn.Parameter(torch.randn(d_ff, d_model))

        # Fixed-point scales (can be RL-tuned later)
        self.scales = (0.02, 0.02, 0.02)

    def forward(self, X):
        Y = int8_self_attention(X, self.Wq, self.Wk, self.Wv, self.scales)
        X = X + Y  # Residual connection

        Z = int8_ffn(X, self.W1, self.W2, self.scales)
        return X + Z


In [None]:
# Test configuration
B, N, D = 2, 8, 32
D_FF = 64

# Random input
X = torch.randn(B, N, D)

# Model
model = INT8TransformerBlock(D, D_FF)

# Forward pass
Y = model(X)

print("Input shape :", X.shape)
print("Output shape:", Y.shape)
print("Sample output:", Y[0, 0, :5])

Input shape : torch.Size([2, 8, 32])
Output shape: torch.Size([2, 8, 32])
Sample output: tensor([ 23.9947,   1.1611,  27.3593, -43.4831,   3.0750])


SIMULATOR

In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Helper functions
def quantize_int8(x: torch.Tensor,scale: float) -> torch.Tensor:
  """
  Eq.(1) : uniform symmetric INT8
  quantization
  x_q = clip(round(x / s),-128,127)
  """
  x_q=torch.round(x / scale)
  x_q=torch.clamp(x_q , -128, 127)
  return x_q.to(torch.int8)

def int8_linear(x_q,w_q,s_x,s_w,s_y):
  """
  INT8XINT8 -> INT32 Accumulator -> rescale ->INT 8
  Eq.(3) : y_q = clip(round (s_x*s_w/s_y))*sum(x_q*w_q)))
  """

  acc_int32=torch.matmul(x_q.int(),w_q.int())

  scale =(s_x* s_w) /s_y
  y_fp=acc_int32.float()*scale

  y_q=torch.round(y_fp)
  y_q=torch.clamp(y_q,-128,127)

  return y_q.to(torch.int8)

def approx_softmax(x):
  """
  polynomial softmax approximation
  Hardware-friendly (LUT/PWL
  implementable)

  """
  x=x-x.max(dim=-1,
  keepdim=True).values
  x=torch.clamp(x,-8,0)

  # second order polynomial approx of exp(x)
  exp_x =1.0+x+0.5*x*x

  return exp_x/exp_x.sum(dim=-1,
      keepdim=True)

def int8_self_attention(X,wq,wk,wv,scales):
  """
  Alogrithm 1:INT8 self Attenttion
  """
  s_x,s_w,s_y = scales

  X_q=quantize_int8(X,s_x)
  wq_q=quantize_int8(wq,s_w)
  wk_q=quantize_int8(wk,s_w)
  wv_q=quantize_int8(wv,s_w)

  Q_q=int8_linear(X_q,wq_q,s_x,s_w,s_y)
  K_q=int8_linear(X_q,wk_q,s_x,s_w,s_y)
  V_q=int8_linear(X_q,wv_q,s_x,s_w,s_y)


  scores=torch.matmul(Q_q,K_q.transpose(-2,-1))
  scores=scores/(Q_q.shape[-1]**0.5)

  A=approx_softmax(scores)

  Y=torch.matmul(A,V_q.float()) # Cast V_q to float to match A

  return Y

def int8_ffn(X, W1, W2, scales):
    """
    INT8 Feed-Forward Network
    """

    s_x, s_w, s_y = scales

    X_q  = quantize_int8(X,  s_x)
    W1_q = quantize_int8(W1, s_w)
    W2_q = quantize_int8(W2, s_w)

    # First linear
    H_q = int8_linear(X_q, W1_q, s_x, s_w, s_y)
    H   = F.relu(H_q.float() * s_y)

    # Second linear
    H_q2 = quantize_int8(H, s_x)
    Y_q  = int8_linear(H_q2, W2_q, s_x, s_w, s_y)

    return Y_q.float() * s_y

# Placeholder Metrics class (from mgZaB8R3XCjp)
class Metrics:
    def latency_cycles(self):
        return 100 # Dummy value
    def energy(self):
        return 50 # Dummy value
    def __init__(self):
        self.mac_ops = 1000 # Dummy value
        self.sram_reads = 200 # Dummy value
        self.sram_writes = 100 # Dummy value

# INT8TransformerBlock definition (from VX27jOOxyCCx and mgZaB8R3XCjp)
class INT8TransformerBlock(nn.Module):
    def __init__(self, d_model, d_ff):
        super().__init__()

        self.Wq = nn.Parameter(torch.randn(d_model, d_model))
        self.Wk = nn.Parameter(torch.randn(d_model, d_model))
        self.Wv = nn.Parameter(torch.randn(d_model, d_model))
        self.W1 = nn.Parameter(torch.randn(d_model, d_ff))
        self.W2 = nn.Parameter(torch.randn(d_ff, d_model))

        # Fixed-point scales (can be RL-tuned later)
        self.scales = (0.02, 0.02, 0.02)

    def forward(self, X):
        Y = int8_self_attention(X, self.Wq, self.Wk, self.Wv, self.scales)
        X = X + Y  # Residual connection

        Z = int8_ffn(X, self.W1, self.W2, self.scales)
        return X + Z

# AcceleratorSimulator definition (from J66ndkU9C0jk)
class AcceleratorSimulator:
    def __init__(self, macs_per_cycle=16, sram_per_cycle=4):
        self.macs_per_cycle = macs_per_cycle
        self.sram_per_cycle = sram_per_cycle

    def run(self, model, input_tensor):
        metrics = Metrics()
        # The INT8TransformerBlock's forward method does not accept 'metrics', so remove it.
        # This assumes the model's forward method should not be passed 'metrics'.
        _ = model(input_tensor) # Assuming model.forward(input_tensor) is correct

        return {
            "latency": metrics.latency_cycles(),
            "energy": metrics.energy(),
            "mac_ops": metrics.mac_ops,
            "sram_ops": metrics.sram_reads + metrics.sram_writes
        }

# Re-define model and X for this cell to ensure they are available
B, N, D = 2, 8, 32
D_FF = 64
X = torch.randn(B, N, D)
model = INT8TransformerBlock(D, D_FF)

# Instantiate the AcceleratorSimulator
simulator = AcceleratorSimulator(macs_per_cycle=16, sram_per_cycle=4)

# Now run the simulator with the model and input
result = simulator.run(model, X)

reward = -(
    result["latency"] +
    0.1 * result["energy"]
)


accelerator_env

In [13]:
import gym
import numpy as np

class AcceleratorEnv(gym.Env):
    def __init__(self, model, input_shape):
        super().__init__()

        self.model = model
        self.sim = AcceleratorSimulator(macs_per_cycle=16, sram_per_cycle=4) # Initialize simulator directly here
        self.x = torch.randn(*input_shape)

        # Actions:
        # scale_id ∈ {0.01, 0.02, 0.04}
        # bitwidth ∈ {8, 6}
        # tile ∈ {8, 16, 32}
        self.action_space = gym.spaces.MultiDiscrete([3, 2, 3])

        self.observation_space = gym.spaces.Box(
            low=0, high=1e6, shape=(4,), dtype=np.float32
        )

        self.layer_id = 0

    def reset(self):
        self.layer_id = 0
        return np.zeros(4, dtype=np.float32)

    def step(self, action):
        scale_id, bw_id, tile_id = action

        scale = [0.01, 0.02, 0.04][scale_id]
        bitwidth = [8, 6][bw_id]
        tile = [8, 16, 32][tile_id]

        # Apply action to model
        self.model.scales = (scale, scale, scale)

        # Simulate lower precision penalty
        accuracy_penalty = 0.02 if bitwidth == 6 else 0.0

        # Corrected call to run method
        result = self.sim.run(self.model, self.x)

        latency = result["latency"]
        energy = result["energy"]

        reward = -(
            latency +
            0.1 * energy +
            100 * accuracy_penalty
        )

        self.layer_id += 1
        done = self.layer_id >= 1  # one block per episode

        obs = np.array([
            self.layer_id,
            latency,
            energy,
            accuracy_penalty
        ], dtype=np.float32)

        return obs, reward, done, {}


In [14]:
!pip install stable_baselines3
!pip install 'shimmy>=2.0'
from stable_baselines3 import PPO

env = AcceleratorEnv(
    model=INT8TransformerBlock(32, 64),
    input_shape=(2, 8, 32)
)

agent = PPO(
    "MlpPolicy",
    env,
    verbose=1,
    learning_rate=3e-4,
    n_steps=128
)

agent.learn(total_timesteps=5000)

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.




---------------------------------
| rollout/           |          |
|    ep_len_mean     | 1        |
|    ep_rew_mean     | -106     |
| time/              |          |
|    fps             | 314      |
|    iterations      | 1        |
|    time_elapsed    | 0        |
|    total_timesteps | 128      |
---------------------------------
------------------------------------------
| rollout/                |              |
|    ep_len_mean          | 1            |
|    ep_rew_mean          | -106         |
| time/                   |              |
|    fps                  | 299          |
|    iterations           | 2            |
|    time_elapsed         | 0            |
|    total_timesteps      | 256          |
| train/                  |              |
|    approx_kl            | 3.931159e-05 |
|    clip_fraction        | 0            |
|    clip_range           | 0.2          |
|    entropy_loss         | -2.89        |
|    explained_variance   | 0            |
|    learning_r

<stable_baselines3.ppo.ppo.PPO at 0x79fcfc999430>