In [31]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import DBSCAN

from transformers import AutoTokenizer
from transformers import GlueDataset
from transformers import GlueDataTrainingArguments as DataTrainingArguments
from transformers import (HfArgumentParser, Trainer, TrainingArguments,
                          glue_compute_metrics, glue_output_modes,
                          glue_tasks_num_labels, set_seed)

import torch

In [2]:
EMBEDDING_PATH = '/home/nlp/experiments/cls_embeddings_mnli.pth'

In [3]:
cls_embeddings = torch.load(EMBEDDING_PATH)

In [4]:
cls_embeddings[0].shape, len(cls_embeddings)

((512, 768), 767)

In [5]:
cls_embeddings = np.concatenate(cls_embeddings)

In [6]:
cls_embeddings.shape

(392704, 768)

In [8]:
cls_embeddings[:1024].shape

(1024, 768)

In [64]:
clustering = DBSCAN(eps = 0.2, min_samples = 50).fit(cls_embeddings)

In [65]:
clustering.labels_

array([  0,   1,   2, ..., 509, 510, 511])

In [30]:
clustering.core_sample_indices_

array([     0,      1,      2, ..., 392701, 392702, 392703])

In [57]:
data_args = DataTrainingArguments(task_name = 'MNLI', data_dir = '/home/nlp/data/glue_data/MNLI')
training_args = TrainingArguments(output_dir = '/home/nlp/experiments/trial',
                                 do_eval = True)

In [52]:
tokenizer = AutoTokenizer.from_pretrained(
    'albert-base-v1',
)

In [90]:
dset = GlueDataset(data_args, tokenizer)

In [97]:
def get_cluster_indices(cluster_num, labels_array):
    return np.where(labels_array == cluster_num)[0]

In [155]:
def get_coupled_cluster_indices(labels, data_pct, original_len):
    current_len, cluster_indices = 0, []
    for i in set(labels):
        curr_cluster_indices = get_cluster_indices(i, labels)
        current_len += len(curr_cluster_indices)
        # print(current_len, int(original_len*data_pct))
        if current_len < int(original_len*data_pct):
            cluster_indices.extend(curr_cluster_indices)
        else:
            return cluster_indices

In [160]:
cluster_indices = get_coupled_cluster_indices(clustering.labels_, 0.1, len(dset))

In [161]:
len(cluster_indices)

39117

In [162]:
ds = torch.utils.data.Subset(dset, cluster_indices)

In [163]:
len(ds)

39117

In [138]:
0.1*len(dset)

39270.200000000004

In [165]:
np.load('/home/nlp/experiments/cluster_labels.npy')

array([  0,   1,   2, ..., 333, 334, 335])