In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

model_id = "facebook/opt-1.3b"
model = AutoModelForCausalLM.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)

config.json:   0%|          | 0.00/653 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/2.63G [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.63G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/137 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/685 [00:00<?, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/441 [00:00<?, ?B/s]

In [None]:
act_list = []

def hook(model, input, output):
    act_list.append(input[0].detach().cpu())


layer_name = "decoder.layers.0.self_attn.q_proj2"
target_layer = model.model.decoder.layers[0].self_attn.q_proj
handle = target_layer.register_forward_hook(hook)

calibration_texts = [
    "The quick brown fox jumps over the lazy dog.",
    "Artificial intelligence is transforming the world at an unprecedented pace.",
    "Photosynthesis is the process by which green plants and some other organisms use sunlight to synthesize foods.",
    "To be, or not to be, that is the question.",
    "The capital of France is Paris, which is known for its cafe culture and landmarks.",
    "Deep learning models require vast amounts of data to train effectively.",
    "The theory of relativity was developed by Albert Einstein in the early 20th century.",
    "Python is a high-level, interpreted programming language known for its readability.",
    "Global warming is the long-term heating of Earth's climate system observed since the pre-industrial period.",
    "In 1492, Christopher Columbus sailed the ocean blue.",
    "The mitochondria is the powerhouse of the cell.",
    "Machine learning algorithms build a model based on sample data, known as training data.",
    "The Great Wall of China is a series of fortifications that were built across the historical northern borders of China.",
    "Quantum mechanics is a fundamental theory in physics that provides a description of the physical properties of nature at the scale of atoms.",
    "I wandered lonely as a cloud that floats on high o'er vales and hills.",
    "The internet is a global system of interconnected computer networks.",
    "Water boils at 100 degrees Celsius at standard atmospheric pressure.",
    "The history of computers dates back to the invention of the abacus.",
    "Music is an art form, and cultural activity, whose medium is sound.",
    "Democracy is a form of government in which the people have the authority to deliberate and decide legislation.",
    "The human brain is the central organ of the human nervous system.",
    "Space exploration helps us understand the universe and our place in it.",
    "Renewable energy is energy that is collected from renewable resources, which are naturally replenished.",
    "Economics is the social science that studies the production, distribution, and consumption of goods and services.",
    "Literature broadly is any collection of written work, but it is also used more narrowly for writings specifically considered to be an art form.",
    "def quicksort(arr): if len(arr) <= 1: return arr",
    "Linear algebra is central to almost all areas of mathematics.",
    "The stock market is a collection of markets and exchanges where regular activities of buying, selling, and issuance of shares of publicly-held companies take place.",
    "Coffee is a brewed drink prepared from roasted coffee beans, the seeds of berries from certain Coffea species.",
    "Optimization is the selection of a best element from some set of available alternatives."
]
for text in calibration_texts:
    inputs = tokenizer(text, return_tensors="pt").to(model.device)
    model(**inputs)

handle.remove()

all_acts = torch.cat(act_list, dim=1).squeeze(0) # Shape: [Total Tokens, Hidden_Dim]
print(f"Total Tokens Collected: {all_acts.shape[0]}")

weight = target_layer.weight.detach().cpu() # Shape: [Out_Dim, In_Dim]

Total Tokens Collected: 532


In [None]:
# Ï±ÑÎÑêÎ≥Ñ ÌèâÍ∑† Ï†àÎåÄÍ∞í Í≥ÑÏÇ∞ (s_x)
per_channel_mean = all_acts.abs().mean(dim=0).float() # [2048]

print(f"Per Channel Mean Shape: {per_channel_mean.shape}")

# ÏÉÅÏúÑ 5Í∞ú Ïù∏Îç±Ïä§ ÏôÄ Í∞í Ï∂úÎ†•
top_indices = per_channel_mean.argsort(descending=True)[:5]

print(f"Top Indices: {top_indices}, value : {per_channel_mean[top_indices]}")

Per Channel Mean Shape: torch.Size([2048])
Top Indices: tensor([1500, 1767,  306, 2011, 1202]), value : tensor([5.5607, 5.4745, 5.4582, 5.2490, 4.8360])


In [None]:
def pseudo_quantize_grouped(w, group_size=128):
    out_features, in_features = w.shape
    num_groups = in_features // group_size

    # [Out, In] -> [Out, Num_Groups, Group_Size]
    w_grouped = w.view(out_features, num_groups, group_size)

    # Max value per group
    max_val = w_grouped.abs().amax(dim=-1, keepdim=True)
    scale = max_val / 7 + 1e-5

    # Quantize & Dequantize
    w_int = torch.round(w_grouped / scale)
    w_int = torch.clamp(w_int, min=-7, max=7)
    w_dequant = w_int * scale

    return w_dequant.view(out_features, in_features)

# ÏûÖÎ†• Îç∞Ïù¥ÌÑ∞ Ï§ÄÎπÑ
x = all_acts.to(weight.device).float()
w_float = weight.float()

# Original Output
y_orig = torch.matmul(x, w_float.t())

# RTN Baseline (Alpha=0)
w_rtn = pseudo_quantize_grouped(w_float, group_size=128)
y_rtn = torch.matmul(x, w_rtn.t())
mse_rtn = (y_orig - y_rtn).pow(2).mean().item()

print("\n" + "="*60)
print(f"{'Alpha':<10} | {'Output MSE':<15} | {'Improvement'}")
print("-" * 60)

best_mse = float('inf')
best_alpha = -1

for alpha in [i * 0.1 for i in range(11)]:
    # Scale Í≥ÑÏÇ∞ (s = s_x ^ alpha)
    current_scale = (per_channel_mean + 1e-5).pow(alpha).to(weight.device)

    # AWQ Ï†ÅÏö©: Scale -> Quantize -> Inverse Scale
    w_scaled = w_float * current_scale
    w_q_scaled = pseudo_quantize_grouped(w_scaled, group_size=128)
    w_restored = w_q_scaled / current_scale

    # MSE Í≥ÑÏÇ∞
    y_awq = torch.matmul(x, w_restored.t())
    mse = (y_orig - y_awq).pow(2).mean().item()

    # Improvement ÌôïÏù∏ (RTN ÎåÄÎπÑ ÏñºÎßàÎÇò Ï§ÑÏóàÎÇò)
    imp_str = "Better! üéâ" if mse < mse_rtn else "Worse"
    print(f"{alpha:<10.1f} | {mse:<15.6f} | {imp_str}")

    if mse < best_mse:
        best_mse = mse
        best_alpha = alpha

print("="*60)
print(f"Original RTN MSE : {mse_rtn:.6f}")
print(f"Best AWQ MSE     : {best_mse:.6f} (at alpha={best_alpha:.1f})")

if best_mse < mse_rtn:
    print(f"üöÄ ÏÑ±Í≥µÏûÖÎãàÎã§! AWQÍ∞Ä RTNÎ≥¥Îã§ ÏóêÎü¨Î•º ÏïΩ {(mse_rtn - best_mse)/mse_rtn*100:.2f}% Ï§ÑÏòÄÏäµÎãàÎã§.")
else:
    print("ü§î Ìù†, ÏïÑÏßÅÎèÑ RTNÏù¥ Îçî Ï¢ãÎã§Î©¥ Ï±ÑÎÑê ÌäπÏÑ±Ïù¥ ÎÑàÎ¨¥ ÌèâÌÉÑÌï† Ïàò ÏûàÏäµÎãàÎã§.")


Alpha      | Output MSE      | Improvement
------------------------------------------------------------
0.0        | 0.005367        | Worse
0.1        | 0.005034        | Better! üéâ
0.2        | 0.005160        | Better! üéâ
0.3        | 0.005525        | Worse
0.4        | 0.006615        | Worse
0.5        | 0.008009        | Worse
0.6        | 0.010121        | Worse
0.7        | 0.013964        | Worse
0.8        | 0.018731        | Worse
0.9        | 0.025097        | Worse
1.0        | 0.033489        | Worse
Original RTN MSE : 0.005367
Best AWQ MSE     : 0.005034 (at alpha=0.1)
üöÄ ÏÑ±Í≥µÏûÖÎãàÎã§! AWQÍ∞Ä RTNÎ≥¥Îã§ ÏóêÎü¨Î•º ÏïΩ 6.21% Ï§ÑÏòÄÏäµÎãàÎã§.
