In [21]:
from scipy.stats import ttest_ind

In [1]:
pip install transformers datasets sentencepiece --quiet

In [2]:
import json
import logging
import os
import re
import string
from pathlib import Path
import pandas as pd
import matplotlib.pyplot as plt

import torch
from torch.utils.data import DataLoader, DistributedSampler, Dataset
import numpy as np
import transformers  # noqa: E402
from transformers import AutoModel, AutoTokenizer, HfArgumentParser  # noqa: E402
from cola_process import process_data
from datasets import load_dataset
import sentencepiece

In [3]:
ds = load_dataset('linxinyuan/cola')
ds_train = pd.DataFrame(ds["train"])
ds_test = pd.DataFrame(ds["test"])



  0%|          | 0/2 [00:00<?, ?it/s]

In [4]:
MODEL_NAME = 'microsoft/deberta-v3-base'

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModel.from_pretrained(MODEL_NAME)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Downloading pytorch_model.bin:   0%|          | 0.00/371M [00:00<?, ?B/s]

Some weights of the model checkpoint at microsoft/deberta-v3-base were not used when initializing DebertaV2Model: ['mask_predictions.LayerNorm.weight', 'lm_predictions.lm_head.bias', 'mask_predictions.dense.bias', 'mask_predictions.dense.weight', 'lm_predictions.lm_head.dense.bias', 'lm_predictions.lm_head.LayerNorm.weight', 'lm_predictions.lm_head.dense.weight', 'mask_predictions.classifier.bias', 'mask_predictions.classifier.weight', 'mask_predictions.LayerNorm.bias', 'lm_predictions.lm_head.LayerNorm.bias']
- This IS expected if you are initializing DebertaV2Model from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DebertaV2Model from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [6]:
processed_df = process_data(ds_train, MODEL_NAME)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [30]:
processed_df.head()

Unnamed: 0,0,1,type,ungram,watch_points
0,One more pseudo generalization and I'm giving up.,One more pseudo generalization or I'm giving up.,swap,2,[4]
1,The gardener watered the flowers flat.,The gardener watered the flowers.,insertion,2,[5]
2,They drank the pub.,They drank the pub dry.,deletion,1,[4]
3,We yelled ourselves.,We yelled ourselves hoarse.,deletion,1,[3]
4,They made him president.,They made him angry.,swap,2,[3]


Выделяются четыре типа ошибок: замена слова (swap), вставка слова (insertion), удаление слова (deletion) и изменение порядка слоа (order). Интересуещее нас значение ungram - 1, которое означает, что при изменении произошел переход от граммтичного предложения 1 к неграмматичному предложению 0. watch_point - номер токена, на котором "происходит неграмматичность". 

In [19]:
class ColaDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return self.data.shape[0]
    
    def __getitem__(self, idx):
        out = {0: self.data[0][idx],
               1: self.data[1][idx],
               'watch_points': self.data['watch_points'][idx]}
        return out
  
def collect_statistics(batch, outs, err_type):
    stats = []
    for i in range(batch['watch_points'].shape[0]):
        cur_stats = {}
        wp = batch['watch_points'][i]
        gr_wp_in = torch.stack([layer[i*2, :, :batch['offsets'][i*2], wp].squeeze().sum(dim=0) for layer in outs['attentions']]).sum(dim=0)
        un_wp_in = torch.stack([layer[i*2+1, :, :batch['offsets'][i*2], wp].squeeze().sum(dim=0) for layer in outs['attentions']]).sum(dim=0)
        gr_wp_out = torch.stack([layer[i*2, :, wp, :batch['offsets'][i*2+1]].squeeze().sum(dim=0) for layer in outs['attentions']]).sum(dim=0)
        un_wp_out = torch.stack([layer[i*2+1, :, wp, :batch['offsets'][i*2+1]].squeeze().sum(dim=0) for layer in outs['attentions']]).sum(dim=0)
        cur_stats['mean_in_gr'] = gr_wp_in.mean()
        cur_stats['mean_in_un'] = un_wp_in.mean()
        cur_stats['mean_out_gr'] = gr_wp_out.mean()
        cur_stats['mean_out_un'] = un_wp_out.mean()
        if err_type == 'swap':
            cur_stats['max_in_change_token'] = tokenizer.decode(batch['data']['input_ids'][i*2][torch.argmax((un_wp_in - gr_wp_in).abs())])
            cur_stats['max_out_change_token'] = tokenizer.decode(batch['data']['input_ids'][i*2][torch.argmax((un_wp_out - gr_wp_out).abs())])
        # elif err_type == 'insertion':

        stats.append(cur_stats)
    return stats
