Skip to content

Importance sampling

Prajjwal Bhargava edited this page Jun 26, 2020 · 1 revision

Sampling with Clustering

from fluence.sampling.clustering import Clustering_Arguments, Clustering_Processor

# Similar to Huggingface Training Arguments
clustering_args = Clustering_Arguments(
        batch_size=32,
        num_clusters_elements=32,
        embedding_path="/home/nlp/experiments/cls_embeddings_mnli.pth",
        num_clusters=8,
        cluster_output_path="/home/nlp/experiments/tmp/c.pth",
    )

clustering_proc = Clustering_Processor(vars(clustering_obj))

# Now perform sampling by data percentage,  or centroids
cluster_indices = clustering_proc.get_cluster_indices_by_pct(
        clustering_args.data_pct, embeddings.shape[0]
    )
# By number of clusters
cluster_indices = clustering_proc.get_cluster_indices_by_num(
    clustering_args.num_clusters_elements
)
# Or centroids
cluster_indices = clustering_proc.get_cluster_indices_by_num(
        clustering_args.num_clusters_elements
    )

train_dataset = GlueDataset(data_args, tokenizer)
train_dataset = torch.utils.data.Subset(train_dataset, cluster_indices)
Clone this wiki locally