In [1]:
%reset -f
import torch

import pickle
from torch.utils.data import DataLoader
from datasets import Dataset, DatasetDict, load_from_disk, concatenate_datasets
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data.distributed import DistributedSampler

In [1]:
from torch.optim.lr_scheduler import _LRScheduler

In [2]:
dataset_path = f'/data/rozen/home/e0833634/lama/protllama/batch_script/uniref50_random90split_8k_512_first_1million_dataset.hf'
dataset = load_from_disk(dataset_path)
dataset

DatasetDict({
    train: Dataset({
        features: ['attention_mask', 'input_ids', 'labels'],
        num_rows: 253440
    })
    valid: Dataset({
        features: ['attention_mask', 'input_ids', 'labels'],
        num_rows: 232711
    })
})

In [3]:
small_dataset_dict = DatasetDict({
    'train': dataset['train'].select(range(10)),
    'valid': dataset['valid'].select(range(10))
})
small_dataset_dict

DatasetDict({
    train: Dataset({
        features: ['attention_mask', 'input_ids', 'labels'],
        num_rows: 10
    })
    valid: Dataset({
        features: ['attention_mask', 'input_ids', 'labels'],
        num_rows: 10
    })
})

In [4]:
small_dataset_dict.save_to_disk('small_small.hf')

Saving the dataset (0/1 shards):   0%|          | 0/10 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/10 [00:00<?, ? examples/s]

In [4]:
with open('/data/rozen/home/e0833634/lama/protllama/batch_script/train_intermediate_checkpoint_batches_1000000.pkl', 'rb') as f:
    batch_indices_train = pickle.load(f)
batch_indices_train = batch_indices_train[:10]
batch_indices_train

[[769188, 769231, 769281, 770055, 770166, 770414],
 [55962],
 [951805, 951834],
 [34696, 36909],
 [39459, 39650, 39713, 39794, 39797, 39858, 40168, 40532, 40713, 40750, 41035],
 [139898, 139951, 140097],
 [223249, 224129],
 [181],
 [227032,
  227315,
  227702,
  227839,
  227841,
  228178,
  228288,
  228420,
  228457,
  228582],
 [94817, 94880]]

In [5]:
with open('/data/rozen/home/e0833634/lama/protllama/batch_script/valid_intermediate_checkpoint_batches_1000000.pkl', 'rb') as f:
    batch_indices_val = pickle.load(f)
batch_indices_val = batch_indices_val[:10]
batch_indices_val

[[178154, 178417],
 [913267, 913283, 913593, 913710, 913882, 913995],
 [419540],
 [291646, 292406],
 [63716, 63773, 63979, 64054, 64094, 64127, 64206, 64264, 64340, 64459, 64520],
 [363967, 364260, 364909],
 [595243, 595280, 595306, 595422, 595542],
 [220497],
 [78772, 79030],
 [689663, 689682]]

In [33]:
from torch.utils.data import Dataset, DataLoader
import torch

class DynamicBatchingDataset(Dataset):
    def __init__(self, dataset_dict, batch_indices):
        print('Initializing dataset...')
        self.dataset_dict = dataset_dict
        self.batch_indices = batch_indices  # This is mainly for informational purposes, if needed.

    def __len__(self):
        return len(self.dataset_dict['attention_mask'])  # Assuming each entry in dataset_dict represents a batch

    def __getitem__(self, idx):
        # Check if idx is an integer or a list
        if isinstance(idx, int):
            indices = [idx]
        else:
            indices = idx

        attention_masks = []
        input_ids = []
        labels = []
        for index in indices:
            attention_masks.append(torch.tensor(self.dataset_dict['attention_mask'][index]))
            input_ids.append(torch.tensor(self.dataset_dict['input_ids'][index]))
            labels.append(torch.tensor(self.dataset_dict['labels'][index]))

        return {
            'attention_mask': attention_masks,
            'input_ids': input_ids,
            'labels': labels
        }


    @staticmethod
    def collate_fn(batch):
        # Since DataLoader's batch_size is 1, batch[0] contains your pre-batched data
        item = batch[0]

        attention_mask = item['attention_mask'][0]
        input_ids = item['input_ids'][0]
        labels = item['labels'][0]

        # These are already pre-padded, so you can directly return
        return {
            'attention_mask': attention_mask,
            'input_ids': input_ids,
            'labels': labels
        }


    @staticmethod
    def dynamic_padding_collate_fn(batch):
        # Extract sequences from the batch
        #print(batch)

        attention_masks = [item['attention_mask'] for item in batch]
        input_ids = [item['input_ids'] for item in batch]
        labels = [item['labels'] for item in batch]
        #print(attention_masks.size())
        # Flatten the sequences and then pad them
        attention_masks_flat = [sequence for batch_item in attention_masks for sequence in batch_item]
        print([seq.shape for seq in attention_masks_flat])
        print(attention_masks_flat)
        input_ids_flat = [sequence for batch_item in input_ids for sequence in batch_item]
        labels_flat = [sequence for batch_item in labels for sequence in batch_item]

        # Pad sequences
        attention_masks_padded = torch.cat(attention_masks_flat, dim=0)
        input_ids_padded = torch.cat(input_ids_flat, dim=0)
        labels_padded = torch.cat(labels_flat, dim=0)

        attention_masks_padded = pad_sequence(attention_masks_flat, batch_first=True)
        input_ids_padded = pad_sequence(input_ids_flat, batch_first=True)
        labels_padded = pad_sequence(labels_flat, batch_first=True, padding_value=-100)

        return {
            'attention_mask': attention_masks_padded,
            'input_ids': input_ids_padded,
            'labels': labels_padded
        }