def plot_attentions(batch, outs, err_type, model_name):
    model_name = model_name.replace('/', '_')
    if not os.path.exists(model_name):
        os.mkdir(model_name)
    if not os.path.exists(os.path.join(model_name, err_type)):
        os.mkdir(os.path.join(model_name, err_type))
    for i in range(1):
        cur_tokens1 = [tokenizer.decode(token) for token in batch['data']['input_ids'][i*2]][:batch['offsets'][i*2]]
        cur_tokens0 = [tokenizer.decode(token) for token in batch['data']['input_ids'][i*2+1]][:batch['offsets'][i*2+1]]
        for layer in range(12):
            fig = plt.figure(figsize=(24, 36))
            for head in range(12):
                ax = fig.add_subplot(6, 4, head*2+1)
                ax.matshow(outs['attentions'][layer][i*2, head, :batch['offsets'][i*2],:batch['offsets'][i*2]], cmap='Reds')
                fontdict = {'fontsize': 20}
                ax.set_xticks(range(len(cur_tokens1)))
                ax.set_yticks(range(len(cur_tokens1)))
                ax.set_yticklabels(cur_tokens1, fontdict=fontdict)
                ax.set_xticklabels(cur_tokens1, fontdict=fontdict, rotation=90)
                ax.set_xlabel(f'''Ex {i+1}, layer {layer}, head {head}, grammatical ''', fontdict)  
                ax = fig.add_subplot(6, 4, head*2+2)
                ax.matshow(outs['attentions'][layer][i*2+1, head, :batch['offsets'][i*2+1],:batch['offsets'][i*2+1]], cmap='Reds')
                fontdict = {'fontsize': 20}
                ax.set_xticks(range(len(cur_tokens0)))
                ax.set_yticks(range(len(cur_tokens0)))
                ax.set_yticklabels(cur_tokens0, fontdict=fontdict)
                ax.set_xticklabels(cur_tokens0, fontdict=fontdict, rotation=90)
                ax.set_xlabel(f'''Ex {i+1}, layer {layer}, head {head}, ungrammatical ''', fontdict)  
            plt.tight_layout()
            plt.savefig(os.path.join(model_name, err_type, str(f'ex{i}layer{layer}')))
            plt.close()

encode_plus_kwargs = {
    'add_special_tokens': False,
    'padding': 'longest',
    'pad_to_multiple_of': 1
}

def collate_fn(batch):
    samples = []
    [samples.extend(j) for j in [[batch[i][1], batch[i][0]] for i in range(len(batch))]]
    
    out_data = tokenizer.batch_encode_plus(samples, return_tensors='pt', **encode_plus_kwargs)
    out_batch = {'data':out_data,
                'watch_points': np.array([batch[i]['watch_points'] for i in range(len(batch))]),
                'offsets': (out_data['input_ids'] >0).sum(dim=1).numpy()}
    return out_batch

def run_test(model, tokenizer, err_type):
    cola_ds = ColaDataset(processed_df[(processed_df['type'] == err_type) & (processed_df['ungram'] == 1)].reset_index())
    cola_dl = DataLoader(cola_ds, batch_size=4, num_workers=2, collate_fn=collate_fn)
    all_stat = []
    with torch.no_grad():
        for i, batch in enumerate(cola_dl):
            outs = model(output_attentions=True, **batch['data'])
            if i == 0:
                plot_attentions(batch, outs, err_type, MODEL_NAME)
            all_stat.extend(collect_statistics(batch, outs, err_type))
    return all_stat

In [20]:
stats = run_test(model, tokenizer, 'swap')

In [25]:
mean_in_gr = [ex['mean_in_gr'].item() for ex in  stats]
mean_in_un = [ex['mean_in_un'].item() for ex in  stats]

In [29]:
ttest_ind(mean_in_gr, mean_in_un).pvalue

0.8336242929339761

In [27]:
mean_out_gr = [ex['mean_out_gr'].item() for ex in  stats]
mean_out_un = [ex['mean_out_un'].item() for ex in  stats]

In [28]:
ttest_ind(mean_out_gr, mean_out_un)

Ttest_indResult(statistic=1.365890382826457e-07, pvalue=0.9999998910677874)