In [None]:
import torch
import torch.nn as nn
import numpy as np

from datasets import load_dataset
from transformers import BertTokenizer, BertForSequenceClassification
from torch.utils.data import DataLoader, TensorDataset, RandomSampler

np.random.seed(42)
torch.manual_seed(42)

In [None]:
model_type = "medium" # tiny, mini, small, medium, base, large
task_name = "mnli" # cola, mnli, qnli, qqp
learning_rate = 2e-05
epochs = 3
batch_size = 32
max_length = 128

In [None]:
model_name = {
    'tiny': 'google/bert_uncased_L-2_H-128_A-2',
    'mini': 'google/bert_uncased_L-4_H-256_A-4',
    'small': 'google/bert_uncased_L-4_H-512_A-8',
    'medium': 'google/bert_uncased_L-8_H-512_A-8',
    'base': 'google/bert_uncased_L-12_H-768_A-12',
    'large': 'bert-large-uncased'
}

model_name = model_name[model_type]
tokenizer = BertTokenizer.from_pretrained(model_name)

task = {
    "qnli":{
        "num_labels": 2,
        "test_dataset_name": "validation",
        "tokenize": lambda data:tokenizer(data['question'], data['sentence'], truncation=True, max_length=max_length, padding='max_length')
    },
    "mnli":{
        "num_labels": 3,
        "test_dataset_name": "validation_matched",
        "tokenize": lambda data:tokenizer(data['premise'], data['hypothesis'], truncation=True, max_length=max_length, padding='max_length')
    },
    "qqp":{
        "num_labels": 2,
        "test_dataset_name": "validation",
        "tokenize": lambda data:tokenizer(data['question1'], data['question2'], truncation=True, max_length=max_length, padding='max_length')
    },
    "cola":{
        "num_labels": 2,
        "test_dataset_name": "validation",
        "tokenize": lambda data:tokenizer(data['sentence'], truncation=True, max_length=max_length, padding='max_length')
    }
}

task = task[task_name]
saved_path = f'../ignore/task/bert-{model_type}_{task_name}.pt'
model = BertForSequenceClassification.from_pretrained(model_name, num_labels=task["num_labels"])

In [None]:
dataset = load_dataset('glue', task_name)
tokenize = task['tokenize']
tokenized_dataset = dataset.map(tokenize, batched=True, batch_size = 5000)

def create_dataloader(dataset, batch_size, random=False):
    input_ids = torch.tensor(dataset['input_ids'])
    attention_masks = torch.tensor(dataset['attention_mask'])
    labels = torch.tensor(dataset['label'])
    tensor_dataset = TensorDataset(input_ids, attention_masks, labels)
    random_sampler = RandomSampler(tensor_dataset)
    if random:
        return DataLoader(tensor_dataset, batch_size = batch_size, sampler=random_sampler)
    return DataLoader(tensor_dataset, batch_size = batch_size)

