In [2]:
import torch
import matplotlib.pyplot as plt
import pytorch_lightning as pl
import transformers
import pandas as pd

from torch import nn, optim
from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from torchmetrics.functional import accuracy
from transformers import BertModel, BertConfig
from transformers import AutoModel, BertTokenizerFast

In [3]:
pub_health_train = pd.read_csv("./PUBHEALTH/train.tsv", sep='\t')
pub_health_test = pd.read_csv("./PUBHEALTH/test.tsv", sep='\t')

In [4]:
pub_health_train = pub_health_train[pub_health_train['label'] != 'snopes']
pub_health_train = pub_health_train[['main_text','label']]
pub_health_train = pub_health_train.dropna(subset=['main_text', 'label'])
pub_health_train.head()

Unnamed: 0,main_text,label
0,"""Hillary Clinton is in the political crosshair...",false
1,While the financial costs of screening mammogr...,mixture
2,The news release quotes lead researcher Robert...,mixture
3,"The story does discuss costs, but the framing ...",true
4,"""Although the story didn’t cite the cost of ap...",true


In [5]:
pub_health_test = pub_health_test[['main_text','label']]
pub_health_test = pub_health_test.dropna(subset=['main_text', 'label'])

In [6]:
pub_health_train['label'] = pub_health_train['label'].map({"true":0, "false":1, "unproven":2, "mixture":3})
pub_health_test['label'] = pub_health_test['label'].map({"true":0, "false":1, "unproven":2, "mixture":3})

In [13]:
class HealthClaimClassifier(pl.LightningModule):

    def __init__(self, max_seq_len=512, batch_size=64, learning_rate = 0.001):
        super().__init__()
        self.learning_rate = learning_rate
        self.max_seq_len = max_seq_len
        self.batch_size = batch_size
        self.loss = nn.CrossEntropyLoss()

        self.pretrain_model  = AutoModel.from_pretrained('bert-base-uncased', return_dict=False)
        self.pretrain_model.eval()
        for param in self.pretrain_model.parameters():
            param.requires_grad = False

        self.new_layers = nn.Sequential(
            nn.Linear(768, 512),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(512,4),
            nn.LogSoftmax(dim=1)
        )

    def prepare_data(self):
        tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased', return_dict=False)

        tokens_train = tokenizer.batch_encode_plus(
            pub_health_train["main_text"].tolist(),
            max_length = self.max_seq_len,
            pad_to_max_length=True,
            truncation=True,
            return_token_type_ids=False
        )

        tokens_test = tokenizer.batch_encode_plus(
            pub_health_test["main_text"].tolist(),
            max_length = self.max_seq_len,
            pad_to_max_length=True,
            truncation=True,
            return_token_type_ids=False
        )

        self.train_seq = torch.tensor(tokens_train['input_ids'])
        self.train_mask = torch.tensor(tokens_train['attention_mask'])
        self.train_y = torch.tensor(pub_health_train["label"].tolist())

        self.test_seq = torch.tensor(tokens_test['input_ids'])
        self.test_mask = torch.tensor(tokens_test['attention_mask'])
        self.test_y = torch.tensor(pub_health_test["label"].tolist())

    def forward(self, encode_id, mask):
        _, output= self.pretrain_model(encode_id, attention_mask=mask)
        output = self.new_layers(output)
        return output


    def train_dataloader(self):
        train_dataset = TensorDataset(self.train_seq, self.train_mask, self.train_y)
        self.train_dataloader_obj = DataLoader(train_dataset, batch_size=self.batch_size)
        return self.train_dataloader_obj


    def test_dataloader(self):
        test_dataset = TensorDataset(self.test_seq, self.test_mask, self.test_y)
        self.test_dataloader_obj = DataLoader(test_dataset, batch_size=self.batch_size)
        return self.test_dataloader_obj


    def training_step(self, batch, batch_idx):
        encode_id, mask, targets = batch
        outputs = self(encode_id, mask) 
        preds = torch.argmax(outputs, dim=1)
        train_accuracy = accuracy(preds, targets)
        loss = self.loss(outputs, targets)
        self.log('train_accuracy', train_accuracy, prog_bar=True, on_step=False, on_epoch=True)
        self.log('train_loss', loss, on_step=False, on_epoch=True)
        return {"loss":loss, 'train_accuracy': train_accuracy}


    def test_step(self, batch, batch_idx):
        encode_id, mask, targets = batch
        outputs = self.forward(encode_id, mask)
        preds = torch.argmax(outputs, dim=1)
        test_accuracy = accuracy(preds, targets)
        loss = self.loss(outputs, targets)
        return {"test_loss":loss, "test_accuracy":test_accuracy}
        

    def test_epoch_end(self, outputs):
        test_outs = []
        for test_out in outputs:
            out = test_out['test_accuracy']
            test_outs.append(out)
        total_test_accuracy = torch.stack(test_outs).mean()
        self.log('total_test_accuracy', total_test_accuracy, on_step=False, on_epoch=True)
        return total_test_accuracy

    def configure_optimizers(self):
        params = self.parameters()
        optimizer = optim.Adam(params=params, lr = self.learning_rate)
        return optimizer

In [14]:
model = HealthClaimClassifier()

trainer = pl.Trainer(fast_dev_run=True, gpus=1)
trainer.fit(model)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Run

Training: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_steps=1` reached.


In [16]:
model = HealthClaimClassifier()
trainer = pl.Trainer(max_epochs=10, gpus=-1)
trainer.fit(model)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOC

Training: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=10` reached.


In [17]:
trainer.test()

  rank_zero_warn(
Restoring states from the checkpoint path at /home/sd/works/practice-torch/pl-notes/lightning_logs/version_1/checkpoints/epoch=9-step=1540.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from checkpoint at /home/sd/works/practice-torch/pl-notes/lightning_logs/version_1/checkpoints/epoch=9-step=1540.ckpt
  rank_zero_warn(


Testing: 0it [00:00, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
   total_test_accuracy      0.5993106961250305
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'total_test_accuracy': 0.5993106961250305}]