In [1]:
import json
import random
from collections import defaultdict

In [2]:
random.seed(42)

In [3]:
input_json = 'data/convai1.spacy.dialogact.discourse.dialogtagger.3009.json'
n = 4
# input_json = 'data/convai2.spacy.dialogact.discourse.dialogtagger.0110.json'
# input_json = 'data/multi-woz2.spacy.dialogact.discourse.dialogtagger.0110.json'

In [4]:
with open(input_json, 'r') as f:
    dialogs = json.load(f)

In [5]:
def get_features(utterance):
    dialog_act_features = [p[0] for p in utterance['predictions'] if '_dci' in p[0]]
    pos_features = "|".join([feats['pos'] for feats in utterance['features_dict']])
    single_discourse_type = utterance['single_discourse_type']
    pair_discourse_type = utterance.get('pair_discourse_type')
    
    dialog_tagger_features = [f"dim_{p['dimension']} comm_func_{p['communicative_function']}" for p in utterance['SVM_predictions']]

    features = dialog_act_features  + dialog_tagger_features + [single_discourse_type] #+ [pos_features]
    if pair_discourse_type and pair_discourse_type != 'PAIR_NONE':
        features += [pair_discourse_type]
    return features

def get_thread_key(n):    
    key = f'thread{n}'
    return key

# def get_thread_key(n):
#     if n == 1:
#         key = 'thread'
#     else:
#         key = f'thread{n}'
#     return key

# def get_features(utterance):
#     dialog_act_features = [p[0] for p in utterance['predictions']]
#     pos_features = "|".join([feats['pos'] for feats in utterance['features_dict']])
#     single_discourse_type = utterance['single_discourse_type']
#     pair_discourse_type = utterance.get('pair_discourse_type')
    
#     dialog_tagger_features = [f"dim_{p['dimension']} comm_func_{p['communicative_function']}" for p in utterance['SVM_predictions']]

#     features = dialog_act_features  + dialog_tagger_features + [single_discourse_type] + [pos_features]
#     if pair_discourse_type:
#         features += [pair_discourse_type]
#     return features

In [6]:
for dialog_id, dialog in dialogs.items():
    thread = dialog['thread']
    for row in thread:
        features = get_features(row)
        row['final_features'] = frozenset(features)

In [7]:
def collect_clusters(dialogs):
    clusters = set()
    for dialog_id, dialog in dialogs.items():
        thread = dialog['thread']
        for ind, row in enumerate(thread):
            clusters.add(row['final_features'])
    clusters = sorted(list(clusters), key=lambda x: str(sorted(list(x))))
    clusters_index = {f: i for i, f in enumerate(clusters)}
    for dialog_id, dialog in dialogs.items():
        thread = dialog['thread']
        for ind, row in enumerate(thread):
            row['cluster_id'] = clusters_index[row['final_features']]
    return clusters_index

In [8]:
def collect_n_cluster(dialogs, n):
    clusters = set()    
    thread_key = get_thread_key(n)
    for dialog_id, dialog in dialogs.items():
        thread = dialog['thread']
        n_thread = []
        for ind, row in enumerate(thread):
            if ind > n - 2:
                new_row = {}
                new_cluster = frozenset()
                prev_texts = []
                for j in range(0, n):
                    prev_row = thread[ind - j]
                    new_cluster = frozenset.union(new_cluster, prev_row['final_features'])
                    prev_texts.append(prev_row['text'])
                new_row['final_features'] = new_cluster
                prev_texts.reverse()
                for j, text in enumerate(prev_texts):
                    new_row[f'text{j}'] = text
                clusters.add(new_cluster)
                n_thread.append(new_row)
        if n_thread:
            dialog[thread_key] = n_thread
    clusters = sorted(list(clusters), key=lambda x: str(sorted(list(x))))
    clusters_index = {f: i for i, f in enumerate(clusters)}
    at_least_one_with_thread_key = False
    for dialog_id, dialog in dialogs.items():
        if thread_key not in dialog:
            continue
        else:
            at_least_one_with_thread_key = True
        thread = dialog[thread_key]
        for ind, row in enumerate(thread):
            row['cluster_id'] = clusters_index[row['final_features']]
    assert at_least_one_with_thread_key
    return clusters_index

