In [None]:
import json
import pandas as pd
import random

from sklearn.model_selection import train_test_split
import plotly.express as px

pd.options.display.max_rows = 250
pd.options.display.max_colwidth = 250

In [None]:
act = pd.read_csv('table/actor.csv', index_col=0)
convs = pd.read_csv('table/convs.csv', index_col=[0, 1])
convs.drop(['isGroup', 'canvasRect_width', 'canvasRect_height'], axis=1, inplace=True)
convs['Actor'] = convs['Actor'].fillna(-1.0)
convs.outgoingLinks = convs.outgoingLinks.apply(lambda x: [tuple(i) for i in json.loads(x.replace('(', '[').replace(')', ']'))])

In [None]:
with open('actors.txt', 'r') as f:
    act_list = ['You'] + [i.replace('\n', '') for i in f.readlines()]

In [None]:
self_idx = list(act[act['Name'].apply(lambda x: x == 'You')].index)[0]
act_list_idx = list(act[act['Name'].apply(lambda x: x in act_list)].index)
act_list_idx.append(0)
act_list_idx.append(-1)
conv_idx_act = list(set(convs[convs.apply(lambda x: x.Actor in act_list_idx, axis=1)].index.get_level_values(level=0)))
cut_convs = convs.loc[conv_idx_act]

In [None]:
convs = convs.loc[convs['Actor'].apply(lambda x: x in act_list_idx)]

## clear outgoing links

In [None]:
link_removed = 0
rem_links = 0

index = convs.index
for idx, row in convs.iterrows():
    links = row['outgoingLinks']
    cut_links = []
    if links:
        for i in links:
            if i in index:
                cut_links.append(i)
            else:
                link_removed += 1
                #print(f'removing link {i}')
                #a = 1/0
    rem_links += len(cut_links)
    convs.at[idx, 'outgoingLinks'] = cut_links

## add incoming links

In [None]:
convs['ingoingLinks'] = [[]]*len(convs)

In [None]:
for idx_from, row in convs.iterrows():
    links = row['outgoingLinks']
    cut_links = []
    for idx_to in links:
        ing_links = set(convs.loc[idx_to, 'ingoingLinks'])
        ing_links.add(idx_from)
        convs.at[idx_to, 'ingoingLinks'] = list(ing_links)
            

## clear solo lines

In [None]:
convs = convs[convs.apply(lambda x: len(x['outgoingLinks'])!=0 or len(x['ingoingLinks'])!=0, axis=1)]

## remove tmp actors

#### fixing links before removing this actors

In [None]:
tmp_act_idx = list(convs[convs['Actor'].apply(lambda x: x in (0.0, -1.0))].index)
for idx in tmp_act_idx:
    outg, ing = convs.loc[idx, ['outgoingLinks', 'ingoingLinks']]
    for i in ing:
        init = set(convs.loc[i, 'outgoingLinks'])
        init.remove(idx)
        convs.at[i, 'outgoingLinks'] = list(init.union(set(outg)))

    for i in outg:
        init = set(convs.loc[i, 'ingoingLinks'])
        init.remove(idx)
        convs.at[i, 'ingoingLinks'] = list(init.union(set(ing)))

#### removing them

In [None]:
convs = convs[convs['Actor'].apply(lambda x: x not in (0.0, -1.0))]

#### checking if any links going outside of current data

In [None]:
index = convs.index
for idx, row in convs.iterrows():
    for i in row['outgoingLinks']:
        if i not in index:
            print('error', idx)

    for i in row['ingoingLinks']:
        if i not in index:
            print('error', idx)

#### remove sololines once again

In [None]:
convs = convs[convs.apply(lambda x: len(x['outgoingLinks'])!=0 or len(x['ingoingLinks'])!=0, axis=1)]

### removing NaN text

In [None]:
nan_text_idx = convs[convs['Dialogue Text'].isna()].index
for idx in nan_text_idx:
    outg, ing = convs.loc[idx, ['outgoingLinks', 'ingoingLinks']]
    for i in ing:
        init = set(convs.loc[i, 'outgoingLinks'])
        init.remove(idx)
        convs.at[i, 'outgoingLinks'] = list(init.union(set(outg)))
    
    for i in outg:
        init = set(convs.loc[i, 'ingoingLinks'])
        init.remove(idx)
        convs.at[i, 'ingoingLinks'] = list(init.union(set(ing)))
convs = convs[~convs['Dialogue Text'].isna()]
convs = convs[convs.apply(lambda x: len(x['outgoingLinks'])!=0 or len(x['ingoingLinks'])!=0, axis=1)]

