In [1]:
import json
import random
import pathlib
import joblib
from collections import defaultdict
import numpy as np
import plotly.express as px

In [2]:
random.seed(42)

In [3]:
output_key = 'multi-woz2'
# output_key = 'topicalchat'
# input_json = 'data/multi-woz2.spacy.dialogact.discourse.dialogtagger.topicmodel.0310.json'
# input_json = 'data/convai1.spacy.dialogact.discourse.dialogtagger.topicmodel.0310.json'
# input_json = 'data/convai2.spacy.dialogact.discourse.dialogtagger.0110.json'
# input_json = 'data/multi-woz2.spacy.dialogact.discourse.dialogtagger.0110.json'

features_name = 'dialog_tagger_features'
# features_name = 'discourse_features'
# features_name = 'topic_model_features'
if output_key == 'multi-woz2':
    input_json = f'data/{output_key}.spacy.dialogact.discourse.dialogtagger.topicmodel.0310.json'
    min_thresh = 0.74 # dialog_tagger_features, n=4, n=2, n=1 (top-10 99% of data) (topic model n4 Threshold: 0.7. Original clusters: 11043. After merging: 1006. 
                                                                                 # Top-10 clusters covers 28.999999999999996% of data, top-25 48.0%, top-50 61.0% )
#     min_thresh = 0.6 # for topic model
elif output_key == 'topicalchat':
    input_json = f'data/{output_key}.train.spacy.dialogact.discourse.topicmodel.0310.json'
    if features_name == 'dialog_tagger_features':
        min_thresh = 0.65 
    if features_name == 'discourse_features':
        min_thresh = 0.7 
    if features_name == 'topic_model_features':
        min_thresh = 0.64

n = 1


output_folder = f'data/results/{output_key}_{n}_{features_name}/'
pathlib.Path(output_folder).mkdir(parents=True, exist_ok=True)
data_path = output_folder + 'data.joblib'

In [4]:
with open(input_json, 'r') as f:
    dialogs = json.load(f)

In [5]:
max([len(dialog['thread']) for dialog_id, dialog in dialogs.items()])    

44

In [6]:
def get_features(utterance):
    dialog_act_features = [p[0] for p in utterance.get('predictions', utterance.get('cobot_predictions')) if '_dci' in p[0]]
    pos_features = "|".join([feats['pos'] for feats in utterance['features_dict']])
    single_discourse_type = utterance['single_discourse_type']
    pair_discourse_type = utterance.get('pair_discourse_type')
    
    dialog_tagger_features = [f"dim_{p['dimension']} comm_func_{p['communicative_function']}" for p in utterance['SVM_predictions']]
    topic_model_features = [f[0] for f in utterance['topic_model_features']]
    
    if features_name == 'dialog_tagger_features':
        features = dialog_tagger_features
    elif features_name == 'discourse_features':
        features = [single_discourse_type]
        if pair_discourse_type and pair_discourse_type != 'PAIR_NONE':
            features += [pair_discourse_type]
    elif features_name == 'topic_model_features':
        features = topic_model_features
    else:
        raise ArgumentError()
        features = dialog_act_features  + dialog_tagger_features + [single_discourse_type] + topic_model_features  #+ [pos_features]    
        if pair_discourse_type and pair_discourse_type != 'PAIR_NONE':
            features += [pair_discourse_type]
#     features = dialog_tagger_features + [pair_discourse_type] + [single_discourse_type]
#     features = dialog_act_features + dialog_tagger_features
    return features

def get_thread_key(n):    
    key = f'thread{n}'
    return key

# def get_thread_key(n):
#     if n == 1:
#         key = 'thread'
#     else:
#         key = f'thread{n}'
#     return key

# def get_features(utterance):
#     dialog_act_features = [p[0] for p in utterance['predictions']]
#     pos_features = "|".join([feats['pos'] for feats in utterance['features_dict']])
#     single_discourse_type = utterance['single_discourse_type']
#     pair_discourse_type = utterance.get('pair_discourse_type')
    
#     dialog_tagger_features = [f"dim_{p['dimension']} comm_func_{p['communicative_function']}" for p in utterance['SVM_predictions']]

