# Pre-Requisite Libraries and Initialization


In [None]:
!pip install transformers
!pip install datasets
!pip install pyamg

In [None]:
import pickle
import torch
from sklearn.cluster import spectral_clustering
import numpy as np

from scipy import sparse
from scipy.sparse import linalg
from scipy.sparse import coo_matrix
from scipy.sparse import csr_matrix
from scipy.sparse import lil_matrix
from scipy.sparse import csgraph
from scipy.linalg import fractional_matrix_power

import pyamg
import networkx as nx
import matplotlib.pyplot as plt
import statistics
import random

from numpy import True_
from transformers import FeatureExtractionPipeline, pipeline, RobertaTokenizer, RobertaModel
from datasets import load_dataset
import nltk
from nltk.corpus import stopwords

In [None]:
# Change file path in FOLDERNAME variable for new Drive

from google.colab import drive
drive.mount('/content/drive', force_remount=True)

FOLDERNAME = 'curis22'
assert FOLDERNAME is not None, "[!] Enter the foldername."

import sys
sys.path.append('/content/drive/My Drive/{}'.format(FOLDERNAME))

%cd /content/drive/My\ Drive/$FOLDERNAME/

# RoBERTa Pre-processing

In [None]:
nltk.download('stopwords')
stopword = stopwords.words('english')
stopword.append('=')
stopword.append('\n')

tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
model = RobertaModel.from_pretrained("roberta-base")

# feature extraction pipeline
pipeline = FeatureExtractionPipeline(model=model, tokenizer=tokenizer, framework="pt")

# loads in wikitext dataset
dataset = load_dataset("wikitext",  'wikitext-2-raw-v1')
ds = dataset['train']

In [None]:
# randomly select n sentences to process from the dataset "ds"
ds = random.sample(ds, 3000)

embeds = []
tokens = []

for key, sentence in enumerate(ds):
  if sentence == '':
    continue

  # cleaning up the string
  sen_ind = 0
  sen_split = sentence.split()
  sen_split = [i for i in sen_split if i != "="]
  length_sen = len(sen_split)
  str_to_process = " ".join(str(x) for x in sen_split)
  outputs = pipeline(str_to_process)
  outputs = torch.FloatTensor(outputs)
  words, features = outputs[0].shape
  
  sen_mask = torch.zeros(max(0, words), dtype=torch.bool)

  while sen_ind < length_sen:
    # seeing how many word-grams the current token generates
    output = pipeline(sen_split[sen_ind])
    output = torch.FloatTensor(output)
    num_wordgram, features = outputs[0].shape
    
    # fill up the corresponding spot in the boolean mask tensor
    for i in range(num_wordgram):
      if sen_split[sen_ind] in stopword:
        sen_mask[i] = False
      else:
        sen_mask[i] = True
    sen_ind += 1

  out = outputs[0][sen_mask]
  embeds.append(out)

  token_ids = tokenizer.__call__(str_to_process)['input_ids']
  decoding = tokenizer.batch_decode(token_ids)
  tokens.extend(decoding)

  if key % 100 == 0: # partially saves embeddings
      file1 = open(str('random_word_embeds_wikitext_' + str(key)), 'wb')
      pickle.dump(embeds, file1)
      file1.close()

      file2 = open('random_wiki_tokens_' + str(key)), 'wb')
      pickle.dump(tokens, file2)
      file2.close()

In [None]:
# This block converts previously processed embeddings into row/column format suitable for further processing.
# Change file names as needed

file1 = open('rand_word_embed_wikitext1400' , 'rb')
word_embed = pickle.load(file1)
file1.close()

col = 0
discrete_embed = []
for i in range(len(word_embed)):
  one, num_discrete, dim = word_embed[i].shape
  tot_discrete += num_discrete
  col = dim
  for j in range(num_discrete):
    discrete_embed.extend(word_embed[i][0][j])

b = torch.Tensor(tot_discrete, dim)
torch.stack(discrete_embed, out=b)

file2 = open('rand_word_embed_wikitext_1400.pt' , 'wb')
pickle.dump(b, file2)
file2.close()

