In [1]:
%cd ..

/misc/vlgscratch4/LakeGroup/wentao/multimodal-baby


In [2]:
import itertools
import functools
import copy
import numpy as np
import torch
import matplotlib
import matplotlib.pyplot as plt
%matplotlib inline
import pandas as pd
import seaborn as sns
from torchvision.transforms.functional import resized_crop
from multimodal.multimodal_lit import MultiModalLitModel
from multimodal.multimodal_data_module import PAD_TOKEN_ID, UNK_TOKEN_ID, SOS_TOKEN_ID, EOS_TOKEN_ID, normalizer
from ngram import NGramModel
from analysis_tools.utils import *
from analysis_tools.pos_tags import *
from analysis_tools.word_categories import *
from analysis_tools.token_items_data import *
from analysis_tools.plotting import *
import analysis_tools.plotting as plotting
from analysis_tools.multimodal_visualization import *
from analysis_tools.processing import *
from analysis_tools.build_data import *
from analysis_tools.checkpoints import *


# set default settings for plotting; may change for each plot
figsize = (8, 7)
paper_context = sns.plotting_context('paper')
paper_context.update({
    'font.size': 10.,
    'axes.labelsize': 10.,
    'axes.titlesize': 14.,
    'xtick.labelsize': 8.8,
    'ytick.labelsize': 8.8,
    'legend.fontsize': 8.8,
    'legend.title_fontsize': 9.6,
})
unticked_relation_style = sns.axes_style('white')
unticked_relation_style.update({
    'axes.spines.left': True,
    'axes.spines.bottom': True,
    'axes.spines.right': False,
    'axes.spines.top': False,
})
ticked_relation_style = sns.axes_style('ticks')
ticked_relation_style.update({
    'axes.spines.left': True,
    'axes.spines.bottom': True,
    'axes.spines.right': False,
    'axes.spines.top': False,
})
heatmap_style = copy.copy(unticked_relation_style)
heatmap_style.update({
    'axes.spines.left': False,
    'axes.spines.bottom': False,
})
font = 'serif'
sns.set_theme(
    context=paper_context,
    style=unticked_relation_style,
    palette=sns.color_palette('tab20'),
    font=font,
    rc={
        'figure.figsize': figsize,
    }
)

np.set_printoptions(suppress=True, precision=2, linewidth=120)
pd.options.display.width = 120

plt.rcParams['figure.dpi'] = 300
plt.rcParams['savefig.dpi'] = 300
plt.rcParams["savefig.bbox"] = 'tight'
plt.rcParams["savefig.pad_inches"] = 0.
plot_format = 'pdf'
saving_fig = True

if plot_format == 'pgf':
    matplotlib.use('pgf')
    plt.rcParams.update({
        "pgf.texsystem": "pdflatex",
        "pgf.preamble": "\n".join([
             r"\usepackage[T1]{fontenc}",
        ]),
    })
elif plot_format == 'svg':
    from matplotlib_inline.backend_inline import set_matplotlib_formats
    set_matplotlib_formats('svg')

if saving_fig:
    def _save_fig(fname, format='png'):
        print(f'saving plot {fname}')
        plt.savefig(f'plots/{fname}.{format}', transparent=True)
        plt.clf()
    plotting.output_fig = functools.partial(_save_fig, format=plot_format)
else:
    plotting.output_fig = lambda fname: plt.show()

output_fig = plotting.output_fig

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
# load model from checkpoint

# select from list of checkpoints
dataset_name = "saycam"
ori_names = {
    "saycam": [
        "LSTM 0", "LSTM 1", "LSTM 2",
        "LSTM Captioning 0", "LSTM Captioning 1", "LSTM Captioning 2",
        "CBOW 0", "CBOW 1", "CBOW 2",
    ],
    "coco": [
        "lm",
        "capt_ft",
        "capt_attn_gt_ft",
        "capt_attn_gt_reg_ft",
        "capt_attn_gt_reg_untie_ft",
        "cbow",
    ],
}[dataset_name]
dataset_checkpoint_paths = all_checkpoint_paths[dataset_name]
ori_checkpoint_paths = [dataset_checkpoint_paths[ori_name] for ori_name in ori_names]

ori_models = []

data = None

for checkpoint_path in ori_checkpoint_paths:
    if "gram" in checkpoint_path:
        ngram_model = build_ngram_model(int(checkpoint_path.split('-')[0]), vocab_size, data.train_dataloader())
        model = ngram_model

    else:
        print_dict_args = False
        if print_dict_args:
            ckpt = torch.load(checkpoint)
            print(ckpt['hyper_parameters']['args'])

        print(f"load model from {checkpoint_path}")
        lit_model = MultiModalLitModel.load_from_checkpoint(checkpoint_path, map_location=device)
        #print(lit_model.args)
        lit_model.to(device)

        if data is None:
            # build data and vocab according to the model
            data, args = build_data(args=lit_model.args, return_args=True)
            dataset_name = args.dataset

            word2idx = lit_model.text_encoder.word2idx
            idx2word = lit_model.text_encoder.idx2word

            vocab = lit_model.text_encoder.vocab
            vocab_size = len(vocab)
            print(f'vocab_size = {vocab_size}')
            # check consistency between vocab and idx2word
            for idx in range(vocab_size):
                assert idx in idx2word

        else:
            assert lit_model.args["dataset"] == dataset_name, f"checkpoint {checkpoint_path} ran on a different dataset {args.dataset}"

        lit_model.eval()
        model = lit_model

    ori_models.append(model)


# each name represents a group of models that we want to use their mean predictions;
# for example, if 'lm' is in this list, then predictions of all model with name 'lm*' is aggregated into 'lm'
names = ['LSTM', 'LSTM Captioning', 'CBOW', 'Contrastive', 'Joint bs16', 'Joint bs512'][:-3] if True else ori_names

groups = {name: [] for name in names}

for i, ori_name in enumerate(ori_names):
    best_name = ''
    for name in names:
        if ori_name.startswith(name) and len(name) > len(best_name):
            best_name = name
    groups[best_name].append(i)

models = []
for name, group in groups.items():
    assert group, f"no models corresponds to {name}"
    models.append(ori_models[group[0]])

