In [58]:
import json
import random
from collections import defaultdict

In [59]:
random.seed(42)

In [60]:
input_json = 'data/topicalchat.train.spacy.dialogact.discourse.topicmodel.0310.json'
n = 2
# input_json = 'data/convai2.spacy.dialogact.discourse.dialogtagger.0110.json'
# input_json = 'data/multi-woz2.spacy.dialogact.discourse.dialogtagger.0110.json'

In [61]:
with open(input_json, 'r') as f:
    dialogs = json.load(f)

In [62]:
max([len(dialog['thread']) for dialog_id, dialog in dialogs.items()])    

53

In [63]:
def get_features(utterance):
    dialog_act_features = [p[0] for p in utterance.get('predictions', utterance.get('cobot_predictions')) if '_dci' in p[0]]
    pos_features = "|".join([feats['pos'] for feats in utterance['features_dict']])
    single_discourse_type = utterance['single_discourse_type']
    pair_discourse_type = utterance.get('pair_discourse_type')
    
    dialog_tagger_features = [f"dim_{p['dimension']} comm_func_{p['communicative_function']}" for p in utterance['SVM_predictions']]
    topic_model_features = [f[0] for f in utterance['topic_model_features']]

    features = dialog_act_features  + dialog_tagger_features + [single_discourse_type] + topic_model_features  #+ [pos_features]    
    features = [single_discourse_type]
    if pair_discourse_type and pair_discourse_type != 'PAIR_NONE':
        features += [pair_discourse_type]
#     features = topic_model_features
    return features

def get_thread_key(n):    
    key = f'thread{n}'
    return key

# def get_thread_key(n):
#     if n == 1:
#         key = 'thread'
#     else:
#         key = f'thread{n}'
#     return key

# def get_features(utterance):
#     dialog_act_features = [p[0] for p in utterance['predictions']]
#     pos_features = "|".join([feats['pos'] for feats in utterance['features_dict']])
#     single_discourse_type = utterance['single_discourse_type']
#     pair_discourse_type = utterance.get('pair_discourse_type')
    
#     dialog_tagger_features = [f"dim_{p['dimension']} comm_func_{p['communicative_function']}" for p in utterance['SVM_predictions']]

#     features = dialog_act_features  + dialog_tagger_features + [single_discourse_type] + [pos_features]
#     if pair_discourse_type:
#         features += [pair_discourse_type]
#     return features

In [64]:
for dialog_id, dialog in dialogs.items():
    thread = dialog['thread']
    for row in thread:
        features = get_features(row)
        row['final_features'] = frozenset(features)

In [65]:
def collect_clusters(dialogs):
    clusters = set()
    for dialog_id, dialog in dialogs.items():
        thread = dialog['thread']
        for ind, row in enumerate(thread):
            clusters.add(row['final_features'])
    clusters = sorted(list(clusters), key=lambda x: str(sorted(list(x))))
    clusters_index = {f: i for i, f in enumerate(clusters)}
    for dialog_id, dialog in dialogs.items():
        thread = dialog['thread']
        for ind, row in enumerate(thread):
            row['cluster_id'] = clusters_index[row['final_features']]
    return clusters_index

In [66]:
def collect_n_cluster(dialogs, n):
    clusters = set()    
    thread_key = get_thread_key(n)
    for dialog_id, dialog in dialogs.items():
        thread = dialog['thread']
        n_thread = []
        for ind, row in enumerate(thread):
            if ind > n - 2:
                new_row = {}
                new_cluster = frozenset()
                prev_texts = []
                for j in range(0, n):
                    prev_row = thread[ind - j]
                    new_cluster = frozenset.union(new_cluster, prev_row['final_features'])
                    prev_texts.append(prev_row['text'])
                new_row['final_features'] = new_cluster
                prev_texts.reverse()
                for j, text in enumerate(prev_texts):
                    new_row[f'text{j}'] = text
                clusters.add(new_cluster)
                n_thread.append(new_row)
        if n_thread:
            dialog[thread_key] = n_thread
    clusters = sorted(list(clusters), key=lambda x: str(sorted(list(x))))
    clusters_index = {f: i for i, f in enumerate(clusters)}
    at_least_one_with_thread_key = False
    for dialog_id, dialog in dialogs.items():
        if thread_key not in dialog:
            continue
        else:
            at_least_one_with_thread_key = True
        thread = dialog[thread_key]
        for ind, row in enumerate(thread):
            row['cluster_id'] = clusters_index[row['final_features']]
    assert at_least_one_with_thread_key
    return clusters_index

In [67]:
def jaccard_similarity(s1, s2):
    res = round(len(s1.intersection(s2)) / len(s1.union(s2)), 5)
    if res < 0:
        res = 0
    if res > 1:
        res = 1
    return res

In [68]:
clusters = collect_n_cluster(dialogs, n)

In [69]:
len(clusters)

270

