In [1]:
import os
import numpy as np
import pandas as pd
from typing import Optional, Any, Callable
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
# torch
import torch
from torch import nn
# transformer
from transformers.optimization import AdamW, get_scheduler, SchedulerType
# local
from NlpAnalytics import *

  from .autonotebook import tqdm as notebook_tqdm
[nltk_data] Downloading package stopwords to /Users/lunli/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


### Load Models

In [2]:
### load BERT Classifier
loader = BertClassifierLoader(ClassifierType.BERT_CLASSIFIER_HF, "bert-base-uncased", 2, 0.1, load_tokenizer=True)
# native
# loader = BertClassifierLoader(ClassifierType.BERT_CLASSIFIER, "bert-base-uncased", 2, 0.1, hidden_dims=[768], load_tokenizer=True)
model, tokenizer = loader.model, loader.tokenizer

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


### AdamNLP + Scheduler

In [5]:
optimizer = AdamNLP(model)
scheduler = optimizer.compile_schedule(total_num_steps=10)



### Demo of AdamW + Scheduler

In [11]:
### prepare everything
root_path = "./NlpAnalytics/data/dummy_data/"
# load dataset, split into input (X) and output (y) variables
dataframe = pd.read_csv(os.path.join(root_path, "ionosphere.csv"), header=None)
dataset = dataframe.values
X = dataset[:,0:34].astype(float)
y = dataset[:,34]
# encode class values as integers
encoder = LabelEncoder()
encoder.fit(y)
y = encoder.transform(y)
# convert into PyTorch tensors
X = torch.tensor(X, dtype=torch.float32)
y = torch.tensor(y, dtype=torch.float32).reshape(-1, 1)
 # train-test split for evaluation of the model
X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=0.7, shuffle=True)
# create model
model = nn.Sequential(
    nn.Linear(34, 34),
    nn.ReLU(),
    nn.Linear(34, 1),
    nn.Sigmoid())

In [12]:
### training set up
n_epochs = 50
batch_size = 24
batch_start = torch.arange(0, len(X_train), batch_size)
loss_fn = nn.BCELoss()
param_optimizer = list(model.named_parameters())
# no_decay = ['bias', 'gamma', 'beta']
optimizer_grouped_parameters = [{'params': [p for _, p in param_optimizer], 'weight_decay': 0.05}]
optimizer = AdamW(optimizer_grouped_parameters, lr=1e-3)
scheduler = get_scheduler(SchedulerType.LINEAR, optimizer=optimizer, num_warmup_steps=0, num_training_steps=550)



In [13]:
### training
model.train()
for epoch in range(n_epochs):
    for start in batch_start:
        X_batch = X_train[start:start+batch_size]
        y_batch = y_train[start:start+batch_size]
        y_pred = model(X_batch)
        loss = loss_fn(y_pred, y_batch)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()
        lr = optimizer.param_groups[0]["lr"]
        print(lr)
 
# evaluate accuracy after training
model.eval()
y_pred = model(X_test)
acc = (y_pred.round() == y_test).float().mean()
acc = float(acc)
print("Model accuracy: %.2f%%" % (acc*100))

0.0009981818181818182
0.0009963636363636364
0.0009945454545454546
0.0009927272727272727
0.000990909090909091
0.000989090909090909
0.0009872727272727273
0.0009854545454545454
0.0009836363636363636
0.0009818181818181818
0.00098
0.0009781818181818181
0.0009763636363636363
0.0009745454545454545
0.0009727272727272728
0.000970909090909091
0.0009690909090909091
0.0009672727272727273
0.0009654545454545455
0.0009636363636363637
0.0009618181818181818
0.00096
0.0009581818181818182
0.0009563636363636365
0.0009545454545454546
0.0009527272727272727
0.0009509090909090909
0.0009490909090909091
0.0009472727272727273
0.0009454545454545454
0.0009436363636363636
0.0009418181818181818
0.00094
0.0009381818181818183
0.0009363636363636364
0.0009345454545454546
0.0009327272727272728
0.000930909090909091
0.0009290909090909091
0.0009272727272727273
0.0009254545454545454
0.0009236363636363636
0.0009218181818181819
0.00092
0.0009181818181818182
0.0009163636363636364
0.0009145454545454546
0.0009127272727272727
0.00