In [1]:
import os
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split, Dataset, Subset
# from torchvision import transforms, models
import pytorch_lightning as pl
from pytorch_lightning.utilities.model_summary import ModelSummary
from pytorch_lightning.loggers import WandbLogger
from PIL import Image
import pandas as pd
from sklearn.model_selection import train_test_split
from torch.optim.lr_scheduler import ReduceLROnPlateau, CosineAnnealingLR
import wandb
import torchmetrics
import sys
from ray.tune.integration.pytorch_lightning import TuneReportCallback
from ray import tune
from transformers import DistilBertTokenizer

# Lightning moduel

In [2]:
class textClassifier(pl.LightningModule):
    def __init__(self, vocab_size=10_000, embedding_dim=128, hidden_dim=64, learning_rate=3e-4):
        super().__init__()
        self.save_hyperparameters()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)
        self.relu = nn.Sequential(
            nn.Linear(embedding_dim, 1)
        )
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, x):
        x = self.embedding(x)
#         x = nn.utils.rnn.pack_padded_sequence(x, text_lengths, batch_first=True)
        x, (hidden, cell) = self.lstm(x)
#         x = hidden.squeeze()
        x = self.relu(x)
        y_hat = self.sigmoid(x)
        
        return y_hat
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        print('y_hat:', y_hat.dtype)
        print('y:', y.dtype)
        loss = F.binary_cross_entropy(y_hat, y)
        pred = loss.round()
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        
        return {'loss':loss, 'pred':pred}
        
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        print('y_hat:', y_hat.dtype)
        print('y:', y.dtype)
        loss = F.binary_cross_entropy(y_hat, y)
        pred = loss.round()
        self.log("val_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        
        return {'val_loss':loss, 'pred':pred}
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adamax(self.parameters(), lr=self.hparams.learning_rate)
        return [optimizer]    

# Dataset

In [3]:
import contractions
from bs4 import BeautifulSoup
from unidecode import unidecode
import string
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer
import re

class textDataset(Dataset):
    def __init__(self, data_dir):
        self.stop_words = set(stopwords.words('english'))
        self.lemmatizer = WordNetLemmatizer()
        self.tweets = pd.read_csv(data_dir)
        self.tweets['keyword'] = self.tweets['keyword'].fillna('')
        self.tweets['tweet'] = self.tweets['keyword'] + ' ' + self.tweets['text']
        self.tweets = self.text_preprocessing(self.tweets)
        self.tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
    
    def rem_urls(self, data):
        regex = re.compile(f"https?://(www\.)?(\w+)(\.\w+)(/\w*)?")
        data = re.sub(regex, "", data)
        return data
    
    def rem_emails(self, data):
        regex = re.compile("([A-Za-z0-9]+[.-_])*[A-Za-z0-9]+@[A-Za-z0-9-]+(\.[A-Z|a-z]{2,})+")
        data = re.sub(regex, "", data)
        return data
    
    def rem_mention(self, data):
        regex = re.compile('@\w+')
        data = re.sub(regex, '', data)
        return data
    
    def rem_accent(self, data):
        data = unidecode(data)
        return data
    
    def rem_unicode(self, data):
        data = data.encode("ascii", "ignore").decode()
        return data
    
    def rem_punc(self, data):
        data = re.sub(f"[{string.punctuation}]", " ", data)
        return data
    
    def clean_numbers(self, data):
        data = re.sub('[0-9]{5,}', '#'*5, data)
        data = re.sub('[0-9]{4}', '#'*4, data)
        data = re.sub('[0-9]{3}', '#'*3, data)
        data = re.sub('[0-9]{2}', '#'*2, data)
        return data
    
    
    def rem_stopwords(self, data):
        return " ".join([word for word in str(data).split() if word not in self.stop_words])
    
    def rem_extra_space(self, data):
        data = re.sub(' +', ' ', data).strip()
        return data
    
    
    def lemmatize_data(self, data):
        words = [self.lemmatizer.lemmatize(word) for word in data.split()]
        data = ' '.join(words)
        return data
    
    def text_preprocessing(self, data):
        data['tweet'] = data['tweet'].str.lower()
        data['tweet'] = data['tweet'].apply(contractions.fix)
        data['tweet'] = data['tweet'].apply(self.rem_urls)
        data['tweet'] = data['tweet'].apply(self.rem_emails)
        data['tweet'] = data['tweet'].apply(lambda x: BeautifulSoup(x).get_text())
        data['tweet'] = data['tweet'].apply(self.rem_mention)
        data['tweet'] = data['tweet'].str.replace(':\(', 'sadness ')
        data['tweet'] = data['tweet'].str.replace(r':\)[$|\s]*', 'happiness ')
        data['tweet'] = data['tweet'].str.replace(r'\;\)[$|\s]*', 'happiness ')
        data['tweet'] = data['tweet'].apply(self.rem_accent)
        data['tweet'] = data['tweet'].apply(self.rem_unicode)
        data['tweet'] = data['tweet'].apply(self.rem_punc)
        data['tweet'] = data['tweet'].apply(self.clean_numbers)
        data['tweet'] = data['tweet'].apply(self.rem_stopwords)
        data['tweet'] = data['tweet'].apply(self.rem_extra_space)
        data['tweet'] = data['tweet'].apply(self.lemmatize_data)
        return data
        
    def __len__(self):
        return len(self.tweets)
    
    def __getitem__(self, idx):
        text = self.tweets.iloc[idx, -1]
        encoding = self.tokenizer.batch_encode_plus([text], return_tensors="pt", max_length=128, pad_to_max_length=True)
        input_ids = encoding['input_ids'].flatten()
        attention_mask = encoding['attention_mask'].flatten()
        label = [self.tweets.iloc[idx, -2]]
        label = torch.tensor(label, dtype=torch.float32)
        
        return {'text':text, 'input_ids':input_ids, 'attention_mask':attention_mask, 'labels':label}

In [4]:
def train_val_dataset(dataset, train_split=0.75, val_split=0.25):
    train_idx, val_idx = train_test_split(list(range(len(dataset))), train_size=train_split, test_size=val_split)
    datasets = {'train':Subset(dataset, train_idx), 'val':Subset(dataset, val_idx)}
    return datasets

In [5]:
class textDatamodule(pl.LightningDataModule):
    def __init__(self, batch_size):
        super().__init__()
        self.batch_size = batch_size
        self.tokenizer = None
    def setup(self, stage=None):
        data = textDataset("./nlp-getting-started/train.csv")
        self.tokenizer = data.tokenizer
        data = train_val_dataset(data)
        self.train_dataset = data['train']
        self.val_dataset = data['val']
    
    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)
    
    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size)    

