In [4]:
import torch
import torch.nn as nn

In [200]:
class custom_GRU_cell(torch.nn.Module):

    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        
        # update gate
        self.linear_w_z = nn.Linear(self.input_dim, self.hidden_dim)
        self.linear_u_z = nn.Linear(self.hidden_dim, self.hidden_dim)
        self.activation_z = nn.Sigmoid()

        # reset gate
        self.linear_w_r = nn.Linear(self.input_dim, self.hidden_dim)
        self.linear_u_r = nn.Linear(self.hidden_dim, self.hidden_dim)
        self.activation_r = nn.Sigmoid()
        
        # output
        self.linear_w_h = nn.Linear(self.input_dim, self.hidden_dim)
        self.linear_u_h = nn.Linear(self.hidden_dim, self.hidden_dim)
        self.activation_h = nn.Tanh()
        
        
    def forward(self, x_t, h_prev):
        try:
            device = x_t.device
        except:
            device = 'cpu'
                
        output_z = self.activation_z(self.linear_w_z(x_t) + self.linear_u_z(h_prev))
        output_r = self.activation_r(self.linear_w_r(x_t) + self.linear_u_r(h_prev))
        hidden_hat = self.activation_h(self.linear_w_h(x_t) + torch.mul(output_r, self.linear_u_h(h_prev)))
        ones = torch.ones_like(output_z).to(device)
        hidden = torch.mul(output_z, h_prev) + torch.mul((ones - output_z), hidden_hat)
        
        return hidden
    
    
class custom_GRU(torch.nn.Module):

    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        
        self.cell = custom_GRU_cell(input_dim, hidden_dim)
    
    def forward(self, inputs):
        try:
            device = inputs.device
        except:
            device = 'cpu'

        outputs = []
        out_t = torch.zeros(inputs.shape[0], 1, self.hidden_dim)
        
        for t, x_t in enumerate(inputs.chunk(inputs.shape[1], dim=1)):
            out_t = self.cell(x_t, out_t)
            outputs.append(out_t.squeeze(1).detach().cpu())
        outputs = torch.stack(outputs, 1)
        return outputs, out_t.squeeze(1)

In [195]:
x = torch.rand(3, 5, 25)

In [196]:
model = custom_GRU(25, 8)

### Test custom_GPU_TCL

In [5]:
from utils import datasets, kl_cpd, models_v2 as models, nets_tl, nets_original
%load_ext autoreload
%autoreload 2

import torch 

import warnings
warnings.filterwarnings("ignore")

import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks.early_stopping import EarlyStopping

import os
# os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]="1"

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [6]:
args = {}
block_type="tcl3d"

experiments_name = 'explosion'
train_dataset, test_dataset = datasets.CPDDatasets(experiments_name=experiments_name).get_dataset_()

Equal sampling is impossible, do random sampling.


In [7]:
if block_type == "tcl3d":
    # For TCL3D
    args['data_dim'] = (192, 8, 8)
    args['RNN_hid_dim'] = (32, 8, 8) # 3072
    args['emb_dim'] = (64, 8, 8) # 3072
    args['bias_rank'] = 4

elif block_type == "tcl":
    # For TCL
    args['data_dim'] = (192, 8, 8)
    args['RNN_hid_dim'] = (32, 8, 8) # 3072
    args['emb_dim'] = (64, 8, 8) # 3072

elif block_type == "trl":
    # For TRL
    args['data_dim'] = (192, 8, 8)
    args['RNN_hid_dim'] = (8, 8, 8, 8) # 3072
    args['emb_dim'] = (8, 8, 8, 8) # 3072
    args['ranks_input'] = (16, 4, 4, 4, 4, 4, 4) # 3072
    args['ranks_output'] = (4, 4, 4, 4, 16, 4, 4) # 3072
    args['ranks_gru'] = (4, 4, 4, 4, 4, 4, 4, 4) # 3072


input = torch.randn(5, 4, 192, 8, 8)
model = nets_tl.NetD_TL(args, block_type=block_type, bias="all")

In [8]:
model(input)[1].shape

torch.Size([5, 4, 192, 8, 8])

In [9]:
total = 0
for param in model.parameters():
    total += param.numel()

