## <ins> Tutorial: Linear Regression </ins>
The goal of this notebook is to introduce pytorch lightning with a simple example. By the end of this tutorial, I'm hoping you will be comfortable with the following:
- The basics of pytorch
- The structure of pytorch lightning

## <ins> Configuration Parameters </ins>
Every experiment needs some control parameters, so here is a basic way to initialize them and prepare them for development. 

In [1]:
# Author: MINDFUL
# Purpose: Configuration ( Linear Regression )

#--------------------------------
# Parameters: All Paths (I / O)
#--------------------------------

path_results = "/develop/results/linear_regression/"
path_dataset = "/develop/data/regression/linear/data.csv"

#-------------------------------
# Parameters: Training Model
#-------------------------------

# Config: Validation Rate

valid_rate = 1

# Config: Randomization

seed = 123 

# Config: CPU 

num_workers = 1

# Config: GPU

use_gpu = 0
gpu_list = [0, 1]

# Config: Gradient Descent

batch_size = 16
num_epochs = 100 
learning_rate = 0.01

# Config: Logger 
# - 0 : Tensorboard
# - 1 : Custom Logger

logger_choice = 0

# Create: Parameter Container 

params = { "path_results": path_results, "path_dataset": path_dataset,
           "valid_rate": valid_rate, "seed": seed, "num_workers": num_workers, "use_gpu": use_gpu, 
           "gpu_list": gpu_list, "batch_size": batch_size, "num_epochs": num_epochs, "learning_rate": learning_rate }


## <ins> Verbose Warnings </ins>

Pytorch lightning is notorious of its warnings. Some can be helpful during the debugging process. Others can be things like suggestions to improve performance. B/c of this, going to share how to filter them. 

In [2]:
#--------------------------------
# Remove: Irrelevant Warnings
#--------------------------------

import warnings

warnings.filterwarnings("ignore")

## <ins> Python Libraries </ins> 

Big strength of python is its large library support. Lets import some and discuss their importance. 

Standard Libraries 

- warnings: Controls i/o with respect to warnings
- numpy: Linear algebra, data representation (e.g., matrices, vectors), and more
- matplotlib: Visualizations / Plots

- torch: pytorch 
- pytorch lightning: Pytorch but with more tools that support organization and simplification. 

Custom Libraries

- Loader: Convert dataset to pytorch format
- Logger: Experiment logging tool for results


In [3]:
#--------------------------------
# Import: Basic Python Libraries
#--------------------------------

import warnings
import numpy as np
import matplotlib.pyplot as plt

#--------------------------------
# Import: Pytorch Libraries
#--------------------------------

import torch
import torch.nn as nn

from typing import Optional
from torch.utils.data import DataLoader
from pytorch_lightning.plugins import DDPPlugin
from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.loggers.base import rank_zero_experiment
from pytorch_lightning import LightningModule, LightningDataModule

#--------------------------------
# Import: Custom Python Libraries
#--------------------------------

import utils

from custom_logger import Logger
from loader import Dataset as Pytorch_Dataset


In [4]:
# Now lets prepare our dataset loader. Its just a CSV file so we can load it appropriately.

In [5]:
#--------------------------------
# Initialize: Custom Dataset 
#--------------------------------

class Dataset:

    def __init__(self, samples, labels):

        self.labels = labels
        self.samples = samples

#--------------------------------
# Load: Training Dataset (.CSV)
#--------------------------------

def load_data(path):

    data_file = open(path, "r")

    data = []
    for line in data_file:
        data.append([ float(ele.strip("\n")) for ele in line.split(",") ])

    data = np.asarray(data)

    samples, labels = data[:, :-1], data[:, -1]

    return Dataset(samples, labels)


In [6]:
# For pytorch, datasets need a specific format in order to take advantage of pytorch utitlies (e.g., dataloaders).
# Because of this, lets make a data module that accomplishes this requirement

In [7]:
#--------------------------------
# Create: Lightning Data Module
#--------------------------------

class PLDM(LightningDataModule):

    def __init__(self, params):
        
        super().__init__() 
                           
        # Load: Dataset Parameters
                           
        self.data = params["train"]
                           
        # Load: Processing Parameters

        self.batch = params["batch_size"]
        self.workers = params["num_workers"]
        
    #----------------------------
    # Create: Training Datasets 
    #----------------------------
                           
    def setup(self, stage: Optional[str] = None):

        # Create: Pytorch Datasets

        self.train = Pytorch_Dataset(self.data)
        self.valid = Pytorch_Dataset(self.data)

    #----------------------------
    # Create: Training DataLoader
    #----------------------------

    def train_dataloader(self):

        return DataLoader( self.train, batch_size = self.batch,
                           num_workers = self.workers, shuffle = 1, persistent_workers = 1 )

    #----------------------------
    # Create: Validation Loader
    #----------------------------

    def val_dataloader(self):

        return DataLoader( self.valid, batch_size = self.batch,
                           num_workers = self.workers, persistent_workers = 1 )


In [8]:
#--------------------------------
# Initialize: Lightining Model
#--------------------------------

