The model in `4_binary_classification` was quite simple, but we had to write a lot of redundant code for some basic work like getting evaluation metrics. We may write cleaner code with [pytorch-lightning](https://pytorch-lightning.readthedocs.io/en/latest/introduction_guide.html).

In [1]:
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import pytorch_lightning as pl
import pytorch_lightning.metrics.sklearns as plm
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.dataset import random_split

In [2]:
# See https://pytorch-lightning.readthedocs.io/en/latest/lightning-module.html
# for the lifecycle in LightningModule.
class LogisticRegressionModel(pl.LightningModule):

    def __init__(self, dat, test_dat, hparams, *args, **kwargs):
        super().__init__()
        self.hparams = hparams
        self.dat = dat
        self.test_dat = test_dat
        
        self.l1 = nn.Linear(test_dat.x.shape[1], 1)

    def forward(self, x):
        y_hat = self.l1(x)
        return y_hat
    
    def setup(self, step):
        # step is either "fit" or "test"; not relevant

        # Split the dataset into a training set and a validation set
        validation_set_size = int(self.dat.x.shape[0] * self.hparams.validation_split)
        training_set_size = self.dat.x.shape[0] - validation_set_size
        train_dat, val_dat = random_split(dat, [training_set_size, validation_set_size])
        self.train_dat = train_dat
        self.val_dat = val_dat

    def configure_optimizers(self):
        return [torch.optim.RMSprop(self.parameters(), lr=self.hparams.learning_rate)]

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x)  # or just self(x)
        loss = nn.BCEWithLogitsLoss()(y_hat, y)
        
        y_pred = (y_hat.detach().clone().sigmoid() > self.hparams.classification_threshold).float()
        
        logs = {
            "train_loss": loss,
            "train_accuracy": plm.Accuracy()(y_pred, y),
            "train_AUC": plm.AUROC()(y_pred, y)
        }
        return {"loss": loss, "log": logs}

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        return {"val_loss": nn.BCEWithLogitsLoss()(y_hat, y)}

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
        logs = {"val_loss": avg_loss}
        return {"avg_val_loss": avg_loss, "log": logs}
    
    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        return {"test_loss": nn.BCEWithLogitsLoss()(y_hat, y)}
    
    def test_epoch_end(self, outputs):
        avg_loss = torch.stack([x["test_loss"] for x in outputs]).mean()
        logs = {"test_loss": avg_loss}
        return {"avg_test_loss": avg_loss, "log": logs}

    def train_dataloader(self):
        return DataLoader(self.train_dat, batch_size=self.hparams.batch_size, num_workers=8)

    def val_dataloader(self):
        return DataLoader(self.val_dat, batch_size=self.hparams.batch_size, num_workers=4)
    
    def test_dataloader(self):
        return DataLoader(self.test_dat, batch_size=self.hparams.batch_size, num_workers=4)

In [3]:
# Prepare data
## Download
train_df = pd.read_csv("https://download.mlcc.google.com/mledu-datasets/california_housing_train.csv")
test_df = pd.read_csv("https://download.mlcc.google.com/mledu-datasets/california_housing_test.csv")
train_df = train_df.reindex(np.random.permutation(train_df.index)) # shuffle the training set

## Calculate the Z-scores of each column
train_df_mean = train_df.mean()
train_df_std = train_df.std()
train_df_norm = (train_df - train_df_mean)/train_df_std

test_df_mean = test_df.mean()
test_df_std  = test_df.std()
test_df_norm = (test_df - test_df_mean)/test_df_std

## Create true label
threshold_in_Z = 1.0
train_df_norm["median_house_value_is_high"] = (train_df_norm["median_house_value"] > threshold_in_Z).astype(float)
test_df_norm["median_house_value_is_high"] = (test_df_norm["median_house_value"] > threshold_in_Z).astype(float)

## Map features and labels into a tensor dataset
class HousingDataset(Dataset):
    def __init__(self, X, y):
        self.x = torch.tensor(X.to_numpy().reshape(-1, X.shape[1]), dtype=torch.float)
        self.y = torch.tensor(y.to_numpy().reshape(-1, 1), dtype=torch.float)
    
    def __getitem__(self, idx):
        return (self.x[idx], self.y[idx])
    
    def __len__(self):
        return len(self.y)


dat = HousingDataset(train_df_norm[["median_income", "total_rooms"]], train_df_norm["median_house_value_is_high"])
test_dat = HousingDataset(test_df_norm[["median_income", "total_rooms"]], test_df_norm["median_house_value_is_high"])

In [4]:
# Remove old versions for TensorBoard
! rm -rf ./lightning_logs

In [None]:
%load_ext tensorboard

%tensorboard --logdir lightning_logs --bind_all

In [6]:
# Hyperparameters
hparams = {
    "learning_rate": 0.001,
    "batch_size": 100,
    "classification_threshold": 0.35,
    "validation_split": 0.2
}
epochs = 20

# Train model
trainer = pl.Trainer(gpus=[0], max_epochs=epochs)
model = LogisticRegressionModel(dat, test_dat, hparams)

trainer.fit(model)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]

  | Name | Type   | Params
--------------------------------
0 | l1   | Linear | 3     


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…




1

In [7]:
trainer.test(model)

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…

--------------------------------------------------------------------------------
TEST RESULTS
{'avg_test_loss': tensor(0.3132, device='cuda:0'),
 'test_loss': tensor(0.3132, device='cuda:0')}
--------------------------------------------------------------------------------



{'avg_test_loss': 0.3131873607635498, 'test_loss': 0.3131873607635498}