In [33]:
# import the necessary libraries
import argparse
import datetime
import time
import logging
import math
import os
import sys
import random
import datasets
import torch
from torch.optim import AdamW
from datasets import load_dataset, load_from_disk, concatenate_datasets
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import transformers
from accelerate import Accelerator, DistributedType
from accelerate.logging import get_logger
from accelerate.utils import set_seed, broadcast_object_list
from transformers import (
    BertConfig,
    BertTokenizerFast,
    BertForPreTraining,
    DataCollatorForLanguageModeling,
    DataCollatorWithPadding,
    SchedulerType,
    get_scheduler,
)
from transformers.utils.versions import require_version
from selectionstrategies import SubmodStrategy
from accelerate import InitProcessGroupKwargs
from selectionstrategies.helper_fns import taylor_softmax_v1
import numpy as np
import pickle
import faiss

os.environ['TOKENIZERS_PARALLELISM']='true'
device = "cuda" if torch.cuda.is_available() else "cpu"

In [34]:
# Create a logger
logger = logging.getLogger(__name__)

# Set the logging level
logger.setLevel(logging.INFO)

# Create a file handler
file_handler = logging.FileHandler('logfile.log')

# Create a formatter
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')

# Set the formatter for the file handler
file_handler.setFormatter(formatter)

# Add the file handler to the logger
logger.addHandler(file_handler)

# Log a message
logger.info('Logger created')

In [35]:
# Variables 
dataset_name = "Salesforce/wikitext"
dataset_config_name = "wikitext-2-raw-v1"
validation_split_percentage = 80
model_config_name = "google-bert/bert-base-uncased"
tokenizer_name = "bert-base-uncased"
use_slow_tokenizer = False  # Bool
num_workers = None # (int)
max_seq_len = 128
short_seq_prob = 0.1
nsp_probability = 0.1
batch_size = 1000

In [36]:
# Get and Preprocess the dataset for the task.
raw_datasets = load_dataset("Salesforce/wikitext", "wikitext-2-raw-v1")

if 'validation' not in raw_datasets.keys():
    raw_datasets=raw_datasets["train"].train_test_split(test_size=(validation_split_percentage/100), shuffle=False)
    raw_datasets=datasets.DatasetDict({"train": raw_datasets["train"], "validation": raw_datasets["test"]})

In [37]:
raw_datasets

DatasetDict({
    test: Dataset({
        features: ['text'],
        num_rows: 4358
    })
    train: Dataset({
        features: ['text'],
        num_rows: 36718
    })
    validation: Dataset({
        features: ['text'],
        num_rows: 3760
    })
})

In [38]:
# # Uncomment to use custom config

# config = BertConfig(
#     vocab_size=vocab_size,
#     hidden_size=hidden_size,
#     num_hidden_layers=num_hidden_layers,
#     num_attention_heads=num_attention_heads,
#     intermediate_size=intermediate_size,
#     hidden_act="gelu",
#     hidden_dropout_prob=0.1,
#     attention_probs_dropout_prob=0.1,
#     max_position_embeddings=512,
#     type_vocab_size=2,
#     initializer_range=0.02,
#     layer_norm_eps=1e-12,
#     position_embedding_type="absolute",
# )

In [39]:
# Create and instance of the model along with its tokenizer

# Tokenizer
tokenizer = BertTokenizerFast.from_pretrained(tokenizer_name, use_fast= not use_slow_tokenizer)

# Load the model
config = BertConfig.from_pretrained(model_config_name)

# Instantiating the model
model = BertForPreTraining(config)

# Resizing the token embeddings to fit the tokenizer
model.resize_token_embeddings(len(tokenizer))

Embedding(30522, 768, padding_idx=0)

In [40]:
# Tokenize and group the data based on the kind of model

column_names=raw_datasets['train'].column_names
text_column_name="text" if "text" in column_names else column_names[0]

def tokenize_function(examples):
    return tokenizer(examples[text_column_name])

tokenized_dataset = raw_datasets.map(
    tokenize_function,
    batched=True,
    num_proc=num_workers, 
    remove_columns=column_names,
    desc="Running tokenizer on every text in dataset"
)

# Grouping the data 
from experiment_utils import group_texts

train_dataset = tokenized_dataset["train"]
eval_dataset = tokenized_dataset["validation"]

