In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from transformers import AutoModel, AutoConfig, AutoTokenizer
from datasets import load_dataset
from omegaconf import OmegaConf, DictConfig
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
import os
import copy 
from transformers import AutoModelForSequenceClassification

In [2]:
class EMA:
    """
    Modified version of class fairseq.models.ema.EMAModule.

    Args:
        model (nn.Module):
        cfg (DictConfig):
        device (str):
        skip_keys (list): The keys to skip assigning averaged weights to.
    """

    def __init__(self, model: nn.Module, cfg, skip_keys=None):
        self.model = self.deepcopy_model(model)
        self.model.requires_grad_(False)
        self.cfg = cfg
        self.device = cfg.device
        self.model.to(self.device)
        self.skip_keys = skip_keys or set()
        self.decay = self.cfg.model.ema_decay
        self.num_updates = 0

    @staticmethod
    def deepcopy_model(model):
        try:
            model = copy.deepcopy(model)
            return model
        except RuntimeError:
            tmp_path = 'tmp_model_for_ema_deepcopy.pt'
            torch.save(model, tmp_path)
            model = torch.load(tmp_path)
            os.remove(tmp_path)
            return model

    def step(self, new_model: nn.Module):
        """
        One EMA step

        Args:
            new_model (nn.Module): Online model to fetch new weights from

        """
        ema_state_dict = {}
        ema_params = self.model.state_dict()
        for key, param in new_model.state_dict().items():
            ema_param = ema_params[key].float()
            if key in self.skip_keys:
                ema_param = param.to(dtype=ema_param.dtype).clone()
            else:
                ema_param.mul_(self.decay)
                ema_param.add_(param.to(dtype=ema_param.dtype), alpha=1 - self.decay)
            ema_state_dict[key] = ema_param
        self.model.load_state_dict(ema_state_dict, strict=False)
        self.num_updates += 1

    def restore(self, model: nn.Module):
        """
        Reassign weights from another model

        Args:
            model (nn.Module): model to load weights from.

        Returns:
            model with new weights
        """
        d = self.model.state_dict()
        model.load_state_dict(d, strict=False)
        return model

    def state_dict(self):
        return self.model.state_dict()

    @staticmethod
    def get_annealed_rate(start, end, curr_step, total_steps):
        """
        Calculate EMA annealing rate
        """
        r = end - start
        pct_remaining = 1 - curr_step / total_steps
        return end - r * pct_remaining

In [3]:
class Data2Vec(nn.Module):
    """
    Data2Vec main module.

    Args:
         encoder (nn.Module): The encoder module like BEiT, ViT, etc.
         cfg (omegaconf.DictConfig): The config containing model properties
    """
    MODALITIES = ['vision', 'text', 'audio']

    def __init__(self, encoder, cfg, **kwargs):
        super(Data2Vec, self).__init__()
        self.modality = cfg.modality
        self.embed_dim = cfg.model.embed_dim
        self.encoder = encoder
        self.__dict__.update(kwargs)

        self.cfg = cfg
        self.ema = EMA(self.encoder, cfg)  # EMA acts as the teacher
        self.regression_head = self._build_regression_head()

        self.cfg = cfg
        self.ema_decay = self.cfg.model.ema_decay
        self.ema_end_decay = self.cfg.model.ema_end_decay
        self.ema_anneal_end_step = self.cfg.model.ema_anneal_end_step

    def _build_regression_head(self):
        """
        Construct the regression head consisting of linear and activation layers.

        Each modality might have its own regression block.

        Returns:
            A nn.Module layer or block of layers
        """
        if self.modality == 'text':
            return nn.Sequential(nn.Linear(self.embed_dim, self.embed_dim * 2),
                                 nn.GELU(),
                                 nn.Linear(self.embed_dim * 2, self.embed_dim))

        if self.modality in ['audio', 'vision']:
            return nn.Linear(self.embed_dim, self.embed_dim)

    def ema_step(self):
        """
        One EMA step for the offline model until the ending decay value is reached
        """
        if self.ema_decay != self.ema_end_decay:
            if self.ema.num_updates >= self.ema_anneal_end_step:
                decay = self.ema_end_decay
            else:
                decay = self.ema.get_annealed_rate(
                    self.ema_decay,
                    self.ema_end_decay,
                    self.ema.num_updates,
                    self.ema_anneal_end_step,
                )
            self.ema.decay = decay
        if self.ema.decay < 1:
            self.ema.step(self.encoder)

    def forward(self, src, trg=None, mask=None, **kwargs):
        """
        Data2Vec forward method.

        Args:
            src: src tokens (masked inputs for training)
            trg: trg tokens (unmasked inputs for training but left as `None` otherwise)
            mask: bool masked indices, Note: if a modality requires the inputs to be masked before forward this param
            has no effect. (see the Encoder for each modality to see if it uses mask or not)

        Returns:
            Either encoder outputs or a tuple of encoder + EMA outputs

        """
        # model forward in online mode (student)
        x = self.encoder(src, mask, **kwargs)['encoder_out']  # fetch the last layer outputs
        if trg is None:
            return x

        # model forward in offline mode (teacher)
        with torch.no_grad():
            self.ema.model.eval()
            y = self.ema.model(trg, ~mask, **kwargs)['encoder_states']  # fetch the last transformer layers outputs
            y = y[-self.cfg.model.average_top_k_layers:]  # take the last k transformer layers

            # Follow the same layer normalization procedure for text and vision
            if self.modality in ['vision', 'text']:
                y = [F.layer_norm(tl.float(), tl.shape[-1:]) for tl in y]
                y = sum(y) / len(y)
                if self.cfg.model.normalize_targets:
                    y = F.layer_norm(y.float(), y.shape[-1:])

            # Use instance normalization for audio
            elif self.modality == 'audio':
                y = [F.instance_norm(tl.float()) for tl in y]
                y = sum(y) / len(y)
                if self.cfg.model.normalize_targets:
                    y = F.instance_norm(y.transpose(1, 2)).transpose(1, 2)

        x = x[mask]
        y = y[mask]

        x = self.regression_head(x)

        return x, y