load model from checkpoints/lm_dataset_saycam_captioning_False_text_encoder_lstm_embedding_dim_512_dropout_i_0.5_dropout_o_0.0_batch_size_16_lr_0.006_lr_scheduler_True_weight_decay_0.04_seed_0/epoch=29.ckpt


  rank_zero_deprecation(


Using base transforms
Calling prepare_data!
SAYCam transcripts have already been downloaded. Skipping this step.
Transcripts have already been renamed. Skipping this step.
Transcripts have already been preprocessed. Skipping this step.
Training frames have already been extracted. Skipping this step.
Training metadata files have already been created. Skipping this step.
Shuffled training metadata file has already been created. Skipping this step.
Evaluation frames have already been filtered. Skipping this step.
Evaluation frames have already been extracted. Skipping this step.
Filtered evaluation frames have already been extracted. Skipping this step.
Evaluation metadata files have already been created. Skipping this step.
Evaluation metadata files have already been created. Skipping this step.
Extra evaluation metadata files have already been created. Skipping this step.
Extra evaluation metadata files have already been created. Skipping this step.
Vocabulary file already exists. Skipp

In [4]:
# get sum values (counts, vector representations, losses) across the training set

split_items = {}

used_splits = ['train', 'val']


def print_items(items, n=20):
    for index, row in items.iloc[:n].iterrows():
        print(row_str(row, names))

    print()

    columns = ['cnt'] + names
    tot_values = items.loc[items.index.map(lambda index: index[0] != PAD_TOKEN_ID), columns].sum(axis=0)
    tot = pd.Series(tot_values, index=items.columns)
    tot[token_field] = 'ppl'
    print(row_str(tot, names))
    tot_values -= items.loc[(SOS_TOKEN_ID,), columns].sum(axis=0)
    tot = pd.Series(tot_values, index=items.columns)
    tot[token_field] = 'ppl_wo_sos'
    print(row_str(tot, names))
    tot_values -= items.loc[(EOS_TOKEN_ID,), columns].sum(axis=0)
    tot = pd.Series(tot_values, index=items.columns)
    tot[token_field] = 'ppl_wo_sos_eos'
    print(row_str(tot, names))


def remove_foils_wrapper(dataloader):
    for x, y, y_len, raw_y in dataloader:
        yield x[:, 0], y, y_len, raw_y

my_batch_size = 256
dataloader_fns = {
    'train': lambda: data.train_dataloader(batch_size=my_batch_size, shuffle=False, drop_last=False),
    'val': lambda: data.val_dataloader(batch_size=my_batch_size)[0],
    'test': lambda: data.test_dataloader(batch_size=my_batch_size)[0],
    'eval_val': lambda: remove_foils_wrapper(data.val_dataloader()[1]),
    'eval_test': lambda: remove_foils_wrapper(data.test_dataloader()[1]),
}

for split in used_splits:
    dataloader_fn = dataloader_fns[split]

    pos_tags = get_pos_tags(dataloader_fn(), dataset_name, split)

    ori_model_items = [
        torch_cache(checkpoint_path + f'.{split}.cache.pt')(get_model_items)(
            model, dataloader_fn(), pos_tags, ignore_all_token_items=(split == 'train'))
        for model, checkpoint_path in zip(ori_models, ori_checkpoint_paths)]
    model_items = [mean_model_items([ori_model_items[i] for i in group], idx=-1) for group in groups.values()]
    items = ModelItems(*[stack_items(items_list, names, idx2word) for items_list in list(zip(*model_items))])
    extend_items(items.token_items, names, idx2word)
    extend_items(items.token_pos_items, names, idx2word)
    if items.all_token_items is not None:
        extend_items(items.all_token_items, names, idx2word)
    split_items[split] = items

    print_items(items.token_items)

    if split_items[split].all_token_items is not None:
        for name, group in groups.items():
            ori_probs = []
            for i in group:
                model = ori_models[i]
                if isinstance(model, NGramModel):
                    continue
                probs = get_model_probs(model, dataloader_fns[split](), pos_tags)
                ori_probs.append(probs)
            probs = mean_probs(ori_probs)
            split_items[split].all_token_items[f'{name} probs'] = probs

load cached pos tags: dataset_cache/saycam/train.pos.cache
load from checkpoints/lm_dataset_saycam_captioning_False_text_encoder_lstm_embedding_dim_512_dropout_i_0.5_dropout_o_0.0_batch_size_16_lr_0.006_lr_scheduler_True_weight_decay_0.04_seed_0/epoch=29.ckpt.train.cache.pt
load from checkpoints/lm_dataset_saycam_captioning_False_text_encoder_lstm_embedding_dim_512_dropout_i_0.5_dropout_o_0.0_batch_size_16_lr_0.006_lr_scheduler_True_weight_decay_0.04_seed_1/epoch=38.ckpt.train.cache.pt
load from checkpoints/lm_dataset_saycam_captioning_False_text_encoder_lstm_embedding_dim_512_dropout_i_0.5_dropout_o_0.0_batch_size_16_lr_0.006_lr_scheduler_True_weight_decay_0.04_seed_2/epoch=28.ckpt.train.cache.pt
load from checkpoints/lm_dataset_saycam_captioning_True_text_encoder_lstm_embedding_dim_512_dropout_i_0.5_dropout_o_0.0_batch_size_16_lr_0.006_lr_scheduler_True_weight_decay_0.04_seed_0/epoch=29.ckpt.train.cache.pt
load from checkpoints/lm_dataset_saycam_captioning_True_text_encoder_lstm_embe

100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:02<00:00,  3.28it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:02<00:00,  3.46it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:02<00:00,  3.53it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:04<00:00,  1.77it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:04<00:00,  1.76it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:04<00:00,  1.76it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:02<00:00,  3.71it/s]
100%|██████████████████████████████████████████████████

In [5]:
# plot statistics of syntactic categories

pos_tag_dfs = []

splits = ['train', 'val', 'test']

for split in splits:
    dataloader_fn = dataloader_fns[split]
    pos_tags = get_pos_tags(dataloader_fn(), dataset_name, split)

    pos_tag_df = pd.DataFrame(data=list(itertools.chain.from_iterable(pos_tags)), columns=['pos'])
    pos_tag_df['split'] = split
    pos_tag_dfs.append(pos_tag_df)

pos_tag_df = pd.concat(pos_tag_dfs)

for pos_field, pos_mapping in pos_mappings.items():
    pos_tag_df[pos_field] = pos_tag_df['pos'].map(pos_mapping).astype('category')

pos_field = 'syntactic category'

g = sns.catplot(kind='count', data=pos_tag_df, x='split', hue=pos_field, palette=pos_palette)
g.figure.set_size_inches(*figsize)

for split in splits:
    print(f'{split:5s} #examples: {len(data.datasets[split]):5d}')

for split in splits:
    split_cnt = len(pos_tag_df[(pos_tag_df['split'] == split)])
    print(f'{split:5s} #tokens: {split_cnt:6d}')

for (split, pos), bar in zip(itertools.product(splits, pos_tag_df.dtypes[pos_field].categories), g.ax.patches):
    split_cnt = len(pos_tag_df[(pos_tag_df['split'] == split)])
    split_pos_cnt = len(pos_tag_df[(pos_tag_df['split'] == split) & (pos_tag_df[pos_field] == pos)])
    print(f'{split:5s} {pos:15s} #tokens: {split_pos_cnt:6d} {split_pos_cnt / split_cnt :7.2%}')
    g.ax.text(bar.get_x() + bar.get_width()/2., bar.get_height() + 40, f'{split_pos_cnt:d}', ha="center", fontsize='xx-small')

output_fig('dataset distribution')

load cached pos tags: dataset_cache/saycam/train.pos.cache
load cached pos tags: dataset_cache/saycam/val.pos.cache
load cached pos tags: dataset_cache/saycam/test.pos.cache
train #examples: 33737
val   #examples:  1874
test  #examples:  1875
train #tokens: 292475
val   #tokens:  16103
test  #tokens:  16168
train .               #tokens: 122736  41.96%
train adjective       #tokens:   8190   2.80%
train adverb          #tokens:  13679   4.68%
train cardinal number #tokens:   1247   0.43%
train function word   #tokens:  75610  25.85%
train noun            #tokens:  27132   9.28%
train verb            #tokens:  43881  15.00%
val   .               #tokens:   6882  42.74%
val   adjective       #tokens:    419   2.60%
val   adverb          #tokens:    757   4.70%
val   cardinal number #tokens:     81   0.50%
val   function word   #tokens:   4117  25.57%
val   noun            #tokens:   1446   8.98%
val   verb            #tokens:   2401  14.91%
test  .               #tokens:   6733  41.64%
t

<Figure size 2400x2100 with 0 Axes>

In [6]:
searched_word = "ball"
visualize_models = True
n_print_example = 5
textgen_model_ns = [model_n for model_n, (name, model) in enumerate(zip(names, models)) if 'attn' in name] if False else []
multiple_views = False
all_steps = False

gradcam_model_ns = [model_n for model_n, model in enumerate(models) if model.text_encoder.captioning or model.text_encoder.has_attention]
attn_model_ns = [model_n for model_n, model in enumerate(models) if model.text_encoder.has_attention]
if not visualize_models:
    gradcam_model_ns = []
    attn_model_ns = []
n_visualized_models = len(gradcam_model_ns) + len(attn_model_ns)

