In [1]:
import json
from collections import defaultdict

In [2]:
input_json = 'data/convai1.spacy.dialogact.discourse.dialogtagger.3009.json'

In [3]:
with open(input_json, 'r') as f:
    dialogs = json.load(f)

In [4]:
def get_features(utterance):
    dialog_act_features = [p[0] for p in utterance['predictions'] if '_dci' in p[0]]
    pos_features = "|".join([feats['pos'] for feats in utterance['features_dict']])
    single_discourse_type = utterance['single_discourse_type']
    pair_discourse_type = utterance.get('pair_discourse_type')
    
    dialog_tagger_features = [f"dim_{p['dimension']} comm_func_{p['communicative_function']}" for p in utterance['SVM_predictions']]

    features = dialog_act_features  + dialog_tagger_features + [single_discourse_type] #+ [pos_features]
    if pair_discourse_type and pair_discourse_type != 'PAIR_NONE':
        features += [pair_discourse_type]
    return features

# def get_features(utterance):
#     dialog_act_features = [p[0] for p in utterance['predictions']]
#     pos_features = "|".join([feats['pos'] for feats in utterance['features_dict']])
#     single_discourse_type = utterance['single_discourse_type']
#     pair_discourse_type = utterance.get('pair_discourse_type')
    
#     dialog_tagger_features = [f"dim_{p['dimension']} comm_func_{p['communicative_function']}" for p in utterance['SVM_predictions']]

#     features = dialog_act_features  + dialog_tagger_features + [single_discourse_type] + [pos_features]
#     if pair_discourse_type:
#         features += [pair_discourse_type]
#     return features

In [5]:
for dialog_id, dialog in dialogs.items():
    thread = dialog['thread']
    for row in thread:
        features = get_features(row)
        row['final_features'] = frozenset(features)

In [6]:
def collect_clusters(dialogs):
    clusters = set()
    for dialog_id, dialog in dialogs.items():
        thread = dialog['thread']
        for ind, row in enumerate(thread):
            clusters.add(row['final_features'])
    # Enumerate + add indicies to row
    clusters = sorted(list(clusters))                
    clusters_index = {f: i for i, f in enumerate(clusters)}
    for dialog_id, dialog in dialogs.items():
        thread = dialog['thread']
        for ind, row in enumerate(thread):
            row['cluster_id'] = clusters_index[row['final_features']]
    return clusters_index

In [7]:
def jaccard_similarity(s1, s2):
    return len(s1.intersection(s2)) / len(s1.union(s2))

In [8]:
uno_clusters = collect_clusters(dialogs)

In [9]:
list(uno_clusters.items())[-2], len(list(uno_clusters.items()))

((frozenset({'PAIR_ANAPHORA',
             'SINGLE_S_COORD',
             'dim_Task comm_func_Statement'}),
  246),
 248)

In [10]:
def get_sims_info(clusters):
    sims_info = list(range(len(clusters)))
    reverse_index = {i: c for c, i in clusters.items()}            
    for i in range(len(clusters)):
        max_sim = 0
        max_ind = 0
        for j in range(i+1, len(clusters)):        
            cur_sim = jaccard_similarity(reverse_index[i], reverse_index[j])
            if max_sim < cur_sim:
                max_sim = cur_sim
                max_ind = j
        sims_info[i] = {'source_ind': i, 'target_ind': max_ind, 'sim': max_sim}
    return sims_info

In [11]:
sims_info = get_sims_info(uno_clusters)
sims_info[:2]

[{'source_ind': 0, 'target_ind': 36, 'sim': 0.6666666666666666},
 {'source_ind': 1, 'target_ind': 27, 'sim': 0.8}]

In [12]:
import plotly.express as px

# Между уникальными кластерами
fig = px.histogram(sims_info, x="sim")
fig.write_html('first_figure.html', auto_open=True)

In [13]:
min_thresh = 0.5

In [14]:
def filter_sims_and_merge_clusters(sims_info, clusters, threshold, dialogs):
    new_clusters = {}
    cluster_id = 0
    used_inds = []
    is_merged = False
    cluster_id_mapping = {}
    reverse_index = {i: c for c, i in clusters.items()}
    for e in sims_info:
        if e['sim'] >= threshold:
            new_cluster = reverse_index[e['source_ind']].union(reverse_index[e['target_ind']])
            if new_cluster not in new_clusters:                
                new_clusters[new_cluster] = cluster_id
                cluster_id += 1
            used_inds += [e['source_ind'], e['target_ind']]
            is_merged = True
            cluster_id_mapping[e['source_ind']] = new_clusters[new_cluster]
            cluster_id_mapping[e['target_ind']] = new_clusters[new_cluster]
    used_inds = set(used_inds)
    for ind in reverse_index.keys():
        if ind not in used_inds:
            if reverse_index[ind] not in new_clusters:                
                new_clusters[reverse_index[ind]] = cluster_id
                cluster_id += 1
            cluster_id_mapping[ind] = new_clusters[reverse_index[ind]]    
    if is_merged is True:
        for dialog_id, dialog in dialogs.items():
            thread = dialog['thread']
            for ind, row in enumerate(thread):
                row['cluster_id'] = cluster_id_mapping[row['cluster_id']]
        return filter_sims_and_merge_clusters(get_sims_info(new_clusters), new_clusters, threshold, dialogs)
    return list(set(new_clusters))

In [15]:
new_clusters = filter_sims_and_merge_clusters(sims_info, uno_clusters, min_thresh, dialogs)

In [16]:
def check_correctness(new_clusters, dialogs):
    cluster_ids = set()
    for dialog_id, dialog in dialogs.items():
        thread = dialog['thread']
        for ind, row in enumerate(thread):
            cluster_ids.add(row['cluster_id'])
    assert len(cluster_ids) == len(new_clusters)

check_correctness(new_clusters, dialogs)

In [17]:
len(uno_clusters), len(new_clusters), new_clusters[:10]

(248,
 12,
 [frozenset({'Information_RequestIntent_dci',
             'PAIR_ANAPHORA',
             'SINGLE_CONN_START',
             'SINGLE_S_COORD',
             'User_InstructionIntent_dci',
             'dim_Task comm_func_Statement'}),
  frozenset({'Opinion_ExpressionIntent_dci',
             'PAIR_ANAPHORA',
             'SINGLE_APPOSITION',
             'dim_Task comm_func_SetQ'}),
  frozenset({'InteractiveIntent_dci',
             'SINGLE_VP_COORD',
             'dim_Task comm_func_Statement'}),
  frozenset({'ClarificationIntent_dci',
             'General_ChatIntent_dci',
             'Information_DeliveryIntent_dci',
             'Information_RequestIntent_dci',
             'InteractiveIntent_dci',
             'Opinion_ExpressionIntent_dci',
             'Opinion_RequestIntent_dci',
             'PAIR_ANAPHORA',
             'PAIR_CONN',
             'PAIR_CONN_ANAPHORA',
             'SINGLE_APPOSITION',
             'SINGLE_CATAPHORA',
             'SINGLE_CONN_INNER',
 