In [1]:
import json
import random
import pathlib
import joblib
from collections import defaultdict
import plotly.express as px

In [2]:
random.seed(42)

In [3]:
output_key = 'multi-woz2'
# output_key = 'topicalchat'
# input_json = 'data/multi-woz2.spacy.dialogact.discourse.dialogtagger.topicmodel.0310.json'
# input_json = 'data/convai1.spacy.dialogact.discourse.dialogtagger.topicmodel.0310.json'
# input_json = 'data/convai2.spacy.dialogact.discourse.dialogtagger.0110.json'
# input_json = 'data/multi-woz2.spacy.dialogact.discourse.dialogtagger.0110.json'

# features_name = 'dialog_tagger_features'
features_name = 'discourse_features'
# features_name = 'topic_model_features'
if output_key == 'multi-woz2':
    input_json = f'data/{output_key}.spacy.dialogact.discourse.dialogtagger.topicmodel.0310.json'
    min_thresh = 0.5 # dialog_tagger_features, n=4, n=2, n=1 (top-10 99% of data)
elif output_key == 'topicalchat':
    input_json = f'data/{output_key}.train.spacy.dialogact.discourse.topicmodel.0310.json'
    min_thresh = 0.74
    if features_name == 'topic_model_features':
        min_thresh = 0.69

n = 4


output_folder = f'data/results/{output_key}_{n}_{features_name}/'
pathlib.Path(output_folder).mkdir(parents=True, exist_ok=True)
data_path = output_folder + 'data.joblib'

In [4]:
with open(input_json, 'r') as f:
    dialogs = json.load(f)

In [5]:
max([len(dialog['thread']) for dialog_id, dialog in dialogs.items()])    

44

In [6]:
def get_features(utterance):
    dialog_act_features = [p[0] for p in utterance.get('predictions', utterance.get('cobot_predictions')) if '_dci' in p[0]]
    pos_features = "|".join([feats['pos'] for feats in utterance['features_dict']])
    single_discourse_type = utterance['single_discourse_type']
    pair_discourse_type = utterance.get('pair_discourse_type')
    
    dialog_tagger_features = [f"dim_{p['dimension']} comm_func_{p['communicative_function']}" for p in utterance['SVM_predictions']]
    topic_model_features = [f[0] for f in utterance['topic_model_features']]
    
    if features_name == 'dialog_tagger_features':
        features = dialog_tagger_features
    elif features_name == 'discourse_features':
        features = [single_discourse_type]
        if pair_discourse_type and pair_discourse_type != 'PAIR_NONE':
            features += [pair_discourse_type]
    elif features_name == 'topic_model_features':
        features = topic_model_features
    else:
        raise ArgumentError()
        features = dialog_act_features  + dialog_tagger_features + [single_discourse_type] + topic_model_features  #+ [pos_features]    
        if pair_discourse_type and pair_discourse_type != 'PAIR_NONE':
            features += [pair_discourse_type]
    return features

def get_thread_key(n):    
    key = f'thread{n}'
    return key

# def get_thread_key(n):
#     if n == 1:
#         key = 'thread'
#     else:
#         key = f'thread{n}'
#     return key

# def get_features(utterance):
#     dialog_act_features = [p[0] for p in utterance['predictions']]
#     pos_features = "|".join([feats['pos'] for feats in utterance['features_dict']])
#     single_discourse_type = utterance['single_discourse_type']
#     pair_discourse_type = utterance.get('pair_discourse_type')
    
#     dialog_tagger_features = [f"dim_{p['dimension']} comm_func_{p['communicative_function']}" for p in utterance['SVM_predictions']]

#     features = dialog_act_features  + dialog_tagger_features + [single_discourse_type] + [pos_features]
#     if pair_discourse_type:
#         features += [pair_discourse_type]
#     return features

In [7]:
for dialog_id, dialog in dialogs.items():
    thread = dialog['thread']
    for row in thread:
        features = get_features(row)
        row['final_features'] = frozenset(features)

In [8]:
def collect_clusters(dialogs):
    clusters = set()
    for dialog_id, dialog in dialogs.items():
        thread = dialog['thread']
        for ind, row in enumerate(thread):
            clusters.add(row['final_features'])
    clusters = sorted(list(clusters), key=lambda x: str(sorted(list(x))))
    clusters_index = {f: i for i, f in enumerate(clusters)}
    for dialog_id, dialog in dialogs.items():
        thread = dialog['thread']
        for ind, row in enumerate(thread):
            row['cluster_id'] = clusters_index[row['final_features']]
    return clusters_index

