In [1]:
import pyspark
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, length
from pyspark.sql.types import ArrayType, IntegerType, StructType, StructField
from datasets import load_dataset
from transformers import BertTokenizer, BertForSequenceClassification
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import IterableDataset, DataLoader
import os
import uuid
import pandas as pd
import numpy as np
from pyspark.sql.functions import pandas_udf
import pyarrow.parquet as pq
import glob
import logging
import time

In [2]:
spark = SparkSession.builder \
        .appName("Distributed BERT Fine-Tuning with Preprocessing") \
        .config("spark.driver.memory", "4g") \
        .config("spark.executor.memory", "4g") \
        .config("spark.cores.max", 4) \
        .config("spark.mongodb.input.uri", "mongodb://localhost:27017/sentiment_db.reviews") \
        .config("spark.mongodb.output.uri", "mongodb://localhost:27017/sentiment_db.reviews") \
        .config("spark.jars.packages", "org.mongodb.spark:mongo-spark-connector_2.12:3.0.1") \
        .getOrCreate()

25/04/25 19:14:17 WARN Utils: Your hostname, yPC resolves to a loopback address: 127.0.1.1; using 10.255.255.254 instead (on interface lo)
25/04/25 19:14:17 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Ivy Default Cache set to: /home/goodh/.ivy2/cache
The jars for the packages stored in: /home/goodh/.ivy2/jars
org.mongodb.spark#mongo-spark-connector_2.12 added as a dependency
:: resolving dependencies :: org.apache.spark#spark-submit-parent-d5fe131a-ac5d-4d22-9abd-81c213a111bf;1.0
	confs: [default]
	found org.mongodb.spark#mongo-spark-connector_2.12;3.0.1 in central
	found org.mongodb#mongodb-driver-sync;4.0.5 in central


:: loading settings :: url = jar:file:/home/goodh/miniconda3/envs/5003/lib/python3.9/site-packages/pyspark/jars/ivy-2.5.1.jar!/org/apache/ivy/core/settings/ivysettings.xml


	found org.mongodb#bson;4.0.5 in central
	found org.mongodb#mongodb-driver-core;4.0.5 in central
:: resolution report :: resolve 77ms :: artifacts dl 3ms
	:: modules in use:
	org.mongodb#bson;4.0.5 from central in [default]
	org.mongodb#mongodb-driver-core;4.0.5 from central in [default]
	org.mongodb#mongodb-driver-sync;4.0.5 from central in [default]
	org.mongodb.spark#mongo-spark-connector_2.12;3.0.1 from central in [default]
	---------------------------------------------------------------------
	|                  |            modules            ||   artifacts   |
	|       conf       | number| search|dwnlded|evicted|| number|dwnlded|
	---------------------------------------------------------------------
	|      default     |   4   |   0   |   0   |   0   ||   4   |   0   |
	---------------------------------------------------------------------
:: retrieving :: org.apache.spark#spark-submit-parent-d5fe131a-ac5d-4d22-9abd-81c213a111bf
	confs: [default]
	0 artifacts copied, 4 already re

In [None]:
imdb_dataset = load_dataset("imdb")
imdb_df = pd.concat([
    imdb_dataset['train'].to_pandas()[['text', 'label']],
    imdb_dataset['test'].to_pandas()[['text', 'label']],
]).sample(frac=1, random_state=42).reset_index(drop=True)
imdb_df['source'] = 'IMDB'

sst2_dataset = load_dataset('glue', 'sst2')
sst2_df = sst2_dataset['train'].to_pandas()[['sentence', 'label']].sample(frac=1, random_state=42).reset_index(drop=True)
sst2_df = sst2_df.rename(columns={'sentence': 'text'})
sst2_df['source'] = 'SST-2'

In [85]:
imdb_spark_df = spark.createDataFrame(imdb_df).select(col('text'), col('label').cast('integer'), col('source'))
sst2_spark_df = spark.createDataFrame(sst2_df).select(col('text'), col('label').cast('integer'), col('source'))

In [22]:
def create_batch_tokenizer_udf(max_length=128):
    def tokenize_batch(texts: pd.Series) -> pd.DataFrame:
        tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
        encodings = tokenizer(
            texts.tolist(),
            padding="max_length",
            truncation=True,
            max_length=max_length,
            return_tensors="np"
        )
        return pd.DataFrame({
            "input_ids": [ids.tolist() for ids in encodings["input_ids"]],
            "attention_mask": [mask.tolist() for mask in encodings["attention_mask"]]
        })
    
    schema = StructType([
        StructField("input_ids", ArrayType(IntegerType())),
        StructField("attention_mask", ArrayType(IntegerType()))
    ])
    
    return pandas_udf(tokenize_batch, schema)

In [7]:
processed_df = imdb_spark_df.filter(length(col('text'))>=10)

In [8]:
tokenize_udf = create_batch_tokenizer_udf(128)
tokenized_df = processed_df.withColumn('encoding', tokenize_udf(col('text')))

In [9]:
tokenized_df = tokenized_df.select(
    col("label").cast("integer").alias("label"),
    col("source"),
    col("encoding.input_ids").alias("input_ids"),
    col("encoding.attention_mask").alias("attention_mask")
)

In [10]:
train_spark_df, test_spark_df = imdb_spark_df.randomSplit([0.8, 0.2], seed=42)
sst2_spark_test_df = sst2_spark_df

