In [19]:
import json
import torch
from transformers import AutoTokenizer, BertModel, LEDForConditionalGeneration
from sklearn.cluster import KMeans, AgglomerativeClustering
from sklearn.metrics import silhouette_samples, silhouette_score, pairwise_distances_argmin_min
from sklearn.preprocessing import MinMaxScaler
from sklearn.decomposition import PCA
from datasets import load_metric
import pandas as pd
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np

In [2]:
# load our augmented dataset json file
augmented_dataset = None
with open('augmented_test.json', 'r') as f:
    augmented_dataset = json.load(f)

# load our indices to use
sample_indices = pd.read_csv('cohere_sample_indices.csv').values.flatten()
dataset_small = [augmented_dataset[i] for i in sample_indices]

In [3]:
# load our model
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
model = BertModel.from_pretrained("bert-base-uncased")

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [4]:
def embed_documents(documents):
    doc_embeddings = []
    for doc in documents:
        inputs = tokenizer.encode_plus(doc, max_length=512, pad_to_max_length=True, return_tensors="pt")
        outputs = model(**inputs)

        last_hidden_states = outputs.last_hidden_state
        # we're going to average across the tokens to get the sentence embedding
        doc_embedding = torch.mean(last_hidden_states, dim=1)
        doc_embeddings.append(doc_embedding)

    return doc_embeddings

In [None]:
def analyze_clusters_kmeans(documents):
    doc_embeddings = embed_documents(documents)
    # turn our list of embeddings into a numpy matrix
    doc_embeddings = torch.cat(doc_embeddings).detach().numpy()

    # we are going to fine tune the number of clusters to use, based on silhoutte score
    best_score = -2
    best_k = 0
    score_k_3 = -2
    for k in range(2, min(10, len(doc_embeddings))):
        kmeans = KMeans(n_clusters=k, random_state=0, n_init=10).fit(doc_embeddings)
        ss = silhouette_score(doc_embeddings, kmeans.labels_)
        if ss > best_score:
            best_score = ss
            best_k = k
        if k == 3:
            score_k_3 = ss
        
    return best_k, best_score, score_k_3

num_docs = []
ks = []
scores = []
scores_k_3 = []

for i in tqdm(range(0, len(dataset_small))):
    data = dataset_small[i]

    best_k, best_score, score_k_3 = analyze_clusters_kmeans(data['documents'])
    num_docs.append(len(data['documents']))
    ks.append(best_k)
    scores.append(best_score)
    scores_k_3.append(score_k_3)


In [15]:
# make a plots directory if it doesn't exist
import os
if not os.path.exists('plots'):
    os.mkdir('plots')

In [None]:
# we're going to plot 2 bar charts. Comparing the average silhoutte scores for the number
# of documents, and the average k value for the number of documents

# first, we need to get the average silhoutte scores for each number of documents
num_docs_to_scores = {}
num_docs_to_ks = {}

for i in range(0, len(num_docs)):
    num_docs_to_scores[num_docs[i]] = num_docs_to_scores.get(num_docs[i], []) + [scores[i]]
    num_docs_to_ks[num_docs[i]] = num_docs_to_ks.get(num_docs[i], []) + [ks[i]]

num_docs_to_avg_scores = {}
num_docs_to_avg_ks = {}
for num_doc, cur_scores in num_docs_to_scores.items():
    num_docs_to_avg_scores[num_doc] = sum(cur_scores) / len(cur_scores)
    num_docs_to_avg_ks[num_doc] = sum(num_docs_to_ks[num_doc]) / len(num_docs_to_ks[num_doc])


# now we can plot the bar charts
if not os.path.exists('plots/kmeans'):
    os.mkdir('plots/kmeans')


plt.bar(num_docs_to_avg_scores.keys(), num_docs_to_avg_scores.values())
plt.title('Average Silhoutte Score for Number of Documents')
plt.xlabel('Number of Documents')
plt.ylabel('Average Silhoutte Score')
plt.savefig('plots/kmeans/avg_silhoutte_score_num_docs.png')
plt.show()

plt.bar(num_docs_to_avg_ks.keys(), num_docs_to_avg_ks.values())
plt.title('Average K Value for Number of Documents')
plt.xlabel('Number of Documents')
plt.ylabel('Average K Value')
plt.savefig('plots/kmeans/avg_k_value_num_docs.png')
plt.show()

# next plot the distribution of documents
plt.hist(num_docs)
plt.title('Distribution of Number of Documents')
plt.xlabel('Number of Documents')
plt.ylabel('Frequency')
plt.savefig('plots/kmeans/distribution_num_docs.png')
plt.show()


