In [1]:
import json
import random
import pathlib
import joblib
from collections import defaultdict
import plotly.express as px

In [2]:
def get_data_path(output_key='topicalchat', feature_name='discourse_features', n=2):     
    # output_key = 'multi-woz2'
    # output_key = 'topicalchat'
    # features_name = 'dialog_tagger_features'
    # features_name = 'discourse_features'
    # features_name = 'topic_model_features'    
    output_folder = f'data/results/{output_key}_{n}_{feature_name}/'    
    data_path = output_folder + 'data.joblib'
    return data_path

def get_output_path(output_key='topicalchat', feature_name='discourse_features', n=2):     
    # output_key = 'multi-woz2'
    # output_key = 'topicalchat'
    # features_name = 'dialog_tagger_features'
    # features_name = 'discourse_features'
    # features_name = 'topic_model_features'    
    output_folder = f'data/results/{output_key}_{n}_{feature_name}/'    
    return output_folder

In [3]:
def load_data(data_path):
    data = joblib.load(data_path)
    return data

def jaccard_similarity(s1, s2):    
    if not s1 and not s2:
        return 0
    res = round(len(s1.intersection(s2)) / len(s1.union(s2)), 5)
    if res < 0:
        res = 0
    if res > 1:
        res = 1
    return res

In [4]:
dataset_name = 'topicalchat'
feature_name_1 = 'dialog_tagger_features'
n = 4
f1_data_path = get_data_path(dataset_name, feature_name=feature_name_1, n=n)
f1_output_path = get_output_path(dataset_name, feature_name=feature_name_1, n=n)
f1_data = load_data(f1_data_path)

feature_name_2 = 'discourse_features'
f2_data_path = get_data_path(dataset_name, feature_name=feature_name_2, n=n)
f2_output_path = get_output_path(dataset_name, feature_name=feature_name_2, n=n)
f2_data = load_data(f2_data_path)

In [5]:
def calc_row_hash(row):
    return hash(row['text0'] + row['text1'] + row['text2'] + row['text3'])

In [6]:
reverse_index_f1 = {i: c for c, i in sorted(f1_data['new_clusters'].items(), key=lambda x: x[1])} 
reverse_index_f2 = {i: c for c, i in sorted(f2_data['new_clusters'].items(), key=lambda x: x[1])} 

In [7]:
def get_thread_key(n):    
    key = f'thread'
    return key


def calc_cluster_cluster_usage_distribution(dialogs, n):
    key = get_thread_key(n)
    cluster_usage_distribution = defaultdict(int)
    for dialog_id, dialog in dialogs.items():
        if key not in dialog:
            continue
        thread = dialog[key]
        for ind, row in enumerate(thread):
            cluster_usage_distribution[row['cluster_id']] += 1
    return cluster_usage_distribution

def calc_cluster_rows_hashes(dialogs, n):
    key = get_thread_key(n)
    cluster_rows = defaultdict(set)
    for dialog_id, dialog in dialogs.items():
        if key not in dialog:
            continue
        thread = dialog[key]
        for ind, row in enumerate(thread):
            hash_val = calc_row_hash(row)
            cluster_rows[row['cluster_id']].add(hash_val)
    return cluster_rows

f1_rows_hashes = calc_cluster_rows_hashes(f1_data['dialog'], n)
f2_rows_hashes = calc_cluster_rows_hashes(f2_data['dialog'], n)
f1_usage = calc_cluster_cluster_usage_distribution(f1_data['dialog'], n)
f2_usage = calc_cluster_cluster_usage_distribution(f2_data['dialog'], n)

In [8]:
top_k = 10
f1_ids = [cluster_id for cluster_id, freq in sorted(f1_usage.items(), key=lambda x: x[1], reverse=True)[:top_k]]
f2_ids = [cluster_id for cluster_id, freq in sorted(f2_usage.items(), key=lambda x: x[1], reverse=True)[:top_k]]
sims = []
for cluster_id_f1 in f1_ids:
    f1_rows = f1_rows_hashes[cluster_id_f1]
    sim_arr = []
    for cluster_id_f2 in f2_ids:
        f2_rows = f2_rows_hashes[cluster_id_f2]
        sim_arr.append(jaccard_similarity(f1_rows, f2_rows))
    sims.append(sim_arr)

In [9]:
len(f1_ids), len(f2_ids)

(10, 10)

In [10]:
def get_cluster_examples(dialogs, n):
    key = get_thread_key(n)
    cluster_examples = defaultdict(list)
    for dialog_id, dialog in dialogs.items():
            if key not in dialog:
                continue
            thread = dialog[key]
            for ind, row in enumerate(thread):                
                prev_texts = []
                for i in range(0, n):
                    text_key = f'text{i}'
                    prev_texts.append(row[text_key])
                cluster_examples[row['cluster_id']].append(prev_texts)
    return cluster_examples

f1_examples = get_cluster_examples(f1_data['dialog'], n)
f2_examples = get_cluster_examples(f2_data['dialog'], n)

In [14]:
f1_cluster_id = 18
f2_cluster_id = 14
print(jaccard_similarity(f1_rows_hashes[f1_cluster_id], f2_rows_hashes[f2_cluster_id]))
print()

print(f"-----------{feature_name_1}-------------")
print(reverse_index_f1[f1_cluster_id])
print()
for sents in f1_examples[f1_cluster_id][:5]:
    for i,s in enumerate(sents):
        print(f"- S{i}: {s}")
    print("")
print("")
print(f"-----------{feature_name_2}-----------")
print(reverse_index_f2[f2_cluster_id])
print()
for sents in f2_examples[f2_cluster_id][:5]:
    for i,s in enumerate(sents):
        print(f"- S{i}: {s}")
    print("")

0.00178

-----------dialog_tagger_features-------------
frozenset({'dim_Feedback comm_func_Feedback', 'dim_Task comm_func_SetQ', 'dim_Task comm_func_Statement'})

- S0: How are you today?
- S1: Great! I was just doing a little shopping online. Browsing the Amazon internet jungle... You?
- S2: Amazon is trying to take over the world. 
- S3: It amazes me that the company is only 25 years old and is already valued as the most valuable retailer in the US. Ahead of Walmart!

- S0: I like Chamber of Secrets.
- S1: Is that the second book?
- S2: Yes..you are right. 
- S3: What was it about again?

- S0: Is that the second book?
- S1: Yes..you are right. 
- S2: What was it about again?
- S3: It's Harry's third year at Hogwarts. Do you know?

- S0: Yes..you are right. 
- S1: What was it about again?
- S2: It's Harry's third year at Hogwarts. Do you know?
- S3: Third year? I thought second year. 

- S0: What was it about again?
- S1: It's Harry's third year at Hogwarts. Do you know?
- S2: Third 

In [12]:
import numpy as np
z = np.array(sims)

In [13]:
import plotly.graph_objects as go

fig = go.Figure(data=go.Heatmap(
        z=z.T,
        x=[f'f1_{e}' for e in f1_ids],
        y=[f'f2_{e}' for e in f2_ids],
        colorscale='Viridis'))
fig.update_layout(
    title=f"{feature_name_1} clusters with {feature_name_2} clusters in {dataset_name}",
    xaxis_title=feature_name_1 + " clusters",
    yaxis_title=feature_name_2 + " clusters",
)
fig.write_html(f1_output_path + f'test_{dataset_name}_{feature_name_1}_{feature_name_2}_compare_{n}_scatter_plot.html', auto_open=True)