In [10]:
class WikiText(Dataset):
    """
    A Dataset instance for WikiText dataset loaded from HuggingFace datasets.

    Args:
        cfg (DictConfig): config object
        split: Split to load ['train', 'test']
        tokenizer: A HuggingFace Tokenizer model like BPE
        **kwargs: extra args which are set as dataset properties
    """

    def __init__(self, cfg, split, tokenizer, **kwargs):
        super(WikiText, self).__init__()
        self.cfg = cfg
        self.path = cfg.dataset.name
        self.mlm_probability = cfg.dataset.mlm_probability
        raw_data = load_dataset('wikitext', self.path)[split]
        self.data = self.clean_dataset(raw_data) if self.cfg.dataset.clean_dataset else raw_data
        self.tokenizer = tokenizer
        self.__dict__.update(kwargs)

    def clean_dataset(self, data):
        """
        Cleanup dataset by removing invalid sized samples, etc.
        """
        print('Cleaning dataset ...')
        min_seq_len, max_seq_len = self.cfg.dataset.valid_seq_lenghts
        texts = []
        with tqdm(data, desc='Removing invalid sized inputs: ') as tbar:
            for i, x in enumerate(tbar):
                if len(x['text']) in range(min_seq_len, max_seq_len + 1):
                    texts.append(x)
        return texts

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        """
        Only return tokens from raw text with no additions e.g, padding, bos/eos, etc.
        Args:
            index: sample index to pick from dataset

        Returns:
            tokenized outputs
        """
        raw_text = self.data[index]['text']
        tokens = self.tokenizer(raw_text, return_attention_mask=False)
        return tokens

    def _mask_tokens(self, inputs, special_tokens_mask=None):
        """
        Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. Ported
         from `transformers.data.DataCollatorForLanguageModeling.torch_mask_tokens()`
        Args:
            inputs: batch of input tokens
            special_tokens_mask:

        Returns:
            a dict batch of masked and padded inputs/labels

        """
        labels = inputs.clone()
        # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
        probability_matrix = torch.full(labels.shape, self.mlm_probability)
        if special_tokens_mask is None:
            special_tokens_mask = [
                self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in
                labels.tolist()
            ]
            special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
        else:
            special_tokens_mask = special_tokens_mask.bool()

        probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
        masked_indices = torch.bernoulli(probability_matrix).bool()
        labels[~masked_indices] = self.tokenizer.pad_token_id
        # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
        indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
        inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)

        # 10% of the time, we replace masked input tokens with random word
        indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
        random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
        inputs[indices_random] = random_words[indices_random]

        # The rest of the time (10% of the time) we keep the masked input tokens unchanged
        return inputs, labels, masked_indices

    def collate_fn(self, batch):
        """
        Collate the batch of data using BERT masking strategy. carefully ported from
         transformers.data.DataCollatorForLanguageModeling
        Args:
            batch: batch of data

        Returns:
            same batch of data masked and padded
        """
        batch = self.tokenizer.pad(batch, return_tensors="pt")
        # If special token mask has been preprocessed, pop it from the dict.
        special_tokens_mask = batch.pop("special_tokens_mask", None)
        src, trg, masked_indices = self._mask_tokens(
            batch["input_ids"], special_tokens_mask=special_tokens_mask
        )
        return src, trg, masked_indices


