In [None]:
#import correct version
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
import matplotlib.pyplot as plt
%matplotlib inline
import pytorch_lightning as pl
from torchmetrics.functional import accuracy
import transformers
from transformers import BertModel, BertConfig
from transformers import AutoModel, BertTokenizerFast
import pandas as pd

In [None]:
IN_DATA_PATH = r'C:\Data Sciences\Data\inputs'
OUT_DATA_PATH = r'C:\Data Sciences\Data\outputs'

In [None]:
pub_health_train = pd.read_csv(IN_DATA_PATH + r"\train.tsv", sep='\t')
pub_health_test = pd.read_csv(IN_DATA_PATH + r"\test.tsv", sep='\t')

In [None]:
pub_health_train = pub_health_train[pub_health_train['label'] != 'snopes']
pub_health_train = pub_health_train[['main_text','label']]
pub_health_train = pub_health_train.dropna(subset=['main_text', 'label'])

pub_health_train.head()

In [None]:
pub_health_test = pub_health_test[['main_text','label']]
pub_health_test = pub_health_test.dropna(subset=['main_text', 'label'])

In [None]:
pub_health_train['label'] = pub_health_train['label'].map({"true":0, "false":1, "unproven":2, "mixture":3})
pub_health_test['label'] = pub_health_test['label'].map({"true":0, "false":1, "unproven":2, "mixture":3})

In [None]:
class HealthClaimClassifier(pl.LightningModule):

    def __init__(self, max_seq_len=512, batch_size=128, learning_rate = 0.001):
        super().__init__()
        self.learning_rate = learning_rate
        self.max_seq_len = max_seq_len
        self.batch_size = batch_size
        self.loss = nn.CrossEntropyLoss()

        self.pretrain_model  = AutoModel.from_pretrained('bert-base-uncased')
        self.pretrain_model.eval()
        for param in self.pretrain_model.parameters():
            param.requires_grad = False


        self.new_layers = nn.Sequential(
            nn.Linear(768, 512),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(512,4),
            nn.LogSoftmax(dim=1)
        )

    def prepare_data(self):
      tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')

      tokens_train = tokenizer.batch_encode_plus(
          pub_health_train["main_text"].tolist(),
          max_length = self.max_seq_len,
          pad_to_max_length=True,
          truncation=True,
          return_token_type_ids=False
      )

      tokens_test = tokenizer.batch_encode_plus(
          pub_health_test["main_text"].tolist(),
          max_length = self.max_seq_len,
          pad_to_max_length=True,
          truncation=True,
          return_token_type_ids=False
      )

      self.train_seq = torch.tensor(tokens_train['input_ids'])
      self.train_mask = torch.tensor(tokens_train['attention_mask'])
      self.train_y = torch.tensor(pub_health_train["label"].tolist())

      self.test_seq = torch.tensor(tokens_test['input_ids'])
      self.test_mask = torch.tensor(tokens_test['attention_mask'])
      self.test_y = torch.tensor(pub_health_test["label"].tolist())

    def forward(self, encode_id, mask):
        _, output= self.pretrain_model(encode_id, attention_mask=mask)
        output = self.new_layers(output)
        return output

    def train_dataloader(self):
      train_dataset = TensorDataset(self.train_seq, self.train_mask, self.train_y)
      self.train_dataloader_obj = DataLoader(train_dataset, batch_size=self.batch_size)
      return self.train_dataloader_obj


    def test_dataloader(self):
      test_dataset = TensorDataset(self.test_seq, self.test_mask, self.test_y)
      self.test_dataloader_obj = DataLoader(test_dataset, batch_size=self.batch_size)
      return self.test_dataloader_obj

    def training_step(self, batch, batch_idx):
      encode_id, mask, targets = batch
      outputs = self(encode_id, mask) 
      preds = torch.argmax(outputs, dim=1)
      train_accuracy = accuracy(preds, targets)
      loss = self.loss(outputs, targets)
      self.log('train_accuracy', train_accuracy, prog_bar=True, on_step=False, on_epoch=True)
      self.log('train_loss', loss, on_step=False, on_epoch=True)
      return {"loss":loss, 'train_accuracy': train_accuracy}

    def test_step(self, batch, batch_idx):
      encode_id, mask, targets = batch
      outputs = self.forward(encode_id, mask)
      preds = torch.argmax(outputs, dim=1)
      test_accuracy = accuracy(preds, targets)
      loss = self.loss(outputs, targets)
      return {"test_loss":loss, "test_accuracy":test_accuracy}

    def test_epoch_end(self, outputs):
      test_outs = []
      for test_out in outputs:
          out = test_out['test_accuracy']
          test_outs.append(out)
      total_test_accuracy = torch.stack(test_outs).mean()
      self.log('total_test_accuracy', total_test_accuracy, on_step=False, on_epoch=True)
      return total_test_accuracy

    def configure_optimizers(self):
      params = self.parameters()
      optimizer = optim.Adam(params=params, lr = self.learning_rate)
      return optimizer

In [9]:
model = HealthClaimClassifier()

trainer = pl.Trainer(fast_dev_run=True)
trainer.fit(model)