<a href="https://colab.research.google.com/github/xutian1113/pytorch_practice/blob/main/glue_CoLA_practice.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# !pip install datasets
# !pip install tqdm

## CoLA Description
CoLA (Corpus of Linguistic Acceptability) is a dataset used for evaluating a model's ability to determine whether a given sentence is gramatically acceptable or unacceptable.
- Task Type: binary classification (Acceptable or Unacceptable)
- Goal: Predict whether a given sentence is grammatically acceptable (1) or unacceptable (0).
- Data Source: Sentences are taken from published linguistic literature.
- Label Distribution:
  - 1 (Acceptable Sentence) → The sentence follows standard English grammar.
  - 0 (Unacceptable Sentence) → The sentence is grammatically incorrect.
  
### CoLA Dataset Statistics
| **Split**      | **Number of Samples** |
|---------------|----------------------|
| **Train**     | 8,551                |
| **Validation** | 1,043                |
| **Test**      | ~1,000 (Labels not public) |



```
DatasetDict({
    train: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 8551
    })
    validation: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 1043
    })
    test: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 1063
    })
})
```

### dataset["train"][0]
```
{'sentence': "Our friends won't buy this analysis, let alone the next one we propose.",
 'label': 1,
 'idx': 0}
```

### dataset statistics
- train: ({0: 2528, 1: 6023}, 0.704362062916618)
- val: ({(0: 322, 1: 721}, 0.6912751677852349)
- test: unavailable


In [2]:
import torch
from torch.utils.data import DataLoader
from transformers import BertTokenizer, BertForSequenceClassification, AdamW, get_linear_schedule_with_warmup
from datasets import load_dataset
import numpy as np
from sklearn.metrics import roc_auc_score
from sklearn.metrics import average_precision_score
from sklearn.metrics import confusion_matrix
from sklearn.metrics import matthews_corrcoef
from tqdm import tqdm


In [3]:
def accuracy(preds, labels):
  _, predicted = torch.max(preds, 1)
  correct = (predicted == labels).sum().item()
  return correct / len(labels)

def f1_score(preds, labels):
  _, predicted = torch.max(preds, 1)
  tp = ((predicted == 1) & (labels == 1)).sum().item()
  fp = ((predicted == 1) & (labels == 0)).sum().item()
  fn = ((predicted == 0) & (labels == 1)).sum().item()
  precision = tp / (tp + fp) if tp + fp > 0 else 0
  recall = tp / (tp + fn) if tp + fn > 0 else 0
  f1 = 2 * (precision * recall) / (precision + recall) if precision + recall > 0 else 0
  return f1

def precision(preds, labels):
  _, predicted = torch.max(preds, 1)
  tp = ((predicted == 1) & (labels == 1)).sum().item()
  fp = ((predicted == 1) & (labels == 0)).sum().item()
  precision = tp / (tp + fp) if tp + fp > 0 else 0
  return precision


def auc_roc(preds, labels):
  # Convert predictions and labels to NumPy arrays
  _, predicted = torch.max(preds, 1)
  preds_np = predicted.detach().cpu().numpy()
  labels_np = labels.detach().cpu().numpy()

  # Calculate AUC-ROC using scikit-learn's roc_auc_score
  auc_score = roc_auc_score(labels_np, preds_np)

  return auc_score


def auc_pr(preds, labels):
  """
  Calculates the AUPRC score for binary classification.

  Args:
    preds: Predicted probabilities (tensor) for the positive class.
    labels: True labels (tensor) (0 or 1).

  Returns:
    AUPRC score (float).
  """
  # Convert predictions and labels to NumPy arrays
  _, predicted = torch.max(preds, 1)
  preds_np = predicted.detach().cpu().numpy()
  labels_np = labels.detach().cpu().numpy()

  # Calculate AUPRC using scikit-learn's average_precision_score
  auprc_score = average_precision_score(labels_np, preds_np)

  return auprc_score

def confusion_matrix_func(preds, labels, num_classes=2):
  """
  Calculates the confusion matrix.

  Args:
    preds: Predicted labels (tensor).
    labels: True labels (tensor).
    num_classes: Number of classes. Defaults to 2 for binary classification.

  Returns:
    Confusion matrix (NumPy array).
  """
  # Get predicted class indices
  _, predicted = torch.max(preds, 1)

  # Convert predictions and labels to NumPy arrays
  predicted_np = predicted.detach().cpu().numpy()
  labels_np = labels.detach().cpu().numpy()

  # Calculate confusion matrix using scikit-learn's confusion_matrix
  cm = confusion_matrix(labels_np, predicted_np, labels=range(num_classes))

  return cm

def recall(preds, labels):
  """
  Calculates the recall score for binary classification.

  Args:
    preds: Predicted labels (tensor).
    labels: True labels (tensor).

  Returns:
    Recall score (float).
  """
  # Get predicted class indices
  _, predicted = torch.max(preds, 1)

  # Calculate true positives (TP), false negatives (FN)
  tp = ((predicted == 1) & (labels == 1)).sum().item()
  fn = ((predicted == 0) & (labels == 1)).sum().item()

  # Calculate recall
  recall_score = tp / (tp + fn) if tp + fn > 0 else 0  # Avoid division by zero

  return recall_score

def MCC(preds, labels):
  """
  Calculates the Matthews correlation coefficient (MCC) for binary classification.

  Args:
    preds: Predicted labels (tensor).
    labels: True labels (tensor).

  Returns:
    MCC score (float).
    """
  # Get predicted class indices
  _, predicted = torch.max(preds, 1)

  # Convert predictions and labels to NumPy arrays
  predicted_np = predicted.detach().cpu().numpy()
  labels_np = labels.detach().cpu().numpy()

  # Calculate MCC using scikit-learn's matthews_corrcoef
  mcc_score = matthews_corrcoef(predicted_np, labels_np)
  return mcc_score


In [4]:
# 1. load the GLUE MRPC dataset
dataset = load_dataset('glue', 'cola')

# 2. load the tokenizer for bert-base-cased
tokenizer = BertTokenizer.from_pretrained("bert-base-cased")
# converting raw text into numerical representations that can be processed by BERT model.

# 3. define a tokenization function for sentence pair

def tokenize_function(example):
    return tokenizer(example["sentence"], truncation=True, padding="max_length", max_length=128)

'''
input_ids: Tokenized numerical representation of both sentences.
token_type_ids: Differentiates sentence1 (0) from sentence2 (1).
attention_mask: Indicates which tokens are real (1) and padding (0).
'''
# 4. Tokenize the dataset (batched processing)
tokenized_datasets = dataset.map(tokenize_function, batched=True)

# 5. Remove unnecessary columns and set the format to PyTorch tensors
# Remove columns that are not used by the model (like the original sentences and index)
tokenized_datasets = tokenized_datasets.remove_columns(["sentence", "idx"])
tokenized_datasets.set_format("torch")

# 6. Create DataLoaders for training and validation
train_dataloader = DataLoader(tokenized_datasets["train"], shuffle=True, batch_size=8)
val_dataloader = DataLoader(tokenized_datasets["validation"], batch_size=8)


# 7. Set up the device, model, optimizer, and scheduler
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = BertForSequenceClassification.from_pretrained("bert-base-cased", num_labels=2)
model.to(device)

# optimizer = AdamW(model.parameters(), lr=2e-5)
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
epochs = 20
total_steps = len(train_dataloader) * epochs

scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=0,
    num_training_steps=total_steps
)


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [5]:
best_val_loss = float('inf')
best_model_state = None
best_epoch = 0
epochs_no_improve = 0
patience = 3