train_dataset = DynamicBatchingDataset(small_dataset_dict['train'], batch_indices_train)

# Batch size is set to 1 because your dataset itself is returning batches
#train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=False, collate_fn=DynamicBatchingDataset.custom_collate_fn)
train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=False, collate_fn=DynamicBatchingDataset.collate_fn, num_workers=2)

Initializing dataset...


In [34]:
import numpy as np
for idx, batch in enumerate(train_dataloader):
    print(batch['input_ids'].shape)
    print(idx, batch)
    print(batch['input_ids'])

torch.Size([6, 84])
0 {'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1,

In [12]:
import numpy as np
np.shape(batch['attention_mask'])

torch.Size([2, 228])

In [13]:
 def __getitem__(self, idx):
        #batch_idx = self.batch_indices[idx]
        # Directly retrieve the batch using the index
        print(idx, type(idx))
        print(len(self.dataset_dict['attention_mask']))
        #print(self.dataset_dict['attention_mask'])
        print(np.shape(self.dataset_dict['attention_mask'][idx]))
        if torch.distributed.is_initialized():
            if torch.distributed.get_rank() == 0:
                print(f"Process 0 is fetching index: {idx}")
            elif torch.distributed.get_rank() == 1:
                print(f"Process 1 is fetching index: {idx}")
        """returns [seq_number, token_length], return one batch at a time"""
        attention_mask = torch.tensor(self.dataset_dict['attention_mask'][idx])
        input_ids = torch.tensor(self.dataset_dict['input_ids'][idx])
        label = torch.tensor(self.dataset_dict['labels'][idx])

        return {
            'attention_mask': attention_mask,
            'input_ids': input_ids,
            'labels': label
        }

In [14]:
import pytorch_lightning as pl
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaForCausalLM
import sentencepiece as spm
class CustomDataModule(pl.LightningDataModule):
    def __init__(self, dataset_dict, batch_indices_train, batch_indices_val, batch_size=1):
        super().__init__()
        self.dataset_dict = dataset_dict
        self.batch_indices_train = batch_indices_train
        self.batch_indices_val = batch_indices_val
        self.batch_size = batch_size

        self.tokenizer = self.tokenizer_generation('protein', '8k')

    @staticmethod
    def tokenizer_generation(target, vocab_size):
        if target == 'original':
            tokenizer = LlamaTokenizer.from_pretrained('hf-internal-testing/llama-tokenizer')
            tokenizer.pad_token = tokenizer.unk_token
            return tokenizer
        elif target == 'protein':
            tokenizer_path = '/data/rozen/home/e0833634/lama/protllama/batch_script/'
            tokenizer = spm.SentencePieceProcessor(model_file=tokenizer_path + "protein_%s.model" % (vocab_size))
            return tokenizer
        else:
            raise ValueError('Have not prepared tokenizer for this target')

    def prepare_data(self):
        # Possibly download data, set transforms, etc.
        pass

    def setup(self, stage=None):
        # Assign train/val datasets for use in dataloaders
        self.train_dataset = DynamicBatchingDataset(self.dataset_dict['train'], self.batch_indices_train)
        # Repeat similar steps for validation and test datasets if needed
        self.val_dataset = DynamicBatchingDataset(self.dataset_dict['valid'], self.batch_indices_val)

    def train_dataloader(self):
        print("dataloader created...")
        d = DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=False, num_workers=2, collate_fn=DynamicBatchingDataset.collate_fn)
        if self.trainer.global_rank == 0:
            for idx, batch in enumerate(d):
                print(idx, batch)
        return d

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=2, collate_fn=DynamicBatchingDataset.collate_fn)

  warn(f"Failed to load image Python extension: {e}")