if searched_word:
    if searched_word is True:
        searched_word = input("search word: ")
    searched_token_id = word2idx.get(searched_word, UNK_TOKEN_ID)
    if searched_token_id == UNK_TOKEN_ID:
        print(f"mapping {searched_word} to UNK")


def get_views(x, grid=(2, 2)):
    """Get multiple views of image x.
    """
    img = n_inv(x)
    views = [img]
    h = img.size(-2) // grid[0]
    w = img.size(-1) // grid[1]
    for grid_i in range(grid[0]):
        for grid_j in range(grid[1]):
            x = grid_i * h
            y = grid_j * w
            views.append(resized_crop(img, x, y, h, w, img.shape[-2:]))
    views = list(map(normalizer, views))
    return views


show_image = functools.partial(
    plot_image,
    overlying=True,
    blur=False,
    #interpolation='nearest',
)


def visualize_example(x, y, y_len, raw_y, steps=None, model_first=True, prepend_x=False, use_losses=None, example_name='example'):
    """Visualize an example.
    Inputs:
        steps: list of steps to visualize; None for all steps
        model_first: if True, then the axes are of n_visualized_models * len(steps), else it is transposed
        prepend_x: if True, prepend raw image x before the models
        use_losses: use the designated losses; must be a list of losses, where each losses is a list of loss at each step;
            for example, split_items[split].losses[example_i]; if None, use the losses generated by running the models
    """
    img = torch_to_numpy_image(n_inv(x))
    y_labels = [idx2word[y_id.item()] for y_id in y]
    if steps is None:
        steps = list(range(y_len.item()))

    n_rows, n_cols = int(prepend_x) + n_visualized_models, len(steps)
    if not model_first:
        n_rows, n_cols = n_cols, n_rows
    n_axes = n_rows * n_cols
    ax_size = 5
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(ax_size * n_cols, ax_size * n_rows), squeeze=False)
    if not model_first:
        axes = [[axes[j][i] for j in range(n_rows)] for i in range(n_cols)]
    axes_iter = itertools.chain.from_iterable(axes)

    if prepend_x:
        for step_i, step in enumerate(steps):
            show_image(next(axes_iter), img)

    for model_n in gradcam_model_ns:
        model, name = models[model_n], names[model_n]
        gradcams = gradCAM_for_captioning_lm(model, x, y, y_len, steps=steps)
        prefix = f'{name} GradCAM'
        for step_i, (step, gradcam) in enumerate(zip(steps, gradcams)):
            if step == 0:
                show_image(next(axes_iter), img, text=prefix)
            else:
                text = y_labels[step]
                if step_i == 0:
                    text = prefix + ' ' + text
                show_image(next(axes_iter), img, gradcam, text=text)

    for model_n in attn_model_ns:
        model, name = models[model_n], names[model_n]
        attns = attention_for_attention_lm(model, x, y, y_len, steps=steps)
        prefix = f'{name} attn'
        for step_i, (step, attn) in enumerate(zip(steps, attns)):
            text = y_labels[step]
            if step_i == 0:
                text = prefix + ' ' + text
            show_image(next(axes_iter), img, attn, text=text)

    for ax in axes_iter:
        ax.axis("off")
    output_fig(example_name + ' visual map')

    if use_losses is None:
        with torch.no_grad():
            rets = [run_model(model, x, y, y_len, single_example=True, return_all=True) for model in models]
            losses = [ret[1].cpu().numpy() for ret in rets]
    else:
        losses = use_losses
    print(raw_y[0])
    names_, losses_ = zip(*[(name, loss) for name, loss in zip(names, losses) if 'contrastive' not in name])
    plot_model_y_value_heatmap(names_, losses_, y_labels)
    output_fig(example_name+' loss heatmap')

    if use_losses is None:
        for name, model, ret in zip(names, models, rets):
            if not isinstance(model, MultiModalLitModel) or 'contrastive' in name:
                continue
            print(f'{name}:')
            if model.language_model.text_encoder.regressional:
                steps_ = [step - 1 for step in steps if step > 0]
            else:
                steps_ = steps
            logits, labels = ret[2], ret[4]
            probs = logits.softmax(-1)
            print_top_values(probs, idx2word, labels, steps=steps_, value_formatter=prob_formatter)
            print()


#examples = examples_from_dataloader(dataloader_fns['val']())

example_i = 0
print_example_i = 0

#for example_i, (x, y, y_len, raw_y) in enumerate(examples):
dataset = data.datasets['val']
for example_i in range(len(dataset)):
    x, y, y_len, raw_y = dataset[example_i]
    y_len = torch.tensor(y_len)
    y = y[:y_len]

    if searched_word:
        searched_word_steps = [idx for idx, y_id in enumerate(y) if y_id == searched_token_id]
        if not searched_word_steps:
            continue
    
    print(f'example #{example_i}:')

    if all_steps:
        steps = None
    else:
        steps = searched_word_steps if searched_word else [0]

    for x_view_i, x_view in enumerate(get_views(x) if multiple_views else [x]):
        visualize_example(x_view, y, y_len, raw_y, steps=steps, model_first=all_steps, prepend_x=not all_steps, example_name=f'example{example_i}')

    for model_n in textgen_model_ns:
        name = names[model_n]
        print(f"generating text from {name}:")
        model = models[model_n]
        image_features, image_feature_map = model.model.encode_image(x.unsqueeze(0).to(device))
        beam_seq, log_prob = model.language_model.beam_search_decode(
            batch_size=1,
            beam_width=model.beam_width,
            decode_length=model.decode_length,
            length_penalty_alpha=model.length_penalty_alpha,
            image_features=image_features if model.language_model.text_encoder.captioning else None,
            image_feature_map=image_feature_map if model.language_model.text_encoder.has_attention else None,
        )
        gen_text_ids = beam_seq[0, 0]
        gen_text_len = len(gen_text_ids)
        while gen_text_len > 0 and gen_text_ids[gen_text_len - 1] == PAD_TOKEN_ID:
            gen_text_len -= 1
        gen_text_len = torch.tensor(gen_text_len, device=device)
        gen_text_labels = [idx2word[y_id.item()] for y_id in gen_text_ids]
        gen_text = ' '.join(gen_text_labels)
        visualize_example(x, gen_text_ids, gen_text_len, [gen_text], steps=None, model_first=True, prepend_x=False, example_name=f'example{example_i} {name}')

    print_example_i += 1
    if print_example_i >= n_print_example:
        break


from textwrap import wrap
import re

shown_examples = {
    'ball': [101, 257, 353, 651, 756],
    'kitty': [97, 285, 442, 1634, 1862],
    'banana': [27, 657, 1137, 1419, 1687],
    #'hand': [156, 1322, 1402, 1432, 1513],
    #'sand': [307, 1712, 1763],
    'baby': [256, 495, 1072, 1139, 1497],
}
n_rows = len(shown_examples)
n_cols = 5
ax_size = 2.2
for model_n in gradcam_model_ns:
    model, name = models[model_n], names[model_n]
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(ax_size * n_cols, ax_size * n_rows), squeeze=False)
    for (searched_word, example_ids), axes_line in zip(shown_examples.items(), axes):
        searched_token_id = word2idx.get(searched_word, UNK_TOKEN_ID)
        for example_i, ax in zip(example_ids, axes_line):
            x, y, y_len, raw_y = dataset[example_i]
            y_len = torch.tensor(y_len)
            utterance = raw_y[0]
            img = torch_to_numpy_image(n_inv(x))
            searched_word_steps = [idx for idx, y_id in enumerate(y) if y_id == searched_token_id]
            assert len(searched_word_steps) == 1
            gradcams = gradCAM_for_captioning_lm(model, x, y, y_len, steps=searched_word_steps)
            step = searched_word_steps[0]
            gradcam = gradcams[0]
            show_image(ax, img, gradcam)
            wrapped_utterance = '\n'.join(wrap(utterance, 40))
            wrapped_utterance = re.sub(r'\b'+searched_word+r'\b', r'$\\bf{'+searched_word+r'}$', wrapped_utterance)
            ax.text(0.5, -0.05, wrapped_utterance, ha='center', va='top', transform=ax.transAxes, fontsize='xx-small')
        ax = axes_line[0]
        ax.text(-0.3, 0.5, searched_word, ha='center', va='top', transform=ax.transAxes, fontsize='medium')
    output_fig(f'{name} GradCAM Examples')