In [9]:
def collect_n_cluster(dialogs, n):
    clusters = set()    
    thread_key = get_thread_key(n)
    for dialog_id, dialog in dialogs.items():
        thread = dialog['thread']
        n_thread = []
        for ind, row in enumerate(thread):
            if ind > n - 2:
                new_row = {}
                new_cluster = frozenset()
                prev_texts = []
                for j in range(0, n):
                    prev_row = thread[ind - j]
                    new_cluster = frozenset.union(new_cluster, prev_row['final_features'])
                    prev_texts.append(prev_row['text'])
                new_row['final_features'] = new_cluster
                prev_texts.reverse()
                for j, text in enumerate(prev_texts):
                    new_row[f'text{j}'] = text
                clusters.add(new_cluster)
                n_thread.append(new_row)
        if n_thread:
            dialog[thread_key] = n_thread
    clusters = sorted(list(clusters), key=lambda x: str(sorted(list(x))))
    clusters_index = {f: i for i, f in enumerate(clusters)}
    at_least_one_with_thread_key = False
    for dialog_id, dialog in dialogs.items():
        if thread_key not in dialog:
            continue
        else:
            at_least_one_with_thread_key = True
        thread = dialog[thread_key]
        for ind, row in enumerate(thread):
            row['cluster_id'] = clusters_index[row['final_features']]
    assert at_least_one_with_thread_key
    return clusters_index

In [10]:
def jaccard_similarity(s1, s2):
    res = round(len(s1.intersection(s2)) / len(s1.union(s2)), 5)
    if res < 0:
        res = 0
    if res > 1:
        res = 1
    return res

In [11]:
clusters = collect_n_cluster(dialogs, n)

In [12]:
len(clusters)

394

In [13]:
def get_sims_info(clusters, for_hist=False):
    used_inds = set()
    sims_info = []
    assert len(clusters.values()) == len(set(clusters.values()))
    reverse_index = {i: c for c, i in sorted(clusters.items(), key=lambda x: x[1])}            
    for i in range(len(clusters)):
        max_sim = 0
        max_ind = 0
        for j in range(i+1, len(clusters)):        
            cur_sim = jaccard_similarity(reverse_index[i], reverse_index[j])
            if max_sim <= cur_sim:
                max_sim = cur_sim
                max_ind = j   
        if (i not in used_inds) and (max_ind not in used_inds) and not for_hist:
            sims_info.append({'source_ind': i, 'target_ind': max_ind, 'sim': max_sim})
            used_inds.add(i)
            used_inds.add(max_ind)
        elif for_hist:
            sims_info.append({'source_ind': i, 'target_ind': max_ind, 'sim': max_sim})
    return sims_info

In [14]:
sims_info = get_sims_info(clusters)

In [15]:
# Между уникальными кластерами
sims_info_for_hist = get_sims_info(clusters, True)
fig = px.histogram(sims_info_for_hist, x="sim", nbins=len(sims_info_for_hist))
fig.write_html(output_folder + 'sims_hist.html', auto_open=True)

In [16]:
def filter_sims_and_merge_clusters(sims_info, clusters, threshold, dialogs, n):
    new_clusters = {}
    cluster_id = 0
    used_inds = []
    is_merged = False
    cluster_id_mapping = {}
    assert len(clusters.values()) == len(set(clusters.values()))
    reverse_index = {i: c for c, i in sorted(clusters.items(), key=lambda x: x[1])} 
    for e in sims_info:
        if e['sim'] >= threshold:
            new_cluster = frozenset().union(reverse_index[e['source_ind']], reverse_index[e['target_ind']])
            if new_cluster not in new_clusters:                
                new_clusters[new_cluster] = cluster_id
                cluster_id += 1
                is_merged = True
            used_inds += [e['source_ind'], e['target_ind']]
            cluster_id_mapping[e['source_ind']] = new_clusters[new_cluster]
            cluster_id_mapping[e['target_ind']] = new_clusters[new_cluster]
    used_inds = set(used_inds)
    for ind in reverse_index.keys():
        if ind not in used_inds:
            if reverse_index[ind] not in new_clusters:                
                new_clusters[reverse_index[ind]] = cluster_id
                cluster_id += 1
            cluster_id_mapping[ind] = new_clusters[reverse_index[ind]]    
    if is_merged is True:
        at_least_one_with_thread_key = False
        for dialog_id, dialog in dialogs.items():
            key = get_thread_key(n)            
            if key not in dialog:
                continue
            else:
                at_least_one_with_thread_key = True
                thread = dialog[key]
            for ind, row in enumerate(thread):
                row['cluster_id'] = cluster_id_mapping[row['cluster_id']]
        assert at_least_one_with_thread_key
        return filter_sims_and_merge_clusters(get_sims_info(new_clusters), new_clusters, threshold, dialogs, n)
    return new_clusters

