In [1]:
import json
import random
from collections import defaultdict

In [2]:
random.seed(42)

In [3]:
# input_json = 'data/convai1.spacy.dialogact.discourse.dialogtagger.3009.json'
# input_json = 'data/convai2.spacy.dialogact.discourse.dialogtagger.0110.json'
input_json = 'data/multi-woz2.spacy.dialogact.discourse.dialogtagger.0110.json'

In [4]:
with open(input_json, 'r') as f:
    dialogs = json.load(f)

In [5]:
def get_features(utterance):
    dialog_act_features = [p[0] for p in utterance['predictions'] if '_dci' in p[0]]
    pos_features = "|".join([feats['pos'] for feats in utterance['features_dict']])
    single_discourse_type = utterance['single_discourse_type']
    pair_discourse_type = utterance.get('pair_discourse_type')
    
    dialog_tagger_features = [f"dim_{p['dimension']} comm_func_{p['communicative_function']}" for p in utterance['SVM_predictions']]

    features = dialog_act_features  + dialog_tagger_features + [single_discourse_type] #+ [pos_features]
    if pair_discourse_type and pair_discourse_type != 'PAIR_NONE':
        features += [pair_discourse_type]
    return features

# def get_features(utterance):
#     dialog_act_features = [p[0] for p in utterance['predictions']]
#     pos_features = "|".join([feats['pos'] for feats in utterance['features_dict']])
#     single_discourse_type = utterance['single_discourse_type']
#     pair_discourse_type = utterance.get('pair_discourse_type')
    
#     dialog_tagger_features = [f"dim_{p['dimension']} comm_func_{p['communicative_function']}" for p in utterance['SVM_predictions']]

#     features = dialog_act_features  + dialog_tagger_features + [single_discourse_type] + [pos_features]
#     if pair_discourse_type:
#         features += [pair_discourse_type]
#     return features

In [6]:
for dialog_id, dialog in dialogs.items():
    thread = dialog['thread']
    for row in thread:
        features = get_features(row)
        row['final_features'] = frozenset(features)

In [7]:
def collect_clusters(dialogs):
    clusters = set()
    for dialog_id, dialog in dialogs.items():
        thread = dialog['thread']
        for ind, row in enumerate(thread):
            clusters.add(row['final_features'])
    clusters = sorted(list(clusters), key=lambda x: str(sorted(list(x))))
    clusters_index = {f: i for i, f in enumerate(clusters)}
    for dialog_id, dialog in dialogs.items():
        thread = dialog['thread']
        for ind, row in enumerate(thread):
            row['cluster_id'] = clusters_index[row['final_features']]
    return clusters_index

In [8]:
def jaccard_similarity(s1, s2):
    res = round(len(s1.intersection(s2)) / len(s1.union(s2)), 5)
    if res < 0:
        res = 0
    if res > 1:
        res = 1
    return res

In [9]:
uno_clusters = collect_clusters(dialogs)

In [10]:
def get_sims_info(clusters, for_hist=False):
    used_inds = set()
    sims_info = []
    assert len(clusters.values()) == len(set(clusters.values()))
    reverse_index = {i: c for c, i in sorted(clusters.items(), key=lambda x: x[1])}            
    for i in range(len(clusters)):
        max_sim = 0
        max_ind = 0
        for j in range(i+1, len(clusters)):        
            cur_sim = jaccard_similarity(reverse_index[i], reverse_index[j])
            if max_sim <= cur_sim:
                max_sim = cur_sim
                max_ind = j   
        if (i not in used_inds) and (max_ind not in used_inds) and not for_hist:
            sims_info.append({'source_ind': i, 'target_ind': max_ind, 'sim': max_sim})
            used_inds.add(i)
            used_inds.add(max_ind)
        elif for_hist:
            sims_info.append({'source_ind': i, 'target_ind': max_ind, 'sim': max_sim})
    return sims_info

In [11]:
sims_info = get_sims_info(uno_clusters)
source_inds = [e['source_ind'] for e in sims_info]
target_inds = [e['target_ind'] for e in sims_info]
sum(source_inds), sum(target_inds)

(315055, 801483)

In [12]:
import plotly.express as px

# Между уникальными кластерами
sims_info_for_hist = get_sims_info(uno_clusters, True)
fig = px.histogram(sims_info_for_hist, x="sim", nbins=len(sims_info_for_hist))
fig.write_html('sims_hist.html', auto_open=True)