total

287232

In [10]:
if block_type == "tcl3d":
    # For TCL3D
    args['data_dim'] = (192, 8, 8)
    args['RNN_hid_dim'] = (32, 8, 8) # 3072
    args['emb_dim'] = (64, 8, 8) # 3072
    args['bias_rank'] = 4

elif block_type == "tcl":
    # For TCL
    args['data_dim'] = (192, 8, 8)
    args['RNN_hid_dim'] = (32, 8, 8) # 3072
    args['emb_dim'] = (64, 8, 8) # 3072

elif block_type == "trl":
    # For TRL
    args['data_dim'] = (192, 8, 8)
    args['RNN_hid_dim'] = (8, 8, 8, 8) # 3072
    args['emb_dim'] = (8, 8, 8, 8) # 3072
    args['ranks_input'] = (16, 4, 4, 4, 4, 4, 4) # 3072
    args['ranks_output'] = (4, 4, 4, 4, 16, 4, 4) # 3072
    args['ranks_gru'] = (4, 4, 4, 4, 4, 4, 4, 4) # 3072
    
input = torch.randn(5, 4, 192, 8, 8)
model = nets_tl.NetG_TL(args, block_type=block_type, bias="all")

In [11]:
model(input, input, 0).shape

torch.Size([5, 4, 192, 8, 8])

In [12]:
total = 0
for param in model.parameters():
    total += param.numel()

total

242240

In [13]:
block_type = "tcl3d"

args['wnd_dim'] = 4
args['batch_size'] = 8
args['lr'] = 1e-4
args['weight_decay'] = 0.
args['grad_clip'] = 10
args['CRITIC_ITERS'] = 5
args['weight_clip'] = .1
args['lambda_ae'] = 0.1 #0.001
args['lambda_real'] = 10 #0.1
args['num_layers'] = 1

if block_type == "tcl":
    # For TCL3D
    args['data_dim'] = (192, 8, 8)
    args['RNN_hid_dim'] = (32, 8, 8) # 3072
    args['emb_dim'] = (64, 8, 8) # 3072
    args['bias_rank'] = 4

elif block_type == "tcl":
    # For TCL
    args['data_dim'] = (192, 8, 8)
    args['RNN_hid_dim'] = (32, 8, 8) # 3072
    args['emb_dim'] = (64, 8, 8) # 3072

elif block_type == "trl":
    # For TRL
    args['data_dim'] = (192, 8, 8)
    args['RNN_hid_dim'] = (8, 8, 8, 8) # 3072
    args['emb_dim'] = (8, 8, 8, 8) # 3072
    args['ranks_input'] = (16, 4, 4, 4, 4, 4, 4) # 3072
    args['ranks_output'] = (4, 4, 4, 4, 16, 4, 4) # 3072
    args['ranks_gru'] = (4, 4, 4, 4, 4, 4, 4, 4) # 3072

elif block_type == "linear":
    # For Linear
    args['data_dim'] = 12288
    args['RNN_hid_dim'] = 256 # 3072
    args['emb_dim'] = 1024 # 3072

args['window_1'] = 4
args['window_2'] = 4

args['sqdist'] = 50

In [14]:
seed = 0
models.fix_seeds(seed)
experiments_name = ('explosion')
    
if block_type == "linear":
    netG = nets_original.NetD(args)
    netD = nets_original.NetD(args)
else:
    netG = nets_tl.NetG_TL(args, block_type=block_type, bias="all")
    netD = nets_tl.NetD_TL(args, block_type=block_type, bias="all")


kl_cpd_model = models.KLCPDVideo(netG, netD, args, train_dataset=train_dataset, test_dataset=test_dataset)

Using cache found in /home/eromanenkova/.cache/torch/hub/facebookresearch_pytorchvideo_main


In [15]:
inputs = (torch.randn((5, 3, 16, 256, 256)),)
kl_cpd_model.forward(inputs,)

tensor([0.0144, 0.0142, 0.0143, 0.0176, 0.0154], grad_fn=<SumBackward1>)

In [16]:
total = 0
for param in kl_cpd_model.parameters():
    total += param.numel()

total

2535602