In [9]:
def jaccard_similarity(s1, s2):
    res = round(len(s1.intersection(s2)) / len(s1.union(s2)), 5)
    if res < 0:
        res = 0
    if res > 1:
        res = 1
    return res

In [10]:
clusters = collect_n_cluster(dialogs, n)

In [11]:
uno_clusters = collect_clusters(dialogs)

In [12]:
if n == 1:
    clusters = collect_clusters(dialogs)
else:
    clusters = collect_n_cluster(dialogs, n)

In [13]:
len(clusters)

735

In [14]:
def get_sims_info(clusters, for_hist=False):
    used_inds = set()
    sims_info = []
    assert len(clusters.values()) == len(set(clusters.values()))
    reverse_index = {i: c for c, i in sorted(clusters.items(), key=lambda x: x[1])}            
    for i in range(len(clusters)):
        max_sim = 0
        max_ind = 0
        for j in range(i+1, len(clusters)):        
            cur_sim = jaccard_similarity(reverse_index[i], reverse_index[j])
            if max_sim <= cur_sim:
                max_sim = cur_sim
                max_ind = j   
        if (i not in used_inds) and (max_ind not in used_inds) and not for_hist:
            sims_info.append({'source_ind': i, 'target_ind': max_ind, 'sim': max_sim})
            used_inds.add(i)
            used_inds.add(max_ind)
        elif for_hist:
            sims_info.append({'source_ind': i, 'target_ind': max_ind, 'sim': max_sim})
    return sims_info

In [15]:
sims_info = get_sims_info(clusters)
source_inds = [e['source_ind'] for e in sims_info]
target_inds = [e['target_ind'] for e in sims_info]
sum(source_inds), sum(target_inds)

(83560, 130314)

In [16]:
import plotly.express as px

# Между уникальными кластерами
sims_info_for_hist = get_sims_info(clusters, True)
fig = px.histogram(sims_info_for_hist, x="sim", nbins=len(sims_info_for_hist))
fig.write_html('sims_hist.html', auto_open=True)

In [17]:
min_thresh = 0.74

In [18]:
def filter_sims_and_merge_clusters(sims_info, clusters, threshold, dialogs, n):
    new_clusters = {}
    cluster_id = 0
    used_inds = []
    is_merged = False
    cluster_id_mapping = {}
    assert len(clusters.values()) == len(set(clusters.values()))
    reverse_index = {i: c for c, i in sorted(clusters.items(), key=lambda x: x[1])} 
    for e in sims_info:
        if e['sim'] >= threshold:
            new_cluster = frozenset().union(reverse_index[e['source_ind']], reverse_index[e['target_ind']])
            if new_cluster not in new_clusters:                
                new_clusters[new_cluster] = cluster_id
                cluster_id += 1
                is_merged = True
            used_inds += [e['source_ind'], e['target_ind']]
            cluster_id_mapping[e['source_ind']] = new_clusters[new_cluster]
            cluster_id_mapping[e['target_ind']] = new_clusters[new_cluster]
    used_inds = set(used_inds)
    for ind in reverse_index.keys():
        if ind not in used_inds:
            if reverse_index[ind] not in new_clusters:                
                new_clusters[reverse_index[ind]] = cluster_id
                cluster_id += 1
            cluster_id_mapping[ind] = new_clusters[reverse_index[ind]]    
    if is_merged is True:
        at_least_one_with_thread_key = False
        for dialog_id, dialog in dialogs.items():
            key = get_thread_key(n)            
            if key not in dialog:
                continue
            else:
                at_least_one_with_thread_key = True
                thread = dialog[key]
            for ind, row in enumerate(thread):
                row['cluster_id'] = cluster_id_mapping[row['cluster_id']]
        assert at_least_one_with_thread_key
        return filter_sims_and_merge_clusters(get_sims_info(new_clusters), new_clusters, threshold, dialogs, n)
    return new_clusters

