In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from transformers import AutoModel, AutoConfig, AutoTokenizer,AutoModelForSequenceClassification
from datasets import load_dataset
from omegaconf import OmegaConf, DictConfig
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
import os
import copy

In [2]:
class EMA:
    
    def __init__(self, model: nn.Module, decay_rate, device=None):
        # copy the model, this model is the model that is in "target" mode
        self.model = copy.deepcopy(model)
        # model should not be changed by calls to backward()
        self.model.requires_grad_(False)
        # update model if training on CUDA
        self.device = device
        self.model.to(device)
        # set decay rate
        self.decay = decay_rate
        # the number of times we have stepped the EMA, needed to eventually taper down delta
        self.update_count = 0
        
    def step(self, n_model: nn.Module):
        """
        A single step of the EMA parameterization of the teacher's parameters

        Args:
            n_model (nn.Module): this is the student model, that the teacher will be following with EMA
        """
        # current parameters
        ema_state_dict = {}
        ema_params = self.model.state_dict()
        
        # update parameters as a the exponential moving average of n_model
        for key, param in n_model.state_dict().items():
            xx = ema_params[key].float()
            xx.mul_(self.decay_rate)
            xx = xx.add(param.to(dtype=xx.dtype).mul(1-self.decay))
            ema_state_dict[key] = xx
        
        # load the updated parameters back into model
        self.model.load_state_dict(ema_state_dict, strict=False)
        # update count
        self.update_count += 1
        
        
    def set_decay(self, decay_rate):
        self.decay_rate = decay_rate

In [None]:
class Data2Vec(nn.Module):
    
    def __init__(self, encoder, modality, embed_dim, ema_decay, ema_end_decay, ema_anneal_end_step, device):
        self.encoder = encoder
        self.modality = modality
        self.embed_dim = embed_dim
        self.EMAModule = EMA(self.encoder, decay_rate=ema_decay, device=device)
        self.device = device
        
        # building the regression head
        if self.modality == 'text':
            self.regression_head = None
        elif self.modality == 'image':
            self.regression_head = nn.Linear(self.embed_dim, self.embed_dim)
        else:
            raise Exception('Given modality is not accepted: ', str(modality))
            
        # parameters for the EMA module
        self.ema_decay = ema_decay                          # starting decay rate
        self.ema_end_decay = ema_end_decay                  # ending decay rate
        self.ema_anneal_end_step = ema_anneal_end_step      # how many steps it should take to reach decay rate (at max)
        
        
    def stepEMA(self):
        '''
        Performs a step in the EMA module (as the teacher model)
        '''
        # find new decay rate if necessary
        if self.ema_decay >= self.ema_end_decay:
            if self.EMAModule.update_count >= self.ema_anneal_end_step:
                # get decay as the ending decay
                decay = self.ema_end_decay
            else:
                # get decay based on update count
                delta = self.ema_end_decay - self.ema_decay
                updates_remaining = 1 - (self.EMAModule.update_count / self.ema_anneal_end_step)
                decay = self.ema_end_decay - (delta * updates_remaining) 
            
            # if we changed, set new decay rate in EMA module
            self.EMAModule.set_decay(decay)
            
        # perform step in EMA
        if self.EMAModule.decay < 1:
            self.EMAModule.step()
        
    def forward(self, src, k, target=None, mask=None):
        x = self.encoder(src, mask)['output']                                       # 1) pass source thru encoder (with mask), and get encoded rep
        
        # if we are in Student mode, we do not have a target, and we are simply building an encoded representation
        if target==None:
            return x
        
        # if we are in Teacher mode, we need to evualate syste
        with torch.no_grad():
            self.EMAModule.model.eval()
            y = self.EMAModule.model(target, ~mask)['states']                       # 2) Get transformer layers outputs 
            y = y[:k]                                                               # 3) Only keep top k layers outputs
            
            # normalizing layers
            y = [F.layer_norm(layer.float(), layer.shape[-1:]) for layer in y]      # 4) Normalize all outputs
            y = sum(y) / len(y)                                                     # 5) Get avg output across all top k layers
            
        x = self.regression_head(x[mask])                                           # 6) Regress x with mask (linear layer)
        y = y[mask]                                                                 # 7) Apply mask to y
        
        return x, y