In [13]:
min_thresh = 0.74

In [14]:
def filter_sims_and_merge_clusters(sims_info, clusters, threshold, dialogs):
    new_clusters = {}
    cluster_id = 0
    used_inds = []
    is_merged = False
    cluster_id_mapping = {}
    assert len(clusters.values()) == len(set(clusters.values()))
    reverse_index = {i: c for c, i in sorted(clusters.items(), key=lambda x: x[1])} 
    for e in sims_info:
        if e['sim'] >= threshold:
            new_cluster = frozenset().union(reverse_index[e['source_ind']], reverse_index[e['target_ind']])
            if new_cluster not in new_clusters:                
                new_clusters[new_cluster] = cluster_id
                cluster_id += 1
                is_merged = True
            used_inds += [e['source_ind'], e['target_ind']]
            cluster_id_mapping[e['source_ind']] = new_clusters[new_cluster]
            cluster_id_mapping[e['target_ind']] = new_clusters[new_cluster]
    used_inds = set(used_inds)
    for ind in reverse_index.keys():
        if ind not in used_inds:
            if reverse_index[ind] not in new_clusters:                
                new_clusters[reverse_index[ind]] = cluster_id
                cluster_id += 1
            cluster_id_mapping[ind] = new_clusters[reverse_index[ind]]    
    if is_merged is True:
        for dialog_id, dialog in dialogs.items():
            thread = dialog['thread']
            for ind, row in enumerate(thread):
                row['cluster_id'] = cluster_id_mapping[row['cluster_id']]
        return filter_sims_and_merge_clusters(get_sims_info(new_clusters), new_clusters, threshold, dialogs)
    return new_clusters

In [15]:
new_clusters = filter_sims_and_merge_clusters(sims_info, uno_clusters, min_thresh, dialogs)

In [16]:
def check_correctness(new_clusters, dialogs):
    cluster_ids = set()
    for dialog_id, dialog in dialogs.items():
        thread = dialog['thread']
        for ind, row in enumerate(thread):
            cluster_ids.add(row['cluster_id'])
    assert len(cluster_ids) == len(new_clusters)

check_correctness(new_clusters, dialogs)

In [17]:
print(len(uno_clusters), len(new_clusters), list(new_clusters.items())[:3])
reverse_index = {i: c for c, i in sorted(new_clusters.items(), key=lambda x: x[1])} 

1690 663 [(frozenset({'SINGLE_CATAPHORA', 'PAIR_ANAPHORA', 'SINGLE_S_COORD', 'Information_RequestIntent_dci', 'Information_DeliveryIntent_dci', 'dim_Task comm_func_Statement', 'dim_SocialObligationManagement comm_func_Thanking', 'General_ChatIntent_dci'}), 0), (frozenset({'SINGLE_S_COORD', 'PAIR_ANAPHORA', 'Information_RequestIntent_dci', 'dim_SocialObligationManagement comm_func_Apology', 'dim_Task comm_func_Statement', 'PAIR_CONN', 'Information_DeliveryIntent_dci', 'General_ChatIntent_dci'}), 1), (frozenset({'SINGLE_S_COORD', 'Information_RequestIntent_dci', 'dim_Task comm_func_Directive', 'Information_DeliveryIntent_dci', 'dim_Task comm_func_Statement', 'PAIR_CONN', 'dim_SocialObligationManagement comm_func_Thanking', 'General_ChatIntent_dci'}), 2)]


In [18]:
cluster_usage_distribution = defaultdict(int)
for dialog_id, dialog in dialogs.items():
        thread = dialog['thread']
        for ind, row in enumerate(thread):
            cluster_usage_distribution[row['cluster_id']] += 1

In [19]:
xs = [f"c_{e}" for e in list(cluster_usage_distribution.keys())]
fig = px.histogram(x=xs, y=list(cluster_usage_distribution.values()), nbins=len(cluster_usage_distribution.keys()), labels={'x': 'cluster id'})
fig.update_layout(xaxis={'categoryorder':'total descending'})
fig.write_html('usage_hist.html', auto_open=True)

In [20]:
cluster_examples = defaultdict(list)
for dialog_id, dialog in dialogs.items():
        thread = dialog['thread']
        for ind, row in enumerate(thread):
            cluster_examples[row['cluster_id']].append(row['text'])