example #101:
normalizing: [vmin, vmax] = [-0.005486, 0.068414] to [0, 1]
saving plot example101 visual map
you love that ball .
saving plot example101 loss heatmap
LSTM:
  2.0% ball     |  19.4% ,         14.8% .          9.6% one        8.4% book       5.8% ?       

LSTM Captioning:
 14.4% ball     |  16.9% .         14.4% ball       8.1% one        7.4% ,          6.4% ?       

CBOW:
  0.7% ball     |  17.8% <eos>      8.6% one        6.3% is         5.3% <unk>      3.2% way     

example #173:
normalizing: [vmin, vmax] = [-0.008447, 0.131133] to [0, 1]
saving plot example173 visual map
there s a ball and puppies and bunnies and bears and carrots .
saving plot example173 loss heatmap
LSTM:
 10.6% ball     |  10.6% ball       8.3% kitty      7.8% baby       6.0% car        4.4% doggy   

LSTM Captioning:
  5.7% ball     |   9.7% kitty      6.9% car        5.7% ball       4.9% frog       3.6% doggy   

CBOW:
  1.7% ball     |   7.2% <sos>      4.4% <unk>      3.7% kitty      3.2% he

<Figure size 1680x367.5 with 0 Axes>

<Figure size 3360x367.5 with 0 Axes>

<Figure size 2100x367.5 with 0 Axes>

<Figure size 1260x367.5 with 0 Axes>

<Figure size 2100x367.5 with 0 Axes>

<Figure size 3300x2640 with 0 Axes>

In [7]:
vector_attr, repres_name = [('mean_vector', 'Mean Representation'), ('embedding', 'Embedding')][1]

In [8]:
# cosine matrices for some tested words

def split_tokens(inp):
    tokens = inp.split()
    token_ids = []
    for token in tokens:
        token_id = word2idx.get(token, UNK_TOKEN_ID)
        if token_id == UNK_TOKEN_ID:
            print(f'mapping {token} to UNK')
        token_ids.append(token_id)
    return token_ids

def get_items_from_inp(inp, token_items=split_items['train'].token_items):
    token_ids = split_tokens(inp)
    if UNK_TOKEN_ID in token_ids:
        raise KeyError
    return token_items.loc[token_ids]


# cosine matrices
print('cosine matrices:')
print()
for figname, inp in {
    "color": "red orange yellow green blue purple brown black white",
    "people": "boy girl mommy daddy grandpa grandma",
}.items():
    try:
        items = get_items_from_inp(inp)
    except KeyError:
        continue
    plot_vector_sim_heatmap(items, names, vector_attr=vector_attr, one_figure=True, figname=figname)

# cosine matrices for the differentiations (vector1 - vector0)
print('cosine matrices for the differentiations:')
print()
for figname, inp in {
    "VB-VBZ": "do does go goes play plays get gets have has make makes look looks",
    "VB-VBD": "do did go went play played get got eat ate have had make made look looked fly flew",
    "VB-VBG": "do doing go going play playing get getting eat eating have having look looking fly flying drive driving stand standing crawl crawling",
    "male-female": "boy girl mommy daddy grandpa grandma",
}.items():
    try:
        items = get_items_from_inp(inp)
    except KeyError:
        continue
    plot_vector_sim_heatmap(items, names, diff=True, vector_attr=vector_attr, one_figure=True, figname=figname)

cosine matrices:

saving plot color
saving plot people
cosine matrices for the differentiations:

saving plot VB-VBZ
saving plot VB-VBD
saving plot VB-VBG
saving plot male-female


<Figure size 6930x4620 with 0 Axes>

<Figure size 5040x3360 with 0 Axes>

<Figure size 10080x3360 with 0 Axes>

<Figure size 12600x4200 with 0 Axes>

<Figure size 15120x5040 with 0 Axes>

<Figure size 5040x1680 with 0 Axes>

In [9]:
for split in used_splits:
    token_items = split_items[split].token_items

    for n, name in enumerate(names):
        print(f'{name}:')
        get_tsne_points(token_items[name], get_attr=vector_attr)
        extend_point_items(token_items, name, 'tsne')
        get_eigen_points(token_items[name], get_attr=vector_attr)
        extend_point_items(token_items, name, 'eigen')
        get_pca_points(token_items[name], get_attr=vector_attr)
        extend_point_items(token_items, name, 'pca')

LSTM:




T-SNE done.
SVD done.
PCA done.
LSTM Captioning:




T-SNE done.
SVD done.
PCA done.
CBOW:




T-SNE done.
SVD done.
PCA done.
LSTM:




T-SNE done.
SVD done.
PCA done.
LSTM Captioning:




T-SNE done.
SVD done.
PCA done.
CBOW:




T-SNE done.
SVD done.
PCA done.


In [10]:
split_pos_items, split_pos_pos_items, split_all_pos_items = {}, {}, {}

for split in used_splits:
    print(f'{split} split:')

    top_token_items = split_items[split].token_items.sort_values('cnt', ascending=False, kind='stable')   # sort by cnt
    top_token_items = top_token_items[~top_token_items[token_field].isin(untypical_words)]  # remove untypical words
    pos_field = 'syntactic category'
    used_poses = top_token_items.dtypes[pos_field].categories
    pos_items = {pos: top_token_items[top_token_items[pos_field] == pos] for pos in used_poses}
    split_pos_items[split] = pos_items
    for pos in used_poses:
        items = pos_items[pos]
        print(f'number of {pos}s: {len(items)}')
        for _, row in items[:50].iterrows():
            print(row_str(row, names))

    interleaving_step = 1

    pos_pos_items = {}
    for i_pos in range(len(used_poses)):
        for j_pos in range(len(used_poses)):
            if i_pos != j_pos:
                pos_i = used_poses[i_pos]
                pos_i_items = pos_items[pos_i]
                pos_j = used_poses[j_pos]
                pos_j_items = pos_items[pos_j]
                interleaved_dfs = []
                i = -interleaving_step
                for i in range(0, min(len(pos_i_items), len(pos_j_items)), interleaving_step):
                    interleaved_dfs.append(pos_i_items[i:i+interleaving_step])
                    interleaved_dfs.append(pos_j_items[i:i+interleaving_step])
                else:
                    i += interleaving_step
                    interleaved_dfs.append(pos_i_items[i:])
                    interleaved_dfs.append(pos_j_items[i:])
                pos_pos_items[(pos_i, pos_j)] = pd.concat(interleaved_dfs)
    split_pos_pos_items[split] = pos_pos_items

    interleaved_dfs = []
    for i in range(0, max(map(len, pos_items.values())), interleaving_step):
        for pos in used_poses:
            items = pos_items[pos]
            interleaved_dfs.append(items[i:i+interleaving_step])
    all_pos_items = pd.concat(interleaved_dfs)
    split_all_pos_items[split] = all_pos_items

    # check some items
    for _, row in top_token_items[:100].iterrows():
        print(row_str(row, names))
    print()
    for word in ['look', 'need', 'draw']:
        try:
            token_id = word2idx[word]
            for _, row in top_token_items.loc[token_id].iterrows():
                print(row_str(row, names))
        except KeyError:
            pass