model.train()
for epoch in range(epochs):
    print(f"Epoch {epoch+1}/{epochs}")
    total_loss = 0
    all_preds = []
    all_labels = []
    for batch in tqdm(train_dataloader, desc=f"Training Epoch {epoch + 1}"):
        # Move batch to the device
        batch = {key: value.to(device) for key, value in batch.items()}
        batch['labels'] = batch['label']
        del batch['label']
        # Forward pass
        outputs = model(**batch)
        loss = outputs.loss
        total_loss += loss.item()

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()
        all_preds.append(outputs.logits.detach())
        all_labels.append(batch['labels'].detach())

    all_preds = torch.cat(all_preds)
    all_labels = torch.cat(all_labels)

    train_acc = accuracy(all_preds, all_labels)
    train_auc = auc_roc(all_preds, all_labels)
    train_rec = recall(all_preds, all_labels)
    train_avg_loss = total_loss / len(train_dataloader)
    train_mcc = MCC(all_preds, all_labels)
    print(f"Epoch {epoch+1}/{epochs} - Loss: {train_avg_loss:.4f}, Accuracy: {train_acc:.4f}, AUC-ROC: {train_auc:.4f}, Recall: {train_rec:.4f}, MCC: {train_mcc:.4f}")

    # 9. Validation loop
    model.eval()
    correct_predictions = 0
    total_predictions = 0
    val_preds = []
    val_labels = []
    val_loss = 0
    for batch in val_dataloader:
        batch = {key: value.to(device) for key, value in batch.items()}
        batch['labels'] = batch['label']
        del batch['label']
        with torch.no_grad():
            outputs = model(**batch)
            val_loss += outputs.loss.item()
        logits = outputs.logits
        val_preds.append(logits.detach())
        val_labels.append(batch['labels'].detach())

    val_preds = torch.cat(val_preds)
    val_labels = torch.cat(val_labels)
    val_acc = accuracy(val_preds, val_labels)
    val_auc = auc_roc(val_preds, val_labels)
    val_rec = recall(val_preds, val_labels)
    val_avg_loss = val_loss / len(val_dataloader)
    val_mcc = MCC(val_preds, val_labels)
    if val_avg_loss < best_val_loss:
        best_val_loss = val_avg_loss
        best_model_state = model.state_dict()
        best_epoch = epoch
        epochs_no_improve = 0
    else:
      epochs_no_improve += 1
      # if epochs_no_improve >= patience:
      #     print(f"Early stopping at epoch {epoch+1} due to no improvement in validation loss for {patience} epochs.")
      #     break
    print(f"Validation - Loss: {val_avg_loss:.4f},  Accuracy: {val_acc:.4f}, AUC-ROC: {val_auc:.4f}, Recall: {val_rec:.4f}, MCC: {val_mcc:.4f}")


    model.train()  # Switch back to training mode for the next epoch


