In [1]:
%load_ext autoreload
%autoreload 2

import torch
import torch.nn as nn
import numpy as np

import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
import helper
from dataset import TimeSeriesDataset

## Training

In [3]:
from Generator_transformer import Generator
from Discriminator_transformer import Discriminator
from F1_score_check import F1_score_check
from GAN import GAN

# from DeepConvLSTM_model import DeepConvNet
# val_model = DeepConvNet(hidden_size = 256)
# state_dict = torch.load("val_model.ckpt")["state_dict"]

from TransformerClassifier import TransformerClassifier
val_model = TransformerClassifier(in_channels = 6, output_size = 8, d_model = 50, nhead = 5, dim_feedforward = 10000, num_layers = 5)
state_dict = torch.load("RWHAR_transformer.ckpt")["state_dict"]

print("Loading Validation model")
state_dict = helper.remove_prefix_from_dict("model.", state_dict)
val_model.load_state_dict(state_dict, strict = False)
val_model.eval()

start_activity = 1
total_activities = 8
val_iter_size = 3 # Num of validation iterations to perform
batch_size = 32
data_size = (6, 50)
noise_len = 50
period = 100
max_retries = 2
init_dim_feedforward = 2048
dim_feedforward_exponent = 5

success = {}

for chosen_activity in range(start_activity, total_activities+1):
    try_num = 0
    dim_feedforward = init_dim_feedforward
    data = helper.load_RWHAR_activity(sel_location = "chest", activity_num = chosen_activity)
    dtset = TimeSeriesDataset(data)
    train_iter = torch.utils.data.DataLoader(dtset, batch_size = batch_size, shuffle = True, num_workers = 10, pin_memory = True)
    val = torch.ones((batch_size * val_iter_size, 1))
    val_iter = torch.utils.data.DataLoader(val, batch_size = batch_size, num_workers = 10, pin_memory = True)

    while (try_num < max_retries):
        print("Activity ", chosen_activity,", Try ",try_num)
        model = GAN(val_model = val_model, 
                    generator = Generator(noise_len = noise_len, output_size = data_size, nheads = 5, period = period, dim_feedforward = dim_feedforward),
                    discriminator = Discriminator(input_size = data_size, nheads = 5, period = period, dim_feedforward = dim_feedforward),
                    val_expected_output = chosen_activity - 1,
                    num_classes = total_activities,
                    noise_len = noise_len,
                    decay = 1,
                    dis_lr = 0.0002,
                    gen_lr = 0.0002,
                   )

        trainer = pl.Trainer(gpus=-1,
                             max_epochs=100,
                             callbacks = [F1_score_check(threshold_value = 0.95), 
                                         ], # Early stopping callback
                             logger = TensorBoardLogger(save_dir = 'Transformer_GAN_logs/', name = "RWHAR_act_"+str(chosen_activity)),
                             check_val_every_n_epoch = 5,
                             )
        result = trainer.fit(model, train_iter, val_iter)
        # verify if the model is trained
        if trainer.callback_metrics['val_f1_score'] >=0.95 or result != 1:
            print("Success!")
            success[chosen_activity] = trainer.logger.version
            break
        else: # model not trained
            dim_feedforward *= dim_feedforward_exponent
            try_num += 1
            if try_num == max_retries:
                success[chosen_activity] = None
                
print(success)

Loading Validation model
File exists. Loading
Selecting location :  chest
Windowing


GPU available: True, used: True
TPU available: None, using: 0 TPU cores


Done!
Activity  1 , Try  0



  | Name          | Type                  | Params
--------------------------------------------------------
0 | criterion     | BCELoss               | 0     
1 | generator     | Generator             | 467 K 
2 | discriminator | Discriminator         | 315 K 
3 | val_model     | TransformerClassifier | 5.1 M 
--------------------------------------------------------
5.9 M     Trainable params
0         Non-trainable params
5.9 M     Total params
23.705    Total estimated model params size (MB)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

GPU available: True, used: True
TPU available: None, using: 0 TPU cores



Activity  1 , Try  1



  | Name          | Type                  | Params
--------------------------------------------------------
0 | criterion     | BCELoss               | 0     
1 | generator     | Generator             | 2.1 M 
2 | discriminator | Discriminator         | 1.1 M 
3 | val_model     | TransformerClassifier | 5.1 M 
--------------------------------------------------------
8.4 M     Trainable params
0         Non-trainable params
8.4 M     Total params
33.634    Total estimated model params size (MB)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…


File exists. Loading
Selecting location :  chest
Windowing


GPU available: True, used: True
TPU available: None, using: 0 TPU cores


Done!
Activity  2 , Try  0



  | Name          | Type                  | Params
--------------------------------------------------------
0 | criterion     | BCELoss               | 0     
1 | generator     | Generator             | 467 K 
2 | discriminator | Discriminator         | 315 K 
3 | val_model     | TransformerClassifier | 5.1 M 
--------------------------------------------------------
5.9 M     Trainable params
0         Non-trainable params
5.9 M     Total params
23.705    Total estimated model params size (MB)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

GPU available: True, used: True
TPU available: None, using: 0 TPU cores



Activity  2 , Try  1



  | Name          | Type                  | Params
--------------------------------------------------------
0 | criterion     | BCELoss               | 0     
1 | generator     | Generator             | 2.1 M 
2 | discriminator | Discriminator         | 1.1 M 
3 | val_model     | TransformerClassifier | 5.1 M 
--------------------------------------------------------
8.4 M     Trainable params
0         Non-trainable params
8.4 M     Total params
33.634    Total estimated model params size (MB)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…


File exists. Loading
Selecting location :  chest
Windowing


GPU available: True, used: True
TPU available: None, using: 0 TPU cores


Done!
Activity  3 , Try  0



  | Name          | Type                  | Params
--------------------------------------------------------
0 | criterion     | BCELoss               | 0     
1 | generator     | Generator             | 467 K 
2 | discriminator | Discriminator         | 315 K 
3 | val_model     | TransformerClassifier | 5.1 M 
--------------------------------------------------------
5.9 M     Trainable params
0         Non-trainable params
5.9 M     Total params
23.705    Total estimated model params size (MB)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

RuntimeError: CUDA error: device-side assert triggered