train split:
number of .s: 222
.    <sos>       33737:     1.000     1.000     9.869
.    <eos>       33370:     1.462     1.409     2.862
.    .           16566:     3.050     2.802     1.817
,    ,            9479:     5.927     5.832     6.801
.    ?            8420:     2.382     2.105     3.315
UH   yeah         4923:     8.659     8.512     4.786
``   "            2043:    13.632     5.386    20.110
.    !            1819:    12.186     9.430    12.480
UH   okay         1760:    26.081    25.654    18.520
UH   oh           1272:    30.678    31.567    16.537
UH   ok            812:    62.735    53.224    52.453
,    ...           748:    45.401    35.485   125.832
UH   no            668:    67.770    52.507    50.516
UH   alright       642:    75.060    69.066    38.375
,    -             486:    18.118    13.404    23.628
UH   huh           253:    13.403    10.324    28.507
UH   well          250:    82.771    70.561   112.284
UH   yea           244:   149.527    51.817    55.7

.    <sos>       33737:     1.000     1.000     9.869
.    <eos>       33370:     1.462     1.409     2.862
.    .           16566:     3.050     2.802     1.817
,    ,            9479:     5.927     5.832     6.801
PRP  you          9235:     5.436     5.131     3.463
.    ?            8420:     2.382     2.105     3.315
DT   the          5841:     5.862     5.071     3.255
UH   yeah         4923:     8.659     8.512     4.786
TO   to           4540:     2.349     2.202     2.492
DT   a            4539:     6.554     5.162     3.592
PRP  it           4435:     9.466     8.726     8.141
CC   and          3813:    10.125     8.347     9.804
DT   that         3742:    17.595    15.068     9.711
PRP  we           3273:    12.192    10.776     5.585
PRP  i            3032:    19.200    18.311     4.120
EX   there        2921:    18.062    14.930    12.888
VBP  do           2850:    17.917    16.139     8.508
VBP  want         2816:     4.856     4.714     4.144
``   "            2043:    1

.    <sos>        1874:     1.000     1.000     9.994
.    <eos>        1857:     1.556     1.526     3.042
.    .             941:     3.684     3.575     1.886
,    ,             543:     7.732     8.272     8.293
PRP  you           512:     5.923     5.827     3.533
.    ?             452:     2.990     2.758     3.444
DT   the           311:     8.950     8.374     4.620
UH   yeah          251:     9.072     9.071     4.668
PRP  it            242:    11.236    10.751     9.312
DT   a             238:     7.951     6.829     3.719
DT   that          226:    24.554    22.002    10.493
TO   to            220:     3.882     3.904     3.425
CC   and           220:    16.097    15.848    14.038
PRP  we            190:    16.151    14.660     5.782
EX   there         167:    24.307    20.305    13.323
PRP  i             162:    21.951    21.078     4.043
VBP  do            153:    21.250    20.322    10.201
``   "             129:    21.941    10.461    22.343
VBP  want          123:     

In [11]:
def plot_dendrogram(
    items,
    names,
    vector_attr='mean_vector',
    heatmap=False,
    annot=False,
    size=0.7,
    color_threshold=None,
    heatmap_linkage=None,
    tag_field='POS tag',
    ctg_field='POS tag',
    ll_tag_field='pos',
    ll_with_cnt=True,
    ll_with_ppl=True,
    title=None
):
    """linkage clustering and dendrogram plotting
    items: pd.DataFrame
    names: the names of the models to plot
    vector_attr: use value.vector_attr; default: 'mean_vector'; can be 'embedding'
    heatmap: bool, plot something like plot_sim_heatmap
    heatmap_linkage: the row_linkage and col_linkage used in heatmap; can be any of:
        None: use the linkage result from the vectors of current model
        "first": use the linkage result of the first model
        "tag": build the linkage by clustering items by tag_field
        result from linkage function
    tag_field: the field of tags to obtain the palette of the sidebar of heatmaps
    ctg_field: the field as the category in dendrograms
    ll_tag_field: the field of tags to append to leaf labels; set to None or empty string if unwanted
    ll_with_cnt: whether to append cnt to leaf labels
    ll_with_ppl: whether to append ppl to leaf labels
    title: title of the plots
    """
    from scipy.cluster.hierarchy import dendrogram, linkage

    n_items = len(items)
    vectors = [get_np_attrs_from_values(items[name], vector_attr) for name in names]

    title_ = f'{title} RSA'
    plot_repres_sim_heatmap(vectors, names, title=title_)
    output_fig(title_)

    # build color map
    colors = items[tag_field].astype('O').map(get_palette(tag_field)).tolist()

    if heatmap:
        if heatmap_linkage == "tag":  # build Z_heatmap based on tag_field
            Z_heatmap = build_linkage_by_same_value(items[tag_field])
        elif not (heatmap_linkage is None or isinstance(heatmap_linkage, str)):  # use heatmap_linkage
            Z_heatmap = heatmap_linkage

    ll_tag_width = max(map(len, items[ll_tag_field]))

    for n, (name, V) in enumerate(zip(names, vectors)):
        print(f'{name}:')
        Z = linkage(V, method='average', metric='cosine')  # of shape (number of merges = n_items - 1, 4)

        def llf(index):
            if index < n_items:
                row = items.iloc[index]
                return row_llf(
                    items.iloc[index],
                    tag_field=ll_tag_field,
                    tag_width=ll_tag_width,
                    sep='  ',
                    with_cnt=ll_with_cnt,
                    name=name if ll_with_ppl else None,
                    baseline_name=None if n == 0 else names[0],
                )
            else:
                merge_index = index - n_items
                return f'{merge_index} {int(Z[merge_index, 3])} {Z[merge_index, 2]:.3f}'

        ctg_palette = get_palette(ctg_field)
        ctg_sets = [{ctg} for ctg in items[ctg_field]]
        link_colors = []
        for link in Z:
            ctg_set = ctg_sets[int(link[0])] | ctg_sets[int(link[1])]
            ctg_sets.append(ctg_set)
            if len(ctg_set) == 1:
                ctg = next(iter(ctg_set))
                link_color = ctg_palette[ctg]
            else:
                link_color = 'black'
            link_colors.append(link_color)

        p = 10000

        plt.figure(figsize=(4, 0.15 * min(p, n_items))) # 0.1
        dendrogram(
            Z,
            truncate_mode='lastp',
            p=p,
            color_threshold=color_threshold,
            orientation='left',
            leaf_rotation=0.,
            #leaf_font_size=16.,
            leaf_label_func=llf,
            link_color_func=lambda k: link_colors[k - n_items],
        )

        title_ = f"{name} {title} Dendrogram"
        if title is not None:
            plt.title(title_)
        #plt.show()
        output_fig(title_)

        if heatmap:
            if heatmap_linkage is None:
                Z_heatmap = Z
            elif heatmap_linkage == "first":
                if n == 0:
                    Z_heatmap = Z

            prefix_labels = [row_llf(row, tag_field=tag_field, with_cnt=False) for _, row in items.iterrows()]
            llf_labels = list(map(llf, range(n_items)))

            matrix = cosine_matrix(V)

            off_diag = ~np.eye(matrix.shape[0], matrix.shape[1], dtype=bool)
            v = np.max(np.abs(matrix[off_diag]))
            vmin = -v
            vmax = +v

            g = sns.clustermap(
                matrix,
                row_linkage=Z_heatmap,
                col_linkage=Z_heatmap,
                figsize=(8, 8),
                cbar_pos=None,
                # kwargs for heatmap
                vmin=vmin, vmax=vmax, center=0,
                annot=annot, fmt='.2f',
                xticklabels=prefix_labels,
                yticklabels=prefix_labels,
                row_colors=colors,
                col_colors=colors,
                square=True,
                #cbar=False,
                dendrogram_ratio=0., # remove all dendrograms
                colors_ratio=0.02,
            )
            g.ax_col_dendrogram.remove()

            title_ = f"{name} {title} Similarity Heatmap"
            if title is not None:
                plt.title(title_)
            output_fig(title_)


