In [1]:
%cd ..

/misc/vlgscratch4/LakeGroup/wentao/multimodal-baby


In [2]:
import itertools
import functools
import copy
import re
import numpy as np
import torch
import matplotlib
import matplotlib.pyplot as plt
%matplotlib inline
import pandas as pd
import seaborn as sns
from torchvision.transforms.functional import resized_crop
from multimodal.multimodal_lit import MultiModalLitModel
from multimodal.multimodal_data_module import PAD_TOKEN_ID, UNK_TOKEN_ID, SOS_TOKEN_ID, EOS_TOKEN_ID, normalizer
from ngram import NGramModel
from analysis_tools.utils import *
from analysis_tools.pos_tags import *
from analysis_tools.word_categories import *
from analysis_tools.token_items_data import *
from analysis_tools.plotting import *
import analysis_tools.plotting as plotting
from analysis_tools.multimodal_visualization import *
from analysis_tools.processing import *
from analysis_tools.build_data import *
from analysis_tools.checkpoints import *
from analysis_tools.frame_visualization import *


# set default settings for plotting; may change for each plot
figsize = (8, 7)
paper_context = sns.plotting_context('paper')
paper_context.update({
    'font.size': 10.,
    'axes.labelsize': 10.,
    'axes.titlesize': 14.,
    'xtick.labelsize': 8.8,
    'ytick.labelsize': 8.8,
    'legend.fontsize': 8.8,
    'legend.title_fontsize': 9.6,
})
talk_context = sns.plotting_context('talk')
talk_context.update({
    'font.size': 12.,
    'axes.labelsize': 12.,
    'axes.titlesize': 16.,
    'xtick.labelsize': 10.8,
    'ytick.labelsize': 10.8,
    'legend.fontsize': 10.8,
    'legend.title_fontsize': 11.6,
})
unticked_relation_style = sns.axes_style('white')
unticked_relation_style.update({
    'axes.spines.left': True,
    'axes.spines.bottom': True,
    'axes.spines.right': False,
    'axes.spines.top': False,
})
ticked_relation_style = sns.axes_style('ticks')
ticked_relation_style.update({
    'axes.spines.left': True,
    'axes.spines.bottom': True,
    'axes.spines.right': False,
    'axes.spines.top': False,
})
heatmap_style = copy.copy(unticked_relation_style)
heatmap_style.update({
    'axes.spines.left': False,
    'axes.spines.bottom': False,
})
font = 'serif'
sns.set_theme(
    context=paper_context,
    style=unticked_relation_style,
    palette=sns.color_palette('tab20'),
    font=font,
    rc={
        'figure.figsize': figsize,
    }
)

np.set_printoptions(suppress=True, precision=2, linewidth=120)
pd.options.display.width = 120

plt.rcParams['figure.dpi'] = 300
plt.rcParams['savefig.dpi'] = 300
plt.rcParams["savefig.bbox"] = 'tight'
plt.rcParams["savefig.pad_inches"] = 0.
add_titles = False
add_lines = True
line_kwargs = {
    'linestyle': 'dashed',
    'linewidth': 1,
    'color': 'black',
}
plot_format = 'pdf'
saving_fig = True

if plot_format == 'pgf':
    matplotlib.use('pgf')
    plt.rcParams.update({
        "pgf.texsystem": "pdflatex",
        "pgf.preamble": "\n".join([
             r"\usepackage[T1]{fontenc}",
        ]),
    })
elif plot_format == 'svg':
    from matplotlib_inline.backend_inline import set_matplotlib_formats
    set_matplotlib_formats('svg')

if saving_fig:
    def _save_fig(fname, format='png'):
        print(f'saving plot {fname}')
        plt.savefig(f'plots/{fname}.{format}', transparent=True)
        plt.clf()
    plotting.output_fig = functools.partial(_save_fig, format=plot_format)
else:
    plotting.output_fig = lambda fname: plt.show()

output_fig = plotting.output_fig

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
# load model from checkpoint

# select from list of checkpoints
dataset_name = "saycam"
ori_names = {
    "saycam": [
        "LSTM 0", "LSTM 1", "LSTM 2",
        "LSTM Captioning 0", "LSTM Captioning 1", "LSTM Captioning 2",
        "CBOW 0", "CBOW 1", "CBOW 2",
    ],
    "coco": [
        "lm",
        "capt_ft",
        "capt_attn_gt_ft",
        "capt_attn_gt_reg_ft",
        "capt_attn_gt_reg_untie_ft",
        "cbow",
    ],
}[dataset_name]
checkpoint_paths = all_checkpoint_paths[dataset_name]

ori_models = {}

data = None

for ori_name in ori_names:
    checkpoint_path = checkpoint_paths[ori_name]

    if "gram" in checkpoint_path:
        ngram_model = build_ngram_model(int(checkpoint_path.split('-')[0]), vocab_size, data.train_dataloader())
        model = ngram_model

    else:
        print_dict_args = False
        if print_dict_args:
            ckpt = torch.load(checkpoint)
            print(ckpt['hyper_parameters']['args'])

        print(f"load model from {checkpoint_path}")
        lit_model = MultiModalLitModel.load_from_checkpoint(checkpoint_path, map_location=device)
        #print(lit_model.args)
        lit_model.to(device)

        if data is None:
            # build data and vocab according to the model
            data, args = build_data(args=lit_model.args, return_args=True)
            dataset_name = args.dataset

            word2idx = lit_model.text_encoder.word2idx
            idx2word = lit_model.text_encoder.idx2word

            vocab = lit_model.text_encoder.vocab
            vocab_size = len(vocab)
            print(f'vocab_size = {vocab_size}')
            # check consistency between vocab and idx2word
            for idx in range(vocab_size):
                assert idx in idx2word

        else:
            assert lit_model.args["dataset"] == dataset_name, f"checkpoint {checkpoint_path} ran on a different dataset {args.dataset}"

        lit_model.eval()
        model = lit_model

    ori_models[ori_name] = model


# each name represents a group of models that we hope to use their mean predictions
if True:
    names = ['LSTM', 'LSTM Captioning', 'CBOW', 'Contrastive', 'Joint bs16', 'Joint bs512'][:-3]
    groups = {name: [ori_name for ori_name in all_groups[name] if ori_name in ori_names] for name in names}
else:
    names = ori_names
    groups = {name: [name] for name in names}


models = {}
for name, group in groups.items():
    assert group, f"no models corresponds to {name}"
    models[name] = ori_models[group[0]]

load model from checkpoints/lm_dataset_saycam_captioning_False_text_encoder_lstm_embedding_dim_512_dropout_i_0.5_dropout_o_0.0_batch_size_16_lr_0.006_lr_scheduler_True_weight_decay_0.04_seed_0/epoch=29.ckpt


  rank_zero_deprecation(


Using base transforms
Calling prepare_data!
SAYCam transcripts have already been downloaded. Skipping this step.
Transcripts have already been renamed. Skipping this step.
Transcripts have already been preprocessed. Skipping this step.
Training frames have already been extracted. Skipping this step.
Training metadata files have already been created. Skipping this step.
Shuffled training metadata file has already been created. Skipping this step.
Evaluation frames have already been filtered. Skipping this step.
Evaluation frames have already been extracted. Skipping this step.
Filtered evaluation frames have already been extracted. Skipping this step.
Evaluation metadata files have already been created. Skipping this step.
Evaluation metadata files have already been created. Skipping this step.
Extra evaluation metadata files have already been created. Skipping this step.
Extra evaluation metadata files have already been created. Skipping this step.
Vocabulary file already exists. Skipp

In [4]:
# get sum values (counts, vector representations, losses) across the training set

split_items = {}

used_splits = ['train', 'val']


def print_items(items, n=20):
    for index, row in items.iloc[:n].iterrows():
        print(row_str(row, names))

    print()

    columns = ['cnt'] + names
    tot_values = items.loc[items.index.map(lambda index: index[0] != PAD_TOKEN_ID), columns].sum(axis=0)
    tot = pd.Series(tot_values, index=items.columns)
    tot[token_field] = 'ppl'
    print(row_str(tot, names))
    tot_values -= items.loc[(SOS_TOKEN_ID,), columns].sum(axis=0)
    tot = pd.Series(tot_values, index=items.columns)
    tot[token_field] = 'ppl_wo_sos'
    print(row_str(tot, names))
    tot_values -= items.loc[(EOS_TOKEN_ID,), columns].sum(axis=0)
    tot = pd.Series(tot_values, index=items.columns)
    tot[token_field] = 'ppl_wo_sos_eos'
    print(row_str(tot, names))


def remove_foils_wrapper(dataloader):
    for x, y, y_len, raw_y in dataloader:
        yield x[:, 0], y, y_len, raw_y

my_batch_size = 256
dataloader_fns = {
    'train': lambda: data.train_dataloader(batch_size=my_batch_size, shuffle=False, drop_last=False),
    'val': lambda: data.val_dataloader(batch_size=my_batch_size)[0],
    'test': lambda: data.test_dataloader(batch_size=my_batch_size)[0],
    'eval_val': lambda: remove_foils_wrapper(data.val_dataloader()[1]),
    'eval_test': lambda: remove_foils_wrapper(data.test_dataloader()[1]),
}

for split in used_splits:
    dataloader_fn = dataloader_fns[split]

    pos_tags = get_pos_tags(dataloader_fn(), dataset_name, split)

    ori_model_items = {
        name: torch_cache(checkpoint_paths[name] + f'.{split}.cache.pt')(get_model_items)(
            model, dataloader_fn(), pos_tags, ignore_all_token_items=(split == 'train'))
        for name, model in ori_models.items()}
    model_items = {
        name: mean_model_items([ori_model_items[ori_name] for ori_name in group], idx=-1)
        for name, group in groups.items()}
    items = ModelItems(*[stack_items(items_list, names, idx2word) for items_list in list(zip(*model_items.values()))])
    extend_items(items.token_items, names, idx2word)
    extend_items(items.token_pos_items, names, idx2word)
    if items.all_token_items is not None:
        extend_items(items.all_token_items, names, idx2word)
    split_items[split] = items

    print_items(items.token_items)

    if split_items[split].all_token_items is not None:
        for name, group in groups.items():
            ori_probs = []
            for ori_name in group:
                model = ori_models[ori_name]
                if isinstance(model, NGramModel):
                    continue
                probs = get_model_probs(model, dataloader_fns[split](), pos_tags)
                ori_probs.append(probs)
            probs = mean_probs(ori_probs)
            split_items[split].all_token_items[f'{name} probs'] = probs.tolist()

load cached pos tags: dataset_cache/saycam/train.pos.cache
load from checkpoints/lm_dataset_saycam_captioning_False_text_encoder_lstm_embedding_dim_512_dropout_i_0.5_dropout_o_0.0_batch_size_16_lr_0.006_lr_scheduler_True_weight_decay_0.04_seed_0/epoch=29.ckpt.train.cache.pt
load from checkpoints/lm_dataset_saycam_captioning_False_text_encoder_lstm_embedding_dim_512_dropout_i_0.5_dropout_o_0.0_batch_size_16_lr_0.006_lr_scheduler_True_weight_decay_0.04_seed_1/epoch=38.ckpt.train.cache.pt
load from checkpoints/lm_dataset_saycam_captioning_False_text_encoder_lstm_embedding_dim_512_dropout_i_0.5_dropout_o_0.0_batch_size_16_lr_0.006_lr_scheduler_True_weight_decay_0.04_seed_2/epoch=28.ckpt.train.cache.pt
load from checkpoints/lm_dataset_saycam_captioning_True_text_encoder_lstm_embedding_dim_512_dropout_i_0.5_dropout_o_0.0_batch_size_16_lr_0.006_lr_scheduler_True_weight_decay_0.04_seed_0/epoch=29.ckpt.train.cache.pt
load from checkpoints/lm_dataset_saycam_captioning_True_text_encoder_lstm_embe

100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:02<00:00,  2.89it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:02<00:00,  3.49it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:02<00:00,  3.49it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:04<00:00,  1.88it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:04<00:00,  1.85it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:04<00:00,  1.84it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:02<00:00,  3.71it/s]
100%|██████████████████████████████████████████████████