In [21]:
top_cluster_ids = []
for cluster_id, freq in sorted(cluster_usage_distribution.items(), key=lambda x: x[1], reverse=True)[:10]:
    print(f"Cluster id: {cluster_id}; Frequency: {freq}")
    print("-------------")
    for e in sorted(reverse_index[cluster_id]):
        print(e)
    print("Samples: ")
    for s in random.sample(cluster_examples[cluster_id], 10):
        print(f"- {s}")
    print("-------------")
    print()
    top_cluster_ids.append(cluster_id)

Cluster id: 85; Frequency: 14125
-------------
General_ChatIntent_dci
Information_DeliveryIntent_dci
Information_RequestIntent_dci
SINGLE_S_COORD
dim_Task comm_func_Statement
Samples: 
- I found 14 for you. Do you have a specific area you'd like?
- Are you sure that there are no hotels on the west side of town? With or without internet?
- No, it's a 3 star guesthouse. Do you want to look for something different?
- The yippee noddle bar is in the moderate price range. Would you like to book a table? 
- Can you find me a guesthouse that includes wi-fi?
- There are 11 possible choices. Do you have any additional preferences?
- The booking was successful. Your table will be reserved for 15 minutes.
Your reference number is : 3MC54ZYM.  Is there anything else I can do for you? 
- Are there any moderately priced hotels that don't have free parking, but have wifi?
- Yes, I need the guesthouse to have free internet and parking. Can you get me their phone number?
- I need to know the following 

In [22]:
random.sample([1,2], 1)

[1]

In [23]:
reverse_index[0], reverse_index[3], reverse_index[5], reverse_index[19]

(frozenset({'General_ChatIntent_dci',
            'Information_DeliveryIntent_dci',
            'Information_RequestIntent_dci',
            'PAIR_ANAPHORA',
            'SINGLE_CATAPHORA',
            'SINGLE_S_COORD',
            'dim_SocialObligationManagement comm_func_Thanking',
            'dim_Task comm_func_Statement'}),
 frozenset({'General_ChatIntent_dci',
            'Information_DeliveryIntent_dci',
            'Information_RequestIntent_dci',
            'PAIR_ANAPHORA',
            'SINGLE_APPOSITION',
            'dim_SocialObligationManagement comm_func_Apology',
            'dim_Task comm_func_Directive',
            'dim_Task comm_func_Statement'}),
 frozenset({'General_ChatIntent_dci',
            'Information_DeliveryIntent_dci',
            'Information_RequestIntent_dci',
            'PAIR_ANAPHORA',
            'SINGLE_CONN_START',
            'dim_Task comm_func_Statement'}),
 frozenset({'General_ChatIntent_dci',
            'Information_RequestIntent_dci',
    

0. Выбросил топики. Изначально нашлось 248 уникальных кластеров на датасете ConvAI 1.
1. Пороги см. sims_hist.html - я делал отсчеку по 0.66
2. После объединения кластеров по порогу их осталось 109
3. Ниже распределение объединенных кластеров см usage_hist.html. Видно что, есть длинный хвост. 
4. Вот примеры наиболее частотных кластеров 
```
0: (frozenset({'General_ChatIntent_dci',
            'Information_DeliveryIntent_dci',
            'Information_RequestIntent_dci',
            'PAIR_ANAPHORA',
            'SINGLE_S_COORD',
            'dim_Task comm_func_Commissive',
            'dim_Task comm_func_Statement'}),
3: frozenset({'General_ChatIntent_dci',
            'Information_RequestIntent_dci',
            'PAIR_ANAPHORA',
            'SINGLE_APPOSITION',
            'SINGLE_VP_COORD',
            'dim_Task comm_func_Statement'}),
5: frozenset({'ClarificationIntent_dci',
            'General_ChatIntent_dci',
            'PAIR_CONN',
            'SINGLE_S_COORD',
            'dim_Task comm_func_Statement'}),
19: frozenset({'General_ChatIntent_dci',
            'InteractiveIntent_dci',
            'SINGLE_APPOSITION',
            'dim_Task comm_func_Statement'}))
```            
            
Сейчас главный вопрос как интерпертировать получившиеся кластера? 

Предлагаю созвониться, а то не очень понятно что делать дальше.