# Merging Checkpoints

As you can see from the scripts included in this project, we ended up batching the comparisons between our keyword utterances ($k \in K$) and our context utterances ($c \in C$). Partially, this was to decrease the noise in the office where the tower is stored while running our tests.

The following scripts are designed to stitch those pieces back together again, largely using the CEDA object/framework to do so.

In [None]:
from CEDA import ceda_model
from tqdm import tqdm
import pandas as pd
import numpy as np
import json
import os

In [None]:
CKPT_PATH = 'data/ckpts'
RAW_PATH = 'data/raw'
OUT_PATH = 'data/results'
OUT_NAME = 'ceda-results.csv'

In [None]:
df = []

In [None]:
mod = ceda_model()
files = [os.path.join(CKPT_PATH, f) for f in os.listdir(CKPT_PATH)]

sum_cols = ['nx', 'Hxy']

for f in tqdm(files):
    mod.load_from_checkpoint(f)
    graph = mod.graph_df(residualize=False)
    
    meta_data_cols = [col for col in list(graph) if col not in sum_cols]
    
    # per article, sum H, n_x for all SOTU pieces.
    for link in graph['link'].unique():
        sub = graph.loc[
            graph['link'].isin([link]) 
            & (graph['nx'] > 0) 
            & (graph['ny'] > 0)
            & (graph['Hxy'] != -404.404)
        ]
        doc = sub[meta_data_cols].to_dict(orient='records')[0]
        for col in sum_cols:
            doc[col] = sub[col].sum()
        df += [doc]

In [None]:
df = pd.concat(df, ignore_index=True)
print(df.shape)
df.head()

In [None]:
# Operations to sanitize data

And some last checks.

In [None]:
df.isna().sum()

Let's also take a moment now and anonymize some of the data (and save our anonymization key locally)

In [None]:
anonymize_columns = [['president'], ['link']]
for cols in anonymize_columns:
    values = np.unique(df[cols].values)
    values = np.random.choice(values, size=(len(values),), replace=False)
    
    conversion = {val:i+1 for i,val in enumerate(values)}
    
    # save conversion dictionary
    f = open(
        os.path.join(
            OUT_PATH, 
            cols[0].replace('x_', '').replace('y_', '')+'.json'
        ), 
        'w'
    )
    f.write(json.dumps(conversion,indent=4))
    f.close()
    
    # anonymize the column
    for col in cols:
        print(col)
        df[col] = [conversion[val] for val in tqdm(df[col].values)]

Finishing this, let's save the data.

In [None]:
df.to_csv(os.path.join(OUT_PATH, OUT_NAME), index=False, encoding='utf-8')

In [None]:
df.shape