In [67]:
checkpoint['sst2_eval_losses']

[0.6795020675783696,
 0.6675350687243516,
 0.6599975805975139,
 0.6501290568721974,
 0.6425037231352815,
 0.6386665645982164,
 0.6346162742448962,
 0.6312412393831311,
 0.6338303024067389,
 0.6215268512384063,
 0.6137265110306407,
 0.6172930431199946,
 0.6118318956605655,
 0.6112569093229877,
 0.6062058276781482,
 0.6037738823795816,
 0.6018424541966235,
 0.5990889155725376,
 0.6015712148927984,
 0.5980112003289891]

In [68]:
checkpoint.keys()

dict_keys(['epoch', 'model_state_dict', 'optimizer_state_dict', 'scaler_state_dict', 'train_losses', 'imdb_eval_losses', 'sst2_eval_losses'])

In [66]:
model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2, hidden_dropout_prob=0.3, attention_probs_dropout_prob=0.3)
optimizer = torch.optim.AdamW(model.classifier.parameters(), lr=2e-5)
scaler = torch.amp.GradScaler('cuda')
checkpoint = torch.load("./checkpoints/20250425_210603_bert_finetuned_epoch_20.pt")
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
scaler.load_state_dict(checkpoint['scaler_state_dict'])

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [23]:
class LazyParquetDataset(IterableDataset):
    def __init__(self, parquet_path, rank, world_size, batch_size=1000):
        self.parquet_files = sorted(glob.glob(os.path.join(parquet_path, "*.parquet")))
        self.rank = rank
        self.world_size = world_size
        self.batch_size = batch_size
        
        # Shard files across ranks
        files_per_rank = len(self.parquet_files) // world_size
        start_idx = rank * files_per_rank
        end_idx = (rank + 1) * files_per_rank if rank < world_size - 1 else len(self.parquet_files)
        self.parquet_files = self.parquet_files[start_idx:end_idx]
    
    def __iter__(self):
        for file in self.parquet_files:
            # print(f"Rank {self.rank} reading Parquet file: {file}")
            parquet_file = pq.ParquetFile(file)
            for batch in parquet_file.iter_batches(batch_size=self.batch_size):
                df = batch.to_pandas()
                for _, row in df.iterrows():
                    yield {
                        "input_ids": torch.tensor(row["input_ids"], dtype=torch.long),
                        "attention_mask": torch.tensor(row["attention_mask"], dtype=torch.long),
                        "labels": torch.tensor(row["label"], dtype=torch.long)
                    }

In [49]:
train_dataset = LazyParquetDataset('archive/train/', 0, 1)
train_loader = DataLoader(train_dataset, batch_size=8)

In [42]:
count

64

In [47]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
correct = total = count = 0
with torch.no_grad():
    for batch in train_loader:
        input_ids, attention_mask, labels = batch['input_ids'].to(device), batch['attention_mask'].to(device), batch['labels'].to(device)
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        predictions = torch.argmax(outputs.logits, dim=-1)
        print(predictions.tolist(), labels.tolist(), (predictions==labels).tolist())
        correct += (predictions == labels).sum().item()
        total += labels.size(0)
        count += predictions.shape[0]

