In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from transformers import AutoModel, AutoConfig, AutoTokenizer,AutoModelForSequenceClassification
from datasets import load_dataset
from omegaconf import OmegaConf, DictConfig
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
import os
import copy

In [13]:
class TextTrainer:
    """
    A Trainer class to train and evaluate NLP models using the Data2Vec approach.

    Args:
        model (nn.Module): The Data2Vec model to be trained.
        tokenizer: Tokenizer for preparing input data.
        train_dataset (Dataset): Dataset for training.
        test_dataset (Dataset): Dataset for evaluation.
        config (DictConfig): Configuration object containing training parameters and device information.
    """
    def __init__(self, cfg: DictConfig):
        self.config = cfg
        self.device = cfg.device

        self.encoder = Encoder(cfg=cfg)
        self.model = Data2Vec(encoder=self.encoder, cfg=cfg)
        self.tokenizer = AutoTokenizer.from_pretrained(cfg.model.encoder_checkpoint)

        self.train_dataset = WikiText(cfg, 'train', self.tokenizer)
        self.test_dataset = WikiText(cfg, 'test', self.tokenizer)
      
        self.train_loader = DataLoader(self.train_dataset, batch_size=cfg.train.batch_size, shuffle=True,
                                       collate_fn=self.train_dataset.collate_fn)
        self.test_loader = DataLoader(self.test_dataset, batch_size=cfg.train.eval_batch_size, shuffle=False,
                                      collate_fn=self.test_dataset.collate_fn)
        

        self.optimizer = optim.Adam(self.model.parameters(), cfg.optimizer.lr)
        self.criterion = nn.CrossEntropyLoss()  # Adjust according to your loss function needs

    def run_epoch(self, mode='train'):
        """
        Run one epoch of training or evaluation.

        Args:
            mode (str): Specifies the mode 'train' or 'eval'.

        Returns:
            Average loss of the epoch.
        """
        if mode == 'train':
            self.model.train()
        else:
            self.model.eval()

        total_loss = 0
        for batch in tqdm(self.train_loader if mode == 'train' else self.test_loader, desc=f"{mode.capitalize()} Epoch"):
            inputs, labels = batch['input_ids'].to(self.device), batch['labels'].to(self.device)
            outputs = self.model(inputs)
            loss = self.criterion(outputs, labels)

            if mode == 'train':
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

            total_loss += loss.item() * inputs.size(0)

        avg_loss = total_loss / len(self.train_loader.dataset if mode == 'train' else self.test_loader.dataset)
        print(f"{mode.capitalize()} loss: {avg_loss:.4f}")
        return avg_loss

    def train(self, num_epochs):
        """
        Train the model for a given number of epochs, alternating between training and evaluation.

        Args:
            num_epochs (int): Number of epochs to train the model.
        """
        for epoch in range(num_epochs):
            print(f"Epoch {epoch+1}/{num_epochs}")
            self.run_epoch('train')
            with torch.no_grad():
                self.run_epoch('eval')

In [14]:
# Load configuration
cfg = OmegaConf.load('roberta-pretraining.yaml')  # Adjust the path if necessary


In [None]:
trainer = TextTrainer(cfg)
num_epochs = cfg.train.num_epochs  # Number of epochs to train
trainer.train(num_epochs)
evaluation_loss = trainer.run_epoch(mode='eval')
print(f"Final Evaluation Loss: {evaluation_loss}")
trained_model = trainer.model

In [None]:

# Sample text data
examples = [
    "A totally engrossing thriller.",
    "Unfortunately, the story is not as strong as the direction or the atmosphere.",
    "This is the best movie I have ever seen."
]

# Load tokenizer and model from a checkpoint
tokenizer = AutoTokenizer.from_pretrained(cfg.model.encoder_checkpoint)
model = AutoModelForSequenceClassification.from_pretrained(trained_model)
model.eval()

# Preprocess the examples
inputs = tokenizer(examples, padding=True, truncation=True, return_tensors="pt")

# Run the model on the examples
with torch.no_grad():  # Disable gradient calculation for efficiency
    outputs = model(**inputs)

# Process the results
predictions = outputs.logits.argmax(dim=-1)

# Map predictions to labels (for SST-2: 0 = negative, 1 = positive)
labels = ["negative", "positive"]
predicted_labels = [labels[p] for p in predictions]

# Show results
for text, label in zip(examples, predicted_labels):
    print(f"Text: {text}\nPredicted sentiment: {label}\n")