In [1]:
!pip install transformers



In [2]:
!git clone https://github.com/ronakdm/input-marginalization.git

fatal: destination path 'input-marginalization' already exists and is not an empty directory.


In [3]:
import pickle
import numpy as np
import time
import datetime
import random
import torch

from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset, random_split

from transformers import BertForSequenceClassification, AdamW, BertConfig, get_linear_schedule_with_warmup

In [4]:
LEARNING_RATE = 2e-5
ADAMW_TOLERANCE = 1e-8
BATCH_SIZE = 32
EPOCHS = 2

if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

print("Running on '%s'." % device)

Running on 'cuda'.


In [5]:
def load_dataset(dataset_name):
    # `dataset_name` can be "train", "valid", or "test".
    input_ids = pickle.load(open("input-marginalization/preprocessed_data/SST-2/input_ids_%s" % dataset_name, "rb"))
    attention_masks = pickle.load(open("input-marginalization/preprocessed_data/SST-2/attention_masks_%s" % dataset_name, "rb"))
    labels = pickle.load(open("input-marginalization/preprocessed_data/SST-2/labels_%s" % dataset_name, "rb"))

    return TensorDataset(input_ids, attention_masks, labels)


train_dataset = load_dataset("train")
val_dataset = load_dataset("valid")

train_dataloader = DataLoader(train_dataset, sampler = RandomSampler(train_dataset), batch_size = BATCH_SIZE)
validation_dataloader = DataLoader(val_dataset, sampler = SequentialSampler(val_dataset), batch_size = BATCH_SIZE)

print('{:>5,} training samples.'.format(len(train_dataset)))
print('{:>5,} validation samples.'.format(len(val_dataset)))

6,919 training samples.
  876 validation samples.


In [6]:
model = BertForSequenceClassification.from_pretrained(
    "bert-base-uncased",
    num_labels = 2, 
    output_attentions = False,
    output_hidden_states = False,
)

model.to(device)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, element

In [7]:
optimizer = AdamW(model.parameters(), lr = LEARNING_RATE, eps = ADAMW_TOLERANCE)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps = 0, num_training_steps = EPOCHS * len(train_dataset))

In [8]:
def flat_accuracy(preds, labels):
    pred_flat = np.argmax(preds, axis=1).flatten()
    labels_flat = labels.flatten()
    return np.sum(pred_flat == labels_flat) / len(labels_flat)


def format_time(elapsed):
    '''
    Takes a time in seconds and returns a string hh:mm:ss
    '''
    # Round to the nearest second.
    elapsed_rounded = int(round((elapsed)))
    
    # Format as hh:mm:ss
    return str(datetime.timedelta(seconds=elapsed_rounded))

In [13]:
seed_val = 42

random.seed(seed_val)
np.random.seed(seed_val)
torch.manual_seed(seed_val)
torch.cuda.manual_seed_all(seed_val)


training_stats = []
total_t0 = time.time()