#     features = dialog_act_features  + dialog_tagger_features + [single_discourse_type] + [pos_features]
#     if pair_discourse_type:
#         features += [pair_discourse_type]
#     return features

In [7]:
for dialog_id, dialog in dialogs.items():
    thread = dialog['thread']
    for ind, row in enumerate(thread):
        features = get_features(row)
#         features = [f"{ind}_{e}" for e in features]
        row['final_features'] = frozenset(features)
        

In [8]:
def collect_clusters(dialogs):
    clusters = set()
    for dialog_id, dialog in dialogs.items():
        thread = dialog['thread']
        for ind, row in enumerate(thread):
            clusters.add(row['final_features'])
    clusters = sorted(list(clusters), key=lambda x: str(sorted(list(x))))
    clusters_index = {f: i for i, f in enumerate(clusters)}
    for dialog_id, dialog in dialogs.items():
        thread = dialog['thread']
        for ind, row in enumerate(thread):
            row['cluster_id'] = clusters_index[row['final_features']]
    return clusters_index

In [9]:
import copy

In [10]:
def collect_n_cluster(dialogs, n):
    clusters = set()    
    thread_key = get_thread_key(n)
    for dialog_id, dialog in dialogs.items():
        thread = dialog['thread']
        n_thread = []
        for ind, row in enumerate(thread):
            if ind > n - 2:
                new_row = {}
                new_cluster = frozenset()
                prev_texts = []
                for j in range(0, n):
                    prev_row = thread[ind - j]
                    feats = copy.deepcopy(prev_row['final_features'])
                    feats = frozenset([f"{n-j-1}_{f}" for f in feats])
                    new_cluster = frozenset.union(new_cluster, feats)
                    prev_texts.append(prev_row['text'])
                new_row['final_features'] = new_cluster                
                prev_texts.reverse()
                for j, text in enumerate(prev_texts):
                    new_row[f'text{j}'] = text
                clusters.add(new_cluster)
                n_thread.append(new_row)
        if n_thread:
            dialog[thread_key] = n_thread
    clusters = sorted(list(clusters), key=lambda x: str(sorted(list(x))))
    clusters_index = {f: i for i, f in enumerate(clusters)}
    at_least_one_with_thread_key = False
    for dialog_id, dialog in dialogs.items():
        if thread_key not in dialog:
            continue
        else:
            at_least_one_with_thread_key = True
        thread = dialog[thread_key]
        for ind, row in enumerate(thread):
            row['cluster_id'] = clusters_index[row['final_features']]
    assert at_least_one_with_thread_key
    return clusters_index

In [11]:
def jaccard_similarity(s1, s2):
    res = round(len(s1.intersection(s2)) / len(s1.union(s2)), 5)
    if res < 0:
        res = 0
    if res > 1:
        res = 1
    return res

In [12]:
clusters = collect_n_cluster(dialogs, n)

In [13]:
len(clusters)

27

In [14]:
def get_sims_info(clusters, for_hist=False):
    used_inds = set()
    sims_info = []
    assert len(clusters.values()) == len(set(clusters.values()))
    reverse_index = {i: c for c, i in sorted(clusters.items(), key=lambda x: x[1])}            
    for i in range(len(clusters)):
        max_sim = 0
        max_ind = 0
        for j in range(i+1, len(clusters)):        
            cur_sim = jaccard_similarity(reverse_index[i], reverse_index[j])
            if max_sim <= cur_sim:
                max_sim = cur_sim
                max_ind = j   
        if (i not in used_inds) and (max_ind not in used_inds) and not for_hist:
            sims_info.append({'source_ind': i, 'target_ind': max_ind, 'sim': max_sim})
            used_inds.add(i)
            used_inds.add(max_ind)
        elif for_hist:
            sims_info.append({'source_ind': i, 'target_ind': max_ind, 'sim': max_sim})
    return sims_info

In [15]:
sims_info = get_sims_info(clusters)

