In [4]:
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from utils import (
    plot_3d_motion_frames_multiple,
    plot_3d_motion_animation,
    plot_3d_motion_frames_multiple,
    activation_dict,
)
from glob import glob
import matplotlib.pyplot as plt
import os
import shutil
import numpy as np
from modules.loss import VAE_Loss

from torch import Tensor
from typing import List, Optional
import copy

def lengths_to_mask(lengths: List[int],
                    device: torch.device,
                    max_len: int = None) -> Tensor:
    """
    Provides a mask, of length max_len or the longest element in lengths. With True for the elements less than the length for each length in lengths.
    """
    lengths = torch.tensor(lengths, device=device)
    max_len = max_len if max_len else max(lengths)
    mask = torch.arange(max_len, device=device).expand(len(lengths), max_len) < lengths.unsqueeze(1)
    return mask

def _get_clones(module, N):
    return nn.ModuleList([copy.deepcopy(module) for i in range(N)])

def _get_clone(module):
    return copy.deepcopy(module)

class SkipTransformerEncoder(nn.Module):
    def __init__(self, encoder_layer, num_layers, norm=None, d_model=256):
        super().__init__()
        self.d_model = d_model

        self.num_layers = num_layers
        self.norm = norm

        assert num_layers % 2 == 1

        num_block = (num_layers-1)//2
        self.input_blocks = _get_clones(encoder_layer, num_block)
        self.middle_block = _get_clone(encoder_layer)
        self.output_blocks = _get_clones(encoder_layer, num_block)
        self.linear_blocks = _get_clones(nn.Linear(2*self.d_model, self.d_model), num_block)

        self._reset_parameters()

    def _reset_parameters(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def forward(self, src,
                mask: Optional[Tensor] = None,
                src_key_padding_mask: Optional[Tensor] = None,
                pos: Optional[Tensor] = None):
        x = src

        xs = []
        for module in self.input_blocks:
            x = module(x, src_mask=mask,
                           src_key_padding_mask=src_key_padding_mask)
            xs.append(x)

        x = self.middle_block(x, src_mask=mask,
                           src_key_padding_mask=src_key_padding_mask)

        for (module, linear) in zip(self.output_blocks, self.linear_blocks):
            x = torch.cat([x, xs.pop()], dim=-1)
            x = linear(x)
            x = module(x, src_mask=mask,
                           src_key_padding_mask=src_key_padding_mask, 
                           )

        if self.norm is not None:
            x = self.norm(x)
        return x

class CascadingTransformerAutoEncoder(nn.Module):
    def __init__(self, d_model, nhead, num_layers, dim_feedforward=2048, dropout=0.1, activation='relu', seq_len=120, verbose=False):
        super(CascadingTransformerAutoEncoder, self).__init__()
        self.latent_size = 1   # 1 for single timestep
        self.latent_dim = d_model # 256
        self.seq_len = seq_len
        self.verbose = verbose
        self.conv1_out_channels = 32
        self.activation = nn.LeakyReLU()
        
        # ENCODER
        self.skel_enc = nn.Linear(66, d_model)
    
        self.skip_trans_enc = SkipTransformerEncoder(
            encoder_layer= nn.TransformerEncoderLayer(
                d_model=256, nhead=64, dim_feedforward=1024, 
                dropout=0.1, activation='gelu', 
                norm_first=False, batch_first=True),
            num_layers=7,
            norm=nn.LayerNorm(256),
            d_model=256
        )

        self.conv2d_enc = nn.Conv2d(
                            in_channels=1,
                            out_channels=self.conv1_out_channels,
                            kernel_size=(8, 256),
                            stride=(6, 1),
                            padding=(0, 0))
        
        self.enc_final_linear = nn.Sequential(
            nn.Linear(2208, 512),
            nn.LeakyReLU(),
            nn.Linear(512, 512),
            nn.LeakyReLU(),
            nn.Linear(512, 512),
            nn.LeakyReLU(),
            nn.Linear(512, 512),
            nn.LeakyReLU(),
            nn.Linear(512, 512),
        )
        # DECODER
        self.linear_dec = nn.Sequential(
            nn.Linear(256, 512),
            nn.LeakyReLU(),
            nn.Linear(512, 512),
            nn.LeakyReLU(),
            nn.Linear(512, 1024),
            nn.LeakyReLU(),
            nn.Linear(1024, 2208),
            nn.LeakyReLU(),
        )

        self.transconv2d_dec = nn.ConvTranspose2d(
                            in_channels=1,
                            out_channels=7,
                            kernel_size=8,
                            stride=(6,1),
                            padding=(0,0), 
                            output_padding=(4,0))

        self.linear_dec2 = nn.Sequential(
            nn.Linear(273, 256),
            nn.LeakyReLU(),
            nn.Linear(256, 256),
            nn.LeakyReLU(),
            nn.Linear(256, 256),
        )
        #nn.Linear(220, 256)

        self.skip_trans_dec2 = SkipTransformerEncoder(
            encoder_layer= nn.TransformerEncoderLayer(
                d_model=256, nhead=64, dim_feedforward=1024,
                dropout=0.1, activation='gelu',
                norm_first=False, batch_first=True),
            num_layers=7,
            norm=nn.LayerNorm(256),
            d_model=256
        )
        
        self.final_layer = nn.Linear(256, 66)
        
    def forward(self, src: Tensor):
        z, lengths, mu, logvar = self.encode(src)
        # mu, logvar = dist[:1], dist[1:]
        # z = self.reparameterize(mu, logvar)
        output = self.decode(z, lengths)
        return output, mu, logvar

    def encode(self, src):
        # get lengths
        lengths = [len(feature) for feature in src]


        if self.verbose: print('ENCODING')
        # get shapes
        bs, nframes, nfeats = src.shape
        if self.verbose: print('batch size:', bs, 'nframes:', nframes, 'nfeats:', nfeats)

        # skeletal embedding
        x = self.skel_enc(src)
        if self.verbose: print('skel_enc:', x.shape)

        # pass through transformerencoder with skip connections
        x = self.skip_trans_enc(x)
        if self.verbose: print('skip trans enc:', x.shape)

        # make small with conv2d
        x = x.unsqueeze(1)
        x = self.conv2d_enc(x)
        x = self.activation(x)
        if self.verbose: print('conv2d:', x.shape)

        # map linear
        x = torch.flatten(x, start_dim=1)
        if self.verbose: print('flattened:', x.shape)

        # map linear
        x = self.enc_final_linear(x)
        if self.verbose: print('final linear:', x.shape)

        mu, logvar = x[:, :256], x[:, 256:]
        # resample
        std = logvar.exp().pow(0.5)
        dist = torch.distributions.Normal(mu, std)
        latent = dist.rsample()

        if self.verbose: print('latent:', latent.shape)
        latentdim = torch.prod(torch.tensor(latent.shape[1:]))
        if self.verbose: print('latentdim:', latentdim)
        return latent, lengths, mu, logvar
    
    
    def decode(self, z: Tensor, lengths: List[int]):
        if self.verbose: print('DECODING')
        mask = lengths_to_mask(lengths, z.device, self.seq_len)
        bs, nframes = mask.shape
        if self.verbose: print('batch size:', bs, 'nframes:', nframes, 'z shape:', z.shape)

        # map linear
        z = self.linear_dec(z)
        if self.verbose: print('linear:', z.shape)
        z = z.view(bs, 1, 69, 32)

        if self.verbose: print('linear:', z.shape)
        # expand with convtranspose2d
        z = self.transconv2d_dec(z)
        z = self.activation(z)
        if self.verbose: print('transconv1d:', z.shape)

  
        # map linear
        z = z.permute(0, 2, 1, 3).flatten(start_dim=2)
        z = self.linear_dec2(z)
        z = self.activation(z)
        if self.verbose: print('linear:', z.shape)

        # apply transformer
        z = self.skip_trans_dec2(z)
        if self.verbose: print('skip trans dec2:', z.shape)
        
        # final layer
        output = self.final_layer(z)
        if self.verbose: print('final layer:', output.shape)
        output[~mask] = 0
        feats = output#.permute(1, 0, 2)
        if self.verbose: print('feats:', feats.shape)

        return feats

device = 'cpu'
sample = torch.randn(16, 420, 66).to(device)

model = CascadingTransformerAutoEncoder(256, 8, 5, 2048, 0.1, 'relu', 420, True).to(device)
output, mu, logvar = model(sample)

ENCODING
batch size: 16 nframes: 420 nfeats: 66
skel_enc: torch.Size([16, 420, 256])
skip trans enc: torch.Size([16, 420, 256])
conv2d: torch.Size([16, 32, 69, 1])
flattened: torch.Size([16, 2208])
final linear: torch.Size([16, 512])
latent: torch.Size([16, 256])
latentdim: tensor(256)
DECODING
batch size: 16 nframes: 420 z shape: torch.Size([16, 256])
linear: torch.Size([16, 2208])
linear: torch.Size([16, 1, 69, 32])
transconv1d: torch.Size([16, 7, 420, 39])
linear: torch.Size([16, 420, 256])
skip trans dec2: torch.Size([16, 420, 256])
final layer: torch.Size([16, 420, 66])
feats: torch.Size([16, 420, 66])


In [94]:
mask

tensor([[True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True],
        ...,
        [True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True]])

In [19]:
v1_pred = [.9,2.1,3.1,4.1, 0, 0, 0]
v1_true = [1,2,3,4,0,0,0]

v2_pred = [.9,2.1,3.1,4.1, 5.1, 6.1, 7.1]
v2_true = [1,2,3,4,5,6,7]

# get lengths
lengths = [4, 7]

# mse
criterion = nn.MSELoss(reduction='sum')
v1_loss = criterion(torch.tensor(v1_pred), torch.tensor(v1_true))
v2_loss = criterion(torch.tensor(v2_pred), torch.tensor(v2_true))

# scale by lengths
v1_loss = v1_loss / lengths[0]
v2_loss = v2_loss / lengths[1]

print('v1 loss:', v1_loss.item())
print('v2 loss:', v2_loss.item())

# total loss
total_loss = (v1_loss + v2_loss)/2
print('total loss:', total_loss.item())

v1 loss: 0.009999986737966537
v2 loss: 0.009999984875321388
total loss: 0.009999985806643963


In [24]:
import torch
import torch.nn as nn

class BatchSequenceMSELoss:
    def __init__(self):
        self.criterion = nn.MSELoss(reduction='none')  # Compute loss per element

    def compute_loss(self, preds, trues, lengths):
        # Ensure inputs are tensors
        preds = torch.tensor(preds, dtype=torch.float32)
        trues = torch.tensor(trues, dtype=torch.float32)
        
        # Compute MSE loss per element
        losses = self.criterion(preds, trues)
        
        # Mask out losses beyond each sequence's length
        mask = torch.arange(losses.size(1)).expand(len(lengths), losses.size(1)) < torch.tensor(lengths).unsqueeze(1)
        masked_losses = losses * mask
        
        # Sum and normalize by the effective sequence length
        loss_per_sequence = masked_losses.sum(dim=1) / torch.tensor(lengths, dtype=torch.float32)
        return loss_per_sequence

    def total_loss(self, predictions, targets, lengths):
        loss_per_sequence = self.compute_loss(predictions, targets, lengths)
        return loss_per_sequence.mean()


import torch
from torch import nn
from torch.nn import functional as F

class VAE_Loss(nn.Module):
    def __init__(self, loss_weights):
        """
        Initialize the CustomLoss module with a structure for loss_weights that defines
        each loss component, its calculation method, and its weight.

        Parameters:
            loss_weights (dict): A dictionary where each key is 'name_method', and the value is weight
        """
        super(VAE_Loss, self).__init__()
        self.loss_weights = loss_weights
           
    def forward(self, loss_data, lengths=None):
        """
        Calculate and return the custom loss based on the provided loss data and methods defined
        in the initialization.

        Parameters:
            loss_data (dict): A dictionary with keys for each loss component (including 'mu' and 'logvar'
                              for KL divergence) and values containing the data necessary for loss calculation.

        Returns:
            float: The total loss calculated from the sum of all components.
            dict: A dictionary containing the calculated losses for each component and the total loss.
            dict: loss unscaled
        """

        total_loss = 0.0
        losses_unscaled = {}
        losses_scaled  = {}
        for k, v in loss_data.items():
            # print(k)
            name, method = k.split('_')
            weight = self.loss_weights.get(k)
            # print(weight)
            if weight in [0, None]:
                continue

            if method == 'L2':
                losses_unscaled[k] = 
            elif method == 'L1':
                losses_unscaled[k] = F.l1_loss(v['rec'], v['true'], reduction='mean') 
            elif method == 'KL':
                losses_unscaled[k] = self.kl_divergence(v['mu'], v['logvar'])
            elif method == 'BCE':
                losses_unscaled[k] = F.binary_cross_entropy(v['rec'], v['true'], reduction='mean')
            else:
                raise ValueError(f"Invalid loss method '{method}' provided for loss component '{k}'.")

            losses_scaled[k] = weight * losses_unscaled[k]
            total_loss += losses_scaled[k]

        return total_loss, losses_scaled, losses_unscaled

    def kl_divergence(self, mu, logvar):
        """
        Calculate the KL divergence loss, encouraging a more compact latent space by penalizing large values of mu and sigma.
        
        Parameters:
            mu (Tensor): The mean vector of the latent space distribution.
            logvar (Tensor): The log variance vector of the latent space distribution.
        
        Returns:
            Tensor: The computed KL divergence loss.
        """
        return -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())