[0, 0, 0, 0, 0, 0, 0, 0] [0, 1, 0, 0, 0, 1, 1, 0] [True, False, True, True, True, False, False, True]
[0, 0, 0, 0, 0, 0, 0, 0] [0, 0, 0, 0, 0, 0, 0, 1] [True, True, True, True, True, True, True, False]
[0, 0, 0, 0, 0, 0, 0, 0] [1, 1, 0, 0, 1, 0, 0, 1] [False, False, True, True, False, True, True, False]
[0, 0, 0, 0, 0, 0, 0, 0] [1, 1, 0, 0, 0, 0, 1, 0] [False, False, True, True, True, True, False, True]
[0, 0, 0, 0, 0, 0, 0, 0] [1, 1, 1, 0, 0, 0, 0, 0] [False, False, False, True, True, True, True, True]
[0, 0, 0, 0, 0, 0, 0, 0] [0, 1, 1, 1, 0, 0, 0, 1] [True, False, False, False, True, True, True, False]
[0, 0, 0, 0, 0, 0, 0, 0] [1, 1, 1, 0, 1, 0, 0, 0] [False, False, False, True, False, True, True, True]
[0, 0, 0, 0, 0, 0, 0, 0] [1, 0, 1, 0, 0, 0, 0, 0] [False, True, False, True, True, True, True, True]
[0, 0, 0, 0, 0, 0, 0, 0] [0, 0, 1, 0, 1, 1, 0, 0] [True, True, False, True, False, False, True, True]
[0, 0, 0, 0, 0, 0, 0, 0] [1, 1, 0, 0, 1, 0, 0, 0] [False, False, True, True, False

In [15]:
import pyspark
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, length
from pyspark.sql.types import ArrayType, IntegerType, StructType, StructField
from datasets import load_dataset
from transformers import BertTokenizer, BertForSequenceClassification
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import IterableDataset, DataLoader
import os
import uuid
import pandas as pd
import numpy as np
from pyspark.sql.functions import pandas_udf
import pyarrow.parquet as pq
import glob
import logging
import time

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Initialize Spark with MongoDB connector
def init_spark(num_cpus = None):
    # if num_spark_executor_core: logger.info(f"{num_spark_executor_core} cores for executor")
    # else: logger.info(f"number of cores for executor UNDEFINED")
    if num_cpus: logger.info(f"{num_cpus} cores for spark")
    else: logger.info(f"num_cpus UNDEFINED")
    spark = SparkSession.builder \
        .appName("Distributed BERT Fine-Tuning with Preprocessing") \
        .config("spark.driver.memory", "4g") \
        .config("spark.executor.memory", "4g") \
        .config("spark.cores.max", num_cpus) \
        .config("spark.mongodb.input.uri", "mongodb://localhost:27017/sentiment_db.reviews") \
        .config("spark.mongodb.output.uri", "mongodb://localhost:27017/sentiment_db.reviews") \
        .config("spark.jars.packages", "org.mongodb.spark:mongo-spark-connector_2.12:3.0.1") \
        .config("spark.mongodb.input.partitionerOptions.partitionSizeMB", "256") \
        .getOrCreate()
        # .config("spark.driver.cores", "2") \
        # .config("spark.executor.cores", str(num_spark_executor_core) if num_spark_executor_core else 4) \
        # .config("spark.executor.instances", 3) \
        # .config("spark.default.parallelism", 10) \
    return spark

# Load IMDB and SST-2 data to MongoDB
def load_data_to_mongodb(spark):
    # IMDB dataset
    start_time = time.time()
    logger.info("Loading IMDB dataset...")
    imdb_dataset = load_dataset("imdb")
    imdb_df = pd.concat([
        imdb_dataset["train"].to_pandas()[["text", "label"]],  # use 50 for debug
        imdb_dataset["test"].to_pandas()[["text", "label"]]
    ])
    imdb_df["source"] = "IMDB"
    imdb_spark_df = spark.createDataFrame(imdb_df).select(col("text"), col("label").cast("integer"), col("source"))
    
    # SST-2 dataset
    logger.info("Loading SST-2 dataset...")
    sst2_dataset = load_dataset("glue", "sst2")
    sst2_df = sst2_dataset["train"].to_pandas()[["sentence", "label"]]  # use 50 for debug
    sst2_df = sst2_df.rename(columns={"sentence": "text"})
    sst2_df["source"] = "SST-2"
    sst2_spark_df = spark.createDataFrame(sst2_df).select(col("text"), col("label").cast("integer"), col("source"))
    
    logger.info("Writing datasets to MongoDB...")
    imdb_spark_df.write.format("mongo").mode("append").save()
    sst2_spark_df.write.format("mongo").mode("append").save()
    return time.time() - start_time, imdb_spark_df, sst2_spark_df

# Batch tokenizer UDF
def create_batch_tokenizer_udf(max_length=128):
    def tokenize_batch(texts: pd.Series) -> pd.DataFrame:
        tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
        encodings = tokenizer(
            texts.tolist(),
            padding="max_length",
            truncation=True,
            max_length=max_length,
            return_tensors="np"
        )
        return pd.DataFrame({
            "input_ids": [ids.tolist() for ids in encodings["input_ids"]],
            "attention_mask": [mask.tolist() for mask in encodings["attention_mask"]]
        })
    
    schema = StructType([
        StructField("input_ids", ArrayType(IntegerType())),
        StructField("attention_mask", ArrayType(IntegerType()))
    ])
    
    return pandas_udf(tokenize_batch, schema)

# Preprocess data and save to Parquet
def preprocess_data(spark, imdb_spark_df, sst2_spark_df, output_dir, max_length=128):
    start_time = time.time()

    # Load and preprocess data
    logger.info("Reading data from MongoDB...")
    # raw_df = spark.read.format("mongo").load()
    raw_df = imdb_spark_df.union(sst2_spark_df)
    processed_df = raw_df.filter(length(col("text")) >= 10)
    
    # Apply distributed batch tokenization
    logger.info("Tokenizing data...")
    tokenize_udf = create_batch_tokenizer_udf(max_length)
    tokenized_df = processed_df.withColumn("encoding", tokenize_udf(col("text")))
    
    # Extract input_ids and attention_mask
    tokenized_df = tokenized_df.select(
        col("label").cast("integer").alias("label"),
        col("source"),
        col("encoding.input_ids").alias("input_ids"),
        col("encoding.attention_mask").alias("attention_mask")
    )
    
    # Split IMDB into train/test
    imdb_df = tokenized_df.filter(col("source") == "IMDB")
    train_df, test_df = imdb_df.randomSplit([0.8, 0.2], seed=42)
    sst2_test_df = tokenized_df.filter(col("source") == "SST-2")
    
    # Save to Parquet with dynamic partitioning
    num_partitions = max(16, spark.sparkContext.defaultParallelism * 2)  # Adjust based on cluster size
    logger.info(f"num_partitions {num_partitions}")
    train_path = os.path.join(output_dir, f"train_{uuid.uuid4().hex}")
    test_path = os.path.join(output_dir, f"test_{uuid.uuid4().hex}")
    sst2_test_path = os.path.join(output_dir, f"sst2_{uuid.uuid4().hex}")
    
    _start_time = time.time()
    logger.info(f"Writing Parquet files: train={train_path}, test={test_path}, sst2={sst2_test_path}")
    train_df.select("input_ids", "attention_mask", "label").repartition(num_partitions).write.mode("overwrite").parquet(train_path)
    logger.info(f"{time.time()-_start_time:.4f}s for train_df partition")
    # _start_time = time.time()
    # test_df.select("input_ids", "attention_mask", "label").repartition(num_partitions).write.mode("overwrite").parquet(test_path)
    # logger.info(f"{time.time()-_start_time:.4f}s for test_df partition")
    # _start_time = time.time()
    # sst2_test_df.select("input_ids", "attention_mask", "label").repartition(num_partitions).write.mode("overwrite").parquet(sst2_test_path)
    # logger.info(f"{time.time()-_start_time:.4f}s for sst2_df partition")
    
    # Store processed data in MongoDB for reference
    # train_collection = f"train_{uuid.uuid4().hex}"
    # test_collection = f"test_{uuid.uuid4().hex}"
    # sst2_collection = f"sst2_{uuid.uuid4().hex}"
    # train_df.write.format("mongo").option("collection", train_collection).mode("overwrite").save()
    # test_df.write.format("mongo").option("collection", test_collection).mode("overwrite").save()
    # sst2_test_df.write.format("mongo").option("collection", sst2_collection).mode("overwrite").save()
    
    preprocess_time = time.time() - start_time
    return train_path, test_path, sst2_test_path, train_collection, test_collection, sst2_collection, preprocess_time

# Check for cached Parquet files
def check_cached_parquet(output_dir):
    train_path = test_path = sst2_test_path = None
    train_collection = test_collection = sst2_collection = None
    
    for dir_name in os.listdir(output_dir):
        if dir_name.startswith("train_"):
            train_path = os.path.join(output_dir, dir_name)
            train_collection = dir_name
        elif dir_name.startswith("test_"):
            test_path = os.path.join(output_dir, dir_name)
            test_collection = dir_name
        elif dir_name.startswith("sst2_"):
            sst2_test_path = os.path.join(output_dir, dir_name)
            sst2_collection = dir_name
    
    if train_path and test_path and sst2_test_path:
        logger.info(f"Found cached Parquet files: train={train_path}, test={test_path}, sst2={sst2_test_path}")
        return train_path, test_path, sst2_test_path, train_collection, test_collection, sst2_collection
    return None

# Lazy-loading Parquet dataset
class LazyParquetDataset(IterableDataset):
    def __init__(self, parquet_path, rank, world_size, batch_size=1000):
        self.parquet_files = sorted(glob.glob(os.path.join(parquet_path, "*.parquet")))
        self.rank = rank
        self.world_size = world_size
        self.batch_size = batch_size
        
        # Shard files across ranks
        files_per_rank = len(self.parquet_files) // world_size
        start_idx = rank * files_per_rank
        end_idx = (rank + 1) * files_per_rank if rank < world_size - 1 else len(self.parquet_files)
        self.parquet_files = self.parquet_files[start_idx:end_idx]
    
    def __iter__(self):
        for file in self.parquet_files:
            logger.debug(f"Rank {self.rank} reading Parquet file: {file}")
            parquet_file = pq.ParquetFile(file)
            for batch in parquet_file.iter_batches(batch_size=self.batch_size):
                df = batch.to_pandas()
                for _, row in df.iterrows():
                    yield {
                        "input_ids": torch.tensor(row["input_ids"], dtype=torch.long),
                        "attention_mask": torch.tensor(row["attention_mask"], dtype=torch.long),
                        "labels": torch.tensor(row["label"], dtype=torch.long)
                    }

# Training and evaluation
def train_and_evaluate(rank, world_size, train_path, test_path, sst2_test_path, train_collection, test_collection, sst2_collection, finetune_time, batch_size=8, epochs=3):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12345"
    dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)
    
    model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)
    model = DDP(model.to(rank), device_ids=[rank])
    
    # Create datasets
    train_dataset = LazyParquetDataset(train_path, rank, world_size)
    test_dataset = LazyParquetDataset(test_path, rank, world_size)
    sst2_test_dataset = LazyParquetDataset(sst2_test_path, rank, world_size)
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=0)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=0)
    sst2_test_loader = DataLoader(sst2_test_dataset, batch_size=batch_size, num_workers=0)
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
    # scaler = torch.cuda.amp.GradScaler()  # For mixed-precision training
    scaler = torch.amp.GradScaler('cuda')  # For mixed-precision training
    
    # Measure training wall time
    dist.barrier()  # Synchronize all ranks before timing
    train_start_time = time.time()
    
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        num_batches = 0
        for batch in train_loader:
            input_ids = batch["input_ids"].to(rank)
            attention_mask = batch["attention_mask"].to(rank)
            labels = batch["labels"].to(rank)
            
            # with torch.cuda.amp.autocast():
            with torch.amp.autocast('cuda'):
                outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
                loss = outputs.loss
            
            optimizer.zero_grad()
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            
            total_loss += loss.item()
            num_batches += 1
        
        logger.info(f"GPU[{rank}], Epoch {epoch+1}, Avg Loss: {total_loss / num_batches:.4f}")
    
    dist.barrier()  # Synchronize all ranks after training
    train_end_time = time.time()
    train_wall_time = train_end_time - train_start_time
    
    # Aggregate max training time across ranks
    train_wall_time_tensor = torch.tensor(train_wall_time, dtype=torch.float64).cuda(rank)
    dist.all_reduce(train_wall_time_tensor, op=dist.ReduceOp.MAX)
    train_wall_time_max = train_wall_time_tensor.item()
    
    # Log training time only from rank 0
    if rank == 0:
        finetune_time[0] = train_wall_time_max
        logger.info(f"Training wall time (max across ranks): {train_wall_time_max:.2f} seconds")
    
    model.eval()
    for dataset_name, loader in [("IMDB Test", test_loader), ("SST-2 Test", sst2_test_loader)]:
        correct = total = 0
        with torch.no_grad():
            for batch in loader:
                input_ids = batch["input_ids"].to(rank)
                attention_mask = batch["attention_mask"].to(rank)
                labels = batch["labels"].to(rank)
                outputs = model(input_ids=input_ids, attention_mask=attention_mask)
                predictions = torch.argmax(outputs.logits, dim=-1)
                correct += (predictions == labels).sum().item()
                total += labels.size(0)
        logger.info(f"GPU[{rank}]: {dataset_name} Accuracy: {correct / total:.4f}")
    
    dist.destroy_process_group()