In [16]:
# Между уникальными кластерами
sims_info_for_hist = get_sims_info(clusters, True)
fig = px.histogram(sims_info_for_hist, x="sim", nbins=len(sims_info_for_hist))
fig.write_html(output_folder + 'sims_hist.html', auto_open=True)

In [17]:
def filter_sims_and_merge_clusters(sims_info, clusters, threshold, dialogs, n):
    new_clusters = {}
    cluster_id = 0
    used_inds = []
    is_merged = False
    cluster_id_mapping = {}
    assert len(clusters.values()) == len(set(clusters.values()))
    reverse_index = {i: c for c, i in sorted(clusters.items(), key=lambda x: x[1])} 
    for e in sims_info:
        if e['sim'] >= threshold:
            new_cluster = frozenset().union(reverse_index[e['source_ind']], reverse_index[e['target_ind']])
            if new_cluster not in new_clusters:                
                new_clusters[new_cluster] = cluster_id
                cluster_id += 1
                is_merged = True
            used_inds += [e['source_ind'], e['target_ind']]
            cluster_id_mapping[e['source_ind']] = new_clusters[new_cluster]
            cluster_id_mapping[e['target_ind']] = new_clusters[new_cluster]
    used_inds = set(used_inds)
    for ind in reverse_index.keys():
        if ind not in used_inds:
            if reverse_index[ind] not in new_clusters:                
                new_clusters[reverse_index[ind]] = cluster_id
                cluster_id += 1
            cluster_id_mapping[ind] = new_clusters[reverse_index[ind]]    
    if is_merged is True:
        at_least_one_with_thread_key = False
        for dialog_id, dialog in dialogs.items():
            key = get_thread_key(n)            
            if key not in dialog:
                continue
            else:
                at_least_one_with_thread_key = True
                thread = dialog[key]
            for ind, row in enumerate(thread):
                row['cluster_id'] = cluster_id_mapping[row['cluster_id']]
        assert at_least_one_with_thread_key
        return filter_sims_and_merge_clusters(get_sims_info(new_clusters), new_clusters, threshold, dialogs, n)
    return new_clusters

In [18]:
new_clusters = filter_sims_and_merge_clusters(sims_info, clusters, min_thresh, dialogs, n)

In [19]:
def check_correctness(new_clusters, dialogs, n=1):
    cluster_ids = set()
    key = get_thread_key(n)
    for dialog_id, dialog in dialogs.items():
        if key not in dialog:
            continue
        thread = dialog[key]
        for ind, row in enumerate(thread):
            cluster_ids.add(row['cluster_id'])
    assert len(cluster_ids) == len(new_clusters), print(len(cluster_ids), len(new_clusters))

check_correctness(new_clusters, dialogs, n)

In [20]:
# dump new_clusters, old_clusters, dialogs, n, min_thresh, graphs

def dump_data(n):
    compressed_dialogs = {}
    key = get_thread_key(n)
    for dialog_id, dialog in dialogs.items():
        if key not in dialog:
            continue
        thread = dialog[key]
        compressed_thread = []
        for ind, row in enumerate(thread):            
            compressed_row = {'cluster_id': row['cluster_id']}
            for i in range(n):
                compressed_row[f'text{i}'] = row[f'text{i}']
            compressed_thread.append(compressed_row)
        compressed_dialogs[dialog_id] = {'thread': compressed_thread}
    
    res = joblib.dump(
        {'new_clusters': new_clusters, 'clusters': clusters, 'n': n, 'min_thresh': min_thresh, 'dialog': compressed_dialogs},
        data_path, 
    )

    
def load_data():
    data = joblib.load(data_path)
    return data

dump_data(n)
# data = load_data()
# dialogs = data['dialogs']
# new_clusters = data['new_clusters']
# clusters = data['clusters']
# n = data['n']
# min_thresh = data['min_thresh']

In [21]:
reverse_index = {i: c for c, i in sorted(new_clusters.items(), key=lambda x: x[1])} 

In [22]:
def calc_cluster_cluster_usage_distribution(dialogs, n):
    key = get_thread_key(n)
    cluster_usage_distribution = defaultdict(int)
    for dialog_id, dialog in dialogs.items():
        if key not in dialog:
            continue
        thread = dialog[key]
        for ind, row in enumerate(thread):
            cluster_usage_distribution[row['cluster_id']] += 1
    return cluster_usage_distribution