def get_subcat_items(items, pos, n_items_from_each_cat=None):
    if n_items_from_each_cat is None:
        n_items_from_each_cat = {'noun': 6, 'verb': 25}[pos]
    dfs = []
    for cat_name, words in pos_subcats[pos].items():
        df = items[items[token_field].isin(words)].copy()
        if len(df) >= n_items_from_each_cat:
            dfs.append(df[:n_items_from_each_cat])
    items = pd.concat(dfs)
    items[subcat_field] = items[token_field].map(word2subcat).astype('category')
    return items


for split in ['train']:
    for title, items, tag_field, ctg_field, ll_tag_field, heatmap_linkage in (
        ('Most Frequent Words', split_all_pos_items[split][:100], 'POS tag', 'POS tag', 'pos', 'first'),
        ('Noun vs Verb', split_pos_pos_items[split][('noun', 'verb')][:50], 'POS tag', 'syntactic category', 'pos', 'first'),
        ('Semantic Categories', get_subcat_items(split_pos_items[split]['noun'], 'noun'), subcat_field, subcat_field, subcat_field, 'tag'),
        ('Verb Transitivity', get_subcat_items(split_pos_items[split]['verb'], 'verb'), subcat_field, subcat_field, subcat_field, 'tag'),
    )[1:]:
        print(f'{name}:')
        with sns.axes_style(heatmap_style, rc={'font.family': [font]}):
            plot_dendrogram(
                items,
                names,
                heatmap=True,
                heatmap_linkage=heatmap_linkage,
                title=f'{repres_name} {title}',
                vector_attr=vector_attr,
                tag_field=tag_field,
                ctg_field=ctg_field,
                ll_tag_field=ll_tag_field,
                ll_with_cnt=False,
                ll_with_ppl=False,
            )

CBOW:
saving plot Embedding Noun vs Verb RSA
LSTM:
saving plot LSTM Embedding Noun vs Verb Dendrogram




saving plot LSTM Embedding Noun vs Verb Similarity Heatmap
LSTM Captioning:
saving plot LSTM Captioning Embedding Noun vs Verb Dendrogram




saving plot LSTM Captioning Embedding Noun vs Verb Similarity Heatmap
CBOW:
saving plot CBOW Embedding Noun vs Verb Dendrogram




saving plot CBOW Embedding Noun vs Verb Similarity Heatmap
CBOW:
saving plot Embedding Semantic Categories RSA
LSTM:
saving plot LSTM Embedding Semantic Categories Dendrogram




saving plot LSTM Embedding Semantic Categories Similarity Heatmap
LSTM Captioning:
saving plot LSTM Captioning Embedding Semantic Categories Dendrogram




saving plot LSTM Captioning Embedding Semantic Categories Similarity Heatmap
CBOW:
saving plot CBOW Embedding Semantic Categories Dendrogram




saving plot CBOW Embedding Semantic Categories Similarity Heatmap
CBOW:
saving plot Embedding Verb Transitivity RSA
LSTM:
saving plot LSTM Embedding Verb Transitivity Dendrogram




saving plot LSTM Embedding Verb Transitivity Similarity Heatmap
LSTM Captioning:
saving plot LSTM Captioning Embedding Verb Transitivity Dendrogram




saving plot LSTM Captioning Embedding Verb Transitivity Similarity Heatmap
CBOW:
saving plot CBOW Embedding Verb Transitivity Dendrogram




saving plot CBOW Embedding Verb Transitivity Similarity Heatmap


<Figure size 1500x1500 with 0 Axes>

<Figure size 1200x2250 with 0 Axes>

<Figure size 2400x2400 with 0 Axes>

<Figure size 1200x2250 with 0 Axes>

<Figure size 2400x2400 with 0 Axes>

<Figure size 1200x2250 with 0 Axes>

<Figure size 1500x1500 with 0 Axes>

<Figure size 1200x2160 with 0 Axes>

<Figure size 2400x2400 with 0 Axes>

<Figure size 1200x2160 with 0 Axes>

<Figure size 2400x2400 with 0 Axes>

<Figure size 1200x2160 with 0 Axes>

<Figure size 1500x1500 with 0 Axes>

<Figure size 1200x2250 with 0 Axes>

<Figure size 2400x2400 with 0 Axes>

<Figure size 1200x2250 with 0 Axes>

<Figure size 2400x2400 with 0 Axes>

<Figure size 1200x2250 with 0 Axes>

<Figure size 2400x2400 with 0 Axes>

In [12]:
from scipy.stats import ttest_rel


def get_best_split_accuracy(label_0, label_1):
    n_label_0 = label_0.value_counts()[True]
    n_label_1 = label_1.value_counts()[True]

    i = 0
    count_0 = 0
    count_1 = n_label_1

    best_i, best_n = i, count_0 + count_1

    for label_i_0, label_i_1 in zip(label_0, label_1):
        count_0 += label_i_0
        count_1 -= label_i_1
        i += 1

        if count_0 + count_1 > best_n:
            best_i, best_n = i, count_0 + count_1

    print(f'n_label_0: {n_label_0}, n_label_1: {n_label_1}, best_accuracy: {best_n / (n_label_0 + n_label_1):.2%}')
    return best_i, best_n / (n_label_0 + n_label_1)


def plot_ROC(y_true, y_score, pos_label=None, label="", **kwargs):
    from sklearn.metrics import roc_curve, auc
    fpr, tpr, thresholds = roc_curve(y_true, y_score, pos_label=pos_label)
    roc_auc = auc(fpr, tpr)
    plt.plot(
        fpr,
        tpr,
        label=f"{label} (AUC = {roc_auc:0.2f})",
        **kwargs,
    )


def plot_ROC_end(title="", **kwargs):
    plt.plot([0, 1], [0, 1], **kwargs)
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.0])
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title(title)
    plt.legend(loc="lower right")
    output_fig(title)


def ttest(a, label, alternative):
    result = ttest_rel(a, np.zeros_like(a), alternative=alternative)
    print(f'{label:15} #examples: {len(a):5} mean: {np.mean(a):5.2f} t-test result: statistic: {result.statistic:6.2f} pvalue: {result.pvalue:4.2f}')


def analyze_value(
    items,
    value_attr,
    model_names,
    cat_field="syntactic category",
    palette=None,
    title=None,
    hlines=None,
    find_best_threshold=False,
    plotting=True,
    ROC=False,
    ttest_alternative='two-sided',
):
    """
    items: items
    value_attr: column in items; e.g., 'loss', 'prob'
    model_names: names of the models to analyze
    cat_field: the category field
    palette: the palette for the categories; a dict mapping each category name to a color
    """
    if palette is None:
        palette = get_palette(cat_field)
    value_suffix = ' ' + value_attr
    model_field = 'model'
    if title is None:
        title = value_attr

    for cat in items.dtypes[cat_field].categories:
        cat_items = items[items[cat_field] == cat]
        if len(cat_items) > 0:
            print(f'{cat:15} #: {len(cat_items)}')
    items_long = items.melt(
        id_vars=[cat_field],
        value_vars=[name + value_suffix for name in model_names],
        var_name=model_field,
        value_name=value_attr)
    items_long[model_field] = items_long[model_field].map(lambda s: s[:-len(value_suffix)])
    plot(
        sns.catplot,
        items_long,
        x=model_field,
        y=value_attr,
        hue=cat_field,
        palette=palette,
        hlines=hlines,
        kind="box",
        figsize=(8, 4),
        medianprops=dict(color="white", alpha=0.7),
        title=title)
    output_fig(title)

    for name in model_names:
        print(f'{name}:')
        value_field = f"{name} {value_attr}"
        cur_items = items.sort_values(value_field)

        if find_best_threshold:
            assert cat_field == "syntactic category"
            label_0 = cur_items[cat_field].isin(["noun"])
            label_1 = ~label_0 #cur_items[cat_field].isin(["function word", "adjective", "adverb", "cardinal number"])
            best_split_i, best_split_acc = get_best_split_accuracy(label_0, label_1)
            print(f'best_split_accuracy: {best_split_acc:.2%}')
            threshold = cur_items.iloc[min(best_split_i, len(cur_items)-1)][value_field]

        if plotting:
            _title = f'{name} {title} Distribution'
            plot(
                sns.kdeplot,
                cur_items,
                x=value_field,
                hue=cat_field,
                palette=palette,
                bw_adjust=.5,
                figsize=(8, 4),
                title=_title)
            output_fig(_title)

            if ROC:
                for cat in cur_items.dtypes[cat_field].categories:
                    plot_ROC(
                        cur_items[cat_field],
                        -cur_items[value_field],
                        pos_label=cat,
                        label=f"{cat} vs others",
                        color=palette[cat])
                plot_ROC_end(title=_title + 'ROC')

        for cat in cur_items.dtypes[cat_field].categories:
            ttest(cur_items.loc[cur_items[cat_field] == cat, value_field], cat, ttest_alternative)
        ttest(cur_items[value_field], "all tokens", ttest_alternative)


