In [None]:
# !pip install pytreebank
# !pip install loguru
# !pip install transformers

In [None]:
import os
import random
import pandas as pd
import numpy as np
import pytreebank
import torch
from loguru import logger
from torch.utils.data import Dataset
from loguru import logger
from transformers import DistilBertTokenizer, DistilBertConfig, DistilBertForSequenceClassification
from tqdm import tqdm
import matplotlib.pyplot as plt


torch.cuda.is_available()

In [None]:
"""This module defines a configurable SSTDataset class."""

logger.info("Loading the tokenizer")
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")

logger.info("Loading SST")
sst = pytreebank.load_sst()


def rpad(array, n=70):
    """Right padding."""
    current_len = len(array)
    if current_len > n:
        return array[: n - 1]
    extra = n - current_len
    return array + ([0] * extra)


def get_binary_label(label):
    """Convert fine-grained label to binary label."""
    if label < 2:
        return 0
    if label > 2:
        return 1
    raise ValueError("Invalid label")


class SSTDataset(Dataset):
    """Configurable SST Dataset.
    
    Things we can configure:
        - split (train / val / test)
        - root / all nodes
        - binary / fine-grained
    """

    def __init__(self, split="train", root=True, binary=True):
        """Initializes the dataset with given configuration.

        Args:
            split: str
                Dataset split, one of [train, val, test]
            root: bool
                If true, only use root nodes. Else, use all nodes.
            binary: bool
                If true, use binary labels. Else, use fine-grained.
        """
        logger.info(f"Loading SST {split} set")
        self.sst = sst[split]

        logger.info("Tokenizing")
        if root and binary:
            self.data = [
                (
                    rpad(
                        tokenizer.encode("[CLS] " + tree.to_lines()[0] + " [SEP]"), n=66
                    ),
                    get_binary_label(tree.label),
                )
                for tree in self.sst
                if tree.label != 2
            ]
        elif root and not binary:
            self.data = [
                (
                    rpad(
#                         tokenizer.encode("[CLS] " + tree.to_lines()[0] + " [SEP]"), n=66
                        tokenizer.encode(tree.to_lines()[0]), n=66
                    ),
                    tree.label,
                )
                for tree in self.sst
            ]
        elif not root and not binary:
            self.data = [
                (rpad(tokenizer.encode("[CLS] " + line + " [SEP]"), n=66), label)
                for tree in self.sst
                for label, line in tree.to_labeled_lines()
            ]
        else:
            self.data = [
                (
                    rpad(tokenizer.encode("[CLS] " + line + " [SEP]"), n=66),
                    get_binary_label(label),
                )
                for tree in self.sst
                for label, line in tree.to_labeled_lines()
                if label != 2
            ]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        X, y = self.data[index]
        X = torch.tensor(X)
        return X, y

    

In [None]:

os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

print(device)

### random seed 
def set_seed(seed_value=42):
    """Set seed for reproducibility.
    """
    random.seed(seed_value)
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    torch.cuda.manual_seed_all(seed_value)
    
    
def train_one_epoch(model, lossfn, optimizer, dataset, batch_size=32):
    generator = torch.utils.data.DataLoader(
        dataset, batch_size=batch_size, shuffle=True
    )
    model.train()
    train_loss, train_acc = 0.0, 0.0
    for batch, labels in tqdm(generator):
        batch, labels = batch.to(device), labels.to(device)
        optimizer.zero_grad()
        loss, logits = model(input_ids = batch, labels=labels)[:2]
        err = lossfn(logits, labels)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        pred_labels = torch.argmax(logits, axis=1)
        train_acc += (pred_labels == labels).sum().item()
    train_loss /= len(dataset)
    train_acc /= len(dataset)
    return train_loss, train_acc


def evaluate_one_epoch(model, lossfn, optimizer, dataset, batch_size=32):
    generator = torch.utils.data.DataLoader(
        dataset, batch_size=batch_size, shuffle=True
    )
    model.eval()
    loss, acc = 0.0, 0.0
    with torch.no_grad():
        for batch, labels in tqdm(generator):
            batch, labels = batch.to(device), labels.to(device)
            logits = model(batch)[0]
            error = lossfn(logits, labels)
            loss += error.item()
            pred_labels = torch.argmax(logits, axis=1)
            acc += (pred_labels == labels).sum().item()
    loss /= len(dataset)
    acc /= len(dataset)
    return loss, acc