# Main
NUM_CPUs = 4
NUM_GPUs = 1
logger.info("Initializing Spark...")
# os.environ['PYSPARK_PYTHON'] = '/home/goodh/miniconda3/envs/5003/bin/python'
# os.environ['PYSPARK_DRIVER_PYTHON'] = '/home/goodh/miniconda3/envs/5003/bin/python'
spark = init_spark(NUM_CPUs)

# Output directory for Parquet files
output_dir = "processed_data"
os.makedirs(output_dir, exist_ok=True)

# Check for cached Parquet files
# cached_data = check_cached_parquet(output_dir)
cached_data = None
train_path = test_path = sst2_test_path = train_collection = test_collection = sst2_collection = None
preprocess_time = 0

if cached_data:
    logger.info("Cached Parquet files found. Skipping data loading and preprocessing...")
    train_path, test_path, sst2_test_path, train_collection, test_collection, sst2_collection = cached_data
else:
    # Load and preprocess data
    logger.info("No cached Parquet files found. Running full pipeline...")
    logger.info("Loading data to MongoDB...")
    load_data_time, imdb_spark_df, sst2_spark_df = load_data_to_mongodb(spark)
    logger.info(f"Data loading to MongoDB took {load_data_time:.2f} seconds")
    
    logger.info("Distributed preprocessing and saving to Parquet...")
    train_path, test_path, sst2_test_path, train_collection, test_collection, sst2_collection, preprocess_time = preprocess_data(spark, imdb_spark_df, sst2_spark_df, output_dir)
    logger.info(f"Distributed preprocessing took {preprocess_time:.2f} seconds")

