In [1]:
%load_ext autoreload
%autoreload 2

import torch
import torch.nn as nn
import numpy as np

import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
import helper
from dataset import TimeSeriesDataset

## Training

In [3]:
from Generator_LSTM import Generator
from Discriminator_LSTM import Discriminator
from F1_score_check import F1_score_check
from GAN import GAN
from PAMAP2_model import DeepConvNet

val_model = DeepConvNet(hidden_size = 256)
state_dict = torch.load("logs/PAMAP2_LSTM_model/version_2/checkpoints/epoch=8.ckpt")["state_dict"]

# from PAMAP2_model_acc import DeepConvNet
# val_model = DeepConvNet(hidden_size = 512)
# state_dict = torch.load("Validation_transformer_logs/PAMAP2_acceleration/version_21/checkpoints/epoch=18.ckpt")["state_dict"]

state_dict = helper.remove_prefix_from_dict("model.", state_dict)
val_model.load_state_dict(state_dict, strict = False)
val_model.eval()

total_activities = 7
val_iter_size = 3 # Num of validation iterations to perform
batch_size = 20
data_size = (27, 100)
noise_len = 100

for chosen_activity in range(1,total_activities+1):
    data = helper.load_PAMAP2_activity(activity_num = chosen_activity)
    dtset = TimeSeriesDataset(data)
    train_iter = torch.utils.data.DataLoader(dtset, batch_size = batch_size, shuffle = True, num_workers = 10)
    val = torch.ones((batch_size * val_iter_size, 1))
    val_iter = torch.utils.data.DataLoader(val, batch_size = batch_size, num_workers = 10)

    model = GAN(val_model = val_model, 
                noise_len = 100, 
                val_expected_output = chosen_activity-1,
                generator = Generator(hidden_size = 100, num_layers = 2, bidirectional = False, noise_len = noise_len, output_size = data_size),
                discriminator = Discriminator(hidden_size = 100, bidirectional = False, num_layers = 2, input_size = data_size)
               )

    trainer = pl.Trainer(gpus=-1,
                         max_epochs=100,
                         callbacks = [F1_score_check(), 
                                     ], # Early stopping callback
                         logger = TensorBoardLogger(save_dir = 'LSTM_GAN_logs/', name = "PAMAP2_act_"+str(chosen_activity)),
                         check_val_every_n_epoch = 5,
                         )
    trainer.fit(model, train_iter, val_iter)

File exists. Loading
Keep only acceleration
Keep only activity number  1
Windowing


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Done!



  | Name          | Type          | Params
------------------------------------------------
0 | criterion     | BCELoss       | 0     
1 | generator     | Generator     | 765 K 
2 | discriminator | Discriminator | 651 K 
3 | val_model     | DeepConvNet   | 6 M   


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…


File exists. Loading
Keep only acceleration
Keep only activity number  2
Windowing


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Done!



  | Name          | Type          | Params
------------------------------------------------
0 | criterion     | BCELoss       | 0     
1 | generator     | Generator     | 765 K 
2 | discriminator | Discriminator | 651 K 
3 | val_model     | DeepConvNet   | 6 M   


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…


File exists. Loading
Keep only acceleration
Keep only activity number  3
Windowing


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Done!



  | Name          | Type          | Params
------------------------------------------------
0 | criterion     | BCELoss       | 0     
1 | generator     | Generator     | 765 K 
2 | discriminator | Discriminator | 651 K 
3 | val_model     | DeepConvNet   | 6 M   


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…


File exists. Loading
Keep only acceleration
Keep only activity number  4
Windowing


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Done!



  | Name          | Type          | Params
------------------------------------------------
0 | criterion     | BCELoss       | 0     
1 | generator     | Generator     | 765 K 
2 | discriminator | Discriminator | 651 K 
3 | val_model     | DeepConvNet   | 6 M   


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…


File exists. Loading
Keep only acceleration
Keep only activity number  5
Windowing


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Done!



  | Name          | Type          | Params
------------------------------------------------
0 | criterion     | BCELoss       | 0     
1 | generator     | Generator     | 765 K 
2 | discriminator | Discriminator | 651 K 
3 | val_model     | DeepConvNet   | 6 M   


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…


File exists. Loading
Keep only acceleration
Keep only activity number  6
Windowing


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Done!



  | Name          | Type          | Params
------------------------------------------------
0 | criterion     | BCELoss       | 0     
1 | generator     | Generator     | 765 K 
2 | discriminator | Discriminator | 651 K 
3 | val_model     | DeepConvNet   | 6 M   


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…


File exists. Loading
Keep only acceleration
Keep only activity number  7
Windowing


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Done!



  | Name          | Type          | Params
------------------------------------------------
0 | criterion     | BCELoss       | 0     
1 | generator     | Generator     | 765 K 
2 | discriminator | Discriminator | 651 K 
3 | val_model     | DeepConvNet   | 6 M   


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…




In [None]:
## Things to Try to stabilise GAN
## 3. Use one sided label smoothing for discriminator (helps a little)
## 5. Remove Linear and use conv only (helps a ton)
## 6. Use dropout of 0.5 