In [5]:
# plot statistics of syntactic categories

pos_tag_dfs = []

splits = ['train', 'val', 'test']

for split in splits:
    dataloader_fn = dataloader_fns[split]
    pos_tags = get_pos_tags(dataloader_fn(), dataset_name, split)

    pos_tag_df = pd.DataFrame(data=list(itertools.chain.from_iterable(pos_tags)), columns=['pos'])
    pos_tag_df['split'] = split
    pos_tag_dfs.append(pos_tag_df)

pos_tag_df = pd.concat(pos_tag_dfs)

extend_pos(pos_tag_df)

pos_field = 'syntactic category'

g = sns.catplot(kind='count', data=pos_tag_df, x='split', hue=pos_field, palette=palette)

for split in splits:
    print(f'{split:5s} #examples: {len(data.datasets[split]):5d}')

for split in splits:
    split_cnt = len(pos_tag_df[(pos_tag_df['split'] == split)])
    print(f'{split:5s} #tokens: {split_cnt:6d}')

for (split, pos), bar in zip(itertools.product(splits, pos_tag_df.dtypes[pos_field].categories), g.ax.patches):
    split_cnt = len(pos_tag_df[(pos_tag_df['split'] == split)])
    split_pos_cnt = len(pos_tag_df[(pos_tag_df['split'] == split) & (pos_tag_df[pos_field] == pos)])
    print(f'{split:5s} {pos:15s} #tokens: {split_pos_cnt:6d} {split_pos_cnt / split_cnt :7.2%}')
    g.ax.text(bar.get_x() + bar.get_width()/2., bar.get_height() + 40, f'{split_pos_cnt:d}', ha="center", fontsize='xx-small')

output_fig('dataset distribution')

load cached pos tags: dataset_cache/saycam/train.pos.cache
load cached pos tags: dataset_cache/saycam/val.pos.cache
load cached pos tags: dataset_cache/saycam/test.pos.cache
train #examples: 33737
val   #examples:  1874
test  #examples:  1875
train #tokens: 292475
val   #tokens:  16103
test  #tokens:  16168
train noun            #tokens:  27132   9.28%
train verb            #tokens:  43881  15.00%
train adjective       #tokens:   8190   2.80%
train adverb          #tokens:  13679   4.68%
train function word   #tokens:  75610  25.85%
train cardinal number #tokens:   1247   0.43%
train .               #tokens: 122736  41.96%
val   noun            #tokens:   1446   8.98%
val   verb            #tokens:   2401  14.91%
val   adjective       #tokens:    419   2.60%
val   adverb          #tokens:    757   4.70%
val   function word   #tokens:   4117  25.57%
val   cardinal number #tokens:     81   0.50%
val   .               #tokens:   6882  42.74%
test  noun            #tokens:   1502   9.29%
t

<Figure size 1940.62x1500 with 0 Axes>

In [6]:
searched_word = "ball"
visualize_models = True
n_print_example = 5
textgen_models = {name: model for name, model in models.items() if 'attn' in name} if False else {}
multiple_views = False
all_steps = False

gradcam_models = {name: model for name, model in models.items() if model.text_encoder.captioning or model.text_encoder.has_attention}
attn_models = {name: model for name, model in models.items() if model.text_encoder.has_attention}
if not visualize_models:
    gradcam_model_ns = []
    attn_model_ns = []
n_visualized_models = len(gradcam_models) + len(attn_models)

if searched_word:
    if searched_word is True:
        searched_word = input("search word: ")
    searched_token_id = word2idx.get(searched_word, UNK_TOKEN_ID)
    if searched_token_id == UNK_TOKEN_ID:
        print(f"mapping {searched_word} to UNK")


def get_views(x, grid=(2, 2)):
    """Get multiple views of image x.
    """
    img = n_inv(x)
    views = [img]
    h = img.size(-2) // grid[0]
    w = img.size(-1) // grid[1]
    for grid_i in range(grid[0]):
        for grid_j in range(grid[1]):
            x = grid_i * h
            y = grid_j * w
            views.append(resized_crop(img, x, y, h, w, img.shape[-2:]))
    views = list(map(normalizer, views))
    return views


show_image = functools.partial(
    plot_image,
    overlying=True,
    blur=False,
    #interpolation='nearest',
)


def visualize_example(x, y, y_len, raw_y, steps=None, model_first=True, prepend_x=False, use_losses=None, example_name='example'):
    """Visualize an example.
    Inputs:
        steps: list of steps to visualize; None for all steps
        model_first: if True, then the axes are of n_visualized_models * len(steps), else it is transposed
        prepend_x: if True, prepend raw image x before the models
        use_losses: use the designated losses; must be a list of losses, where each losses is a list of loss at each step;
            for example, split_items[split].losses[example_i]; if None, use the losses generated by running the models
    """
    img = torch_to_numpy_image(n_inv(x))
    y_labels = [idx2word[y_id.item()] for y_id in y]
    if steps is None:
        steps = list(range(y_len.item()))

    n_rows, n_cols = int(prepend_x) + n_visualized_models, len(steps)
    if not model_first:
        n_rows, n_cols = n_cols, n_rows
    n_axes = n_rows * n_cols
    ax_size = 5
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(ax_size * n_cols, ax_size * n_rows), squeeze=False)
    if not model_first:
        axes = [[axes[j][i] for j in range(n_rows)] for i in range(n_cols)]
    axes_iter = itertools.chain.from_iterable(axes)

    if prepend_x:
        for step_i, step in enumerate(steps):
            show_image(next(axes_iter), img)

    for name, model in gradcam_models.items():
        gradcams = gradCAM_for_captioning_lm(model, x, y, y_len, steps=steps)
        prefix = f'{name} GradCAM'
        for step_i, (step, gradcam) in enumerate(zip(steps, gradcams)):
            if step == 0:
                show_image(next(axes_iter), img, text=prefix)
            else:
                text = y_labels[step]
                if step_i == 0:
                    text = prefix + ' ' + text
                show_image(next(axes_iter), img, gradcam, text=text)

    for name, model in attn_models.items():
        attns = attention_for_attention_lm(model, x, y, y_len, steps=steps)
        prefix = f'{name} attn'
        for step_i, (step, attn) in enumerate(zip(steps, attns)):
            text = y_labels[step]
            if step_i == 0:
                text = prefix + ' ' + text
            show_image(next(axes_iter), img, attn, text=text)

    for ax in axes_iter:
        ax.axis("off")
    output_fig(example_name + ' visual map')

    if use_losses is None:
        with torch.no_grad():
            rets = {name: run_model(model, x, y, y_len, single_example=True, return_all=True) for name, model in models.items()}
            losses = [ret[1].cpu().numpy() for ret in rets.values()]
    else:
        losses = use_losses
    print(raw_y[0])
    names_, losses_ = zip(*[(name, loss) for name, loss in zip(names, losses) if 'contrastive' not in name])
    plot_model_y_value_heatmap(names_, losses_, y_labels)
    output_fig(example_name+' loss heatmap')

    if use_losses is None:
        for name, model in models.items():
            ret = rets[name]
            if not isinstance(model, MultiModalLitModel) or 'contrastive' in name:
                continue
            print(f'{name}:')
            if model.language_model.text_encoder.regressional:
                steps_ = [step - 1 for step in steps if step > 0]
            else:
                steps_ = steps
            logits, labels = ret[2], ret[4]
            probs = logits.softmax(-1)
            print_top_values(probs, idx2word, labels, steps=steps_, value_formatter=prob_formatter)
            print()


