## 4. Training function per worker
This function will be executed by each worker during training. It handles data loading, tokenization, model initialization, and the training loop. This will automatically select GPU, MPS (on Apple Silicon), or CPU.

### Tokenizer
Tokenizer function is used to convert text into input IDs and attention masks.

Padding and truncation are applied to ensure uniform input size. This is essential for training models that require fixed-size inputs. The function is applied to the dataset using the map method. The map method applies the function to each example in the dataset. The batched=True argument allows processing multiple examples at once, which is more efficient.

The resulting dataset will have the tokenized inputs ready for training. This is a crucial step in preparing the dataset for model training. It ensures that the text data is converted into a format that the model can understand.

### Dataloaders
Dataloaders are used to load the dataset in batches for training and evaluation. This is essential for efficient training, especially with large datasets. The DataLoader will shuffle the training data and collate it into batches
The collate_fn is set to transformers.default_data_collator, which handles padding and batching automatically. The batch_size is set to the batch size per worker, which is defined in the config. This allows each worker to process a subset of the data in parallel. This is crucial for distributed training, where each worker processes a portion of the dataset.

In [3]:
def train_func_per_worker(config: Dict):
    
    # Datasets
    dataset = load_dataset("yelp_review_full")
    tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
    
    # Tokenization function
    def tokenize_function(examples):
        """    
        This function will tokenize the text data in the dataset
        It uses the tokenizer to convert text into input IDs and attention masks
        Padding and truncation are applied to ensure uniform input size
        This is essential for training models that require fixed-size inputs
        """
        return tokenizer(examples["text"], padding="max_length", truncation=True)

    lr = config["lr"]
    epochs = config["epochs"]
    batch_size = config["batch_size_per_worker"]

    # select a subset of the dataset for training and evaluation
    # In a real-world scenario, you would use the entire dataset
    SMALL_SIZE = 100
    # The map method applies the function to each example in the dataset
    # The batched=True argument allows processing multiple examples at once, which is more efficient
    # The resulting dataset will have the tokenized inputs ready for training
    # This is a crucial step in preparing the dataset for model training
    # It ensures that the text data is converted into a format that the model can understand
    train_dataset = dataset["train"].select(range(SMALL_SIZE)).map(tokenize_function, batched=True)
    eval_dataset = dataset["test"].select(range(SMALL_SIZE)).map(tokenize_function, batched=True)

    # Prepare dataloader for each worker
    # Dataloaders are used to load the dataset in batches for training and evaluation
    # The dataloaders dictionary will hold the training and evaluation dataloaders
    # This allows for easy access to the dataloaders during training and evaluation
    # The dataloaders will be used in the training loop to fetch batches of data for each worker
    dataloaders = {}
    dataloaders["train"] = torch.utils.data.DataLoader(
        train_dataset, 
        shuffle=True, 
        collate_fn=transformers.default_data_collator, 
        batch_size=batch_size
    )
    dataloaders["test"] = torch.utils.data.DataLoader(
        eval_dataset, 
        shuffle=True, 
        collate_fn=transformers.default_data_collator, 
        batch_size=batch_size
    )

    # Obtain GPU device automatically
    # device = ray.train.torch.get_device()
    
    # Alternatively, you can specify the device manually
    # Check if CUDA or MPS is available and set device accordingly
    # This is useful for running on different hardware configurations
    if torch.cuda.is_available():
        device = torch.device("cuda")
    elif torch.backends.mps.is_available():
        device = torch.device("mps") # For Apple Silicon Macs
    else:
        device = torch.device("cpu")

    # Prepare model and optimizer
    # Load a pre-trained BERT model for sequence classification
    # The model is initialized with the number of labels for classification
    model = AutoModelForSequenceClassification.from_pretrained(
        "bert-base-cased", num_labels=5
    )
    # The model is moved to the selected device (GPU, MPS, or CPU)
    model = model.to(device)
    
    # The optimizer is set to SGD with momentum
    # This is essential for training the model
    # The optimizer will update the model parameters during training
    # The learning rate and momentum are set based on the configuration
    optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)

    # Start training loops
    # The model will be trained for the specified number of epochs
    # The model will be trained using the training dataloader
    # The model will be evaluated using the evaluation dataloader
    # The training loop will iterate over the epochs and batches
    for epoch in range(epochs):
        # Each epoch has a training and validation phase
        for phase in ["train", "test"]:
            if phase == "train":
                model.train()  # Set model to training mode
            else:
                model.eval()  # Set model to evaluate mode

            # breakpoint()
            for batch in dataloaders[phase]: # Iterate over batches in the dataloader
                batch = {k: v.to(device) for k, v in batch.items()}

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward pass
                with torch.set_grad_enabled(phase == "train"):
                    # Get model outputs and calculate loss
                    # The model processes the input batch and returns outputs
                    # The outputs include the loss and logits
                    # The loss is calculated based on the model's predictions and the true labels
                    # The logits are the raw predictions from the model
                    # The loss is used to update the model parameters during training
                    outputs = model(**batch)
                    loss = outputs.loss

                    # backward + optimize only if in training phase
                    if phase == "train":
                        loss.backward() # Backpropagate the loss to compute gradients
                        # The optimizer updates the model parameters based on the computed gradients
                        optimizer.step()
                        print(f"train epoch:[{epoch}]\tloss:{loss:.6f}")