In [1]:
import numpy as np
import matplotlib.colors as mcolors
import pandas as pd
from sklearn.decomposition import PCA
        
def process_color_vector(raw_color_data, mode, date_format="%Y-%m-%d", compute_norm=False, **kwargs):
    """
    Processes raw color data into a format suitable for plotting in one of three modes:
      - 'categorical': expects labels (e.g., ['DM', 'GAN', ...])
      - 'continuous': expects numerical values (e.g., citation counts)
      - 'date': expects date strings which will be converted to Unix timestamps
    
    Parameters:
    -----------
    raw_color_data : array-like
        The raw data extracted from document metadata.
    mode : str
        One of 'categorical', 'continuous', or 'date'.
    date_format : str, optional
        The format to use when parsing dates (default: "%Y-%m-%d").
    compute_norm : bool, optional
        For mode 'date': if True, computes and returns a matplotlib Normalize instance 
        based on valid date values (default: False).
    **kwargs:
        Any additional keyword arguments for further customization.
    
    Returns:
    --------
    processed : np.array or list
        The processed color vector.
    norm : matplotlib.colors.Normalize or None
        For mode 'date': returns a Normalize instance if compute_norm is True, otherwise None.
        For other modes, always returns None.
    """
    if mode == 'categorical':
        # Simply convert to list (or np.array) of labels.
        processed = list(raw_color_data)
        norm = None

    elif mode == 'continuous':
        # Convert to a float array.
        processed = np.array(raw_color_data, dtype=float)
        norm = None

    elif mode == 'date':
        # Convert the raw data to a pandas Series and then to datetime.
        dates_series = pd.to_datetime(pd.Series(raw_color_data), format=date_format, errors='coerce')
        # Convert datetime values to Unix timestamp (seconds)
        # Note: dates_series.astype('int64') converts to nanoseconds.
        date_ints = dates_series.astype('int64')
        # Some NaT values turn into the minimum int64; replace those with NaN.
        date_ints = date_ints.where(date_ints != np.iinfo('int64').min)
        # Convert from nanoseconds to seconds.
        processed = date_ints / 1e9
        norm = None
        if compute_norm:
            valid_mask = ~dates_series.isna()
            if valid_mask.any():
                vmin = processed[valid_mask].min()
                vmax = processed[valid_mask].max()
                norm = mcolors.Normalize(vmin=vmin, vmax=vmax)
    else:
        raise ValueError("Invalid mode. Choose 'categorical', 'continuous', or 'date'.")

    return processed, norm

def generate_plot_from_query_new(
    query,
    vector_store,
    color_var,
    x_axis_comp=0,
    y_axis_comp=1,
    n_docs=10,
    n_components=.95,
    labels=None
):
    query = ["search_query: " + q for q in query]
    if labels is not None:
        full_labels = []
        for l in labels:
            full_labels.extend([l] * n_docs)
        labels = full_labels
    
    all_embeddings = []
    all_citation_counts = []
    all_dates = []
    
    for q in query:
        retrieved_docs = vector_store.similarity_search_with_score(q, k=n_docs)
        
        embeddings = np.array([doc.metadata.get("abstractEmbedding") for doc, _ in retrieved_docs])
        citation_counts = np.array([doc.metadata.get("citationCount") for doc, _ in retrieved_docs])
        dates = np.array([doc.metadata.get("publicationDate") for doc, _ in retrieved_docs])
        
        all_embeddings.append(embeddings)
        all_citation_counts.append(citation_counts)
        all_dates.append(dates)
    
    if len(query) > 1:  
        embeddings = np.concatenate(all_embeddings, axis=0)
        citation_counts = np.concatenate(all_citation_counts, axis=0)
        dates = np.concatenate(all_dates, axis=0)
    
    pca = PCA(n_components=n_components)
    reduced_embeddings = pca.fit_transform(embeddings)
    
    if color_var == 'labels' and labels is not None:
        labels, _ = process_color_vector(labels, mode='categorical')
        return reduced_embeddings, labels
    elif color_var == 'citationCount':
        citation_counts, _ = process_color_vector(citation_counts, mode='continuous')
        return reduced_embeddings, citation_counts
    elif color_var == 'dates':
        dates, norm = process_color_vector(dates, mode='date', date_format="%Y-%m-%d", compute_norm=True)
        return reduced_embeddings, dates
    else:
        raise Exception(f'color_var must be one of: labels, citationCount, dates. Current value is {color_var}')

In [2]:
from langchain_neo4j import Neo4jVector
from langchain_ollama import OllamaEmbeddings

import sys
sys.path.append('/home/TomKerby/Research/lit_review/lit_review')
import utils
from visualization import generate_plot_from_query

sys.path.append('/home/TomKerby/Research/lit_review/configs')
from rag_config import config

kg = utils.load_kg(config)