In [None]:
# OPTIONAL token recovery step to get back every token processed to map 1-to-1 with vectors. 
# Included retroactively in main processing step.

file1 = open('random_wiki_sentences_1400', 'rb')
selected_sentences = pickle.load( file1)
file1.close()

tokens = []

for key, sentence in enumerate(selected_sentences):
  if sentence == '':
    continue

  # cleaning up the string
  sen_ind = 0
  sen_split = sentence.split()
  sen_split = [i for i in sen_split if i != "="]
  length_sen = len(sen_split)
  str_to_process = " ".join(str(x) for x in sen_split)

  token_ids = tokenizer.__call__(str_to_process)['input_ids']
  decoding = tokenizer.batch_decode(token_ids)
  tokens.extend(decoding)
  
file1 = open('random_wiki_sentences_1400_tokens', 'wb')
pickle.dump(tokens, file1)
file1.close()

# Orthogonal Matching Pursuit


In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

file1 = open('rand_word_embed_wikitext_1400.pt' , 'rb')
ds = pickle.load(file1)
file1.close()

embedding = torch.reshape(ds, (91454, 768))
embedding = embedding / np.linalg.norm(embedding, axis = 1)[:,None]
embedding.to(device) # code runs on CUDA

shape = list(embedding.shape)
row, col = shape

In [None]:
# initialization / hyperparameters: adjust as needed 
sparsity_low = 80
sparsity_med = 120
sparsity_high = 250
norm_threshold = 0.001
norm_threshold_higher = 0.01
norm_threshold_high = 0.1
norm_threshold_highest = 0.3

In [None]:
cached_index_set = []

# change variable names if switching hyperparameters

def orthogonal_matching_pursuit_hybrid(input_signal, index):
    residual = input_signal
    residual.to(device)
    num_iters = 0
    max_set = [0] * (sparsity_high + 100)
    max_set = torch.LongTensor(max_set)
    max_set.to(device)

    i = 0

    embed = torch.clone(embedding)
    embed[index] = 0

    while torch.count_nonzero(max_set) < sparsity_high and (torch.norm(residual) > norm_threshold_highest or num_iters == 0) and i < sparsity_high:
        
        num_iters += 1
        max_set[i] = int(torch.argmax(residual @ embed.T, dim=1).item())
        i += 1

        residual = input_signal - input_signal @ (torch.linalg.pinv(embedding[max_set[:i]]) @ embedding[max_set[:i]]).T

    cached_index_set.append(max_set[:i])
    return torch.norm(residual), i

In [None]:
residuals = []
lengths = []

