In [1]:
%load_ext autoreload
%autoreload 2

import torch
import torch.nn as nn
import numpy as np

import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
import helper
from dataset import TimeSeriesDataset

## Training

In [None]:
from Generator_LSTM import Generator
from Discriminator_LSTM import Discriminator
from F1_score_check import F1_score_check
from GAN import GAN


# from DeepConvLSTM_model import DeepConvNet
# val_model = DeepConvNet(hidden_size = 256)
# state_dict = torch.load("logs/PAMAP2_LSTM_model/version_2/checkpoints/epoch=8.ckpt")["state_dict"]

from TransformerClassifier import TransformerClassifier
val_model = TransformerClassifier(in_channels = 6, output_size = 8, d_model = 50, nhead = 5, dim_feedforward = 10000, num_layers = 5)
state_dict = torch.load("RWHAR_transformer.ckpt")["state_dict"]

state_dict = helper.remove_prefix_from_dict("model.", state_dict)
val_model.load_state_dict(state_dict, strict = False)
val_model.eval()

start_activity = 1
total_activities = 8
val_iter_size = 3 # Num of validation iterations to perform
batch_size = 20
data_size = (6, 50)
noise_len = 100

success = {}

for chosen_activity in range(start_activity, total_activities+1):
    data = helper.load_RWHAR_activity(sel_location = "chest", activity_num = chosen_activity)
    train_iter, val_iter = helper.get_dataloaders(data, batch_size = batch_size, output_size = total_activities, val_pc = val_iter_size)

    model = GAN(val_model = val_model, 
                noise_len = noise_len, 
                val_expected_output = chosen_activity-1,
                generator = Generator(hidden_size = 100, num_layers = 2, bidirectional = False, noise_len = noise_len, output_size = data_size),
                discriminator = Discriminator(hidden_size = 100, bidirectional = False, num_layers = 2, input_size = data_size),
                num_classes = total_activities
               )

    trainer = pl.Trainer(gpus=-1,
                         max_epochs=100,
                         callbacks = [F1_score_check(), 
                                     ], # Early stopping callback
                         logger = TensorBoardLogger(save_dir = 'LSTM_GAN_logs/', name = "RWHAR_act_"+str(chosen_activity)),
                         check_val_every_n_epoch = 5,
                         )
    trainer.fit(model, train_iter, val_iter)
    # verify if the model is trained
    if trainer.callback_metrics['val_f1_score'] >=0.95:
        print("Success!")
        success[chosen_activity] = trainer.logger.version
        break
    else: # model not traineds:
        success[chosen_activity] = None
                
print(success)

File exists. Loading
Selecting location :  chest
Windowing


GPU available: True, used: True
TPU available: None, using: 0 TPU cores


Done!



  | Name          | Type                  | Params
--------------------------------------------------------
0 | criterion     | BCELoss               | 0     
1 | generator     | Generator             | 765 K 
2 | discriminator | Discriminator         | 636 K 
3 | val_model     | TransformerClassifier | 5.1 M 
--------------------------------------------------------
6.5 M     Trainable params
0         Non-trainable params
6.5 M     Total params
26.176    Total estimated model params size (MB)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…


File exists. Loading
Selecting location :  chest
Windowing


GPU available: True, used: True
TPU available: None, using: 0 TPU cores


Done!



  | Name          | Type                  | Params
--------------------------------------------------------
0 | criterion     | BCELoss               | 0     
1 | generator     | Generator             | 765 K 
2 | discriminator | Discriminator         | 636 K 
3 | val_model     | TransformerClassifier | 5.1 M 
--------------------------------------------------------
6.5 M     Trainable params
0         Non-trainable params
6.5 M     Total params
26.176    Total estimated model params size (MB)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…


File exists. Loading
Selecting location :  chest
Windowing


GPU available: True, used: True
TPU available: None, using: 0 TPU cores


Done!



  | Name          | Type                  | Params
--------------------------------------------------------
0 | criterion     | BCELoss               | 0     
1 | generator     | Generator             | 765 K 
2 | discriminator | Discriminator         | 636 K 
3 | val_model     | TransformerClassifier | 5.1 M 
--------------------------------------------------------
6.5 M     Trainable params
0         Non-trainable params
6.5 M     Total params
26.176    Total estimated model params size (MB)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…




File exists. Loading
Selecting location :  chest
Windowing


In [None]:
## Things to Try to stabilise GAN
## 3. Use one sided label smoothing for discriminator (helps a little)
## 5. Remove Linear and use conv only (helps a ton)
## 6. Use dropout of 0.5 