In [19]:
new_clusters = filter_sims_and_merge_clusters(sims_info, clusters, min_thresh, dialogs, n)

In [20]:
def check_correctness(new_clusters, dialogs, n=1):
    cluster_ids = set()
    key = get_thread_key(n)
    for dialog_id, dialog in dialogs.items():
        if key not in dialog:
            continue
        thread = dialog[key]
        for ind, row in enumerate(thread):
            cluster_ids.add(row['cluster_id'])
    assert len(cluster_ids) == len(new_clusters), print(len(cluster_ids), len(new_clusters))

check_correctness(new_clusters, dialogs, n)

In [21]:
print(len(clusters), len(new_clusters), list(new_clusters.items())[:3])
reverse_index = {i: c for c, i in sorted(new_clusters.items(), key=lambda x: x[1])} 

735 172 [(frozenset({'dim_Task comm_func_Commissive', 'SINGLE_VP_COORD', 'SINGLE_APPOSITION', 'SINGLE_S_COORD', 'Information_DeliveryIntent_dci', 'General_ChatIntent_dci', 'PAIR_ANAPHORA', 'dim_Feedback comm_func_Feedback', 'SINGLE_CATAPHORA', 'dim_Task comm_func_Statement', 'dim_Task comm_func_Directive', 'Information_RequestIntent_dci'}), 0), (frozenset({'dim_Feedback comm_func_Feedback', 'dim_Task comm_func_Statement', 'InteractiveIntent_dci', 'PAIR_CONN', 'SINGLE_VP_COORD', 'dim_Task comm_func_Directive', 'SINGLE_APPOSITION', 'SINGLE_S_COORD', 'General_ChatIntent_dci'}), 1), (frozenset({'SINGLE_VP_COORD', 'SINGLE_APPOSITION', 'SINGLE_S_COORD', 'dim_SocialObligationManagement comm_func_Salutation', 'Information_DeliveryIntent_dci', 'General_ChatIntent_dci', 'dim_Feedback comm_func_Feedback', 'SINGLE_CATAPHORA', 'dim_Task comm_func_Statement', 'InteractiveIntent_dci', 'dim_Task comm_func_Directive', 'Information_RequestIntent_dci'}), 2)]


In [22]:
def calc_cluster_cluster_usage_distribution(dialogs, n):
    key = get_thread_key(n)
    cluster_usage_distribution = defaultdict(int)
    for dialog_id, dialog in dialogs.items():
        if key not in dialog:
            continue
        thread = dialog[key]
        for ind, row in enumerate(thread):
            cluster_usage_distribution[row['cluster_id']] += 1
    return cluster_usage_distribution

cluster_usage_distribution = calc_cluster_cluster_usage_distribution(dialogs, n)

In [23]:
xs = [f"c_{e}" for e in list(cluster_usage_distribution.keys())]
fig = px.histogram(x=xs, y=list(cluster_usage_distribution.values()), nbins=len(cluster_usage_distribution.keys()), labels={'x': 'cluster id'})
fig.update_layout(xaxis={'categoryorder':'total descending'})
fig.write_html('usage_hist.html', auto_open=True)

In [24]:
def get_cluster_examples(dialogs, n):
    key = get_thread_key(n)
    cluster_examples = defaultdict(list)
    for dialog_id, dialog in dialogs.items():
            if key not in dialog:
                continue
            thread = dialog[key]
            for ind, row in enumerate(thread):                
                prev_texts = []
                for i in range(0, n):
                    text_key = f'text{i}'
                    prev_texts.append(row[text_key])
                cluster_examples[row['cluster_id']].append(prev_texts)
    return cluster_examples