class Linear_Regression(LightningModule):

    def __init__(self, params):

        super().__init__()

        # Load: Model Parameters
        
        self.max_epochs = params["num_epochs"]
        self.learning_rate = params["learning_rate"]

        # Initialize: Regression Model 

        self.regressor = nn.Linear(1, 1)

    #----------------------------
    # Create: Objective Function
    #----------------------------

    def objective(self, preds, labels):
    
        # Format: Labels

        labels = labels.reshape(preds.size()).type(preds.type())

        # Objective: Mean Squared Error

        cost = nn.MSELoss() 

        loss = cost(preds, labels) 

        # Logging: Loss

        self.log("loss", loss, on_step = True, on_epoch = True)

        return loss

    #----------------------------
    # Create: Optimizer Function
    #----------------------------

    def configure_optimizers(self):

        optimizer = torch.optim.Adam(self.parameters(), lr = self.learning_rate)

        return optimizer

    #----------------------------
    # Create: Model Forward Pass
    #----------------------------

    def forward(self, samples):

        return self.regressor(samples)

    #----------------------------
    # Create: Train Cycle (Epoch)
    #----------------------------

    def training_step(self, batch, batch_idx):

        # Load: Data Batch

        samples, labels = batch

        preds = self(samples)

        # Calculate: Training Loss

        loss = self.objective(preds, labels)
       
        return loss

    #----------------------------
    # Run: Post Training Script
    #----------------------------

    def training_epoch_end(self, train_step_outputs): 

        # Update: Training Plots

        if(logger_choice == 1):
            
            if(self.current_epoch > 0):

                logger = self.logger.experiment

                logger.log_training_loss(self.current_epoch)

                # Finalize: Learned Features & Metrics ( Video )

                if(self.current_epoch == self.max_epochs - 1):

                    logger.finalize_results()

    #----------------------------
    # Create: Validation Cycle 
    #----------------------------

    def validation_step(self, batch, batch_idx):

        samples, labels = batch

        preds = self(samples)
    
        return samples, labels, preds

    #----------------------------
    # Run: Post Validation Script
    #----------------------------

    def validation_epoch_end(self, val_step_outputs): 

        # Organize: Validation Outputs
 
        all_samples, all_labels, all_preds = [], [], []
    
        for group in val_step_outputs:

            samples, labels, preds = group

            all_labels.append( labels )
            all_samples.append( samples )
            all_preds.append( preds.detach() )

        all_preds = torch.cat(all_preds)
        all_labels = torch.cat(all_labels)
        all_samples = torch.cat(all_samples)

        # Logger: Visualizations  

        
        if(logger_choice == 0):
        
            self.log_linear_regression(all_samples, all_labels, all_preds, self.current_epoch)
            
        else:
            
            logger = self.logger.experiment
            logger.log_linear_regression(all_samples, all_labels, all_preds, self.current_epoch)
            
    
    #----------------------------
    # Logging: Feature Embeddings
    #----------------------------

    def log_linear_regression(self, samples, labels, preds, epoch, z = 4, f_s = 20, p_s = (15, 11)):

        # Format: Feature Vectors

        preds = torch.squeeze(preds)
        samples = torch.squeeze(samples)

        # Format: Plot

        plt.style.use("seaborn")

        # Assign: Figure Name

        name = str(epoch).zfill(z) + ".png"

        # Visualize: Results

        fig = plt.figure(figsize = p_s)

        ax = fig.add_subplot()

        ax.scatter( samples, labels, c = "blue")
        ax.plot( samples, preds, c = "red")
        ax.set_xlabel("x1", fontsize = f_s)
        ax.set_ylabel("x2", fontsize = f_s)

        fig.suptitle("Learned Linear Regression", fontsize = f_s)

        plt.subplots_adjust(top = 0.90)

        logger = self.logger.experiment
        logger.add_figure(name,  plt.gcf())


In [9]:
# Lastly, lets create a "Trainer" that will train and validate our model

In [10]:
# Initialize: Gloabl Seed

seed_everything(seed, workers = True)

# Generate: Synthetic Dataset
     
dataset = load_data(path_dataset)
        
params["train"] = dataset
    
# Initialize: Formatter

dataset = PLDM(params)

# Initialize: Model

model = Linear_Regression(params)

# Initialize: Logger 

if(logger_choice == 0):
    logger = pl_loggers.TensorBoardLogger(path_results, name = "", version = 0)
else:
    logger = Logger(path_results, name = "", version = 0)
    
# Train: Model

if(use_gpu):

    # Initialize: GPU Trainer

    trainer = Trainer( logger = logger,
                       deterministic = True,
                       default_root_dir = path_results,
                       check_val_every_n_epoch = valid_rate,
                       max_epochs = num_epochs, num_nodes = 1,
                       num_sanity_val_steps = 0, gpus = gpu_list,
                       plugins = DDPPlugin(find_unused_parameters=False, ) )
else:

    # Initialize: CPU Trainer

    trainer = Trainer( logger = logger,
                       deterministic = True,
                       max_epochs = num_epochs,
                       num_sanity_val_steps = 0,
                       default_root_dir = path_results,
                       check_val_every_n_epoch = valid_rate )

trainer.fit(model, dataset)


Global seed set to 123
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs

  | Name      | Type   | Params
-------------------------------------
0 | regressor | Linear | 2     
-------------------------------------
2         Trainable params
0         Non-trainable params
2         Total params
0.000     Total estimated model params size (MB)


Training: -1it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]