## Investigate effect of quantization on model predictions and performance

In [38]:
import os
import json
import torch
import pandas as pd
import numpy as np
from utils import MODEL_TO_NAME, load_data

In [None]:
root_dir = ''
quantization = [
    '4_bit_quant',
    '8_bit_quant',
    '16_bit_quant',
    'full'
]

tasks = [
    'atis',
    'snips',
    'clinic150',
    'massive'
]

models = [
    'llama-3.1-8b-instruct',
    'gemma-2-9b-it',
    'phi-3-medium-4k-instruct',
    'mistral-7b-instruct'
]

In [44]:
overall_performance = []
overall_correlation = []
for model_type in models:
    print(f"Tallying model {model_type}")
    model_name = MODEL_TO_NAME[model_type]
    model_entry = []
    for precision in quantization:
        precision_preds = []
        if os.path.exists(f"{root_dir}/results/{model_name}/{precision}"):
            precision_entry = {
                'model': model_type,
                'quantization': precision,
            }

            for task in tasks:
                _, preds, texts, model_outs, _ = load_data(task, model_type, model_name, precision)
                precision_preds.append(preds)  # failed preds are -1
                precision_entry[f"{task}-preds"] = preds
                precision_entry[f"{task}-fr"] = (preds == -1).float().mean()

            precision_preds = torch.concat(precision_preds)
            failed_map_rate = (precision_preds == -1).float().mean()
            print(f"{precision} failed mapping rate: {failed_map_rate:.4f}")
            precision_entry['failed_map_rate'] = failed_map_rate
            precision_entry['all_preds'] = precision_preds

            model_entry.append(precision_entry)
    model_entry = pd.DataFrame(model_entry)

    model_quants = model_entry['quantization'].to_list()
    for i in range(len(model_quants)):
        for j in range(i + 1, len(model_quants)):
            precision_a = model_quants[i]
            precision_b = model_quants[j]
            correlation_entry = {
                'model': model_type,
                'precision_a': precision_a,
                'precision_b': precision_b,
            }

            model_entry_a = model_entry[(model_entry['model'] == model_type) & (model_entry['quantization'] == precision_a)]
            model_entry_b = model_entry[(model_entry['model'] == model_type) & (model_entry['quantization'] == precision_b)]

            for task in tasks:
                preds_a = model_entry_a[f"{task}-preds"].item()
                preds_b = model_entry_b[f"{task}-preds"].item()
                task_corr_ab = np.corrcoef(preds_a, preds_b)[0, 1]
                correlation_entry[f"{task}_fr_a"] = (preds_a == -1).float().mean().item()
                correlation_entry[f"{task}_fr_b"] = (preds_b == -1).float().mean().item()
                correlation_entry[f"{task}-corr-ab"] = task_corr_ab

            correlation_entry['failed_map_rate_a'] = model_entry_a['failed_map_rate'].item().item()
            correlation_entry['failed_map_rate_b'] = model_entry_b['failed_map_rate'].item().item()
            preds_a = model_entry_a['all_preds'].item()
            preds_b = model_entry_b['all_preds'].item()

            non_failed_preds_idx = (preds_a != -1) & (preds_b != -1)
            ovr_corr_ab = np.corrcoef(preds_a, preds_b)[0, 1]
            suc_corr_ab = np.corrcoef(preds_a[non_failed_preds_idx], preds_b[non_failed_preds_idx])[0, 1]
            # print(f"Overall {precision_a} - {precision_b} corr: {ovr_corr_ab:.4f}")
            # print(f"Success {precision_a} - {precision_b} corr: {suc_corr_ab:.4f}")

            correlation_entry['ovr_corr'] = ovr_corr_ab
            correlation_entry['suc_corr'] = suc_corr_ab

            # print(json.dumps(correlation_entry, indent=4))

            overall_correlation.append(correlation_entry)

    overall_performance.append(model_entry)
    print()

overall_correlation = pd.DataFrame(overall_correlation)

Tallying model llama-3.1-8b-instruct
4_bit_quant failed mapping rate: 0.0056
8_bit_quant failed mapping rate: 0.0055
full failed mapping rate: 0.0047

Tallying model gemma-2-9b-it
4_bit_quant failed mapping rate: 0.0145
8_bit_quant failed mapping rate: 0.0123
full failed mapping rate: 0.0118

Tallying model phi-3-medium-4k-instruct
4_bit_quant failed mapping rate: 0.1016
8_bit_quant failed mapping rate: 0.0077
16_bit_quant failed mapping rate: 0.0076

Tallying model mistral-7b-instruct
4_bit_quant failed mapping rate: 0.0519
8_bit_quant failed mapping rate: 0.0608
full failed mapping rate: 0.0607



In [45]:
overall_correlation.to_excel('quantization_effects.xlsx', index=False)

In [21]:
model_type = 'llama-3.1-8b-instruct'
model_name = 'Meta-Llama-3.1-8B-Instruct'

for task in tasks:
    l4, p4, t4, m4, _ = load_data(task, model_type, model_name, quantization[0])
    l8, p8, t8, m8, _ = load_data(task, model_type, model_name, quantization[1])
    # lf, pf, tf, mf = load_data(task, model_type, model_name, quantization[2])

    compare_predictions(p4, p8, f"{task} 4bit vs 8bit")
    # compare_predictions(p8, pf, "8bit vs full")
    # compare_predictions(p4, pf, "4bit vs full")


assert (l4 == l8).all()  # Sanity check, labels should all be the same
# assert (l4 == l8).all() and (l4 == lf).all()  # Sanity check, labels should all be the same

[ATIS]
0.8037 average equivalent prediction atis 4bit vs 8bit
0.6680 pearsons r
[SNIPS]
0.8061 average equivalent prediction snips 4bit vs 8bit
0.7977 pearsons r
[CLINIC150]
0.8938 average equivalent prediction clinic150 4bit vs 8bit
0.8903 pearsons r
[MASSIVE]
0.8001 average equivalent prediction massive 4bit vs 8bit
0.8672 pearsons r