# Example usage
loss_calculator = BatchSequenceMSELoss()

# Predictions and ground truths
predictions = [
    [0.9, 2.1, 3.1, 4.1, 0, 0, 0],
    [0.9, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1]
]
targets = [
    [1, 2, 3, 4, 0, 0, 0],
    [1, 2, 3, 4, 5, 6, 7]
]

# Effective lengths of the sequences
lengths = [4, 7]

# Calculate total loss
total_loss = loss_calculator.total_loss(predictions, targets, lengths)
print('Total loss:', total_loss.item())


Total loss: 0.009999985806643963


In [26]:
import torch
from torch import nn
from torch.nn import functional as F

class BatchSequenceMSELoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.criterion = nn.MSELoss(reduction='none')  # Compute loss per element

    def forward(self, preds, trues, lengths):
        preds = torch.tensor(preds, dtype=torch.float32)
        trues = torch.tensor(trues, dtype=torch.float32)
        
        losses = self.criterion(preds, trues)
        mask = torch.arange(losses.size(1)).expand(len(lengths), losses.size(1)) < torch.tensor(lengths).unsqueeze(1)
        masked_losses = losses * mask
        
        loss_per_sequence = masked_losses.sum(dim=1) / torch.tensor(lengths, dtype=torch.float32)
        return loss_per_sequence.sum()

