In [None]:

import torch
import torch.nn as nn
from fairscale.nn import PipeModule
from transformers import AutoModelForSequenceClassification, AutoTokenizer



In [None]:
# Load your model and tokenizer
model = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased")
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")



In [None]:
# Wrap your model with PipeModule
model = PipeModule(model)

# Define your training arguments
training_args = {
    "batch_size": 32,
    "num_epochs": 10,
    "learning_rate": 5e-5,
}

# Define your optimizer and criterion
optimizer = torch.optim.Adam(model.parameters(), lr=training_args["learning_rate"])
criterion = nn.CrossEntropyLoss()



In [None]:
# Define your data loader
dataset = load_dataset("glue", "mrpc")
train_dataset = dataset["train"]
eval_dataset = dataset["test"]
train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=training_args["batch_size"], shuffle=True
)
eval_loader = torch.utils.data.DataLoader(
    eval_dataset, batch_size=training_args["batch_size"], shuffle=False
)

# Wrap your model with DataParallel
model = nn.DataParallel(model)



In [None]:

# Train your model
for epoch in range(training_args["num_epochs"]):
    # Set the model to training mode
    model.train()
    # Loop over the training batches
    for batch in train_loader:
        # Move the batch to the device
        batch = {k: v.to(model.device) for k, v in batch.items()}
        # Get the inputs and labels
        inputs = tokenizer(batch["sentence1"], batch["sentence2"], return_tensors="pt", padding=True)
        labels = batch["label"]
        # Forward pass
        outputs = model(**inputs)
        # Compute the loss
        loss = criterion(outputs.logits, labels)
        # Backward pass and update the parameters
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    # Set the model to evaluation mode
    model.eval()
    # Loop over the evaluation batches
    for batch in eval_loader:
        # Move the batch to the device
        batch = {k: v.to(model.device) for k, v in batch.items()}
        # Get the inputs and labels
        inputs = tokenizer(batch["sentence1"], batch["sentence2"], return_tensors="pt", padding=True)
        labels = batch["label"]
        # Forward pass
        outputs = model(**inputs)
        # Compute the loss and accuracy
        loss = criterion(outputs.logits, labels)
        preds = outputs.logits.argmax(dim=-1)
        acc = (preds == labels).float().mean()
    # Print the results for this epoch
    print(f"Epoch {epoch + 1}, Loss: {loss.item():.4f}, Accuracy: {acc.item():.4f}")