cluster_examples = get_cluster_examples(dialogs, n)

In [25]:
top_cluster_ids = []
for cluster_id, freq in sorted(cluster_usage_distribution.items(), key=lambda x: x[1], reverse=True)[:10]:
    print(f"Cluster id: {cluster_id}; Frequency: {freq}")
    print("-------------")
    for e in sorted(reverse_index[cluster_id]):
        print(e)
    print("Samples: ")
    for example_sentences in random.sample(cluster_examples[cluster_id], 10):
        print("-")
        for i, s in enumerate(example_sentences):        
            print(f"-- S{i}: {s}")        
    print("-------------")
    print()
    top_cluster_ids.append(cluster_id)

Cluster id: 2; Frequency: 54
-------------
General_ChatIntent_dci
Information_DeliveryIntent_dci
Information_RequestIntent_dci
InteractiveIntent_dci
SINGLE_APPOSITION
SINGLE_CATAPHORA
SINGLE_S_COORD
SINGLE_VP_COORD
dim_Feedback comm_func_Feedback
dim_SocialObligationManagement comm_func_Salutation
dim_Task comm_func_Directive
dim_Task comm_func_Statement
Samples: 
-
-- S0: yo!
-- S1: What's up?
-- S2: cool, how are you?
-- S3: I'm fine.
-
-- S0: What is it to you?
-- S1: What is pythogerous theorem
-- S2: Interesting question.
-- S3: I didn't know guassian method could be used!
-
-- S0: I'm the one who's here.
-- S1: what is copyright?
-- S2: What is your purpose in asking?
-- S3: I want to know
-
-- S0: Do you think i should tone it down?
-- S1: yes please
-- S2: What a joke.
-- S3: are u telling that im a joke
-
-- S0: What a joke.
-- S1: U said that already
-- S2: Are we in a loop
-- S3: Megyn, you're a joke. you're a joke.
-
-- S0: I'm not a man.
-- S1: Huh?
-- S2: What a joke.
-- 


- General_ChatIntent_dci - confidence (freq in cluster)
- SINGLE_APPOSITION - confidence
- Turn - distribution , mean
- Попробовать двойки, четверки

- Описать все признаки

- Попробовать Левенштейна

In [26]:
reverse_index[0], reverse_index[3], reverse_index[5], reverse_index[19]

(frozenset({'General_ChatIntent_dci',
            'Information_DeliveryIntent_dci',
            'Information_RequestIntent_dci',
            'PAIR_ANAPHORA',
            'SINGLE_APPOSITION',
            'SINGLE_CATAPHORA',
            'SINGLE_S_COORD',
            'SINGLE_VP_COORD',
            'dim_Feedback comm_func_Feedback',
            'dim_Task comm_func_Commissive',
            'dim_Task comm_func_Directive',
            'dim_Task comm_func_Statement'}),
 frozenset({'General_ChatIntent_dci',
            'Information_DeliveryIntent_dci',
            'Information_RequestIntent_dci',
            'Opinion_ExpressionIntent_dci',
            'PAIR_CONN',
            'SINGLE_APPOSITION',
            'SINGLE_RELATIVE',
            'SINGLE_S_COORD',
            'SINGLE_VP_COORD',
            'dim_Feedback comm_func_Feedback',
            'dim_Task comm_func_Directive',
            'dim_Task comm_func_SetQ',
            'dim_Task comm_func_Statement'}),
 frozenset({'ClarificationIntent_dc

Имея разметку по реплике можно построить разметки по N реплик
- thread => `{'1': [{'text': blabla, cluster_id: 232}, ...], '2': [{'text1': 'blabla', 'text2': blabla, cluster_id: 211}, ...]}`
- thread =>  [{'text': blabla, cluster_id: 232, final_features_before_clustering: frozenset(..)}, ...]
- thread2 => [{'text1': 'blabla', 'text2': blabla, cluster_id: 211, final_features_before_clustering: frozenset(..)}, ...]