In [17]:
new_clusters = filter_sims_and_merge_clusters(sims_info, clusters, min_thresh, dialogs, n)

In [18]:
def check_correctness(new_clusters, dialogs, n=1):
    cluster_ids = set()
    key = get_thread_key(n)
    for dialog_id, dialog in dialogs.items():
        if key not in dialog:
            continue
        thread = dialog[key]
        for ind, row in enumerate(thread):
            cluster_ids.add(row['cluster_id'])
    assert len(cluster_ids) == len(new_clusters), print(len(cluster_ids), len(new_clusters))

check_correctness(new_clusters, dialogs, n)

In [19]:
# dump new_clusters, old_clusters, dialogs, n, min_thresh, graphs

def dump_data(n):
    compressed_dialogs = {}
    key = get_thread_key(n)
    for dialog_id, dialog in dialogs.items():
        if key not in dialog:
            continue
        thread = dialog[key]
        compressed_thread = []
        for ind, row in enumerate(thread):            
            compressed_row = {'cluster_id': row['cluster_id']}
            for i in range(n):
                compressed_row[f'text{i}'] = row[f'text{i}']
            compressed_thread.append(compressed_row)
        compressed_dialogs[dialog_id] = {'thread': compressed_thread}
    
    res = joblib.dump(
        {'new_clusters': new_clusters, 'clusters': clusters, 'n': n, 'min_thresh': min_thresh, 'dialog': compressed_dialogs},
        data_path, 
    )

    
def load_data():
    data = joblib.load(data_path)
    return data

dump_data(n)
# data = load_data()
# dialogs = data['dialogs']
# new_clusters = data['new_clusters']
# clusters = data['clusters']
# n = data['n']
# min_thresh = data['min_thresh']

In [20]:
reverse_index = {i: c for c, i in sorted(new_clusters.items(), key=lambda x: x[1])} 

In [21]:
def calc_cluster_cluster_usage_distribution(dialogs, n):
    key = get_thread_key(n)
    cluster_usage_distribution = defaultdict(int)
    for dialog_id, dialog in dialogs.items():
        if key not in dialog:
            continue
        thread = dialog[key]
        for ind, row in enumerate(thread):
            cluster_usage_distribution[row['cluster_id']] += 1
    return cluster_usage_distribution

cluster_usage_distribution = calc_cluster_cluster_usage_distribution(dialogs, n)

In [22]:
xs = [f"c_{e}" for e in list(cluster_usage_distribution.keys())]
fig = px.histogram(x=xs, y=list(cluster_usage_distribution.values()), nbins=len(cluster_usage_distribution.keys()), labels={'x': 'cluster id'})
fig.update_layout(xaxis={'categoryorder':'total descending'})
fig.write_html(output_folder + 'usage_hist.html', auto_open=True)

In [23]:
def get_cluster_examples(dialogs, n):
    key = get_thread_key(n)
    cluster_examples = defaultdict(list)
    for dialog_id, dialog in dialogs.items():
            if key not in dialog:
                continue
            thread = dialog[key]
            for ind, row in enumerate(thread):                
                prev_texts = []
                for i in range(0, n):
                    text_key = f'text{i}'
                    prev_texts.append(row[text_key])
                cluster_examples[row['cluster_id']].append(prev_texts)
    return cluster_examples

cluster_examples = get_cluster_examples(dialogs, n)

In [24]:
def get_cluster_by_turns_distribution(dialogs, n):
    # { Cluster_id: {turn_num: count, ...}, ... }
    key = get_thread_key(n)
    cluster_turns_distr = defaultdict(dict)
    for dialog_id, dialog in dialogs.items():
        if key not in dialog:
            continue
        thread = dialog[key]
        for ind, row in enumerate(thread):     
            if ind not in cluster_turns_distr[row['cluster_id']]:
                cluster_turns_distr[row['cluster_id']][ind] = 1
            else:
                cluster_turns_distr[row['cluster_id']][ind] += 1
    return cluster_turns_distr
cluster_turns_distr = get_cluster_by_turns_distribution(dialogs, n)

In [25]:
def get_cluster_features_distribution(dialogs, n):
    # {feature: [cluster_id1, cluster_id2, ...]}
    key = get_thread_key(n)
    features_distr = defaultdict(set)
    for dialog_id, dialog in dialogs.items():
        if key not in dialog:
            continue
        thread = dialog[key]
        for ind, row in enumerate(thread):  
            for feature in row['final_features']:
                features_distr[feature].add(row['cluster_id'])            
    return features_distr