# Run distributed training
world_size = NUM_GPUs if NUM_GPUs else max(1, torch.cuda.device_count())
logger.info(f"Using {world_size} GPU(s)")

logger.info("Distributed fine-tuning...")
import torch.multiprocessing as mp
finetune_time = torch.zeros(world_size, dtype=torch.float32).share_memory_()
mp.spawn(
    train_and_evaluate,
    args=(world_size, train_path, test_path, sst2_test_path, train_collection, test_collection, sst2_collection, finetune_time),
    nprocs=world_size,
    join=True
)

# append results
result = f"{time.strftime('%Y/%m/%d-%H:%M:%S')}\t{NUM_CPUs}\t\t{NUM_GPUs}\t\t{preprocess_time:.2f}\t\t{finetune_time[0]:.2f}\n"
logger.info(result)
with open("out/results.out", "a") as f:
    f.write(result)

spark.stop()

INFO:__main__:Initializing Spark...
INFO:__main__:4 cores for spark
25/04/22 23:28:22 WARN SparkSession: Using an existing Spark session; only runtime SQL configurations will take effect.
INFO:__main__:No cached Parquet files found. Running full pipeline...
INFO:__main__:Loading data to MongoDB...
INFO:__main__:Loading IMDB dataset...
INFO:__main__:Loading SST-2 dataset...
INFO:__main__:Writing datasets to MongoDB...
25/04/22 23:28:59 WARN TaskSetManager: Stage 5 contains a task of very large size (1327 KiB). The maximum recommended task size is 1000 KiB.
INFO:__main__:Data loading to MongoDB took 38.81 seconds                        
INFO:__main__:Distributed preprocessing and saving to Parquet...
INFO:__main__:Reading data from MongoDB...
INFO:__main__:Tokenizing data...
INFO:__main__:num_partitions 56
INFO:__main__:Writing Parquet files: train=processed_data/train_5933d11a44e643b297ab8a2a7b33a7ad, test=processed_data/test_3d891be5c6144acfb2990326693eaf61, sst2=processed_data/sst2_0d

ProcessExitedException: process 0 terminated with exit code 1