#         batch = self.tokenizer.pad(batch, return_tensors="pt")
#         # If special token mask has been preprocessed, pop it from the dict.
#         special_tokens_mask = batch.pop("special_tokens_mask", None)
#         src, trg, masked_indices = self._mask_tokens(
#             batch["input_ids"], special_tokens_mask=special_tokens_mask
#         )
#         return {
#         'input_ids': src,
#         'labels': trg,
#         'masked_indices': masked_indices
#         }

In [11]:

class Encoder(nn.Module):
    """
    Encoder model using HuggingFace for NLP

    To load your desired model specify model checkpoint under cfg.model.encoder_checkpoint

    Args:
        cfg: An omegaconf.DictConf instance containing all the configurations.
        **kwargs: extra args which are set as model properties
    """

    def __init__(self, cfg, **kwargs):
        super(Encoder, self).__init__()
        self.cfg = cfg
        checkpoint = cfg.model.encoder_checkpoint
        model_config = AutoConfig.from_pretrained(checkpoint)
        self.encoder = AutoModel.from_config(model_config)
        self.__dict__.update(kwargs)

    def forward(self, inputs, mask=None, **kwargs):
        """
        Forward inputs through the encoder and extract transformer/attention layers outputs

        Args:
            inputs: source tokens
            mask: bool masked indices
            kwargs: keyword args specific to the encoder's forward method

        Returns:
            A dictionary of the encoder outputs including transformer layers outputs and attentions outputs

        """
        # Note: inputs are already masked for MLM so mask is not used
        outputs = self.encoder(inputs, output_hidden_states=True, output_attentions=True, **kwargs)
        encoder_states = outputs['hidden_states'][:-1]  # encoder layers outputs separately
        encoder_out = outputs['hidden_states'][-1]      # last encoder output (accumulated)
        attentions = outputs['attentions']
        return {
            'encoder_states': encoder_states,
            'encoder_out': encoder_out,
            'attentions': attentions
        }

In [12]:
class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)


def maybe_save_checkpoint(model, optimizer, path, epoch_num, save_freq):
    """
    Save a checkpoint specific to Data2Vec
    Args:
        model: a nn.Module instance
        optimizer
        path: path to save checkpoint to
        epoch_num: current epoch number
        save_freq: save frequency based on epoch number

    """
    if not os.path.exists(path):
        os.makedirs(path)
    path = os.path.join(path, f'{epoch_num}.pt')
    if epoch_num % save_freq == 0:
        checkpoint = {'data2vec': model.state_dict(),
                      'encoder': model.encoder.encoder.state_dict(),
                      'optimizer': optimizer.state_dict()}
        torch.save(checkpoint, path)
        print(f'Saved checkpoint to `{path}`')

In [13]:
"""
Train Data2Vec for text. The encoder is loaded from huggingface specified in the config file.
"""

class TextTrainer:
    """
    A Trainer class to train NLP model on Data2Vec.

    Args:
        cfg (DictConfig): the config object containing all properties
    """

    def __init__(self, cfg: DictConfig):
        self.cfg = cfg
        self.num_epochs = self.cfg.train.num_epochs
        self.device = self.cfg.device
        self.ckpt_dir = cfg.train.checkpoints_dir
        self.save_ckpt_freq = cfg.train.save_ckpt_freq
        # Model, Optim, Criterion
        self.tokenizer = AutoTokenizer.from_pretrained(cfg.model.encoder_checkpoint)
        self.encoder = Encoder(cfg=cfg)
        self.model = Data2Vec(encoder=self.encoder, cfg=cfg)
        self.model.to(self.device)
        self.optimizer = optim.Adam(self.model.parameters(), cfg.optimizer.lr)
        self.criterion = nn.SmoothL1Loss(reduction='none', beta=cfg.criterion.loss_beta)
        self.criterion.to(self.device)
        # Datasets & Data Loaders
        self.train_dataset = WikiText(cfg, 'train', self.tokenizer)
        self.test_dataset = WikiText(cfg, 'test', self.tokenizer)
        self.train_loader = DataLoader(self.train_dataset, batch_size=cfg.train.batch_size,
                                       collate_fn=self.train_dataset.collate_fn)
        self.test_loader = DataLoader(self.test_dataset, batch_size=cfg.train.eval_batch_size,
                                      collate_fn=self.test_dataset.collate_fn)
        # Tensorboard
        self.tensorboard = SummaryWriter(log_dir=self.cfg.train.log_dir)

        # Trackers
        self.loss_tracker = AverageMeter('loss')

    def train_step(self, batch):
        """
        Train one batch of data and return loss.

        Args:
            batch: A batch of data, inputs, labels and mask with shape [batch_size, seq_len]

        Returns:
            Loss value
        """
        src, trg, mask = batch
        src, trg, mask = src.to(self.device), trg.to(self.device), mask.to(self.device)

        x, y = self.model(src, trg, mask)
        loss = self.criterion(x.float(), y.float()).sum(dim=-1).sum().div(x.size(0))
        loss.backward()
        self.optimizer.step()
        self.optimizer.zero_grad()

        return loss.item()

    def test_step(self, batch):
        """
        Test a model on one batch of data and return loss.

        Args:
            batch: A batch of data, inputs, labels and mask with shape [batch_size, seq_len]

        Returns:
            Loss value
        """
