In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from transformers import AutoModel, AutoConfig, AutoTokenizer
from datasets import load_dataset
from omegaconf import OmegaConf, DictConfig
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
import os
import copy 
from transformers import AutoModelForSequenceClassification

In [2]:
class EMA:
    """
    Exponential Moving Average (EMA) of model weights.

    Args:
        model (torch.nn.Module): The model for which to maintain the EMA.
        decay (float): The decay rate for EMA (between 0 and 1).
        device (torch.device): The device on which the EMA model will reside.
    """

    def __init__(self, model, decay, device=None):
        self.decay = decay
        self.device = device or next(model.parameters()).device
        self.model = copy.deepcopy(model).eval()
        self.model.to(self.device)
        for param in self.model.parameters():
            param.requires_grad_(False)

    def update(self, model):
        """
        Update the EMA parameters based on the current model parameters.

        Args:
            model (torch.nn.Module): The current training model to update from.
        """
        with torch.no_grad():
            for ema_param, model_param in zip(self.model.parameters(), model.parameters()):
                if model_param.requires_grad:
                    ema_param.data.mul_(self.decay).add_(model_param.data, alpha=1 - self.decay)

    def apply_shadow(self, model):
        """
        Copy the EMA parameters to the training model.

        Args:
            model (torch.nn.Module): The model to which EMA parameters will be copied.
        """
        model.load_state_dict(self.model.state_dict())

    def __call__(self, model):
        """
        Update method callable directly with the model.
        """
        self.update(model)

In [None]:
class Data2VecText(nn.Module):
    """
    Data2Vec model for text modality.

    Args:
        encoder (nn.Module): The encoder module (e.g., a Transformer-based model).
        cfg (omegaconf.DictConfig): Configuration containing model properties.
    """

    def __init__(self, encoder, cfg):
        super(Data2VecText, self).__init__()
        self.embed_dim = cfg.model.embed_dim
        self.encoder = encoder
        self.cfg = cfg
        self.ema = EMA(self.encoder, cfg.model.ema_decay, cfg.device)
        self.regression_head = self._build_regression_head()

    def _build_regression_head(self):
        """
        Constructs the regression head for the model.

        Returns:
            nn.Module: A sequential model consisting of linear and activation layers.
        """
        return nn.Sequential(
            nn.Linear(self.embed_dim, self.embed_dim * 2),
            nn.GELU(),
            nn.Linear(self.embed_dim * 2, self.embed_dim)
        )

    def forward(self, src, trg=None, mask=None):
        """
        Forward pass through Data2Vec text model.

        Args:
            src (Tensor): Source tokens (masked inputs for training).
            trg (Tensor, optional): Target tokens (unmasked inputs for training, `None` otherwise).
            mask (Tensor, optional): Boolean masked indices (not directly used here but can be useful for extensions).

        Returns:
            Tensor or tuple: Output from the encoder or a tuple of encoder and EMA outputs.
        """
        # Model forward in online mode (student)
        x = self.encoder(src)['encoder_out']  # Fetch the last layer outputs

        if trg is None:
            return x

        # Model forward in offline mode (teacher)
        with torch.no_grad():
            self.ema.model.eval()
            y = self.ema.model(trg)['encoder_states']  # Fetch the last transformer layers outputs

            # Layer normalization for text
            y = [F.layer_norm(tl.float(), tl.shape[-1:]) for tl in y]
            y = sum(y) / len(y)
            if self.cfg.model.normalize_targets:
                y = F.layer_norm(y.float(), y.shape[-1:])

        x = x[mask]
        y = y[mask]

        x = self.regression_head(x)

        return x, y