#examples = examples_from_dataloader(dataloader_fns['val']())

example_i = 0
print_example_i = 0

#for example_i, (x, y, y_len, raw_y) in enumerate(examples):
dataset = data.datasets['val']
for example_i in range(len(dataset)):
    if print_example_i >= n_print_example:
        break

    x, y, y_len, raw_y = dataset[example_i]
    y_len = torch.tensor(y_len)
    y = y[:y_len]

    if searched_word:
        searched_word_steps = [idx for idx, y_id in enumerate(y) if y_id == searched_token_id]
        if not searched_word_steps:
            continue
    
    print(f'example #{example_i}:')

    if all_steps:
        steps = None
    else:
        steps = searched_word_steps if searched_word else [0]

    for x_view_i, x_view in enumerate(get_views(x) if multiple_views else [x]):
        visualize_example(x_view, y, y_len, raw_y, steps=steps, model_first=all_steps, prepend_x=not all_steps, example_name=f'example{example_i}')

    for name, model in textgen_models.items():
        print(f"generating text from {name}:")
        image_features, image_feature_map = model.model.encode_image(x.unsqueeze(0).to(device))
        beam_seq, log_prob = model.language_model.beam_search_decode(
            batch_size=1,
            beam_width=model.beam_width,
            decode_length=model.decode_length,
            length_penalty_alpha=model.length_penalty_alpha,
            image_features=image_features if model.language_model.text_encoder.captioning else None,
            image_feature_map=image_feature_map if model.language_model.text_encoder.has_attention else None,
        )
        gen_text_ids = beam_seq[0, 0]
        gen_text_len = len(gen_text_ids)
        while gen_text_len > 0 and gen_text_ids[gen_text_len - 1] == PAD_TOKEN_ID:
            gen_text_len -= 1
        gen_text_len = torch.tensor(gen_text_len, device=device)
        gen_text_labels = [idx2word[y_id.item()] for y_id in gen_text_ids]
        gen_text = ' '.join(gen_text_labels)
        visualize_example(x, gen_text_ids, gen_text_len, [gen_text], steps=None, model_first=True, prepend_x=False, example_name=f'example{example_i} {name}')

    print_example_i += 1


from textwrap import wrap


def visualize_frame_subplots(
    examples,
    show_original=True,
    show_gradcam=True,
    layout_dim=0,
    add_labels=True,
    add_captions=True,
    fig_width=8.,
    fontsize=5.,
    label_fontsize='medium',
    gridspec_kw={'wspace': 0.025, 'hspace': 0.3},
):
    """Visualize frames and utterances (as captions) in subplots
    examples: can be a dict of form {label: [example[i, j] for j in range(ncols)] for i in range(nrows)},
        or a list of form [[example[i, j] for j in range(ncols)] for i in range(nrows)]
    show_original: bool, whether to show the original image
    show_gradcam: bool, whether to show the gradcam
    layout_dim: int, the extended dimention to arrange original image and gradcam when both are shown
    add_labels: bool, whether to add labels
    add_captions: bool, whether to add captions
    """
    nrows = len(examples)
    if isinstance(examples, dict):
        labels = list(examples.keys())
        examples = list(examples.values())
    else:
        labels = [None] * nrows
        assert not show_gradcam
    ncols = len(examples[0])
    shape = np.array([nrows, ncols])
    shape_each = np.array([1, 1])
    shape_each[layout_dim] = int(show_original) + int(show_gradcam)
    fig, axes, frame_width = frame_subplots(*(shape * shape_each), fig_width=fig_width, gridspec_kw=gridspec_kw)
    for i, (label, examples_row) in enumerate(zip(labels, examples)):
        if label is not None:
            searched_token_id = word2idx.get(label, UNK_TOKEN_ID)
            if add_labels:
                ax = axes[(i + 1) * shape_each[0] - 1][0]
                dy = 0.5 + (shape_each[0] - 1) * (1 + gridspec_kw['hspace']) / 2
                ax.text(-0.3, dy, label, ha='center', va='center', transform=ax.transAxes, fontsize=label_fontsize)
        else:
            assert not show_gradcam

        for j, example in enumerate(examples_row):
            x, y, y_len, raw_y = example
            y_len = torch.tensor(y_len)
            utterance = raw_y[0]
            img = torch_to_numpy_image(n_inv(x))

            ax_loc = np.array([i, j]) * shape_each

            if show_original:
                ax = axes[ax_loc[0]][ax_loc[1]]
                show_image(ax, img)

            if label is not None:
                searched_word_steps = [idx for idx, y_id in enumerate(y) if y_id == searched_token_id]
                assert len(searched_word_steps) >= 1, f"\"{searched_word}\" not found in \"{utterance}\""
                if len(searched_word_steps) > 1:
                    print(f"\"{searched_word}\" occurs {len(searched_word_steps)} times in \"{utterance}\"")
                if show_gradcam:
                    gradcams = gradCAM_for_captioning_lm(model, x, y, y_len, steps=searched_word_steps)
                    step = searched_word_steps[0]
                    gradcam = gradcams[0]

                    ax_loc_ = ax_loc.copy()
                    ax_loc_[layout_dim] += 1
                    ax = axes[ax_loc_[0]][ax_loc_[1]]
                    show_image(ax, img, gradcam)

            if add_captions:
                wrap_width = get_wrap_width(fontsize, frame_width * shape_each[1] + gridspec_kw['wspace'] * (shape_each[1] - 1), c=110.)
                wrapped_utterance = '\n'.join(wrap(untokenize(utterance), wrap_width))
                if label is not None:
                    wrapped_utterance = re.sub(r'\b'+label+r'\b', r'$\\bf{'+label+r'}$', wrapped_utterance, count=1)
                ax = axes[ax_loc[0]][ax_loc[1]]
                dx = 0.5 + (shape_each[1] - 1) * (1 + gridspec_kw['wspace']) / 2
                add_caption(ax, wrapped_utterance, x=dx, y=-0.025, fontsize=fontsize, wrap_width=wrap_width)


split = 'train'
dataset = data.datasets[split]
shown_example_ids = {
    'train': {
        'ball': [545, 1805, 3928] + [1102, 1727, 4588],
        'kitty': [3440, 4884, 9534] + [8407, 439, 1389],
        'banana': [3388, 13105, 18350] + [2576, 3616, 8582],
        'baby': [30, 1080, 4516] + [899, 2168, 6701],
    },
    'val': {
        'ball': [101, 257, 353, 651, 756],
        'kitty': [97, 285, 442, 1634, 1862],
        'banana': [27, 657, 1137, 1419, 1687],
        #'hand': [156, 1322, 1402, 1432, 1513],
        #'sand': [307, 1712, 1763],
        'baby': [256, 495, 1072, 1139, 1497],
    },
}[split]
shown_examples = {label: [dataset[example_i] for example_i in example_ids] for label, example_ids in shown_example_ids.items()}
shown_examples_sample = {label: [examples_row[j] for j in [2, 3]] for i, (label, examples_row) in enumerate(shown_examples.items()) if i < 1}
for name, model in gradcam_models.items():
    visualize_frame_subplots(shown_examples, fig_width=7.3)
    output_fig(f'{name} GradCAM Examples')
    visualize_frame_subplots(shown_examples_sample, layout_dim=1, add_labels=False, fig_width=5.3)
    output_fig(f'{name} GradCAM Examples Sample')

example #101:
normalizing: [vmin, vmax] = [-0.005486, 0.068414] to [0, 1]
saving plot example101 visual map
you love that ball .
saving plot example101 loss heatmap
LSTM:
  2.0% ball     |  19.4% ,         14.8% .          9.6% one        8.4% book       5.8% ?       

LSTM Captioning:
 14.4% ball     |  16.9% .         14.4% ball       8.1% one        7.4% ,          6.4% ?       

CBOW:
  0.7% ball     |  17.8% <eos>      8.6% one        6.3% is         5.3% <unk>      3.2% way     

example #173:
normalizing: [vmin, vmax] = [-0.008447, 0.131132] to [0, 1]
saving plot example173 visual map
there s a ball and puppies and bunnies and bears and carrots .
saving plot example173 loss heatmap
LSTM:
 10.6% ball     |  10.6% ball       8.3% kitty      7.8% baby       6.0% car        4.4% doggy   

