In [126]:
%reset -f
import pickle
from torch.utils.data import DataLoader
from datasets import Dataset, DatasetDict, load_from_disk, concatenate_datasets

In [127]:
dataset_path = f'/data/rozen/home/e0833634/lama/protllama/batch_script/uniref50_random90split_8k_512_first_1million_dataset.hf'
dataset = load_from_disk(dataset_path)
dataset

DatasetDict({
    train: Dataset({
        features: ['attention_mask', 'input_ids', 'labels'],
        num_rows: 253440
    })
    valid: Dataset({
        features: ['attention_mask', 'input_ids', 'labels'],
        num_rows: 232711
    })
})

In [128]:
small_dataset_dict = DatasetDict({
    'train': dataset['train'].select(range(10)),
    'valid': dataset['valid'].select(range(10))
})
small_dataset_dict

DatasetDict({
    train: Dataset({
        features: ['attention_mask', 'input_ids', 'labels'],
        num_rows: 10
    })
    valid: Dataset({
        features: ['attention_mask', 'input_ids', 'labels'],
        num_rows: 10
    })
})

In [129]:
with open('/data/rozen/home/e0833634/lama/protllama/batch_script/train_intermediate_checkpoint_batches_1000000.pkl', 'rb') as f:
    batch_indices_train = pickle.load(f)
batch_indices_train = batch_indices_train[:10]
batch_indices_train

[[769188, 769231, 769281, 770055, 770166, 770414],
 [55962],
 [951805, 951834],
 [34696, 36909],
 [39459, 39650, 39713, 39794, 39797, 39858, 40168, 40532, 40713, 40750, 41035],
 [139898, 139951, 140097],
 [223249, 224129],
 [181],
 [227032,
  227315,
  227702,
  227839,
  227841,
  228178,
  228288,
  228420,
  228457,
  228582],
 [94817, 94880]]

In [154]:
from torch.utils.data import Dataset, DataLoader
import torch

class DynamicBatchingDataset(Dataset):
    def __init__(self, dataset_dict, batch_indices):
        self.dataset_dict = dataset_dict['train']
        self.batch_indices = batch_indices  # This is mainly for informational purposes, if needed.

    def __len__(self):
        return len(self.dataset_dict['attention_mask'])  # Assuming each entry in dataset_dict represents a batch

    def __getitem__(self, idx):
        #batch_idx = self.batch_indices[idx]
        # Directly retrieve the batch using the index
        """returns [seq_number, token_length], return one batch at a time"""
        attention_mask = torch.tensor(self.dataset_dict['attention_mask'][idx])
        input_ids = torch.tensor(self.dataset_dict['input_ids'][idx])
        label = torch.tensor(self.dataset_dict['labels'][idx])

        return {
            'attention_mask': attention_mask,
            'input_ids': input_ids,
            'labels': label
        }

    @staticmethod
    def custom_collate_fn(batch):
        """Return the first element of the batch because each element is already a batch.
            (prevent auto batching from pytorch DataLoader, otherwise, the batch_size=1 will add another dimension during data retrieval)
        """
        return batch[0]

#batch_indices = [[181], [55, 266]]  # Your actual batch indices
train_dataset = DynamicBatchingDataset(small_dataset_dict, batch_indices_train)

# Batch size is set to 1 because your dataset itself is returning batches
train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=False, collate_fn=DynamicBatchingDataset.custom_collate_fn)


In [155]:
for idx, batch in enumerate(train_dataloader):
    print(idx, batch)

0 {'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1

In [156]:
import numpy as np
np.shape(batch['attention_mask'])

torch.Size([2, 228])

In [160]:
import pytorch_lightning as pl
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaForCausalLM
import sentencepiece as spm
class CustomDataModule(pl.LightningDataModule):
    def __init__(self, dataset_dict, batch_indices_train, batch_size=1):
        super().__init__()
        self.dataset_dict = dataset_dict
        self.batch_indices_train = batch_indices_train
        self.batch_size = batch_size

        self.tokenizer = self.tokenizer_generation('protein', '8k')

    @staticmethod
    def tokenizer_generation(target, vocab_size):
        if target == 'original':
            tokenizer = LlamaTokenizer.from_pretrained('hf-internal-testing/llama-tokenizer')
            tokenizer.pad_token = tokenizer.unk_token
            return tokenizer
        elif target == 'protein':
            tokenizer_path = '/data/rozen/home/e0833634/lama/protllama/batch_script/'
            tokenizer = spm.SentencePieceProcessor(model_file=tokenizer_path + "protein_%s.model" % (vocab_size))
            return tokenizer
        else:
            raise ValueError('Have not prepared tokenizer for this target')

    def prepare_data(self):
        # Possibly download data, set transforms, etc.
        pass

    def setup(self, stage=None):
        # Assign train/val datasets for use in dataloaders
        self.train_dataset = DynamicBatchingDataset(self.dataset_dict, self.batch_indices_train)
        # Repeat similar steps for validation and test datasets if needed

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=False, collate_fn=DynamicBatchingDataset.custom_collate_fn)

