-
Notifications
You must be signed in to change notification settings - Fork 3
Importance sampling
Prajjwal Bhargava edited this page Jun 26, 2020
·
1 revision
from fluence.sampling.clustering import Clustering_Arguments, Clustering_Processor
# Similar to Huggingface Training Arguments
clustering_args = Clustering_Arguments(
batch_size=32,
num_clusters_elements=32,
embedding_path="/home/nlp/experiments/cls_embeddings_mnli.pth",
num_clusters=8,
cluster_output_path="/home/nlp/experiments/tmp/c.pth",
)
clustering_proc = Clustering_Processor(vars(clustering_obj))
# Now perform sampling by data percentage, or centroids
cluster_indices = clustering_proc.get_cluster_indices_by_pct(
clustering_args.data_pct, embeddings.shape[0]
)
# By number of clusters
cluster_indices = clustering_proc.get_cluster_indices_by_num(
clustering_args.num_clusters_elements
)
# Or centroids
cluster_indices = clustering_proc.get_cluster_indices_by_num(
clustering_args.num_clusters_elements
)
train_dataset = GlueDataset(data_args, tokenizer)
train_dataset = torch.utils.data.Subset(train_dataset, cluster_indices)