In [70]:
def get_sims_info(clusters, for_hist=False):
    used_inds = set()
    sims_info = []
    assert len(clusters.values()) == len(set(clusters.values()))
    reverse_index = {i: c for c, i in sorted(clusters.items(), key=lambda x: x[1])}            
    for i in range(len(clusters)):
        max_sim = 0
        max_ind = 0
        for j in range(i+1, len(clusters)):        
            cur_sim = jaccard_similarity(reverse_index[i], reverse_index[j])
            if max_sim <= cur_sim:
                max_sim = cur_sim
                max_ind = j   
        if (i not in used_inds) and (max_ind not in used_inds) and not for_hist:
            sims_info.append({'source_ind': i, 'target_ind': max_ind, 'sim': max_sim})
            used_inds.add(i)
            used_inds.add(max_ind)
        elif for_hist:
            sims_info.append({'source_ind': i, 'target_ind': max_ind, 'sim': max_sim})
    return sims_info

In [71]:
sims_info = get_sims_info(clusters)

In [72]:
import plotly.express as px

# Между уникальными кластерами
sims_info_for_hist = get_sims_info(clusters, True)
fig = px.histogram(sims_info_for_hist, x="sim", nbins=len(sims_info_for_hist))
fig.write_html('sims_hist.html', auto_open=True)

In [73]:
min_thresh = 0.74

In [74]:
def filter_sims_and_merge_clusters(sims_info, clusters, threshold, dialogs, n):
    new_clusters = {}
    cluster_id = 0
    used_inds = []
    is_merged = False
    cluster_id_mapping = {}
    assert len(clusters.values()) == len(set(clusters.values()))
    reverse_index = {i: c for c, i in sorted(clusters.items(), key=lambda x: x[1])} 
    for e in sims_info:
        if e['sim'] >= threshold:
            new_cluster = frozenset().union(reverse_index[e['source_ind']], reverse_index[e['target_ind']])
            if new_cluster not in new_clusters:                
                new_clusters[new_cluster] = cluster_id
                cluster_id += 1
                is_merged = True
            used_inds += [e['source_ind'], e['target_ind']]
            cluster_id_mapping[e['source_ind']] = new_clusters[new_cluster]
            cluster_id_mapping[e['target_ind']] = new_clusters[new_cluster]
    used_inds = set(used_inds)
    for ind in reverse_index.keys():
        if ind not in used_inds:
            if reverse_index[ind] not in new_clusters:                
                new_clusters[reverse_index[ind]] = cluster_id
                cluster_id += 1
            cluster_id_mapping[ind] = new_clusters[reverse_index[ind]]    
    if is_merged is True:
        at_least_one_with_thread_key = False
        for dialog_id, dialog in dialogs.items():
            key = get_thread_key(n)            
            if key not in dialog:
                continue
            else:
                at_least_one_with_thread_key = True
                thread = dialog[key]
            for ind, row in enumerate(thread):
                row['cluster_id'] = cluster_id_mapping[row['cluster_id']]
        assert at_least_one_with_thread_key
        return filter_sims_and_merge_clusters(get_sims_info(new_clusters), new_clusters, threshold, dialogs, n)
    return new_clusters

In [75]:
new_clusters = filter_sims_and_merge_clusters(sims_info, clusters, min_thresh, dialogs, n)

In [76]:
def check_correctness(new_clusters, dialogs, n=1):
    cluster_ids = set()
    key = get_thread_key(n)
    for dialog_id, dialog in dialogs.items():
        if key not in dialog:
            continue
        thread = dialog[key]
        for ind, row in enumerate(thread):
            cluster_ids.add(row['cluster_id'])
    assert len(cluster_ids) == len(new_clusters), print(len(cluster_ids), len(new_clusters))

check_correctness(new_clusters, dialogs, n)

In [77]:
print(len(clusters), len(new_clusters), list(new_clusters.items())[:3])
reverse_index = {i: c for c, i in sorted(new_clusters.items(), key=lambda x: x[1])} 

270 154 [(frozenset({'SINGLE_APPOSITION', 'SINGLE_CATAPHORA', 'PAIR_ANAPHORA', 'PAIR_CONN'}), 0), (frozenset({'SINGLE_APPOSITION', 'SINGLE_CONN_INNER', 'PAIR_ANAPHORA', 'PAIR_CONN'}), 1), (frozenset({'SINGLE_APPOSITION', 'PAIR_ANAPHORA', 'PAIR_CONN', 'SINGLE_S_COORD'}), 2)]


In [78]:
def calc_cluster_cluster_usage_distribution(dialogs, n):
    key = get_thread_key(n)
    cluster_usage_distribution = defaultdict(int)
    for dialog_id, dialog in dialogs.items():
        if key not in dialog:
            continue
        thread = dialog[key]
        for ind, row in enumerate(thread):
            cluster_usage_distribution[row['cluster_id']] += 1
    return cluster_usage_distribution

cluster_usage_distribution = calc_cluster_cluster_usage_distribution(dialogs, n)

In [79]:
xs = [f"c_{e}" for e in list(cluster_usage_distribution.keys())]
fig = px.histogram(x=xs, y=list(cluster_usage_distribution.values()), nbins=len(cluster_usage_distribution.keys()), labels={'x': 'cluster id'})
fig.update_layout(xaxis={'categoryorder':'total descending'})
fig.write_html('usage_hist.html', auto_open=True)

