In [1]:
import logging
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from transformers import BertTokenizer, BertModel, get_linear_schedule_with_warmup
from datasets import load_dataset
from torch.cuda.amp import autocast, GradScaler
from tqdm.auto import tqdm

# Suppress specific warnings
logging.getLogger("transformers.tokenization_utils_base").setLevel(logging.ERROR)

# Check GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# Load SNLI dataset
dataset = load_dataset("snli")

# Load Tokenizer
tokenizer = BertTokenizer.from_pretrained("distilbert-base-uncased")

  from .autonotebook import tqdm as notebook_tqdm


Using device: cpu


In [2]:
# Tokenization Function with fixed padding
def tokenize_function(examples):
    return tokenizer(examples["premise"], examples["hypothesis"], padding="max_length", truncation=True, max_length=64)

# Tokenize Dataset
tokenized_datasets = dataset.map(tokenize_function, batched=True)

# Remove only existing columns dynamically
columns_to_remove = ['premise', 'hypothesis']
available_columns = tokenized_datasets['train'].column_names
columns_to_remove = [col for col in columns_to_remove if col in available_columns]
tokenized_datasets = tokenized_datasets.remove_columns(columns_to_remove)

# Filter out invalid labels (-1 or out of range)
def filter_invalid_labels(example):
    return example["label"] in [0, 1, 2]  # Keep only valid labels

tokenized_datasets = tokenized_datasets.filter(filter_invalid_labels)

# Format dataset for PyTorch
tokenized_datasets.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])

# Adjust the number of samples to match the dataset size
train_subset = tokenized_datasets['train'].select(range(min(10000, len(tokenized_datasets['train']))))  # First 10,000 or max available samples
train_dataloader = DataLoader(train_subset, batch_size=4, shuffle=True, num_workers=8)

eval_subset = tokenized_datasets['validation'].select(range(min(10000, len(tokenized_datasets['validation']))))  # First 10,000 or max available samples
eval_dataloader = DataLoader(eval_subset, batch_size=4, num_workers=8)

# Sentence-BERT (SBERT) Model Definition
class SBERT(nn.Module):
    def __init__(self, bert_path="distilbert-base-uncased", hidden_dim=768):
        super(SBERT, self).__init__()
        self.bert = BertModel.from_pretrained(bert_path)
        self.fc = nn.Linear(hidden_dim * 3, 3)  # Output 3 classes (entailment, contradiction, neutral)
    
    def mean_pooling(self, token_embeds, attention_mask):
        """Mean Pooling for Sentence Representation"""
        in_mask = attention_mask.unsqueeze(-1).expand(token_embeds.size()).float()
        return torch.sum(token_embeds * in_mask, 1) / torch.clamp(in_mask.sum(1), min=1e-9)
    
    def forward(self, input_ids_a, attention_mask_a, input_ids_b, attention_mask_b):
        """Forward Pass for Training"""
        u = self.bert(input_ids_a, attention_mask=attention_mask_a).last_hidden_state
        v = self.bert(input_ids_b, attention_mask=attention_mask_b).last_hidden_state
        
        u_mean = self.mean_pooling(u, attention_mask_a)
        v_mean = self.mean_pooling(v, attention_mask_b)
        
        # Concatenation: (u, v, |u - v|)
        x = torch.cat([u_mean, v_mean, torch.abs(u_mean - v_mean)], dim=-1)
        logits = self.fc(x)
        return logits  # No softmax needed with CrossEntropyLoss

# Initialize Model
model = SBERT().to(device)

# Loss and Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=2e-5)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=100, num_training_steps=len(train_dataloader) * 2)

# Mixed Precision (AMP) & Gradient Scaling
scaler = GradScaler()

# Gradient Accumulation
accumulation_steps = 2  # Accumulate gradients over 2 batches

# Training Loop
num_epochs = 1  # Limiting to 1 epoch for faster training
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    for step, batch in enumerate(tqdm(train_dataloader)):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)

        optimizer.zero_grad()

        # Split into two sentence pairs
        mid_point = input_ids.shape[1] // 2
        input_ids_a, input_ids_b = input_ids[:, :mid_point], input_ids[:, mid_point:]
        attention_mask_a, attention_mask_b = attention_mask[:, :mid_point], attention_mask[:, mid_point:]

        with autocast():  # Mixed Precision
            outputs = model(input_ids_a, attention_mask_a, input_ids_b, attention_mask_b)
            loss = criterion(outputs, labels) / accumulation_steps  # Normalize loss by accumulation steps

        scaler.scale(loss).backward()

        if (step + 1) % accumulation_steps == 0:
            scaler.step(optimizer)
            scaler.update()
            scheduler.step()

        total_loss += loss.item()

    print(f"Epoch {epoch+1}, Loss: {total_loss/len(train_dataloader)}")

# Save Model
torch.save(model.state_dict(), "sbert_task2.pth")
print("Model saved!")

You are using a model of type distilbert to instantiate a model of type bert. This is not supported for all configurations of models and can yield errors.
Some weights of BertModel were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['embeddings.LayerNorm.bias', 'embeddings.LayerNorm.weight', 'embeddings.position_embeddings.weight', 'embeddings.token_type_embeddings.weight', 'embeddings.word_embeddings.weight', 'encoder.layer.0.attention.output.LayerNorm.bias', 'encoder.layer.0.attention.output.LayerNorm.weight', 'encoder.layer.0.attention.output.dense.bias', 'encoder.layer.0.attention.output.dense.weight', 'encoder.layer.0.attention.self.key.bias', 'encoder.layer.0.attention.self.key.weight', 'encoder.layer.0.attention.self.query.bias', 'encoder.layer.0.attention.self.query.weight', 'encoder.layer.0.attention.self.value.bias', 'encoder.layer.0.attention.self.value.weight', 'encoder.layer.0.intermediate.dense.bias', 'encoder.layer.0.inte

Epoch 1, Loss: 0.5695841263890267
Model saved!
