In [1]:
import time
import numpy as np
import pandas as pd
from transformers import AdamW, get_linear_schedule_with_warmup
import torch
from torch import nn
from torch.utils.data import dataloader

from _classifier import BertClassifier, BERT16SDatasetForPhylaClassification, GeneratePhylumLabels, TrainTestSplit

I0813 13:34:05.460933 4556096960 file_utils.py:39] PyTorch version 1.5.0 available.
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


### Add Phylum Lables to Dataset 

In [2]:
label_generator = GeneratePhylumLabels(data_path='SILVA_parsed_V2.tsv')
label_generator.save('SILVA_parsed_V2__labeled.tsv')
num_classes = label_generator.num_classes

  if (await self.run_code(code, result,  async_=asy)):


In [3]:
label_generator.other_label

array([41])

In [44]:
num_classes

42

### Train-Test Split 

In [5]:
train_df, test_df = TrainTestSplit('SILVA_parsed_V2__labeled.tsv').train_test_split()

train_df.to_csv('SILVA_parsed_V2__labeled__train.tsv', sep='\t')
test_df.to_csv('SILVA_parsed_V2__labeled__test.tsv', sep='\t')

  if (await self.run_code(code, result,  async_=asy)):


### Create Dataset 

In [None]:
trainset = BERT16SDatasetForPhylaClassification(
    vocab_path='model/vocab.txt', 
    data_path='SILVA_parsed_V2__labeled__train.tsv')

testset = BERT16SDatasetForPhylaClassification(
    vocab_path='model/vocab.txt', 
    data_path='SILVA_parsed_V2__labeled__test.tsv')

In [4]:
batch_size = 32
num_workers = 4

In [5]:
train_loader = dataloader.DataLoader(
    dataset=trainset,
    batch_size=batch_size,
    num_workers=num_workers
)

test_loader = dataloader.DataLoader(
    dataset=testset,
    batch_size=batch_size,
    num_workers=num_workers
)

### Define Model 

In [6]:
if torch.cuda.is_available():       
    device = torch.device("cuda")
    print(f'There are {torch.cuda.device_count()} GPU(s) available.')
    print('Device name:', torch.cuda.get_device_name(0))

else:
    print('No GPU available, using the CPU instead.')
    device = torch.device("cpu")

No GPU available, using the CPU instead.


In [39]:
def initialize_model(epochs):
    """Initialize the Bert Classifier, the optimizer and the learning rate scheduler.
    """
    # Instantiate Bert Classifier
    bert_classifier = BertClassifier(path='model/', num_classes=num_classes, freeze_bert=True)

    # Tell PyTorch to run the model on GPU
    bert_classifier.to(device)

    # Create the optimizer
    optimizer = AdamW(
        bert_classifier.parameters(),
        lr=5e-5,    # Default learning rate
        eps=1e-8    # Default epsilon value
    )

    # Total number of training steps
    total_steps = len(trainset) * epochs

    # Set up the learning rate scheduler
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=0, # Default value
        num_training_steps=total_steps)
    
    return bert_classifier, optimizer, scheduler

In [40]:
# Specify loss function
loss_fn = nn.CrossEntropyLoss()

### Define Train Loop 

In [41]:
def train(model, train_dataloader, val_dataloader=None, epochs=4, evaluation=False):
    """
    Train loop.
    """
    for epoch_i in range(epochs):
        # Print the header of the result table
        print(f"{'Epoch':^7} | {'Batch':^15} | {'LR':^7} | {'Train Loss':^12} | {'Val Loss':^10} | {'Val Acc':^9} | {'Elapsed':^9}")
        print("-"*90)

        # Measure the elapsed time of each epoch
        t0_epoch, t0_batch = time.time(), time.time()

        total_loss, batch_loss, batch_counts = 0, 0, 0

        model.train()
        num_steps = len(train_dataloader)
        for step, batch in enumerate(train_dataloader):
              
            batch_counts += 1
            b_input_ids, b_labels = tuple(t.to(device) for t in batch)
            model.zero_grad()
            logits = model(b_input_ids)

            loss = loss_fn(logits, b_labels.view(-1,))
            batch_loss += loss.item()
            total_loss += loss.item()

            # back-propagation
            loss.backward()
            # clip the norm of the gradients to 1.0 to prevent "exploding gradients"
            #torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

            optimizer.step()
            scheduler.step()

            if (step % 50 == 0 and step != 0) or (step == len(train_dataloader) - 1):
                time_elapsed = time.time() - t0_batch
                print(f"{epoch_i + 1:^7} | {step:^7}/{num_steps:^7} | {np.round(scheduler.get_lr()[-1], 7):^7}| {batch_loss / batch_counts:^12.6f} | {'-':^10} | {'-':^9} | {time_elapsed:^9.2f}")
                batch_loss, batch_counts = 0, 0
                t0_batch = time.time()           

        avg_train_loss = total_loss / len(train_dataloader)

        print("-"*70)

        if evaluation == True:
            val_loss, val_accuracy = evaluate(model, val_dataloader)
            time_elapsed = time.time() - t0_epoch
            
            print(f"{epoch_i + 1:^7} | {'-':^15} | {'-':^7} | {avg_train_loss:^12.6f} | {val_loss:^10.6f} | {val_accuracy:^9.2f} | {time_elapsed:^9.2f}")
            print("-"*90)
        print("\n")