In [None]:
index = convs.index
for idx, row in convs.iterrows():
    for i in row['outgoingLinks']:
        if i not in index:
            print('error', idx)

    for i in row['ingoingLinks']:
        if i not in index:
            print('error', idx)

In [None]:
convs['ActorName'] = convs['Actor'].map(act['Name'])
convs['ActorName'].value_counts()

In [None]:
convs_dict = convs['outgoingLinks'].to_dict()
def recursive_count(node, prev_nodes):
    if len(prev_nodes) > 5:
        return 1
    links = convs_dict[node]
    if len(links) == 0:
        return 1
    
    sum_ = 0
    for link in links:
        if link not in prev_nodes:
            sum_ += recursive_count(link, prev_nodes+[node])
    return sum_


def recursive_path(node, prev_nodes, max_len=7, min_len=3):
    if len(prev_nodes) >= max_len:
        return [prev_nodes]
    links = convs_dict[node]
    if len(links) == 0:
        if len(prev_nodes) > min_len - 1:
            return [prev_nodes+[node]]
        else:
            return []
    
    sum_ = []
    for link in links:
        if link not in prev_nodes:
            sum_ += recursive_path(link, prev_nodes+[node])
    return sum_

def filter_paths(paths):
    ret_path = []
    for path in paths:
        # You not only on last position
        for node in path[:-1]:
            if convs.loc[node]['ActorName'] == 'You':
                ret_path.append(path)
                break
    return ret_path


def populate_dial(path):
    ret_path = []
    cur = []
    for node in path:
        if convs.loc[node]['ActorName'] == 'You' and cur:
            ret_path.append(cur[:])
        cur.append(node)
    if convs.loc[path[-1]]['ActorName'] != 'You':
        ret_path.append(cur[:])
    return ret_path
                
        

In [None]:
def path_to_example(path):
    lines = []
    buffer = []
    for node in path:
        text, actor = convs.loc[node][['Dialogue Text', 'ActorName']]
        if actor == 'You':
            if buffer:
                lines.append('\n'.join(buffer))
            lines.append(text)
            buffer = []
        else:
            buffer.append(f'[{actor}]: {text}')
    if buffer:
        lines.append('\n'.join(buffer))
    return lines

In [None]:
all_starts = list(convs[convs['ingoingLinks'].apply(len) == 0].index)

In [None]:
train_start, test_start = train_test_split(all_starts, test_size=0.025, random_state=42)
len(train_start), len(test_start)

In [None]:
samples_train = []
for i in train_start:
    # filter
    paths = filter_paths(recursive_path(i, []))
    # if there is too much from one conv, so sample dialogues from here
    if len(paths) > 10:
        paths = random.sample(paths, 10)
    # populate dialogues by cutting them
    new_paths = []
    for p in paths:
        new_paths.extend(populate_dial(p))
    paths = new_paths
    if paths:
        samples_train.append((i, paths))


samples_test = []
for i in test_start:
    # filter
    paths = filter_paths(recursive_path(i, []))
    # if there is too much from one conv, so sample dialogues from here
    if len(paths) > 10:
        paths = random.sample(paths, 10)
    # populate dialogues by cutting them
    new_paths = []
    for p in paths:
        new_paths.extend(populate_dial(p))
    paths = new_paths
    if paths:
        samples_test.append((i, paths))

In [None]:
dataset_train = []
for sample in samples_train:
    paths = sample[1]
    for i in paths:
        dataset_train.append(path_to_example(i))

dataset_test = []
for sample in samples_test:
    paths = sample[1]
    for i in paths:
        dataset_test.append(path_to_example(i))

In [None]:
x = [len(i[1]) for i in samples_train]
print(sum(x), sum(x)/len(x))
fig = px.histogram(x=x,)
fig.show()

In [None]:
dataset_train = list(set([tuple(i) for i in dataset_train]))
dataset_test = list(set([tuple(i) for i in dataset_test]))
len(dataset_train), len(dataset_test)

In [None]:
with open('dataset/v1/train.json', 'w') as f:
    json.dump(dataset_train, f, indent=4)

with open('dataset/v1/test.json', 'w') as f:
    json.dump(dataset_test, f, indent=4)

In [None]:
freq_dict = {}
for sample in dataset_train:
    for line in sample:
        if line == '':
            print(sample)
        freq_dict[line] = freq_dict.get(line, 0) + 1


In [None]:
for k, v in freq_dict.items():
    if v > 100:
        print(f'{k}', v)