In [15]:
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification, BertModel
import torch
from torch.utils.data import DataLoader
from torch.optim import AdamW
from pytorch_lightning import LightningModule, Trainer, callbacks
import evaluate
import numpy

In [16]:
batch_size = 16
dataset = load_dataset("yelp_review_full")
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True)
tokenized_datasets = dataset.map(tokenize_function, batched=True)
# get dataset in appropriate format for pytorch
tokenized_datasets = tokenized_datasets.remove_columns(["text"])
tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
tokenized_datasets.set_format("torch")
# use small sample of full dataset
small_train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(5120))
small_eval_dataset = tokenized_datasets["test"].shuffle(seed=42).select(range(512))
# data loaders
train_dataloader = DataLoader(small_train_dataset, shuffle=True, batch_size=batch_size)
eval_dataloader = DataLoader(small_eval_dataset, batch_size=batch_size)

Found cached dataset yelp_review_full (/Users/skelley/.cache/huggingface/datasets/yelp_review_full/yelp_review_full/1.0.0/e8e18e19d7be9e75642fc66b198abadb116f73599ec89a69ba5dd8d1e57ba0bf)


  0%|          | 0/2 [00:00<?, ?it/s]

Loading cached processed dataset at /Users/skelley/.cache/huggingface/datasets/yelp_review_full/yelp_review_full/1.0.0/e8e18e19d7be9e75642fc66b198abadb116f73599ec89a69ba5dd8d1e57ba0bf/cache-1b229fb576f6410a.arrow
Loading cached processed dataset at /Users/skelley/.cache/huggingface/datasets/yelp_review_full/yelp_review_full/1.0.0/e8e18e19d7be9e75642fc66b198abadb116f73599ec89a69ba5dd8d1e57ba0bf/cache-769d3116751632a1.arrow
Loading cached shuffled indices for dataset at /Users/skelley/.cache/huggingface/datasets/yelp_review_full/yelp_review_full/1.0.0/e8e18e19d7be9e75642fc66b198abadb116f73599ec89a69ba5dd8d1e57ba0bf/cache-6e4b596798578d8f.arrow
Loading cached shuffled indices for dataset at /Users/skelley/.cache/huggingface/datasets/yelp_review_full/yelp_review_full/1.0.0/e8e18e19d7be9e75642fc66b198abadb116f73599ec89a69ba5dd8d1e57ba0bf/cache-219ca94af128ec87.arrow


