In [2]:
# !sudo mongod --dbpath ~/data/mongodb

import findspark
findspark.init()

In [1]:
import pyspark
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, length
from datasets import load_dataset
import pandas as pd
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import Dataset, DataLoader, DistributedSampler
from transformers import BertTokenizer, BertForSequenceClassification
import os
import numpy as np

In [2]:
# Initialize Spark Session
def init_spark():
    spark = SparkSession.builder \
        .appName("Distributed BERT Fine-Tuning") \
        .getOrCreate()
    return spark

# Load and preprocess data (simplified for example)
def load_and_preprocess_data(spark):
    # Load IMDB dataset (using a small subset for demo)
    imdb_dataset = load_dataset("imdb")
    imdb_df = imdb_dataset["train"].to_pandas()[["text", "label"]].head(100)  # Limit for testing
    spark_df = spark.createDataFrame(imdb_df)
    
    # Filter short texts and standardize labels
    processed_df = spark_df.filter(length(col("text")) >= 10)
    processed_df = processed_df.withColumn("label", col("label").cast("integer"))
    
    # Split into train/test
    train_df, test_df = processed_df.randomSplit([0.8, 0.2], seed=42)
    return train_df, test_df

# Custom PyTorch Dataset for Spark DataFrame
class SparkDataset(Dataset):
    def __init__(self, spark_df, tokenizer, max_length=128):
        # Convert Spark DataFrame to Pandas for simplicity (in practice, use Spark's iterator for large data)
        self.data = spark_df.toPandas()
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        text = self.data.iloc[idx]["text"]
        label = self.data.iloc[idx]["label"]
        
        # Tokenize
        encoding = self.tokenizer(
            text,
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        )
        
        return {
            "input_ids": encoding["input_ids"].squeeze(),
            "attention_mask": encoding["attention_mask"].squeeze(),
            "labels": torch.tensor(label, dtype=torch.long)
        }

# Initialize BERT model and tokenizer
def init_bert():
    tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
    model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)
    return tokenizer, model

# Setup distributed training
def setup_distributed_training(rank, world_size):
    if world_size > 1:
        os.environ["MASTER_ADDR"] = "localhost"
        os.environ["MASTER_PORT"] = "12345"
        dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)

# Training function
def train_bert(model, train_loader, rank, world_size, epochs=3, batch_size=8):
    device = torch.device(f"cuda:{rank}" if torch.cuda.is_available() else "cpu")
    
    # Wrap model in DDP if multiple GPUs, else use single GPU
    if world_size > 1:
        model = DDP(model, device_ids=[rank])
    model = model.to(device)
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
    
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for batch in train_loader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)
            
            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
        
        if rank == 0:
            print(f"Epoch {epoch+1}, Avg Loss: {total_loss / len(train_loader):.4f}")

# Evaluation function
def evaluate_bert(model, test_loader, rank, dataset_name):
    device = torch.device(f"cuda:{rank}" if torch.cuda.is_available() else "cpu")
    model.eval()
    correct = 0
    total = 0
    
    with torch.no_grad():
        for batch in test_loader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)
            
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            predictions = torch.argmax(outputs.logits, dim=-1)
            correct += (predictions == labels).sum().item()
            total += labels.size(0)
    
    accuracy = correct / total
    if rank == 0:
        print(f"{dataset_name} Accuracy: {accuracy:.4f}")
    return accuracy

# Main distributed training function
def run_distributed_training(rank, world_size, model, train_dataset, test_dataset, batch_size=8):
    # Setup distributed training
    setup_distributed_training(rank, world_size)
    
    # Create distributed DataLoader
    train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank) if world_size > 1 else None
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        sampler=train_sampler,
        shuffle=(train_sampler is None),
        num_workers=0
    )
    
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
    
    # Train and evaluate
    train_bert(model, train_loader, rank, world_size)
    evaluate_bert(model, test_loader, rank, "Test")
    
    # Clean up
    if world_size > 1:
        dist.destroy_process_group()


In [3]:

# Initialize Spark and load data
spark = init_spark()
train_df, test_df = load_and_preprocess_data(spark)
train_df, test_df = train_df.limit(6), test_df.limit(6)

# Initialize BERT
tokenizer, model = init_bert()

# Create Datasets
train_dataset = SparkDataset(train_df, tokenizer)
test_dataset = SparkDataset(test_df, tokenizer)

# GPU setup
world_size = torch.cuda.device_count() if torch.cuda.is_available() else 1
print(f"Using {world_size} GPU(s)")

if world_size > 1:
    # Multi-GPU: Use torch.multiprocessing
    import torch.multiprocessing as mp
    mp.spawn(
        run_distributed_training,
        args=(world_size, model, train_dataset, test_dataset),
        nprocs=world_size,
        join=True
    )
else:
    # Single GPU or CPU
    run_distributed_training(0, 1, model, train_dataset, test_dataset)

# Stop Spark
spark.stop()

25/04/13 20:37:33 WARN Utils: Your hostname, yPC resolves to a loopback address: 127.0.1.1; using 10.255.255.254 instead (on interface lo)
25/04/13 20:37:33 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/04/13 20:37:33 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Using 1 GPU(s)
Epoch 1, Avg Loss: 0.4646
Epoch 2, Avg Loss: 0.3847
Epoch 3, Avg Loss: 0.2759
Test Accuracy: 1.0000