Epoch 1/20


Training Epoch 1: 100%|██████████| 1069/1069 [00:51<00:00, 20.67it/s]


Epoch 1/20 - Loss: 0.4989, Accuracy: 0.7659, AUC-ROC: 0.6699, Recall: 0.9047, MCC: 0.3908
Validation - Loss: 0.4441,  Accuracy: 0.8006, AUC-ROC: 0.7028, Recall: 0.9584, MCC: 0.5026
Epoch 2/20


Training Epoch 2: 100%|██████████| 1069/1069 [00:51<00:00, 20.88it/s]


Epoch 2/20 - Loss: 0.2825, Accuracy: 0.8862, AUC-ROC: 0.8553, Recall: 0.9309, MCC: 0.7229
Validation - Loss: 0.3951,  Accuracy: 0.8265, AUC-ROC: 0.7894, Recall: 0.8863, MCC: 0.5878
Epoch 3/20


Training Epoch 3: 100%|██████████| 1069/1069 [00:51<00:00, 20.87it/s]


Epoch 3/20 - Loss: 0.1593, Accuracy: 0.9415, AUC-ROC: 0.9267, Recall: 0.9630, MCC: 0.8588
Validation - Loss: 0.4973,  Accuracy: 0.8236, AUC-ROC: 0.7701, Recall: 0.9098, MCC: 0.5712
Epoch 4/20


Training Epoch 4: 100%|██████████| 1069/1069 [00:51<00:00, 20.89it/s]


