In [105]:
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from utils import (
    plot_3d_motion_frames_multiple,
    plot_3d_motion_animation,
    plot_3d_motion_frames_multiple,
    activation_dict,
)
from glob import glob
import matplotlib.pyplot as plt
import os
import shutil
import numpy as np
from modules.loss import VAE_Loss

from torch import Tensor
from typing import List, Optional
import copy

def lengths_to_mask(lengths: List[int],
                    device: torch.device,
                    max_len: int = None) -> Tensor:
    """
    Provides a mask, of length max_len or the longest element in lengths. With True for the elements less than the length for each length in lengths.
    """
    lengths = torch.tensor(lengths, device=device)
    max_len = max_len if max_len else max(lengths)
    mask = torch.arange(max_len, device=device).expand(len(lengths), max_len) < lengths.unsqueeze(1)
    return mask

def _get_clones(module, N):
    return nn.ModuleList([copy.deepcopy(module) for i in range(N)])

def _get_clone(module):
    return copy.deepcopy(module)

class SkipTransformerEncoder(nn.Module):
    def __init__(self, encoder_layer, num_layers, norm=None, d_model=256):
        super().__init__()
        self.d_model = d_model

        self.num_layers = num_layers
        self.norm = norm

        assert num_layers % 2 == 1

        num_block = (num_layers-1)//2
        self.input_blocks = _get_clones(encoder_layer, num_block)
        self.middle_block = _get_clone(encoder_layer)
        self.output_blocks = _get_clones(encoder_layer, num_block)
        self.linear_blocks = _get_clones(nn.Linear(2*self.d_model, self.d_model), num_block)

        self._reset_parameters()

    def _reset_parameters(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def forward(self, src,
                mask: Optional[Tensor] = None,
                src_key_padding_mask: Optional[Tensor] = None,
                pos: Optional[Tensor] = None):
        x = src

        xs = []
        for module in self.input_blocks:
            x = module(x, src_mask=mask,
                           src_key_padding_mask=src_key_padding_mask)
            xs.append(x)

        x = self.middle_block(x, src_mask=mask,
                           src_key_padding_mask=src_key_padding_mask)

        for (module, linear) in zip(self.output_blocks, self.linear_blocks):
            x = torch.cat([x, xs.pop()], dim=-1)
            x = linear(x)
            x = module(x, src_mask=mask,
                           src_key_padding_mask=src_key_padding_mask, 
                           )

        if self.norm is not None:
            x = self.norm(x)
        return x

class CascadingTransformerAutoEncoder(nn.Module):
    def __init__(self, d_model, nhead, num_layers, dim_feedforward=2048, dropout=0.1, activation='relu', seq_len=120, verbose=False):
        super(CascadingTransformerAutoEncoder, self).__init__()
        self.latent_size = 1   # 1 for single timestep
        self.latent_dim = d_model # 256
        self.seq_len = seq_len
        self.verbose = verbose
        self.conv1_out_channels = 8
        
        # ENCODER
        self.skel_enc = nn.Linear(66, d_model)
    
        self.skip_trans_enc = SkipTransformerEncoder(
            encoder_layer= nn.TransformerEncoderLayer(
                d_model=256, nhead=8, dim_feedforward=1024, 
                dropout=0.1, activation='relu', 
                norm_first=False, batch_first=True),
            num_layers=3,
            norm=nn.LayerNorm(256),
            d_model=256
        )

        self.conv2d_enc = nn.Conv2d(
                            in_channels=1,
                            out_channels=self.conv1_out_channels,
                            kernel_size=(3, 256),
                            stride=(2, 1),
                            padding=(0, 0))
        
        self.skip_trans_enc2 = SkipTransformerEncoder(
            encoder_layer= nn.TransformerEncoderLayer(
                d_model=self.conv1_out_channels, nhead=8, dim_feedforward=64,
                dropout=0.1, activation='relu',
                norm_first=False, batch_first=True),
            num_layers=3,
            norm=nn.LayerNorm(self.conv1_out_channels),
            d_model=self.conv1_out_channels
        )

        self.conv2d_enc2 = nn.Conv2d(
                            in_channels=1,
                            out_channels=1,
                            kernel_size=(4, 1),
                            stride=(3, 1),
                            padding=(0, 0))
        
        # DECODER
        self.transconv2d_dec = nn.ConvTranspose2d(
                            in_channels=1,
                            out_channels=8,
                            kernel_size=(4, 1),
                            stride=(3, 2),
                            padding=(0, 0), 
                            output_padding=(1, 1))
        
        self.skip_trans_dec = SkipTransformerEncoder(
            encoder_layer= nn.TransformerEncoderLayer(
                d_model=64, nhead=4, dim_feedforward=256,
                dropout=dropout, activation=activation,
                norm_first=False, batch_first=True),
            num_layers=3,
            norm=nn.LayerNorm(64),
            d_model=64
        )

        self.transconv2d_dec2 = nn.ConvTranspose2d(
                            in_channels=1,
                            out_channels=1,
                            kernel_size=(3, 64),
                            stride=(2, 3),
                            padding=(0, 0), 
                            output_padding=(1, 2))
        
        self.skip_trans_dec2 = SkipTransformerEncoder(
            encoder_layer= nn.TransformerEncoderLayer(
                d_model=255, nhead=15, dim_feedforward=1024,
                dropout=0.1, activation='relu',
                norm_first=False, batch_first=True),
            num_layers=3,
            norm=nn.LayerNorm(255),
            d_model=255
        )
        
        self.final_layer = nn.Linear(255, 66)
        
    def forward(self, src: Tensor):
        z, lengths, mu, logvar = self.encode(src)
        # mu, logvar = dist[:1], dist[1:]
        # z = self.reparameterize(mu, logvar)
        output, mask = self.decode(z, lengths)
        return output, mu, logvar

    def encode(self, src):
        # get lengths
        lengths = [len(feature) for feature in src]


        if self.verbose: print('ENCODING')
        # get shapes
        bs, nframes, nfeats = src.shape
        if self.verbose: print('batch size:', bs, 'nframes:', nframes, 'nfeats:', nfeats)

        # skeletal embedding
        x = self.skel_enc(src)
        if self.verbose: print('skel_enc:', x.shape)

        # pass through transformerencoder with skip connections
        x = self.skip_trans_enc(x)
        if self.verbose: print('skip trans enc:', x.shape)

        # make small with conv2d
        x = x.unsqueeze(1)
        x = self.conv2d_enc(x)
        if self.verbose: print('conv2d:', x.shape)

        # pass through transformerencoder with skip connections
        x = x.squeeze(-1).permute(0, 2, 1)
        # x = self.skip_trans_enc2(x)
        if self.verbose: print('skip trans enc2:', x.shape)

        # make small with conv2d
        x = x.unsqueeze(1)
        x = self.conv2d_enc2(x).squeeze(1)
        if self.verbose: print('conv2d2:', x.shape)

        mu, logvar = x[:, :, :4], x[:, :, 4:]
        # resample
        std = logvar.exp().pow(0.5)
        dist = torch.distributions.Normal(mu, std)
        latent = dist.rsample()

        if self.verbose: print('latent:', latent.shape)
        latentdim = torch.prod(torch.tensor(latent.shape[1:]))
        if self.verbose: print('latentdim:', latentdim)
        return latent, lengths, mu, logvar
    
    
    def decode(self, z: Tensor, lengths: List[int]):
        if self.verbose: print('DECODING')
        mask = lengths_to_mask(lengths, z.device, self.seq_len)
        bs, nframes = mask.shape
        if self.verbose: print('batch size:', bs, 'nframes:', nframes)

        # expand with convtranspose2d
        if self.verbose: print('transconv2d:', z.shape)
        z = z.unsqueeze(1)
        z = self.transconv2d_dec(z)
        if self.verbose: print('transconv2d:', z.shape)

        # apply transformer
        z = z.permute(0, 2, 1, 3)
        z = z.reshape(z.shape[0], z.shape[1], -1)
        z = self.skip_trans_dec(z)
        if self.verbose: print('skip trans dec:', z.shape)

        # expand with convtranspose2d
        z = z.unsqueeze(1)
        z = self.transconv2d_dec2(z)
        if self.verbose: print('transconv2d2:', z.shape)

        # apply transformer
        z = z.squeeze(1)
        z = self.skip_trans_dec2(z)
        if self.verbose: print('skip trans dec2:', z.shape)
        
        # final layer
        output = self.final_layer(z)
        if self.verbose: print('final layer:', output.shape)
        output[~mask] = 0
        feats = output.permute(1, 0, 2)
        if self.verbose: print('feats:', feats.shape)

        return feats, mask
    

sample = torch.randn(32, 420, 66)

model = CascadingTransformerAutoEncoder(256, 8, 5, 2048, 0.1, 'relu', 420, True)
output, mu, logvar = model(sample)

ENCODING
batch size: 32 nframes: 420 nfeats: 66
skel_enc: torch.Size([32, 420, 256])
skip trans enc: torch.Size([32, 420, 256])
conv2d: torch.Size([32, 8, 209, 1])
skip trans enc2: torch.Size([32, 209, 8])
conv2d2: torch.Size([32, 69, 8])
latent: torch.Size([32, 69, 4])
latentdim: tensor(276)
DECODING
batch size: 32 nframes: 420
transconv2d: torch.Size([32, 69, 4])
transconv2d: torch.Size([32, 8, 209, 8])
skip trans dec: torch.Size([32, 209, 64])
transconv2d2: torch.Size([32, 1, 420, 255])
skip trans dec2: torch.Size([32, 420, 255])
final layer: torch.Size([32, 420, 66])
feats: torch.Size([420, 32, 66])


In [94]:
mask

tensor([[True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True],
        ...,
        [True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True]])

torch.Size([32, 8, 60, 1])