value_attr = 'loss diff'
ttest_alternative = 'two-sided'
title = 'Loss Difference'


# use the first name as the baseline by default
baseline_name = names[0]
model_names = ['LSTM Captioning']#list(filter(lambda name: name != baseline_name and 'contrastive' not in name.lower(), names))

for split in ['val']:
    print('Type level:')
    token_items = split_items[split].token_items
    analyze_value(
        token_items[token_items['cnt'] >= 2],
        value_attr,
        model_names,
        title='Type Level ' + title,
        hlines=[0],
        ttest_alternative=ttest_alternative,
    )

    all_token_items = split_items[split].all_token_items
    if all_token_items is not None:
        print('=' * 100)
        print('Token level:')
        analyze_value(
            all_token_items.reset_index(drop=True),
            value_attr,
            model_names,
            title='Token Level ' + title,
            hlines=[0],
            plotting=False,
            ttest_alternative=ttest_alternative,
        )

Type level:
.               #: 65
adjective       #: 44
adverb          #: 45
cardinal number #: 11
function word   #: 82
noun            #: 220
verb            #: 150
plotting 617/617 = 100.00% items...
saving plot Type Level Loss Difference
LSTM Captioning:
plotting 617/617 = 100.00% items...
saving plot LSTM Captioning Type Level Loss Difference Distribution
.               #examples:    65 mean: -0.16 t-test result: statistic:  -2.09 pvalue: 0.04
adjective       #examples:    44 mean: -0.14 t-test result: statistic:  -1.75 pvalue: 0.09
adverb          #examples:    45 mean: -0.14 t-test result: statistic:  -2.23 pvalue: 0.03
cardinal number #examples:    11 mean: -0.03 t-test result: statistic:  -0.18 pvalue: 0.86
function word   #examples:    82 mean: -0.13 t-test result: statistic:  -3.25 pvalue: 0.00
noun            #examples:   220 mean: -0.51 t-test result: statistic:  -9.44 pvalue: 0.00
verb            #examples:   150 mean: -0.29 t-test result: statistic:  -5.58 pvalue: 0.00

<Figure size 2400x1200 with 0 Axes>

<Figure size 2400x1200 with 0 Axes>

In [13]:
def analyze_subcat_value_diff(
    all_items,
    subcats,
    names,
    baseline_name=None,
    model_names=None,
    subcats_name='Subcategories',
    subcat_name='subcategory',
    word_name='word',
    given_subcat=False,
    plotting_subcats=False,
    ttest_alternative='two-sided'
):
    # use the first name as the baseline by default
    if baseline_name is None:
        baseline_name = names[0]
    if model_names is None:
        model_names = list(filter(lambda name: name != baseline_name and 'contrastive' not in name.lower(), names))

    subcat2word_ids = {
        subcat_name: [word2idx[word] for word in subcat_words if word in word2idx]
        for subcat_name, subcat_words in subcats.items()
    }
    all_subcat_words = list(filter(
        lambda word: word in word2idx,
        itertools.chain.from_iterable(subcat_words for subcat_name, subcat_words in subcats.items())))
    all_subcat_word_ids = [word2idx[word] for word in all_subcat_words]

    p_field = 'prob'
    ps_field = 'probs'
    #subcat_field = 'subcat'  # use global subcat_field

    items = all_items[all_items[token_field].isin(all_subcat_words)].copy()
    items[subcat_field] = items[token_field].map(word2subcat).astype('category')
    print(f'#: {len(items)}')

    p_cats_field = 'prob cats'
    p_subcat_field = 'prob subcat'
    p_subcat_given_cats_field = 'prob subcat given cats'
    for name in names:
        name_ps_field = f'{name} {ps_field}'
        name_p_cats_field = f'{name} {p_cats_field}'
        name_p_subcat_field = f'{name} {p_subcat_field}'
        name_p_subcat_given_cats_field = f'{name} {p_subcat_given_cats_field}'
        items[name_p_cats_field] = items[name_ps_field].map(lambda p: np.sum(p[all_subcat_word_ids]))
        items[name_p_subcat_field] = items.apply(lambda item: np.sum(item[name_ps_field][subcat2word_ids[item[subcat_field]]]), axis=1)
        items[name_p_subcat_given_cats_field] = items[name_p_subcat_field] / items[name_p_cats_field]
        if name != baseline_name:
            extend_items_value_diff(items, name, baseline_name, p_subcat_given_cats_field)
    analyze_value(
        items.reset_index(drop=True),
        p_subcat_given_cats_field + ' diff',
        model_names,
        cat_field=subcat_field,
        title=f'P({subcat_name} | {subcats_name}) Difference',
        hlines=[0],
        ttest_alternative=ttest_alternative,
    )

    if given_subcat:
        for subcat_name, subcat_words in subcats.items():
            items = all_items[all_items[token_field].isin(subcat_words)].copy()
            items[token_field] = items[token_field].astype('category')
            print(f'{subcat_name:14} #: {len(items)}')
            palette = sns.color_palette('tab20')
            palette = {word: palette[i % len(palette)] for i, word in enumerate(subcat_words)}
            p_subcat_field = f'prob {subcat_name}'
            p_given_subcat_field = f'prob given {subcat_name}'
            for name in names:
                name_p_field = f'{name} {p_field}'
                name_ps_field = f'{name} {ps_field}'
                name_p_subcat_field = f'{name} {p_subcat_field}'
                name_p_given_subcat_field = f'{name} {p_given_subcat_field}'
                items[name_p_subcat_field] = items[name_ps_field].map(lambda p: np.sum(p[subcat2word_ids[subcat_name]]))
                items[name_p_given_subcat_field] = items[name_p_field] / items[name_p_subcat_field]
                if name != baseline_name:
                    extend_items_value_diff(items, name, baseline_name, p_given_subcat_field)
            analyze_value(
                items.reset_index(drop=True),
                p_given_subcat_field + ' diff',
                model_names,
                cat_field=token_field,
                palette=palette,
                title=f'P({word_name} | {subcat_name}) Difference',
                hlines=[0],
                plotting=plotting_subcats,
                ttest_alternative=ttest_alternative,
            )


subcats_name = {
    'noun': 'Noun Semantic Subcategories',
    'verb': 'Verb Transitivity SubCategories'
}