for index in range(0, 91454): # change for-loop range to "parallelize" operations on Sherlock
    print(index)
    res = orthogonal_matching_pursuit_hybrid(embedding[index][None,:], index)
    residuals.append(res[0])
    lengths.append(res[1])
    
    if index % 500 == 0:
      file1 = open(str('wikitext_id_indices_' + str(index)) , 'wb')
      pickle.dump(cached_index_set, file1)
      file1.close()

      # for fixed index size
      file2 = open(str('wikitext_residual_distr_ + str(index)), 'wb')
      pickle.dump(residuals, file2)
      file2.close()

      # for fixed index size
      file3 = open(str('wikitext_lengths_distr_' + str(index)), 'wb')
      pickle.dump(lengths, file3)
      file3.close()

# "Data Processing" Step after OMP; idiosyncratic to Sherlock batch processing

In [None]:
# Put all batch files in OMP into one folder and index into it; look through the files to make sure they exist. Might need to remount / refresh Drive.
# Code runs through all file names and concatenates intermediate OMP files.
# Before Use, rename files. 3 Files from previous pipeline: Index Set, Lengths, Residuals / Appx Error. Compile all 3.

file1 = open('/content/drive/My Drive/curis22/rand_word_embed_wikitext_1400.pt' , 'rb')
embedding = pickle.load(file1)
file1.close()
embedding = torch.reshape(embedding, (91454, 768))
shape = list(embedding.shape)
row, col = shape

file_num = 0
while file_num <= row:
  if path.exists("wikitext_lengths_distr_high_res_high_len_rand_" + str(file_num)) == False:
    print(str(file_num) + ' hasn\'t been added!')
  file_num += 500
print(file_num)

file_num = 0
compiled = []

while file_num < (90000) - 500:
  file = open('wikitext_lengths_distr_high_res_high_len_rand_' + str(file_num), 'rb')
  res = pickle.load(file)
  file.close()

  next_file = open('wikitext_lengths_distr_high_res_high_len_rand_' + str(file_num + 500), 'rb')
  next_res = pickle.load(next_file)
  next_file.close()

  if len(res) < len(next_res): # haven't reached a stopping point yet
    file_num += 500
    continue

  compiled.extend(res) 
  file_num += 500

file1 = open('wikitext_lengths_distr_high_res_high_len_rand_90000', 'rb')
last_res = pickle.load(file1)
compiled.extend(last_res)

file1 = open('wikitext_lengths_distr_high_res_high_len_rand_91454', 'rb')
last_res = pickle.load(file1)
compiled.extend(last_res)

# Spectral Clustering

In [None]:
# Open pre=requisite ".pt" file and list of indices

file = open('/home/users/youngch/wiki_ind_highlenres_0.2', 'rb')
data_inds = pickle.load(file)
file.close()

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

file1 = open('rand_word_embed_wikitext_1400.pt' , 'rb')
ds = pickle.load(file1)
file1.close()

whole_embed = torch.reshape(ds, (91454, 768))
whole_embed = whole_embed / np.linalg.norm(whole_embed, axis = 1)[:,None]
whole_embed.to(device)

shape = list(whole_embed.shape)
row, col = shape

c = lil_matrix((row, row), dtype=np.double)

In [None]:
# ith_dp = processing all linear combinations in data_inds in order, so it's the ith datapoint in our orig. dataset
# ind = the indices of the vectors in that element's linear combination
for ith_dp, subset in enumerate(data_inds):
  # access all the vectors in the index set from the parent embedding
  sub_index_set = whole_embed[subset]
  sub_index_set.to(device)
  # transpose to obtain the right dimensions for matrix multiplication
  ci = torch.linalg.pinv(sub_index_set.T) @ whole_embed[ith_dp]
  for j, ind in enumerate(subset):
    if ind == ith_dp:
      print("Oh no! An element has its own index in its linear combination!")
    c[ith_dp, int(ind)] = ci[j]

W = abs(c) + abs(c.T) # efficient storage
one_vec = sparse.eye(row)
one_vec = one_vec.diagonal() # efficient storage of 1-vec

degree = []
for i in range(row):
  degree.append(len( np.nonzero(W[i])[0] ))

file1 = open('wiki_rand_nonzero_degrees_normalized' , 'wb')
pickle.dump(degree, file1)
file1.close()

W1 = W @ one_vec
W1 = sparse.diags(W1)
W1 = W1.power(-0.5)
L = W1 @ (W @ W1)
L = abs(L)

file1 = open('wiki_rand_whole_exp_Laplacian_normalized', 'wb')
pickle.dump(L, file1)
file1.close()

file1 = open('wiki_rand_whole_exp_Laplacian_normalized', 'rb')
L = pickle.load(file1)
file1.close()

cluster_num = 1440 # advisory square root of how many points we have
resolution_booster = 0

labels = spectral_clustering(
        L,
        eigen_solver = 'amg',
        assign_labels = 'cluster_qr',
        n_clusters = (cluster_num + resolution_booster),
        )

# create a dictionary with labels & indices... reevaluate the index set with a new matrix with all elems in the same subspace and look at its rank... 
cluster_mapping = {}
for i, label in enumerate(labels):
  if label not in cluster_mapping:
    cluster_mapping[label] = [i]
  else:
    cluster_mapping[label].append(i)

file1 = open('wiki_rand_cluster_assignments_normalized' , 'wb')
pickle.dump(cluster_mapping, file1)
file1.close()

In [None]:
ranks = []
for i in range(len(cluster_mapping)):
  new_mat = whole_embed[cluster_mapping[i]]
  rank = np.linalg.matrix_rank(new_mat)
  ranks.append(rank)

file1 = open('wiki_rand_rank_normalized' , 'wb')
pickle.dump(ranks, file1)
file1.close()

# Projections

In [None]:
# Calculate effective singular values; graphing subplots is computationally taxing for large number of clusters; consider commenting out

# Requires "labels" --> cluster assignments from spectral step, load from Drive if needed. 

# file1 = open('wiki_rand_cluster_assignments_normalized_1440', 'rb')
# labels = pickle.load(file1)
# file1.close()

eff_singular_vals = []
percentages = []

for key, value in enumerate(labels): 
  mat = whole_embed[labels[key]]
  u, s, vh = torch.linalg.svd(mat)
  s = torch.cumsum(s**2 / sum(s**2), dim=0) # normalizing
  x_pt = np.linspace(0, 1, len(s))
  slopes = (s[1:] - s[:-1]) / x_pt[1] # subtracts consecutive numbers

  for ind, slope in enumerate(slopes):
    if slopes[ind] <= 0.25:
      eff_singular_vals.append(ind + 1)
      percentages.append(s[ind + 1])
      break

  fig, ax = plt.subplots()
  ax.plot(x_pt, s)
  ax.set_title('Singular Value Decomp at Cluster '+ str(key))

print(eff_singular_vals)
fig, ax = plt.subplots()
ax.hist(eff_singular_vals)
ax.set_title('Overall Effective Ranks')

fig, ax = plt.subplots()
ax.hist(percentages)
ax.set_title('Percentage Represented')

In [None]:
file1 = open('num_eff_singular_values_rand_1400', 'rb')
eff_singular_vals = pickle.load(file1)
file1.close()

In [None]:
singular_vector_spaces = []
for key, value in enumerate(labels):
  mat = whole_embed[labels[key]]
  u, s, vh = torch.linalg.svd(mat)
  effective_dim_vec = vh[ : eff_singular_vals[key] ][:]
  # scale by the singular values, does not work so far
  # effective_dim_vec = vh[ : eff_singular_vals[key] ][:] * s[ : eff_singular_vals[key] ][:, None]
  singular_vector_spaces.append(effective_dim_vec)

In [None]:
def project_input_to_permutation(sample_input):
  largest_projections = [] 
  for i in range(len(singular_vector_spaces)):
    proj = (singular_vector_spaces[i] @ sample_input.T)  # projection of k x 768 and 768 x 1
    # proj = proj / np.linalg.norm(proj, axis = 1)[:,None]
    largest_projections.append( torch.linalg.norm(torch.abs(proj)) ) # appending the number of nonzero values in the projection

  largest_projections = torch.tensor(largest_projections)
  proj = torch.argmax(largest_projections) # just need the largest, subspace the input is most likely in

  subspace_proj = singular_vector_spaces[proj] @ sample_input
  subspace_proj = torch.abs(subspace_proj)
  subspace_sort = torch.argsort(subspace_proj, descending=True)
  subspace_sort = torch.add( subspace_sort, sum(eff_singular_vals[ : max(0, int(proj) - 1)]) )
  
  # find these k vectors in our space and give them values 1 through k, rest has value 0. Easily change to multiprobe by changing k and num_elem.
  num_elem = eff_singular_vals[proj]
  perm_vector = torch.zeros(sum(eff_singular_vals))
  k = torch.range(1, num_elem)
  perm_vector[subspace_sort] = k

  return int(proj), torch.as_tensor(perm_vector), largest_projections

Example of a cluster OOD retrieval task

In [None]:
i1 = whole_embed[2000]
proj1, v1, p1 = project_input_to_permutation(i1)
print('The current embed is ' + str(tokens[2000]))

tokens_in_clus = []
for i in range(len(labels[proj1])):
  tokens_in_clus.append(tokens[ labels[proj1][i] ])
  print(tokens[labels[proj1][i]])
print(len(tokens_in_clus))

i2 = whole_embed[2000]
print('The current embed is ' + str(tokens[2000]))
proj2, v2, p2 = project_input_to_permutation(i2)

print(v1, v2)
print(torch.dot(v1, v2))

plt.hist(p1)

In [None]:
example = 'I try to gauge of the power of women in politics with my own political index' # comparison: "ACE is , broadly speaking , a measure of the power of the hurricane multiplied by the length of time it existed , so storms that last a long time..."
sen = pipeline(example)
sen = torch.FloatTensor(sen)
sen = sen / np.linalg.norm(sen, axis = 1)[:,None]
x,y,z = sen.shape

token_ids = tokenizer.__call__(example)['input_ids']
decoding = tokenizer.batch_decode(token_ids)
print(decoding)

tokens_in_clus = {}
perm_vectors = {}
for i, word in enumerate(decoding):
  if word != '<s>' and word != '</s>':
    proj, v, p = project_input_to_permutation(sen[0][i])
    words = []
    for i in range(len(labels[proj])):
      words.append(tokens[labels[proj][i]])
    tokens_in_clus[word] = words
    perm_vectors[word] = v

print(tokens_in_clus)
for key, value in enumerate(tokens_in_clus):
  print(value, tokens_in_clus[value])

print(perm_vectors)

# comparing the embeddings for closely across two contexts
print(v1)
print(perm_vectors[' power'])
print(torch.nonzero(v1))
print(torch.nonzero(perm_vectors[' power']))
print(torch.dot(v1, perm_vectors[' power']))

print(tokens_in_clus[' power'])

# Benchmarking


In [None]:
conda install -c conda-forge faiss

In [None]:
import faiss

In [None]:
# Load in all relevant files; redirect to personal file path
# This does not work on Colab!!!

file = open('/Users/youngchen/Downloads/wiki_inds_normalized', 'rb')
data_inds = pickle.load(file)
file.close()

file = open('/Users/youngchen/Downloads/random_wiki_sentences_1400_tokens', 'rb')
tokens = pickle.load(file)
file.close()

file1 = open('/Users/youngchen/Downloads/rand_word_embed_wikitext_1400.pt' , 'rb')
embedding = pickle.load(file1)
file1.close()
whole_embed = torch.reshape(embedding, (91454, 768))
whole_embed = whole_embed / np.linalg.norm(whole_embed, axis = 1)[:,None]

file1 = open('/Users/youngchen/Downloads/wiki_rand_cluster_assignments_normalized_1440', 'rb')
labels = pickle.load(file1)
file1.close()

file1 = open('/Users/youngchen/Downloads/num_eff_singular_values_rand_1400', 'rb')
eff_singular_vals = pickle.load(file1)
file1.close()

In [None]:
nb = 91454                      # database size
nq = 900                       # nb of queries
xb = whole_embed.numpy()

indice = random.sample(range(91454), 100)
indice = torch.tensor(indice)
xq = whole_embed[indice]

# xq = whole_embed[:10]
xq = xq.numpy()

index = faiss.IndexFlatL2(768)   # build the index
print(index.is_trained)
index.add(xb)                  # add vectors to the index
print(index.ntotal)

In [None]:
k = 30                       # we want to see the k nearest neighbors
D, I = index.search(xb[:10], k) # sanity check
print(I.shape)
print(D.shape)
D, I = index.search(xq, k)     # actual search
print(I)                   # neighbors of the queries

In [None]:
nearest = {}

for i in range(100):
    i1 = xq[i]
    proj1, v1, p1 = project_input_to_permutation(i1)
    print('The current embed is ' + str(tokens[i]))
    
    nearest[i] = labels[proj1]

print(nearest)

In [None]:
right_neighbors = 0

for k in range(100):
    neighbors_faiss = I[k]
    for j, value in enumerate(nearest[k]):
        if value in neighbors_faiss:
            right_neighbors += 1

accuracy = right_neighbors / 100 # accuracy metric ;  if one of the 30-NNs are in the cluster, count it as "accurate"
print(accuracy)