train_dataset = train_dataset.map(
    group_texts, 
    fn_kwargs={'split': 'train', 'tokenizer':tokenizer, 'max_seq_length': max_seq_len, 
               'short_seq_prob':short_seq_prob, 'nsp_probability':nsp_probability, 'tokenized_datasets':tokenized_dataset},
    batched=True,
    batch_size=batch_size,
    num_proc=num_workers,
    with_indices=True,
    desc=f"Grouping Train texts into chucks of {max_seq_len}"
)

eval_dataset = eval_dataset.map(
    group_texts, 
    fn_kwargs={'split': 'validation', 'tokenizer':tokenizer, 'max_seq_length': max_seq_len, 
               'short_seq_prob':short_seq_prob, 'nsp_probability':nsp_probability, 'tokenized_datasets':tokenized_dataset},
    batched=True,
    batch_size=batch_size,
    num_proc=num_workers,
    with_indices=True,
    desc=f"Grouping Validation texts into chucks of {max_seq_len}"
)

In [41]:
eval_dataset

Dataset({
    features: ['input_ids', 'token_type_ids', 'attention_mask', 'special_tokens_mask', 'next_sentence_label'],
    num_rows: 1196
})

In [42]:
tokenizer.decode(eval_dataset[9]['input_ids'])

'[CLS] lobster pots, although lines baited with octopus or cuttlefish sometimes succeed in tempting them out, to allow them to be caught in a net or by hand. in 2008, 4 @, @ 386 t of h. gammarus were caught across europe and north africa, of which 3 @, @ 462 t ( 79 % ) was caught in the british isles ( including the channel islands ). the minimum landing size for h. gammarus is a carapace length of 87 mm ( 3 @. @ 4 [SEP] aquaculture systems for h. gammarus are under development, and production rates are still very low. [SEP]'

In [43]:
# Prepare the data
prepared_data = datasets.DatasetDict({"train": train_dataset, "validation": eval_dataset})
dataset=prepared_data['train']

def extract_first_sentences(examples):
    for i, input_ids in enumerate(examples["input_ids"]):
        idx=input_ids.index(tokenizer.sep_token_id)
        examples["input_ids"][i]=input_ids[:idx+1]
        examples["attention_mask"][i]=examples["attention_mask"][i][:idx+1]
        examples["token_type_ids"][i]=examples["token_type_ids"][i][:idx+1]
        examples["special_tokens_mask"][i]=examples["special_tokens_mask"][i][:idx+1]
    return examples

# Separate the data into those that have the next sentence labels and those that do not.
nsp_zero=dataset.filter(lambda examples: [x==0 for x in examples["next_sentence_label"]], batched=True, num_proc=num_workers, keep_in_memory=True)
nsp_one=dataset.filter(lambda examples: [x==1 for x in examples["next_sentence_label"]], batched=True, num_proc=num_workers, keep_in_memory=True)

# Extract the first sentences from both datasets
first_sent_nsp_zero=nsp_zero.map(extract_first_sentences, batched=True, num_proc=num_workers, remove_columns=["next_sentence_label", "special_tokens_mask"], keep_in_memory=True)
first_sent_nsp_one=nsp_one.map(extract_first_sentences, batched=True, num_proc=num_workers, remove_columns=["next_sentence_label", "special_tokens_mask"], keep_in_memory=True)

Filter:   0%|          | 0/12198 [00:00<?, ? examples/s]

Filter:   0%|          | 0/12198 [00:00<?, ? examples/s]

Map:   0%|          | 0/6467 [00:00<?, ? examples/s]

Map:   0%|          | 0/5731 [00:00<?, ? examples/s]

In [44]:
tokenizer.decode(first_sent_nsp_one[0]['input_ids'])

'[CLS] of movement limited by their action gauge. up to nine characters can be assigned to a single mission. during gameplay, characters will call out if something happens to them, such as their health points ( hp ) getting low or being knocked out by enemy attacks. each character has specific " potentials ", skills unique to each character. they are divided into " personal potential ", which are innate skills that remain unaltered unless otherwise dictated by the story and can either help or impede a character, and " battle potentials ", which are grown throughout the game and always grant boons to a character. to learn battle [SEP]'

In [45]:
subset_fraction = 0.25