In [80]:
def get_cluster_examples(dialogs, n):
    key = get_thread_key(n)
    cluster_examples = defaultdict(list)
    for dialog_id, dialog in dialogs.items():
            if key not in dialog:
                continue
            thread = dialog[key]
            for ind, row in enumerate(thread):                
                prev_texts = []
                for i in range(0, n):
                    text_key = f'text{i}'
                    prev_texts.append(row[text_key])
                cluster_examples[row['cluster_id']].append(prev_texts)
    return cluster_examples

cluster_examples = get_cluster_examples(dialogs, n)

In [81]:
def get_cluster_by_turns_distribution(dialogs, n):
    # { Cluster_id: {turn_num: count, ...}, ... }
    key = get_thread_key(n)
    cluster_turns_distr = defaultdict(dict)
    for dialog_id, dialog in dialogs.items():
        if key not in dialog:
            continue
        thread = dialog[key]
        for ind, row in enumerate(thread):     
            if ind not in cluster_turns_distr[row['cluster_id']]:
                cluster_turns_distr[row['cluster_id']][ind] = 1
            else:
                cluster_turns_distr[row['cluster_id']][ind] += 1
    return cluster_turns_distr
cluster_turns_distr = get_cluster_by_turns_distribution(dialogs, n)

In [82]:
def get_cluster_features_distribution(dialogs, n):
    # {feature: [cluster_id1, cluster_id2, ...]}
    key = get_thread_key(n)
    features_distr = defaultdict(set)
    for dialog_id, dialog in dialogs.items():
        if key not in dialog:
            continue
        thread = dialog[key]
        for ind, row in enumerate(thread):  
            for feature in row['final_features']:
                features_distr[feature].add(row['cluster_id'])            
    return features_distr
features_distr = get_cluster_features_distribution(dialogs, n)

In [83]:
print("Help:")
print("- Turn dist calculated as: Turn X freq / Frequency")
print("- Feature confidence calculated as: 1 / number_of_clusters_that_has_this_feature")
print()
print(f"Threshold: {min_thresh}. Original clusters: {len(clusters)}. After merging: {len(new_clusters)}.")
print()
top_cluster_ids = []
for cluster_id, freq in sorted(cluster_usage_distribution.items(), key=lambda x: x[1], reverse=True)[:10]:
    total_turns_count = len(cluster_turns_distr[cluster_id].values())    
    print(f"Cluster id: {cluster_id}; Frequency: {freq}; Present in {total_turns_count} turns.")
    sorted_cluster_turns_distr = sorted(cluster_turns_distr[cluster_id].items(), key=lambda x: x[1], reverse=True)[:3]
    top_turns_count = [f"Turn {turn_num} freq: {count}, dist: {round(count / freq, 2)}" for turn_num, count in sorted_cluster_turns_distr]
    print("Top 3 freq turns: ")
    for turn_num, count in sorted_cluster_turns_distr:
        print(f" - Turn {turn_num} freq: {count}, dist: {round(count / freq, 2)}")        
    print("-------------")
    for feature in sorted(reverse_index[cluster_id], key=lambda x: 1 / len(features_distr[x]), reverse=True):        
        print(f"{feature}: {round(1 / len(features_distr[feature]), 3)}")
    print("Samples: ")
    for example_sentences in random.sample(cluster_examples[cluster_id], 10):
        print("-")
        for i, s in enumerate(example_sentences):        
            print(f"-- S{i}: {s}")        
    print("-------------")
    print()
    top_cluster_ids.append(cluster_id)

Help:
- Turn dist calculated as: Turn X freq / Frequency
- Feature confidence calculated as: 1 / number_of_clusters_that_has_this_feature

Threshold: 0.74. Original clusters: 270. After merging: 154.

Cluster id: 2; Frequency: 42079; Present in 42 turns.
Top 3 freq turns: 
 - Turn 8 freq: 2253, dist: 0.05
 - Turn 10 freq: 2215, dist: 0.05
 - Turn 6 freq: 2209, dist: 0.05
-------------
SINGLE_APPOSITION: 0.029
SINGLE_S_COORD: 0.028
PAIR_CONN: 0.014
PAIR_ANAPHORA: 0.013
Samples: 
-
-- S0: now THAT sounds boring to me
-- S1: If you were bored of your chicken, you could eat it. 
-
-- S0: I would really like to see that one!  I loved the early simpsons, the first 10 seasons are excellent. 
-- S1: yes,  love the simpsons. Ringo starr, george carlin, and alec baldwin have all narrated thomas the tank engine for at least 52 episodes each
-
-- S0: Yes..you are right. 
-- S1: What was it about again?
-
-- S0: I'm originally from England so soccer is my thing, yes, a little rugby too.  I've tried

- print report to folder (save figures, etc)
- compare for example by turn description distribution, типа такая кластеризация хорошо описывает шаги х, когда другая шаги у
- save info to further compare?