cluster_usage_distribution = calc_cluster_cluster_usage_distribution(dialogs, n)

In [24]:
xs = [f"c_{e}" for e in list(cluster_usage_distribution.keys())]
fig = px.histogram(x=xs, y=list(cluster_usage_distribution.values()), nbins=len(cluster_usage_distribution.keys()), labels={'x': 'cluster id'})
fig.update_layout(xaxis={'categoryorder':'total descending'})
fig.write_html(output_folder + 'usage_hist.html', auto_open=True)

In [25]:
def get_cluster_examples(dialogs, n):
    key = get_thread_key(n)
    cluster_examples = defaultdict(list)    
    for dialog_id, dialog in dialogs.items():
            if key not in dialog:
                continue
            thread = dialog[key]
            for ind, row in enumerate(thread):                
                prev_texts = []
                for i in range(0, n):
                    text_key = f'text{i}'
                    prev_texts.append(row[text_key])
                cluster_examples[row['cluster_id']].append(prev_texts)
    return cluster_examples

cluster_examples = get_cluster_examples(dialogs, n)

In [26]:
def get_original_features(dialogs, n):
    key = get_thread_key(n)
    orig_features = defaultdict(list)
    for dialog_id, dialog in dialogs.items():
            if key not in dialog:
                continue
            thread = dialog[key]
            for ind, row in enumerate(thread):          
                orig_features[row['cluster_id']].append(row['final_features'])
    return orig_features

orig_features = get_original_features(dialogs, n)

In [27]:
def get_cluster_by_turns_distribution(dialogs, n):
    # { Cluster_id: {turn_num: count, ...}, ... }
    key = get_thread_key(n)
    cluster_turns_distr = defaultdict(dict)
    for dialog_id, dialog in dialogs.items():
        if key not in dialog:
            continue
        thread = dialog[key]
        for ind, row in enumerate(thread):     
            if ind not in cluster_turns_distr[row['cluster_id']]:
                cluster_turns_distr[row['cluster_id']][ind] = 1
            else:
                cluster_turns_distr[row['cluster_id']][ind] += 1
    return cluster_turns_distr
cluster_turns_distr = get_cluster_by_turns_distribution(dialogs, n)

In [28]:
def get_cluster_features_distribution(dialogs, n):
    # {feature: [cluster_id1, cluster_id2, ...]}
    key = get_thread_key(n)
    features_distr = defaultdict(int)
#     features_distr = defaultdict(set)
    for dialog_id, dialog in dialogs.items():
        if key not in dialog:
            continue
        thread = dialog[key]
        for ind, row in enumerate(thread):  
            for feature in row['final_features']:
                features_distr[feature] += 1
#                 features_distr[feature].add(row['cluster_id'])            
    return features_distr
features_distr = get_cluster_features_distribution(dialogs, n)

In [29]:
def calc_feature_confidence(feature, cluster_id, features_distr, cluster_examples):    
    final_cluster_size = len(cluster_examples[cluster_id])
    features_before_clustering = orig_features[cluster_id]
    n = 0
    for fs in features_before_clustering:
        if feature in fs:
            n += 1
    return round(n / final_cluster_size, 2)

In [30]:
def get_mean_and_std_turn(cluster_turns_distr, cluster_id):
    arr = []
    for turn_num, count in cluster_turns_distr[cluster_id].items():
        arr += [turn_num]*count
    arr = np.array(arr)
    return round(np.mean(arr), 1), round(np.std(arr), 2)

get_mean_and_std_turn(cluster_turns_distr, 2)

(8.9, 5.63)

In [35]:
def calc_coverage(top_k):
    coverage = sum([freq for cluster_id, freq in sorted(cluster_usage_distribution.items(), key=lambda x: x[1], reverse=True)[:top_k]]) / sum(cluster_usage_distribution.values())
    return round(coverage, 2)*100
