# Data exploration

In [1]:
# Imports
%matplotlib inline
from matplotlib import pyplot as plt

from collections import Counter

from src.model import DLDLMTokenizer
from src.data import DLDLMCorpus

from typing import Tuple, Dict, List

  from .autonotebook import tqdm as notebook_tqdm
[nltk_data] Downloading package stopwords to
[nltk_data]     /Users/vincenzoscotti_polimi/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package punkt to
[nltk_data]     /Users/vincenzoscotti_polimi/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [6]:
tokenizer: DLDLMTokenizer = DLDLMTokenizer.from_pretrained('gpt2').extend_from_gpt2_tokenizer(0)

splits: Tuple[str, str, str] = ('train', 'validation', 'test')
corpus_splits: Dict[str, DLDLMCorpus] = {
    split: DLDLMCorpus(
        '../resources/data/raw/', tokenizer, split, '../resources/data/cache/',
        corpus_list=['dailydialog', 'empatheticdialogues', 'personachat', 'wizard_of_wikipedia'],
        max_context_length=256,
        max_response_length=128,
        count_word_tokens=True
    )
    for split in splits
}

top_n = 30

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'GPT2Tokenizer'. 
The class this function is called from is 'DLDLMTokenizer'.


In [7]:
# Get most common words per split
counts: Dict[str, Counter] = {
    split: sum((corpus_splits[split][i]['word_counts'] for i in range(len(corpus_splits[split]))), Counter())
    for split in splits
}


KeyboardInterrupt



In [None]:
# Plot word counts
figure, axes = plt.subplots(nrows=1, ncols=3, figsize=(15, 20), sharex=True)
axes = axes.flatten()
for key_idx, split in enumerate(counts):
    items, occurrences = list(zip(*counts[split].most_common(top_n)))
    ax = axes[key_idx]
    ax.barh(items, counts, height=0.7)
    ax.set_title(f"Split: {split.capitalize()}", fontdict={"fontsize": 20})
    ax.invert_yaxis()
    ax.tick_params(axis="both", which="major", labelsize=20)
    for i in "top right left".split():
        ax.spines[i].set_visible(False)
    figure.suptitle(f"Split-wise most common words", fontsize=32)
plt.subplots_adjust(top=0.90, bottom=0.05, wspace=0.90, hspace=0.3)
plt.show()

In [None]:
# Plot sorted word distribution
all_counts: Counter = sum((counts[split] for split in counts), Counter())
occurrences: List[str] = sorted(all_counts.values())

plt.figure(figsize=(20, 15))
plt.bar(range(len(occurrences)), occurrences)
plt.xlim(0,len(occurrences))
plt.ylim(1, 1e5)
plt.yscale('log')
plt.grid()
plt.show()