for split in ['val']:
    for pos, subcats in pos_subcats.items():
        items = split_items[split].all_token_items
        items = items[items['syntactic category'].isin([pos])].copy()
        analyze_subcat_value_diff(
            items,
            subcats,
            names,
            baseline_name=baseline_name,
            model_names=model_names,
            subcats_name=subcats_name[pos],
            given_subcat=False,
        )

#: 392
animals         #: 99
body_parts      #: 35
clothing        #: 40
food_drink      #: 74
games_routines  #: 4
household       #: 35
places          #: 10
toys            #: 49
vehicles        #: 46
plotting 392/392 = 100.00% items...
saving plot P(subcategory | Noun Semantic Subcategories) Difference
LSTM Captioning:
plotting 392/392 = 100.00% items...
saving plot LSTM Captioning P(subcategory | Noun Semantic Subcategories) Difference Distribution
animals         #examples:    99 mean:  0.07 t-test result: statistic:   4.46 pvalue: 0.00
body_parts      #examples:    35 mean:  0.01 t-test result: statistic:   0.23 pvalue: 0.82
clothing        #examples:    40 mean:  0.04 t-test result: statistic:   2.11 pvalue: 0.04
food_drink      #examples:    74 mean:  0.10 t-test result: statistic:   4.20 pvalue: 0.00
games_routines  #examples:     4 mean:  0.22 t-test result: statistic:   1.99 pvalue: 0.14
household       #examples:    35 mean:  0.07 t-test result: statistic:   3.16 pvalue: 0

<Figure size 2400x1200 with 0 Axes>

<Figure size 2400x1200 with 0 Axes>

In [14]:
pos_field = "syntactic category"

baseline_name = names[0]

item_combinations = lambda split: {
#   "All":                 (split_all_pos_items[split], None),
#   "Most Frequent Words": (split_all_pos_items[split], len(used_poses) * interleaving_step * (25 // interleaving_step)),
#   "All Noun vs Verb":    (split_pos_pos_items[split][('noun', 'verb')], None),
    "Noun vs Verb":        (split_pos_pos_items[split][('noun', 'verb')], 2 * interleaving_step * (25 // interleaving_step)),
    "Semantic Categories": (get_subcat_items(split_pos_items[split]['noun'], 'noun'), 10000),
    "Verb Transitivity":   (get_subcat_items(split_pos_items[split]['verb'], 'verb'), 10000),
}

for split in ['train']:
    for name in names:
        print(f'{name}:')
        loss_field = f'{name} loss'
        loss_diff_field = f'{name} loss diff'
        fields = [pos_field, "cnt", "logcnt", conc_field, "AnimPhysical", "AnimMental", "Category", "AoA", loss_field, loss_diff_field]

        for items_name, (items, n_items) in item_combinations(split).items():
            print(f'{items_name}:')
            token_kwargs = {'fontsize': 'small'} if n_items is not None else None
            pos_hue = subcat_field if items_name in ["Semantic Categories", "Verb Transitivity"] else pos_field
            items = items.copy()
            items[pos_hue] = items[pos_hue].cat.remove_unused_categories()

            title_prefix = f"{name} {repres_name} {items_name} "
            for hue, x, y, xlabel, ylabel, axis_option, title, plot_reg in [
                (pos_hue, f"{name} tsne 0", f"{name} tsne 1", "", "", "off", title_prefix + "t-SNE", False),
                #(pos_hue, f"{name} pca 0", f"{name} pca 1", "principal component 1", "principal component 2", "on", title_prefix + "PCA", False),
                #("logcnt", f"{name} pca 0", f"{name} pca 1", "principal component 1", "principal component 2", "on", title_prefix + "PCA", False),
                #(pos_hue, f"{name} pca 0", "logcnt", "principal component 1", "log frequency", "on", title_prefix + "Correlation between principal component 1 and log frequency", True),
                (pos_hue, f"{name} pca 1", f"{name} pca 2", "principal component 2", "principal component 3", "on", title_prefix + "PCA", False),
            ]:
                kwargs = {
                    key: globals()[key]
                    for key in ['x', 'y', 'n_items', 'token_kwargs', 'xlabel', 'ylabel', 'axis_option', 'title', 'figsize']
                }
                plot(sns.scatterplot, items, hue=hue, **kwargs)
                output_fig(title)
                if plot_reg:
                    plot(sns.regplot, items, **kwargs)
                    output_fig(title + ' regression')

for split in ['val'][:0]:
    for name in filter(lambda name: name != baseline_name, names):
        print(f'{name}:')
        loss_field = f'{name} loss'
        loss_diff_field = f'{name} loss diff'

        item_combinations_split = item_combinations(split)
        item_combinations_split = item_combinations_split[1:2] + item_combinations_split[4:6]
        for items_name, (items, n_items) in item_combinations_split.items():
            figname_prefix = f'{name} {items_name} '
            print(f'{items_name}:')
            token_kwargs = {'fontsize': 'small'} if n_items is not None else None
            pos_hue = subcat_field if items_name in ["Semantic Categories", "Verb Transitivity"] else pos_field

            loss_diff_items = items.sort_values(loss_diff_field)
            if n_items is not None:
                loss_diff_items = loss_diff_items[loss_diff_items['cnt'] >= 5]
                if n_items * 2 >= len(loss_diff_items):
                    n_items = len(loss_diff_items)
            plot(sns.scatterplot, loss_diff_items, x=conc_field, y=loss_diff_field, hue=pos_hue, n_items=n_items, hlines=[0], token_kwargs=token_kwargs, title="vs", figsize=figsize)
            output_fig(figname_prefix + f'Concreteness vs loss diff {n_items} words')
            if n_items is None and False:
                plot(sns.regplot, loss_diff_items, x=conc_field, y=loss_diff_field, n_items=n_items, hlines=[0], token_kwargs=token_kwargs, title="vs", figsize=figsize)
                output_fig(figname_prefix + f'Concrenteness vs loss diff all words')
                #plot(sns.catplot, loss_diff_items, x=pos_field, y=loss_diff_field, n_items=n_items, color="b", hlines=[0], title="vs", figsize=figsize,) #kind="violin", inner="stick",
                #output_Fig(figname_prefix + f'{pos_field} vs loss diff all words')
            if n_items is not None and n_items < len(loss_diff_items):
                print('highest:')
                plot(sns.scatterplot, loss_diff_items[::-1], x=conc_field, y=loss_diff_field, hue=pos_hue, n_items=n_items, hlines=[0], token_kwargs=token_kwargs, title="vs", figsize=figsize)
                output_fig(figname_prefix + 'Concreness vs loss diff highest words')

LSTM:
Noun vs Verb:
plotting 50/1596 = 3.13% items...
saving plot LSTM Embedding Noun vs Verb t-SNE
plotting 50/1596 = 3.13% items...
saving plot LSTM Embedding Noun vs Verb PCA
Semantic Categories:
plotting 48/48 = 100.00% items...
saving plot LSTM Embedding Semantic Categories t-SNE
plotting 48/48 = 100.00% items...
saving plot LSTM Embedding Semantic Categories PCA
Verb Transitivity:
plotting 50/50 = 100.00% items...
saving plot LSTM Embedding Verb Transitivity t-SNE
plotting 50/50 = 100.00% items...
saving plot LSTM Embedding Verb Transitivity PCA
LSTM Captioning:
Noun vs Verb:
plotting 50/1596 = 3.13% items...
saving plot LSTM Captioning Embedding Noun vs Verb t-SNE
plotting 50/1596 = 3.13% items...
saving plot LSTM Captioning Embedding Noun vs Verb PCA
Semantic Categories:
plotting 48/48 = 100.00% items...
saving plot LSTM Captioning Embedding Semantic Categories t-SNE
plotting 48/48 = 100.00% items...
saving plot LSTM Captioning Embedding Semantic Categories PCA
Verb Transitivit

<Figure size 2400x2100 with 0 Axes>