<a href="https://colab.research.google.com/github/xutian1113/pytorch_practice/blob/main/glue_SST_2_practice.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install datasets
!pip install tqdm

Collecting datasets
  Downloading datasets-3.2.0-py3-none-any.whl.metadata (20 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.9.0,>=2023.1.0 (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.9.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.2.0-py3-none-any.whl (480 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m480.6/480.6 kB[0m [31m9.0 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m8.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2024.9.0-py3-none-any.whl (1

## SST-2 Description
The Stanford Sentiment Treebank (SST-2) is a binary sentiment classification dataset that is part of the GLUE (General Language Understanding Evaluation) benchmark. It is derived from the Stanford Sentiment Treebank (SST), which originally contained fine-grained sentiment labels (very negative, negative, neutral, positive, very positive). However, in SST-2, the task is simplified into binary classification:

- Label 0 → Negative Sentiment
- Label 1 → Positive Sentiment

### Key Features of SST-2
- Task Type: Binary Sentiment Classification
- Input Type:	Single Sentence
- Output Labels:	0 (Negative), 1 (Positive)
- Dataset Size:	67,349 (Train), 872 (Validation), 1,821 (Test)


| **sentences**      | **label (sentiment)** |
|-------------------------------------|----------------------|
|This movie was amazing! I loved it.|	 1 (Positive)|
|A complete waste of time, very disappointing.	| 0 (Negative)|
|It had great visuals but lacked a solid story.	| 1 (Positive)|
|The worst film I have ever seen.	| 0 (Negative) |



### CoLA Dataset Statistics
| **Split**      | **Number of Samples** |
|---------------|----------------------|
| **Train**     | 67,349               |
| **Validation** | 872                |
| **Test**      | 1,821 (Labels not public) |



```
DatasetDict({
    train: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 67349
    })
    validation: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 872
    })
    test: Dataset({
        features: ['sentence', 'idx'],
        num_rows: 1821
    })
})

```

### dataset["train"][0]
```
{'sentence': 'hide new secretions from the parental units ',
 'label': 0,
 'idx': 0}
```

### dataset statistics


In [2]:
import torch
from torch.utils.data import DataLoader
from transformers import BertTokenizer, BertForSequenceClassification, AdamW, get_linear_schedule_with_warmup
from datasets import load_dataset
import numpy as np
from sklearn.metrics import roc_auc_score
from sklearn.metrics import average_precision_score
from sklearn.metrics import confusion_matrix
from sklearn.metrics import matthews_corrcoef
from tqdm import tqdm


In [3]:
def accuracy(preds, labels):
  _, predicted = torch.max(preds, 1)
  correct = (predicted == labels).sum().item()
  return correct / len(labels)

def f1_score(preds, labels):
  _, predicted = torch.max(preds, 1)
  tp = ((predicted == 1) & (labels == 1)).sum().item()
  fp = ((predicted == 1) & (labels == 0)).sum().item()
  fn = ((predicted == 0) & (labels == 1)).sum().item()
  precision = tp / (tp + fp) if tp + fp > 0 else 0
  recall = tp / (tp + fn) if tp + fn > 0 else 0
  f1 = 2 * (precision * recall) / (precision + recall) if precision + recall > 0 else 0
  return f1

def precision(preds, labels):
  _, predicted = torch.max(preds, 1)
  tp = ((predicted == 1) & (labels == 1)).sum().item()
  fp = ((predicted == 1) & (labels == 0)).sum().item()
  precision = tp / (tp + fp) if tp + fp > 0 else 0
  return precision


def auc_roc(preds, labels):
  # Convert predictions and labels to NumPy arrays
  _, predicted = torch.max(preds, 1)
  preds_np = predicted.detach().cpu().numpy()
  labels_np = labels.detach().cpu().numpy()

  # Calculate AUC-ROC using scikit-learn's roc_auc_score
  auc_score = roc_auc_score(labels_np, preds_np)

  return auc_score


def auc_pr(preds, labels):
  """
  Calculates the AUPRC score for binary classification.

  Args:
    preds: Predicted probabilities (tensor) for the positive class.
    labels: True labels (tensor) (0 or 1).

  Returns:
    AUPRC score (float).
  """
  # Convert predictions and labels to NumPy arrays
  _, predicted = torch.max(preds, 1)
  preds_np = predicted.detach().cpu().numpy()
  labels_np = labels.detach().cpu().numpy()

  # Calculate AUPRC using scikit-learn's average_precision_score
  auprc_score = average_precision_score(labels_np, preds_np)

  return auprc_score

def confusion_matrix_func(preds, labels, num_classes=2):
  """
  Calculates the confusion matrix.

  Args:
    preds: Predicted labels (tensor).
    labels: True labels (tensor).
    num_classes: Number of classes. Defaults to 2 for binary classification.

  Returns:
    Confusion matrix (NumPy array).
  """
  # Get predicted class indices
  _, predicted = torch.max(preds, 1)

  # Convert predictions and labels to NumPy arrays
  predicted_np = predicted.detach().cpu().numpy()
  labels_np = labels.detach().cpu().numpy()

  # Calculate confusion matrix using scikit-learn's confusion_matrix
  cm = confusion_matrix(labels_np, predicted_np, labels=range(num_classes))

  return cm

def recall(preds, labels):
  """
  Calculates the recall score for binary classification.

  Args:
    preds: Predicted labels (tensor).
    labels: True labels (tensor).

  Returns:
    Recall score (float).
  """
  # Get predicted class indices
  _, predicted = torch.max(preds, 1)

  # Calculate true positives (TP), false negatives (FN)
  tp = ((predicted == 1) & (labels == 1)).sum().item()
  fn = ((predicted == 0) & (labels == 1)).sum().item()

  # Calculate recall
  recall_score = tp / (tp + fn) if tp + fn > 0 else 0  # Avoid division by zero

  return recall_score

def MCC(preds, labels):
  """
  Calculates the Matthews correlation coefficient (MCC) for binary classification.

  Args:
    preds: Predicted labels (tensor).
    labels: True labels (tensor).

  Returns:
    MCC score (float).
    """
  # Get predicted class indices
  _, predicted = torch.max(preds, 1)

  # Convert predictions and labels to NumPy arrays
  predicted_np = predicted.detach().cpu().numpy()
  labels_np = labels.detach().cpu().numpy()

  # Calculate MCC using scikit-learn's matthews_corrcoef
  mcc_score = matthews_corrcoef(predicted_np, labels_np)
  return mcc_score


In [9]:
# 1. load the GLUE MRPC dataset
dataset = load_dataset('glue', 'sst2')

# 2. load the tokenizer for bert-base-cased
tokenizer = BertTokenizer.from_pretrained("bert-base-cased")
# converting raw text into numerical representations that can be processed by BERT model.

# 3. define a tokenization function for sentence pair

def tokenize_function(example):
    return tokenizer(example["sentence"], truncation=True, padding="max_length", max_length=128)

'''
input_ids: Tokenized numerical representation of both sentences.
token_type_ids: Differentiates sentence1 (0) from sentence2 (1).
attention_mask: Indicates which tokens are real (1) and padding (0).
'''
# 4. Tokenize the dataset (batched processing)
tokenized_datasets = dataset.map(tokenize_function, batched=True)

# 5. Remove unnecessary columns and set the format to PyTorch tensors
# Remove columns that are not used by the model (like the original sentences and index)
tokenized_datasets = tokenized_datasets.remove_columns(["sentence", "idx"])
tokenized_datasets.set_format("torch")

# 6. Create DataLoaders for training and validation
train_dataloader = DataLoader(tokenized_datasets["train"], shuffle=True, batch_size=8)
val_dataloader = DataLoader(tokenized_datasets["validation"], batch_size=8)


# 7. Set up the device, model, optimizer, and scheduler
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = BertForSequenceClassification.from_pretrained("bert-base-cased", num_labels=2)
model.to(device)

# optimizer = AdamW(model.parameters(), lr=2e-5)
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
epochs = 5
total_steps = len(train_dataloader) * epochs

scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=0,
    num_training_steps=total_steps
)


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
best_val_loss = float('inf')
best_model_state = None
best_epoch = 0
epochs_no_improve = 0
patience = 3

model.train()
for epoch in range(epochs):
    print(f"Epoch {epoch+1}/{epochs}")
    total_loss = 0
    all_preds = []
    all_labels = []
    for batch in tqdm(train_dataloader, desc=f"Training Epoch {epoch + 1}"):
        # Move batch to the device
        batch = {key: value.to(device) for key, value in batch.items()}
        batch['labels'] = batch['label']
        del batch['label']
        # Forward pass
        outputs = model(**batch)
        loss = outputs.loss
        total_loss += loss.item()

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()
        all_preds.append(outputs.logits.detach())
        all_labels.append(batch['labels'].detach())

    all_preds = torch.cat(all_preds)
    all_labels = torch.cat(all_labels)

    train_acc = accuracy(all_preds, all_labels)
    train_auc = auc_roc(all_preds, all_labels)
    train_rec = recall(all_preds, all_labels)
    train_avg_loss = total_loss / len(train_dataloader)
    train_mcc = MCC(all_preds, all_labels)
    print(f"Epoch {epoch+1}/{epochs} - Loss: {train_avg_loss:.4f}, Accuracy: {train_acc:.4f}, AUC-ROC: {train_auc:.4f}, Recall: {train_rec:.4f}, MCC: {train_mcc:.4f}")

    # 9. Validation loop
    model.eval()
    correct_predictions = 0
    total_predictions = 0
    val_preds = []
    val_labels = []
    val_loss = 0
    for batch in val_dataloader:
        batch = {key: value.to(device) for key, value in batch.items()}
        batch['labels'] = batch['label']
        del batch['label']
        with torch.no_grad():
            outputs = model(**batch)
            val_loss += outputs.loss.item()
        logits = outputs.logits
        val_preds.append(logits.detach())
        val_labels.append(batch['labels'].detach())

    val_preds = torch.cat(val_preds)
    val_labels = torch.cat(val_labels)
    val_acc = accuracy(val_preds, val_labels)
    val_auc = auc_roc(val_preds, val_labels)
    val_rec = recall(val_preds, val_labels)
    val_avg_loss = val_loss / len(val_dataloader)
    val_mcc = MCC(val_preds, val_labels)
    if val_avg_loss < best_val_loss:
        best_val_loss = val_avg_loss
        best_model_state = model.state_dict()
        best_epoch = epoch
        epochs_no_improve = 0
    else:
      epochs_no_improve += 1
      # if epochs_no_improve >= patience:
      #     print(f"Early stopping at epoch {epoch+1} due to no improvement in validation loss for {patience} epochs.")
      #     break
    print(f"Validation - Loss: {val_avg_loss:.4f},  Accuracy: {val_acc:.4f}, AUC-ROC: {val_auc:.4f}, Recall: {val_rec:.4f}, MCC: {val_mcc:.4f}")


    model.train()  # Switch back to training mode for the next epoch


Epoch 1/5


Training Epoch 1:  15%|█▌        | 1291/8419 [01:01<05:38, 21.07it/s]

In [None]:
test_dataloader = DataLoader(tokenized_datasets["test"], batch_size=8)
model.eval()
test_preds = []
test_labels = []

for batch in test_dataloader:
    batch = {key: value.to(device) for key, value in batch.items()}
    batch['labels'] = batch['label']
    del batch['label']
    with torch.no_grad():
        outputs = model(**batch)
    logits = outputs.logits
    test_preds.append(logits.detach())
    test_labels.append(batch['labels'].detach())

test_preds = torch.cat(test_preds)
test_labels = torch.cat(test_labels)

test_acc = accuracy(test_preds, test_labels)
test_auc = auc_roc(test_preds, test_labels)
test_rec = recall(test_preds, test_labels)
print(f"Test - Accuracy: {test_acc:.4f}, AUC-ROC: {test_auc:.4f}, Recall: {test_rec:.4f}")