# Example

In [1]:
from utils import datasets, kl_cpd, models
%load_ext autoreload
%autoreload 2

import warnings
warnings.filterwarnings("ignore")

import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks.early_stopping import EarlyStopping

import os
# os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]="1"

In [2]:
experiments_name = 'explosion'
train_dataset, test_dataset = datasets.CPDDatasets(experiments_name=experiments_name).get_dataset_()

Equal sampling is impossible, do random sampling.


# Example with Different Layers

In [3]:
from utils import datasets, kl_cpd, models_v2 as models, nets_tl, nets_original
%load_ext autoreload
%autoreload 2

import warnings
warnings.filterwarnings("ignore")

import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks.early_stopping import EarlyStopping

import torch

import os
# os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]="1"

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [4]:
experiments_name = 'explosion'
train_dataset, test_dataset = datasets.CPDDatasets(experiments_name=experiments_name).get_dataset_()

Equal sampling is impossible, do random sampling.


In [5]:
args = {}

block_type = "tcl"
bias = "all"

args["seed"] = 102
args["block_type"] = block_type
args["bias"] = bias
args["epochs"] = 30

args['wnd_dim'] = 4 # 8
args['batch_size'] = 8
args['lr'] = 1e-4
args['weight_decay'] = 0.
args['grad_clip'] = 10
args['CRITIC_ITERS'] = 5
args['weight_clip'] = .1
args['lambda_ae'] = 0.1 #0.001
args['lambda_real'] = 10 #0.1
args['num_layers'] = 1
args['window_1'] = 4 # 8
args['window_2'] = 4 # 8
args['sqdist'] = 50

if args["block_type"] == "tcl3d":
    # For TCL3D
    args['data_dim'] = (192, 8, 8)
    args['RNN_hid_dim'] = (16, 4, 4) # 3072
    args['emb_dim'] = (32, 8, 8) # 3072
    args['bias_rank'] = 1
    
elif args["block_type"] == "tcl":
    # For TCL
    args['data_dim'] = (192, 8, 8)
    args['RNN_hid_dim'] = (64, 8, 8) # 3072
    args['emb_dim'] = (64, 8, 8) # 3072
    args['bias_rank'] = 1

elif args["block_type"] == "linear":
    # For Linear
    args['data_dim'] = 12288
    args['RNN_hid_dim'] = 16 # 3072
    args['emb_dim'] = 100 # 3072

elif args["block_type"] == "masked":
    # For Linear
    args['data_dim'] = 12288
    args['RNN_hid_dim'] = 16 # 3072
    args['emb_dim'] = 100 # 3072
    args["alphaD"] = 1e-3
    args["alphaG"] = 1e-4


In [6]:

seed = args["seed"]
models.fix_seeds(seed)
experiments_name = ('explosion')
    
if args["block_type"] == "linear":
    netG = nets_original.NetG(args)
    netD = nets_original.NetD(args)
elif args["block_type"] == "masked":
    netG = nets_original.NetG_Masked(args)
    netD = nets_original.NetD_Masked(args)
else:
    netG = nets_tl.NetG_TL(args, block_type=args["block_type"], bias=args["bias"])
    netD = nets_tl.NetD_TL(args, block_type=args["block_type"], bias=args["bias"])

kl_cpd_model = models.KLCPDVideo(netG, netD, args, train_dataset=train_dataset, test_dataset=test_dataset)

Using cache found in /home/eromanenkova/.cache/torch/hub/facebookresearch_pytorchvideo_main


In [7]:
# %%debug

logger = TensorBoardLogger(save_dir='logs/explosion', name='kl_cpd')
early_stop_callback = EarlyStopping(monitor="val_mmd2_real_D", stopping_threshold=1e-5, 
                                    verbose=True, mode="min", patience=5)



for param in kl_cpd_model.extractor.parameters():
    param.requires_grad = False