def train(
    root=True,
    binary=False,
    bert="distilbert-base-uncased",
    epochs=10,
    batch_size=32,
    patience = 5,
    save=False,
):
    trainset = SSTDataset("train", root=root, binary=binary)
    devset = SSTDataset("dev", root=root, binary=binary)
    testset = SSTDataset("test", root=root, binary=binary)

#     REMOVE BAD TRAINING DATA
    for x in trainset.data:
        if len(x[0]) != 66:
            trainset.data.remove(x)
   
    for x in devset.data:
        if len(x[0]) != 66:
            devset.data.remove(x)
   
    for x in testset.data:
        if len(x[0]) != 66:
            testset.data.remove(x)

    train_losses = []
    val_losses = []
    test_losses = []

    train_accuracies = []
    val_accuracies = []
    test_accuracies = []
    
    # Early stopping parameters
    last_loss = 100
    patience = patience
    trigger_times = 0
    
    config = DistilBertConfig.from_pretrained(bert)
    if not binary:
        config.num_labels = 5
    
    model = DistilBertForSequenceClassification.from_pretrained(bert, config=config)
    
    #    switch to GPU if available
    model = model.to(device)
    
    lossfn = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)

    for epoch in range(1, epochs+1):
        train_loss, train_acc = train_one_epoch(
            model, lossfn, optimizer, trainset, batch_size=batch_size
        )
        val_loss, val_acc = evaluate_one_epoch(
            model, lossfn, optimizer, devset, batch_size=batch_size
        )
        test_loss, test_acc = evaluate_one_epoch(
            model, lossfn, optimizer, testset, batch_size=batch_size
        )
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        test_losses.append(test_loss)

        train_accuracies.append(train_acc)
        val_accuracies.append(val_acc)
        test_accuracies.append(test_acc) 
            
        
        
        logger.info(f"epoch={epoch}")
        logger.info(
            f"train_loss={train_loss:.4f}, val_loss={val_loss:.4f}, test_loss={test_loss:.4f}"
        )
        logger.info(
            f"train_acc={train_acc:.3f}, val_acc={val_acc:.3f}, test_acc={test_acc:.3f}"
        )
        
        if save:
            label = "binary" if binary else "fine"
            nodes = "root" if root else "all"
            torch.save(model, f"drop_{bert}__{nodes}__{label}_e{epoch}.pickle")

        # Early Stopping
        current_loss = val_loss
        if current_loss > last_loss:
            trigger_times += 1
            logger.info(f"Trigger Times: {trigger_times}")

            if trigger_times >= patience:
                logger.info(f"Done with Early Stopping at epoch {epoch}!")
                return train_losses, val_losses, test_losses, train_accuracies, val_accuracies, test_accuracies, epoch

        else:
            logger.info('Trigger Times: 0')
            trigger_times = 0

        last_loss = current_loss
    logger.success("Done!")
    return train_losses, val_losses, test_losses, train_accuracies, val_accuracies, test_accuracies, epoch


In [None]:
total_train_losses = []
total_val_losses = []
total_test_losses = []
total_train_accuracies = []
total_val_accuracies = []
total_test_accuracies = []

for i in range(1,6):
#     set_seed(42)
    train_losses, val_losses, test_losses, train_accuracies, val_accuracies, test_accuracies, epoch = train(root=True,
                                                                                                            binary=False,
                                                                                                            bert="distilbert-base-uncased",
                                                                                                            epochs=30,
                                                                                                            batch_size=8,
                                                                                                            patience = 30,
                                                                                                            save=True)
    total_train_losses += train_losses
    total_val_losses += val_losses
    total_test_losses += test_losses
    total_train_accuracies += train_accuracies
    total_val_accuracies += val_accuracies
    total_test_accuracies += test_accuracies

In [None]:

df = pd.DataFrame(list(zip(total_train_losses,total_val_losses,total_test_losses,
                           total_train_accuracies,total_val_accuracies,total_test_accuracies)),
                                              columns =['Train Loss', 'Val Loss', 'Test Loss',
                                                        'Train Accuracy', 'Val Accuracy', 'Test Accuracy'])
df.to_csv('DistilBERT_BASE_5.csv')
df_1, df_2, df_3, df_4, df_5 = np.array_split(df, 5)

In [None]:

EPOCH = range(epoch)