In [10]:
class WikiText(Dataset):
    """
    A Dataset instance for WikiText dataset loaded from HuggingFace datasets.

    Args:
        cfg (DictConfig): config object
        split: Split to load ['train', 'test']
        tokenizer: A HuggingFace Tokenizer model like BPE
        **kwargs: extra args which are set as dataset properties
    """

    def __init__(self, cfg, split, tokenizer, **kwargs):
        super(WikiText, self).__init__()
        self.cfg = cfg
        self.path = cfg.dataset.name
        self.mlm_probability = cfg.dataset.mlm_probability
        raw_data = load_dataset('wikitext', self.path)[split]
        self.data = self.clean_dataset(raw_data) if self.cfg.dataset.clean_dataset else raw_data
        self.tokenizer = tokenizer
        self.__dict__.update(kwargs)

    def clean_dataset(self, data):
        """
        Cleanup dataset by removing invalid sized samples, etc.
        """
        print('Cleaning dataset ...')
        min_seq_len, max_seq_len = self.cfg.dataset.valid_seq_lenghts
        texts = []
        with tqdm(data, desc='Removing invalid sized inputs: ') as tbar:
            for i, x in enumerate(tbar):
                if len(x['text']) in range(min_seq_len, max_seq_len + 1):
                    texts.append(x)
        return texts

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        """
        Only return tokens from raw text with no additions e.g, padding, bos/eos, etc.
        Args:
            index: sample index to pick from dataset

        Returns:
            tokenized outputs
        """
        raw_text = self.data[index]['text']
        tokens = self.tokenizer(raw_text, return_attention_mask=False)
        return tokens

    def _mask_tokens(self, inputs, special_tokens_mask=None):
        """
        Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. Ported
         from `transformers.data.DataCollatorForLanguageModeling.torch_mask_tokens()`
        Args:
            inputs: batch of input tokens
            special_tokens_mask:

        Returns:
            a dict batch of masked and padded inputs/labels

        """
        labels = inputs.clone()
        # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
        probability_matrix = torch.full(labels.shape, self.mlm_probability)
        if special_tokens_mask is None:
            special_tokens_mask = [
                self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in
                labels.tolist()
            ]
            special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
        else:
            special_tokens_mask = special_tokens_mask.bool()

        probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
        masked_indices = torch.bernoulli(probability_matrix).bool()
        labels[~masked_indices] = self.tokenizer.pad_token_id
        # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
        indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
        inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)

        # 10% of the time, we replace masked input tokens with random word
        indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
        random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
        inputs[indices_random] = random_words[indices_random]

        # The rest of the time (10% of the time) we keep the masked input tokens unchanged
        return inputs, labels, masked_indices

    def collate_fn(self, batch):
        """
        Collate the batch of data using BERT masking strategy. carefully ported from
         transformers.data.DataCollatorForLanguageModeling
        Args:
            batch: batch of data

        Returns:
            same batch of data masked and padded
        """
        batch = self.tokenizer.pad(batch, return_tensors="pt")
        # If special token mask has been preprocessed, pop it from the dict.
        special_tokens_mask = batch.pop("special_tokens_mask", None)
        src, trg, masked_indices = self._mask_tokens(
            batch["input_ids"], special_tokens_mask=special_tokens_mask
        )
        return src, trg, masked_indices


#         batch = self.tokenizer.pad(batch, return_tensors="pt")
#         # If special token mask has been preprocessed, pop it from the dict.
#         special_tokens_mask = batch.pop("special_tokens_mask", None)
#         src, trg, masked_indices = self._mask_tokens(
#             batch["input_ids"], special_tokens_mask=special_tokens_mask
#         )
#         return {
#         'input_ids': src,
#         'labels': trg,
#         'masked_indices': masked_indices
#         }

In [11]:
class Encoder(nn.Module):
    """
    Encoder model using HuggingFace for NLP

    To load your desired model specify model checkpoint under cfg.model.encoder_checkpoint

    Args:
        cfg: An omegaconf.DictConf instance containing all the configurations.
        **kwargs: extra args which are set as model properties
    """

    def __init__(self, cfg, **kwargs):
        super(Encoder, self).__init__()
        self.cfg = cfg
        checkpoint = cfg.model.encoder_checkpoint
        model_config = AutoConfig.from_pretrained(checkpoint)
        self.encoder = AutoModel.from_config(model_config)
        self.__dict__.update(kwargs)

    def forward(self, inputs, mask=None, **kwargs):
        """
        Forward inputs through the encoder and extract transformer/attention layers outputs

        Args:
            inputs: source tokens
            mask: bool masked indices
            kwargs: keyword args specific to the encoder's forward method

        Returns:
            A dictionary of the encoder outputs including transformer layers outputs and attentions outputs

        """
        # Note: inputs are already masked for MLM so mask is not used
        outputs = self.encoder(inputs, output_hidden_states=True, output_attentions=True, **kwargs)
        encoder_states = outputs['hidden_states'][:-1]  # encoder layers outputs separately
        encoder_out = outputs['hidden_states'][-1]      # last encoder output (accumulated)
        attentions = outputs['attentions']
        return {
            'encoder_states': encoder_states,
            'encoder_out': encoder_out,
            'attentions': attentions
        }

