# Toycode for PyTorch-Lightning

So far I have bumped into many technical difficulties related to Pytorch-lightning and Tensorboard. In this notebook, I will use a simple model and dataset to experiment with different functionalities including: 

<ul>
    <li>Tensorboard logging</li>
    <li>Callback</li> 
    <li>Freeze parameters</li>
</ul>

Other things I need to experiment 

<ul>
    <li>Rouge metric and other metrics</li>
    <li>Optimizer and scheduler</li>
</ul>

## Resources 

<ul>
    <li><a href="https://pytorch-lightning.readthedocs.io/en/stable/">PyTorch-lightning doc</a></li>
    <li><a href="https://pytorch-lightning.readthedocs.io/en/stable/rapid_prototyping_templates.html">Rapid prototyping</a></li>
</ul>
    

In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl
from transformers import AdamW
import time

In [2]:
nSamp = 40000

# Generate `nSamp` points
X = (np.random.rand(nSamp, 5) ** 2 * 5).tolist()
# Labels are the norms of `nSamp` points
y = (np.sqrt(np.sum(np.square(X), axis = 1))).tolist()

In [3]:
class MyDataset(Dataset): 
    def __init__(self, X, y): 
        super().__init__()
        self.X = X
        self.y = y
        
    def __len__(self): 
        return len(y)
    
    def __getitem__(self,idx):
        return {
            'source': torch.tensor(X[idx]), 
            'target': torch.tensor([y[idx]]) # !! Even the target is a single number, we wrap it as vector
        }

In [4]:
class MyModel(pl.LightningModule): 
    ''' Part 1: Define the architecture of model in init '''
    def __init__(self, hparams):
        super(MyModel, self).__init__()
        self.layer1 = nn.Linear(5, 10)
        self.layer2 = nn.Linear(10, 8)
        self.layer3 = nn.Linear(8, 1)
        self.hparams = hparams 
        
    ''' Part 2: Define the forward propagation '''
    def forward(self, x): 
        x = F.relu(self.layer1(x))
        x = F.relu(self.layer2(x))
        x = self.layer3(x)
        return x
    
    ''' Part 3: Prepare optimizer and scheduler '''
    def configure_optimizers(self): 
        optimizer = AdamW(self.parameters(), lr = self.hparams['learning_rate'])
        return optimizer
    
    ''' Part 4.1: Training logic '''
    def training_step(self, batch, batch_idx): 
        X = batch['source']
        y = batch['target']
        y_hat = self(X)    # Calls forward function 
        loss = F.mse_loss(y_hat, y)
        # We can log any metric that is (1) numeric (a single number); (2) aggregated / averaged from a batch
        self.log('train_loss', loss)
        return loss
    
    ''' Part 4.2: Validation logic '''
    def validation_step(self, batch, batch_idx): 
        X = batch['source']
        y = batch['target']
        y_hat = self(X)
        loss = F.mse_loss(y_hat, y)
        self.log('val_loss', loss)
        
    ''' Part 4.3: Test logic '''
    def test_step(self, batch, batch_idx): 
        X = batch['source']
        y = batch['target']
        y_hat = self(X)
        loss = F.mse_loss(y_hat, y)
        self.log('test_loss', loss)
        
    ''' Part 5: Data loaders '''
    def train_dataloader(self): 
        dataset = MyDataset(X[:int(0.7 * nSamp)], y[:int(0.7 * nSamp)])
        return DataLoader(dataset, batch_size = hparams['batch_size'])
    
    def val_dataloader(self): 
        dataset = MyDataset(X[int(0.7 * nSamp):int(0.9 * nSamp)], y[int(0.7 * nSamp):int(0.9 * nSamp)])
        return DataLoader(dataset, batch_size = hparams['batch_size'])
    
    def test_dataloader(self): 
        dataset = MyDataset(X[int(0.9 * nSamp):], y[int(0.9 * nSamp):])
        return DataLoader(dataset, batch_size = hparams['batch_size'])

In [5]:
hparams = {
    'learning_rate': 3e-4, 
    'batch_size': 16
}

In [6]:
model = MyModel(hparams)
trainer = pl.Trainer(gpus = 1, max_epochs = 2, progress_bar_refresh_rate = 20)

start = time.time()
trainer.fit(model)
end = time.time()
print(f'Total time: {end - start}s')

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name   | Type   | Params
----------------------------------
0 | layer1 | Linear | 60    
1 | layer2 | Linear | 88    
2 | layer3 | Linear | 9     


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…


Total time: 16.130785703659058s