class VAE_Loss(nn.Module):
    def __init__(self, loss_weights):
        super(VAE_Loss, self).__init__()
        self.loss_weights = loss_weights
        self.sequence_mse_loss = BatchSequenceMSELoss()  # integrate the sequence MSE loss here

    def forward(self, loss_data, lengths=None):
        total_loss = 0.0
        losses_unscaled = {}
        losses_scaled = {}
        for k, v in loss_data.items():
            name, method = k.split('_')
            weight = self.loss_weights.get(k, 0)
            if weight == 0:
                continue

            if method == 'L2':
                if lengths is not None:
                    l = self.sequence_mse_loss(v['rec'], v['true'], lengths)# if lengths else F.mse_loss(v['rec'], v['true'], reduction='sum')
                    losses_unscaled[k] = l / lengths.sum()# if lengths else F.mse_loss(v['rec'], v['true'], reduction='mean')
                else:
                    losses_unscaled[k] = F.mse_loss(v['rec'], v['true'], reduction='mean')
                    
            elif method == 'L1':
                losses_unscaled[k] = F.l1_loss(v['rec'], v['true'], reduction='mean')
            elif method == 'KL':
                losses_unscaled[k] = self.kl_divergence(v['mu'], v['logvar'])
            elif method == 'BCE':
                losses_unscaled[k] = F.binary_cross_entropy(v['rec'], v['true'], reduction='mean')
            else:
                raise ValueError(f"Invalid loss method '{method}' provided for loss component '{k}'.")

            losses_scaled[k] = weight * losses_unscaled[k]
            total_loss += losses_scaled[k]

        return total_loss, losses_scaled, losses_unscaled

    def kl_divergence(self, mu, logvar):
        return -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

# Example of setting up the VAE_Loss
loss_weights = {
    'reconstruction_L2': 1.0,
    'kl_KL': 0.5  # Example weights
}

vae_loss = VAE_Loss(loss_weights)

# Example data
loss_data = {
    'reconstruction_L2': {'rec': [[0.9, 2.1, 3.1, 4.1, 0, 0, 0], [0.9, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1]],
                          'true': [[1, 2, 3, 4, 0, 0, 0], [1, 2, 3, 4, 5, 6, 7]]},
    'kl_KL': {'mu': torch.zeros(10, 32), 'logvar': torch.zeros(10, 32)}
}

lengths = torch.tensor([4, 7])

# Calculate loss
total_loss, losses_scaled, losses_unscaled = vae_loss(loss_data, lengths)
print('Total Loss:', total_loss.item())


RuntimeError: Boolean value of Tensor with more than one value is ambiguous

torch.Size([32, 8, 60, 1])