In [42]:
def evaluate(model, val_dataloader):
    """
    Evaluate model performance.
    """
    model.eval()

    val_accuracy = []
    val_loss = []

    for batch in val_dataloader:
        b_input_ids, b_labels = tuple(t.to(device) for t in batch)

        with torch.no_grad():
            logits = model(b_input_ids)

        loss = loss_fn(logits, b_labels.view(-1,))
        val_loss.append(loss.item())

        preds = torch.argmax(logits, dim=1).flatten()

        accuracy = (preds == b_labels.view(-1,)).cpu().numpy().mean() * 100
        val_accuracy.append(accuracy)

    # compute the average accuracy and loss over the validation set.
    val_loss = np.mean(val_loss)
    val_accuracy = np.mean(val_accuracy)

    return val_loss, val_accuracy

### Train! 

In [43]:
%%time
bert_classifier, optimizer, scheduler = initialize_model(epochs=5)
train(bert_classifier, train_loader, test_loader, epochs=5, evaluation=True)

I0813 21:28:43.049000 4556096960 configuration_utils.py:263] loading configuration file model/config.json
I0813 21:28:43.070768 4556096960 configuration_utils.py:301] Model config BertConfig {
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 256,
  "initializer_range": 0.02,
  "intermediate_size": 1024,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 4,
  "num_hidden_layers": 4,
  "pad_token_id": 0,
  "type_vocab_size": 2,
  "vocab_size": 15621
}

I0813 21:28:43.075200 4556096960 modeling_utils.py:648] loading weights file model/pytorch_model.bin


 Epoch  |      Batch      |   LR    |  Train Loss  |  Val Loss  |  Val Acc  |  Elapsed 