# Sample a subset of the train dataset
num_samples = int(round(len(train_dataset) * subset_fraction, 0))
init_subset_indices = [random.sample(list(range(len(train_dataset))), num_samples)]
full_dataset=train_dataset
subset_dataset = full_dataset.select(init_subset_indices[0])

In [46]:
first_sent_nsp_zero

Dataset({
    features: ['input_ids', 'token_type_ids', 'attention_mask'],
    num_rows: 6467
})

In [47]:
train_dataset

Dataset({
    features: ['input_ids', 'token_type_ids', 'attention_mask', 'special_tokens_mask', 'next_sentence_label'],
    num_rows: 12198
})

In [48]:
subset_dataset

Dataset({
    features: ['input_ids', 'token_type_ids', 'attention_mask', 'special_tokens_mask', 'next_sentence_label'],
    num_rows: 3050
})

In [49]:
mlm_probability = 0.15
train_batch_size = 4
eval_batch_size = 4

# Create datacollators
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=mlm_probability)
data_collator_embd = DataCollatorWithPadding(tokenizer=tokenizer)

# warmstart dataloader (train on all the train dataset during warmup)
warmstart_dataloader = DataLoader(train_dataset.remove_columns(['special_tokens_mask']), shuffle=True, collate_fn=data_collator, batch_size=train_batch_size)

# first sent nsp zero dataloader
first_sent_nsp_zero_dataloader=DataLoader(first_sent_nsp_zero, shuffle=False, collate_fn=data_collator_embd, batch_size=eval_batch_size)

# first sent nsp one  dataloader
first_sent_nsp_one_dataloader=DataLoader(first_sent_nsp_one, shuffle=False, collate_fn=data_collator_embd, batch_size=eval_batch_size)

# subset dataloader (train)
subset_dataloader=DataLoader(subset_dataset.remove_columns(['special_tokens_mask']), shuffle=True, collate_fn=data_collator, batch_size=train_batch_size,)

# eval dataloader (validation & testing)
eval_dataloader=DataLoader(eval_dataset.remove_columns(['special_tokens_mask']), collate_fn=data_collator, batch_size=eval_batch_size)

In [50]:
learning_rate = 1e-4
scheduler_name = 'linear'
num_warmup_steps = 10
num_training_steps = 10

# Optimizer
optimizer = AdamW(model.parameters(), lr=learning_rate)

# learning scheduler
lr_scheduler = get_scheduler(
    name=scheduler_name,
    optimizer=optimizer,
    num_warmup_steps=num_warmup_steps,
    num_training_steps=num_training_steps
)

In [51]:
warmer = iter(eval_dataloader)
X= next(warmer)
X['input_ids'].shape

torch.Size([4, 128])

In [52]:
warmstart_epochs = 1
completed_steps = 0

model.to(device)
# Warmstart the model: Train the model with the warmstart data for warmstart epochs
for epoch in range(warmstart_epochs):
    if epoch==0:
        print("Begining warmstart")
    model.train() # Setup the model for training
    for step, batch in enumerate(warmstart_dataloader):
        outputs = model(**batch)
        loss = outputs.loss
        print(f"Completed Steps: {1+completed_steps}; Loss: {loss.detach().float()}; lr: {lr_scheduler.get_last_lr()};")
        loss.backward()
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        completed_steps += 1
        if completed_steps >= num_warmup_steps:
            break

    # model.eval()
    # losses = []
    # for step, batch in enumerate(eval_dataloader):
    #     with torch.no_grad():  # Use torch.no_grad() instead of inference_mode()
    #         outputs = model(**batch)
        
    #     loss = outputs.loss
    #     losses.append(loss.view(1))  # Add the loss as a 1-dimensional tensor

    # if losses:
    #     losses = torch.cat(losses)
    #     losses = losses[:len(eval_dataset)]
    #     try:
    #         perplexity = math.exp(torch.mean(losses).item())
    #     except OverflowError:
    #         perplexity = float("inf")
    # else:
    #     perplexity = float("inf")
    
    # print(f"Epoch {epoch + 1}: Perplexity: {perplexity}")

# Plot both training & perplexity