#         src = batch['input_ids'].to(self.device)
#         trg = batch['labels'].to(self.device)
#         mask = batch['masked_indices'].to(self.device)
        src = batch[0].to(self.device)
        trg = batch[1].to(self.device)
        mask = batch[2].to(self.device)
        x, y = self.model(src, trg, mask=mask)
        loss = self.criterion(x, y)

        return loss.item()
#         input_ids = batch['input_ids'].to(self.device)
#         labels = batch['labels'].to(self.device)
#         attention_mask = batch.get('attention_mask', None)  # Add this if your data includes attention masks

#         # Adjusting the call to match expected parameters
#         outputs = self.model(input_ids, attention_mask=attention_mask)
#         logits = outputs.logits

#         loss = self.criterion(logits, labels)
#         return loss.item()

    def train_epoch(self, epoch_num):
        """
        Train the model for one epoch and verbose using the progress bar.

        Args:
            epoch_num: number of the current epoch

        Returns:
            The average loss through the whole epoch
        """
        self.model.train()
        self.loss_tracker.reset()
        with tqdm(self.train_loader, unit="batch", desc=f'Epoch: {epoch_num}/{self.num_epochs} ',
                  bar_format='{desc:<16}{percentage:3.0f}%|{bar:70}{r_bar}', ascii=" #") as iterator:
            for batch in iterator:
                loss = self.train_step(batch)
                self.model.ema_step()
                self.loss_tracker.update(loss)
                avg_loss = self.loss_tracker.avg
                iterator.set_postfix(loss=avg_loss)

        return avg_loss

    def evaluate(self):
        """
        Evaluate the model on the test set

        Returns:
            The average loss through the whole test dataset
        """
        self.model.eval()
        self.loss_tracker.reset()
        with tqdm(self.test_loader, unit="batch", desc=f'Evaluating... ',
                  bar_format='{desc:<16}{percentage:3.0f}%|{bar:70}{r_bar}', ascii=" #") as iterator:
            with torch.no_grad():
                for batch in iterator:
                    loss = self.test_step(batch)
                    self.loss_tracker.update(loss)
                    avg_loss = self.loss_tracker.avg
                    iterator.set_postfix(loss=avg_loss)

        return avg_loss

#     def train(self):
#         """
#         Train and evaluate the model on the datasets and save checkpoints and write summaries to TensorBoard.

#         """
#         for epoch in range(1, self.num_epochs + 1):
#             print()
#             train_loss = self.train_epoch(epoch)
#             val_loss = self.evaluate()

#             # tensorboard
#             self.tensorboard.add_scalar('train_loss', train_loss, epoch)
#             self.tensorboard.add_scalar('val_loss', val_loss, epoch)

#             # save checkpoint
#             maybe_save_checkpoint(self.model, self.optimizer, self.ckpt_dir, epoch, self.save_ckpt_freq)



    def train(self):
        """
        Load a model from Hugging Face and evaluate on the test dataset.
        """
        # Load a model compatible with your configuration from Hugging Face
        model_checkpoint = 'bert-base-uncased'  # Make sure to use a model that fits your task
        self.model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=2)  # Adjust num_labels based on your task
        self.model.to(self.device)

        # Evaluate the model
        print("Starting evaluation using Hugging Face model...")
        val_loss = self.evaluate()
        print(f"Evaluation Loss: {val_loss}")


In [14]:
# Load configuration
cfg = OmegaConf.load('roberta-pretraining.yaml')  # Adjust the path if necessary


In [15]:
def train_model():
    trainer = TextTrainer(cfg)
    trainer.train()

# Run the training function
train_model()

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.

In [None]:
def evaluate_model():
    trainer = TextTrainer(cfg)
    return trainer.evaluate()

# Print evaluation results
print("Evaluation loss:", evaluate_model())