In [12]:
class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)


def maybe_save_checkpoint(model, optimizer, path, epoch_num, save_freq):
    """
    Save a checkpoint specific to Data2Vec
    Args:
        model: a nn.Module instance
        optimizer
        path: path to save checkpoint to
        epoch_num: current epoch number
        save_freq: save frequency based on epoch number

    """
    if not os.path.exists(path):
        os.makedirs(path)
    path = os.path.join(path, f'{epoch_num}.pt')
    if epoch_num % save_freq == 0:
        checkpoint = {'data2vec': model.state_dict(),
                      'encoder': model.encoder.encoder.state_dict(),
                      'optimizer': optimizer.state_dict()}
        torch.save(checkpoint, path)
        print(f'Saved checkpoint to `{path}`')

In [13]:
class TextTrainer:
    """
    Trainer class for training and evaluating the Data2VecText model on text data.

    Args:
        cfg (omegaconf.DictConfig): Configuration object containing all necessary parameters.
    """

    def __init__(self, cfg):
        self.cfg = cfg
        self.device = torch.device(cfg.device if torch.cuda.is_available() else 'cpu')

        # Initialize tokenizer, model, optimizer, and data loaders
        self.tokenizer = AutoTokenizer.from_pretrained(cfg.model.encoder_checkpoint)
        self.model = Data2VecText(Encoder(self.cfg), self.cfg).to(self.device)
        self.optimizer = Adam(self.model.parameters(), lr=cfg.optimizer.lr)
        self.criterion = MSELoss()  # Using MSELoss for simplicity, can be changed as needed.

        self.train_dataset = WikiText(cfg, 'train', self.tokenizer)
        self.test_dataset = WikiText(cfg, 'test', self.tokenizer)
        self.train_loader = DataLoader(self.train_dataset, batch_size=cfg.train.batch_size,
                                       collate_fn=self.train_dataset.collate_fn, shuffle=True)
        self.test_loader = DataLoader(self.test_dataset, batch_size=cfg.train.eval_batch_size,
                                      collate_fn=self.test_dataset.collate_fn)

    def train_step(self, src, trg, mask):
        """
        Executes a single training step.

        Args:
            src (Tensor): Input batch of source tokens.
            trg (Tensor): Input batch of target tokens.
            mask (Tensor): Mask indicating the active parts of the input.

        Returns:
            float: The loss value for the step.
        """
        self.model.train()
        self.optimizer.zero_grad()
        x, y = self.model(src, trg, mask)
        loss = self.criterion(x, y)
        loss.backward()
        self.optimizer.step()
        return loss.item()

    def train_epoch(self):
        """
        Runs one epoch of training.

        Returns:
            float: Average loss for this epoch.
        """
        total_loss = 0
        for src, trg, mask in tqdm(self.train_loader, desc="Training", leave=False):
            src, trg, mask = src.to(self.device), trg.to(self.device), mask.to(self.device)
            loss = self.train_step(src, trg, mask)
            total_loss += loss
        return total_loss / len(self.train_loader)

    def evaluate(self):
        """
        Evaluates the model on the test dataset.

        Returns:
            float: Average loss on the test dataset.
        """
        self.model.eval()
        total_loss = 0
        with torch.no_grad():
            for src, trg, mask in tqdm(self.test_loader, desc="Evaluating", leave=False):
                src, trg, mask = src.to(self.device), trg.to(self.device), mask.to(self.device)
                x, y = self.model(src, trg, mask)
                loss = self.criterion(x, y)
                total_loss += loss.item()
        return total_loss / len(self.test_loader)

    def train(self):
        """
        Conducts the full training process across all epochs specified in the configuration.
        """
        for epoch in range(1, self.cfg.train.num_epochs + 1):
            avg_train_loss = self.train_epoch()
            avg_val_loss = self.evaluate()
            print(f'Epoch {epoch}/{self.cfg.train.num_epochs}, Train Loss: {avg_train_loss}, Val Loss: {avg_val_loss}')

In [14]:
# Load configuration
cfg = OmegaConf.load('roberta-pretraining.yaml')  # Adjust the path if necessary


In [15]:
def train_model():
    trainer = TextTrainer(cfg)
    trainer.train()

# Run the training function
train_model()

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.

In [None]:
def evaluate_model():
    trainer = TextTrainer(cfg)
    return trainer.evaluate()

# Print evaluation results
print("Evaluation loss:", evaluate_model())
