In [None]:
import functools
import importlib
import gc
import os
import sys
from dataclasses import dataclass
from multiprocessing import Pool
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly
import plotly.express as px
import plotly.graph_objects as go
import seaborn as sns
import torch
import transformers
from torch import nn
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm
from transformers.modeling_bert import BertPreTrainedModel

In [None]:
@dataclass
class Config:
    
    nrows = None
    inputdir = Path('/kaggle/input')
    datadir = inputdir / 'google-quest-challenge'
    codedir = inputdir / 'bert-base-random-code'
    pretrained_modeldir = inputdir / 'bert-base-pretrained/stackx-base-cased'
    finetune_modeldir = inputdir / 'bert-base-pseudo-noleak-random'
    input_columns = ["question_title", "question_body", "answer"]
    target_columns = [
        "question_asker_intent_understanding",
        "question_body_critical",
        "question_conversational",
        "question_expect_short_answer",
        "question_fact_seeking",
        "question_has_commonly_accepted_answer",
        "question_interestingness_others",
        "question_interestingness_self",
        "question_multi_intent",
        "question_not_really_a_question",
        "question_opinion_seeking",
        "question_type_choice",
        "question_type_compare",
        "question_type_consequence",
        "question_type_definition",
        "question_type_entity",
        "question_type_instructions",
        "question_type_procedure",
        "question_type_reason_explanation",
        "question_type_spelling",
        "question_well_written",
        "answer_helpful",
        "answer_level_of_information",
        "answer_plausible",
        "answer_relevance",
        "answer_satisfaction",
        "answer_type_instructions",
        "answer_type_procedure",
        "answer_type_reason_explanation",
        "answer_well_written",
    ]
    max_sequence_length = 500
    max_title_length = 26
    max_question_length = 260
    max_answer_length = 210
    batch_size = 8
    head_tail = True
    num_classes = 30

config = Config()

In [None]:
def load_module(filename):
    assert isinstance(filename, Path)
    name = filename.stem
    spec = importlib.util.spec_from_file_location(name, filename)
    mod = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(mod)
    sys.modules[mod.__name__] = mod
    return mod

In [None]:
datasets = load_module(config.codedir / 'dataset.py')

In [None]:
test_df = pd.read_csv(config.datadir / 'test.csv', nrows=config.nrows)

tokenizers = dict()
tokenizers['origin'] = transformers.BertTokenizer.from_pretrained('bert-base-cased', do_lower_case=False)
tokenizers['pretrained'] = transformers.BertTokenizer.from_pretrained(str(config.pretrained_modeldir), do_lower_case=False)

test_datasets = dict()
test_datasets['origin'] = datasets.get_test_set(config, test_df, tokenizers['origin'])
test_datasets['pretrained'] = datasets.get_test_set(config, test_df, tokenizers['pretrained'])

In [None]:
class BertAttentionExtractor(BertPreTrainedModel):

    def __init__(self, config):
        config.output_hidden_states = True
        config.output_attentions = True
        super().__init__(config)
        self.num_labels = config.num_labels
        self.bert = transformers.BertModel(config)
        self.dropout = nn.Dropout(p=0.2)
        self.high_dropout = nn.Dropout(p=0.5)

        n_weights = config.num_hidden_layers + 1
        weights_init = torch.zeros(n_weights).float()
        weights_init.data[:-1] = -3
        self.layer_weights = torch.nn.Parameter(weights_init)
        self.classifier = torch.nn.Linear(config.hidden_size, self.config.num_labels)

        self.init_weights()

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
    ):
        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
        )
        return outputs[-1]

In [None]:
# models = dict()
# models['origin'] = BertAttentionExtractor.from_pretrained('bert-base-cased', num_labels=config.num_classes)
# models['pretrained'] = BertAttentionExtractor.from_pretrained(str(config.pretrained_modeldir), num_labels=config.num_classes)
# models['finetuned_fold0'] = BertAttentionExtractor.from_pretrained(str(config.pretrained_modeldir), num_labels=config.num_classes)
# models['finetuned_fold0'] = torch.nn.DataParallel(models['finetuned_fold0'])
# models['finetuned_fold0'].load_state_dict(torch.load(config.finetune_modeldir / 'fold0/best_model.pth', map_location=torch.device('cpu')))
# models['finetuned_fold1'] = BertAttentionExtractor.from_pretrained(str(config.pretrained_modeldir), num_labels=config.num_classes)
# models['finetuned_fold1'] = torch.nn.DataParallel(models['finetuned_fold1'])
# models['finetuned_fold1'].load_state_dict(torch.load(config.finetune_modeldir / 'fold1/best_model.pth', map_location=torch.device('cpu')))