Begining warmstart
Completed Steps: 1; Loss: 11.272369384765625; lr: [0.0];
Completed Steps: 2; Loss: 11.205646514892578; lr: [1e-05];
Completed Steps: 3; Loss: 11.192643165588379; lr: [2e-05];
Completed Steps: 4; Loss: 10.960800170898438; lr: [3e-05];
Completed Steps: 5; Loss: 10.243038177490234; lr: [4e-05];
Completed Steps: 6; Loss: 10.265896797180176; lr: [5e-05];
Completed Steps: 7; Loss: 10.641788482666016; lr: [6e-05];
Completed Steps: 8; Loss: 10.159079551696777; lr: [7e-05];
Completed Steps: 9; Loss: 10.350421905517578; lr: [8e-05];
Completed Steps: 10; Loss: 10.05018138885498; lr: [9e-05];


In [53]:
selection_strategy = 'fl'
num_partitions = 2000 # Default is 5000
partition_strategy = 'random'
submod_optimizer = 'LazyGreedy'

# Define subset selection strategies

subset_strategy = SubmodStrategy(logger, selection_strategy,
    num_partitions=num_partitions, partition_strategy=partition_strategy,
    optimizer=submod_optimizer, similarity_criterion='feature', 
    metric='cosine', eta=1, stopIfZeroGain=False, 
    stopIfNegativeGain=False, verbose=False, lambdaVal=1)

In [68]:
selection_strategy = 'fl'
layer_for_similarity_computation = 9
temperature = 0.5
seed = 23
parallel_processes = 3
# num_samples has already been defined when creating subset.
probs_nsp_zero=[]
greedyList_nsp_zero=[]
gains_nsp_zero=[]


# Begin subset selection for first_sent_nsp_zero
if selection_strategy == 'Random-Online':
    subset_indices_nsp_zero = [random.sample(list(range(len(first_sent_nsp_zero))), math.floor(num_samples/2))]
    subset_indices_nsp_one = [random.sample(list(range(len(first_sent_nsp_one))), math.ceil(num_samples/2))]
elif selection_strategy in ['fl', 'logdet', 'gc', 'disparity-sum']:
    # Choose a selection strategy
    model.eval() # Set the model in evaluation model 
    representations_nsp_zero=[]
    batch_indices_nsp_zero=[]
    total_cnt=0
    total_storage=0
    # Unwrap the model and set it in evaluation mode.
    print("Performing Subset selection for NSP class 0")
    for step, batch in enumerate(first_sent_nsp_zero_dataloader):
        with torch.no_grad():
            output = model(**batch, output_hidden_states=True)
        embeddings=output["hidden_states"][layer_for_similarity_computation]
        # print(f"Embeddings shape: {embeddings.shape}")
        mask=(batch['attention_mask'].unsqueeze(-1).expand(embeddings.size()).float())
        mask1=((batch['token_type_ids'].unsqueeze(-1).expand(embeddings.size()).float())==0)
        mask=mask*mask1
        mean_pooled=torch.sum(embeddings*mask, 1) / torch.clamp(mask.sum(1), min=1e-9)
        total_cnt += mean_pooled.size(0)
        mean_pooled = mean_pooled.cpu()
        total_storage += sys.getsizeof(mean_pooled.storage())
        representations_nsp_zero.append(mean_pooled)
        # print(f"Current total representations: {len(representations_nsp_zero)}")
        
    print(f"Final number of representations: {len(representations_nsp_zero)}")
    representations_nsp_zero=torch.cat(representations_nsp_zero, dim=0)
    representations_nsp_zero=representations_nsp_zero[:len(first_sent_nsp_zero)]
    total_storage += sys.getsizeof(representations_nsp_zero.storage())
    representations_nsp_zero=representations_nsp_zero.numpy()
    print('Representations(NSP Class 0) Size: {}, Total number of samples: {}'.format(total_storage/(1024 * 1024), total_cnt))
    batch_indices_nsp_zero=list(range(len(first_sent_nsp_zero)))
    print('Length of indices: {}'.format(len(batch_indices_nsp_zero)))
    print('Representations(NSP Class 0) gathered. Shape of representations: {}. Length of indices: {}'.format(representations_nsp_zero.shape, len(batch_indices_nsp_zero)))

    partition_indices_nsp_zero, greedyIdx_nsp_zero, gains_nsp_zero = subset_strategy.select(len(batch_indices_nsp_zero)-1, 
                                                                                            batch_indices_nsp_zero, representations_nsp_zero, 
                                                                                            parallel_processes=parallel_processes, return_gains=True)
    subset_indices_nsp_zero = [[]]
    i=0
    for p in gains_nsp_zero:
        greedyList_nsp_zero.append(greedyIdx_nsp_zero[i:i+len(p)])         
        i+=len(p)
    probs_nsp_zero=[taylor_softmax_v1(torch.from_numpy(np.array([partition_gains])/temperature)).numpy()[0] for partition_gains in gains_nsp_zero]
    print(f"Taylor Softmax Prop: {probs_nsp_zero}")
    rng=np.random.default_rng(seed+completed_steps)
    for i, partition_prob in enumerate(probs_nsp_zero):
        print(f"{i}: Partition probablity :{partition_prob}")
        partition_budget=min(math.ceil((len(partition_prob)/len(batch_indices_nsp_zero)) * math.floor(num_samples/2)), len(partition_prob)-1)
        subset_indices_nsp_zero[0].extend(rng.choice(greedyList_nsp_zero[i], size=partition_budget, replace=False, p=partition_prob).tolist())

