In [1]:
%matplotlib inline

import pickle
import chromadb
import numpy as np
from pprint import pprint
from tqdm.auto import tqdm
import sys
from matplotlib import pyplot as plt
import random

sys.path.append("../")
from utils.parse_arxiv import *
from make_vectordb import get_embedding_model

[nltk_data] Downloading package punkt to /home/zyang37/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [2]:
chroma_client = chromadb.PersistentClient(path="../data/chroma_dbs/")
chroma_client.list_collections()

[Collection(name=arxiv_title_meta),
 Collection(name=cnn_headline),
 Collection(name=arxiv_title),
 Collection(name=arxiv_abstract),
 Collection(name=arxiv_abstract_meta),
 Collection(name=cnn_article)]

In [4]:
coll_name = "arxiv_title"
collection = chroma_client.get_collection(name=coll_name)

In [27]:
results = collection.query(query_texts="testing", n_results=5, include=['embeddings'])
res_embeds = np.array(results['embeddings'][0])

In [29]:
def positive_embeds():
    pass

res_embeds[[0,1,2],].shape

(3, 384)

In [None]:
with open("../data/arxiv/filtered_data.pickle", "rb") as f:
    data = pickle.load(f)

In [None]:
data.head()

In [None]:
def aggregate_by_keywords(data):
    author_dict = {}
    cat_dict = {}
    journal_dict = {}
    year_dict = {}
    
    for i in tqdm(range(data.shape[0])):
        # parse authors
        authors = parse_authors(data['authors'].iloc[i])
        for author in authors:
            if author not in author_dict.keys():
                author_dict[author] = [i]
            else:
                author_dict[author].append(i)
                
        # parse journals
        journals = parse_journal(data['journal-ref'].iloc[i])
        if journals not in journal_dict.keys():
            journal_dict[journals] = [i]
        else:
            journal_dict[journals].append(i)
            
        # parse categories
        categories = parse_categories(data['categories'].iloc[i])
        for cat in categories:
            if cat not in cat_dict.keys():
                cat_dict[cat] = [i]
            else:
                cat_dict[cat].append(i)

        # parse year
        year = parse_year(data['update_date'].iloc[i])
        if year not in year_dict.keys():
            year_dict[year] = [i]
        else:
            year_dict[year].append(i)
    return author_dict, cat_dict, journal_dict, year_dict

In [None]:
author_dict, cat_dict, journal_dict, year_dict = aggregate_by_keywords(data)
for k in year_dict.keys():
    print(k, year_dict[k])

In [None]:
for k in author_dict.keys():
    print(k, len(author_dict[k]))

In [None]:
def add_test_groups(test_groups, target_dict):
    for k in target_dict.keys():
        if len(target_dict[k]) > 1:
            test_groups.append(target_dict[k])
    return test_groups

In [None]:
test_groups = []
test_groups = add_test_groups(test_groups, author_dict)
test_groups = add_test_groups(test_groups, cat_dict)
test_groups = add_test_groups(test_groups, journal_dict)
test_groups = add_test_groups(test_groups, year_dict)

In [None]:
freq = {}
for group in test_groups:
    group_len = len(group)
    if group_len not in freq:
        freq[group_len] = 1
    else:
        freq[group_len] += 1

In [None]:
for k in freq.keys():
    print(k, freq[k])

In [None]:
cfg = load_json('../config/arxiv_cfg.json')

In [None]:
embed_func = get_embedding_model(cfg['vectorDB'])


In [None]:
def embed_all(df, embed_func):
    test_embeddings = []
    gt_embeddings = []
    for i in tqdm(range(df.shape[0])):
        test_embeddings.append(embed_func([df['title'].iloc[i]])[0])
        gt_embeddings.append(embed_func([df['abstract'].iloc[i]])[0])
    return test_embeddings, gt_embeddings

In [None]:
test_embeddings, gt_embeddings = embed_all(data, embed_func)

In [None]:
test_embeddings = np.array(test_embeddings)
gt_embeddings = np.array(gt_embeddings)

In [None]:
test_arr = [
    [1, 1, 2], 
    [2, 2, 4],
    [3, 3, 6]
]

In [None]:
def row_var(arr):
    return np.sum(np.var(arr, axis=0))

In [None]:
print(row_var(test_arr))

In [None]:
def calc_group_vars(test_embeddings, gt_embeddings, test_groups):
    test_vars = []
    gt_vars = []
    for group in test_groups:
        test_vars.append(row_var(test_embeddings[group]))
        gt_vars.append(row_var(gt_embeddings[group]))
    return test_vars, gt_vars

In [None]:
test_vars, gt_vars = calc_group_vars(test_embeddings, gt_embeddings, test_groups)

In [None]:
fig, ax = plt.subplots()
ax.plot([0, 1], [0, 1], 'r', transform=ax.transAxes)
ax.scatter(test_vars, gt_vars, s=1)
ax.set_xlabel('Test Variances (Arxiv titles)')
ax.set_ylabel('Ground Truth Variances (Arxiv abstracts)')
ax.set_title("Metadata Groups Variances")
fig.savefig('../data/metadata_vars.png')

In [None]:
test_vars = row_var(test_embeddings)
gt_vars = row_var(gt_embeddings)
print(test_vars, gt_vars)

In [None]:
print(test_embeddings.shape)
print(gt_embeddings.shape)

In [None]:
def rand_sample(test_embeddings, gt_embeddings, n, group_size):
    sample_groups = []
    for i in range(n):
        sample_groups.append(random.sample(range(test_embeddings.shape[0]), group_size))
    test_vars, gt_vars = calc_group_vars(test_embeddings, gt_embeddings, sample_groups)
    fig, ax = plt.subplots()
    ax.plot([0, 1], [0, 1], 'r', transform=ax.transAxes)
    ax.scatter(test_vars, gt_vars, s=1)
    ax.set_xlabel('Test Variances (Arxiv titles)')
    ax.set_ylabel('Ground Truth Variances (Arxiv abstracts)')
    ax.set_title("Random Sample Variances")
    fig.savefig('../data/rand_vars.png')

In [None]:
rand_sample(test_embeddings, gt_embeddings, 10000, 3)