num_validation = (len(tokenized_dataset[task['test_dataset_name']]) // 10)
train_data_loader = create_dataloader(tokenized_dataset['train'], batch_size, random=True)
validation_data_loader = create_dataloader(tokenized_dataset[task['test_dataset_name']][:num_validation], batch_size)
test_data_loader = create_dataloader(tokenized_dataset[task['test_dataset_name']][num_validation:], batch_size)

In [None]:
import sys
import os
from copy import deepcopy

sys.path.append(os.path.relpath(".."))
sys.path.append(os.path.relpath("."))

from utils import train_model, evaluate_model, draw_activation, draw_weight, replace_modules
from smooth_quant import quantize_per_tensor_asymmetric, quantize_per_tensor_symmetric, get_act_scales, FakeQuantLinear, SmoothQuantLinear

In [None]:
model = torch.load(saved_path)
model.to('cuda')
model.eval()
# train_model(model, train_data_loader, epochs=epochs, lr=learning_rate)
# torch.save(model, saved_path)
evaluate_model(model, test_data_loader, multiple_classes=task['num_labels'] > 2)

In [None]:
act_scales = get_act_scales(model.bert, validation_data_loader)

In [None]:
model.to('cpu')

def smooth_quantize(model, method, better=False):
    smooth_model = deepcopy(model)
    
    for name, act_scale in act_scales.items():
        
        
        keys = name.split(".")
        module = smooth_model.bert

        for key in keys[:-1]:
            module = getattr(module, key)
        if 'attention.output.dense' in name:
            setattr(module, keys[-1], SmoothQuantLinear(getattr(module, keys[-1]), act_scale, quantization_method=method))
        elif better and ('output.dense' in name or 'intermediate.dense' in name):
            setattr(module, keys[-1], FakeQuantLinear(getattr(module, keys[-1]), quantization_method=method))
        else:
            setattr(module, keys[-1], SmoothQuantLinear(getattr(module, keys[-1]), act_scale, quantization_method=method))
        
    return smooth_model


def fake_quantize(model, method):
    fq_model = deepcopy(model)
    replace_modules(fq_model, nn.Linear, lambda model: FakeQuantLinear(model, quantization_method=method))
    return fq_model

In [None]:
smooth_model_sym = smooth_quantize(model, quantize_per_tensor_symmetric)
smooth_model_asym = smooth_quantize(model, quantize_per_tensor_asymmetric)
smooth_model_asym_better = smooth_quantize(model, quantize_per_tensor_asymmetric, True)
fq_model_sym = fake_quantize(model, quantize_per_tensor_symmetric)
fq_model_asym = fake_quantize(model, quantize_per_tensor_asymmetric)

print("original model")
model.cuda()
evaluate_model(model, test_data_loader, multiple_classes=task['num_labels'] > 2)

print("smooth quantized model (asymmetric)")
smooth_model_asym.to('cuda')
evaluate_model(smooth_model_asym, test_data_loader, multiple_classes=task['num_labels'] > 2)

print("smooth quantized model (asymmetric better)")
smooth_model_asym_better.to('cuda')
evaluate_model(smooth_model_asym_better, test_data_loader, multiple_classes=task['num_labels'] > 2)

print("fake quantized model (asymmetric)")
fq_model_asym.to('cuda')
evaluate_model(fq_model_asym, test_data_loader, multiple_classes=task['num_labels'] > 2)

print("smooth quantized model (symmetric)")
smooth_model_sym.to('cuda')
evaluate_model(smooth_model_sym, test_data_loader, multiple_classes=task['num_labels'] > 2)

print("fake quantized model (symmetric)")
fq_model_sym.to('cuda')
evaluate_model(fq_model_sym, test_data_loader, multiple_classes=task['num_labels'] > 2)

In [None]:
def get_diff_inputs(target_model, ref_model, data_loader):
    target_model.to('cuda')
    ref_model.to('cuda')
    
    diff_inputs = []
    
    for step, batch in enumerate(data_loader):
        batch_inputs = tuple(t.to('cuda') for t in batch)
        inputs = {
            'input_ids': batch_inputs[0],
            'attention_mask': batch_inputs[1],
            'labels': batch_inputs[2]
        }

        with torch.no_grad():
            target_outputs = target_model(**inputs)
            ref_outputs = model(**inputs)
            
            target_predictions = target_outputs.logits.detach().cpu().numpy().argmax(axis=1)
            ref_predictions = ref_outputs.logits.detach().cpu().numpy().argmax(axis=1)
        
            for idx in np.where(target_predictions != ref_predictions)[0]:
                diff_inputs.append({
                    'input_ids': batch_inputs[0][idx].unsqueeze(0),
                    'attention_mask': batch_inputs[1][idx].unsqueeze(0)
                })
                
    return diff_inputs

In [None]:
diff_inputs = get_diff_inputs(smooth_model_asym, model, validation_data_loader)

In [None]:
print(diff_inputs[0]['input_ids'].shape)

In [None]:
import functools

model.to('cuda')
error_threshold = 0.005
hooks = []
danger_ins = []
danger_outs = []

def mse_from_origin(m, x, y, name, origin_model):
    loss = nn.functional.mse_loss(y[0], origin_model(x[0].squeeze(0)))
    if loss > error_threshold:
        print(name, loss)
        danger_ins.append((name, x))
        
def register_hooks(model, origin_model, module_name):
    for name, module in model.named_children():
        if isinstance(module, SmoothQuantLinear):
            hooks.append(module.register_forward_hook(functools.partial(mse_from_origin, name=f'{module_name}.{name}', origin_model=getattr(origin_model, name))))
        else:
            register_hooks(module, getattr(origin_model, name), f'{module_name}.{name}')
            
register_hooks(smooth_model_asym.bert, model.bert, 'bert')

smooth_model_asym.to('cuda')
smooth_model_asym.eval()

for input in diff_inputs:
    with torch.no_grad():
        smooth_model_asym(**input)

for hook in hooks:
    hook.remove()

In [None]:
i = 0
num = 5

for input in danger_ins:
    print(input[0])
    draw_activation(input[1][0])
    i += 1
    if i >= num:
        break