LSTM Captioning:
  5.7% ball     |   9.7% kitty      6.9% car        5.7% ball       4.9% frog       3.6% doggy   

CBOW:
  1.7% ball     |   7.2% <sos>      4.4% <unk>      3.7% kitty      3.2% he

<Figure size 1680x367.5 with 0 Axes>

<Figure size 3360x367.5 with 0 Axes>

<Figure size 2100x367.5 with 0 Axes>

<Figure size 1260x367.5 with 0 Axes>

<Figure size 2100x367.5 with 0 Axes>

<Figure size 2190x3611.27 with 0 Axes>

<Figure size 1590x390.184 with 0 Axes>

In [7]:
vector_attr, repres_name = [('mean_vector', 'Mean Representation'), ('embedding', 'Embedding')][1]

In [8]:
# cosine matrices for some tested words

def split_tokens(inp):
    tokens = inp.split()
    token_ids = []
    for token in tokens:
        token_id = word2idx.get(token, UNK_TOKEN_ID)
        if token_id == UNK_TOKEN_ID:
            print(f'mapping {token} to UNK')
        token_ids.append(token_id)
    return token_ids

def get_items_from_inp(inp, token_items=split_items['train'].token_items):
    token_ids = split_tokens(inp)
    if UNK_TOKEN_ID in token_ids:
        raise KeyError
    return token_items.loc[token_ids]


# cosine matrices
print('cosine matrices:')
print()
for figname, inp in {
    "color": "red orange yellow green blue purple brown black white",
    "people": "boy girl mommy daddy grandpa grandma",
}.items():
    try:
        items = get_items_from_inp(inp)
    except KeyError:
        continue
    plot_vector_sim_heatmap(items, names, vector_attr=vector_attr, one_figure=True, figname=figname)

# cosine matrices for the differentiations (vector1 - vector0)
print('cosine matrices for the differentiations:')
print()
for figname, inp in {
    "VB-VBZ": "do does go goes play plays get gets have has make makes look looks",
    "VB-VBD": "do did go went play played get got eat ate have had make made look looked fly flew",
    "VB-VBG": "do doing go going play playing get getting eat eating have having look looking fly flying drive driving stand standing crawl crawling",
    "male-female": "boy girl mommy daddy grandpa grandma",
}.items():
    try:
        items = get_items_from_inp(inp)
    except KeyError:
        continue
    plot_vector_sim_heatmap(items, names, diff=True, vector_attr=vector_attr, one_figure=True, figname=figname)

cosine matrices:

saving plot color
saving plot people
cosine matrices for the differentiations:

saving plot VB-VBZ
saving plot VB-VBD
saving plot VB-VBG
saving plot male-female


<Figure size 6930x4620 with 0 Axes>

<Figure size 5040x3360 with 0 Axes>

<Figure size 10080x3360 with 0 Axes>

<Figure size 12600x4200 with 0 Axes>

<Figure size 15120x5040 with 0 Axes>

<Figure size 5040x1680 with 0 Axes>

In [9]:
for split in used_splits:
    token_items = split_items[split].token_items

    for name in names:
        print(f'{name}:')
        if False:
            get_tsne_points(token_items[name], get_attr=vector_attr)
            extend_point_items(token_items, name, 'tsne')
        if False:
            get_eigen_points(token_items[name], get_attr=vector_attr)
            extend_point_items(token_items, name, 'eigen')
        if True:
            get_pca_points(token_items[name], get_attr=vector_attr)
            extend_point_items(token_items, name, 'pca')

LSTM:
PCA done.
LSTM Captioning:
PCA done.
CBOW:
PCA done.
LSTM:
PCA done.
LSTM Captioning:
PCA done.
CBOW:
PCA done.


In [10]:
def get_top_token_items(token_items):
    top_token_items = token_items.sort_values('cnt', ascending=False, kind='stable')   # sort by cnt
    top_token_items = top_token_items[~top_token_items[token_field].isin(untypical_words)]  # remove untypical words
    return top_token_items

def get_pos_items(token_items, pos_field='syntactic category'):
    used_poses = token_items.dtypes[pos_field].categories
    pos_items = {pos: token_items[token_items[pos_field] == pos] for pos in used_poses}
    return pos_items

for split in used_splits:
    print(f'{split} split:')

    top_token_items = get_top_token_items(split_items[split].token_items)
    pos_items = get_pos_items(top_token_items)

    for pos, items in pos_items.items():
        print(f'number of {pos}s: {len(items)}')
        for _, row in items[:50].iterrows():
            print(row_str(row, names))

    # check some items
    for _, row in top_token_items[:100].iterrows():
        print(row_str(row, names))
    print()
    for word in ['look', 'need', 'draw']:
        try:
            token_id = word2idx[word]
            for _, row in top_token_items.loc[token_id].iterrows():
                print(row_str(row, names))
        except KeyError:
            pass

train split:
number of nouns: 1100
NNP  sam           505:    48.205    36.268    93.581
NN   ball          504:     9.229     4.456    33.163
NN   kitty         471:    13.768     8.270    42.983
NN   baby          390:     7.656     3.228    15.232
NN   book          308:     5.681     4.027    56.617
NN   train         242:    22.360    12.322    56.589
NN   water         235:    36.356    21.865    83.275
NN   time          228:    13.167    11.240    75.908
NN   bear          225:    35.474    22.402    92.249
NN   car           184:    21.185    14.405    71.850
NN   banana        184:    27.397    10.630   111.932
NN   poo           174:    30.004    15.905    55.634
NN   way           173:     8.482     6.555   173.185
NN   truck         163:    12.629     7.204    39.295
NNS  things        160:    25.827    16.167    90.585
NNS  shoes         152:    18.365    10.618    60.477
NN   bunny         145:    57.011    20.520   145.401
NN   bread         143:    32.420    13.497   1

.    !            1819:    12.186     9.430    12.480
IN   on           1768:     6.420     5.590    12.285
UH   okay         1760:    26.081    25.654    18.520
IN   in           1649:     7.861     7.482    17.812
PRP$ your         1621:     9.881     9.879     4.710
VBP  have         1584:    11.989    10.021    10.614
DT   this         1411:    28.523    24.724    15.477
VB   put          1299:    13.290    11.405    16.771
UH   oh           1272:    30.678    31.567    16.537
RB   here         1150:    37.639    29.693    47.358
WP   what         1138:    36.450    31.130    27.280
DT   all          1137:    36.251    27.725    11.401
IN   of           1134:     2.944     2.639     3.731
RB   now          1068:    42.463    41.957    63.051
VB   see          1057:    30.043    25.928    45.373
DT   some          977:    12.806    10.091     7.894
VB   get           878:    18.489    16.741    23.311
VB   let           875:    55.995    51.956     9.321
JJ   right         858:     

IN   of             58:     4.577     5.043     5.627
DT   some           54:    18.276    15.861     8.245
PRP  me             53:     7.483     7.711    15.688
IN   with           43:    12.939    13.284    17.939
PRP  they           43:    50.531    42.412    40.735
IN   at             39:     4.376     4.312     5.886
WRB  where          38:   108.009    83.534    31.618
IN   for            36:    61.177    47.332    47.928
RP   out            36:    19.359    19.446    43.021
RP   up             32:    67.069    74.303   107.420
PRP  he             30:    53.398    40.060    25.260
MD   'll            30:    16.162    14.580    13.116
CC   but            27:   102.054   120.246    73.580
TO   na             27:     1.005     1.005     1.117
IN   if             26:    94.125   118.744    64.724
RP   off            23:    18.599    19.551   177.608
DT   those          23:   148.958   129.611    88.123
MD   will           22:    43.682    50.528    46.950
PRP  them           21:    2

In [11]:
def build_linkage_by_same_value(s):
    """build the linkage by clustering same values in the series s
    """
    s = s.reset_index(drop=True).sort_values(kind='stable')

    # get groups by finding contingous segments on sorted idxes
    l, r = 0, 0
    groups = []
    while l < len(s):
        while r < len(s) and s.iloc[l] == s.iloc[r]:
            r += 1
        groups.append(s.index[l:r].tolist())
        l = r

    # initialization
    Z = [(idx, idx, 0., 1) for idx in range(len(s))]

    def merge(idx0, idx1, distance):
        Z.append((idx0, idx1, distance, Z[idx0][3] + Z[idx1][3]))
        return len(Z) - 1

    def merge_group(group, distance):
        root = group[0]
        for idx in group[1:]:
            root = merge(root, idx, distance)
        return root

    # merge items within each group
    group_idxes = [merge_group(group, 0.) for group in groups]
    # merge groups
    root = merge_group(group_idxes, 1.)

    return np.array(Z[len(s):])


def PearsonRResult_format(result):
    return f'({result.statistic:.2f}, p={result.pvalue:6.4f})'


def print_2d_array(array, element_format):
    for row in array:
        for e in row:
            print(element_format(e), end=' ')
        print()


def print_rsa_results(rsa_results):
    print_2d_array(rsa_results, PearsonRResult_format)