In [3]:
vis_query = """
MATCH (p:Paper)
WHERE p.abstract IS NOT NULL AND p.abstract <> ''
WITH DISTINCT p, vector.similarity.cosine(p.abstractEmbedding, $embedding) AS score
ORDER BY score DESC LIMIT $k
RETURN p.abstract AS text, score, properties(p) AS metadata
"""

abstract_vector = Neo4jVector.from_existing_index(
    OllamaEmbeddings(model=config['embedding']['model_id']),
    graph=kg, 
    index_name='abstract_embeddings',
    embedding_node_property='abstractEmbedding',
    text_node_property='abstract',
    retrieval_query=vis_query,
)

# Plotting GAN vs DM papers

In [4]:
query = ["Papers that discuss methods applied to diffusion models.", "Papers that discuss methods applied to GANs."]
color_var = "labels"

generate_plot_from_query_new(
    query,
    abstract_vector,
    color_var,
    x_axis_comp=0,
    y_axis_comp=1,
    n_docs=10,
    n_components=.95,
    labels=["DM", "GAN"]
)

(array([[ 0.31927266, -0.15681114, -0.13806565, -0.06683156,  0.11723362,
          0.06682453,  0.04238628, -0.14865868, -0.01113825, -0.01472197,
         -0.03742936, -0.04186704,  0.02750564, -0.021306  , -0.06025469],
        [ 0.28393004, -0.02760603, -0.17532565, -0.04815694,  0.11687503,
         -0.02330416,  0.02766311, -0.11069427, -0.08334369, -0.05281043,
         -0.01833625, -0.06352981, -0.03714362, -0.06939606, -0.00725304],
        [ 0.26381232,  0.07527303,  0.01278658, -0.03716339,  0.19505319,
          0.02839648, -0.01408408, -0.00661593,  0.18287047,  0.18242148,
          0.1296776 , -0.09603207,  0.09748977,  0.10702666, -0.01371722],
        [ 0.27436888, -0.09879017, -0.19256725,  0.1044493 , -0.03970548,
         -0.05677795,  0.09787898,  0.03472447,  0.04794987, -0.00526212,
         -0.07228386,  0.01162421, -0.02661758, -0.01172602, -0.10346575],
        [ 0.20323067, -0.17889988, -0.02217942,  0.09675187, -0.19152109,
          0.08738232,  0.0048649 ,

In [5]:
query = ["Papers that discuss methods applied to diffusion models.", "Papers that discuss methods applied to GANs."]
color_var = "dates"

generate_plot_from_query_new(
    query,
    abstract_vector,
    color_var,
    x_axis_comp=0,
    y_axis_comp=1,
    n_docs=10,
    n_components=.95,
    labels=["DM", "GAN"]
)

(array([[ 0.31927266, -0.15681114, -0.13806565, -0.06683156,  0.11723362,
          0.06682453,  0.04238628, -0.14865868, -0.01113825, -0.01472197,
         -0.03742936, -0.04186704,  0.02750564, -0.021306  , -0.06025469],
        [ 0.28393004, -0.02760603, -0.17532565, -0.04815694,  0.11687503,
         -0.02330416,  0.02766311, -0.11069427, -0.08334369, -0.05281043,
         -0.01833625, -0.06352981, -0.03714362, -0.06939606, -0.00725304],
        [ 0.26381232,  0.07527303,  0.01278658, -0.03716339,  0.19505319,
          0.02839648, -0.01408408, -0.00661593,  0.18287047,  0.18242148,
          0.1296776 , -0.09603207,  0.09748977,  0.10702666, -0.01371722],
        [ 0.27436888, -0.09879017, -0.19256725,  0.1044493 , -0.03970548,
         -0.05677795,  0.09787898,  0.03472447,  0.04794987, -0.00526212,
         -0.07228386,  0.01162421, -0.02661758, -0.01172602, -0.10346575],
        [ 0.20323067, -0.17889988, -0.02217942,  0.09675187, -0.19152109,
          0.08738232,  0.0048649 ,

In [None]:
query = ["Papers that discuss methods applied to diffusion models."]
color_var = "citationCount"

generate_plot_from_query_new(
    query,
    abstract_vector,
    color_var,
    x_axis_comp=0,
    y_axis_comp=1,
    n_docs=10,
    n_components=.95,
    labels=None
)

In [None]:
query = ["Papers that discuss methods applied to diffusion models."]
color_var = "dates"

generate_plot_from_query_new(
    query,
    abstract_vector,
    color_var,
    x_axis_comp=0,
    y_axis_comp=1,
    n_docs=10,
    n_components=.95,
    labels=None
)

In [None]:
# import io
# import base64

# # Generate the figure
# fig = plot_pca_embeddings(embeddings, date_floats, is_date=True, norm=norm, cmap='viridis')

# # Save figure to a buffer
# buf = io.BytesIO()
# fig.savefig(buf, format='png')
# buf.seek(0)
# encoded_img = base64.b64encode(buf.read()).decode('utf-8')