fig, axs = plt.subplots(2,3, figsize=(15,8))
fig.suptitle('Horizontally stacked subplots of Losses and Accuracies')
axs[0,0].plot(EPOCH, df_1['Train Loss'])
axs[0,0].plot(EPOCH, df_2['Train Loss'])
axs[0,0].plot(EPOCH, df_3['Train Loss'])
axs[0,0].plot(EPOCH, df_4['Train Loss'])
axs[0,0].plot(EPOCH, df_5['Train Loss'])
axs[0,0].set_title("Train Loss")
axs[0,1].plot(EPOCH, df_1['Val Loss'])
axs[0,1].plot(EPOCH, df_2['Val Loss'])
axs[0,1].plot(EPOCH, df_3['Val Loss'])
axs[0,1].plot(EPOCH, df_4['Val Loss'])
axs[0,1].plot(EPOCH, df_5['Val Loss'])
axs[0,1].set_title("Validation Loss")
axs[0,2].plot(EPOCH, df_1['Test Loss'])
axs[0,2].plot(EPOCH, df_2['Test Loss'])
axs[0,2].plot(EPOCH, df_3['Test Loss'])
axs[0,2].plot(EPOCH, df_4['Test Loss'])
axs[0,2].plot(EPOCH, df_5['Test Loss'])
axs[0,2].set_title("Test Loss")
axs[1,0].plot(EPOCH, df_1['Train Accuracy'])
axs[1,0].plot(EPOCH, df_2['Train Accuracy'])
axs[1,0].plot(EPOCH, df_3['Train Accuracy'])
axs[1,0].plot(EPOCH, df_4['Train Accuracy'])
axs[1,0].plot(EPOCH, df_5['Train Accuracy'])
axs[1,0].set_title("Train Accuracy")
axs[1,1].plot(EPOCH, df_1['Val Accuracy'])
axs[1,1].plot(EPOCH, df_2['Val Accuracy'])
axs[1,1].plot(EPOCH, df_3['Val Accuracy'])
axs[1,1].plot(EPOCH, df_4['Val Accuracy'])
axs[1,1].plot(EPOCH, df_5['Val Accuracy'])
axs[1,1].set_title("Validation Accuracy")
axs[1,2].plot(EPOCH, df_1['Test Accuracy'])
axs[1,2].plot(EPOCH, df_2['Test Accuracy'])
axs[1,2].plot(EPOCH, df_3['Test Accuracy'])
axs[1,2].plot(EPOCH, df_4['Test Accuracy'])
axs[1,2].plot(EPOCH, df_5['Test Accuracy'])
axs[1,2].set_title("Test Accuracy")

In [None]:
df_mean = pd.DataFrame()
df_concat = pd.concat([df_1.reset_index().drop(['index'],axis=1),
                       df_2.reset_index().drop(['index'],axis=1),
                       df_3.reset_index().drop(['index'],axis=1),
                       df_4.reset_index().drop(['index'],axis=1),
                       df_5.reset_index().drop(['index'],axis=1)], axis= 1) 
df_mean['Train Loss'] = df_concat['Train Loss'].mean(axis=1)
df_mean['Val Loss'] = df_concat['Val Loss'].mean(axis=1)
df_mean['Test Loss'] = df_concat['Test Loss'].mean(axis=1)
df_mean['Train Accuracy'] = df_concat['Train Accuracy'].mean(axis=1)
df_mean['Val Accuracy'] = df_concat['Val Accuracy'].mean(axis=1)
df_mean['Test Accuracy'] = df_concat['Test Accuracy'].mean(axis=1)

In [None]:
EPOCH = range(epoch)

fig, axs = plt.subplots(2,3, figsize=(15,8))
fig.suptitle('Horizontally stacked subplots of Losses and Accuracies')
axs[0,0].plot(EPOCH, df_mean['Train Loss'])
axs[0,0].set_title("Train Loss")
axs[0,1].plot(EPOCH, df_mean['Val Loss'])
axs[0,1].set_title("Validation Loss")
axs[0,2].plot(EPOCH, df_mean['Test Loss'])
axs[0,2].set_title("Test Loss")
axs[1,0].plot(EPOCH, df_mean['Train Accuracy'])
axs[1,0].set_title("Train Accuracy")
axs[1,1].plot(EPOCH, df_mean['Val Accuracy'])
axs[1,1].set_title("Validation Accuracy")
axs[1,2].plot(EPOCH, df_mean['Test Accuracy'])
axs[1,2].set_title("Test Accuracy")