In [1]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import DBSCAN

from transformers import AutoTokenizer
from transformers import GlueDataset
from transformers import GlueDataTrainingArguments as DataTrainingArguments
from transformers import (HfArgumentParser, Trainer, TrainingArguments,
                          glue_compute_metrics, glue_output_modes,
                          glue_tasks_num_labels, set_seed)

import torch

In [2]:
EMBEDDING_PATH = '/home/nlp/experiments/cls_embeddings_mnli.pth'

In [3]:
cls_embeddings = torch.load(EMBEDDING_PATH)

In [4]:
cls_embeddings[0].shape, len(cls_embeddings)

((256, 768), 1534)

In [5]:
cls_embeddings = np.concatenate(cls_embeddings)

In [6]:
cls_embeddings.shape

(392704, 768)

In [7]:
cls_embeddings[:1024].shape

(1024, 768)

In [10]:
#clustering = DBSCAN(eps = 0.2, min_samples = 50).fit(cls_embeddings)
cluster_labels = np.load('/home/nlp/experiments/cluster_labels.npy')
class cluster_labels_patch:
    def __init__(self, cluster_labels):
        self.labels_ = cluster_labels

clustering = cluster_labels_patch(cluster_labels)

In [15]:
clustering.labels_

array([  0,   1,   2, ..., 253, 254, 255])

In [13]:
# clustering.core_sample_indices_

In [17]:
data_args = DataTrainingArguments(task_name = 'MNLI', data_dir = '/home/nlp/data/glue_data/MNLI')
training_args = TrainingArguments(output_dir = '/home/nlp/experiments/trial',
                                 do_eval = True)

In [18]:
tokenizer = AutoTokenizer.from_pretrained(
    'albert-base-v1',
)

In [19]:
dset = GlueDataset(data_args, tokenizer)

In [20]:
def get_cluster_indices(cluster_num, labels_array):
    return np.where(labels_array == cluster_num)[0]

In [30]:
def get_coupled_cluster_indices_by_pct(labels, data_pct, original_len):
    current_len, cluster_indices = 0, []
    for i in set(labels):
        curr_cluster_indices = get_cluster_indices(i, labels)
        current_len += len(curr_cluster_indices)
        # print(current_len, int(original_len*data_pct))
        if current_len < int(original_len*data_pct):
            cluster_indices.extend(curr_cluster_indices)
        else:
            return cluster_indices

In [31]:
cluster_indices = get_coupled_cluster_indices_by_pct(clustering.labels_, 0.1, len(dset))

In [23]:
len(cluster_indices)

38350

In [24]:
ds = torch.utils.data.Subset(dset, cluster_indices)

In [25]:
len(ds)

38350

In [32]:
num_clusters = 16

In [39]:
def get_coupled_cluster_indices_by_num(labels, num_clusters):
    indices = []
    for i in range(0, num_clusters):
        indices.extend(get_cluster_indices(i, labels))
    return indices

In [40]:
cluster_indices = get_coupled_cluster_indices_by_num(clustering.labels_, num_clusters)

In [41]:
len(cluster_indices)

24544