In [10]:
class WikiText(Dataset):
    """
    A Dataset instance for WikiText dataset loaded from HuggingFace datasets.

    Args:
        cfg (DictConfig): config object
        split: Split to load ['train', 'test']
        tokenizer: A HuggingFace Tokenizer model like BPE
        **kwargs: extra args which are set as dataset properties
    """

    def __init__(self, cfg, split, tokenizer, **kwargs):
        super(WikiText, self).__init__()
        self.cfg = cfg
        self.path = cfg.dataset.name
        self.mlm_probability = cfg.dataset.mlm_probability
        raw_data = load_dataset('wikitext', self.path)[split]
        self.data = self.clean_dataset(raw_data) if self.cfg.dataset.clean_dataset else raw_data
        self.tokenizer = tokenizer
        self.__dict__.update(kwargs)

    def clean_dataset(self, data):
        """
        Cleanup dataset by removing invalid sized samples, etc.
        """
        print('Cleaning dataset ...')
        min_seq_len, max_seq_len = self.cfg.dataset.valid_seq_lenghts
        texts = []
        with tqdm(data, desc='Removing invalid sized inputs: ') as tbar:
            for i, x in enumerate(tbar):
                if len(x['text']) in range(min_seq_len, max_seq_len + 1):
                    texts.append(x)
        return texts

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        """
        Only return tokens from raw text with no additions e.g, padding, bos/eos, etc.
        Args:
            index: sample index to pick from dataset

        Returns:
            tokenized outputs
        """
        raw_text = self.data[index]['text']
        tokens = self.tokenizer(raw_text, return_attention_mask=False)
        return tokens

    def _mask_tokens(self, inputs, special_tokens_mask=None):
        """
        Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. Ported
         from `transformers.data.DataCollatorForLanguageModeling.torch_mask_tokens()`
        Args:
            inputs: batch of input tokens
            special_tokens_mask:

        Returns:
            a dict batch of masked and padded inputs/labels

        """
        labels = inputs.clone()
        # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
        probability_matrix = torch.full(labels.shape, self.mlm_probability)
        if special_tokens_mask is None:
            special_tokens_mask = [
                self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in
                labels.tolist()
            ]
            special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
        else:
            special_tokens_mask = special_tokens_mask.bool()

        probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
        masked_indices = torch.bernoulli(probability_matrix).bool()
        labels[~masked_indices] = self.tokenizer.pad_token_id
        # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
        indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
        inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)

        # 10% of the time, we replace masked input tokens with random word
        indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
        random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
        inputs[indices_random] = random_words[indices_random]

        # The rest of the time (10% of the time) we keep the masked input tokens unchanged
        return inputs, labels, masked_indices

    def collate_fn(self, batch):
        """
        Collate the batch of data using BERT masking strategy. carefully ported from
         transformers.data.DataCollatorForLanguageModeling
        Args:
            batch: batch of data

        Returns:
            same batch of data masked and padded
        """
        batch = self.tokenizer.pad(batch, return_tensors="pt")
        # If special token mask has been preprocessed, pop it from the dict.
        special_tokens_mask = batch.pop("special_tokens_mask", None)
        src, trg, masked_indices = self._mask_tokens(
            batch["input_ids"], special_tokens_mask=special_tokens_mask
        )
        return src, trg, masked_indices


#         batch = self.tokenizer.pad(batch, return_tensors="pt")
#         # If special token mask has been preprocessed, pop it from the dict.
#         special_tokens_mask = batch.pop("special_tokens_mask", None)
#         src, trg, masked_indices = self._mask_tokens(
#             batch["input_ids"], special_tokens_mask=special_tokens_mask
#         )
#         return {
#         'input_ids': src,
#         'labels': trg,
#         'masked_indices': masked_indices
#         }

In [11]:
class Encoder(nn.Module):
    """
    Encoder model using HuggingFace for NLP

    To load your desired model specify model checkpoint under cfg.model.encoder_checkpoint

    Args:
        cfg: An omegaconf.DictConf instance containing all the configurations.
        **kwargs: extra args which are set as model properties
    """

    def __init__(self, cfg, **kwargs):
        super(Encoder, self).__init__()
        self.cfg = cfg
        checkpoint = cfg.model.encoder_checkpoint
        model_config = AutoConfig.from_pretrained(checkpoint)
        self.encoder = AutoModel.from_config(model_config)
        self.__dict__.update(kwargs)

    def forward(self, inputs, mask=None, **kwargs):
        """
        Forward inputs through the encoder and extract transformer/attention layers outputs

        Args:
            inputs: source tokens
            mask: bool masked indices
            kwargs: keyword args specific to the encoder's forward method

        Returns:
            A dictionary of the encoder outputs including transformer layers outputs and attentions outputs

        """
        # Note: inputs are already masked for MLM so mask is not used
        outputs = self.encoder(inputs, output_hidden_states=True, output_attentions=True, **kwargs)
        encoder_states = outputs['hidden_states'][:-1]  # encoder layers outputs separately
        encoder_out = outputs['hidden_states'][-1]      # last encoder output (accumulated)
        attentions = outputs['attentions']
        return {
            'encoder_states': encoder_states,
            'encoder_out': encoder_out,
            'attentions': attentions
        }