for epoch_i in range(EPOCHS):
    
    # ========================================
    #               Training
    # ========================================
    
    print("")
    print('======== Epoch {:} / {:} ========'.format(epoch_i + 1, EPOCHS))
    print('Training...')

    
    t0 = time.time()
    total_train_loss = 0

    model.train()

    for step, batch in enumerate(train_dataloader):


        if step % 40 == 0 and not step == 0:
            elapsed = format_time(time.time() - t0)
            print('  Batch {:>5,}  of  {:>5,}.    Elapsed: {:}.'.format(step, len(train_dataloader), elapsed))

        b_input_ids = batch[0].to(device)
        b_input_mask = batch[1].to(device)
        b_labels = batch[2].to(device)

        model.zero_grad()        

        output = model(b_input_ids, token_type_ids=None, attention_mask=b_input_mask, labels=b_labels, return_dict=True)

        loss = output.loss
        logits = output.logits
        
        total_train_loss += loss.item()

        loss.backward()

        # TODO: See if this is needed.
        # Clip the norm of the gradients to 1.0.
        # This is to help prevent the "exploding gradients" problem.
        # torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

        optimizer.step()
        scheduler.step()


    avg_train_loss = total_train_loss / len(train_dataloader)            
    training_time = format_time(time.time() - t0)

    print("")
    print("  Average training loss: {0:.2f}".format(avg_train_loss))
    print("  Training epcoh took: {:}".format(training_time))
        
    # ========================================
    #               Validation
    # ========================================

    print("")
    print("Running Validation...")

    t0 = time.time()

    model.eval()

    total_eval_accuracy = 0
    total_eval_loss = 0
    nb_eval_steps = 0

    for batch in validation_dataloader:
        
        b_input_ids = batch[0].to(device)
        b_input_mask = batch[1].to(device)
        b_labels = batch[2].to(device)
        

        with torch.no_grad():        
            output = model(b_input_ids, token_type_ids=None, attention_mask=b_input_mask, labels=b_labels)
            loss = output.loss
            logits = output.logits
            
        total_eval_loss += loss.item()

        logits = logits.detach().cpu().numpy()
        label_ids = b_labels.to('cpu').numpy()

        total_eval_accuracy += flat_accuracy(logits, label_ids)
        

    avg_val_accuracy = total_eval_accuracy / len(validation_dataloader)
    print("  Accuracy: {0:.2f}".format(avg_val_accuracy))
    avg_val_loss = total_eval_loss / len(validation_dataloader)
    validation_time = format_time(time.time() - t0)
    print("  Validation Loss: {0:.2f}".format(avg_val_loss))
    print("  Validation took: {:}".format(validation_time))

    training_stats.append(
        {
            'epoch': epoch_i + 1,
            'Training Loss': avg_train_loss,
            'Valid. Loss': avg_val_loss,
            'Valid. Accur.': avg_val_accuracy,
            'Training Time': training_time,
            'Validation Time': validation_time
        }
    )

print("")
print("Training complete!")

print("Total training took {:} (h:mm:ss)".format(format_time(time.time()-total_t0)))


Training...
  Batch    40  of    217.    Elapsed: 0:00:18.
  Batch    80  of    217.    Elapsed: 0:00:36.
  Batch   120  of    217.    Elapsed: 0:00:55.
  Batch   160  of    217.    Elapsed: 0:01:14.
  Batch   200  of    217.    Elapsed: 0:01:32.

  Average training loss: 0.35
  Training epcoh took: 0:01:40

Running Validation...
  Accuracy: 0.91
  Validation Loss: 0.23
  Validation took: 0:00:03

Training...
  Batch    40  of    217.    Elapsed: 0:00:19.
  Batch    80  of    217.    Elapsed: 0:00:37.
  Batch   120  of    217.    Elapsed: 0:00:56.
  Batch   160  of    217.    Elapsed: 0:01:15.
  Batch   200  of    217.    Elapsed: 0:01:34.

  Average training loss: 0.14
  Training epcoh took: 0:01:41

Running Validation...
  Accuracy: 0.92
  Validation Loss: 0.24
  Validation took: 0:00:03

Training complete!
Total training took 0:03:28 (h:mm:ss)


In [15]:
# ========================================
#               Testing
# ========================================

test_dataset = load_dataset("test")
test_dataloader = DataLoader(test_dataset, sampler = RandomSampler(test_dataset), batch_size = BATCH_SIZE)
print('{:>5,} test samples.'.format(len(test_dataset)))

print("")
print("Testing...")

t0 = time.time()

model.eval()

total_test_accuracy = 0
total_test_loss = 0
nb_test_steps = 0

for batch in test_dataloader:
    
    b_input_ids = batch[0].to(device)
    b_input_mask = batch[1].to(device)
    b_labels = batch[2].to(device)

    with torch.no_grad():        
        output = model(b_input_ids, token_type_ids=None, attention_mask=b_input_mask, labels=b_labels)
        loss = output.loss
        logits = output.logits
        
    total_test_loss += loss.item()

    logits = logits.detach().cpu().numpy()
    label_ids = b_labels.to('cpu').numpy()

    total_test_accuracy += flat_accuracy(logits, label_ids)
    

avg_test_accuracy = total_test_accuracy / len(test_dataloader)
print("  Accuracy: {0:.2f}".format(avg_test_accuracy))
avg_test_loss = total_test_loss / len(test_dataloader)
test_time = format_time(time.time() - t0)
print("  Test Loss: {0:.2f}".format(avg_test_loss))
print("  Test took: {:}".format(test_time))

1,822 test samples.

Testing...
  Accuracy: 0.93
  Test Loss: 0.20
  Test took: 0:00:08