In [6]:
checkpoint_callback = pl.callbacks.ModelCheckpoint(
    dirpath='./check_point/',
    filename='{epoch}-{train_loss:.4f}-{val_loss:.4f}',
    monitor='val_loss',
    mode='min',
    save_top_k=2
)

In [7]:
wandb.init(
      mode='disabled',
      # Set the project where this run will be logged
      project="histopathologic-cancer-classifier", 
      # We pass a run name (otherwise it’ll be randomly assigned, like sunshine-lollypop-10)
      name=f"Test1", 
      # Track hyperparameters and run metadata
      config={
      "learning_rate": 0.0005,
      "data_size": 1,
      "batch_size":32,
      })
wandb_logger = WandbLogger()

  rank_zero_warn(


In [8]:
trainer = pl.Trainer(accelerator='gpu',
                    devices=1,
                    max_epochs=100,
                    logger=wandb_logger,
                    callbacks=[checkpoint_callback]
                    )

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [9]:
import nltk
nltk.download('stopwords')
nltk.download('wordnet')

[nltk_data] Downloading package stopwords to
[nltk_data]     /Users/haeinpark/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package wordnet to
[nltk_data]     /Users/haeinpark/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


True

In [10]:
dm = textDatamodule(batch_size=wandb.config['batch_size'])
dm.setup()

  data['tweet'] = data['tweet'].str.replace(':\(', 'sadness ')
  data['tweet'] = data['tweet'].str.replace(r':\)[$|\s]*', 'happiness ')
  data['tweet'] = data['tweet'].str.replace(r'\;\)[$|\s]*', 'happiness ')


In [11]:
text_encoder = textClassifier(vocab_size=dm.tokenizer.vocab_size, embedding_dim=64, hidden_dim=32)

In [None]:
 trainer.fit(text_encoder, dm)

  data['tweet'] = data['tweet'].str.replace(':\(', 'sadness ')
  data['tweet'] = data['tweet'].str.replace(r':\)[$|\s]*', 'happiness ')
  data['tweet'] = data['tweet'].str.replace(r'\;\)[$|\s]*', 'happiness ')

  | Name      | Type       | Params
-----------------------------------------
0 | embedding | Embedding  | 2.0 M 
1 | lstm      | LSTM       | 12.5 K
2 | relu      | Sequential | 65    
3 | sigmoid   | Sigmoid    | 0     
-----------------------------------------
2.0 M     Trainable params
0         Non-trainable params
2.0 M     Total params
7.864     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.