def plot_dendrogram(
    items,
    names,
    baseline_name=None,
    vector_attr='mean_vector',
    heatmap=False,
    annot=False,
    size=0.7,
    color_threshold=None,
    heatmap_linkage=None,
    tag_field='POS tag',
    tag_palette=palette,
    ctg_field='POS tag',
    ctg_palette=palette,
    ll_tag_field='pos',
    ll_with_cnt=True,
    ll_with_ppl=True,
    title=None,
    rc_context={'font.family': 'monospace'},
):
    """linkage clustering and dendrogram plotting
    items: pd.DataFrame
    names: the names of the models to plot
    vector_attr: use value.vector_attr; default: 'mean_vector'; can be 'embedding'
    heatmap: bool, plot something like plot_sim_heatmap
    heatmap_linkage: the row_linkage and col_linkage used in heatmap; can be any of:
        None: use the linkage result from the vectors of current model
        "first": use the linkage result of the first model
        "tag": build the linkage by clustering items by tag_field
        result from linkage function
    tag_field: the field of tags to obtain the palette of the sidebar of heatmaps
    tag_palette: palette for tag_field
    ctg_field: the field as the category in dendrograms
    ctg_palette: palette for ctg_field
    ll_tag_field: the field of tags to append to leaf labels; set to None or empty string if unwanted
    ll_with_cnt: whether to append cnt to leaf labels
    ll_with_ppl: whether to append ppl to leaf labels
    title: title of the plots
    rc_context: rc context for plots
    """
    from analysis_tools.hierarchy import dendrogram, linkage

    # use the first name as the baseline by default
    if baseline_name is None:
        baseline_name = names[0]

    n_items = len(items)
    vectors = [get_np_attrs_from_values(items[name], vector_attr) for name in names]

    _title = f'{title} RSA'
    ax, rsa_results = plot_rsa_heatmap(vectors, names)
    print_rsa_results(rsa_results)
    if add_titles:
        plt.title(_title)
    output_fig(_title)

    # build color map
    colors = items[tag_field].astype('O').map(tag_palette).tolist()

    if heatmap:
        if heatmap_linkage == "tag":  # build Z_heatmap based on tag_field
            Z_heatmap = build_linkage_by_same_value(items[tag_field])
        elif not (heatmap_linkage is None or isinstance(heatmap_linkage, str)):  # use heatmap_linkage
            Z_heatmap = heatmap_linkage

    ll_tag_width = max(map(len, items[ll_tag_field])) if ll_tag_field else 0

    for n, (name, V) in enumerate(zip(names, vectors)):
        print(f'{name}:')
        Z = linkage(V, method='average', metric='cosine')  # of shape (number of merges = n_items - 1, 4)

        def llf(index):
            if index < n_items:
                row = items.iloc[index]
                s = row_prefix_str(
                    row,
                    tag_field=ll_tag_field,
                    tag_width=ll_tag_width,
                    sep='  ',
                    with_cnt=ll_with_cnt,
                    align=ll_tag_field,
                )
                if ll_with_ppl:
                    s += '  '
                    ppl = row[name].ppl
                    if name != baseline_name:
                        baseline_ppl = row[baseline_name].ppl
                        s += f'{ppl-baseline_ppl:+9.2f}='
                    s += f'{ppl:8.2f}'
                return s
            else:
                merge_index = index - n_items
                return f'{merge_index} {int(Z[merge_index, 3])} {Z[merge_index, 2]:.3f}'

        ctg_sets = [{ctg} for ctg in items[ctg_field]]
        for link in Z:
            ctg_set = ctg_sets[int(link[0])] | ctg_sets[int(link[1])]
            ctg_sets.append(ctg_set)

        def link_color_func(k):
            ctg_set = ctg_sets[k]
            if len(ctg_set) == 1:
                ctg = next(iter(ctg_set))
                return ctg_palette[ctg]
            else:
                return 'black'

        p = 10000

        plt.figure(figsize=(2 if ll_tag_field else 2.8, 0.15 * min(p, n_items)))
        with plt.rc_context(rc_context):
            dendrogram(
                Z,
                truncate_mode='lastp',
                p=p,
                color_threshold=color_threshold,
                orientation='left',
                leaf_rotation=0.,
                #leaf_font_size=16.,
                leaf_label_func=llf,
                link_color_func=link_color_func,
                color_branch=True,
            )
        plt.xticks(ticks=[], labels=[])

        _title = f"{name} {title} Dendrogram"
        if add_titles and title is not None:
            plt.title(_title)
        output_fig(_title)

        if heatmap:
            if heatmap_linkage is None:
                Z_heatmap = Z
            elif heatmap_linkage == "first":
                if n == 0:
                    Z_heatmap = Z

            llf_labels = list(map(llf, range(n_items)))

            matrix = cosine_matrix(V)

            off_diag = ~np.eye(matrix.shape[0], matrix.shape[1], dtype=bool)
            v = np.max(np.abs(matrix[off_diag]))
            vmin = -v
            vmax = +v

            with plt.rc_context(rc_context):
                g = sns.clustermap(
                    matrix,
                    row_linkage=Z_heatmap,
                    col_linkage=Z_heatmap,
                    figsize=(8, 8),
                    cbar_pos=None,
                    # kwargs for heatmap
                    vmin=vmin, vmax=vmax, center=0,
                    annot=annot, fmt='.2f',
                    xticklabels=llf_labels,
                    yticklabels=llf_labels,
                    row_colors=colors,
                    col_colors=colors,
                    #cbar=False,
                    dendrogram_ratio=0., # remove all dendrograms
                    colors_ratio=0.02,
                )
            g.ax_col_dendrogram.remove()

            _title = f"{name} {title} Similarity Heatmap"
            if add_titles and title is not None:
                plt.title(_title)
            output_fig(_title)


def get_cat_items(items, cats, n_each, cat_field):
    dfs = []
    for cat_name, words in cats.items():
        df = items[items[token_field].isin(words)].copy()
        df[cat_field] = cat_name
        if len(df) >= n_each:  # ignore categories without enough items
            dfs.append(df[:n_each])
    if dfs:
        items = pd.concat(dfs)
    else:
        items = items[:0]
        items[cat_field] = []
    items[cat_field] = items[cat_field].astype('category')
    return items


def item_combinations(token_items, with_all=False):
    top_token_items = get_top_token_items(token_items)
    pos_items = get_pos_items(top_token_items)

    syn_cat_field = 'syntactic category'
    sem_cat_field = 'semantic category'
    syn_cat_items = pd.concat([pos_items['noun'][:24], get_cat_items(pos_items['verb'], pos_subcats['verb'], 12, syn_cat_field)])
    from pandas.api.types import CategoricalDtype
    syn_cat_items[syn_cat_field] = syn_cat_items[syn_cat_field].astype(
        CategoricalDtype(['noun', 'trans. verb', 'intrans. verb'], ordered=True))

    # tuples are (items, tag_field, ctg_field, ll_tag_field, heatmap_linkage)
    ret = {
    "All Syntactic Categories": (pd.concat([items[:12] for pos, items in pos_items.items() if pos not in ['cardinal number', '.', 'function word']]), syn_cat_field, syn_cat_field, syn_cat_field, 'tag'),
    #   "Noun vs Verb":         (pd.concat([pos_items[pos][:25] for pos in ('noun', 'verb')]), 'POS tag', syn_cat_field, 'pos', 'first'),
        "Syntactic Categories": (syn_cat_items, syn_cat_field, syn_cat_field, syn_cat_field, 'tag'),
        "Semantic Categories":  (get_cat_items(pos_items['noun'], pos_subcats['noun'], 6, sem_cat_field), sem_cat_field, sem_cat_field, sem_cat_field, 'tag'),
        "Verb Transitivity":    (get_cat_items(pos_items['verb'], pos_subcats['verb'], 25, sem_cat_field), sem_cat_field, sem_cat_field, sem_cat_field, 'tag'),
    }
    if with_all:
        ret.update({
        "All":                  (top_token_items, 'POS tag', 'POS tag', 'pos', 'first'),
        "All Noun vs Verb":     (pd.concat([pos_items[pos] for pos in ('noun', 'verb')]), 'POS tag', syn_cat_field, 'pos', 'first'),
        })
    return ret


for split in ['train']:
    token_items = split_items[split].token_items

    vectors = [get_np_attrs_from_values(token_items[name], vector_attr) for name in names]

    _title = f'All RSA'
    ax, rsa_results = plot_rsa_heatmap(vectors, names)
    print_rsa_results(rsa_results)
    if add_titles:
        plt.title(_title)
    output_fig(_title)

    for title, (items, tag_field, ctg_field, ll_tag_field, heatmap_linkage) in item_combinations(token_items).items():
        print(f'{title}:')
        with sns.axes_style(heatmap_style, rc={'font.family': [font]}):
            plot_dendrogram(
                items,
                names,
                heatmap=True,
                heatmap_linkage=heatmap_linkage,
                title=f'{repres_name} {title}',
                vector_attr=vector_attr,
                tag_field=tag_field,
                ctg_field=ctg_field,
                ll_tag_field=None,#ll_tag_field,
                ll_with_cnt=False,
                ll_with_ppl=False,
                rc_context={'font.family': font},
            )