features_distr = get_cluster_features_distribution(dialogs, n)

In [26]:
def calc_coverage(top_k):
    coverage = sum([freq for cluster_id, freq in sorted(cluster_usage_distribution.items(), key=lambda x: x[1], reverse=True)[:top_k]]) / sum(cluster_usage_distribution.values())
    return round(coverage, 2)*100
top_k = 10
result_report = "Help:\n"
result_report += "- Turn dist calculated as: Turn X freq / Frequency \n"
result_report += "- Feature confidence calculated as: 1 / number_of_clusters_that_has_this_feature \n"
result_report += "\n"
result_report += f"Threshold: {min_thresh}. Original clusters: {len(clusters)}. After merging: {len(new_clusters)}. \n"
result_report += f"Top-{top_k} clusters covers {calc_coverage(top_k)}% of data, top-25 {calc_coverage(25)}%, top-50 {calc_coverage(50)}% \n"
result_report += "\n"
top_cluster_ids = []
for cluster_id, freq in sorted(cluster_usage_distribution.items(), key=lambda x: x[1], reverse=True)[:top_k]:
    total_turns_count = len(cluster_turns_distr[cluster_id].values())    
    result_report += f"Cluster id: {cluster_id}; Frequency: {freq}; Present in {total_turns_count} turns.\n "
    sorted_cluster_turns_distr = sorted(cluster_turns_distr[cluster_id].items(), key=lambda x: x[1], reverse=True)[:3]
    top_turns_count = [f"Turn {turn_num} freq: {count}, dist: {round(count / freq, 2)}" for turn_num, count in sorted_cluster_turns_distr]
    result_report += "Top 3 freq turns: \n"
    for turn_num, count in sorted_cluster_turns_distr:
        result_report += f" - Turn {turn_num} freq: {count}, dist: {round(count / freq, 2)}\n"        
    result_report += "-------------\n"
    for feature in sorted(reverse_index[cluster_id], key=lambda x: 1 / len(features_distr[x]), reverse=True):        
        result_report += f"{feature}: {round(1 / len(features_distr[feature]), 3)}\n "
    result_report += "Samples: \n"
    for example_sentences in random.sample(cluster_examples[cluster_id], min(10, len(cluster_examples[cluster_id]))):
        result_report += "- \n"
        for i, s in enumerate(example_sentences):        
            result_report += f"-- S{i}: {s} \n"        
    result_report += "-------------\n"
    result_report += "\n"
    top_cluster_ids.append(cluster_id)

In [27]:
with open(output_folder + f'{output_key}_n-{n}_{features_name}_report.txt', 'w') as f:
    print(result_report, file=f)
print(result_report)

Help:
- Turn dist calculated as: Turn X freq / Frequency 
- Feature confidence calculated as: 1 / number_of_clusters_that_has_this_feature 

Threshold: 0.5. Original clusters: 394. After merging: 7. 
Top-10 clusters covers 100.0% of data, top-25 100.0%, top-50 100.0% 

Cluster id: 0; Frequency: 111520; Present in 41 turns.
 Top 3 freq turns: 
 - Turn 0 freq: 10417, dist: 0.09
 - Turn 2 freq: 10156, dist: 0.09
 - Turn 1 freq: 10150, dist: 0.09
-------------
SINGLE_CONN_INNER: 1.0
 SINGLE_APPOSITION: 0.5
 SINGLE_CONN_START: 0.5
 SINGLE_CONN_INNER_ANAPHORA: 0.5
 PAIR_CONN: 0.333
 SINGLE_VP_COORD: 0.333
 SINGLE_S_COORD_ANAPHORA: 0.333
 PAIR_CONN_ANAPHORA: 0.333
 SINGLE_CATAPHORA: 0.333
 SINGLE_S_COORD: 0.333
 SINGLE_RELATIVE: 0.25
 PAIR_ANAPHORA: 0.2
 Samples: 
- 
-- S0: their postcode is 	cb21sj 
-- S1: Thank you.  Have you heard of a particular restaurant, I think it's called Gourmet Burger Kitchen? 
-- S2: that is an expensive restaurant in centre. Would you like reservations 
-- S3: Ye

- multiwoz vs topicalchat compare
- print report to folder (save figures, etc)
- compare for example by turn description distribution, типа такая кластеризация хорошо описывает шаги х, когда другая шаги у
- save info to further compare?