------------------------------------------------------------------------------------------




   1    |   50   / 10801  |  5e-05 |   3.575851   |     -      |     -     |   3.36   
   1    |   100  / 10801  |  5e-05 |   3.186916   |     -      |     -     |   2.67   
   1    |   150  / 10801  |  5e-05 |   2.663955   |     -      |     -     |   2.75   
   1    |   200  / 10801  |  5e-05 |   2.238808   |     -      |     -     |   2.79   
   1    |   250  / 10801  |  5e-05 |   2.118240   |     -      |     -     |   2.70   
   1    |   300  / 10801  |  5e-05 |   2.032410   |     -      |     -     |   2.67   
   1    |   350  / 10801  |  5e-05 |   1.983330   |     -      |     -     |   2.76   
   1    |   400  / 10801  |  5e-05 |   1.922099   |     -      |     -     |   2.66   
   1    |   450  / 10801  |  5e-05 |   1.914383   |     -      |     -     |   2.65   
   1    |   500  / 10801  |  5e-05 |   1.915675   |     -      |     -     |   2.71   
   1    |   550  / 10801  |  5e-05 |   1.921506   |     -      |     -     |   2.66   
   1    |   600  / 10801  |  5e-05 |   1.85

   1    |  4750  / 10801  | 4.99e-05|   1.001105   |     -      |     -     |   2.98   
   1    |  4800  / 10801  | 4.99e-05|   0.989876   |     -      |     -     |   2.96   
   1    |  4850  / 10801  | 4.99e-05|   0.992256   |     -      |     -     |   2.95   
   1    |  4900  / 10801  | 4.99e-05|   0.970932   |     -      |     -     |   2.88   
   1    |  4950  / 10801  | 4.99e-05|   1.027093   |     -      |     -     |   2.85   
   1    |  5000  / 10801  | 4.99e-05|   0.960116   |     -      |     -     |   2.85   
   1    |  5050  / 10801  | 4.99e-05|   0.974573   |     -      |     -     |   2.85   
   1    |  5100  / 10801  | 4.99e-05|   0.962808   |     -      |     -     |   2.84   
   1    |  5150  / 10801  | 4.99e-05|   0.919352   |     -      |     -     |   2.81   
   1    |  5200  / 10801  | 4.98e-05|   0.952846   |     -      |     -     |   2.78   
   1    |  5250  / 10801  | 4.98e-05|   0.998562   |     -      |     -     |   2.78   
   1    |  5300  / 10801  | 4.98

   1    |  9450  / 10801  | 4.97e-05|   0.563357   |     -      |     -     |   2.88   
   1    |  9500  / 10801  | 4.97e-05|   0.591606   |     -      |     -     |   2.83   
   1    |  9550  / 10801  | 4.97e-05|   0.635915   |     -      |     -     |   2.82   
   1    |  9600  / 10801  | 4.97e-05|   0.682108   |     -      |     -     |   2.84   
   1    |  9650  / 10801  | 4.97e-05|   0.620580   |     -      |     -     |   2.89   
   1    |  9700  / 10801  | 4.97e-05|   0.592673   |     -      |     -     |   2.82   
   1    |  9750  / 10801  | 4.97e-05|   0.646565   |     -      |     -     |   2.84   
   1    |  9800  / 10801  | 4.97e-05|   0.613155   |     -      |     -     |   2.82   
   1    |  9850  / 10801  | 4.97e-05|   0.606445   |     -      |     -     |   2.85   
   1    |  9900  / 10801  | 4.97e-05|   0.605053   |     -      |     -     |   2.84   
   1    |  9950  / 10801  | 4.97e-05|   0.639989   |     -      |     -     |   2.84   
   1    |  10000 / 10801  | 4.97

   2    |  3100  / 10801  | 4.96e-05|   0.445077   |     -      |     -     |   2.86   
   2    |  3150  / 10801  | 4.96e-05|   0.429202   |     -      |     -     |   2.85   
   2    |  3200  / 10801  | 4.96e-05|   0.437577   |     -      |     -     |   2.86   
   2    |  3250  / 10801  | 4.96e-05|   0.461201   |     -      |     -     |   2.89   
   2    |  3300  / 10801  | 4.96e-05|   0.418579   |     -      |     -     |   2.88   
   2    |  3350  / 10801  | 4.96e-05|   0.404221   |     -      |     -     |   2.88   
   2    |  3400  / 10801  | 4.96e-05|   0.414280   |     -      |     -     |   2.89   
   2    |  3450  / 10801  | 4.96e-05|   0.414561   |     -      |     -     |   2.88   
   2    |  3500  / 10801  | 4.96e-05|   0.444143   |     -      |     -     |   2.87   
   2    |  3550  / 10801  | 4.96e-05|   0.432002   |     -      |     -     |   2.86   
   2    |  3600  / 10801  | 4.96e-05|   0.482920   |     -      |     -     |   2.90   
   2    |  3650  / 10801  | 4.96

   2    |  7800  / 10801  | 4.95e-05|   0.351029   |     -      |     -     |   2.80   
   2    |  7850  / 10801  | 4.95e-05|   0.376700   |     -      |     -     |   2.75   
   2    |  7900  / 10801  | 4.95e-05|   0.324993   |     -      |     -     |   2.75   
   2    |  7950  / 10801  | 4.95e-05|   0.315294   |     -      |     -     |   2.76   
   2    |  8000  / 10801  | 4.95e-05|   0.365817   |     -      |     -     |   2.83   
   2    |  8050  / 10801  | 4.95e-05|   0.365328   |     -      |     -     |   2.81   
   2    |  8100  / 10801  | 4.95e-05|   0.376610   |     -      |     -     |   2.80   
   2    |  8150  / 10801  | 4.95e-05|   0.375353   |     -      |     -     |   2.85   
   2    |  8200  / 10801  | 4.95e-05|   0.355943   |     -      |     -     |   2.87   
   2    |  8250  / 10801  | 4.94e-05|   0.303515   |     -      |     -     |   2.85   
   2    |  8300  / 10801  | 4.94e-05|   0.340950   |     -      |     -     |   2.85   
   2    |  8350  / 10801  | 4.94

   3    |  1450  / 10801  | 4.93e-05|   0.293160   |     -      |     -     |   2.78   
   3    |  1500  / 10801  | 4.93e-05|   0.282545   |     -      |     -     |   2.76   
   3    |  1550  / 10801  | 4.93e-05|   0.309852   |     -      |     -     |   2.81   
   3    |  1600  / 10801  | 4.93e-05|   0.279801   |     -      |     -     |   2.84   
   3    |  1650  / 10801  | 4.93e-05|   0.265368   |     -      |     -     |   2.81   
   3    |  1700  / 10801  | 4.93e-05|   0.333609   |     -      |     -     |   2.82   
   3    |  1750  / 10801  | 4.93e-05|   0.306202   |     -      |     -     |   2.85   
   3    |  1800  / 10801  | 4.93e-05|   0.300373   |     -      |     -     |   2.88   
   3    |  1850  / 10801  | 4.93e-05|   0.282935   |     -      |     -     |   2.88   
   3    |  1900  / 10801  | 4.93e-05|   0.284142   |     -      |     -     |   2.87   
   3    |  1950  / 10801  | 4.93e-05|   0.283176   |     -      |     -     |   2.89   
   3    |  2000  / 10801  | 4.93

   3    |  6150  / 10801  | 4.92e-05|   0.213853   |     -      |     -     |   2.81   
   3    |  6200  / 10801  | 4.92e-05|   0.267753   |     -      |     -     |   2.82   
   3    |  6250  / 10801  | 4.92e-05|   0.235970   |     -      |     -     |   2.80   
   3    |  6300  / 10801  | 4.92e-05|   0.275273   |     -      |     -     |   2.82   
   3    |  6350  / 10801  | 4.92e-05|   0.252543   |     -      |     -     |   2.82   
   3    |  6400  / 10801  | 4.92e-05|   0.248773   |     -      |     -     |   2.87   
   3    |  6450  / 10801  | 4.92e-05|   0.250292   |     -      |     -     |   2.81   
   3    |  6500  / 10801  | 4.92e-05|   0.217746   |     -      |     -     |   2.81   
   3    |  6550  / 10801  | 4.92e-05|   0.261023   |     -      |     -     |   2.81   
   3    |  6600  / 10801  | 4.92e-05|   0.254824   |     -      |     -     |   2.85   
   3    |  6650  / 10801  | 4.92e-05|   0.278864   |     -      |     -     |   2.84   
   3    |  6700  / 10801  | 4.92

   3    |        -        |    -    |   0.257811   |  0.214883  |   95.00   |  716.73  