(1.00, p=0.0000) (0.87, p=0.0000) (0.33, p=0.0000) 
(0.87, p=0.0000) (1.00, p=0.0000) (0.33, p=0.0000) 
(0.33, p=0.0000) (0.33, p=0.0000) (1.00, p=0.0000) 
saving plot All RSA
All Syntactic Categories:
(1.00, p=0.0000) (0.84, p=0.0000) (0.65, p=0.0000) 
(0.84, p=0.0000) (1.00, p=0.0000) (0.64, p=0.0000) 
(0.65, p=0.0000) (0.64, p=0.0000) (1.00, p=0.0000) 
saving plot Embedding All Syntactic Categories RSA
LSTM:
saving plot LSTM Embedding All Syntactic Categories Dendrogram
saving plot LSTM Embedding All Syntactic Categories Similarity Heatmap
LSTM Captioning:
saving plot LSTM Captioning Embedding All Syntactic Categories Dendrogram
saving plot LSTM Captioning Embedding All Syntactic Categories Similarity Heatmap
CBOW:
saving plot CBOW Embedding All Syntactic Categories Dendrogram
saving plot CBOW Embedding All Syntactic Categories Similarity Heatmap
Syntactic Categories:
(1.00, p=0.0000) (0.82, p=0.0000) (0.71, p=0.0000) 
(0.82, p=0.0000) (1.00, p=0.0000) (0.70, p=0.0000) 
(0.71, p=0.0

  self._figure = plt.figure(figsize=figsize)


saving plot LSTM Embedding Verb Transitivity Similarity Heatmap
LSTM Captioning:
saving plot LSTM Captioning Embedding Verb Transitivity Dendrogram
saving plot LSTM Captioning Embedding Verb Transitivity Similarity Heatmap
CBOW:
saving plot CBOW Embedding Verb Transitivity Dendrogram
saving plot CBOW Embedding Verb Transitivity Similarity Heatmap


<Figure size 1500x1500 with 0 Axes>

<Figure size 840x2160 with 0 Axes>

<Figure size 2400x2400 with 0 Axes>

<Figure size 840x2160 with 0 Axes>

<Figure size 2400x2400 with 0 Axes>

<Figure size 840x2160 with 0 Axes>

<Figure size 1500x1500 with 0 Axes>

<Figure size 840x2160 with 0 Axes>

<Figure size 2400x2400 with 0 Axes>

<Figure size 840x2160 with 0 Axes>

<Figure size 2400x2400 with 0 Axes>

<Figure size 840x2160 with 0 Axes>

<Figure size 1500x1500 with 0 Axes>

<Figure size 840x2160 with 0 Axes>

<Figure size 2400x2400 with 0 Axes>

<Figure size 840x2160 with 0 Axes>

<Figure size 2400x2400 with 0 Axes>

<Figure size 840x2160 with 0 Axes>

<Figure size 1500x1500 with 0 Axes>

<Figure size 840x2250 with 0 Axes>

<Figure size 2400x2400 with 0 Axes>

<Figure size 840x2250 with 0 Axes>

<Figure size 2400x2400 with 0 Axes>

<Figure size 840x2250 with 0 Axes>

<Figure size 2400x2400 with 0 Axes>

In [12]:
from scipy.stats import ttest_rel


def get_best_split_accuracy(label_0, label_1):
    n_label_0 = label_0.value_counts()[True]
    n_label_1 = label_1.value_counts()[True]

    i = 0
    count_0 = 0
    count_1 = n_label_1

    best_i, best_n = i, count_0 + count_1

    for label_i_0, label_i_1 in zip(label_0, label_1):
        count_0 += label_i_0
        count_1 -= label_i_1
        i += 1

        if count_0 + count_1 > best_n:
            best_i, best_n = i, count_0 + count_1

    print(f'n_label_0: {n_label_0}, n_label_1: {n_label_1}, best_accuracy: {best_n / (n_label_0 + n_label_1):.2%}')
    return best_i, best_n / (n_label_0 + n_label_1)


def plot_ROC(y_true, y_score, pos_label=None, label="", **kwargs):
    from sklearn.metrics import roc_curve, auc
    fpr, tpr, thresholds = roc_curve(y_true, y_score, pos_label=pos_label)
    roc_auc = auc(fpr, tpr)
    plt.plot(
        fpr,
        tpr,
        label=f"{label} (AUC = {roc_auc:0.2f})",
        **kwargs,
    )


def plot_ROC_end(**kwargs):
    plt.plot([0, 1], [0, 1], **kwargs)
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.0])
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.legend(loc="lower right")


def ttest(a, label, alternative):
    result = ttest_rel(a, np.zeros_like(a), alternative=alternative)
    print(f'{label:15} #examples: {len(a):5} mean: {np.mean(a):5.2f} t-test result: statistic: {result.statistic:6.2f} pvalue: {result.pvalue:5.3f}')
    return result


def analyze_value(
    items,
    value_attr,
    value_suffix,
    model_names,
    cat_field="syntactic category",
    palette=palette,
    title=None,
    lines=None,
    find_best_threshold=False,
    plotting=True,
    add_pvalue_asterisks=True,
    ROC=False,
    ttest_alternative='two-sided',
):
    """
    items: items
    value_attr: the name of the value
    value_suffix: model_name + value_suffix is the field of the value of the model
    model_names: names of the models to analyze
    cat_field: the category field
    palette: the palette for the categories; a dict mapping each category name to a color
    """
    model_field = 'model'
    if title is None:
        title = value_attr
    flierprops = dict(marker='o', markersize=2, markerfacecolor='none')

    n_cats = len(items.dtypes[cat_field].categories)
    for cat in items.dtypes[cat_field].categories:
        cat_items = items[items[cat_field] == cat]
        if len(cat_items) > 0:
            print(f'{cat:15} #: {len(cat_items)}')
    items_long = items.melt(
        id_vars=[cat_field],
        value_vars=[name + value_suffix for name in model_names],
        var_name=model_field,
        value_name=value_attr)
    items_long[model_field] = items_long[model_field].map(lambda s: s[:-len(value_suffix)])
    g = sns.catplot(
        items_long,
        x=model_field,
        y=value_attr,
        hue=cat_field,
        palette=palette,
        kind="box",
        flierprops=flierprops,
        medianprops=dict(color="white", alpha=0.7),
    )
    if lines:
        ax = plt.gca()
        for line in lines:
            ax.axhline(line, **line_kwargs)
    g.figure.set_size_inches(1 * n_cats, 4)
    if add_titles:
        plt.title(title)
    output_fig(title)

    for name in model_names:
        print(f'{name}:')
        value_field = name + value_suffix
        cur_items = items.sort_values(value_field)

        if find_best_threshold:
            assert cat_field == "syntactic category"
            label_0 = cur_items[cat_field].isin(["noun"])
            label_1 = ~label_0 #cur_items[cat_field].isin(["function word", "adjective", "adverb", "cardinal number"])
            best_split_i, best_split_acc = get_best_split_accuracy(label_0, label_1)
            print(f'best_split_accuracy: {best_split_acc:.2%}')
            threshold = cur_items.iloc[min(best_split_i, len(cur_items)-1)][value_field]

        ttest_results = {}
        for cat in cur_items.dtypes[cat_field].categories:
            ttest_results[cat] = ttest(cur_items.loc[cur_items[cat_field] == cat, value_field], cat, ttest_alternative)
        ttest(cur_items[value_field], "all tokens", ttest_alternative)

        if plotting:
            _title = f'{name} {title}'
            g = sns.catplot(
                cur_items,
                x=value_field,
                y=cat_field,
                palette=palette,
                kind='box',
                flierprops=flierprops,
            )
            if add_pvalue_asterisks:
                ax = plt.gca()
                labels = [item.get_text() for item in ax.get_yticklabels()]
                for i in range(len(labels)):
                    label = labels[i]
                    ttest_result = ttest_results[label]
                    if ttest_result.pvalue <= 0.001:
                        label += '***'
                    elif ttest_result.pvalue <= 0.01:
                        label += '**'
                    elif ttest_result.pvalue <= 0.05:
                        label += '*'
                    labels[i] = label
                ax.set_yticklabels(labels)
            if lines:
                ax = plt.gca()
                for line in lines:
                    ax.axvline(line, **line_kwargs)
            g.figure.set_size_inches(6, 0.2 * n_cats)
            if add_titles:
                plt.title(_title)
            output_fig(_title)

            _title = f'{name} {title} Distribution'
            sns.kdeplot(
                cur_items,
                x=value_field,
                hue=cat_field,
                palette=palette,
                bw_adjust=.5,
            )
            if lines:
                ax = plt.gca()
                for line in lines:
                    ax.axvline(line, **line_kwargs)
            plt.gcf().set_size_inches(6, 3)
            if add_titles:
                plt.title(_title)
            output_fig(_title)

            if ROC:
                for cat in cur_items.dtypes[cat_field].categories:
                    plot_ROC(
                        cur_items[cat_field],
                        -cur_items[value_field],
                        pos_label=cat,
                        label=f"{cat} vs others",
                        color=palette[cat])
                __title = _title + 'ROC'
                plot_ROC_end()
                if add_titles:
                    plt.title(__title)
                output_fig(__title)


# use the first name as the baseline by default
baseline_name = names[0]
model_names = ['LSTM Captioning']#list(filter(lambda name: name != baseline_name and 'contrastive' not in name.lower(), names))
remove_cats = ['cardinal number', '.']
value_attr = 'Loss Difference'
value_suffix = diff_field_suffix(baseline_name, 'loss')
ttest_alternative = 'two-sided'
title = 'Loss Difference'


