In [1]:
# import libraries
import os
import sys
import time
import tqdm
import math
import random
import datasets
from datasets import load_dataset, load_from_disk, concatenate_datasets
from transformers import (
    BertForPreTraining,
    BertTokenizerFast,
    BertConfig,
    DataCollatorForLanguageModeling,
    DataCollatorWithPadding,
    Trainer, 
    TrainingArguments,
    get_scheduler,
    set_seed,
    SchedulerType)
import torch
from torch.utils.data import DataLoader
import numpy as np
from torch.optim import AdamW
from selectionstrategies import SubmodStrategy
from helper_fns import taylor_softmax_v1

# Set seed
set_seed(42)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Load Tokenizer
checkpoint = "bert-base-uncased"
tokenizer = BertTokenizerFast.from_pretrained(checkpoint)

# Initialize the model 
config = BertConfig()
model = BertForPreTraining(config=config)
model.resize_token_embeddings(len(tokenizer))

Embedding(30522, 768, padding_idx=0)

In [3]:
# raw_data = load_dataset("wikipedia", "20220301.en")
raw_data = load_dataset("bookcorpus")
# Split dataset
raw_data=raw_data["train"].train_test_split(test_size=(30/100), shuffle=False)
raw_data=datasets.DatasetDict({"train": raw_data["train"], "validation": raw_data["test"]})

# Get the column names for tokenization
column_names = raw_data["train"].column_names
text_column_name = "text" if "text" in column_names else column_names[0]

# set length for tokenization
max_seq_length = tokenizer.model_max_length

# Define a function to tokenize the dataset
def tokenize_function(examples):
    examples[text_column_name] = [
        line for line in examples[text_column_name] if len(line) > 0 and not line.isspace()
    ]
    return tokenizer(examples[text_column_name], truncation=True, max_length=max_seq_length, padding="max_length", return_special_tokens_mask=True)


# Tokenize the dataset
print("Tokenizing the dataset")
tokenized_dataset = raw_data.map(
    tokenize_function, 
    batched=True, 
    num_proc=4, 
    remove_columns=column_names,
    load_from_cache_file=False,
    desc="Running tokenizer on the entire dataset",
)
train_dataset = tokenized_dataset["train"]
validation_dataset = tokenized_dataset["validation"]


Reusing dataset bookcorpus (/home/UNT/tm0663/.cache/huggingface/datasets/bookcorpus/plain_text/1.0.0/44662c4a114441c35200992bea923b170e6f13f2f0beb7c14e43759cec498700)
100%|██████████| 1/1 [00:00<00:00,  6.31it/s]
Loading cached split indices for dataset at /home/UNT/tm0663/.cache/huggingface/datasets/bookcorpus/plain_text/1.0.0/44662c4a114441c35200992bea923b170e6f13f2f0beb7c14e43759cec498700/cache-3ab949bb3673e959.arrow and /home/UNT/tm0663/.cache/huggingface/datasets/bookcorpus/plain_text/1.0.0/44662c4a114441c35200992bea923b170e6f13f2f0beb7c14e43759cec498700/cache-d6f270e854b67fa1.arrow


Tokenizing the dataset