In [17]:
class BertLightning(LightningModule):
    def __init__(self):
        super().__init__()
        self.bert = BertModel.from_pretrained("bert-base-cased", output_attentions=True)
        self.W = torch.nn.Linear(self.bert.config.hidden_size, 5)
        self.num_classes = 5
        self.loss_function = torch.nn.CrossEntropyLoss()

    def forward(self, input_ids, token_type_ids, attention_mask):
        result = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        logits = self.W(result['last_hidden_state'][:, 0])
        return logits
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=5e-5)
    
    def training_step(self, batch, batch_idx):
        y, input_ids, token_type_ids, attention_mask = batch['labels'], batch['input_ids'], batch['token_type_ids'], batch['attention_mask']
        pred = self(input_ids, token_type_ids, attention_mask)
        loss = self.loss_function(pred, y)
        accuracy = sum(pred.argmax(1) == y)/len(y)
        self.log("training_loss", loss, on_step=True, on_epoch=True)
        self.log("training_accuracy", accuracy, on_step=True, on_epoch=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        y, input_ids, token_type_ids, attention_mask = batch['labels'], batch['input_ids'], batch['token_type_ids'], batch['attention_mask']
        pred = self(input_ids, token_type_ids, attention_mask)
        loss = self.loss_function(pred, y)
        accuracy = sum(pred.argmax(1) == y)/len(y)
        self.log("validation_loss", loss, prog_bar=True, on_step=True, on_epoch=True)
        self.log("validation_accuracy", accuracy, prog_bar=True, on_step=True, on_epoch=True)

In [18]:
checkpoint_callback = callbacks.ModelCheckpoint(dirpath='./bert_lightning/',filename='bert_{epoch}')
model = BertLightning()
trainer = Trainer(max_epochs=3, callbacks=[checkpoint_callback])
trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=eval_dataloader)
torch.save(model.state_dict(),'./bert_lightning/weights.pth')

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | 

Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=3` reached.


In [39]:
loaded_model = BertLightning()
loaded_model.load_from_checkpoint('./bert_lightning/bert_epoch=2.ckpt')
loaded_model.load_state_dict(torch.load('./bert_lightning/weights.pth'))

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.predictions.bias', 'cls

<All keys matched successfully>

In [40]:
metric = evaluate.load("accuracy")
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = numpy.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

loss_function = torch.nn.CrossEntropyLoss()
it = iter(eval_dataloader)
batch  = next(it)
labels, input_ids, token_type_ids, attention_mask = batch['labels'], batch['input_ids'], batch['token_type_ids'], batch['attention_mask']
print(labels, input_ids, token_type_ids, attention_mask)
with torch.no_grad():
    result =  loaded_model.forward(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
    print(labels)
    print(result)
    print(loss_function(result, labels))
    print(numpy.argmax(result, axis=-1))
    print(compute_metrics((result, labels)))

tensor([2, 4, 1, 4, 3, 4, 2, 3, 2, 3, 0, 0, 3, 2, 2, 1]) tensor([[  101, 14812, 16442,  ...,     0,     0,     0],
        [  101, 19383,  1303,  ...,     0,     0,     0],
        [  101, 12008, 27788,  ...,     0,     0,     0],
        ...,
        [  101,  1753,  1277,  ...,     0,     0,     0],
        [  101, 13377,   112,  ...,     0,     0,     0],
        [  101,   140,  3161,  ...,     0,     0,     0]]) tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]]) tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]])
tensor([2, 4, 1, 4, 3, 4, 2, 3, 2, 3, 0, 0, 3, 2, 2, 1])
tensor([[-2.3139e+00, -2.7456e+00, -3.5378e-01,  2.5809e+00,  2.5276e+00],
        [-4.0

In [41]:
count = 0
correct = 0
for step, batch in enumerate(eval_dataloader):
    with torch.no_grad():
        labels, input_ids, token_type_ids, attention_mask = batch['labels'], batch['input_ids'], batch['token_type_ids'], batch['attention_mask']
        predictions = loaded_model(input_ids, token_type_ids, attention_mask)
        new_correct = len(labels) * compute_metrics((predictions, labels))['accuracy']
        correct = correct + new_correct
        count = count + len(labels)
        print(count, new_correct, correct)
print(correct/count)

16 9.0 9.0
32 9.0 18.0
48 7.0 25.0
64 13.0 38.0
80 10.0 48.0
96 5.0 53.0
112 10.0 63.0
128 12.0 75.0
144 7.0 82.0
160 12.0 94.0
176 12.0 106.0
192 7.0 113.0
208 10.0 123.0
224 9.0 132.0
240 9.0 141.0
256 8.0 149.0
272 9.0 158.0
288 9.0 167.0
304 10.0 177.0
320 12.0 189.0
336 11.0 200.0
352 8.0 208.0
368 9.0 217.0
384 6.0 223.0
400 10.0 233.0
416 10.0 243.0
432 10.0 253.0
448 12.0 265.0
464 7.0 272.0
480 11.0 283.0
496 7.0 290.0
512 5.0 295.0
0.576171875


In [72]:
result = torch.tensor([[-1e-6, -1e-6, 1e-6], [1e-6,-1e-6,-1e-6]], dtype=torch.float64)
labels = torch.tensor([2,0])
result = torch.tensor([[0, 1.1, 0]], dtype=torch.float64)
labels = torch.tensor([1])
loss_function(result, labels)

tensor(0.5103, dtype=torch.float64)

In [61]:
out = torch.tensor ([[0.0, 1.0, 0.0]])
out_logits = torch.tensor ([[-1.e6, 1.e6, -1.e6]])
print(torch.softmax (out_logits, 1))
target = torch.argmax (out, dim = 1)
target = torch.tensor([1])
print(target)
print(torch.nn.CrossEntropyLoss() (out, target))
print(torch.nn.CrossEntropyLoss() (out_logits, target))


tensor([[0., 1., 0.]])
tensor([1])
tensor(0.5514)
tensor(0.)