In [12]:
class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)


def maybe_save_checkpoint(model, optimizer, path, epoch_num, save_freq):
    """
    Save a checkpoint specific to Data2Vec
    Args:
        model: a nn.Module instance
        optimizer
        path: path to save checkpoint to
        epoch_num: current epoch number
        save_freq: save frequency based on epoch number

    """
    if not os.path.exists(path):
        os.makedirs(path)
    path = os.path.join(path, f'{epoch_num}.pt')
    if epoch_num % save_freq == 0:
        checkpoint = {'data2vec': model.state_dict(),
                      'encoder': model.encoder.encoder.state_dict(),
                      'optimizer': optimizer.state_dict()}
        torch.save(checkpoint, path)
        print(f'Saved checkpoint to `{path}`')

In [13]:
class TextTrainer:
    """
    A Trainer class to train and evaluate NLP models using the Data2Vec approach.

    Args:
        model (nn.Module): The Data2Vec model to be trained.
        tokenizer: Tokenizer for preparing input data.
        train_dataset (Dataset): Dataset for training.
        test_dataset (Dataset): Dataset for evaluation.
        config (DictConfig): Configuration object containing training parameters and device information.
    """
    def __init__(self, cfg: DictConfig):
        self.config = cfg
        self.model = Data2Vec(encoder=self.encoder, cfg=cfg)
        self.tokenizer = AutoTokenizer.from_pretrained(cfg.model.encoder_checkpoint)
        self.train_dataset = WikiText(cfg, 'train', self.tokenizer)
        self.test_dataset = WikiText(cfg, 'test', self.tokenizer)
        self.config = cfg

        

        self.optimizer = optim.Adam(self.model.parameters(), cfg.optimizer.lr)
        self.criterion = nn.CrossEntropyLoss()  # Adjust according to your loss function needs

    def run_epoch(self, mode='train'):
        """
        Run one epoch of training or evaluation.

        Args:
            mode (str): Specifies the mode 'train' or 'eval'.

        Returns:
            Average loss of the epoch.
        """
        if mode == 'train':
            self.model.train()
        else:
            self.model.eval()

        total_loss = 0
        for batch in tqdm(self.train_loader if mode == 'train' else self.test_loader, desc=f"{mode.capitalize()} Epoch"):
            inputs, labels = batch['input_ids'].to(self.config.device), batch['labels'].to(self.config.device)
            outputs = self.model(inputs)
            loss = self.criterion(outputs, labels)

            if mode == 'train':
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

            total_loss += loss.item() * inputs.size(0)

        avg_loss = total_loss / len(self.train_loader.dataset if mode == 'train' else self.test_loader.dataset)
        print(f"{mode.capitalize()} loss: {avg_loss:.4f}")
        return avg_loss

    def train(self, num_epochs):
        """
        Train the model for a given number of epochs, alternating between training and evaluation.

        Args:
            num_epochs (int): Number of epochs to train the model.
        """
        for epoch in range(num_epochs):
            print(f"Epoch {epoch+1}/{num_epochs}")
            self.run_epoch('train')
            with torch.no_grad():
                self.run_epoch('eval')

In [14]:
# Load configuration
cfg = OmegaConf.load('roberta-pretraining.yaml')  # Adjust the path if necessary


In [None]:
trainer = TextTrainer(cfg)
num_epochs = cfg.train.num_epochs  # Number of epochs to train
trainer.train(num_epochs)
evaluation_loss = trainer.run_epoch(mode='eval')
print(f"Final Evaluation Loss: {evaluation_loss}")
trained_model = trainer.model

In [None]:

# Sample text data
examples = [
    "A totally engrossing thriller.",
    "Unfortunately, the story is not as strong as the direction or the atmosphere."
]

# Load tokenizer and model from a checkpoint
tokenizer = AutoTokenizer.from_pretrained(cfg.model.encoder_checkpoint)
model = AutoModelForSequenceClassification.from_pretrained(trained_model)
model.eval()

# Preprocess the examples
inputs = tokenizer(examples, padding=True, truncation=True, return_tensors="pt")

# Run the model on the examples
with torch.no_grad():  # Disable gradient calculation for efficiency
    outputs = model(**inputs)

# Process the results
predictions = outputs.logits.argmax(dim=-1)

# Map predictions to labels (for SST-2: 0 = negative, 1 = positive)
labels = ["negative", "positive"]
predicted_labels = [labels[p] for p in predictions]

# Show results
for text, label in zip(examples, predicted_labels):
    print(f"Text: {text}\nPredicted sentiment: {label}\n")