------------------------------------------------------------------------------------------


 Epoch  |      Batch      |   LR    |  Train Loss  |  Val Loss  |  Val Acc  |  Elapsed 
------------------------------------------------------------------------------------------
   4    |   50   / 10801  | 4.91e-05|   0.219590   |     -      |     -     |   3.38   
   4    |   100  / 10801  | 4.91e-05|   0.212455   |     -      |     -     |   2.99   
   4    |   150  / 10801  | 4.91e-05|   0.247308   |     -      |     -     |   2.88   
   4    |   200  / 10801  | 4.91e-05|   0.205077   |     -      |     -     |   2.80   
   4    |   250  / 10801  | 4.91e-05|   0.246967   |     -      |     -     |   2.83   
   4    |   300  / 10801  | 4.91e-05|   0.249460   |     -      |     -     |   2.75   
   4    |   350  / 10801  | 4.91e-05|   0.225617   |     -      |     -     |   2.77   
   4    |   400  / 10801

   4    |  4550  / 10801  | 4.89e-05|   0.222563   |     -      |     -     |   2.83   
   4    |  4600  / 10801  | 4.89e-05|   0.190561   |     -      |     -     |   2.81   
   4    |  4650  / 10801  | 4.89e-05|   0.220163   |     -      |     -     |   2.81   
   4    |  4700  / 10801  | 4.89e-05|   0.165605   |     -      |     -     |   2.80   
   4    |  4750  / 10801  | 4.89e-05|   0.219296   |     -      |     -     |   2.79   
   4    |  4800  / 10801  | 4.89e-05|   0.202418   |     -      |     -     |   2.75   
   4    |  4850  / 10801  | 4.89e-05|   0.203315   |     -      |     -     |   2.79   
   4    |  4900  / 10801  | 4.89e-05|   0.182199   |     -      |     -     |   2.77   
   4    |  4950  / 10801  | 4.89e-05|   0.221684   |     -      |     -     |   2.75   
   4    |  5000  / 10801  | 4.89e-05|   0.203098   |     -      |     -     |   2.77   
   4    |  5050  / 10801  | 4.89e-05|   0.191417   |     -      |     -     |   2.76   
   4    |  5100  / 10801  | 4.89

   4    |  9250  / 10801  | 4.88e-05|   0.232309   |     -      |     -     |   2.83   
   4    |  9300  / 10801  | 4.88e-05|   0.203771   |     -      |     -     |   2.84   
   4    |  9350  / 10801  | 4.88e-05|   0.152251   |     -      |     -     |   2.89   
   4    |  9400  / 10801  | 4.88e-05|   0.218597   |     -      |     -     |   2.85   
   4    |  9450  / 10801  | 4.88e-05|   0.171704   |     -      |     -     |   2.87   
   4    |  9500  / 10801  | 4.88e-05|   0.184242   |     -      |     -     |   2.87   
   4    |  9550  / 10801  | 4.88e-05|   0.204857   |     -      |     -     |   2.83   
   4    |  9600  / 10801  | 4.88e-05|   0.215463   |     -      |     -     |   2.82   
   4    |  9650  / 10801  | 4.88e-05|   0.213507   |     -      |     -     |   2.89   
   4    |  9700  / 10801  | 4.88e-05|   0.187164   |     -      |     -     |   2.82   
   4    |  9750  / 10801  | 4.88e-05|   0.214217   |     -      |     -     |   2.84   
   4    |  9800  / 10801  | 4.88

   5    |  2900  / 10801  | 4.87e-05|   0.165582   |     -      |     -     |   2.66   
   5    |  2950  / 10801  | 4.87e-05|   0.140575   |     -      |     -     |   2.65   
   5    |  3000  / 10801  | 4.87e-05|   0.187673   |     -      |     -     |   2.69   
   5    |  3050  / 10801  | 4.87e-05|   0.196073   |     -      |     -     |   2.65   
   5    |  3100  / 10801  | 4.87e-05|   0.169473   |     -      |     -     |   2.63   
   5    |  3150  / 10801  | 4.87e-05|   0.184560   |     -      |     -     |   2.60   
   5    |  3200  / 10801  | 4.87e-05|   0.177969   |     -      |     -     |   2.63   
   5    |  3250  / 10801  | 4.87e-05|   0.201723   |     -      |     -     |   2.59   
   5    |  3300  / 10801  | 4.87e-05|   0.174618   |     -      |     -     |   2.62   
   5    |  3350  / 10801  | 4.87e-05|   0.165200   |     -      |     -     |   2.60   
   5    |  3400  / 10801  | 4.87e-05|   0.174261   |     -      |     -     |   2.63   
   5    |  3450  / 10801  | 4.87

   5    |  7600  / 10801  | 4.85e-05|   0.176030   |     -      |     -     |   2.72   
   5    |  7650  / 10801  | 4.85e-05|   0.165660   |     -      |     -     |   2.71   
   5    |  7700  / 10801  | 4.85e-05|   0.173175   |     -      |     -     |   2.72   
   5    |  7750  / 10801  | 4.85e-05|   0.155898   |     -      |     -     |   2.69   
   5    |  7800  / 10801  | 4.85e-05|   0.177572   |     -      |     -     |   2.67   
   5    |  7850  / 10801  | 4.85e-05|   0.213814   |     -      |     -     |   2.65   
   5    |  7900  / 10801  | 4.85e-05|   0.171210   |     -      |     -     |   2.63   
   5    |  7950  / 10801  | 4.85e-05|   0.156597   |     -      |     -     |   2.62   
   5    |  8000  / 10801  | 4.85e-05|   0.199799   |     -      |     -     |   2.63   
   5    |  8050  / 10801  | 4.85e-05|   0.191391   |     -      |     -     |   2.61   
   5    |  8100  / 10801  | 4.85e-05|   0.205604   |     -      |     -     |   2.61   
   5    |  8150  / 10801  | 4.85