In [1]:
import pyspark
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, length
from pyspark.sql.types import ArrayType, IntegerType, StructType, StructField
from datasets import load_dataset
from transformers import BertTokenizer, BertForSequenceClassification
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import IterableDataset, DataLoader
import os
import uuid
import pandas as pd
import numpy as np
from pyspark.sql.functions import pandas_udf
import pyarrow.parquet as pq
import glob
import logging
import time

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Initialize Spark with MongoDB connector
def init_spark(num_cpus = None):
    # if num_spark_executor_core: logger.info(f"{num_spark_executor_core} cores for executor")
    # else: logger.info(f"number of cores for executor UNDEFINED")
    if num_cpus: logger.info(f"{num_cpus} cores for spark")
    else: logger.info(f"num_cpus UNDEFINED")
    spark = SparkSession.builder \
        .appName("Distributed BERT Fine-Tuning with Preprocessing") \
        .config("spark.sql.shuffle.partitions", 28) \
        .config("spark.driver.memory", "4g") \
        .config("spark.executor.memory", "4g") \
        .config("spark.executor.cores", 4) \
        .config("spark.executor.instances", 7) \
        .config("spark.cores.max", num_cpus) \
        .config("spark.mongodb.input.uri", "mongodb://localhost:27017/sentiment_db.reviews") \
        .config("spark.mongodb.output.uri", "mongodb://localhost:27017/sentiment_db.reviews") \
        .config("spark.jars.packages", "org.mongodb.spark:mongo-spark-connector_2.12:3.0.1") \
        .getOrCreate()
        # .config("spark.mongodb.input.partitionerOptions.partitionSizeMB", "256") \
        # .config("spark.driver.cores", "2") \
        # .config("spark.default.parallelism", 10) \
    return spark

# Load IMDB and SST-2 data to MongoDB
def load_data_to_mongodb(spark):
    # IMDB dataset
    start_time = time.time()
    logger.info("Loading IMDB dataset...")
    imdb_dataset = load_dataset("imdb")
    imdb_df = pd.concat([
        imdb_dataset["train"].to_pandas()[["text", "label"]],  # use 50 for debug
        imdb_dataset["test"].to_pandas()[["text", "label"]]
    ])
    imdb_df["source"] = "IMDB"
    imdb_spark_df = spark.createDataFrame(imdb_df).select(col("text"), col("label").cast("integer"), col("source"))
    
    # # SST-2 dataset
    # logger.info("Loading SST-2 dataset...")
    # sst2_dataset = load_dataset("glue", "sst2")
    # sst2_df = sst2_dataset["train"].to_pandas()[["sentence", "label"]]  # use 50 for debug
    # sst2_df = sst2_df.rename(columns={"sentence": "text"})
    # sst2_df["source"] = "SST-2"
    # sst2_spark_df = spark.createDataFrame(sst2_df).select(col("text"), col("label").cast("integer"), col("source"))
    
    logger.info("Writing datasets to MongoDB...")
    imdb_spark_df.write.format("mongo").mode("append").save()
    # sst2_spark_df.write.format("mongo").mode("append").save()
    return time.time() - start_time

# Batch tokenizer UDF
def create_batch_tokenizer_udf(max_length=128):
    def tokenize_batch(texts: pd.Series) -> pd.DataFrame:
        tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
        encodings = tokenizer(
            texts.tolist(),
            padding="max_length",
            truncation=True,
            max_length=max_length,
            return_tensors="np"
        )
        return pd.DataFrame({
            "input_ids": [ids.tolist() for ids in encodings["input_ids"]],
            "attention_mask": [mask.tolist() for mask in encodings["attention_mask"]]
        })
    
    schema = StructType([
        StructField("input_ids", ArrayType(IntegerType())),
        StructField("attention_mask", ArrayType(IntegerType()))
    ])
    
    return pandas_udf(tokenize_batch, schema)

# Preprocess data and save to Parquet
def preprocess_data(spark, output_dir, max_length=128):
    start_time = time.time()

    # Load and preprocess data
    logger.info("Reading data from MongoDB...")
    # num_partitions = max(16, spark.sparkContext.defaultParallelism * 2)  # Adjust based on cluster size
    num_partitions = 28
    raw_df = spark.read.format("mongo").load()
    raw_df.repartition(num_partitions)
    processed_df = raw_df.filter(length(col("text")) >= 10)
    
    # Apply distributed batch tokenization
    logger.info("Tokenizing data...")
    tokenize_udf = create_batch_tokenizer_udf(max_length)
    tokenized_df = processed_df.withColumn("encoding", tokenize_udf(col("text")))
    
    # Extract input_ids and attention_mask
    tokenized_df = tokenized_df.select(
        col("label").cast("integer").alias("label"),
        col("source"),
        col("encoding.input_ids").alias("input_ids"),
        col("encoding.attention_mask").alias("attention_mask")
    )
    
    # Split IMDB into train/test
    imdb_df = tokenized_df.filter(col("source") == "IMDB")
    train_df, test_df = imdb_df.randomSplit([0.8, 0.2], seed=42)
    sst2_test_df = tokenized_df.filter(col("source") == "SST-2")
    
    # Save to Parquet with dynamic partitioning
    logger.info(f"num_partitions {num_partitions}")
    train_path = os.path.join(output_dir, f"train_{uuid.uuid4().hex}")
    # test_path = os.path.join(output_dir, f"test_{uuid.uuid4().hex}")
    # sst2_test_path = os.path.join(output_dir, f"sst2_{uuid.uuid4().hex}")
    
    _start_time = time.time()
    logger.info(f"Writing Parquet files: train={train_path}, test={test_path}, sst2={sst2_test_path}")
    train_df.select("input_ids", "attention_mask", "label").write.mode("overwrite").parquet(train_path)
    logger.info(f"{time.time()-_start_time:.4f}s for train_df partition")
    # _start_time = time.time()
    # test_df.select("input_ids", "attention_mask", "label").repartition(num_partitions).write.mode("overwrite").parquet(test_path)
    # logger.info(f"{time.time()-_start_time:.4f}s for test_df partition")
    # _start_time = time.time()
    # sst2_test_df.select("input_ids", "attention_mask", "label").repartition(num_partitions).write.mode("overwrite").parquet(sst2_test_path)
    # logger.info(f"{time.time()-_start_time:.4f}s for sst2_df partition")
    
    # Store processed data in MongoDB for reference
    train_collection = f"train_{uuid.uuid4().hex}"
    # test_collection = f"test_{uuid.uuid4().hex}"
    # sst2_collection = f"sst2_{uuid.uuid4().hex}"
    train_df.write.format("mongo").option("collection", train_collection).mode("overwrite").save()
    # test_df.write.format("mongo").option("collection", test_collection).mode("overwrite").save()
    # sst2_test_df.write.format("mongo").option("collection", sst2_collection).mode("overwrite").save()
    
    preprocess_time = time.time() - start_time
    return train_path, test_path, sst2_test_path, train_collection, test_collection, sst2_collection, preprocess_time