In [None]:
def build_heatmap(i, attention_maps, tokens):
    row, col = i // n_cols, i % n_cols
    heatmap = go.Heatmap(
        z=attention_maps[row, col],
        colorscale=['ivory', 'red', 'purple'],
        zmin=0, zmax=0.1,
#         customdata=tokens,
        hovertemplate="""
        attention: %{z}
        <br>x: %{x}, y: %{y}
        <extra></extra>
        """
#         <br>x: "%{customdata[0]}", y: "%{customdata[1]}"
    )
    return heatmap
    

def plot_attention_heatmap(model, tokenizer, sample, title, n_rows=12, n_cols=12):
    model.eval()
    attention_maps = model(
        input_ids=sample.input_ids[None],
        attention_mask=sample.attention_mask[None],
        token_type_ids=sample.token_type_ids[None],
    )
    attention_maps = torch.stack(attention_maps, axis=1).detach().cpu().numpy()[0]
    
    tokens = tokenizer.convert_ids_to_tokens(sample.input_ids)
    tokens = np.stack([
        np.array(tokens)[None, :].repeat(500, axis=0),
        np.array(tokens)[:, None].repeat(500, axis=1),
    ], axis=-1)
    N = n_rows * n_cols
    fig = plotly.subplots.make_subplots(n_rows, n_cols)
    
    f = functools.partial(
        build_heatmap,
        attention_maps=attention_maps,
        tokens=tokens
    )
    with Pool(os.cpu_count()) as pool:
        for i, heatmap in enumerate(tqdm(pool.imap(f, range(N)), total=N)):
            row, col = i // n_cols, i % n_cols
            fig.add_trace(trace=heatmap, row=row+1, col=col+1)

    fig.update_layout(
        width=200 * n_rows,
        height=200 * n_cols,
        title_text=f"""
        Attention heatmap [{title}]
        <br>question_title: {sample.question_title}
        <br>question_body: {sample.question_body}
        <br>answer: {sample.answer}
        """.strip(),
        font=dict(
            size=7,
        ),
    )

    return fig
#     fig.write_html(f"{key}_{i_data:04}.html")

In [None]:
def build_model(name):
    if name == 'origin':
        return BertAttentionExtractor.from_pretrained('bert-base-cased', num_labels=config.num_classes)
    elif name == 'pretrained':
        return BertAttentionExtractor.from_pretrained(str(config.pretrained_modeldir), num_labels=config.num_classes)
    elif name == 'finetuned_fold0':
        model = BertAttentionExtractor.from_pretrained(str(config.pretrained_modeldir), num_labels=config.num_classes)
        model = torch.nn.DataParallel(model)
        model.load_state_dict(torch.load(config.finetune_modeldir / 'fold0/best_model.pth', map_location=torch.device('cpu')))
        return model
    elif name == 'finetuned_fold1':
        model = BertAttentionExtractor.from_pretrained(str(config.pretrained_modeldir), num_labels=config.num_classes)
        model = torch.nn.DataParallel(model)
        model.load_state_dict(torch.load(config.finetune_modeldir / 'fold1/best_model.pth', map_location=torch.device('cpu')))
        return model

In [None]:
idx = 0
name = 'origin'
n_rows, n_cols = 2, 2

model = build_model(name)
dataset = test_datasets['origin'] if name == 'origin' else test_datasets['pretrained']
tokenizer = tokenizers['origin'] if name == 'origin' else tokenizers['pretrained']

sample = test_df.iloc[idx].copy()
sample['input_ids'], sample['attention_mask'], sample['token_type_ids'], _ = dataset[idx]
fig = plot_attention_heatmap(model, tokenizer, sample, title=name, n_rows=n_rows, n_cols=n_cols)
fig
#     fig.write_html(f"{key}_{i_data:04}.html")

In [None]:
idx = 1
n_rows, n_cols = 12, 12
models = [
    'origin',
    'pretrained',
    'finetuned_fold0',
    'finetuned_fold1',
]

for name in models:
    del model, fig
    gc.collect()
    model = build_model(name)
    dataset = test_datasets['origin'] if name == 'origin' else test_datasets['pretrained']
    tokenizer = tokenizers['origin'] if name == 'origin' else tokenizers['pretrained']

    sample = test_df.iloc[idx].copy()
    sample['input_ids'], sample['attention_mask'], sample['token_type_ids'], _ = dataset[idx]
    fig = plot_attention_heatmap(model, tokenizer, sample, title=name, n_rows=n_rows, n_cols=n_cols)
    fig.write_html(f"{name}_data={idx:04}.html")

In [None]:
!df -h