Running tokenizer on the entire dataset #0:   0%|          | 10/12951 [00:02<42:34,  5.07ba/s]]
Running tokenizer on the entire dataset #0:   0%|          | 12/12951 [00:02<43:18,  4.98ba/s]
Running tokenizer on the entire dataset #0:   0%|          | 13/12951 [00:02<43:50,  4.92ba/s]
Running tokenizer on the entire dataset #0:   0%|          | 14/12951 [00:02<43:15,  4.99ba/s]
[A

[A[A
Running tokenizer on the entire dataset #0:   0%|          | 15/12951 [00:03<42:58,  5.02ba/s]
Running tokenizer on the entire dataset #0:   0%|          | 16/12951 [00:03<42:53,  5.03ba/s]

[A[A
Running tokenizer on the entire dataset #0:   0%|          | 17/12951 [00:03<42:20,  5.09ba/s]

[A[A
Running tokenizer on the entire dataset #0:   0%|          | 18/12951 [00:03<41:53,  5.15ba/s]

[A[A
Running tokenizer on the entire dataset #0:   0%|          | 19/12951 [00:03<41:44,  5.16ba/s]

[A[A
[A

Running tokenizer on the entire dataset #0:   0%|          | 20/12951 [00:04<48:36,  4.43ba/s]


In [None]:
# Group the texts into chunks of max_seq_length
def group_texts(examples, idx, split, tokenized_datasets):
    # Account for [CLS], [SEP], [SEP]
    max_num_tokens = max_seq_length-3
    # We *usually* want to fill up the entire sequence since we are padding
    # to `max_seq_length` anyways, so short sequences are generally wasted
    # computation. However, we *sometimes*
    # (i.e., short_seq_prob == 0.1 == 10% of the time) want to use shorter
    # sequences to minimize the mismatch between pre-training and fine-tuning.
    # The `target_seq_length` is just a rough target however, whereas
    # `max_seq_length` is a hard limit.
    short_seq_prob = 0.1
    nsp_probability = 0.5
    target_seq_length = max_num_tokens
    if random.random() < short_seq_prob:
        target_seq_length = random.randint(2, max_num_tokens)
    # We DON'T just concatenate all of the tokens from a document into a long
    # sequence and choose an arbitrary split point because this would make the
    # next sentence prediction task too easy. Instead, we split the input into
    # segments "A" and "B" based on the actual "sentences" provided by the user
    # input.
    result = {k: [] for k, v in tokenizer(
        "", return_special_tokens_mask=True).items()}
    result['next_sentence_label'] = []
    current_chunk = []
    current_length = 0
    i = 0
    while i < len(idx):
        segment = {k: examples[k][i][1:-1] for k in examples.keys()}
        current_chunk.append(segment)
        current_length += len(segment['input_ids'])
        if i == len(idx)-1 or current_length >= target_seq_length:
            if current_chunk:
                # `a_end` is how many segments from `current_chunk` go into the `A`
                # (first) sentence.
                a_end = 1
                if len(current_chunk) >= 2:
                    a_end = random.randint(1, len(current_chunk)-1)
                tokens_a = {k: [] for k, t in tokenizer(
                    "", return_special_tokens_mask=True).items()}
                for j in range(a_end):
                    for k, v in current_chunk[j].items():
                        tokens_a[k].extend(v)

                tokens_b = {k: [] for k, t in tokenizer(
                    "", return_special_tokens_mask=True).items()}
                # Random next
                is_random_next = False
                if len(current_chunk) == 1 or random.random() < nsp_probability:
                    is_random_next = True
                    target_b_length = target_seq_length - \
                        len(tokens_a["input_ids"])
                    # This should rarely go for more than one iteration for large
                    # corpora. However, just to be careful, we try to make sure that
                    # the random document is not the same as the document
                    # we're processing.
                    for _ in range(10):
                        random_segment_index = random.randint(
                            0, len(tokenized_datasets[split])-len(idx)-1)
                        if (random_segment_index-len(idx) not in idx) and (random_segment_index+len(idx) not in idx):
                            break

                    random_start = random.randint(0, len(idx)-1)
                    for j in range(random_start, len(idx)):
                        for k, v in {k: tokenized_datasets[split][random_segment_index+j][k][1:-1] for k in examples.keys()}.items():
                            tokens_b[k].extend(v)
                        if len(tokens_b['input_ids']) >= target_b_length:
                            break
                    # We didn't actually use these segments so we "put them back" so
                    # they don't go to waste.
                    num_unused_segments = len(current_chunk)-a_end
                    i -= num_unused_segments
                # Actual next
                else:
                    is_random_next = False
                    for j in range(a_end, len(current_chunk)):
                        for k, v in current_chunk[j].items():
                            tokens_b[k].extend(v)

                while True:
                    total_length = len(
                        tokens_a['input_ids'])+len(tokens_b['input_ids'])
                    if total_length <= max_num_tokens:
                        break
                    trunc_tokens = tokens_a if len(tokens_a['input_ids']) > len(
                        tokens_b['input_ids']) else tokens_b
                    # We want to sometimes truncate from the front and sometimes from the
                    # back to add more randomness and avoid biases.
                    if random.random() < 0.5:
                        for k in trunc_tokens.keys():
                            del trunc_tokens[k][0]
                    else:
                        for k in trunc_tokens.keys():
                            trunc_tokens[k].pop()
                inp = {
                    k: v[:-1] for k, v in tokenizer("", return_special_tokens_mask=True).items()}
                for k, v in tokens_a.items():
                    inp[k].extend(v)
                SEP = {k: v[1:] for k, v in tokenizer(
                    "", return_special_tokens_mask=True).items()}
                for k, v in SEP.items():
                    inp[k].extend(v)
                # Before this line: tokens_b['token_type_ids'] = list(map(lambda x: 1, tokens_b['token_type_ids']))
                # Add a check to ensure 'token_type_ids' exists in tokens_b

                if 'token_type_ids' in tokens_b:
                    tokens_b['token_type_ids'] = list(map(lambda x: 1, tokens_b['token_type_ids']))
                    for k, v in SEP.items():
                        tokens_b[k].extend(v)
                    tokens_b['token_type_ids'][-1] = 1
                # else:
                #     print("Key 'token_type_ids' not found in tokens_b")
                for k, v in tokens_b.items():
                    inp[k].extend(v)
                inp['next_sentence_label'] = int(is_random_next)
                for k, v in inp.items():
                    result[k].append(v)
            current_chunk = []
            current_length = 0
        i += 1
    return result

In [5]:
train_dataset = train_dataset.map(
    group_texts,
    fn_kwargs={'split': 'train', 'tokenized_datasets': tokenized_dataset},
    batched=True,
    batch_size=1000,
    num_proc=8,
    load_from_cache_file=False,
    with_indices=True,
    desc=f"Grouping Train texts in chunks of {max_seq_length}",
)

# Group the validation dataset
validation_dataset = validation_dataset.map(
    group_texts,
    fn_kwargs={'split': 'validation', 'tokenized_datasets': tokenized_dataset},
    batched=True,
    batch_size=1000,
    num_proc=8,
    load_from_cache_file= False,
    with_indices=True,
    desc=f"Grouping Validation texts in chunks of {max_seq_length}",
)

prepared_data  = datasets.DatasetDict({"train": train_dataset, "validation": validation_dataset})
prepared_data.save_to_disk("/mnt/DATA/bookcorpus/bert/prepared")

  return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
  return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
  return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
  return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
  return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)



[A[A[A

[A[A
  return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
  return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
  return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)





[A[A[A[A[A





[A[A[A[A[A[A



Grouping Train texts in chunks of 512 #0:   0%|          | 1/6476 [00:05<10:17:01,  5.72s/ba]

[A[A


[A[A[A
[A





[A[A[A[A[A[A




[A[A[A[A[A



Grouping Train texts in chunks of 512 #0:   0%|          | 2/6476 [00:10<8:48:06,  4.89s/ba] 

[A[A


[A[A[A
[A




[A[A[A[A[A



[A[A[A[A





Grouping Train texts in chunks of 512 #0:   0%|          | 3/64

OSError: [Errno 28] Error writing bytes to file. Detail: [errno 28] No space left on device

In [None]:
# dataset=load_from_disk('bert_dataset_prepared')
dataset=prepared_data['train']
tokenizer=BertTokenizerFast.from_pretrained("bert-base-uncased")

def extract_first_sentences(examples):
    for i, input_ids in enumerate(examples["input_ids"]):
        idx=input_ids.index(tokenizer.sep_token_id)
        examples["input_ids"][i]=input_ids[:idx+1]
        examples["attention_mask"][i]=examples["attention_mask"][i][:idx+1]
        examples["token_type_ids"][i]=examples["token_type_ids"][i][:idx+1]
        examples["special_tokens_mask"][i]=examples["special_tokens_mask"][i][:idx+1]
    return examples

# filter points from dataset with next_sentence_label == 0
nsp_zero=dataset.filter(lambda examples: [x==0 for x in examples["next_sentence_label"]], batched=True, num_proc=96, keep_in_memory=True)
nsp_zero.save_to_disk("/mnt/DATA/bookcorpus/bert/nsp_zero")
# filter points from dataset with next_sentence_label == 1
nsp_one=dataset.filter(lambda examples: [x==1 for x in examples["next_sentence_label"]], batched=True, num_proc=96, keep_in_memory=True)
nsp_one.save_to_disk("/mnt/DATA/bookcorpus/bert/nsp_one")
# extract first sentences from both datasets
nsp_zero=nsp_zero.map(extract_first_sentences, batched=True, num_proc=96, remove_columns=["next_sentence_label", "special_tokens_mask"], keep_in_memory=True)
nsp_one=nsp_one.map(extract_first_sentences, batched=True, num_proc=96, remove_columns=["next_sentence_label", "special_tokens_mask"], keep_in_memory=True)

# save datasets
nsp_zero.save_to_disk("/mnt/DATA/bookcorpus/bert/first_sent_nsp_zero")
nsp_one.save_to_disk("/mnt/DATA/bookcorpus/bert/first_sent_nsp_one")

In [None]:
# Initialize Random Subset Selection
subset_fraction = 0.1
num_samples = int(round(len(train_dataset) * subset_fraction, 0))
init_subset_indices = [random.sample(list(range(len(train_dataset))), num_samples)]

full_dataset=train_dataset
subset_dataset = full_dataset.select(init_subset_indices[0])

In [None]:
subset_dataset[0]

Data Collator

In [None]:
mlm_probability=0.15
per_device_train_batch_size=32
per_device_eval_batch_size=32

data_collator=DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=mlm_probability)

# Dataloaders creation
warmstart_dataloader=DataLoader(
    train_dataset, shuffle=True, collate_fn=data_collator, batch_size=per_device_train_batch_size
)

subset_dataloader=DataLoader(
    subset_dataset, shuffle=True, collate_fn=data_collator, batch_size=per_device_train_batch_size
)

eval_dataloader=DataLoader(
    validation_dataset, collate_fn=data_collator, batch_size=per_device_eval_batch_size
)

Preparing Optimizer & Learning rate schedule

In [None]:
# Optimizer
# Split weights in two groups, one with weight decay and the other not

weight_decay=0.01
learning_rate=5e-3

no_decay=["bias", "LayerNorm.weight"]
optimizer_grouped_parameters=[
    {
        "params":[p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
        "weight_decay":weight_decay,
    },
    {
        "params":[p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
        "weight_decay": 0.0
    }
]

optimizer=AdamW(optimizer_grouped_parameters, lr=learning_rate)

In [None]:
lr_scheduler_type=SchedulerType.LINEAR
num_warmup_steps=10
num_training_steps=10

lr_scheduler=get_scheduler(
    name=lr_scheduler_type,
    optimizer=optimizer,
    num_warmup_steps=num_warmup_steps,
    num_training_steps=num_training_steps
)

Training the Model

In [None]:
num_partitions = 1500
partition_strategy = 'random'
ss_optimizer = 'LazyGreedy'
subset_strategy = SubmodStrategy(logger=None, smi_func_type='fl',
                                 num_partitions=num_partitions, partition_strategy=partition_strategy,
                                 optimizer=ss_optimizer, similarity_criterion='feature',
                                 metric='cosine', eta=1, stopIfZeroGain=False,
                                 stopIfNegativeGain=False, verbose=False, lambdaVal=1)

In [None]:
# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")
if device.type == "cuda":
    # del unused_tensor
    torch.cuda.empty_cache()


model.to(device) # Move the model and data to the GPU

max_train_steps = 1000
per_device_train_batch_size = 1
num_warmstart_epochs = 100
num_processes = 1
gradient_accumulation_steps = 1
checkpointing_steps = 1000
output_dir = "./model"

# Train!
total_batch_size = per_device_train_batch_size * num_processes * gradient_accumulation_steps
main_start_time = time.time()
print(f"  Num examples = {len(train_dataset)}")
print(f"  Num warm-start epochs = {num_warmstart_epochs}")
print(f"  Instantaneous batch size per device = {per_device_train_batch_size}")
print(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
print(f"  Gradient Accumulation steps = {gradient_accumulation_steps}")
print(f"  Total optimization steps = {max_train_steps}")

# Only show the progress bar once on each machine.
# progress_bar = tqdm(range(max_train_steps))
completed_steps = 0

print(f"Begin the training.")
timing = []
warmstart_start_time = time.time()
for epoch in range(num_warmstart_epochs):
    if epoch == 0:
        print("Begin the warm-start")
    model.train()
    for step, batch in enumerate(warmstart_dataloader):
        batch = {k: v.to(device) for k, v in batch.items()}
        start_time = time.time()
        # Check if 'token_type_ids' is in batch
        if 'token_type_ids' in batch:
            outputs = model(**batch)
        else:
            outputs = model(input_ids=batch['input_ids'], attention_mask=batch['attention_mask'], labels=batch['next_sentence_label'])
        loss = outputs.loss
        print(f"Completed Steps: {1+completed_steps}; Loss: {loss.detach().float()}; lr: {lr_scheduler.get_last_lr()};")
        loss = loss / gradient_accumulation_steps
        loss.backward()
        if step % gradient_accumulation_steps == 0 or step == len(warmstart_dataloader) - 1:
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()
            # progress_bar.update(1)
            completed_steps += 1
        if isinstance(checkpointing_steps, int):
            if completed_steps % checkpointing_steps == 0:
                output_dir = f"step_{completed_steps}"
                if output_dir is not None:
                    output_dir = os.path.join(output_dir, output_dir)
                torch.save(model.state_dict(), output_dir)
        if completed_steps >= max_train_steps:
            break
        timing.append([(time.time() - start_time), 0])

    model.eval()
    losses = []
    for step, batch in enumerate(eval_dataloader):
        batch = {k: v.to(device) for k, v in batch.items()}
        with torch.no_grad():
            # Check if 'token_type_ids' is in batch
            if 'token_type_ids' in batch:
                outputs = model(**batch)
            else:
                outputs = model(input_ids=batch['input_ids'], attention_mask=batch['attention_mask'], labels=batch['next_sentence_label'])

        loss = outputs.loss
        losses.append(loss.repeat(per_device_eval_batch_size))

    losses = torch.cat(losses)
    losses = losses[:len(validation_dataset)]
    try:
        perplexity = math.exp(torch.mean(losses))
    except OverflowError:
        perplexity = float("inf")

    print(f"Steps {completed_steps}: perplexity: {perplexity}")
    if epoch == num_warmstart_epochs - 1:
        print("End the warm-start")
# Save the state after warm-start
output_dir = f"after_warmstart_step_{completed_steps}"
if output_dir is not None:
    output_dir = os.path.join(output_dir, output_dir)
torch.save(model.state_dict(), output_dir)
warmstart_end_time = time.time()
print(f"Completed warm-start in {warmstart_end_time - warmstart_start_time} seconds")


In [None]:
from torch.utils.data import DataLoader
import torch
import os
import time
import math
from accelerate import Accelerator

# Initialize Accelerator
accelerator = Accelerator()

model, optimizer, lr_scheduler, warmstart_dataloader, eval_dataloader = accelerator.prepare(
    model, optimizer, lr_scheduler, warmstart_dataloader, eval_dataloader)

num_epochs = 100
gradient_accumulation_steps = 1
checkpointing_steps = 1000
output_dir = "./model_checkpoints"
max_train_steps = 1000
completed_steps = 0

for epoch in range(num_epochs):
    model.train()
    for step, batch in enumerate(warmstart_dataloader):
        outputs = model(input_ids=batch['input_ids'], attention_mask=batch['attention_mask'], labels=batch['next_sentence_label'])
        loss = outputs.loss / gradient_accumulation_steps
        accelerator.backward(loss)

        if (step + 1) % gradient_accumulation_steps == 0 or step == len(warmstart_dataloader) - 1:
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()
            completed_steps += 1

            if completed_steps % checkpointing_steps == 0:
                checkpoint_path = os.path.join(output_dir, f"checkpoint_{completed_steps}.pt")
                accelerator.save(model.state_dict(), checkpoint_path)

        if completed_steps >= max_train_steps:
            break

    model.eval()
    eval_losses = []
    for batch in eval_dataloader:
        with torch.no_grad():
            outputs = model(input_ids=batch['input_ids'], attention_mask=batch['attention_mask'], labels=batch['next_sentence_label'])
        eval_losses.append(outputs.loss.item())

    avg_loss = sum(eval_losses) / len(eval_losses)
    try:
        perplexity = math.exp(avg_loss)
    except OverflowError:
        perplexity = float("inf")
    print(f"Epoch {epoch+1}, Step {completed_steps}, Loss: {avg_loss:.2f}, Perplexity: {perplexity}")

# Save the final model state
final_model_path = os.path.join(output_dir, "final_model.pt")
accelerator.save(model.state_dict(), final_model_path)


In [None]:
probs_nsp_zero=[]
probs_nsp_one=[]
greedyList_nsp_zero=[]
greedyList_nsp_one=[]
gains_nsp_zero=[]
gains_nsp_one=[]
if (args.num_warmstart_epochs!=0) or (args.resume_from_checkpoint):
    logger.info(f"Beginning the subset selection after warm-start or resuming from checkpoint")
    start_time=time.time()
    if args.selection_strategy == 'Random-Online':
        if accelerator.is_main_process:
            subset_indices_nsp_zero = [random.sample(list(range(len(first_sent_nsp_zero))), math.floor(num_samples/2))]
            subset_indices_nsp_one = [random.sample(list(range(len(first_sent_nsp_one))), math.ceil(num_samples/2))]
        else:
            subset_indices_nsp_zero = [[]]
            subset_indices_nsp_one = [[]]
    elif args.selection_strategy in ["fl", "logdet", "gc", "disparity-min"]:
        logger.info(f"Performing Subset selection for NSP class 0")
        pbar=tqdm(range(len(first_sent_nsp_zero_dataloader)), disable=not accelerator.is_local_main_process)
        model.eval()
        representations_nsp_zero=[]
        batch_indices_nsp_zero=[]
        total_cnt=0
        total_storage=0

        accelerator.wait_for_everyone()
        unwrapped_model=accelerator.unwrap_model(model)
        bert_model=unwrapped_model.bert
        bert_model=accelerator.prepare(bert_model)
        representations_start_time=time.time()
        for step, batch in enumerate(first_sent_nsp_zero_dataloader):
            with torch.no_grad():
                output=bert_model(**batch, output_hidden_states=True)
            embeddings=output["hidden_states"][args.layer_for_similarity_computation]
            mask=(batch['attention_mask'].unsqueeze(-1).expand(embeddings.size()).float())
            mask1=((batch['token_type_ids'].unsqueeze(-1).expand(embeddings.size()).float())==0)
            mask=mask*mask1
            mean_pooled=torch.sum(embeddings*mask, 1) / torch.clamp(mask.sum(1), min=1e-9)
            mean_pooled = accelerator.gather(mean_pooled)
            total_cnt += mean_pooled.size(0)
            if accelerator.is_main_process:
                mean_pooled = mean_pooled.cpu()
                total_storage += sys.getsizeof(mean_pooled.storage())
                representations_nsp_zero.append(mean_pooled)
            pbar.update(1)
        if accelerator.is_main_process:
            representations_nsp_zero=torch.cat(representations_nsp_zero, dim=0)
            representations_nsp_zero=representations_nsp_zero[:len(first_sent_nsp_zero)]
            total_storage += sys.getsizeof(representations_nsp_zero.storage())
            representations_nsp_zero=representations_nsp_zero.numpy()
            logger.info('Representations(NSP Class 0) Size: {}, Total number of samples: {}'.format(total_storage/(1024 * 1024), total_cnt))
            batch_indices_nsp_zero=list(range(len(first_sent_nsp_zero)))
            logger.info('Length of indices: {}'.format(len(batch_indices_nsp_zero)))
            logger.info('Representations(NSP Class 0) gathered. Shape of representations: {}. Length of indices: {}'.format(representations_nsp_zero.shape, len(batch_indices_nsp_zero)))
        logger.info(f"Representations(NSP Class 0) computed in {time.time()-representations_start_time} seconds")
        if accelerator.is_main_process:
            partition_indices_nsp_zero, greedyIdx_nsp_zero, gains_nsp_zero = subset_strategy.select(len(batch_indices_nsp_zero)-1, batch_indices_nsp_zero, representations_nsp_zero, parallel_processes=args.parallel_processes, return_gains=True)
            subset_indices_nsp_zero = [[]]
            i=0
            for p in gains_nsp_zero:
                greedyList_nsp_zero.append(greedyIdx_nsp_zero[i:i+len(p)])         
                i+=len(p)
            probs_nsp_zero=[taylor_softmax_v1(torch.from_numpy(np.array([partition_gains])/args.temperature)).numpy()[0] for partition_gains in gains_nsp_zero]
            rng=np.random.default_rng(args.seed+completed_steps)
            for i, partition_prob in enumerate(probs_nsp_zero):
                partition_budget=min(math.ceil((len(partition_prob)/len(batch_indices_nsp_zero)) * math.floor(num_samples/2)), len(partition_prob)-1)
                subset_indices_nsp_zero[0].extend(rng.choice(greedyList_nsp_zero[i], size=partition_budget, replace=False, p=partition_prob).tolist())
        else:
            subset_indices_nsp_zero=[[]]
    
        logger.info(f"Performing Subset selection for NSP class 1")
        pbar=tqdm(range(len(first_sent_nsp_one_dataloader)), disable=not accelerator.is_local_main_process)
        model.eval()
        representations_nsp_one=[]
        batch_indices_nsp_one=[]
        total_cnt=0
        total_storage=0
        representations_start_time=time.time()
        for step, batch in enumerate(first_sent_nsp_one_dataloader):
            with torch.no_grad():
                output=bert_model(**batch, output_hidden_states=True)
            embeddings=output["hidden_states"][args.layer_for_similarity_computation]
            mask=(batch['attention_mask'].unsqueeze(-1).expand(embeddings.size()).float())
            mask1=((batch['token_type_ids'].unsqueeze(-1).expand(embeddings.size()).float())==0)
            mask=mask*mask1
            mean_pooled=torch.sum(embeddings*mask, 1) / torch.clamp(mask.sum(1), min=1e-9)
            mean_pooled = accelerator.gather(mean_pooled)
            total_cnt += mean_pooled.size(0)
            if accelerator.is_main_process:
                mean_pooled = mean_pooled.cpu()
                total_storage += sys.getsizeof(mean_pooled.storage())
                representations_nsp_one.append(mean_pooled)
            pbar.update(1)
        if accelerator.is_main_process:
            representations_nsp_one=torch.cat(representations_nsp_one, dim=0)
            representations_nsp_one=representations_nsp_one[:len(first_sent_nsp_one)]
            total_storage += sys.getsizeof(representations_nsp_one.storage())
            representations_nsp_one=representations_nsp_one.numpy()
            logger.info('Representations(NSP Class 1) Size: {}, Total number of samples: {}'.format(total_storage/(1024 * 1024), total_cnt))
            batch_indices_nsp_one=list(range(len(first_sent_nsp_one)))
            logger.info('Length of indices: {}'.format(len(batch_indices_nsp_one)))
            logger.info('Representations(NSP Class 1) gathered. Shape of representations: {}. Length of indices: {}'.format(representations_nsp_one.shape, len(batch_indices_nsp_one)))
        logger.info(f"Representations(NSP Class 1) computed in {time.time()-representations_start_time} seconds")
        if accelerator.is_main_process:
            partition_indices_nsp_one, greedyIdx_nsp_one, gains_nsp_one = subset_strategy.select(len(batch_indices_nsp_one)-1, batch_indices_nsp_one, representations_nsp_one, parallel_processes=args.parallel_processes, return_gains=True)
            subset_indices_nsp_one = [[]]
            i=0
            for p in gains_nsp_one:
                greedyList_nsp_one.append(greedyIdx_nsp_one[i:i+len(p)])         
                i+=len(p)
            probs_nsp_one=[taylor_softmax_v1(torch.from_numpy(np.array([partition_gains])/args.temperature)).numpy()[0] for partition_gains in gains_nsp_one]
            rng=np.random.default_rng(args.seed+completed_steps)
            for i, partition_prob in enumerate(probs_nsp_one):
                partition_budget=min(math.ceil((len(partition_prob)/len(batch_indices_nsp_one)) * math.ceil(num_samples/2)), len(partition_prob)-1)
                subset_indices_nsp_one[0].extend(rng.choice(greedyList_nsp_one[i], size=partition_budget, replace=False, p=partition_prob).tolist())
        else:
            subset_indices_nsp_one=[[]]
    accelerator.wait_for_everyone()    
    broadcast_object_list(subset_indices_nsp_zero)
    broadcast_object_list(subset_indices_nsp_one)
    timing.append([0, time.time()-start_time])
    logger.info(f"First subset selection completed. Total Time taken(including embeddings computation): {time.time()-start_time}")
if accelerator.is_main_process:
    output_file=f"nsp_zero_subset_indices_after_step_{completed_steps}.pt"
    output_file=os.path.join(args.subset_dir, output_file)
    torch.save(torch.tensor(subset_indices_nsp_zero), output_file)
    output_file=f"nsp_one_subset_indices_after_step_{completed_steps}.pt"
    output_file=os.path.join(args.subset_dir, output_file)
    torch.save(torch.tensor(subset_indices_nsp_one), output_file)
    output_file=f"nsp_zero_gains_after_step_{completed_steps}.pkl"
    output_file=os.path.join(args.subset_dir, output_file)
    with open(output_file, "wb") as f:
        pickle.dump(gains_nsp_zero, f)
    output_file=f"nsp_one_gains_after_step_{completed_steps}.pkl"
    output_file=os.path.join(args.subset_dir, output_file)
    with open(output_file, "wb") as f:
        pickle.dump(gains_nsp_one, f)
    output_file=f"nsp_zero_partition_indices_after_step_{completed_steps}.pkl"
    output_file=os.path.join(args.subset_dir, output_file)
    with open(output_file, "wb") as f:
        pickle.dump(partition_indices_nsp_zero, f)
    output_file=f"nsp_one_partition_indices_after_step_{completed_steps}.pkl"
    output_file=os.path.join(args.subset_dir, output_file)
    with open(output_file, "wb") as f:
        pickle.dump(partition_indices_nsp_one, f)
    output_file=f"nsp_zero_greedy_indices_after_step_{completed_steps}.pkl"
    output_file=os.path.join(args.subset_dir, output_file)
    with open(output_file, "wb") as f:
        pickle.dump(greedyIdx_nsp_zero, f)
    output_file=f"nsp_one_greedy_indices_after_step_{completed_steps}.pkl"
    output_file=os.path.join(args.subset_dir, output_file)
    with open(output_file, "wb") as f:
        pickle.dump(greedyIdx_nsp_one, f)
accelerator.wait_for_everyone()

nsp_zero_subset_dataset=nsp_zero.select(subset_indices_nsp_zero[0])
nsp_one_subset_dataset=nsp_one.select(subset_indices_nsp_one[0])
# Concatenate the two datasets
subset_dataset = concatenate_datasets([nsp_zero_subset_dataset, nsp_one_subset_dataset])
subset_dataloader=DataLoader(
    subset_dataset, shuffle=True, collate_fn=data_collator, batch_size=args.per_device_train_batch_size)
subset_dataloader = accelerator.prepare(subset_dataloader)

logger.info("Begin the main training loop with importance re-sampling, after warm-start")