# Check for cached Parquet files
def check_cached_parquet(output_dir):
    train_path = test_path = sst2_test_path = None
    train_collection = test_collection = sst2_collection = None
    
    for dir_name in os.listdir(output_dir):
        if dir_name.startswith("train_"):
            train_path = os.path.join(output_dir, dir_name)
            train_collection = dir_name
        elif dir_name.startswith("test_"):
            test_path = os.path.join(output_dir, dir_name)
            test_collection = dir_name
        elif dir_name.startswith("sst2_"):
            sst2_test_path = os.path.join(output_dir, dir_name)
            sst2_collection = dir_name
    
    if train_path and test_path and sst2_test_path:
        logger.info(f"Found cached Parquet files: train={train_path}, test={test_path}, sst2={sst2_test_path}")
        return train_path, test_path, sst2_test_path, train_collection, test_collection, sst2_collection
    return None

# Lazy-loading Parquet dataset
class LazyParquetDataset(IterableDataset):
    def __init__(self, parquet_path, rank, world_size, batch_size=1000):
        self.parquet_files = sorted(glob.glob(os.path.join(parquet_path, "*.parquet")))
        self.rank = rank
        self.world_size = world_size
        self.batch_size = batch_size
        
        # Shard files across ranks
        files_per_rank = len(self.parquet_files) // world_size
        start_idx = rank * files_per_rank
        end_idx = (rank + 1) * files_per_rank if rank < world_size - 1 else len(self.parquet_files)
        self.parquet_files = self.parquet_files[start_idx:end_idx]
    
    def __iter__(self):
        for file in self.parquet_files:
            logger.debug(f"Rank {self.rank} reading Parquet file: {file}")
            parquet_file = pq.ParquetFile(file)
            for batch in parquet_file.iter_batches(batch_size=self.batch_size):
                df = batch.to_pandas()
                for _, row in df.iterrows():
                    yield {
                        "input_ids": torch.tensor(row["input_ids"], dtype=torch.long),
                        "attention_mask": torch.tensor(row["attention_mask"], dtype=torch.long),
                        "labels": torch.tensor(row["label"], dtype=torch.long)
                    }

# Training and evaluation
def train_and_evaluate(rank, world_size, train_path, test_path, sst2_test_path, train_collection, test_collection, sst2_collection, finetune_time, batch_size=8, epochs=3):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12345"
    dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)
    
    model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)
    model = DDP(model.to(rank), device_ids=[rank])
    
    # Create datasets
    train_dataset = LazyParquetDataset(train_path, rank, world_size)
    test_dataset = LazyParquetDataset(test_path, rank, world_size)
    sst2_test_dataset = LazyParquetDataset(sst2_test_path, rank, world_size)
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=0)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=0)
    sst2_test_loader = DataLoader(sst2_test_dataset, batch_size=batch_size, num_workers=0)
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
    # scaler = torch.cuda.amp.GradScaler()  # For mixed-precision training
    scaler = torch.amp.GradScaler('cuda')  # For mixed-precision training
    
    # Measure training wall time
    dist.barrier()  # Synchronize all ranks before timing
    train_start_time = time.time()
    
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        num_batches = 0
        for batch in train_loader:
            input_ids = batch["input_ids"].to(rank)
            attention_mask = batch["attention_mask"].to(rank)
            labels = batch["labels"].to(rank)
            
            # with torch.cuda.amp.autocast():
            with torch.amp.autocast('cuda'):
                outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
                loss = outputs.loss
            
            optimizer.zero_grad()
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            
            total_loss += loss.item()
            num_batches += 1
        
        logger.info(f"GPU[{rank}], Epoch {epoch+1}, Avg Loss: {total_loss / num_batches:.4f}")
    
    dist.barrier()  # Synchronize all ranks after training
    train_end_time = time.time()
    train_wall_time = train_end_time - train_start_time
    
    # Aggregate max training time across ranks
    train_wall_time_tensor = torch.tensor(train_wall_time, dtype=torch.float64).cuda(rank)
    dist.all_reduce(train_wall_time_tensor, op=dist.ReduceOp.MAX)
    train_wall_time_max = train_wall_time_tensor.item()
    
    # Log training time only from rank 0
    if rank == 0:
        finetune_time[0] = train_wall_time_max
        logger.info(f"Training wall time (max across ranks): {train_wall_time_max:.2f} seconds")
    
    model.eval()
    for dataset_name, loader in [("IMDB Test", test_loader), ("SST-2 Test", sst2_test_loader)]:
        correct = total = 0
        with torch.no_grad():
            for batch in loader:
                input_ids = batch["input_ids"].to(rank)
                attention_mask = batch["attention_mask"].to(rank)
                labels = batch["labels"].to(rank)
                outputs = model(input_ids=input_ids, attention_mask=attention_mask)
                predictions = torch.argmax(outputs.logits, dim=-1)
                correct += (predictions == labels).sum().item()
                total += labels.size(0)
        logger.info(f"GPU[{rank}]: {dataset_name} Accuracy: {correct / total:.4f}")
    
    dist.destroy_process_group()

