In [None]:
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification, BertModel
import torch
from torch.utils.data import DataLoader
from torch.optim import AdamW
from pytorch_lightning import LightningModule, Trainer, callbacks
import evaluate
import numpy

In [None]:
batch_size = 16
dataset = load_dataset("yelp_review_full")
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True)
tokenized_datasets = dataset.map(tokenize_function, batched=True)
# get dataset in appropriate format for pytorch
tokenized_datasets = tokenized_datasets.remove_columns(["text"])
tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
tokenized_datasets.set_format("torch")
# use small sample of full dataset
small_train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(5120))
small_eval_dataset = tokenized_datasets["test"].shuffle(seed=42).select(range(512))
# data loaders
train_dataloader = DataLoader(small_train_dataset, shuffle=True, batch_size=batch_size)
eval_dataloader = DataLoader(small_eval_dataset, batch_size=batch_size)

In [None]:
class BertLightning(LightningModule):
    def __init__(self):
        super().__init__()
        self.bert = BertModel.from_pretrained("bert-base-cased", output_attentions=True)
        self.W = torch.nn.Linear(self.bert.config.hidden_size, 5)
        self.num_classes = 5
        self.loss_function = torch.nn.CrossEntropyLoss()

    def forward(self, input_ids, token_type_ids, attention_mask):
        result = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        logits = self.W(result['last_hidden_state'][:, 0])
        return logits
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=5e-5)
    
    def training_step(self, batch, batch_idx):
        y, input_ids, token_type_ids, attention_mask = batch['labels'], batch['input_ids'], batch['token_type_ids'], batch['attention_mask']
        pred = self(input_ids, token_type_ids, attention_mask)
        loss = self.loss_function(pred, y)
        accuracy = sum(pred.argmax(1) == y)/len(y)
        self.log("training_loss", loss, on_step=True, on_epoch=True)
        self.log("training_accuracy", accuracy, on_step=True, on_epoch=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        y, input_ids, token_type_ids, attention_mask = batch['labels'], batch['input_ids'], batch['token_type_ids'], batch['attention_mask']
        pred = self(input_ids, token_type_ids, attention_mask)
        loss = self.loss_function(pred, y)
        accuracy = sum(pred.argmax(1) == y)/len(y)
        self.log("validation_loss", loss, prog_bar=True, on_step=True, on_epoch=True)
        self.log("validation_accuracy", accuracy, prog_bar=True, on_step=True, on_epoch=True)

In [None]:
checkpoint_callback = callbacks.ModelCheckpoint(dirpath='./bert_lightning/',filename='bert_{epoch}')
model = BertLightning()
trainer = Trainer(max_epochs=3, callbacks=[checkpoint_callback])
trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=eval_dataloader)
torch.save(model.state_dict(),'./bert_lightning/weights.pth')

In [None]:
loaded_model = BertLightning()
loaded_model.load_from_checkpoint('./bert_lightning/bert_epoch=2.ckpt')
loaded_model.load_state_dict(torch.load('./bert_lightning/weights.pth'))

In [None]:
metric = evaluate.load("accuracy")
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = numpy.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

loss_function = torch.nn.CrossEntropyLoss()
it = iter(eval_dataloader)
batch  = next(it)
labels, input_ids, token_type_ids, attention_mask = batch['labels'], batch['input_ids'], batch['token_type_ids'], batch['attention_mask']
print(labels, input_ids, token_type_ids, attention_mask)
with torch.no_grad():
    result =  loaded_model.forward(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
    print(labels)
    print(result)
    print(loss_function(result, labels))
    print(numpy.argmax(result, axis=-1))
    print(compute_metrics((result, labels)))

In [None]:
count = 0
correct = 0
for step, batch in enumerate(eval_dataloader):
    with torch.no_grad():
        labels, input_ids, token_type_ids, attention_mask = batch['labels'], batch['input_ids'], batch['token_type_ids'], batch['attention_mask']
        predictions = loaded_model(input_ids, token_type_ids, attention_mask)
        new_correct = len(labels) * compute_metrics((predictions, labels))['accuracy']
        correct = correct + new_correct
        count = count + len(labels)
        print(count, new_correct, correct)
print(correct/count)