In [1]:
from datasets import load_dataset, Dataset
from transformers import BertTokenizer, BertForMaskedLM
import numpy as np
import time
from transformers import AdamW, get_linear_schedule_with_warmup

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
#load dataset
books = load_dataset("opus_books", "en-fr")

# load model
model_name = "prajjwal1/bert-tiny"
tokenizer = BertTokenizer.from_pretrained(model_name, cache_dir="llm-tokenizer")
model = BertForMaskedLM.from_pretrained(model_name, cache_dir="llm-model")

In [6]:
from torch.utils.data import Dataset, DataLoader
import torch

class CustomDataset(Dataset):
    def __init__(self, books, tokenizer, from_language='en', to_language='fr', max_length=128):
        self.books = books
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.from_language = from_language
        self.to_language = to_language
        self.examples = []

        self._prepare_examples()

    def _prepare_examples(self):
        # Generate incremental source-target pairs for each data pair
        print(len(self.books['train']['translation']))
        for pair in self.books['train']['translation'][:10000]:
            source_text = pair[self.from_language]
            target_text = pair[self.to_language]
            combined_text = "SOURCE: " + source_text + "; TRANSLATION: " + target_text

            # Tokenize the entire combined text once
            encoding = self.tokenizer(
                combined_text,
                max_length=self.max_length,
                padding='max_length',
                truncation=True,
                return_tensors='pt'
            )
            input_ids = encoding['input_ids'].flatten()

            # Create incremental source-target pairs
            for i in range(1, len(input_ids)):
                # Create prefix up to the i-th token
                input_prefix = input_ids[:i]
                target_token = input_ids[i]

                # Only add pairs if within max length
                if len(input_prefix) < self.max_length:
                    # Pad input_prefix to max_length
                    padded_input = torch.cat([
                        input_prefix,
                        torch.full((self.max_length - len(input_prefix),), self.tokenizer.pad_token_id)
                    ])

                    # Prepare label for next token prediction
                    label = torch.full((self.max_length,), -100)
                    label[i] = target_token  # Only the next token is the target

                    # Add example to dataset
                    self.examples.append({
                        'input_ids': padded_input,
                        'attention_mask': (padded_input != self.tokenizer.pad_token_id).long(),
                        'labels': label
                    })

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        return self.examples[idx]

In [4]:
dataset = train_dataset = CustomDataset(books, tokenizer)
dataloader = DataLoader(
        train_dataset, batch_size=32, shuffle=True, num_workers=0
)
dataset

127085


<__main__.CustomDataset at 0x1579cbc20>

In [12]:

def train_model(model, train_loader, val_loader, device, num_epochs=3):
    # Initialize optimizer only with parameters that require gradients
    optimizer = AdamW(
        [p for p in model.parameters() if p.requires_grad],
        lr=1e-5
    )

    total_steps = len(train_loader) * num_epochs
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=0,
        num_training_steps=total_steps
    )
    
    for epoch in range(num_epochs):
        start_time = time.time()
        model.train()
        total_train_loss = 0
        right_predictions = 0
        for batch in train_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
        
            model.zero_grad()
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels
            )

            # Calculate accuracy
            preds = torch.argmax(outputs.logits, dim=1)
            right_predictions += torch.sum(preds == labels).item()

            loss = outputs.loss
            total_train_loss += loss.item()

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            scheduler.step()

        # Validation
        model.eval()
        total_val_loss = 0
        predictions = []
        true_labels = []

        with torch.no_grad():
            for batch in val_loader:
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                labels = batch['labels'].to(device)

                outputs = model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    labels=labels
                )

                loss = outputs.loss
                total_val_loss += loss.item()

                preds = torch.argmax(outputs.logits, dim=1)
                predictions.extend(preds.cpu().numpy())
                true_labels.extend(labels.cpu().numpy())

        avg_train_loss = total_train_loss / len(train_loader)
        avg_val_loss = total_val_loss / len(val_loader)
        accuracy = np.mean(np.array(predictions) == np.array(true_labels))
        end_time = time.time()

        print(f'Epoch {epoch + 1}:')
        print(f'Average training loss: {avg_train_loss:.4f}')
        print(f'Average validation loss: {avg_val_loss:.4f}')
        print(f'Right predictions: {right_predictions} out of {len(train_loader) * 32}')
        print(f'Validation Accuracy: {accuracy:.4f}')
        print(f'Time taken for epoch: {end_time - start_time:.2f} seconds')
        print('-' * 60)

train_model(model, dataloader, dataloader, "cpu")

RuntimeError: The size of tensor a (30522) must match the size of tensor b (128) at non-singleton dimension 1