with open('plots/kmeans/analysis.txt', 'w') as f:
    f.write(f"Avg silhoutte: {sum(scores) / len(scores)}\n")
    f.write(f"Avg silhoutte k=3: {sum(scores_k_3) / len(scores_k_3)}\n")

    f.write("Num docs, q1, q2, q3, min, max, var\n")
    
    # also compute the quartiles for the distribution of k values for each number of documents
    for num_doc, ks in num_docs_to_ks.items():
        # compute the quartiles of ks and the min and max
        q1 = np.quantile(ks, 0.25)
        q2 = np.quantile(ks, 0.5)
        q3 = np.quantile(ks, 0.75)
        min_k = min(ks)
        max_k = max(ks)
        # compute the variance of ks
        var = np.var(ks)
        f.write(f"{num_doc}: {q1}, {q2}, {q3}, {min_k}, {max_k}, {var} \n")

In [None]:
def analyze_clusters_kmeans_pca(documents):
    doc_embeddings = embed_documents(documents)
    # turn our list of embeddings into a numpy matrix
    doc_embeddings = torch.cat(doc_embeddings).detach().numpy()
    
    # run PCA on the embeddings first
    pca = PCA(n_components=0.99)
    pca.fit(doc_embeddings)
    doc_embeddings = pca.fit_transform(doc_embeddings)

    # we are going to fine tune the number of clusters to use, based on silhoutte score
    best_score = -2
    best_k = 0
    score_k_3 = -2

    for k in range(2, min(10, len(doc_embeddings))):
        kmeans = KMeans(n_clusters=k, random_state=0, n_init=10).fit(doc_embeddings)
        ss = silhouette_score(doc_embeddings, kmeans.labels_)
        if ss > best_score:
            best_score = ss
            best_k = k
        if k == 3:
            score_k_3 = ss
        
    return best_k, best_score, score_k_3

num_docs = []
ks = []
scores = []
scores_k_3 = []

for i in tqdm(range(0, len(dataset_small))):
    data = dataset_small[i]

    best_k, best_score, score_k_3 = analyze_clusters_kmeans_pca(data['documents'])

    num_docs.append(len(data['documents']))
    ks.append(best_k)
    scores.append(best_score)
    scores_k_3.append(score_k_3)


In [None]:
# we're going to plot 2 bar charts. Comparing the average silhoutte scores for the number
# of documents, and the average k value for the number of documents

# first, we need to get the average silhoutte scores for each number of documents
num_docs_to_scores = {}
num_docs_to_ks = {}

for i in range(0, len(num_docs)):
    num_docs_to_scores[num_docs[i]] = num_docs_to_scores.get(num_docs[i], []) + [scores[i]]
    num_docs_to_ks[num_docs[i]] = num_docs_to_ks.get(num_docs[i], []) + [ks[i]]

num_docs_to_avg_scores = {}
num_docs_to_avg_ks = {}
for num_doc, cur_scores in num_docs_to_scores.items():
    num_docs_to_avg_scores[num_doc] = sum(cur_scores) / len(cur_scores)
    num_docs_to_avg_ks[num_doc] = sum(num_docs_to_ks[num_doc]) / len(num_docs_to_ks[num_doc])


# now we can plot the bar charts
if not os.path.exists('plots/kmeans_pca'):
    os.mkdir('plots/kmeans_pca')


plt.bar(num_docs_to_avg_scores.keys(), num_docs_to_avg_scores.values())
plt.title('Average Silhoutte Score for Number of Documents')
plt.xlabel('Number of Documents')
plt.ylabel('Average Silhoutte Score')
plt.savefig('plots/kmeans_pca/avg_silhoutte_score_num_docs.png')
plt.show()

plt.bar(num_docs_to_avg_ks.keys(), num_docs_to_avg_ks.values())
plt.title('Average K Value for Number of Documents')
plt.xlabel('Number of Documents')
plt.ylabel('Average K Value')
plt.savefig('plots/kmeans_pca/avg_k_value_num_docs.png')
plt.show()

# next plot the distribution of documents
plt.hist(num_docs)
plt.title('Distribution of Number of Documents')
plt.xlabel('Number of Documents')
plt.ylabel('Frequency')
plt.savefig('plots/kmeans_pca/distribution_num_docs.png')
plt.show()


with open('plots/kmeans_pca/analysis.txt', 'w') as f:
    f.write(f"Avg silhoutte: {sum(scores) / len(scores)}\n")
    f.write(f"Avg silhoutte k=3: {sum(scores_k_3) / len(scores_k_3)}\n")

    f.write("Num docs, q1, q2, q3, min, max, var\n")
    
    # also compute the quartiles for the distribution of k values for each number of documents
    for num_doc, ks in num_docs_to_ks.items():
        # compute the quartiles of ks and the min and max
        q1 = np.quantile(ks, 0.25)
        q2 = np.quantile(ks, 0.5)
        q3 = np.quantile(ks, 0.75)
        min_k = min(ks)
        max_k = max(ks)
        # compute the variance of ks
        var = np.var(ks)
        f.write(f"{num_doc}: {q1}, {q2}, {q3}, {min_k}, {max_k}, {var} \n")

In [None]:
    # closest, _ = pairwise_distances_argmin_min(kmeans.cluster_centers_, doc_embeddings)