top_k = 10
result_report = "Help:\n"
result_report += "- Turn dist calculated as: Turn X freq / Frequency \n"
result_report += "- Feature confidence calculated as: min(1, number_of_clusters_that_has_this_feature / original total feature freq) \n"
result_report += "\n"
result_report += f"Threshold: {min_thresh}. Original clusters: {len(clusters)}. After merging: {len(new_clusters)}. \n"
result_report += f"Top-{top_k} clusters covers {calc_coverage(top_k)}% of data, top-25 {calc_coverage(25)}%, top-50 {calc_coverage(50)}% \n"
result_report += "\n"
top_cluster_ids = []
for cluster_id, freq in sorted(cluster_usage_distribution.items(), key=lambda x: x[1], reverse=True)[:top_k]:
    if freq < 2:
        continue
    total_turns_count = len(cluster_turns_distr[cluster_id].values())    
    result_report += f"Cluster id: {cluster_id}; Frequency: {freq}; Present in {total_turns_count} turns.\n "
    sorted_cluster_turns_distr = sorted(cluster_turns_distr[cluster_id].items(), key=lambda x: x[1], reverse=True)[:3]
    top_turns_count = [f"Turn {turn_num} freq: {count}, dist: {round(count / freq, 2)}" for turn_num, count in sorted_cluster_turns_distr]
    mean_turn, std_turn = get_mean_and_std_turn(cluster_turns_distr, cluster_id)
    result_report += f"Mean turn: {mean_turn} +- {std_turn} \n"
    result_report += "Top 3 freq turns: \n"
    for turn_num, count in sorted_cluster_turns_distr:
        result_report += f" - Turn {turn_num} freq: {count}, dist: {round(count / freq, 2)}\n"        
    result_report += "-------------\n"
    for feature in sorted(reverse_index[cluster_id], key=lambda x: x[0]):        
        result_report += f"{feature}: {calc_feature_confidence(feature, cluster_id, features_distr, cluster_examples)}. "
        result_report += f'Orig freq: {features_distr[feature]}\n'
    result_report += "Samples: \n"
    for example_sentences in random.sample(cluster_examples[cluster_id], min(10, len(cluster_examples[cluster_id]))):
        result_report += "- \n"
        for i, s in enumerate(example_sentences):        
            result_report += f"-- S{i}: {s} \n"        
    result_report += "-------------\n"
    result_report += "\n"
    top_cluster_ids.append(cluster_id)

In [36]:
list(dialogs.keys())[:3]

['SNG01856.json', 'SNG0129.json', 'PMUL1635.json']

In [37]:
# for row in dialogs['PMUL1635.json']['thread1']:
#     print(reverse_index[row['cluster_id']], row['cluster_id'])
#     print('S0 ', row['text0'])
#     print()

In [38]:
with open(output_folder + f'{output_key}_n-{n}_{features_name}_report.txt', 'w') as f:
    print(result_report, file=f)
print(result_report)

Help:
- Turn dist calculated as: Turn X freq / Frequency 
- Feature confidence calculated as: min(1, number_of_clusters_that_has_this_feature / original total feature freq) 

Threshold: 0.74. Original clusters: 27. After merging: 27. 
Top-10 clusters covers 98.0% of data, top-25 100.0%, top-50 100.0% 

Cluster id: 25; Frequency: 81056; Present in 41 turns.
 Mean turn: 6.6 +- 5.03 
Top 3 freq turns: 
 - Turn 3 freq: 7601, dist: 0.09
 - Turn 1 freq: 7095, dist: 0.09
 - Turn 5 freq: 6952, dist: 0.09
-------------
0_dim_Task comm_func_Statement: 1.0. Orig freq: 85923
Samples: 
- 
-- S0: Yes, there are 11 museums in the centre of town. What other information would you like? 
- 
-- S0: Warkworth does offer internet. 
- 
-- S0: OK, you're booked at the Holiday Inn, reference# W1TLHAWK. Can I help you with anything else today? 
- 
-- S0: Yes please. I need it for one person on Saturday. A 3 night stay. 
- 
-- S0: I have TR6416 that will get you to Cambridge by 07:52. Do you need any tickets to