for split in ['val']:
    print('Type level:')
    token_items = split_items[split].token_items
    if remove_cats:
        token_items = token_items[~token_items['syntactic category'].isin(remove_cats)].copy()
        token_items['syntactic category'] = token_items['syntactic category'].cat.remove_unused_categories()
    analyze_value(
        token_items[token_items['cnt'] >= 2],
        value_attr,
        value_suffix,
        model_names,
        title='Type Level ' + title,
        lines=[0] if add_lines else None,
        ttest_alternative=ttest_alternative,
    )

    all_token_items = split_items[split].all_token_items
    if all_token_items is not None:
        print('=' * 100)
        print('Token level:')
        if remove_cats:
            all_token_items = all_token_items[~all_token_items['syntactic category'].isin(remove_cats)].copy()
            all_token_items['syntactic category'] = all_token_items['syntactic category'].cat.remove_unused_categories()
        analyze_value(
            all_token_items.reset_index(drop=True),
            value_attr,
            value_suffix,
            model_names,
            title='Token Level ' + title,
            lines=[0] if add_lines else None,
            plotting=True,
            ttest_alternative=ttest_alternative,
        )

Type level:
noun            #: 220
verb            #: 150
adjective       #: 44
adverb          #: 45
function word   #: 82
saving plot Type Level Loss Difference
LSTM Captioning:
noun            #examples:   220 mean: -0.51 t-test result: statistic:  -9.44 pvalue: 0.000
verb            #examples:   150 mean: -0.29 t-test result: statistic:  -5.58 pvalue: 0.000
adjective       #examples:    44 mean: -0.14 t-test result: statistic:  -1.75 pvalue: 0.087
adverb          #examples:    45 mean: -0.14 t-test result: statistic:  -2.23 pvalue: 0.031
function word   #examples:    82 mean: -0.13 t-test result: statistic:  -3.25 pvalue: 0.002
all tokens      #examples:   541 mean: -0.33 t-test result: statistic: -11.40 pvalue: 0.000
saving plot LSTM Captioning Type Level Loss Difference
saving plot LSTM Captioning Type Level Loss Difference Distribution
Token level:
noun            #: 1406
verb            #: 2346
adjective       #: 413
adverb          #: 740
function word   #: 4042
saving plot To

<Figure size 1500x1200 with 0 Axes>

<Figure size 1800x900 with 0 Axes>

<Figure size 1500x1200 with 0 Axes>

<Figure size 1800x900 with 0 Axes>

In [13]:
def analyze_subcat_value_diff(
    all_items,
    subcats,
    names,
    baseline_name=None,
    model_names=None,
    subcats_name='Subcategories',
    subcat_name='subcategory',
    word_name='word',
    given_subcat=False,
    plotting_subcats=False,
    ttest_alternative='two-sided'
):
    # use the first name as the baseline by default
    if baseline_name is None:
        baseline_name = names[0]
    if model_names is None:
        model_names = list(filter(lambda name: name != baseline_name and 'contrastive' not in name.lower(), names))

    subcat2word_ids = {
        subcat_name: [word2idx[word] for word in subcat_words if word in word2idx]
        for subcat_name, subcat_words in subcats.items()
    }
    all_subcat_words = list(filter(
        lambda word: word in word2idx,
        itertools.chain.from_iterable(subcat_words for subcat_name, subcat_words in subcats.items())))
    all_subcat_word_ids = [word2idx[word] for word in all_subcat_words]

    p_field = 'prob'
    ps_field = 'probs'
    subcat_field = 'subcategory'

    items = all_items[all_items[token_field].isin(all_subcat_words)].copy()
    items[subcat_field] = items[token_field].map(word2subcat).astype('category')
    print(f'#: {len(items)}')

    p_cats_field = 'prob cats'
    p_subcat_field = 'prob subcat'
    p_subcat_given_cats_field = 'prob subcat given cats'
    for name in names:
        name_ps_field = f'{name} {ps_field}'
        name_p_cats_field = f'{name} {p_cats_field}'
        name_p_subcat_field = f'{name} {p_subcat_field}'
        name_p_subcat_given_cats_field = f'{name} {p_subcat_given_cats_field}'
        items[name_p_cats_field] = items[name_ps_field].map(lambda p: np.sum(p[all_subcat_word_ids]))
        items[name_p_subcat_field] = items.apply(lambda item: np.sum(item[name_ps_field][subcat2word_ids[item[subcat_field]]]), axis=1)
        items[name_p_subcat_given_cats_field] = items[name_p_subcat_field] / items[name_p_cats_field]
        if name != baseline_name:
            extend_items_value_diff(items, name, baseline_name, p_subcat_given_cats_field)
    analyze_value(
        items.reset_index(drop=True),
        p_subcat_given_cats_field + ' Difference',
        diff_field_suffix(baseline_name, p_subcat_given_cats_field),
        model_names,
        cat_field=subcat_field,
        palette=palette,
        title=f'P({subcat_name} | {subcats_name}) Difference',
        lines=[0] if add_lines else None,
        ttest_alternative=ttest_alternative,
    )

    if given_subcat:
        for subcat_name, subcat_words in subcats.items():
            items = all_items[all_items[token_field].isin(subcat_words)].copy()
            items[token_field] = items[token_field].astype('category')
            print(f'{subcat_name:14} #: {len(items)}')
            _palette = sns.color_palette('tab20')
            _palette = {word: _palette[i % len(_palette)] for i, word in enumerate(subcat_words)}
            p_subcat_field = f'prob {subcat_name}'
            p_given_subcat_field = f'prob given {subcat_name}'
            for name in names:
                name_p_field = f'{name} {p_field}'
                name_ps_field = f'{name} {ps_field}'
                name_p_subcat_field = f'{name} {p_subcat_field}'
                name_p_given_subcat_field = f'{name} {p_given_subcat_field}'
                items[name_p_subcat_field] = items[name_ps_field].map(lambda p: np.sum(p[subcat2word_ids[subcat_name]]))
                items[name_p_given_subcat_field] = items[name_p_field] / items[name_p_subcat_field]
                if name != baseline_name:
                    extend_items_value_diff(items, name, baseline_name, p_given_subcat_field)
            analyze_value(
                items.reset_index(drop=True),
                p_given_subcat_field + ' Difference',
                diff_field_suffix(baseline_name, p_given_subcat_field),
                model_names,
                cat_field=token_field,
                palette=_palette,
                title=f'P({word_name} | {subcat_name}) Difference',
                lines=[0] if add_lines else None,
                plotting=plotting_subcats,
                ttest_alternative=ttest_alternative,
            )


subcats_name = {
    'noun': 'Noun Semantic Subcategories',
    'verb': 'Verb Transitivity SubCategories'
}


for split in ['val']:
    for pos, subcats in pos_subcats.items():
        items = split_items[split].all_token_items
        items = items[items['syntactic category'].isin([pos])].copy()
        analyze_subcat_value_diff(
            items,
            subcats,
            names,
            baseline_name=baseline_name,
            model_names=model_names,
            subcats_name=subcats_name[pos],
            given_subcat=False,
        )

#: 392
animals         #: 99
body_parts      #: 35
clothing        #: 40
food_drink      #: 74
games_routines  #: 4
household       #: 35
places          #: 10
toys            #: 49
vehicles        #: 46
saving plot P(subcategory | Noun Semantic Subcategories) Difference
LSTM Captioning:
animals         #examples:    99 mean:  0.07 t-test result: statistic:   4.46 pvalue: 0.000
body_parts      #examples:    35 mean:  0.01 t-test result: statistic:   0.23 pvalue: 0.818
clothing        #examples:    40 mean:  0.04 t-test result: statistic:   2.11 pvalue: 0.041
food_drink      #examples:    74 mean:  0.10 t-test result: statistic:   4.20 pvalue: 0.000
games_routines  #examples:     4 mean:  0.22 t-test result: statistic:   1.99 pvalue: 0.141
household       #examples:    35 mean:  0.07 t-test result: statistic:   3.16 pvalue: 0.003
places          #examples:    10 mean:  0.03 t-test result: statistic:   0.59 pvalue: 0.572
toys            #examples:    49 mean:  0.16 t-test result: statist

<Figure size 2700x1200 with 0 Axes>

<Figure size 1800x900 with 0 Axes>

<Figure size 600x1200 with 0 Axes>

<Figure size 1800x900 with 0 Axes>

In [14]:
pos_field = "syntactic category"

baseline_name = names[0]
text_label_kwargs = {'fontsize': 7.8}
rc_context = {
    'font.size': 9.,
    'axes.labelsize': 9.,
    'axes.titlesize': 13.,
    'xtick.labelsize': 7.8,
    'ytick.labelsize': 7.8,
    'legend.fontsize': 8.,
    'legend.title_fontsize': 7.8,
}

def get_xylabel(label):
    if label == "logcnt":
        return "log frequency"
    match = re.search(r"\bpca (\d+)\b", label)
    if match:
        return f"principal component {int(match.group(1)) + 1}"
    match = re.search(r"\btsne (\d+)\b", label)
    if match:
        return ""
    return label