In [167]:
import pytorch_lightning as pl
import transformers
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaForCausalLM
from torch.optim import AdamW
from transformers import get_linear_schedule_with_warmup
from transformers import get_cosine_schedule_with_warmup
import os
import logging as log
import glob
from argparse import ArgumentParser
from protllama.bin.data import PretrainDataset
import torch
class pretrainLlama(pl.LightningModule):
    def __init__(self, hparam) -> None:
        super(pretrainLlama, self).__init__()
        self.save_hyperparameters()
        self.hparam = hparam  # need to contain epoch, target, date, learning rate, batch_size, num_frozen_epochs
        self.MODEL_CONFIGS = self.retrieve_config()
        self.__build_model()
        self.tokenizer = self.tokenizer_generation('protein', '8k')

    @staticmethod
    def tokenizer_generation(target, vocab_size):
        if target == 'original':
            tokenizer = LlamaTokenizer.from_pretrained('hf-internal-testing/llama-tokenizer')
            tokenizer.pad_token = tokenizer.unk_token
            return tokenizer
        elif target == 'protein':
            tokenizer_path = '/data/rozen/home/e0833634/lama/protllama/batch_script/'
            tokenizer = spm.SentencePieceProcessor(model_file=tokenizer_path + "protein_%s.model" % (vocab_size))
            return tokenizer
        else:
            raise ValueError('Have not prepared tokenizer for this target')

    def retrieve_config(self):
        """ return transformers DATASET object"""
        if self.hparam.target == 'original':
            config_dict = {'7b': LlamaConfig(max_position_embeddings=self.hparam.max_position_embeddings,
                                             hidden_size=self.hparam.hidden_size,
                                             intermediate_size=self.hparam.intermediate_size)}
            return config_dict['7b']
        elif self.hparam.target == 'protein':
            config_dict = {
                'protllama2': LlamaConfig(max_position_embeddings=self.hparam.max_position_embeddings,  # maximum length
                                          hidden_size=self.hparam.hidden_size,
                                          transformers_version=transformers.__version__,
                                          intermediate_size=self.hparam.intermediate_size,
                                          vocab_size=int(self.hparam.vocab_size.rstrip('k')) * 1000)}
            print(config_dict['protllama2'])
            return config_dict['protllama2']
        else:
            raise ValueError('Have not prepared dataset for this target')

    def __build_model(self) -> None:
        """start model building, can add customized classification head"""
        self.model = LlamaForCausalLM(self.MODEL_CONFIGS)
        print(self.model.lm_head.weight)

    def configure_optimizers(self):
        """set learning rates"""
        if self.hparam.scheduler == 'linear':
            parameters = self.model.parameters()
            optimizer = AdamW(parameters, lr=self.hparam.learning_rate, betas=(0.9, 0.95), weight_decay=0.1)
            lr_schedulers = {
                "scheduler": get_linear_schedule_with_warmup(optimizer,
                                                            num_warmup_steps=100,
                                                            num_training_steps=self.hparam.epoch * self.hparam.train_dataset_length),
                "name": 'learning_rate_logs'
            }
            return [optimizer], [lr_schedulers]
        elif self.hparam.scheduler == 'cosine':
            """llama behavior, end learning rate matches 10% of the maximum learning rate
                hard-coded to be 10% first
            """
            parameters = self.model.parameters()
            optimizer = AdamW(parameters, lr=self.hparam.learning_rate, betas=(0.9, 0.95), weight_decay=0.1)
            lr_schedulers = {
                "scheduler": get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=100,
                                                            num_training_steps=self.hparam.epoch * self.hparam.train_dataset_length,
                                                            num_cycles=0.39758361765,
                                                            # number of waves in the cosine schedule - e.g. 0.5 for period 2 cos wave means take 0 to 1
                                                            last_epoch=-1  # index of the last epoch when resuming training
                                                            ),
                "name": 'learning_rate_logs'
            }
            return [optimizer], [lr_schedulers]
        else:
            raise ValueError('You need to specify a scheduler first. Default is linear')

    def forward(self, **inputs):
        """ Pytorch forward function
        Returns:
        dict with model outputs (loss, logits, hidden layer, attention)
        """
        return self.model(**inputs)

    def training_step(self, batch, batch_nb: int, verbose=True):
        outputs = self.forward(**batch)
        loss_train = outputs[0]

        # Compute the perplexity
        perplexity = torch.exp(outputs[0].cpu())  # Ensure outputs are on CPU

        # Accuracy computation
        # Shifting
        shift_logits = outputs[1][..., :-1, :].contiguous().argmax(
            dim=-1).cpu()  # Ensure outputs and argmax result are on CPU
        if verbose:
            print('model predict?')
            print(shift_logits)

        # Assuming 'labels' is a key in batch containing true token IDs
        shift_labels = batch['labels'][..., 1:].contiguous().cpu()  # Move labels to CPU
        if verbose:
            print('model true?')
            print(shift_labels)

        non_padding_mask = shift_labels != -100

        # Compare predictions to true labels, but only for non-padding tokens
        acc_train = ((shift_logits == shift_labels) & non_padding_mask).sum().item() / non_padding_mask.sum().item()
        if verbose:
            print(acc_train)

        # Log
        self.log('train_loss', loss_train, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        self.log('train_perplexity', perplexity, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        self.log('train_accuracy', acc_train, on_step=False, on_epoch=True, prog_bar=True, logger=True)

        return loss_train

    def validation_step(self, batch, batch_nb: int):
        """ Similar to the training step but with the model in eval mode.
        Returns:
            - dictionary passed to the validation_end function.
        """
        outputs = self.forward(**batch)
        loss_val = outputs[0].cpu()

        # Compute the perplexity
        perplexity = torch.exp(loss_val)  # Ensure outputs are on CPU

        # Accuracy computation
        # Shifting
        shift_logits = outputs[1][..., :-1, :].contiguous().argmax(
            dim=-1).cpu()  # Ensure outputs and argmax result are on CPU

        # Assuming 'labels' is a key in batch containing true token IDs
        shift_labels = batch['labels'][..., 1:].contiguous().cpu()  # Move labels to CPU

        non_padding_mask = shift_labels != -100

        # Compare predictions to true labels, but only for non-padding tokens
        acc_val = ((shift_logits == shift_labels) & non_padding_mask).sum().item() / non_padding_mask.sum().item()

        # Log
        self.log('val_loss', loss_val, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        self.log('val_perplexity', perplexity, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        self.log('val_accuracy', acc_val, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        return loss_val

    @classmethod
    def add_model_specific_args(cls, parser: ArgumentParser):
        """parser for hyperparameters"""
        parser.add_argument('--learning_rate', type=float, default=3e-4, help='Learning rate for Adam optimizer')
        parser.add_argument('--scheduler', type=str, default='linear', help='Learning rate scheduler, either linear '
                                                                            'or cosine')
        parser.add_argument('--epoch', type=int, default=1, help='number of epochs for the training')
        #parser.add_argument('--batch_size', type=int, default=2, help='Batch sizes, sequence number per batch')
        return parser

In [168]:
import argparse
import os
from protllama.bin.data import PretrainDataset
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, TQDMProgressBar, LearningRateMonitor
from pytorch_lightning.loggers import WandbLogger

# set wandb offline on HPC
os.environ['WANDB_MODE'] = "offline"

from types import SimpleNamespace

hparam = SimpleNamespace(
    date='Oct_11',
    target='protein',
    max_position_embeddings=512,
    vocab_size='8k',
    hidden_size=640,
    intermediate_size=1720,
    save_top_k=1,
    scheduler='linear',
    learning_rate=3e-4,
    epoch=1
    #... add all other arguments similarly
)


#dm = PretrainDataset(target=hparam.target,
                     #max_sequence_length=hparam.max_position_embeddings)
data_module = CustomDataModule(dataset_dict=small_dataset_dict, batch_indices_train=batch_indices_train, batch_size=1)

# make sure dataset has "training" key
hparam.train_dataset_length = 10
training_log_path = str('protllama/pl_logs/')
if not os.path.exists(training_log_path):
    os.makedirs(training_log_path)
logger = WandbLogger(project="protllama2",
                     name=f"{hparam.target}_{hparam.date}_pre-training_log", #display on the web
                     save_dir='protllama/pl_logs/',
                     job_type='model-training',
                     group=f'pretrain_protllama2_{hparam.vocab_size}_{hparam.max_position_embeddings}',
                     id='version_%s' % str(1))
seed_everything(42)
model = pretrainLlama(hparam)
early_stop_callback = EarlyStopping(
    monitor="loss",
    min_delta=0.0,
    patience=0,  # number of epoch with no improvement
    verbose=True,
    mode="min",
)
training_model_path = str('protllama/pl_model_cache/')
if not os.path.exists(training_model_path):
    os.makedirs(training_model_path)
checkpoint_callback = ModelCheckpoint(
    dirpath=training_model_path,
    filename="{epoch}-{train_loss:.2f}-{val_loss:.2f}-%s_%s_%s_%s" % (hparam.target, hparam.date, hparam.vocab_size, hparam.max_position_embeddings),
    save_top_k=hparam.save_top_k,
    verbose=True,
    monitor="val_loss",
    mode="min",
)
lr_monitor = LearningRateMonitor(
    logging_interval='epoch'
)
trainer = Trainer(
    devices=1,
    accelerator='gpu',
    limit_train_batches=3,
    max_epochs=1,
    logger=logger,
    # max_epochs=1,
    # min_epochs=1,
    callbacks=[TQDMProgressBar(refresh_rate=10), lr_monitor],
    deterministic=True,
    enable_model_summary=True
)

# automatic garbage collection
import gc
gc.collect()

trainer.fit(model, datamodule=data_module)

Global seed set to 42


LlamaConfig {
  "bos_token_id": 1,
  "eos_token_id": 2,
  "hidden_act": "silu",
  "hidden_size": 640,
  "initializer_range": 0.02,
  "intermediate_size": 1720,
  "max_position_embeddings": 512,
  "model_type": "llama",
  "num_attention_heads": 32,
  "num_hidden_layers": 32,
  "num_key_value_heads": 32,
  "pretraining_tp": 1,
  "rms_norm_eps": 1e-06,
  "rope_scaling": null,
  "rope_theta": 10000.0,
  "tie_word_embeddings": false,
  "transformers_version": "4.33.1",
  "use_cache": true,
  "vocab_size": 8000
}



GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Parameter containing:
tensor([[ 1.2041e-02,  9.0041e-03, -3.4151e-03,  ..., -7.1120e-05,
          2.6942e-02,  5.4405e-03],
        [ 2.2820e-02,  4.9283e-03, -1.2922e-02,  ...,  1.2707e-02,
         -9.8123e-03, -4.3455e-03],
        [ 6.4636e-03,  8.7062e-03,  4.9835e-02,  ...,  4.0261e-03,
         -3.7080e-04,  9.0960e-04],
        ...,
        [ 2.2562e-02, -8.0692e-04, -7.5625e-03,  ...,  1.4612e-02,
         -2.3076e-02, -5.0668e-03],
        [ 5.5981e-03, -3.8602e-03, -3.7671e-02,  ...,  1.4204e-02,
          2.7897e-02, -5.4544e-03],
        [-6.7332e-05,  2.3461e-02,  1.0129e-02,  ..., -2.0542e-02,
          1.8716e-02,  4.7361e-02]], requires_grad=True)


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [GPU-d4342c04-7329-6ec9-9124-b61038ac3411,GPU-0958174f-473c-d99c-c8b5-c9c88c212a45]

  | Name  | Type             | Params
-------------------------------------------
0 | model | LlamaForCausalLM | 168 M 
-------------------------------------------
168 M     Trainable params
0         Non-trainable params
168 M     Total params
673.549   Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]

model predict?
tensor([[5876, 2453, 2453, 2453, 2453, 2453, 2453, 2453, 2453,  547,  547,  547,
          547,  547,  547,  547,  547,  547,  547,  547,  547,  547,  547,  547,
          547,  547,  547,  547,  547,  547,  547,  547,  547,  547,  547,  547,
          547,  547,  547,  547,  547,  547,  547,  547,  547,  547,  547,  547,
          547,  547,  547,  547,  547,  547,  547,  547,  547,  547,  547,  547,
          547,  547,  547,  547,  547,  547,  547,  547,  547,  547,  547,  547,
          547,  547,  547,  547,  547,  547,  547,  547,  547,  547,  547],
        [5876, 5876, 7949, 7949, 7949, 7949, 7949, 7896, 7896, 7896, 7896, 7896,
         7896, 7896, 7896, 7896, 7896, 7896, 7896, 7896, 7896, 7896, 7896, 7896,
         7896, 7896, 7896, 7896, 7896, 7896, 7896, 7896, 7896, 7896, 7896, 7896,
         7896, 7896, 7896, 7896, 7896, 7896, 7896, 7896, 7896, 4726, 7896, 7896,
         7896,  987, 7896,  987,  987, 7896,  987, 7896, 4679,  583,  987, 4679,
          583,  58

`Trainer.fit` stopped: `max_epochs=1` reached.


In [166]:
outputs = np.array([[5876, 5876, 5876, 5876, 5876, 2376, 2376, 2376, 2376, 2376, 2376, 2376,
         2376, 2376, 5394, 2376, 2376, 2376, 5394,  547, 5394,  547,  547,  547,
          547,  547, 5662,  547,  547,  547,  547,  547,  547,  547,  547,  547,
          547,  547,  547,  547,  547,  547,  547,  547,  547,  547,  547,  547,
          547,  547,  547,  547,  547,  547,  547,  547,  547,  547,  547,  547,
          547,  547,  547,  547,  547,  547,  547,  547,  547,  547,  547,  547,
          547,  547,  547,  547,  547,  547,  547,  547,  547,  547,  547]])


array([[5876, 5876, 5876, 5876, 5876, 2376, 2376, 2376, 2376, 2376, 2376,
        2376, 2376, 2376, 5394, 2376, 2376, 2376, 5394,  547, 5394,  547,
         547,  547,  547,  547, 5662,  547,  547,  547,  547,  547,  547,
         547,  547,  547,  547,  547,  547,  547,  547,  547,  547,  547,
         547,  547,  547,  547,  547,  547,  547,  547,  547,  547,  547,
         547,  547,  547,  547,  547,  547,  547,  547,  547,  547,  547,
         547,  547,  547,  547,  547,  547,  547,  547,  547,  547,  547,
         547,  547,  547,  547,  547,  547]])

In [None]:

        shift_logits = outputs[1][..., :-1, :].contiguous().argmax(
            dim=-1).cpu()  # Ensure outputs and argmax result are on CPU
        if verbose:
            print('model predict?')
            print(shift_logits)

        # Assuming 'labels' is a key in batch containing true token IDs
        shift_labels = batch['labels'][..., 1:].contiguous().cpu()  # Move labels to CPU
        if verbose:
            print('model true?')
            print(shift_logits)

        non_padding_mask = shift_labels != -100

        # Compare predictions to true labels, but only for non-padding tokens
        acc_train = ((shift_logits == shift_labels) & non_padding_mask).sum().item() / non_padding_mask.sum().item()
        if verbose:
            print(acc_train)

In [None]:
# batch indices
with open('/data/rozen/home/e0833634/lama/protllama/batch_script/train_intermediate_checkpoint_batches_1000000.pkl', 'rb') as f:
    batch_indices_train = pickle.load(f)
# batch indices
with open('/data/rozen/home/e0833634/lama/protllama/batch_script/valid_intermediate_checkpoint_batches_1000000.pkl', 'rb') as f:
    batch_indices_valid = pickle.load(f)

In [97]:
len(batch_indices_train)

253440

In [98]:
len(batch_indices_train)

253440

In [99]:
# build smaller test set
small_batches = [batch for batch in batch_indices_train if all(idx < 1000 for idx in batch)]
small_batches

[[181],
 [816],
 [327],
 [790],
 [227],
 [896],
 [125],
 [435, 618, 756, 881],
 [55, 266],
 [325, 345],
 [823],
 [895],
 [975],
 [643],
 [18],
 [471],
 [558, 647, 653, 681, 697, 764, 864, 882, 917],
 [24],
 [17],
 [35],
 [108, 150, 194, 300, 368, 672],
 [609],
 [955],
 [755],
 [855],
 [617],
 [122, 298],
 [829],
 [51],
 [487, 657],
 [490],
 [365],
 [976],
 [470, 775],
 [730],
 [85, 322],
 [182],
 [530, 541, 543, 548, 712, 785],
 [151, 216, 230, 483, 549, 608, 711, 760],
 [447],
 [250],
 [234],
 [269, 443, 452, 515],
 [313, 956],
 [691],
 [102],
 [935],
 [584, 592],
 [614, 640, 934],
 [843],
 [926],
 [544, 734, 876],
 [696, 731, 736, 773, 965],
 [101, 255, 312, 680],
 [662],
 [572, 604, 690],
 [699, 827, 922],
 [723, 873, 905],
 [212, 219, 655],
 [62],
 [892],
 [299],
 [283],
 [115],
 [840],
 [98],
 [87],
 [0],
 [25],
 [480],
 [522],
 [251],
 [692],
 [525, 605, 978],
 [792],
 [256],
 [163, 781],
 [114],
 [126, 382],
 [82],
 [4, 80, 103],
 [113, 143, 274],
 [794],
 [364],
 [76, 204, 286,

In [None]:
batch_indices = small_batches
train_dataset = CustomDataset(dataset, batch_indices)
train_dataloader = DataLoader(train_dataset, batch_size=None, shuffle=False)  # Note that batch_size is None because your dataset itself is returning batches

In [107]:
train_dataset

KeyboardInterrupt: 

In [103]:
# Flatten the batch indices to select the data from the original dataset
flat_indices = [idx for batch in small_batches for idx in batch]

# Create the shrunken dataset using the flattened indices
shrunken_dataset_train = dataset["train"].select(flat_indices)

# Create a mapping from original indices to their position in the shrunken dataset
index_mapping = {original_idx: shrunken_idx for shrunken_idx, original_idx in enumerate(flat_indices)}

# Define the custom collate function
def collate_fn(batch_indices_list):
    # Convert the original indices to shrunken dataset indices
    mapped_indices = [index_mapping[idx] for idx_list in batch_indices_list for idx in idx_list]

    batch_ = {
        'attention_mask': [shrunken_dataset_train['attention_mask'][i] for i in mapped_indices],
        'input_ids': [shrunken_dataset_train['input_ids'][i] for i in mapped_indices],
        'labels': [shrunken_dataset_train['labels'][i] for i in mapped_indices]
    }
    return batch_

# Custom Dataset to wrap the batch indices
class BatchedDataset(Dataset):
    def __init__(self, batch_indices):
        self.batch_indices = batch_indices

    def __len__(self):
        return len(self.batch_indices)

    def __getitem__(self, idx):
        return self.batch_indices[idx]

# DataLoader
dataloader = DataLoader(BatchedDataset(small_batches), batch_size=1, shuffle=False, collate_fn=collate_fn)

# Usage
for batch in dataloader:
    print(batch)

TypeError: list indices must be integers or slices, not list

In [72]:
len(small_batches)

205

In [51]:
flattened_list = [num for sublist in small_batches for num in sublist]
mapping_dict = {num: idx for idx, num in enumerate(flattened_list)}
mapping_dict

{181: 0,
 816: 1,
 327: 2,
 790: 3,
 227: 4,
 896: 5,
 125: 6,
 435: 7,
 618: 8,
 756: 9,
 881: 10,
 55: 11,
 266: 12,
 325: 13,
 345: 14,
 823: 15,
 895: 16,
 975: 17,
 643: 18,
 18: 19,
 471: 20,
 558: 21,
 647: 22,
 653: 23,
 681: 24,
 697: 25,
 764: 26,
 864: 27,
 882: 28,
 917: 29,
 24: 30,
 17: 31,
 35: 32,
 108: 33,
 150: 34,
 194: 35,
 300: 36,
 368: 37,
 672: 38,
 609: 39,
 955: 40,
 755: 41,
 855: 42,
 617: 43,
 122: 44,
 298: 45,
 829: 46,
 51: 47,
 487: 48,
 657: 49,
 490: 50,
 365: 51,
 976: 52,
 470: 53,
 775: 54,
 730: 55,
 85: 56,
 322: 57,
 182: 58,
 530: 59,
 541: 60,
 543: 61,
 548: 62,
 712: 63,
 785: 64,
 151: 65,
 216: 66,
 230: 67,
 483: 68,
 549: 69,
 608: 70,
 711: 71,
 760: 72,
 447: 73,
 250: 74,
 234: 75,
 269: 76,
 443: 77,
 452: 78,
 515: 79,
 313: 80,
 956: 81,
 691: 82,
 102: 83,
 935: 84,
 584: 85,
 592: 86,
 614: 87,
 640: 88,
 934: 89,
 843: 90,
 926: 91,
 544: 92,
 734: 93,
 876: 94,
 696: 95,
 731: 96,
 736: 97,
 773: 98,
 965: 99,
 101: 100,
 255: 

In [79]:
len(mapping_dict)

368

In [85]:
# build smaller test set
small_batches_valid = [batch for batch in batch_indices_valid if all(idx < 1000 for idx in batch)]
small_batches_valid

[[920],
 [787],
 [475],
 [885],
 [256],
 [374, 469],
 [0],
 [819],
 [64, 410, 491],
 [9],
 [78, 129, 399, 505, 667],
 [697],
 [527],
 [174, 197],
 [561, 564, 682, 738, 852, 856, 957],
 [535],
 [234, 873],
 [861],
 [319, 341, 604, 616, 650],
 [725, 946],
 [128, 607, 829, 967],
 [547],
 [179, 267, 344, 409, 423, 433, 460, 468, 486, 669, 692, 731, 752, 871, 884],
 [921],
 [334, 662, 791, 979],
 [242, 597],
 [960],
 [538],
 [65, 314],
 [329],
 [20],
 [349],
 [514, 892, 912],
 [617],
 [434],
 [168],
 [618, 645],
 [761],
 [259],
 [377],
 [886],
 [777],
 [39],
 [707],
 [769],
 [21, 204, 285, 357, 489, 898, 951],
 [750],
 [257],
 [272],
 [931],
 [250, 336],
 [940],
 [524],
 [690],
 [1],
 [33],
 [325, 508, 649, 658],
 [260],
 [376],
 [360],
 [860],
 [798],
 [571],
 [236, 400, 583, 610],
 [258],
 [843, 923],
 [466],
 [663],
 [86, 88, 211, 387],
 [574],
 [316],
 [942],
 [563],
 [762],
 [153, 201, 215, 365, 452, 529, 606, 894, 917],
 [596],
 [366, 455, 502],
 [117],
 [266, 503],
 [331],
 [708],
 [

In [73]:
len(small_batches_valid)

110

In [54]:
flattened_list_valid = [num for sublist in small_batches_valid for num in sublist]
mapping_dict_valid = {num: idx for idx, num in enumerate(flattened_list_valid)}
mapping_dict_valid

{920: 0,
 787: 1,
 475: 2,
 885: 3,
 256: 4,
 374: 5,
 469: 6,
 0: 7,
 819: 8,
 64: 9,
 410: 10,
 491: 11,
 9: 12,
 78: 13,
 129: 14,
 399: 15,
 505: 16,
 667: 17,
 697: 18,
 527: 19,
 174: 20,
 197: 21,
 561: 22,
 564: 23,
 682: 24,
 738: 25,
 852: 26,
 856: 27,
 957: 28,
 535: 29,
 234: 30,
 873: 31,
 861: 32,
 319: 33,
 341: 34,
 604: 35,
 616: 36,
 650: 37,
 725: 38,
 946: 39,
 128: 40,
 607: 41,
 829: 42,
 967: 43,
 547: 44,
 179: 45,
 267: 46,
 344: 47,
 409: 48,
 423: 49,
 433: 50,
 460: 51,
 468: 52,
 486: 53,
 669: 54,
 692: 55,
 731: 56,
 752: 57,
 871: 58,
 884: 59,
 921: 60,
 334: 61,
 662: 62,
 791: 63,
 979: 64,
 242: 65,
 597: 66,
 960: 67,
 538: 68,
 65: 69,
 314: 70,
 329: 71,
 20: 72,
 349: 73,
 514: 74,
 892: 75,
 912: 76,
 617: 77,
 434: 78,
 168: 79,
 618: 80,
 645: 81,
 761: 82,
 259: 83,
 377: 84,
 886: 85,
 777: 86,
 39: 87,
 707: 88,
 769: 89,
 21: 90,
 204: 91,
 285: 92,
 357: 93,
 489: 94,
 898: 95,
 951: 96,
 750: 97,
 257: 98,
 272: 99,
 931: 100,
 250: 101

In [56]:
mapping_dict_valid[787]

1

In [81]:
for batch in small_batches:
    print(dataset["train"].select(batch))
    print(batch)

Dataset({
    features: ['attention_mask', 'input_ids', 'labels'],
    num_rows: 1
})
[181]
Dataset({
    features: ['attention_mask', 'input_ids', 'labels'],
    num_rows: 1
})
[816]
Dataset({
    features: ['attention_mask', 'input_ids', 'labels'],
    num_rows: 1
})
[327]
Dataset({
    features: ['attention_mask', 'input_ids', 'labels'],
    num_rows: 1
})
[790]
Dataset({
    features: ['attention_mask', 'input_ids', 'labels'],
    num_rows: 1
})
[227]
Dataset({
    features: ['attention_mask', 'input_ids', 'labels'],
    num_rows: 1
})
[896]
Dataset({
    features: ['attention_mask', 'input_ids', 'labels'],
    num_rows: 1
})
[125]
Dataset({
    features: ['attention_mask', 'input_ids', 'labels'],
    num_rows: 4
})
[435, 618, 756, 881]
Dataset({
    features: ['attention_mask', 'input_ids', 'labels'],
    num_rows: 2
})
[55, 266]
Dataset({
    features: ['attention_mask', 'input_ids', 'labels'],
    num_rows: 2
})
[325, 345]
Dataset({
    features: ['attention_mask', 'input_ids', 

In [74]:
small_datasets_train = [dataset["train"].select(batch) for batch in small_batches]
small_datasets_valid = [dataset["valid"].select(batch) for batch in small_batches_valid]

In [76]:
small_train_dataset = concatenate_datasets(small_datasets_train)

In [77]:

small_valid_dataset = concatenate_datasets(small_datasets_valid)

In [78]:
# Use list comprehension to get all small datasets and then concatenate them
dataset_ = DatasetDict({
            'train': small_train_dataset,
            'valid': small_valid_dataset
        })
dataset_

DatasetDict({
    train: Dataset({
        features: ['attention_mask', 'input_ids', 'labels'],
        num_rows: 368
    })
    valid: Dataset({
        features: ['attention_mask', 'input_ids', 'labels'],
        num_rows: 248
    })
})

In [26]:
small_batches[0]

[181]

In [29]:
dataset_['train']['original_index']

[181,
 816,
 327,
 790,
 227,
 896,
 125,
 435,
 618,
 756,
 881,
 55,
 266,
 325,
 345,
 823,
 895,
 975,
 643,
 18,
 471,
 558,
 647,
 653,
 681,
 697,
 764,
 864,
 882,
 917,
 24,
 17,
 35,
 108,
 150,
 194,
 300,
 368,
 672,
 609,
 955,
 755,
 855,
 617,
 122,
 298,
 829,
 51,
 487,
 657,
 490,
 365,
 976,
 470,
 775,
 730,
 85,
 322,
 182,
 530,
 541,
 543,
 548,
 712,
 785,
 151,
 216,
 230,
 483,
 549,
 608,
 711,
 760,
 447,
 250,
 234,
 269,
 443,
 452,
 515,
 313,
 956,
 691,
 102,
 935,
 584,
 592,
 614,
 640,
 934,
 843,
 926,
 544,
 734,
 876,
 696,
 731,
 736,
 773,
 965,
 101,
 255,
 312,
 680,
 662,
 572,
 604,
 690,
 699,
 827,
 922,
 723,
 873,
 905,
 212,
 219,
 655,
 62,
 892,
 299,
 283,
 115,
 840,
 98,
 87,
 0,
 25,
 480,
 522,
 251,
 692,
 525,
 605,
 978,
 792,
 256,
 163,
 781,
 114,
 126,
 382,
 82,
 4,
 80,
 103,
 113,
 143,
 274,
 794,
 364,
 76,
 204,
 286,
 547,
 629,
 613,
 997,
 479,
 718,
 741,
 783,
 521,
 120,
 237,
 306,
 500,
 197,
 223,
 437,
 31,

In [28]:
for i in dataset_['train']['original_index']:
    if i == 819:
        print(dataset_['train']['input_ids'])

In [48]:
small_batches[0]

[181]

In [50]:
mapping_table[181]

725

In [47]:
original_indices = [mapping_table[idx] for idx in small_batches[0]]
original_indices

[725]

In [64]:
original_indices = [mapping_dict[idx] for idx in small_batches[1]]
original_indices

[1]

In [65]:
small_batches[1]

[816]

In [61]:
mapping_dict[181]

0

In [70]:
len(dataset_['train']['attention_mask'][0])

11

In [None]:
[dataset_[split_name]['input_ids'][i] for i in small_batches[0]]

In [66]:
def collate_fn(batch_indices, split_name):
    # Use the mapping table to get the original indices
    mapped_indices = [mapping_dict[idx] for idx in batch_indices[0]]

    batch_ = {
        'attention_mask': [dataset_[split_name]['attention_mask'][i] for i in mapped_indices],
        'input_ids': [dataset_[split_name]['input_ids'][i] for i in mapped_indices],
        'labels': [dataset_[split_name]['labels'][i] for i in mapped_indices]
    }
    return batch_

In [86]:
def collate_fn(batch_indices, split_name):
        # Given a list of indices, retrieve the corresponding batch from your HuggingFace Dataset
    batch_ = {
        'attention_mask': [dataset[split_name]['attention_mask'][i] for i in batch_indices[0]],
        'input_ids': [dataset[split_name]['input_ids'][i] for i in batch_indices[0]],
        'labels': [dataset[split_name]['labels'][i] for i in batch_indices[0]]
    }
    return batch_

In [92]:
dataset['train']['input_ids'][435]

[[1,
  404,
  5798,
  453,
  1204,
  376,
  1015,
  2602,
  1296,
  2756,
  3843,
  7725,
  2407,
  817,
  6581,
  2230,
  422,
  6352,
  885,
  6007,
  675,
  1421,
  7029,
  1524,
  5810,
  6630,
  1354,
  5256,
  4,
  1522,
  754,
  1211,
  2065,
  5927,
  378,
  648,
  3065,
  2837,
  1133,
  3620,
  3543,
  1023,
  5723,
  1538,
  1688,
  97,
  461,
  842,
  5150,
  2267,
  3944,
  2036,
  4756,
  25,
  572,
  6245],
 [1,
  260,
  38,
  3356,
  1121,
  4169,
  600,
  1744,
  3275,
  778,
  479,
  796,
  3,
  3425,
  2062,
  6486,
  6838,
  3822,
  1074,
  3752,
  5262,
  7456,
  6896,
  1203,
  3211,
  2529,
  6704,
  1300,
  7382,
  749,
  2155,
  4669,
  2297,
  3726,
  2400,
  580,
  3565,
  2068,
  140,
  1545,
  3300,
  350,
  1920,
  7455,
  170,
  61,
  525,
  597,
  3127,
  1688,
  2077,
  352,
  1290,
  2318,
  56,
  2662],
 [1,
  820,
  602,
  5306,
  141,
  2374,
  996,
  1810,
  2340,
  459,
  697,
  40,
  4029,
  2869,
  492,
  1481,
  1055,
  36,
  4597,
  1258,
  10

In [87]:
from functools import partial
train_dataloader = DataLoader(small_batches, batch_size=1, pin_memory=True, shuffle=False, num_workers=1,
                                    collate_fn=partial(collate_fn, split_name='train'))
train_dataloader

<torch.utils.data.dataloader.DataLoader at 0x2aac87608df0>

In [88]:
for batch in enumerate(train_dataloader):
    print(batch)
    break

(0, {'attention_mask': [[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 