In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import BertTokenizer
from datasets import load_dataset

# Check for GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# Load a subset of Wikipedia dataset
dataset = load_dataset("wikipedia", "20220301.en", split="train[:5000]", trust_remote_code=True)

  from .autonotebook import tqdm as notebook_tqdm


Using device: cpu


In [2]:
# Tokenizer
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=128)

tokenized_datasets = dataset.map(tokenize_function, batched=True)

tokenized_datasets.set_format(type='torch', columns=['input_ids', 'attention_mask'])

# Define a simple BERT model
class SimpleBERT(nn.Module):
    def __init__(self, vocab_size=30522, hidden_dim=768, num_layers=6, num_heads=6):
        super(SimpleBERT, self).__init__()
        self.embedding = nn.Embedding(vocab_size, hidden_dim)
        self.encoder_layer = nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=num_heads)
        self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=num_layers)
        self.fc = nn.Linear(hidden_dim, vocab_size)

    def forward(self, input_ids):
        embedded = self.embedding(input_ids)
        encoded = self.transformer_encoder(embedded)
        logits = self.fc(encoded)
        return logits

# Initialize model
model = SimpleBERT().to(device)

# Loss and Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=5e-5)

# Training loop
num_epochs = 3
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    for batch in tokenized_datasets:
        input_ids = batch['input_ids'].to(device)
        optimizer.zero_grad()
        outputs = model(input_ids)
        loss = criterion(outputs.view(-1, outputs.size(-1)), input_ids.view(-1))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1}, Loss: {total_loss/len(tokenized_datasets)}")

# Save model
torch.save(model.state_dict(), "bert_scratch.pth")
print("Model saved!")


Map: 100%|██████████| 5000/5000 [09:26<00:00,  8.82 examples/s]


Epoch 1, Loss: 1.8949753407254815
Epoch 2, Loss: 0.36290882806442676
Epoch 3, Loss: 0.11137995455454221
Model saved!