with plt.rc_context(rc_context):
    for split in ['train']:
        for name in names:
            print(f'{name}:')
            loss_field = f'{name} loss'
            loss_diff_field = diff_field_name(name, baseline_name, 'loss')
            fields = [pos_field, "cnt", "logcnt", conc_field, "AnimPhysical", "AnimMental", "Category", "AoA", loss_field, loss_diff_field]

            token_items = split_items[split].token_items
            filtered_items = {
                '': token_items,
                'AnimPhysical': token_items.dropna(subset='AnimPhysical'),
                'det-adj': get_top_token_items(token_items[token_items['POS tag (compressed)'].map(lambda pos: pos in ["DT", "adjective"])])[:100],
            }

            for filter_name, token_items in filtered_items.items():
                for items_name, (items, tag_field, ctg_field, *unused) in item_combinations(token_items, with_all=True).items():
                    if filter_name != '' and items_name not in ['All']:
                        continue
                    if len(items) == 0:
                        continue
                    small_plot = len(items) <= 100
                    print(f'{items_name}:')
                    text_label = token_field if small_plot else None
                    items = items.copy()
                    items[ctg_field] = items[ctg_field].cat.remove_unused_categories()

                    convert_attr_for_each(
                        items[name],
                        get_attr=vector_attr,
                        set_attr='tsne_point',
                        converter=functools.partial(get_tsne_points_from_vectors, perplexity=10, n_iter=1000, init='pca')
                    )
                    extend_point_items(items, name, 'tsne')

                    title_prefix = f"{name} {repres_name} {items_name} "
                    configs = {
                        '': [
                            (ctg_field, palette, f"{name} tsne 0", f"{name} tsne 1", "off", title_prefix + "t-SNE", False),
                            (ctg_field, palette,  f"{name} pca 0",  f"{name} pca 1",  "on", title_prefix + "PCA", False),
                            ( "logcnt",    None,  f"{name} pca 0",  f"{name} pca 1",  "on", title_prefix + "PCA", False),
                            (ctg_field, palette,  f"{name} pca 0",         "logcnt",  "on", title_prefix + "Correlation between principal component 1 and log frequency", True),
                            (ctg_field, palette,  f"{name} pca 1",  f"{name} pca 2",  "on", title_prefix + "PCA", False),
                        ],
                        'AnimPhysical': [
                            ("AnimPhysical",None,f"{name} tsne 0", f"{name} tsne 1", "off", title_prefix + "t-SNE with AnimPhysical", False),
                        ],
                        'det-adj': [
                            ("syntactic category", palette, f"{name} tsne 0", f"{name} tsne 1", "off", title_prefix + "t-SNE of det-adj", False),
                        ],
                    }[filter_name]
                    if filter_name == '':
                        configs = configs[:1] + configs[-1:]

                    for hue, _palette, x, y, axis_option, title, plot_reg in configs:
                        xlabel = get_xylabel(x)
                        ylabel = get_xylabel(y)
                        kwargs = {
                            key: globals()[key]
                            for key in ['x', 'y', 'text_label', 'text_label_kwargs', 'xlabel', 'ylabel', 'axis_option']
                        }
                        if _palette is not None:
                            kwargs['palette'] = _palette
                        plt.figure(figsize=(5, 5) if small_plot else (10, 10))
                        plot_wrapper(sns.scatterplot)(items, hue=hue, **kwargs)
                        if add_titles:
                            plt.title(title)
                        output_fig(title)
                        plt.close()
                        if plot_reg:
                            _title = title + ' regression'
                            plot_wrapper(sns.regplot)(items, **kwargs)
                            if add_titles:
                                plt.title(_title)
                            output_fig(_title)
                            plt.close()


    # investigating correlations bewteen concreteness and loss diff
    for split in ['val'][:0]:
        for name in filter(lambda name: name != baseline_name, names):
            print(f'{name}:')
            loss_field = f'{name} loss'
            loss_diff_field = diff_field_name(name, baseline_name, 'loss')

            token_items = split_items[split].token_items
            token_items = token_items.dropna(subset=conc_field)
            token_items = token_items[token_items['cnt'] >= 0]
            for items_name, (items, tag_field, ctg_field, *unused) in item_combinations(token_items, with_all=True).items():
                if items_name not in ['All']:
                    continue
                figname_prefix = f'{name} {items_name} '
                print(f'{items_name}:')
                text_label = token_field if len(items) <= 100 else None

                loss_diff_items = items.sort_values(loss_diff_field)
                _title = figname_prefix + f'Concreteness vs loss diff'
                plot_wrapper(sns.scatterplot)(loss_diff_items, x=conc_field, y=loss_diff_field, hue=ctg_field, palette=palette, text_label=text_label, text_label_kwargs=text_label_kwargs)
                if add_lines:
                    plt.gca().axhline(0, **line_kwargs)
                if add_titles:
                    plt.title(_title)
                output_fig(_title)
                plt.close()
                if False:
                    _title = figname_prefix + f'Concrenteness vs loss diff all words'
                    plot_wrapper(sns.regplot)(loss_diff_items, x=conc_field, y=loss_diff_field, text_label=text_label, text_label_kwargs=text_label_kwargs)
                    if add_lines:
                        plt.gca().axhline(0, **line_kwargs)
                    if add_titles:
                        plt.title(_title)
                    output_fig(_title)
                    plt.close()
                    if False:
                        _title = figname_prefix + f'{pos_field} vs loss diff all words'
                        sns.catplot(loss_diff_items, x=pos_field, y=loss_diff_field, color="b",) #kind="violin", inner="stick",
                        if add_lines:
                            plt.gca().axhline(0, **line_kwargs)
                        if add_titles:
                            plt.title(_title)
                        output_Fig(_title)
                        plt.close()
                if False:
                    print('highest:')
                    _title = figname_prefix + 'Concreness vs loss diff highest words'
                    plot_wrapper(sns.scatterplot)(loss_diff_items[::-1], x=conc_field, y=loss_diff_field, hue=ctg_field, palette=palette, n_items=n_items, text_label=text_label, text_label_kwargs=text_label_kwargs)
                    if add_lines:
                        plt.gca().axhline(0, **line_kwargs)
                    if add_titles:
                        plt.title(_title)
                    output_fig(_title)
                    plt.close()

LSTM:
All Syntactic Categories:




T-SNE done.
saving plot LSTM Embedding All Syntactic Categories t-SNE
saving plot LSTM Embedding All Syntactic Categories PCA
Syntactic Categories:




T-SNE done.
saving plot LSTM Embedding Syntactic Categories t-SNE
saving plot LSTM Embedding Syntactic Categories PCA
Semantic Categories:




T-SNE done.
saving plot LSTM Embedding Semantic Categories t-SNE
saving plot LSTM Embedding Semantic Categories PCA
Verb Transitivity:




T-SNE done.
saving plot LSTM Embedding Verb Transitivity t-SNE
saving plot LSTM Embedding Verb Transitivity PCA
All:




T-SNE done.
saving plot LSTM Embedding All t-SNE
saving plot LSTM Embedding All PCA
All Noun vs Verb:




T-SNE done.
saving plot LSTM Embedding All Noun vs Verb t-SNE
saving plot LSTM Embedding All Noun vs Verb PCA
All:




T-SNE done.
saving plot LSTM Embedding All t-SNE with AnimPhysical
All:




T-SNE done.
saving plot LSTM Embedding All t-SNE of det-adj
LSTM Captioning:
All Syntactic Categories:




T-SNE done.
saving plot LSTM Captioning Embedding All Syntactic Categories t-SNE
saving plot LSTM Captioning Embedding All Syntactic Categories PCA
Syntactic Categories:




T-SNE done.
saving plot LSTM Captioning Embedding Syntactic Categories t-SNE
saving plot LSTM Captioning Embedding Syntactic Categories PCA
Semantic Categories:




T-SNE done.
saving plot LSTM Captioning Embedding Semantic Categories t-SNE
saving plot LSTM Captioning Embedding Semantic Categories PCA
Verb Transitivity:




T-SNE done.
saving plot LSTM Captioning Embedding Verb Transitivity t-SNE
saving plot LSTM Captioning Embedding Verb Transitivity PCA
All:




T-SNE done.
saving plot LSTM Captioning Embedding All t-SNE
saving plot LSTM Captioning Embedding All PCA
All Noun vs Verb:




T-SNE done.
saving plot LSTM Captioning Embedding All Noun vs Verb t-SNE
saving plot LSTM Captioning Embedding All Noun vs Verb PCA
All:




T-SNE done.
saving plot LSTM Captioning Embedding All t-SNE with AnimPhysical
All:




T-SNE done.
saving plot LSTM Captioning Embedding All t-SNE of det-adj
CBOW:
All Syntactic Categories:




T-SNE done.
saving plot CBOW Embedding All Syntactic Categories t-SNE
saving plot CBOW Embedding All Syntactic Categories PCA
Syntactic Categories:




T-SNE done.
saving plot CBOW Embedding Syntactic Categories t-SNE
saving plot CBOW Embedding Syntactic Categories PCA
Semantic Categories:




T-SNE done.
saving plot CBOW Embedding Semantic Categories t-SNE
saving plot CBOW Embedding Semantic Categories PCA
Verb Transitivity:




T-SNE done.
saving plot CBOW Embedding Verb Transitivity t-SNE
saving plot CBOW Embedding Verb Transitivity PCA
All:




T-SNE done.
saving plot CBOW Embedding All t-SNE
saving plot CBOW Embedding All PCA
All Noun vs Verb:




T-SNE done.
saving plot CBOW Embedding All Noun vs Verb t-SNE
saving plot CBOW Embedding All Noun vs Verb PCA
All:




T-SNE done.
saving plot CBOW Embedding All t-SNE with AnimPhysical
All:




T-SNE done.
saving plot CBOW Embedding All t-SNE of det-adj