2023-10-16 10:22:22.895647: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [17]:
import pytorch_lightning as pl
import transformers
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaForCausalLM
from torch.optim import AdamW
from transformers import get_linear_schedule_with_warmup
from transformers import get_cosine_schedule_with_warmup
import os
import logging as log
import glob
from argparse import ArgumentParser
from protllama.bin.data import PretrainDataset
import torch
class pretrainLlama(pl.LightningModule):
    def __init__(self, hparam) -> None:
        super(pretrainLlama, self).__init__()
        self.save_hyperparameters()
        self.hparam = hparam  # need to contain epoch, target, date, learning rate, batch_size, num_frozen_epochs
        self.MODEL_CONFIGS = self.retrieve_config()
        self.__build_model()
        self.tokenizer = self.tokenizer_generation('protein', '8k')

    @staticmethod
    def tokenizer_generation(target, vocab_size):
        if target == 'original':
            tokenizer = LlamaTokenizer.from_pretrained('hf-internal-testing/llama-tokenizer')
            tokenizer.pad_token = tokenizer.unk_token
            return tokenizer
        elif target == 'protein':
            tokenizer_path = '/data/rozen/home/e0833634/lama/protllama/batch_script/'
            tokenizer = spm.SentencePieceProcessor(model_file=tokenizer_path + "protein_%s.model" % (vocab_size))
            return tokenizer
        else:
            raise ValueError('Have not prepared tokenizer for this target')

    def retrieve_config(self):
        """ return transformers DATASET object"""
        if self.hparam.target == 'original':
            config_dict = {'7b': LlamaConfig(max_position_embeddings=self.hparam.max_position_embeddings,
                                             hidden_size=self.hparam.hidden_size,
                                             intermediate_size=self.hparam.intermediate_size)}
            return config_dict['7b']
        elif self.hparam.target == 'protein':
            config_dict = {
                'protllama2': LlamaConfig(max_position_embeddings=self.hparam.max_position_embeddings,  # maximum length
                                          hidden_size=self.hparam.hidden_size,
                                          transformers_version=transformers.__version__,
                                          intermediate_size=self.hparam.intermediate_size,
                                          vocab_size=int(self.hparam.vocab_size.rstrip('k')) * 1000)}
            #print(config_dict['protllama2'])
            return config_dict['protllama2']
        else:
            raise ValueError('Have not prepared dataset for this target')

    def __build_model(self) -> None:
        """start model building, can add customized classification head"""
        self.model = LlamaForCausalLM(self.MODEL_CONFIGS)
        #print(self.model.lm_head.weight)

    def configure_optimizers(self):
        """set learning rates"""
        if self.hparam.scheduler == 'linear':
            parameters = self.model.parameters()
            optimizer = AdamW(parameters, lr=self.hparam.learning_rate, betas=(0.9, 0.95), weight_decay=0.1)
            lr_schedulers = {
                "scheduler": get_linear_schedule_with_warmup(optimizer,
                                                            num_warmup_steps=100,
                                                            num_training_steps=self.hparam.epoch * self.hparam.train_dataset_length),
                "name": 'learning_rate_logs'
            }
            return [optimizer], [lr_schedulers]

    def forward(self, **inputs):
        """ Pytorch forward function
        Returns:
        dict with model outputs (loss, logits, hidden layer, attention)
        """
        return self.model(**inputs)

    def training_step(self, batch, batch_nb: int, verbose=False):
        print(batch.keys())
        if torch.distributed.is_initialized():
            print(f"Process {torch.distributed.get_rank()} starting training step")
        print(batch)
        print(batch['input_ids'].shape)
        outputs = self.forward(**batch)
        loss_train = outputs[0]

        # Compute the perplexity
        perplexity = torch.exp(outputs[0].cpu())  # Ensure outputs are on CPU

        # Accuracy computation
        # Shifting
        shift_logits = outputs[1][..., :-1, :].contiguous().argmax(
            dim=-1).cpu()  # Ensure outputs and argmax result are on CPU

        # Assuming 'labels' is a key in batch containing true token IDs
        shift_labels = batch['labels'][..., 1:].contiguous().cpu()  # Move labels to CPU

        non_padding_mask = shift_labels != -100
        # Compare predictions to true labels, but only for non-padding tokens
        acc_train = ((shift_logits == shift_labels) & non_padding_mask).sum().item() / non_padding_mask.sum().item()

        print('train', loss_train, perplexity, acc_train)

        # Log
        self.log('train_loss', loss_train, on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
        self.log('train_perplexity', perplexity, on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
        self.log('train_accuracy', acc_train, on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)

        return loss_train

    def validation_step(self, batch, batch_nb: int, verbose=False):
        """ Similar to the training step but with the model in eval mode.
        Returns:
            - dictionary passed to the validation_end function.
        """
        outputs = self.forward(**batch)
        loss_val = outputs[0].cpu()

        # Compute the perplexity
        perplexity = torch.exp(loss_val)  # Ensure outputs are on CPU

        # Accuracy computation
        # Shifting
        shift_logits = outputs[1][..., :-1, :].contiguous().argmax(
            dim=-1).cpu()  # Ensure outputs and argmax result are on CPU

        # Assuming 'labels' is a key in batch containing true token IDs
        shift_labels = batch['labels'][..., 1:].contiguous().cpu()  # Move labels to CPU

        non_padding_mask = shift_labels != -100

        # Compare predictions to true labels, but only for non-padding tokens
        acc_val = ((shift_logits == shift_labels) & non_padding_mask).sum().item() / non_padding_mask.sum().item()

        print('val', loss_val, perplexity, acc_val)
        # Log
        self.log('val_loss', loss_val, on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
        self.log('val_perplexity', perplexity, on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
        self.log('val_accuracy', acc_val, on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
        return loss_val


In [18]:
import argparse
import os
from protllama.bin.data import PretrainDataset
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, TQDMProgressBar, LearningRateMonitor
from pytorch_lightning.loggers import WandbLogger

# set wandb offline on HPC
os.environ['WANDB_MODE'] = "offline"

from types import SimpleNamespace

hparam = SimpleNamespace(
    date='Oct_11',
    target='protein',
    max_position_embeddings=512,
    vocab_size='8k',
    hidden_size=640,
    intermediate_size=1720,
    save_top_k=1,
    scheduler='linear',
    learning_rate=3e-4,
    epoch=1
    #... add all other arguments similarly
)


#dm = PretrainDataset(target=hparam.target,
                     #max_sequence_length=hparam.max_position_embeddings)
data_module = CustomDataModule(dataset_dict=small_dataset_dict,
                               batch_indices_train=batch_indices_train,
                               batch_indices_val=batch_indices_val,
                               batch_size=1)

# make sure dataset has "training" key
hparam.train_dataset_length = 10
training_log_path = str('protllama/pl_logs/')
if not os.path.exists(training_log_path):
    os.makedirs(training_log_path)
logger = WandbLogger(project="protllama2",
                     name=f"{hparam.target}_{hparam.date}_pre-training_log", #display on the web
                     save_dir='protllama/pl_logs/',
                     job_type='model-training',
                     group=f'pretrain_protllama2_{hparam.vocab_size}_{hparam.max_position_embeddings}',
                     id='version_%s' % str(1))
seed_everything(42)
model = pretrainLlama(hparam)
early_stop_callback = EarlyStopping(
    monitor="loss",
    min_delta=0.0,
    patience=0,  # number of epoch with no improvement
    verbose=True,
    mode="min",
)
training_model_path = str('protllama/pl_model_cache/')
if not os.path.exists(training_model_path):
    os.makedirs(training_model_path)
checkpoint_callback = ModelCheckpoint(
    dirpath=training_model_path,
    filename="{epoch}-{train_loss:.2f}-{val_loss:.2f}-%s_%s_%s_%s" % (hparam.target, hparam.date, hparam.vocab_size, hparam.max_position_embeddings),
    save_top_k=hparam.save_top_k,
    verbose=True,
    monitor="val_loss",
    mode="min",
)
lr_monitor = LearningRateMonitor(
    logging_interval='epoch'
)
#os.environ["NCCL_SOCKET_IFNAME"] = "lo"  # Or another interface if not eth0
os.environ['CUDA_VISIBLE_DEVICES']="0,1"
os.environ["NCCL_IB_DISABLE"] = "1"
os.environ["NCCL_SOCKET_IFNAME"] = "eth0"
os.environ["NCCL_DEBUG"] = "INFO"
os.environ["NCCL_PROTO"] = "TCP"
os.environ["NCCL_BLOCKING_WAIT"] = "1"
os.environ["NCCL_DEBUG_SUBSYS"] = "ALL"

print("About to initialize the Trainer...")
trainer = Trainer(
    devices=1,
    accelerator='gpu',
    #strategy='ddp_notebook',
    num_nodes=1,
    fast_dev_run=True,
    limit_train_batches=2,
    max_epochs=1,
    logger=logger,
    # max_epochs=1,
    # min_epochs=1,
    #callbacks=[TQDMProgressBar(refresh_rate=10), lr_monitor],
    #deterministic=True,
    enable_model_summary=True
)
print("Trainer initialized.")
torch.set_float32_matmul_precision('medium')
# automatic garbage collection
import gc
gc.collect()

if torch.distributed.is_initialized():
    print(f"Process {torch.distributed.get_rank()} initialized")

trainer.fit(model, datamodule=data_module)

  rank_zero_warn(
Global seed set to 42
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Running in `fast_dev_run` mode: will run the requested loop using 1 batch(es). Logging and checkpointing is suppressed.


About to initialize the Trainer...
Trainer initialized.


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name  | Type             | Params
-------------------------------------------
0 | model | LlamaForCausalLM | 168 M 
-------------------------------------------
168 M     Trainable params
0         Non-trainable params
168 M     Total params
673.549   Total estimated model params size (MB)


Initializing dataset...
Initializing dataset...
dataloader created...
0 {'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 

Training: 0it [00:00, ?it/s]

dict_keys(['attention_mask', 'input_ids', 'labels'])
{'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_steps=1` reached.


val tensor(9.1401) tensor(9321.6484) 0.0


In [None]:
outputs = np.array([[5876, 5876, 5876, 5876, 5876, 2376, 2376, 2376, 2376, 2376, 2376, 2376,
         2376, 2376, 5394, 2376, 2376, 2376, 5394,  547, 5394,  547,  547,  547,
          547,  547, 5662,  547,  547,  547,  547,  547,  547,  547,  547,  547,
          547,  547,  547,  547,  547,  547,  547,  547,  547,  547,  547,  547,
          547,  547,  547,  547,  547,  547,  547,  547,  547,  547,  547,  547,
          547,  547,  547,  547,  547,  547,  547,  547,  547,  547,  547,  547,
          547,  547,  547,  547,  547,  547,  547,  547,  547,  547,  547]])


In [None]:

        shift_logits = outputs[1][..., :-1, :].contiguous().argmax(
            dim=-1).cpu()  # Ensure outputs and argmax result are on CPU
        if verbose:
            print('model predict?')
            print(shift_logits)

        # Assuming 'labels' is a key in batch containing true token IDs
        shift_labels = batch['labels'][..., 1:].contiguous().cpu()  # Move labels to CPU
        if verbose:
            print('model true?')
            print(shift_logits)

        non_padding_mask = shift_labels != -100

        # Compare predictions to true labels, but only for non-padding tokens
        acc_train = ((shift_logits == shift_labels) & non_padding_mask).sum().item() / non_padding_mask.sum().item()
        if verbose:
            print(acc_train)

In [None]:
# batch indices
with open('/data/rozen/home/e0833634/lama/protllama/batch_script/train_intermediate_checkpoint_batches_1000000.pkl', 'rb') as f:
    batch_indices_train = pickle.load(f)
# batch indices
with open('/data/rozen/home/e0833634/lama/protllama/batch_script/valid_intermediate_checkpoint_batches_1000000.pkl', 'rb') as f:
    batch_indices_valid = pickle.load(f)

In [None]:
len(batch_indices_train)

In [None]:
len(batch_indices_train)

In [None]:
# build smaller test set
small_batches = [batch for batch in batch_indices_train if all(idx < 1000 for idx in batch)]
small_batches

In [None]:
batch_indices = small_batches
train_dataset = CustomDataset(dataset, batch_indices)
train_dataloader = DataLoader(train_dataset, batch_size=None, shuffle=False)  # Note that batch_size is None because your dataset itself is returning batches

In [None]:
train_dataset

In [None]:
# Flatten the batch indices to select the data from the original dataset
flat_indices = [idx for batch in small_batches for idx in batch]

# Create the shrunken dataset using the flattened indices
shrunken_dataset_train = dataset["train"].select(flat_indices)

# Create a mapping from original indices to their position in the shrunken dataset
index_mapping = {original_idx: shrunken_idx for shrunken_idx, original_idx in enumerate(flat_indices)}

# Define the custom collate function
def collate_fn(batch_indices_list):
    # Convert the original indices to shrunken dataset indices
    mapped_indices = [index_mapping[idx] for idx_list in batch_indices_list for idx in idx_list]

    batch_ = {
        'attention_mask': [shrunken_dataset_train['attention_mask'][i] for i in mapped_indices],
        'input_ids': [shrunken_dataset_train['input_ids'][i] for i in mapped_indices],
        'labels': [shrunken_dataset_train['labels'][i] for i in mapped_indices]
    }
    return batch_

# Custom Dataset to wrap the batch indices
class BatchedDataset(Dataset):
    def __init__(self, batch_indices):
        self.batch_indices = batch_indices

    def __len__(self):
        return len(self.batch_indices)

    def __getitem__(self, idx):
        return self.batch_indices[idx]

# DataLoader
dataloader = DataLoader(BatchedDataset(small_batches), batch_size=1, shuffle=False, collate_fn=collate_fn)

# Usage
for batch in dataloader:
    print(batch)

In [None]:
len(small_batches)

In [None]:
flattened_list = [num for sublist in small_batches for num in sublist]
mapping_dict = {num: idx for idx, num in enumerate(flattened_list)}
mapping_dict

In [None]:
len(mapping_dict)

In [None]:
# build smaller test set
small_batches_valid = [batch for batch in batch_indices_valid if all(idx < 1000 for idx in batch)]
small_batches_valid

In [None]:
len(small_batches_valid)

In [None]:
flattened_list_valid = [num for sublist in small_batches_valid for num in sublist]
mapping_dict_valid = {num: idx for idx, num in enumerate(flattened_list_valid)}
mapping_dict_valid

In [None]:
mapping_dict_valid[787]

In [None]:
for batch in small_batches:
    print(dataset["train"].select(batch))
    print(batch)

In [None]:
small_datasets_train = [dataset["train"].select(batch) for batch in small_batches]
small_datasets_valid = [dataset["valid"].select(batch) for batch in small_batches_valid]

In [None]:
small_train_dataset = concatenate_datasets(small_datasets_train)

In [None]:

small_valid_dataset = concatenate_datasets(small_datasets_valid)

In [None]:
# Use list comprehension to get all small datasets and then concatenate them
dataset_ = DatasetDict({
            'train': small_train_dataset,
            'valid': small_valid_dataset
        })
dataset_

In [None]:
small_batches[0]

In [None]:
dataset_['train']['original_index']

In [None]:
for i in dataset_['train']['original_index']:
    if i == 819:
        print(dataset_['train']['input_ids'])

In [None]:
small_batches[0]

In [None]:
mapping_table[181]

In [None]:
original_indices = [mapping_table[idx] for idx in small_batches[0]]
original_indices

In [None]:
original_indices = [mapping_dict[idx] for idx in small_batches[1]]
original_indices

In [None]:
small_batches[1]

In [None]:
mapping_dict[181]

In [None]:
len(dataset_['train']['attention_mask'][0])

In [None]:
[dataset_[split_name]['input_ids'][i] for i in small_batches[0]]

In [None]:
def collate_fn(batch_indices, split_name):
    # Use the mapping table to get the original indices
    mapped_indices = [mapping_dict[idx] for idx in batch_indices[0]]

    batch_ = {
        'attention_mask': [dataset_[split_name]['attention_mask'][i] for i in mapped_indices],
        'input_ids': [dataset_[split_name]['input_ids'][i] for i in mapped_indices],
        'labels': [dataset_[split_name]['labels'][i] for i in mapped_indices]
    }
    return batch_

In [None]:
def collate_fn(batch_indices, split_name):
        # Given a list of indices, retrieve the corresponding batch from your HuggingFace Dataset
    batch_ = {
        'attention_mask': [dataset[split_name]['attention_mask'][i] for i in batch_indices[0]],
        'input_ids': [dataset[split_name]['input_ids'][i] for i in batch_indices[0]],
        'labels': [dataset[split_name]['labels'][i] for i in batch_indices[0]]
    }
    return batch_

In [None]:
dataset['train']['input_ids'][435]

In [None]:
from functools import partial
train_dataloader = DataLoader(small_batches, batch_size=1, pin_memory=True, shuffle=False, num_workers=1,
                                    collate_fn=partial(collate_fn, split_name='train'))
train_dataloader

In [None]:
for batch in enumerate(train_dataloader):
    print(batch)
    break