In [7]:
trainer.test()

HBox(children=(HTML(value='Testing'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=…

test loss: 0.07056723535060883
test loss: 0.10660012066364288
test loss: 0.08180828392505646
test loss: 0.06224687397480011
test loss: 0.0994734913110733
test loss: 0.06271513551473618
test loss: 0.09154578298330307
test loss: 0.08703809231519699
test loss: 0.042949095368385315
test loss: 0.09328750520944595
test loss: 0.07743780314922333
test loss: 0.061594583094120026
test loss: 0.11613646149635315
test loss: 0.10203823447227478
test loss: 0.12461048364639282
test loss: 0.06826229393482208
test loss: 0.09328098595142365
test loss: 0.06989350914955139
test loss: 0.09224238246679306
test loss: 0.0701751559972763
test loss: 0.059912677854299545
test loss: 0.024801447987556458
test loss: 0.026453083381056786
test loss: 0.10041327029466629
test loss: 0.07021427899599075
test loss: 0.08009334653615952
test loss: 0.10775065422058105
test loss: 0.08325160294771194
test loss: 0.10658274590969086
test loss: 0.047958992421627045
test loss: 0.13010534644126892
test loss: 0.08168311417102814
test

test loss: 0.07873810827732086
test loss: 0.09289657324552536
test loss: 0.10570532083511353
test loss: 0.13814973831176758
test loss: 0.0722019299864769
test loss: 0.028554826974868774
test loss: 0.09070280194282532
test loss: 0.09393583238124847
test loss: 0.10113972425460815
test loss: 0.0994059294462204
test loss: 0.0893283486366272
test loss: 0.07913295179605484
test loss: 0.05246931314468384
test loss: 0.11401693522930145
test loss: 0.06844791769981384
test loss: 0.10376542806625366
test loss: 0.054520539939403534
test loss: 0.09056300669908524
test loss: 0.061915308237075806
test loss: 0.05712869018316269
test loss: 0.08664386719465256
test loss: 0.041125066578388214
test loss: 0.07865997403860092
test loss: 0.05941224470734596
test loss: 0.04252193123102188
test loss: 0.05628820136189461
test loss: 0.05488055944442749
test loss: 0.10520045459270477
test loss: 0.06125236302614212
test loss: 0.09866541624069214
test loss: 0.11587987095117569
test loss: 0.1240650862455368
test los

test loss: 0.10460186004638672
test loss: 0.06677919626235962
test loss: 0.13456407189369202
test loss: 0.12161386758089066
test loss: 0.065012127161026
test loss: 0.09830457717180252
test loss: 0.06899227201938629
test loss: 0.145668163895607
test loss: 0.057223379611968994
test loss: 0.1009112298488617
test loss: 0.07671003043651581
test loss: 0.07389690726995468
test loss: 0.09184983372688293
test loss: 0.07058108597993851
test loss: 0.05065827816724777
test loss: 0.05015117675065994
test loss: 0.16703903675079346
test loss: 0.06030430644750595
test loss: 0.13900262117385864
test loss: 0.16409623622894287
test loss: 0.05747305229306221
test loss: 0.14395669102668762
test loss: 0.06961900740861893
test loss: 0.10422500967979431
test loss: 0.05825541168451309
test loss: 0.07121515274047852
test loss: 0.09969053417444229
test loss: 0.05190003663301468
test loss: 0.06982491910457611
test loss: 0.1383909285068512
test loss: 0.09478609263896942
test loss: 0.06914899498224258
test loss: 0.

test loss: 0.09211041033267975
test loss: 0.05647309124469757
test loss: 0.06524097174406052
test loss: 0.09832818061113358
test loss: 0.04965042695403099
test loss: 0.09813772141933441
test loss: 0.05535106360912323
test loss: 0.10142924636602402
test loss: 0.08222630620002747
test loss: 0.15272411704063416
test loss: 0.06267894804477692
test loss: 0.044334348291158676
test loss: 0.04830383509397507
test loss: 0.05458392947912216
test loss: 0.06753334403038025
test loss: 0.06548978388309479
test loss: 0.040391068905591965
test loss: 0.0652860552072525
test loss: 0.06902171671390533
test loss: 0.031613197177648544
test loss: 0.06742040812969208
test loss: 0.0703127458691597
test loss: 0.0706527978181839
test loss: 0.11881061643362045
test loss: 0.09978818148374557
test loss: 0.07597698271274567
test loss: 0.11091794073581696
test loss: 0.047667644917964935
test loss: 0.054661259055137634
test loss: 0.05068361014127731
test loss: 0.06048218160867691
test loss: 0.0494147390127182
test lo

test loss: 0.050671838223934174
test loss: 0.13109087944030762
test loss: 0.0645313560962677
test loss: 0.05273529887199402
test loss: 0.11551333218812943
test loss: 0.09815524518489838
test loss: 0.11470183730125427
test loss: 0.04962554946541786
test loss: 0.02964792214334011
test loss: 0.12305491417646408
test loss: 0.09399903565645218
test loss: 0.06741571426391602
test loss: 0.07273216545581818
test loss: 0.08661795407533646
test loss: 0.050397299230098724
test loss: 0.0381157360970974
test loss: 0.12668491899967194
test loss: 0.07550020515918732
test loss: 0.09311654418706894
test loss: 0.03732422739267349
test loss: 0.06092367321252823
test loss: 0.03217519819736481
test loss: 0.13479024171829224
test loss: 0.09734421223402023
test loss: 0.10694960504770279
test loss: 0.1051916629076004
test loss: 0.08994784951210022
test loss: 0.11488886177539825
test loss: 0.048560768365859985
test loss: 0.06630498915910721
test loss: 0.08213050663471222
test loss: 0.07218831777572632
test los

test loss: 0.08413532376289368
test loss: 0.058477818965911865
test loss: 0.05752605199813843
test loss: 0.09496279060840607
test loss: 0.039091527462005615
test loss: 0.07204325497150421
test loss: 0.05505378544330597
test loss: 0.09408002346754074
test loss: 0.09380035847425461
test loss: 0.06276297569274902
test loss: 0.09640797972679138
test loss: 0.07147860527038574
test loss: 0.08326542377471924
test loss: 0.07033602893352509
test loss: 0.0805569440126419
test loss: 0.06466282904148102
test loss: 0.12189629673957825
test loss: 0.072409488260746
test loss: 0.11063681542873383
test loss: 0.0671694204211235
test loss: 0.08257973194122314
test loss: 0.05521692708134651
test loss: 0.09368737041950226
test loss: 0.07679577171802521
test loss: 0.1260359138250351
test loss: 0.09314960241317749
test loss: 0.09250606596469879
test loss: 0.09213300049304962
test loss: 0.10610802471637726
test loss: 0.03516155108809471
test loss: 0.05989931523799896
test loss: 0.0956168919801712
test loss: 0

test loss: 0.07293534278869629
test loss: 0.05006809160113335
test loss: 0.20166029036045074
test loss: 0.14322344958782196
test loss: 0.03145953640341759
test loss: 0.06093330681324005
test loss: 0.11757531017065048
test loss: 0.142354354262352
test loss: 0.11420896649360657
test loss: 0.06795142590999603
test loss: 0.06822577118873596
test loss: 0.08380773663520813
test loss: 0.09646229445934296
test loss: 0.05560609698295593
test loss: 0.06905168294906616
test loss: 0.07902709394693375
test loss: 0.05922075733542442
test loss: 0.02610449120402336
test loss: 0.08586869388818741
test loss: 0.05640198290348053
test loss: 0.1438177525997162
test loss: 0.11535947024822235
test loss: 0.07608577609062195
test loss: 0.08125313371419907
test loss: 0.11468586325645447
test loss: 0.022023338824510574
test loss: 0.025679808109998703
test loss: 0.06735800951719284
test loss: 0.08914583176374435
test loss: 0.09858300536870956
test loss: 0.07664848119020462
test loss: 0.14804401993751526
test loss

1

In [8]:
x1 = [1.,1.,2.,2.,3.]
x2 = X[0]

print(f'{x1} has actual norm {np.sqrt(np.sum(np.square(x1)))}, predicted {model(torch.tensor(x1).to("cuda"))[0]}')
print(f'{x2} has actual norm {np.sqrt(np.sum(np.square(x2)))}, predicted {model(torch.tensor(x2).to("cuda"))[0]}')

[1.0, 1.0, 2.0, 2.0, 3.0] has actual norm 4.358898943540674, predicted 4.50333833694458
[2.600147994400853, 2.215423672456916, 0.025293287273638825, 2.419506900588191, 4.969182847189589] has actual norm 6.497407421614215, predicted 6.323709487915039


In [9]:
# Start tensorboard.
#%load_ext tensorboard
%reload_ext tensorboard
%tensorboard --logdir lightning_logs/

Reusing TensorBoard on port 6006 (pid 29172), started 2:01:03 ago. (Use '!kill 29172' to kill it.)