trainer = pl.Trainer(
    max_epochs=1, # 100
    gpus='1',
    # devices='1',
    benchmark=True,
    check_val_every_n_epoch=5,
    gradient_clip_val=args['grad_clip'],
    logger=logger,
    callbacks=early_stop_callback
)

trainer.fit(kl_cpd_model)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [1]

  | Name      | Type       | Params
-----------------------------------------
0 | netG      | NetG_TL    | 338 K 
1 | netD      | NetD_TL    | 338 K 
2 | extractor | Sequential | 2.0 M 
-----------------------------------------
677 K     Trainable params
2.0 M     Non-trainable params
2.7 M     Total params
10.736    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

torch.Size([8, 4, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Siz

Training: 0it [00:00, ?it/s]

torch.Size([8, 4, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Siz

torch.Size([8, 4, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Siz

torch.Size([8, 4, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Siz

torch.Size([8, 4, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Siz

torch.Size([8, 4, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Siz

torch.Size([8, 4, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Siz

torch.Size([8, 4, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Siz

torch.Size([8, 4, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Siz

torch.Size([8, 4, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Siz

torch.Size([8, 4, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Siz

torch.Size([8, 4, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Siz

torch.Size([8, 4, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Siz

torch.Size([8, 4, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Size([8, 64, 8, 8])
torch.Siz

In [17]:
kl_cpd_model.netD.mask_loss().item()#, kl_cpd_model.netG.mask_loss().item()

40358.89453125

In [18]:
kl_cpd_model.netD.fc1.weight.data.norm(p=1, dim=1)

tensor([ 0.2787,  0.3042,  0.1664,  0.1965, 10.1422,  0.1573,  0.2564,  0.2095,  0.2782,  0.1643,  0.6125,  0.1514,  0.1625,  0.1806,  3.7543,  0.5240,  0.2675,  9.9076,  0.1713,  0.9702,  0.1636,
         0.2404,  0.1800,  1.0309,  0.6973,  2.9935,  2.1880,  5.0784,  0.5837,  0.1535,  0.1729,  0.1696,  0.4811,  7.8527,  1.6646,  0.1721,  0.4175,  0.1725,  0.1691,  4.6674,  4.3273,  0.1700,
         1.0074,  0.4366,  0.1632,  0.2968,  0.4841,  0.1684,  0.1547,  0.1551,  3.6853,  0.3235,  0.2754,  0.1913,  7.6544,  0.1643,  0.2551,  0.1604,  4.8249, 11.6831,  0.3494,  0.1560,  0.1585,
         0.1584,  1.0185,  0.2031,  1.8946,  0.1800,  0.1552,  0.1558,  5.4217,  0.4044,  0.1674,  7.2464,  0.4134,  0.1729,  0.2358,  0.1794,  5.1118,  0.9262,  0.1531,  0.1718,  0.8117,  7.6098,
         0.1677,  0.1636,  1.3359,  0.6282,  0.7006,  0.1548,  0.3499,  0.8079,  1.6855,  1.2281,  3.2816,  0.2088,  0.1541,  0.1657,  0.1646,  0.1559])

In [19]:
kl_cpd_model.netD.fc2.weight.data.norm(p=1, dim=0)

tensor([394.5896, 394.2867, 401.7104, 394.1988, 398.6168, 400.3464, 395.2538, 396.5166, 392.7541, 394.0240, 397.5858, 401.0594, 399.2442, 392.5852, 397.7998, 400.4600, 396.0073, 395.0919, 396.7904,
        398.8372, 400.6757, 398.6306, 395.3871, 393.4721, 398.5417, 395.5612, 392.9665, 398.3795, 398.5013, 400.8709, 397.7074, 402.6966, 401.1615, 396.8614, 400.7512, 397.2556, 396.5844, 395.0372,
        400.0142, 400.5509, 394.4132, 403.1198, 399.2444, 393.7011, 398.5826, 398.9307, 396.4456, 397.3938, 398.9660, 395.6059, 400.2199, 400.4991, 401.6479, 398.1312, 399.4326, 404.6075, 399.3232,
        398.0941, 394.9440, 397.4167, 400.2268, 397.9948, 398.5245, 399.6082, 397.7357, 397.2970, 402.6093, 400.3627, 396.2129, 399.5329, 397.2866, 394.1641, 401.1922, 399.8922, 397.4161, 400.3550,
        404.8687, 402.9345, 400.5465, 400.1188, 396.8640, 395.5505, 399.8173, 399.7761, 392.1123, 399.9803, 395.4568, 397.4171, 400.5359, 399.0378, 394.8703, 405.4914, 394.0442, 397.1314, 395.4755,
        40

In [30]:
mask = kl_cpd_model.netD.fc1.weight.data.norm(p=1, dim=1) > 0.2

In [5]:
kl_cpd_model.netD.mask1[0, 0] = mask * 1.

NameError: name 'mask' is not defined

In [32]:
kl_cpd_model.netD.mask1.sum()

tensor(58., device='cuda:0')

In [37]:
mask = kl_cpd_model.netD.fc2.bias.data.norm(p=1, dim=0)
mask

tensor(398.8411)

In [10]:
# kl_cpd_model.netG.masks[1][0,0]

In [None]:
torch.__version__

# Estimate parameters and flops

In [12]:
total_param = 0
for param in kl_cpd_model.netG.parameters():
    total_param += param.numel()

total_param

119552

In [13]:
total_param = 0
for param in kl_cpd_model.netD.parameters():
    total_param += param.numel()

total_param

161120

In [7]:
total_param = 0
for param in kl_cpd_model.extractor.parameters():
    total_param += param.numel()

total_param

2006130

In [10]:
kl_cpd_model.extractor

Sequential(
  (0): ResNetBasicStem(
    (conv): Conv2plus1d(
      (conv_t): Conv3d(3, 24, kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1), bias=False)
      (conv_xy): Conv3d(24, 24, kernel_size=(5, 1, 1), stride=(1, 1, 1), padding=(2, 0, 0), groups=24, bias=False)
    )
    (norm): BatchNorm3d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (activation): ReLU()
  )
  (1): ResStage(
    (res_blocks): ModuleList(
      (0): ResBlock(
        (branch1_conv): Conv3d(24, 24, kernel_size=(1, 1, 1), stride=(1, 2, 2), bias=False)
        (branch2): BottleneckBlock(
          (conv_a): Conv3d(24, 54, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)
          (norm_a): BatchNorm3d(54, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (act_a): ReLU()
          (conv_b): Conv3d(54, 54, kernel_size=(3, 3, 3), stride=(1, 2, 2), padding=[1, 1, 1], groups=54, bias=False)
          (norm_b): Sequential(
            (0): BatchNorm3d(54, ep

In [8]:
kl_cpd_model.train_dataset.__getitem__(0)[0].shape

torch.Size([3, 16, 256, 256])

In [5]:
from fvcore.nn import FlopCountAnalysis
import torch

In [7]:
inputs = (torch.randn((1, 3, 16, 256, 256)),)
flops = FlopCountAnalysis(kl_cpd_model.extractor, inputs)

In [7]:
flops.by_module()

Unsupported operator aten::mean encountered 15 time(s)
Unsupported operator aten::sigmoid encountered 15 time(s)
Unsupported operator aten::mul encountered 15 time(s)
Unsupported operator aten::add encountered 26 time(s)
The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
1.res_blocks.1.branch2.norm_b.1, 2.res_blocks.1.branch2.norm_b.1, 2.res_blocks.3.branch2.norm_b.1, 3.res_blocks.1.branch2.norm_b.1, 3.res_blocks.3.branch2.norm_b.1, 3.res_blocks.5.branch2.norm_b.1, 3.res_blocks.7.branch2.norm_b.1, 3.res_blocks.9.branch2.norm_b.1, 4.res_blocks.1.branch2.norm_b.1, 4.res_blocks.3.branch2.norm_b.1, 4.res_blocks.5.branch2.norm_b.1


Counter({'': 6626536704,
         '0': 232783872,
         '0.conv': 201326592,
         '0.conv.conv_t': 169869312,
         '0.conv.conv_xy': 31457280,
         '0.norm': 31457280,
         '0.activation': 0,
         '1': 1271662272,
         '1.res_blocks': 0,
         '1.res_blocks.0': 654312288,
         '1.res_blocks.0.branch1_conv': 37748736,
         '1.res_blocks.0.branch2': 616563552,
         '1.res_blocks.0.branch2.conv_a': 339738624,
         '1.res_blocks.0.branch2.norm_a': 70778880,
         '1.res_blocks.0.branch2.act_a': 0,
         '1.res_blocks.0.branch2.conv_b': 95551488,
         '1.res_blocks.0.branch2.norm_b': 17695584,
         '1.res_blocks.0.branch2.norm_b.0': 17694720,
         '1.res_blocks.0.branch2.norm_b.1': 864,
         '1.res_blocks.0.branch2.norm_b.1.block': 864,
         '1.res_blocks.0.branch2.norm_b.1.block.0': 432,
         '1.res_blocks.0.branch2.norm_b.1.block.1': 0,
         '1.res_blocks.0.branch2.norm_b.1.block.2': 432,
         '1.res_block

In [12]:
flops.by_operator()

Counter({'conv': 6093974784, 'batch_norm': 532561920})

In [6]:
x, noise = torch.randn(1, 16, 192,8,8), torch.randn(1, 1, 4, 4)
# x, noise = torch.randn(1, 16, 12288), torch.randn(1,16)
flops = FlopCountAnalysis(kl_cpd_model.netG, (x, x, noise))

In [7]:
flops.total()

Unsupported operator aten::mul encountered 3216 time(s)
Unsupported operator aten::add_ encountered 1560 time(s)
Unsupported operator aten::add encountered 129 time(s)
Unsupported operator aten::sigmoid encountered 64 time(s)
Unsupported operator aten::tanh encountered 32 time(s)
Unsupported operator aten::rsub encountered 32 time(s)
Unsupported operator aten::clone encountered 1 time(s)
The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.



29323264

In [8]:
kl_cpd_model.netD(x)
flops = FlopCountAnalysis(kl_cpd_model.netD, (x,))

In [9]:
flops.total()

Unsupported operator aten::mul encountered 3200 time(s)
Unsupported operator aten::add_ encountered 1552 time(s)
Unsupported operator aten::add encountered 128 time(s)
Unsupported operator aten::sigmoid encountered 64 time(s)
Unsupported operator aten::tanh encountered 32 time(s)
Unsupported operator aten::rsub encountered 32 time(s)
The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.



28622848

# Test Masked Nets

In [16]:
import torch
x, noise = torch.randn(1, 16, 12288), torch.randn(1,16)
kl_cpd_model.netG(x, x, noise);

In [17]:
kl_cpd_model.netD(x);

# Test TCL and TRL

In [41]:
from utils.nets_tl import TCL, TRL, TCL3D, TRLhalf
%load_ext autoreload
%autoreload 2

import torch
import tensorly as tl
from tensorly.tenalg import multi_mode_dot, tensordot
tl.set_backend('pytorch')

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [11]:
input_shape = (9, 8, 7)
output_shape = (6, 5, 4)

model = TCL(input_shape, output_shape, bias_rank=False, normalize=False)
U = [u.clone().detach() for u in model.factors]

In [12]:
x = torch.randn(input_shape)

y = model(x)
y2 = multi_mode_dot(x, U, transpose=True)

tl.norm(y - y2), y.shape

(tensor(3.1164e-05, grad_fn=<CopyBackwards>), torch.Size([6, 5, 4]))

In [13]:
model = TCL(input_shape, output_shape, bias_rank=2)
y = model(x)

In [14]:
input_shape = (9, 8, 7)
output_shape = (6, 5, 4)

model = TCL3D(input_shape, output_shape, bias_rank=0, normalize=False)
# U = [u.weight.data.clone().detach() for u in model.factors]
U = [u.clone().detach() for u in model.factors]

In [15]:
x = torch.randn(input_shape)

y = model(x)
y2 = multi_mode_dot(x, U, transpose=True)

tl.norm(y - y2), y.shape

(tensor(3.1226e-05, grad_fn=<CopyBackwards>), torch.Size([6, 5, 4]))

In [16]:
model = TCL3D(input_shape, output_shape, bias_rank="full")
y = model(x)

In [24]:
input_shape = (11, 10, 9)
output_shape = (8, 7)
core_shape = (6, 5, 4, 3, 2)

model = TRL(input_shape, output_shape, core_shape, bias_rank=False, freeze_modes=[])
Uinner = [u.clone().detach() for u in model.factors_inner]
Uouter = [u.clone().detach() for u in model.factors_outer]
core = model.core.clone().detach()

In [25]:
x = torch.randn(input_shape)

y = model(x)
y1 = multi_mode_dot(x, Uinner, transpose=True)
y2 = tensordot(y1, core, ([0, 1, 2], [0, 1, 2]))
y3 = multi_mode_dot(y2, Uouter, transpose=True)

tl.norm(y - y3), y.shape, y1.shape, y2.shape, y3.shape

(tensor(6817.6353, grad_fn=<CopyBackwards>),
 torch.Size([8, 7]),
 torch.Size([6, 5, 4]),
 torch.Size([3, 2]),
 torch.Size([8, 7]))

In [57]:
input_shape = (11, 10, 9)
output_shape = (8, 7)
core_shape = (6, 5, 4)

model = TRLhalf(input_shape, output_shape, core_shape, bias_rank=False, freeze_modes=[], normalize=False)
Uinner = [u.clone().detach() for u in model.factors_inner]
# Uouter = [u.clone().detach() for u in model.factors_outer]
core = model.core.clone().detach()

TRL: abc,ad,be,cf,defgh->gh


In [59]:
x = torch.randn(input_shape)

y = model(x)
y1 = multi_mode_dot(x, Uinner, transpose=True)
y2 = tensordot(y1, core, ([0, 1, 2], [0, 1, 2]))
# y3 = multi_mode_dot(y2, Uouter, transpose=True)

tl.norm(y - y2), y.shape, y1.shape, y2.shape, y3.shape

TRL forward: abc,ad,be,cf,defgh->gh, X shape: torch.Size([11, 10, 9]), 3, torch.Size([6, 5, 4, 8, 7])


(tensor(0.0005, grad_fn=<CopyBackwards>),
 torch.Size([8, 7]),
 torch.Size([6, 5, 4]),
 torch.Size([8, 7]),
 torch.Size([8, 7]))

# Load Masked Model

In [1]:
from utils import datasets, kl_cpd, models_v2 as models, nets_tl, nets_original, metrics

import warnings
warnings.filterwarnings("ignore")

import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks.early_stopping import EarlyStopping

import torch
import matplotlib.pyplot as plt
import numpy as np

import os

args_local = {"ext_name": "x3d_m", "timestamp": 1653480783}

In [2]:
experiments_name = 'explosion'
train_dataset, test_dataset = datasets.CPDDatasets(experiments_name=experiments_name).get_dataset_()

name = args_local["ext_name"]
timestamp = args_local["timestamp"]
checkpoint = torch.load(f'saves/models/model_{name}_tl_{timestamp}.pth')    
state_dict = checkpoint["checkpoint"]
args = checkpoint["args"]
block_type = args["block_type"]
bias = args["bias"]
if "name" not in args:
    args["name"] = name


seed = 0 # args["seed"]
models.fix_seeds(seed)
experiments_name = ('explosion')
    
if block_type == "linear":
    netG = nets_original.NetG(args)
    netD = nets_original.NetD(args)
elif args["block_type"] == "masked":
    netG = nets_original.NetG_Masked(args)
    netD = nets_original.NetD_Masked(args)
else:
    netG = nets_tl.NetG_TL(args, block_type=block_type, bias=bias)
    netD = nets_tl.NetD_TL(args, block_type=block_type, bias=bias)

extractor = torch.hub.load('facebookresearch/pytorchvideo:main', args["name"], pretrained=True)
extractor = torch.nn.Sequential(*list(extractor.blocks[:5]))

kl_cpd_model = models.KLCPDVideo(netG, netD, args, train_dataset=train_dataset, test_dataset=test_dataset, extractor=extractor)
kl_cpd_model.load_state_dict(state_dict)

Equal sampling is impossible, do random sampling.


Using cache found in /home/eromanenkova/.cache/torch/hub/facebookresearch_pytorchvideo_main


<All keys matched successfully>

In [3]:
kl_cpd_model.netD.mask_loss().item()#, kl_cpd_model.netG.mask_loss().item()

80.34313201904297

In [4]:
kl_cpd_model.netD.fc1.weight.data.norm(p=1, dim=1)

tensor([ 0.1619,  0.1587,  0.1654,  1.4966,  3.1787,  0.1532,  0.1614,  0.1567,  0.5163,  0.1555,  0.1786,  0.1497,  0.1537,  0.1584,  0.2136,  0.1579,  0.1750,  0.7142,  0.1520,  0.1626,  0.1559,
         1.4104,  0.1614,  0.1769,  0.1743,  0.2265,  0.1777,  0.1983,  0.1613,  0.1595,  0.1569,  0.1571,  0.1661,  5.3898,  0.1769,  0.1547,  0.1558, 11.1559,  0.1546,  6.2068,  0.2086,  0.1427,
         0.1580,  0.1969,  0.1663,  0.1605,  0.1563,  0.1512,  0.1414,  0.1546,  0.3022,  0.1583,  0.1634,  0.1629,  2.4532,  0.1590,  0.7235,  0.1506,  0.2006,  0.9223,  0.1620,  0.1620,  0.1528,
         0.1522,  0.2117,  0.1640,  0.2124,  0.1568,  0.1574,  0.1530,  6.4739,  0.1671,  0.1600,  0.2782,  0.1630,  3.8637,  0.1824,  0.1546,  0.2138,  0.1583,  0.1559,  0.1570,  0.1671,  0.4818,
         0.1596,  0.1570,  0.1751,  1.5675,  0.4229,  0.1594,  0.1640,  0.1730,  0.1605,  0.1723,  0.2427,  0.1574,  0.1472,  0.1530,  0.1568,  0.1586])

In [5]:
kl_cpd_model.netD.fc2.weight.data.norm(p=1, dim=0)

tensor([0.1487, 0.1487, 0.1479, 0.1477, 0.1482, 0.1494, 0.1482, 0.1468, 0.1500, 0.1497, 0.1501, 0.1485, 0.1492, 0.1470, 0.1481, 0.1489, 0.1504, 0.1498, 0.1491, 0.1484, 0.1496, 0.1480, 0.1477, 0.1472,
        0.1487, 0.1478, 0.1482, 0.1494, 0.1492, 0.1492, 0.1485, 0.1499, 0.1501, 0.1485, 0.1479, 0.1489, 0.1483, 0.1492, 0.1495, 0.1488, 0.1494, 0.1476, 0.1479, 0.1517, 0.1471, 0.1483, 0.1468, 0.1463,
        0.1476, 0.1483, 0.1486, 0.1463, 0.1484, 0.1484, 0.1494, 0.1494, 0.1483, 0.1492, 0.1490, 0.1482, 0.1502, 0.1485, 0.1493, 0.1486, 0.1485, 0.1468, 0.1494, 0.1487, 0.1509, 0.1495, 0.1487, 0.1496,
        0.1483, 0.1499, 0.1502, 0.1477, 0.1495, 0.1478, 0.1508, 0.1487, 0.1492, 0.1476, 0.1494, 0.1477, 0.1506, 0.1504, 0.1483, 0.1506, 0.1493, 0.1492, 0.1507, 0.1486, 0.1458, 0.1492, 0.1465, 0.1479,
        0.1471, 0.1494, 0.1483, 0.1469])

In [9]:
l1_norm = kl_cpd_model.netD.fc1.weight.data.norm(p=1, dim=1)
median = torch.quantile(l1_norm, q=0.9)
mask = l1_norm > median
mask.sum(), median

(tensor(10), tensor(0.9711))

In [10]:
kl_cpd_model.netD.mask1[0, 0] = mask * 1.
kl_cpd_model.netD.mask1[0, 0]

tensor([0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0.])

In [11]:
l1_norm = kl_cpd_model.netD.fc2.weight.data.norm(p=1, dim=0)
median = torch.quantile(l1_norm, q=0.9)
mask = l1_norm > median
mask.sum(), median, l1_norm

(tensor(10),
 tensor(0.1502),
 tensor([0.1487, 0.1487, 0.1479, 0.1477, 0.1482, 0.1494, 0.1482, 0.1468, 0.1500, 0.1497, 0.1501, 0.1485, 0.1492, 0.1470, 0.1481, 0.1489, 0.1504, 0.1498, 0.1491, 0.1484, 0.1496, 0.1480, 0.1477, 0.1472,
         0.1487, 0.1478, 0.1482, 0.1494, 0.1492, 0.1492, 0.1485, 0.1499, 0.1501, 0.1485, 0.1479, 0.1489, 0.1483, 0.1492, 0.1495, 0.1488, 0.1494, 0.1476, 0.1479, 0.1517, 0.1471, 0.1483, 0.1468, 0.1463,
         0.1476, 0.1483, 0.1486, 0.1463, 0.1484, 0.1484, 0.1494, 0.1494, 0.1483, 0.1492, 0.1490, 0.1482, 0.1502, 0.1485, 0.1493, 0.1486, 0.1485, 0.1468, 0.1494, 0.1487, 0.1509, 0.1495, 0.1487, 0.1496,
         0.1483, 0.1499, 0.1502, 0.1477, 0.1495, 0.1478, 0.1508, 0.1487, 0.1492, 0.1476, 0.1494, 0.1477, 0.1506, 0.1504, 0.1483, 0.1506, 0.1493, 0.1492, 0.1507, 0.1486, 0.1458, 0.1492, 0.1465, 0.1479,
         0.1471, 0.1494, 0.1483, 0.1469]))

In [12]:
kl_cpd_model.netD.mask2[0, 0] = mask * 1.
kl_cpd_model.netD.mask2[0, 0]

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 0., 0., 1., 1., 0., 1., 0., 0., 1., 0., 0., 0., 0., 0.,
        0., 0., 0., 0.])

In [13]:
args["alphaD"], args["alphaG"] = 0, 0

In [14]:
logger = TensorBoardLogger(save_dir='logs/explosion', name='kl_cpd')
early_stop_callback = EarlyStopping(monitor="val_mmd2_real_D", stopping_threshold=1e-5, 
                                    verbose=True, mode="min", patience=5)



for param in kl_cpd_model.extractor.parameters():
    param.requires_grad = False

trainer = pl.Trainer(
    max_epochs=5, # 100
    gpus='1',
    # devices='1',
    benchmark=True,
    check_val_every_n_epoch=1,
    gradient_clip_val=args['grad_clip'],
    logger=logger,
    callbacks=early_stop_callback
)

trainer.fit(kl_cpd_model)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name      | Type        | Params
------------------------------------------
0 | netG      | NetG_Masked | 1.4 M 
1 | netD      | NetD_Masked | 2.5 M 
2 | extractor | Sequential  | 2.0 M 
------------------------------------------
4.0 M     Trainable params
2.0 M     Non-trainable params
6.0 M     Total params
23.865    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Stopping threshold reached: val_mmd2_real_D = 1.244907821273955e-08 < 1e-05. Signaling Trainer to stop.


In [15]:
torch.save({"checkpoint": kl_cpd_model.state_dict(), "args": args}, 
            f'saves/models/model_{args_local["ext_name"]}_tl_8{args_local["timestamp"]}.pth')

# Load checkpoint

In [5]:
import torch
from pathlib import Path

In [6]:
path_models = Path("saves/models")
path_results = Path("saves/results")
for model_checkpoint in path_models.iterdir():
    print(model_checkpoint)
    checkpoint = torch.load(model_checkpoint)
    # print({key:checkpoint["args"][key] for key in ['RNN_hid_dim', 'emb_dim', 'bias_rank']})
    print(checkpoint["args"])
    try:
        print((path_results / f'metrics{Path(model_checkpoint).name[5:-4]}.txt').open().read())
    except:
        pass
    print()

saves/models/model_x3d_m_tl_1652866491.pth
{'seed': 102, 'block_type': 'tcl3d', 'bias': 'all', 'epochs': 150, 'wnd_dim': 4, 'batch_size': 8, 'lr': 0.0001, 'weight_decay': 0.0, 'grad_clip': 10, 'CRITIC_ITERS': 5, 'weight_clip': 0.1, 'lambda_ae': 0.1, 'lambda_real': 10, 'num_layers': 1, 'window_1': 4, 'window_2': 4, 'sqdist': 50, 'data_dim': (192, 8, 8), 'RNN_hid_dim': (32, 4, 4), 'emb_dim': (64, 8, 8), 'bias_rank': 8}
SEED: x3d_m tn 25
AUC: 1.4929453895212106
Time to FA 15.2317, delay detection 0.3175 for best-F1 threshold: 0.0513
TN 266, FP 34, FN 12, TP 3 for best-F1 threshold: 0.0513
Max F1 0.1154: for best-F1 threshold 0.0513
COVER 0.9433: for best-F1 threshold 0.0513
Max COVER 0.9806299603174603: for threshold -0.001
----------------------------------------------------------------------
SEED: x3d_m tn 25 norm
AUC: 1.935243146019494
Time to FA 15.9429, delay detection 0.3492 for best-F1 threshold: 0.8893
TN 296, FP 4, FN 14, TP 1 for best-F1 threshold: 0.8893
Max F1 0.1: for best-F1

In [None]:
# saves/models/model_x3d_m_tl_1653380917.pth
{'seed': 102, 'block_type': 'tcl3d', 'bias': 'all', 'epochs': 150, 'name': 'x3d_m', 'wnd_dim': 4, 'batch_size': 8, 'lr': 0.0001, 'weight_decay': 0.0, 'grad_clip': 10, 'CRITIC_ITERS': 5, 'weight_clip': 0.1, 'lambda_ae': 0.1, 'lambda_real': 10, 'num_layers': 1, 'window_1': 4, 'window_2': 4, 'sqdist': 50, 'data_dim': (192, 8, 8), 'RNN_hid_dim': (16, 4, 4), 'emb_dim': (32, 8, 8), 'bias_rank': 8}
# saves/models/model_x3d_m_tl_1653374508.pth
{'seed': 102, 'block_type': 'tcl3d', 'bias': 'all', 'epochs':  50, 'name': 'x3d_m', 'wnd_dim': 4, 'batch_size': 8, 'lr': 0.0001, 'weight_decay': 0.0, 'grad_clip': 10, 'CRITIC_ITERS': 5, 'weight_clip': 0.1, 'lambda_ae': 0.1, 'lambda_real': 10, 'num_layers': 1, 'window_1': 4, 'window_2': 4, 'sqdist': 50, 'data_dim': (192, 8, 8), 'RNN_hid_dim': (16, 4, 4), 'emb_dim': (32, 8, 8), 'bias_rank': 8}