Epoch 4/20 - Loss: 0.0924, Accuracy: 0.9678, AUC-ROC: 0.9610, Recall: 0.9778, MCC: 0.9227
Validation - Loss: 0.6177,  Accuracy: 0.8207, AUC-ROC: 0.7603, Recall: 0.9182, MCC: 0.5611
Epoch 5/20


Training Epoch 5: 100%|██████████| 1069/1069 [00:51<00:00, 20.88it/s]


Epoch 5/20 - Loss: 0.0740, Accuracy: 0.9757, AUC-ROC: 0.9706, Recall: 0.9831, MCC: 0.9416
Validation - Loss: 0.6008,  Accuracy: 0.8025, AUC-ROC: 0.7514, Recall: 0.8849, MCC: 0.5231
Epoch 6/20


Training Epoch 6: 100%|██████████| 1069/1069 [00:51<00:00, 20.88it/s]


Epoch 6/20 - Loss: 0.0414, Accuracy: 0.9856, AUC-ROC: 0.9831, Recall: 0.9892, MCC: 0.9655
Validation - Loss: 0.9687,  Accuracy: 0.8054, AUC-ROC: 0.7295, Recall: 0.9279, MCC: 0.5171
Epoch 7/20


Training Epoch 7: 100%|██████████| 1069/1069 [00:51<00:00, 20.88it/s]


Epoch 7/20 - Loss: 0.0420, Accuracy: 0.9863, AUC-ROC: 0.9836, Recall: 0.9902, MCC: 0.9672
Validation - Loss: 0.9233,  Accuracy: 0.8111, AUC-ROC: 0.7491, Recall: 0.9112, MCC: 0.5369
Epoch 8/20


Training Epoch 8: 100%|██████████| 1069/1069 [00:51<00:00, 20.88it/s]


Epoch 8/20 - Loss: 0.0303, Accuracy: 0.9898, AUC-ROC: 0.9872, Recall: 0.9937, MCC: 0.9755
Validation - Loss: 0.9443,  Accuracy: 0.8198, AUC-ROC: 0.7485, Recall: 0.9348, MCC: 0.5554
Epoch 9/20


Training Epoch 9: 100%|██████████| 1069/1069 [00:51<00:00, 20.89it/s]


Epoch 9/20 - Loss: 0.0289, Accuracy: 0.9891, AUC-ROC: 0.9869, Recall: 0.9924, MCC: 0.9739
Validation - Loss: 0.9854,  Accuracy: 0.8063, AUC-ROC: 0.7379, Recall: 0.9168, MCC: 0.5222
Epoch 10/20


Training Epoch 10: 100%|██████████| 1069/1069 [00:51<00:00, 20.87it/s]


Epoch 10/20 - Loss: 0.0204, Accuracy: 0.9917, AUC-ROC: 0.9897, Recall: 0.9945, MCC: 0.9801
Validation - Loss: 1.2007,  Accuracy: 0.8111, AUC-ROC: 0.7396, Recall: 0.9265, MCC: 0.5332
Epoch 11/20


Training Epoch 11:  29%|██▊       | 306/1069 [00:14<00:36, 20.86it/s]


KeyboardInterrupt: 

In [None]:
test_dataloader = DataLoader(tokenized_datasets["test"], batch_size=8)
model.eval()
test_preds = []
test_labels = []

for batch in test_dataloader:
    batch = {key: value.to(device) for key, value in batch.items()}
    batch['labels'] = batch['label']
    del batch['label']
    with torch.no_grad():
        outputs = model(**batch)
    logits = outputs.logits
    test_preds.append(logits.detach())
    test_labels.append(batch['labels'].detach())

test_preds = torch.cat(test_preds)
test_labels = torch.cat(test_labels)

test_acc = accuracy(test_preds, test_labels)
test_auc = auc_roc(test_preds, test_labels)
test_rec = recall(test_preds, test_labels)
print(f"Test - Accuracy: {test_acc:.4f}, AUC-ROC: {test_auc:.4f}, Recall: {test_rec:.4f}")