# Main
# if __name__ == "__main__":
NUM_CPUs = 28
NUM_GPUs = 1
logger.info("Initializing Spark...")
# os.environ['PYSPARK_PYTHON'] = '/home/goodh/miniconda3/envs/5003/bin/python'
# os.environ['PYSPARK_DRIVER_PYTHON'] = '/home/goodh/miniconda3/envs/5003/bin/python'
spark = init_spark(NUM_CPUs)

# Output directory for Parquet files
output_dir = "processed_data"
os.makedirs(output_dir, exist_ok=True)

# Check for cached Parquet files
# cached_data = check_cached_parquet(output_dir)
cached_data = None
train_path = test_path = sst2_test_path = train_collection = test_collection = sst2_collection = None
preprocess_time = 0

if cached_data:
    logger.info("Cached Parquet files found. Skipping data loading and preprocessing...")
    train_path, test_path, sst2_test_path, train_collection, test_collection, sst2_collection = cached_data
else:
    # Load and preprocess data
    logger.info("No cached Parquet files found. Running full pipeline...")
    logger.info("Loading data to MongoDB...")
    load_data_time = load_data_to_mongodb(spark)
    logger.info(f"Data loading to MongoDB took {load_data_time:.2f} seconds")
    
    logger.info("Distributed preprocessing and saving to Parquet...")
    train_path, test_path, sst2_test_path, train_collection, test_collection, sst2_collection, preprocess_time = preprocess_data(spark, output_dir)
    logger.info(f"Distributed preprocessing took {preprocess_time:.2f} seconds")

# Run distributed training
world_size = NUM_GPUs if NUM_GPUs else max(1, torch.cuda.device_count())
logger.info(f"Using {world_size} GPU(s)")

logger.info("Distributed fine-tuning...")
import torch.multiprocessing as mp
finetune_time = torch.zeros(world_size, dtype=torch.float32).share_memory_()
mp.spawn(
    train_and_evaluate,
    args=(world_size, train_path, test_path, sst2_test_path, train_collection, test_collection, sst2_collection, finetune_time),
    nprocs=world_size,
    join=True
)

# append results
result = f"{time.strftime('%Y/%m/%d-%H:%M:%S')}\t{NUM_CPUs}\t\t{NUM_GPUs}\t\t{preprocess_time:.2f}\t\t{finetune_time[0]:.2f}\n"
logger.info(result)
with open("out/results.out", "a") as f:
    f.write(result)

spark.stop()

INFO:__main__:Initializing Spark...
INFO:__main__:28 cores for spark
25/04/22 22:38:20 WARN Utils: Your hostname, yPC resolves to a loopback address: 127.0.1.1; using 10.255.255.254 instead (on interface lo)
25/04/22 22:38:21 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address


:: loading settings :: url = jar:file:/home/goodh/miniconda3/envs/5003/lib/python3.9/site-packages/pyspark/jars/ivy-2.5.1.jar!/org/apache/ivy/core/settings/ivysettings.xml


Ivy Default Cache set to: /home/goodh/.ivy2/cache
The jars for the packages stored in: /home/goodh/.ivy2/jars
org.mongodb.spark#mongo-spark-connector_2.12 added as a dependency
:: resolving dependencies :: org.apache.spark#spark-submit-parent-286c017c-2100-4589-b54a-b6d869fae13a;1.0
	confs: [default]
	found org.mongodb.spark#mongo-spark-connector_2.12;3.0.1 in central
	found org.mongodb#mongodb-driver-sync;4.0.5 in central
	found org.mongodb#bson;4.0.5 in central
	found org.mongodb#mongodb-driver-core;4.0.5 in central
:: resolution report :: resolve 115ms :: artifacts dl 3ms
	:: modules in use:
	org.mongodb#bson;4.0.5 from central in [default]
	org.mongodb#mongodb-driver-core;4.0.5 from central in [default]
	org.mongodb#mongodb-driver-sync;4.0.5 from central in [default]
	org.mongodb.spark#mongo-spark-connector_2.12;3.0.1 from central in [default]
	---------------------------------------------------------------------
	|                  |            modules            ||   artifacts   

Py4JError: An error occurred while calling o131.parquet