nsp_zero_subset_dataset=nsp_zero.select(subset_indices_nsp_zero[0])

# Using the list and the selection strategy, get the indices and the gains of each data point in the list.

Performing Subset selection for NSP class 0
Final number of representations: 1617
Representations(NSP Class 0) Size: 37.966644287109375, Total number of samples: 6467
Length of indices: 6467
Representations(NSP Class 0) gathered. Shape of representations: (6467, 768). Length of indices: 6467
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4

100%|██████████| 1617/1617 [00:00<00:00, 2466.44it/s][Iteration  [Iteration 33 of  of 33]] [Iteration 1 of 3] [Iteration 2 of 3]Iteration 3 of 3]33% [Iteration 1 of 3]3]33% [Iteration 1 of 3]% [Iteration 1 of 3]


Pool Parallel process ended.
Taylor Softmax Prop: [array([0.94880959, 0.02668135, 0.02450906]), array([0.95165489, 0.02464078, 0.02370433]), array([0.94954413, 0.02653621, 0.02391966]), array([0.94882799, 0.02660778, 0.02456423]), array([0.94829299, 0.02656814, 0.02513888]), array([0.94836044, 0.0266782 , 0.02496137]), array([0.95103823, 0.02472994, 0.02423184]), array([0.94582797, 0.02988716, 0.02428488]), array([0.95099425, 0.02524971, 0.02375603]), array([0.95136389, 0.02442303, 0.02421308]), array([0.94917176, 0.02628009, 0.02454815]), array([0.95055135, 0.02550436, 0.02394429]), array([0.94533809, 0.0292432 , 0.02541871]), array([0.95143518, 0.02480612, 0.0237587 ]), array([0.94846989, 0.02638538, 0.02514472]), array([0.94893239, 0.02652226, 0.02454534]), array([0.95133564, 0.02483183, 0.02383252]), array([0.94887474, 0.026608  , 0.02451726]), array([0.94879434, 0.02657607, 0.02462959]), array([0.95181536, 0.0245839 , 0.02360074]), array([0.94908767, 0.02649067, 0.02442166]), arra

In [74]:
gains_nsp_one=[]
probs_nsp_one=[]
greedyList_nsp_one=[]

if selection_strategy in ['fl', 'logdet', 'gc', 'disparity-sum']:
    # Choose a selection strategy
    model.eval() # Set the model in evaluation model 
    representations_nsp_one=[]
    batch_indices_nsp_one=[]
    total_cnt=0
    total_storage=0
    # Unwrap the model and set it in evaluation mode.
    print("Performing Subset selection for NSP class 1")
    for step, batch in enumerate(first_sent_nsp_one_dataloader):
        with torch.no_grad():
            output = model(**batch, output_hidden_states=True)
        embeddings=output["hidden_states"][layer_for_similarity_computation]
        # print(f"Embeddings shape: {embeddings.shape}")
        mask=(batch['attention_mask'].unsqueeze(-1).expand(embeddings.size()).float())
        mask1=((batch['token_type_ids'].unsqueeze(-1).expand(embeddings.size()).float())==0)
        mask=mask*mask1
        mean_pooled=torch.sum(embeddings*mask, 1) / torch.clamp(mask.sum(1), min=1e-9)
        total_cnt += mean_pooled.size(0)
        mean_pooled = mean_pooled.cpu()
        total_storage += sys.getsizeof(mean_pooled.storage())
        representations_nsp_one.append(mean_pooled)
        # print(f"Current total representations: {len(representations_nsp_one)}")
        
    print(f"Final number of representations: {len(representations_nsp_one)}")
    representations_nsp_one=torch.cat(representations_nsp_one, dim=0)
    representations_nsp_one=representations_nsp_one[:len(first_sent_nsp_one)]
    total_storage += sys.getsizeof(representations_nsp_one.storage())
    representations_nsp_one=representations_nsp_one.numpy()
    print('Representations(NSP Class 1) Size: {}, Total number of samples: {}'.format(total_storage/(1024 * 1024), total_cnt))
    batch_indices_nsp_one=list(range(len(first_sent_nsp_one)))
    print('Length of indices: {}'.format(len(batch_indices_nsp_one)))
    print('Representations(NSP Class 0) gathered. Shape of representations: {}. Length of indices: {}'.format(representations_nsp_one.shape, len(batch_indices_nsp_one)))

    partition_indices_nsp_one, greedyIdx_nsp_one, gains_nsp_one = subset_strategy.select(len(batch_indices_nsp_one)-1, 
                                                                                            batch_indices_nsp_one, representations_nsp_one, 
                                                                                            parallel_processes=parallel_processes, return_gains=True)
    subset_indices_nsp_one = [[]]
    i=0
    for p in gains_nsp_one:
        greedyList_nsp_one.append(greedyIdx_nsp_one[i:i+len(p)])         
        i+=len(p)
    probs_nsp_one=[taylor_softmax_v1(torch.from_numpy(np.array([partition_gains])/temperature)).numpy()[0] for partition_gains in gains_nsp_one]
    print(f"Taylor Softmax Prop: {probs_nsp_one}")
    rng=np.random.default_rng(seed+completed_steps)
    for i, partition_prob in enumerate(probs_nsp_one):
        print(f"{i}: Partition probablity :{partition_prob}")
        if len(partition_prob) > 0:
            partition_budget=min(math.ceil((len(partition_prob)/len(batch_indices_nsp_one)) * math.floor(num_samples/2)), len(partition_prob)-1)
            print(f"Partition Budget: {partition_budget}")
            subset_indices_nsp_one[0].extend(rng.choice(greedyList_nsp_one[i], size=partition_budget, replace=False, p=partition_prob).tolist())

nsp_one_subset_dataset=nsp_one.select(subset_indices_nsp_one[0])

Performing Subset selection for NSP class 1
Final number of representations: 1433
Representations(NSP Class 1) Size: 33.645721435546875, Total number of samples: 5731
Length of indices: 5731
Representations(NSP Class 0) gathered. Shape of representations: (5731, 768). Length of indices: 5731
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3
3

100%|██████████| 1911/1911 [00:00<00:00, 2581.56it/s][||||||||||          ]50% [Iteration 1 of 2]n 1 of 22]]22]50% [Iteration 1 of 2]50% [Iteration 1 of 2]1 of 2]


Pool Parallel process ended.
Taylor Softmax Prop: [array([0.95981982, 0.04018018]), array([0.96134661, 0.03865339]), array([0.96137383, 0.03862617]), array([0.95578722, 0.04421278]), array([0.96141658, 0.03858342]), array([0.96136335, 0.03863665]), array([0.96145556, 0.03854444]), array([0.96135778, 0.03864222]), array([0.96138345, 0.03861655]), array([0.96139705, 0.03860295]), array([0.96136785, 0.03863215]), array([0.96127955, 0.03872045]), array([0.96136443, 0.03863557]), array([0.96133148, 0.03866852]), array([0.96133358, 0.03866642]), array([0.96011885, 0.03988115]), array([0.96141154, 0.03858846]), array([0.95616052, 0.04383948]), array([0.96135461, 0.03864539]), array([0.95605683, 0.04394317]), array([0.9613626, 0.0386374]), array([0.96136087, 0.03863913]), array([0.96136046, 0.03863954]), array([0.9612856, 0.0387144]), array([0.96135149, 0.03864851]), array([0.9613254, 0.0386746]), array([0.96135255, 0.03864745]), array([0.96141714, 0.03858286]), array([0.95604225, 0.04395775])

In [75]:
subset_indices_nsp_one

[[731,
  4198,
  5551,
  4883,
  1511,
  4634,
  39,
  3104,
  5670,
  1168,
  228,
  4878,
  4952,
  3549,
  1863,
  5582,
  267,
  103,
  1007,
  4915,
  256,
  4986,
  4756,
  2122,
  3832,
  1037,
  1384,
  4067,
  1062,
  661,
  5696,
  940,
  2456,
  2041,
  605,
  1028,
  3440,
  1102,
  3735,
  2807,
  5631,
  2770,
  3047,
  4346,
  5462,
  1263,
  4284,
  2001,
  861,
  5617,
  812,
  960,
  1503,
  3408,
  677,
  2326,
  1588,
  2636,
  2825,
  556,
  2623,
  2895,
  1294,
  821,
  5260,
  4972,
  1654,
  4732,
  4304,
  2091,
  2723,
  1093,
  18,
  1631,
  4557,
  848,
  1082,
  2289,
  3746,
  4644,
  3404,
  5242,
  1718,
  5588,
  5186,
  345,
  2381,
  2293,
  2792,
  4988,
  2196,
  3842,
  802,
  357,
  1461,
  2361,
  1120,
  1733,
  1727,
  5658,
  5647,
  5457,
  2439,
  499,
  3614,
  3153,
  5414,
  5194,
  5547,
  4345,
  804,
  4905,
  3107,
  5500,
  3938,
  3894,
  2823,
  911,
  4582,
  2276,
  1468,
  1333,
  2149,
  3299,
  203,
  331,
  2657,
  2470,
  3

In [63]:
partition_budget

-1

In [71]:
greedyList_nsp_zero

[[6067, 3837, 136],
 [5020, 4149, 2523],
 [796, 1419, 15],
 [1315, 6292, 1395],
 [167, 2406, 4201],
 [277, 1570, 4781],
 [998, 6076, 1618],
 [2320, 5831, 1952],
 [1710, 5653, 4991],
 [2022, 5201, 2374],
 [1216, 260, 2557],
 [807, 3276, 6229],
 [5274, 5562, 5174],
 [5837, 5048, 3258],
 [6241, 2341, 4466],
 [255, 1266, 2260],
 [5860, 2827, 3272],
 [6272, 4721, 1082],
 [4729, 3813, 294],
 [2596, 404, 1538],
 [886, 1229, 982],
 [3710, 4445, 1034],
 [2044, 2728, 3721],
 [1566, 5055, 2961],
 [6252, 1821, 4478],
 [1427, 4064, 3464],
 [4746, 4707, 1927],
 [4549, 2864, 1612],
 [2433, 2034, 2104],
 [1648, 1521, 6005],
 [5039, 5452, 861],
 [3247, 5636, 1177],
 [3814, 4869, 3625],
 [6442, 3556, 6238],
 [610, 2811, 92],
 [127, 5136, 844],
 [2344, 211, 5427],
 [1684, 3858, 5134],
 [2021, 223, 956],
 [4348, 1299, 2332],
 [6070, 1935, 1595],
 [5315, 5439, 809],
 [3714, 3965, 1423],
 [2039, 4961, 968],
 [5788, 2479, 2019],
 [351, 1932, 4352],
 [4595, 89, 5251],
 [1470, 5186, 1322],
 [63, 4992, 4311],
 

In [85]:
# Put all the data into a dataset called subset_dataset
# Concatenate the two datasets
subset_dataset = concatenate_datasets([nsp_zero_subset_dataset, nsp_one_subset_dataset])

subset_dataloader=DataLoader(
    subset_dataset.remove_columns(['special_tokens_mask']), shuffle=True, collate_fn=data_collator, batch_size=train_batch_size)

# add the data to the subset dataloader

In [81]:
subset_dataset.remove_columns(['special_tokens_mask'])

Dataset({
    features: ['input_ids', 'token_type_ids', 'attention_mask', 'next_sentence_label'],
    num_rows: 3527
})

In [86]:
max_train_steps = 10
completed_steps = 0 # I needed to reset this!
select_every = 5 # This is needed to breaking the training and set the subset selection back in motion.
print("Begin the main training loop with importance re-sampling, after warm-start")
while completed_steps<max_train_steps:
    model.train()
    select_subset=False
    for step, batch in enumerate(subset_dataloader):
        train_time=0
        subset_time=0
        start_time=time.time()
        outputs=model(**batch)
        loss=outputs.loss
        logger.info(f"Completed Steps: {1+completed_steps}; Loss: {loss.detach().float()}; lr: {lr_scheduler.get_last_lr()};")
        loss.backward()
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        completed_steps+=1
           
        train_time += (time.time() - start_time)

        if completed_steps>=max_train_steps:
            break
        
        if (completed_steps)%select_every==0:
            select_subset=True
            break

    if select_subset==True:
        start_time = time.time()
        num_samples = int(round(len(full_dataset) * subset_fraction, 0)) 
        if selection_strategy == 'Random-Online':
            subset_indices_nsp_zero = [random.sample(list(range(len(first_sent_nsp_zero))), math.floor(num_samples/2))]
            subset_indices_nsp_one = [random.sample(list(range(len(first_sent_nsp_one))), math.ceil(num_samples/2))]

        elif selection_strategy in ["fl", "logdet", "gc", "disparity-min"]:
            print(f"Performing Subset selection for NSP class 0")
            sampling_start_time=time.time()
            
            subset_indices_nsp_zero=[[]]
            rng=np.random.default_rng(seed+completed_steps)
            for i, partition_prob in enumerate(probs_nsp_zero):
                partition_budget=min(math.ceil((len(partition_prob)/len(batch_indices_nsp_zero)) * math.floor(num_samples/2)), len(partition_prob)-1)
                subset_indices_nsp_zero[0].extend(rng.choice(greedyList_nsp_zero[i], size=partition_budget, replace=False, p=partition_prob).tolist())

            print("Sampling time(NSP Class 0): {}".format(time.time()-sampling_start_time))
        
            logger.info(f"Performing Subset selection for NSP class 1")
            sampling_start_time=time.time()
            
            subset_indices_nsp_one=[[]]
            rng=np.random.default_rng(seed+completed_steps)
            for i, partition_prob in enumerate(probs_nsp_one):
                if len(partition_prob) > 0:
                    partition_budget=min(math.ceil((len(partition_prob)/len(batch_indices_nsp_one)) * math.ceil(num_samples/2)), len(partition_prob)-1)
                    subset_indices_nsp_one[0].extend(rng.choice(greedyList_nsp_one[i], size=partition_budget, replace=False, p=partition_prob).tolist())

            print("Sampling time(NSP Class 1): {}".format(time.time()-sampling_start_time))

        nsp_zero_subset_dataset=nsp_zero.select(subset_indices_nsp_zero[0])
        nsp_one_subset_dataset=nsp_one.select(subset_indices_nsp_one[0])
        
        # Concatenate the two datasets
        subset_dataset = concatenate_datasets([nsp_zero_subset_dataset, nsp_one_subset_dataset])
        subset_dataloader=DataLoader(
            subset_dataset.remove_columns(['special_tokens_mask']), shuffle=True, collate_fn=data_collator, batch_size=train_batch_size)

        select_subset=False # Setting the up for training after re-sampling

        print(f"Subset selection Completed")
    
    model.eval()
    losses=[]
    for step, batch in enumerate(eval_dataloader):
        with torch.no_grad():
            outputs=model(**batch)

        loss=outputs.loss
        losses.append(loss.view(1))

    losses=torch.cat(losses)
    losses=losses[:len(eval_dataset)]
    try:
        perplexity=math.exp(torch.mean(losses))
    except OverflowError:
        perplexity=float("inf")

    print(f"Steps {completed_steps}: perplexity: {perplexity}")

print(f"Saving the final model after {completed_steps} steps.")
print(f"Training completed successfully!")

Begin the main training loop with importance re-sampling, after warm-start
Performing Subset selection for NSP class 0
Sampling time(NSP Class 0): 0.07647418975830078
Sampling time(NSP Class 1): 0.09356141090393066
Subset selection Completed
Steps 5: perplexity: 25712.872199259866
Steps 10: perplexity: 25611.674610323036
Saving the final model after 10 steps.
Training completed successfully!


In [None]:
# Train with importance re-sampling

# Train on the entire dataset once

# Sample using the indices and gains

# Train the model on the sampled dataset

# Evaluate the model 

# Save the model 

In [None]:
# Personal Addition: Inference on the model.