In [None]:
!git clone https://github.com/guijiejie/DCMD-main.git

In [None]:
!ls /kaggle/working/

# **Pip**

In [None]:
!pip install appdirs==1.4.4 docker-pycreds==0.4.0 gitdb==4.0.10 gitpython==3.1.32 joblib==1.3.1 numpy==1.25.2 pathtools==0.1.2 protobuf==4.23.4 scikit-learn==1.3.0 scipy==1.11.1 sentry-sdk==1.29.2 setproctitle==1.3.2 smmap==5.0.0 threadpoolctl==3.2.0 wandb==0.15.8
!pip install --upgrade pytorch-lightning pyyaml wandb pandas numpy matplotlib matplotlib-inline scikit-learn tqdm 
!pip install --upgrade torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

# **Replace code to run properly**

In [42]:
!pip install networkx




In [103]:
%%writefile /kaggle/working/DCMD-main/config/Avenue/dcmd_train.yaml 
### Experiment configuration

## General settings
split: 'train' # data split; choices ['train', 'test']
debug: false # if true, load only a few data samples
seed: 999
validation: false # use validation; only for UBnormal
use_hr: false # for validation and test on UBnormal

## Computational resources
accelerator: 'gpu'
devices: [0] # indices of cuda devices to use

## Paths
dir_name: 'train_experiment' # name of the directory of the current experiment
data_dir: '/kaggle/input/avenue/Avenue' # path to the data
exp_dir: '/kaggle/working/DCMD-main/checkpoints' # path to the directory that will contain the current experiment directory
test_path: '/kaggle/input/avenue/Avenue/testing/test_frame_mask' # path to the test data
load_ckpt: 'best.ckpt' # name of the checkpoint to load at inference time
create_experiment_dir: true

## WANDB configuration
use_wandb: false
project_name: "project_name"
wandb_entity: "entity_name"
group_name: "group_name"
use_ema: false

##############################


### Model's configuration

## U-Net's configuration
dropout: 0. # probability of dropout
conditioning_strategy: 'inject'
## Rec configuration
h_dim: 512 # dimension of the bottleneck at the end of the encoder of the conditioning network
latent_dim: 256 # dimension of the latent space of the conditioning encoder
channels: [512,256,512] # channels for the encoder

# Cho forecasting
use_forecasting: true
hidden_dim_forecast: 64
dropout_forecast: 0.2
num_heads_gat: 4
latent_dim_forecast: 32
lambda_forecast: 0.1

# Trong config.yaml của bạn, thêm dòng tương tự:
use_adaptive: true
use_jigsaw: true
layer_channels: [128, 64, 128]
emb_dim: 64

##############################


### Training's configuration

## Diffusion's configuration
noise_steps: 10 # how many diffusion steps to perform

### Optimizer and scheduler's configuration
n_epochs: 10
opt_lr: 0.001

## Losses' configuration
loss_fn: 'smooth_l1' # loss function; choices ['mse', 'l1', 'smooth_l1']

##############################


### Inference's configuration
n_generated_samples: 50 # number of samples to generate
model_return_value: 'loss' # choices ['loss', 'poses', 'all']; if 'loss', the model will return the loss;
                           # if 'poses', the model will return the generated poses; 
                           # if 'all', the model will return both the loss and the generated poses
aggregation_strategy: 'best' # choices ['best', 'mean', 'median', 'random']; if 'best', the best sample will be selected; 
                             # if 'mean', the mean of loss of the samples will be selected; 
                             # if 'median', the median of the loss of the samples will be selected; 
                             # if 'random', a random sample will be selected;
                             # if 'mean_poses', the mean of the generated poses will be selected;
                             # if 'median_poses', the median of the generated poses will be selected;
                             # if 'all', all the generated poses will be selected
filter_kernel_size: 30 # size of the kernel to use for smoothing the anomaly score of each clip
frames_shift: 6 # it compensates the shift of the anomaly score due to the sliding window; 
                # in conjuction with pad_size and filter_kernel_size, it strongly depends on the dataset
save_tensors: false # if true, save the generated tensors for faster inference
load_tensors: false # if true, load the generated tensors for faster inference

##############################


### Dataset's configuration

## Important parameters
dataset_choice: 'HR-Avenue'
seg_len: 7 # length of the window (his+pre)
vid_res: [640,360]
batch_size: 2048
pad_size: 12 # size of the padding 

## Other parameters
headless: false # remove the keypoints of the head
hip_center: false # center the keypoints on the hip
kp18_format: false # use the 18 keypoints format
normalization_strategy: 'robust' # use 'none' to avoid normalization, 'robust' otherwise
num_coords: 2 # number of coordinates to use
num_transform: 5 # number of transformations to apply
num_workers: 4
seg_stride: 1
seg_th: 0
start_offset: 0
symm_range: true
use_fitted_scaler: false

## New configuration
n_his: 3
padding: 'LastFrame'
## translinear configuration
num_layers: 6
num_heads: 8
latent_dims: 512
loss_1_series_weight: 0.01
loss_1_prior_weight: 0
loss_2_series_weight: 0
loss_2_prior_weight: 0.01

Overwriting /kaggle/working/DCMD-main/config/Avenue/dcmd_train.yaml


In [44]:
%%writefile /kaggle/working/DCMD-main/models/transformer.py
import torch
import torch.nn.functional as F
from torch import layer_norm, nn
import numpy as np
from typing import List, Tuple, Union

import math


def timestep_embedding(timesteps, dim, max_period=10000):
    """
    Create sinusoidal timestep embeddings.
    :param timesteps: a 1-D Tensor of N indices, one per batch element.
                      These may be fractional.
    :param dim: the dimension of the output.
    :param max_period: controls the minimum frequency of the embeddings.
    :return: an [N x dim] Tensor of positional embeddings.
    """
    half = dim // 2
    freqs = torch.exp(
        -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
    ).to(device=timesteps.device)
    args = timesteps[:, None].float() * freqs[None]
    embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
    if dim % 2:
        embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
    return embedding


def set_requires_grad(nets, requires_grad=False):
    """Set requies_grad for all the networks.

    Args:
        nets (nn.Module | list[nn.Module]): A list of networks or a single
            network.
        requires_grad (bool): Whether the networks require gradients or not
    """
    if not isinstance(nets, list):
        nets = [nets]
    for net in nets:
        if net is not None:
            for param in net.parameters():
                param.requires_grad = requires_grad


def zero_module(module):
    """
    Zero out the parameters of a module and return it.
    """
    for p in module.parameters():
        p.detach().zero_()
    return module


class StylizationBlock(nn.Module):

    def __init__(self, latent_dim, time_embed_dim, dropout):
        super().__init__()
        self.emb_layers = nn.Sequential(
            nn.SiLU(),
            nn.Linear(time_embed_dim, 2 * latent_dim),
        )
        self.norm = nn.LayerNorm(latent_dim)
        self.out_layers = nn.Sequential(
            nn.SiLU(),
            nn.Dropout(p=dropout),
            zero_module(nn.Linear(latent_dim, latent_dim)),
        )

    def forward(self, h, emb):
        """
        h: B, T, D
        emb: B, D
        """
        # B, 1, 2D
        emb_out = self.emb_layers(emb).unsqueeze(1)
        # scale: B, 1, D / shift: B, 1, D
        scale, shift = torch.chunk(emb_out, 2, dim=2)
        # B, T, D
        h = self.norm(h) * (1 + scale) + shift
        h = self.out_layers(h)
        return h


class FFN(nn.Module):

    def __init__(self, latent_dim, ffn_dim, dropout, time_embed_dim):
        super().__init__()
        self.linear1 = nn.Linear(latent_dim, ffn_dim)
        self.linear2 = zero_module(nn.Linear(ffn_dim, latent_dim))
        self.activation = nn.GELU()
        self.dropout = nn.Dropout(dropout)
        self.proj_out = StylizationBlock(latent_dim, time_embed_dim, dropout)

    def forward(self, x, emb):
        """
            x: B, T, D (D=latent_dim)
        """
        y = self.linear2(self.dropout(self.activation(self.linear1(x))))
        y = x + self.proj_out(y, emb)
        return y


class TemporalSelfAttention(nn.Module):

    def __init__(self, n_frames, latent_dim, num_head, dropout, time_embed_dim, output_attention = True):
        super().__init__()
        self.num_head = num_head
        self.output_attention = output_attention
        self.norm = nn.LayerNorm(latent_dim)
        self.query = nn.Linear(latent_dim, latent_dim, bias=False)
        self.key = nn.Linear(latent_dim, latent_dim, bias=False)
        self.value = nn.Linear(latent_dim, latent_dim, bias=False)
        self.sigma_projection = nn.Linear(latent_dim, num_head, bias=False)
        self.dropout = nn.Dropout(dropout)
        self.proj_out = StylizationBlock(latent_dim, time_embed_dim, dropout)
        n_frames = n_frames
        self.distances = torch.zeros((n_frames, n_frames)).cuda(0)

        for i in range(n_frames):
            for j in range(n_frames):
                self.distances[i][j] = abs(i - j)

    def forward(self, x, emb):
        """
        x: B, T, D (D=latent_dim)
        """
        B, T, D = x.shape
        H = self.num_head

        ## series-association
        # B, T, 1, D
        query = self.query(self.norm(x)).unsqueeze(2)
        # B, 1, T, D
        key = self.key(self.norm(x)).unsqueeze(1)
        # B, T, H, D/H
        query = query.view(B, T, H, -1)
        key = key.view(B, T, H, -1)
        scale = 1. / math.sqrt(D/H)
        # B, H, T, T
        scores = torch.einsum('bnhd,bmhd->bhnm', query, key) / math.sqrt(D // H)
        attention = scale * scores
        # B, H, T, T
        series = self.dropout(F.softmax(attention, dim=-1))

        ## prior-association
        sigma = self.sigma_projection(x).view(B, T, H)  # B, T, H
        sigma = sigma.transpose(1, 2)  # B T H ->  B H T
        sigma = torch.sigmoid(sigma * 5) + 1e-5
        sigma = torch.pow(3, sigma) - 1
        sigma = sigma.unsqueeze(-1).repeat(1, 1, 1, T)  # B, H, T, T
        prior = self.distances.unsqueeze(0).unsqueeze(0).repeat(sigma.shape[0], sigma.shape[1], 1, 1) # B, H, T, T
        prior = 1.0 / (math.sqrt(2 * math.pi) * sigma) * torch.exp(-prior ** 2 / 2 / (sigma ** 2)).cuda(0) # B, H, T, T

        # B, T, H, D/H
        value = self.value(self.norm(x)).view(B, T, H, -1)
        # B, T, D
        y = torch.einsum('bhnm,bmhd->bnhd', series, value).reshape(B, T, D)
        y = x + self.proj_out(y, emb)

        if self.output_attention:
            return y.contiguous(), series, prior, sigma
        else:
            return y.contiguous(), None

class TemporalDiffusionTransformerDecoderLayer(nn.Module):

    def __init__(self,
                 n_frames = 7,
                 latent_dim=16,
                 time_embed_dim=16,
                 ffn_dim=32,
                 num_head=4,
                 dropout=0.5
                 ):
        super().__init__()
        self.sa_block = TemporalSelfAttention(
            n_frames, latent_dim, num_head, dropout, time_embed_dim)
        self.ffn = FFN(latent_dim, ffn_dim, dropout, time_embed_dim)

    def forward(self, x, emb):
        x, series, prior, sigma = self.sa_block(x, emb)
        x = self.ffn(x, emb)
        return x, series, prior, sigma


class MotionTransformer(nn.Module):
    def __init__(self,
                 input_feats,
                 num_frames=7,
                 latent_dim=16,
                 ff_size=32,
                 num_layers=8,
                 num_heads=8,
                 dropout=0.2,
                 activation="gelu",
                 output_attention = True,
                 device: Union[str, torch.DeviceObjType] = 'cpu',
                 inject_condition: bool = False,
                 **kargs):
        super().__init__()


        self.input_feats = input_feats # 34
        self.num_frames = num_frames
        self.latent_dim = latent_dim
        self.ff_size = ff_size
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.dropout = dropout
        self.activation = activation
        self.output_attention = output_attention
        self.device = device
        self.time_embed_dim = latent_dim
        self.inject_condition = inject_condition

        self.build_model()

    def build_model(self):
        self.sequence_embedding = nn.Parameter(torch.randn(self.num_frames, self.latent_dim))

        # Input Embedding
        self.joint_embed = nn.Linear(self.input_feats, self.latent_dim)
        self.cond_embed = nn.Linear(256, self.time_embed_dim)

        self.time_embed = nn.Sequential(
            nn.Linear(self.latent_dim, self.time_embed_dim),
            nn.SiLU(),
            nn.Linear(self.time_embed_dim, self.time_embed_dim),
        )

        self.temporal_decoder_blocks = nn.ModuleList()
        for i in range(self.num_layers):
            self.temporal_decoder_blocks.append(
                TemporalDiffusionTransformerDecoderLayer(
                    n_frames=self.num_frames,
                    latent_dim=self.latent_dim,
                    time_embed_dim=self.time_embed_dim,
                    ffn_dim=self.ff_size,
                    num_head=self.num_heads,
                    dropout=self.dropout,
                )
            )
        # Output Module
        self.out = zero_module(nn.Linear(self.latent_dim, self.input_feats))


    def forward(self, x, timesteps, condition_data:torch.Tensor=None):
        """
        x: B, T, D (D=C*V)
        """
        B, T = x.shape[0], x.shape[1]

        # B, latent_dim
        emb = self.time_embed(timestep_embedding(timesteps, self.latent_dim))

        # Add conditioning signal
        if self.inject_condition:
            condition_data = self.cond_embed(condition_data)
            emb = emb + condition_data

        # B, T, latent_dim
        h = self.joint_embed(x)
        h = h + self.sequence_embedding.unsqueeze(0)[:, :T, :]
        
        i = 0
        prelist = []
        series_list = []
        prior_list = []
        sigma_list = []
        for module in self.temporal_decoder_blocks:
            if i < (self.num_layers // 2):
                prelist.append(h)
                h, series, prior, sigmas = module(h, emb) # B, T, latent_dim
                series_list.append(series)
                prior_list.append(prior)
                sigma_list.append(sigmas)
            elif i >= (self.num_layers // 2):
                h, series, prior, sigmas = module(h, emb)
                h += prelist[-1]
                series_list.append(series)
                prior_list.append(prior)
                sigma_list.append(sigmas)
                prelist.pop()
            i += 1

        # B, T, C*V
        output = self.out(h).view(B, T, -1).contiguous()
        if self.output_attention:
            return output, series_list, prior_list, sigma_list
        return output

Overwriting /kaggle/working/DCMD-main/models/transformer.py


In [83]:
%%writefile /kaggle/working/DCMD-main/train_DCMD.py
import argparse
import os
import random

import numpy as np
import pytorch_lightning as pl
import torch
import yaml
from models.dcmd import DCMD
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.strategies import DDPStrategy
from utils.argparser import init_args
from utils.dataset import get_dataset_and_loader
from utils.ema import EMACallback


if __name__== '__main__':

    # Parse command line arguments and load config file
    parser = argparse.ArgumentParser(description='Pose_AD_Experiment')
    parser.add_argument('-c', '--config', type=str, required=True,
                        default='/your_default_config_file_path')
    
    args = parser.parse_args()
    config_path = args.config
    args = yaml.load(open(args.config), Loader=yaml.FullLoader)
    args = argparse.Namespace(**args)
    args = init_args(args) 
    # Save config file to ckpt_dir
    os.system(f'cp {config_path} {os.path.join(args.ckpt_dir, "config.yaml")}')     
    
    # Set seeds    
    torch.manual_seed(args.seed)
    random.seed(args.seed)
    np.random.seed(args.seed) 
    pl.seed_everything(args.seed)

    # Set callbacks and logger
    if (hasattr(args, 'diffusion_on_latent') and args.stage == 'pretrain'):
        monitored_metric = 'pretrain_rec_loss'
        metric_mode = 'min'
    elif args.validation:
        monitored_metric = 'AUC'
        metric_mode = 'max'
    else:
        monitored_metric = 'loss'
        metric_mode = 'min'
    callbacks = [ModelCheckpoint(
                    dirpath=args.ckpt_dir,
                    save_top_k=2,
                    save_last=True,
                    monitor=monitored_metric,
                    mode=metric_mode
                )]

    
    callbacks += [EMACallback()] if args.use_ema else [] # Use to achieve exponential moving average
    
    if args.use_wandb:
        callbacks += [LearningRateMonitor(logging_interval='step')]
        wandb_logger = WandbLogger(project=args.project_name, group=args.group_name, entity=args.wandb_entity, 
                                   name=args.dir_name, config=vars(args), log_model='all')
    else:
        wandb_logger = False

    parser.add_argument(
        "--channels",       # vẫn để channels nếu có những chỗ khác dựa vào as well,
        type=int,
        default=3,
        help="Số kênh đầu vào (vd: xyz = 3)."
    )
    # THÊM 2 dòng sau để tạo ra args.in_channels:
    parser.add_argument(
        "--in_channels",
        type=int,
        default=3,
        help="(Thêm) Số kênh đầu vào, để tương thích với những chỗ gọi args.in_channels"
    )
    
    # Get dataset and loaders
    _, train_loader, _, val_loader = get_dataset_and_loader(args, split=args.split, validation=args.validation)
    
    # Initialize model and trainer
    model = DCMD(args)
    
    trainer = pl.Trainer(accelerator=args.accelerator, devices=args.devices, default_root_dir=args.ckpt_dir, max_epochs=args.n_epochs, 
                         logger=wandb_logger, callbacks=callbacks, strategy=DDPStrategy(find_unused_parameters=False),
                         log_every_n_steps=20, num_sanity_val_steps=0, deterministic=True)
    
    # Train the model    
    trainer.fit(model=model, train_dataloaders=train_loader, val_dataloaders=val_loader)

Overwriting /kaggle/working/DCMD-main/train_DCMD.py


In [None]:
%%writefile /kaggle/working/DCMD-main/config/Avenue/dcmd_test.yaml
### Experiment configuration

## General settings
split: 'test' # data split; choices ['train', 'test']
debug: false # if true, load only a few data samples
seed: 999
validation: false # use validation; only for UBnormal
use_hr: false # for validation and test on UBnormal

## Computational resources
accelerator: 'gpu'
devices: [0] # indices of cuda devices to use

## Paths
dir_name: 'test_experiment' # name of the directory of the current experiment
data_dir: '/kaggle/input/avenue/Avenue' # path to the data
exp_dir: '/kaggle/working/DCMD-main/checkpoints' # path to the directory that will contain the current experiment directory
test_path: '/kaggle/input/avenue/Avenue/testing/test_frame_mask' # path to the test data
load_ckpt: 'last.ckpt' # name of the checkpoint to load at inference time
create_experiment_dir: false # if true, create a new directory for the current experiment

## WANDB configuration
use_wandb: false
project_name: "project_name"
wandb_entity: "entity_name"
group_name: "group_name"
use_ema: false

##############################


### Model's configuration

## U-Net's configuration
dropout: 0. # probability of dropout
conditioning_strategy: 'inject'
## Rec configuration
h_dim: 512 # dimension of the bottleneck at the end of the encoder of the conditioning network
latent_dim: 256 # dimension of the latent space of the conditioning encoder
channels: [512,256,512] # channels for the encoder

##############################


### Training's configuration

## Diffusion's configuration
noise_steps: 10 # how many diffusion steps to perform

### Optimizer and scheduler's configuration
n_epochs: 10
opt_lr: 0.001

## Losses' configuration
loss_fn: 'smooth_l1' # loss function; choices ['mse', 'l1', 'smooth_l1']

##############################


### Inference's configuration
n_generated_samples: 50 # number of samples to generate
model_return_value: 'loss' # choices ['loss', 'poses', 'all']; if 'loss', the model will return the loss;
                           # if 'poses', the model will return the generated poses; 
                           # if 'all', the model will return both the loss and the generated poses
aggregation_strategy: 'best' # choices ['best', 'mean', 'median', 'random']; if 'best', the best sample will be selected; 
                             # if 'mean', the mean of loss of the samples will be selected; 
                             # if 'median', the median of the loss of the samples will be selected; 
                             # if 'random', a random sample will be selected;
                             # if 'mean_poses', the mean of the generated poses will be selected;
                             # if 'median_poses', the median of the generated poses will be selected;
                             # if 'all', all the generated poses will be selected
filter_kernel_size: 30 # size of the kernel to use for smoothing the anomaly score of each clip
frames_shift: 6 # it compensates the shift of the anomaly score due to the sliding window; 
                # in conjuction with pad_size and filter_kernel_size, it strongly depends on the dataset
save_tensors: true # if true, save the generated tensors for faster inference
load_tensors: false # if true, load the generated tensors for faster inference

##############################


### Dataset's configuration

## Important parameters
dataset_choice: 'HR-Avenue'
seg_len: 7 # length of the window (cond+noised)
vid_res: [640,360]
batch_size: 2048
pad_size: 12 # size of the padding

## Other parameters
headless: false # remove the keypoints of the head
hip_center: false # center the keypoints on the hip
kp18_format: false # use the 18 keypoints format
normalization_strategy: 'robust' # use 'none' to avoid normalization, 'robust' otherwise
num_coords: 2 # number of coordinates to use
num_transform: 5 # number of transformations to apply
num_workers: 4
seg_stride: 1
seg_th: 0
start_offset: 0
symm_range: true
use_fitted_scaler: false

## New configuration
n_his: 3
padding: 'LastFrame'
## translinear configuration
num_layers: 6
num_heads: 8
latent_dims: 512
loss_1_series_weight: 0.01
loss_1_prior_weight: 0
loss_2_series_weight: 0
loss_2_prior_weight: 0.01

In [None]:
%%writefile /kaggle/working/DCMD-main/models/dcmd.py
import argparse
import os
from math import prod
from typing import Dict, List, Tuple

import numpy as np
import pandas as pd
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
from matplotlib import pyplot as plt

from models.stsae.stsae import STSAE
from sklearn.metrics import roc_curve, roc_auc_score
from torch.optim import Adam
from tqdm import tqdm

from utils.diffusion_utils import Diffusion
from utils.eval_utils import (compute_var_matrix, filter_vectors_by_cond,
                              get_avenue_mask, get_hr_ubnormal_mask, pad_scores, score_process)
from utils.model_utils import processing_data, my_kl_loss
from models.transformer import MotionTransformer
from utils.tools import get_dct_matrix, generate_pad, padding_traj


class DCMD(pl.LightningModule):

    losses = {'l1': nn.L1Loss, 'smooth_l1': nn.SmoothL1Loss, 'mse': nn.MSELoss}
    conditioning_strategies = {'inject': 'inject'}

    def __init__(self, args: argparse.Namespace) -> None:
        """
        This class implements DCMD model.

        Args:
            args (argparse.Namespace): arguments containing the hyperparameters of the model
        """

        super(DCMD, self).__init__()

        ## Log the hyperparameters of the model
        self.save_hyperparameters(args)

        ## Set the internal variables of the model
        # Data parameters
        self.n_frames = args.seg_len
        self.num_coords = args.num_coords
        self.n_joints = self._infer_number_of_joint(args)

        ## Model parameters
        # Main network
        self.dropout = args.dropout
        self.conditioning_strategy = self.conditioning_strategies[args.conditioning_strategy]
        # Conditioning network
        self.cond_h_dim = args.h_dim
        self.cond_latent_dim = args.latent_dim
        self.cond_channels = args.channels
        self.cond_dropout = args.dropout

        ## Training and inference parameters
        self.learning_rate = args.opt_lr
        self.loss_fn = self.losses[args.loss_fn](reduction='none')
        self.noise_steps = args.noise_steps
        self.aggregation_strategy = args.aggregation_strategy
        self.n_generated_samples = args.n_generated_samples
        self.model_return_value = args.model_return_value
        self.gt_path = args.gt_path
        self.split = args.split
        self.use_hr = args.use_hr
        self.ckpt_dir = args.ckpt_dir
        self.save_tensors = args.save_tensors
        self.num_transforms = args.num_transform
        self.anomaly_score_pad_size = args.pad_size
        self.anomaly_score_filter_kernel_size = args.filter_kernel_size
        self.anomaly_score_frames_shift = args.frames_shift
        self.dataset_name = args.dataset_choice

        # New parameters
        self.n_his = args.n_his
        self.padding = args.padding
        self.num_layers = args.num_layers
        self.num_heads = args.num_heads
        self.latent_dims = args.latent_dims
        self.automatic_optimization = False
        self.loss_1_series_weight = args.loss_1_series_weight
        self.loss_1_prior_weight = args.loss_1_prior_weight
        self.loss_2_series_weight = args.loss_2_series_weight
        self.loss_2_prior_weight = args.loss_2_prior_weight
        self.idx_pad, self.zero_index = generate_pad(self.padding, self.n_his, self.n_frames-self.n_his)

        ## Set the noise scheduler for the diffusion process
        self._set_diffusion_variables()

        ## Build the model
        self.build_model()


    def build_model(self) -> None:
        """
        Build the model according to the specified hyperparameters.
        """

        # Prediction Model
        pre_model = MotionTransformer(
            input_feats=2 * self.n_joints,
            num_frames=self.n_frames,
            num_layers=self.num_layers,
            num_heads=self.num_heads,
            latent_dim=self.latent_dims,
            dropout=self.dropout,
            device=self.device,
            inject_condition=(self.conditioning_strategy == 'inject')
        )

        # Reconstruction Model
        rec_model = STSAE(
            c_in=self.num_coords,
            h_dim=self.cond_h_dim,
            latent_dim=self.cond_latent_dim,
            n_frames=self.n_his,
            dropout=self.cond_dropout,
            n_joints=self.n_joints,
            layer_channels=self.cond_channels,
            device=self.device)

        self.pre_model, self.rec_model = pre_model, rec_model


    def forward(self, input_data: List[torch.Tensor], aggr_strategy: str = None, return_: str = None) -> List[torch.Tensor]:
        """
        Forward pass of the model.
        """

        ## Unpack data: tensor_data is the input data, meta_out is a list of metadata
        tensor_data, meta_out = self._unpack_data(input_data)
        B = tensor_data.shape[0]

        ## Select frames to reconstruct and to predict
        history_data = tensor_data[:, :, :self.n_his, :]
        x_0 = padding_traj(history_data, self.padding, self.idx_pad, self.zero_index)

        generated_xs = []
        # Generate m future predictions
        for _ in range(self.n_generated_samples):

            ## Reconstruction —— AE model
            condition_embedding, rec_his_data = self.rec_model(history_data)

            ## Prediction —— diffusion model
            ## DCT transformation
            dct_m, idct_m = get_dct_matrix(self.n_frames)
            dct_m_all = dct_m.float().to(self.device)
            idct_m_all = idct_m.float().to(self.device)
            # (B, C, T, V) -> (B, T, V, C)
            x = x_0.permute(0, 2, 3, 1).contiguous()
            # (B, T, V, C) -> (B, T, C*V)
            x = x.reshape([x.shape[0], self.n_frames, -1])
            y = torch.matmul(dct_m_all, x)  # [B, T, C*V]

            ## Generate gaussian noise of the same shape as the y
            y_d = torch.randn_like(y, device=self.device)

            ## (t ∈ T, T-1, ..., 1)
            for i in reversed(range(1, self.noise_steps)):

                ### Prediction (Two branches)
                ## Set the time step
                t = torch.full(size=(B,), fill_value=i, dtype=torch.long, device=self.device)
                t_prev = torch.full(size=(B,), fill_value=i, dtype=torch.long, device=self.device)
                t_prev[0] = 0

                ## Generate gaussian noise of the same shape as the predicted noise
                noise_pre = torch.randn_like(y_d, device=self.device) if i > 1 else torch.zeros_like(y_d, device=self.device)

                ## First branch
                # Predict the noise
                predicted_noise_pre, series, prior, _ = self.pre_model(y_d, t, condition_data=condition_embedding)
                # Get the alpha and beta values and expand them to the shape of the predicted noise
                alpha_pre = self._alpha[t][:, None, None]
                alpha_hat_pre = self._alpha_hat[t][:, None, None]
                beta_pre = self._beta[t][:, None, None]
                # Recover the predicted sequence
                y_d = (1 / torch.sqrt(alpha_pre)) * (y_d - ((1 - alpha_pre) / (torch.sqrt(1 - alpha_hat_pre))) * predicted_noise_pre) \
                    + torch.sqrt(beta_pre) * noise_pre
                ## Second branch
                alpha_hat_prev = self._alpha_hat[t_prev][:, None, None]
                # Add noise
                y_n = (torch.sqrt(alpha_hat_prev) * y) + (torch.sqrt(1 - alpha_hat_prev) * noise_pre)
                ## Mask completion
                # Get M values
                mask = torch.zeros_like(x, device=self.device) # [batch, T, C*V]
                for m in range(0, self.n_his):
                    mask[:, m, :] = 1
                # iDCT transformation
                y_d_idct = torch.matmul(idct_m_all, y_d)
                y_n_idct = torch.matmul(idct_m_all, y_n)
                # mask-mul
                m_mul_y_n = torch.mul(mask, y_n_idct)
                m_mul_y_d = torch.mul((1-mask), y_d_idct)
                # together
                m_y = m_mul_y_d + m_mul_y_n
                # DCT again
                y_d = torch.matmul(dct_m_all, m_y)

            # iDCT
            pre_future_data = torch.matmul(idct_m_all, y_d)
            # (B, T, C*V) -> (B, T, V, C)
            pre_future_data = pre_future_data.reshape(pre_future_data.shape[0], pre_future_data.shape[1], -1, 2)
            # (B, T, V, C) -> (B, C, T, V)
            pre_future_data = pre_future_data.permute(0, 3, 1, 2).contiguous()
            # select future sequences
            pre_future_data = pre_future_data[:,:,self.n_his:,:]

            ## Reconstruction + Prediction
            xs = torch.cat((rec_his_data, pre_future_data), dim=2)

            generated_xs.append(xs)

        selected_x, loss_of_selected_x = self._aggregation_strategy(generated_xs, tensor_data, aggr_strategy)

        return self._pack_out_data(selected_x, loss_of_selected_x, [tensor_data] + meta_out, return_=return_)


    def training_step(self, batch: List[torch.Tensor], batch_idx: int) -> torch.float32:
        """
        Training step of the model.
        """

        ## Get the optimizer returned in configuration_optimizers()
        opt = self.optimizers()

        ## Unpack data: tensor_data is the input data
        tensor_data, _ = self._unpack_data(batch)

        ## Select frames to reconstruct and to predict
        history_data = tensor_data[:, :, :self.n_his, :] # Used for rec (first n_his)
        x_0 = tensor_data # Used for pre（all）

        ## Reconstruction
        # Encode the history data
        condition_embedding, rec_his_data = self.rec_model(history_data)
        # Compute the rec_loss
        rec_loss = torch.mean(self.loss_fn(rec_his_data, history_data))
        self.log('rec_loss', rec_loss)

        ## Prediction
        # DCT transformation
        dct_m, _ = get_dct_matrix(self.n_frames)
        dct_m_all = dct_m.float().to(self.device)
        # (B, C, T, V) -> (B, T, V, C)
        x = x_0.permute(0, 2, 3, 1).contiguous()
        # (B, T, V, C) -> (B, T, C*V)
        x = x.reshape([x.shape[0], self.n_frames, -1]) # [batch, T, C*V]
        y_0 = torch.matmul(dct_m_all, x)

        # Sample the time steps and corrupt the data
        t = self.noise_scheduler.sample_timesteps(y_0.shape[0]).to(self.device)
        y_t, pre_noise = self.noise_scheduler.noise_motion(y_0, t) # (B, T, C*(V-1))

        # Predict the noise
        pre_predicted_noise, series, prior, _ = self.pre_model(y_t, t, condition_data=condition_embedding)

        # Compute the pre_loss
        # Calculate Association discrepancy
        series_loss = 0.0
        prior_loss = 0.0
        for u in range(len(prior)):
            # Pdetach, S <-> Maximize
            series_loss += \
                (torch.mean(my_kl_loss(
                    series[u],
                    (prior[u] / torch.unsqueeze(torch.sum(prior[u], dim=-1), dim=-1).repeat(1, 1, 1, self.n_frames)).detach()))
                + torch.mean(my_kl_loss(
                    (prior[u] / torch.unsqueeze(torch.sum(prior[u], dim=-1), dim=-1).repeat(1, 1, 1, self.n_frames)).detach(),
                    series[u])))
            # P, Sdetach <-> Minimize
            prior_loss += \
                (torch.mean(my_kl_loss(
                    (prior[u] / torch.unsqueeze(torch.sum(prior[u], dim=-1), dim=-1).repeat(1, 1, 1, self.n_frames)),
                    series[u].detach()))
                + torch.mean(my_kl_loss(
                    series[u].detach(),
                    (prior[u] / torch.unsqueeze(torch.sum(prior[u], dim=-1), dim=-1).repeat(1, 1, 1, self.n_frames)))))
        series_loss = series_loss / len(prior)
        prior_loss = prior_loss / len(prior)

        pre_loss = torch.mean(self.loss_fn(pre_predicted_noise, pre_noise))
        self.log('pre_loss', pre_loss)

        ## Compute loss1 & loss2
        loss1 = rec_loss + pre_loss \
                - self.loss_1_series_weight * series_loss \
                + self.loss_1_prior_weight * prior_loss
        self.log('loss1', loss1)
        loss2 = rec_loss + pre_loss \
                + self.loss_2_prior_weight * prior_loss \
                + self.loss_2_series_weight * series_loss
        self.log('loss2', loss2)

        ## Minimax strategy
        self.manual_backward(loss1, retain_graph=True)
        self.manual_backward(loss2)
        opt.step()
        opt.zero_grad()


    def test_step(self, batch: List[torch.Tensor], batch_idx: int) -> None:
        """
        Test step of the model. It saves the output of the model and the input data as
        List[torch.Tensor]: [predicted poses and the loss, tensor_data, transformation_idx, metadata, actual_frames]

        Args:
            batch (List[torch.Tensor]): list containing the following tensors:
                                         - tensor_data: tensor of shape (B, C, T, V) containing the input sequences
                                         - transformation_idx
                                         - metadata
                                         - actual_frames
            batch_idx (int): index of the batch
        """

        self._test_output_list.append(self.forward(batch))
        return


    def on_test_epoch_start(self) -> None:
        """
        Called when the test epoch begins.
        """

        super().on_test_epoch_start()
        self._test_output_list = []
        return


    def on_test_epoch_end(self) -> float:
        """
        Test epoch end of the model.

        Returns:
            float: test auc score
        """

        out, gt_data, trans, meta, frames = processing_data(self._test_output_list)
        del self._test_output_list
        if self.save_tensors:
            tensors = {'prediction': out, 'gt_data': gt_data,
                       'trans': trans, 'metadata': meta, 'frames': frames}
            self._save_tensors(tensors, split_name=self.split, aggr_strategy=self.aggregation_strategy, n_gen=self.n_generated_samples)
        auc_score = self.post_processing(out, gt_data, trans, meta, frames)
        self.log('AUC', auc_score)
        return auc_score


    def validation_step(self, batch: List[torch.Tensor], batch_idx: int) -> None:
        """
        Validation step of the model. It saves the output of the model and the input data as
        List[torch.Tensor]: [predicted poses and the loss, tensor_data, transformation_idx, metadata, actual_frames]

        Args:
            batch (List[torch.Tensor]): list containing the following tensors:
                                         - tensor_data: tensor of shape (B, C, T, V) containing the input sequences
                                         - transformation_idx
                                         - metadata
                                         - actual_frames
            batch_idx (int): index of the batch
        """

        self._validation_output_list.append(self.forward(batch))
        return


    def on_validation_epoch_start(self) -> None:
        """
        Called when the test epoch begins.
        """

        super().on_validation_epoch_start()
        self._validation_output_list = []
        return


    def on_validation_epoch_end(self) -> float:
        """
        Validation epoch end of the model.

        Returns:
            float: validation auc score
        """

        out, gt_data, trans, meta, frames = processing_data(self._validation_output_list)
        del self._validation_output_list
        if self.save_tensors:
            tensors = {'prediction': out, 'gt_data': gt_data,
                       'trans': trans, 'metadata': meta, 'frames': frames}
            self._save_tensors(tensors, split_name=self.split, aggr_strategy=self.aggregation_strategy, n_gen=self.n_generated_samples)
        auc_score = self.post_processing(out, gt_data, trans, meta, frames)
        self.log('AUC', auc_score, sync_dist=True)
        return auc_score


    def configure_optimizers(self) -> Dict:
        """
        Configure the optimizers and the learning rate schedulers.

        Returns:
            Dict: dictionary containing the optimizers, the learning rate schedulers and the metric to monitor
        """

        optimizer = Adam(self.parameters(), lr=self.learning_rate)
        scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99, last_epoch=-1)

        return {'optimizer': optimizer, 'lr_scheduler': scheduler, 'monitor': 'AUC'}


    def post_processing(self, out: np.ndarray, gt_data: np.ndarray, trans: np.ndarray, meta: np.ndarray, frames: np.ndarray) -> float:
        """
        Post processing of the model.

        Args:
            out (np.ndarray): output of the model
            gt_data (np.ndarray): ground truth data
            trans (np.ndarray): transformation index
            meta (np.ndarray): metadata
            frames (np.ndarray): frame indexes of the data

        Returns:
            float: auc score
        """

        all_gts = [file_name for file_name in os.listdir(self.gt_path) if file_name.endswith('.npy')]
        all_gts = sorted(all_gts)
        scene_clips = [(int(fn.split('_')[0]), int(fn.split('_')[1].split('.')[0])) for fn in all_gts]
        hr_ubnormal_masked_clips = get_hr_ubnormal_mask(self.split) if (self.use_hr and (self.dataset_name == 'UBnormal')) else {}
        hr_avenue_masked_clips = get_avenue_mask() if self.dataset_name == 'HR-Avenue' else {}

        num_transform = self.num_transforms
        model_scores_transf = {}
        dataset_gt_transf = {}

        for transformation in tqdm(range(num_transform)):
            # iterating over each transformation T

            dataset_gt = []
            model_scores = []
            cond_transform = (trans == transformation)

            out_transform, gt_data_transform, meta_transform, frames_transform = filter_vectors_by_cond([out, gt_data, meta, frames], cond_transform)

            for idx in range(len(all_gts)):
                # iterating over each clip C with transformation T

                scene_idx, clip_idx = scene_clips[idx]

                gt = np.load(os.path.join(self.gt_path, all_gts[idx]))
                n_frames = gt.shape[0]

                cond_scene_clip = (meta_transform[:, 0] == scene_idx) & (meta_transform[:, 1] == clip_idx)
                out_scene_clip, gt_scene_clip, meta_scene_clip, frames_scene_clip = filter_vectors_by_cond([out_transform, gt_data_transform,
                                                                                                            meta_transform, frames_transform],
                                                                                                           cond_scene_clip)

                figs_ids = sorted(list(set(meta_scene_clip[:, 2])))
                error_per_person = []
                error_per_person_max_loss = []

                for fig in figs_ids:
                    # iterating over each actor A in each clip C with transformation T

                    cond_fig = (meta_scene_clip[:, 2] == fig)
                    out_fig, _, frames_fig = filter_vectors_by_cond([out_scene_clip, gt_scene_clip, frames_scene_clip], cond_fig)
                    loss_matrix = compute_var_matrix(out_fig, frames_fig, n_frames)
                    fig_reconstruction_loss = np.nanmax(loss_matrix, axis=0)
                    if self.anomaly_score_pad_size != -1:
                        fig_reconstruction_loss = pad_scores(fig_reconstruction_loss, gt, self.anomaly_score_pad_size)

                    error_per_person.append(fig_reconstruction_loss)
                    error_per_person_max_loss.append(max(fig_reconstruction_loss))

                clip_score = np.stack(error_per_person, axis=0)
                clip_score_log = np.log1p(clip_score)
                clip_score = np.mean(clip_score, axis=0) + (np.amax(clip_score_log, axis=0)-np.amin(clip_score_log, axis=0))

                # removing the non-HR frames for UBnormal dataset
                if (scene_idx, clip_idx) in hr_ubnormal_masked_clips:
                    clip_score = clip_score[hr_ubnormal_masked_clips[(scene_idx, clip_idx)]]
                    gt = gt[hr_ubnormal_masked_clips[(scene_idx, clip_idx)]]

                # removing the non-HR frames for Avenue dataset
                if clip_idx in hr_avenue_masked_clips:
                    clip_score = clip_score[np.array(hr_avenue_masked_clips[clip_idx])==1]
                    gt = gt[np.array(hr_avenue_masked_clips[clip_idx])==1]

                # Abnormal score per frame
                clip_score = score_process(clip_score, self.anomaly_score_frames_shift, self.anomaly_score_filter_kernel_size)
                model_scores.append(clip_score)

                dataset_gt.append(gt)

            model_scores = np.concatenate(model_scores, axis=0)

            dataset_gt = np.concatenate(dataset_gt, axis=0)

            model_scores_transf[transformation] = model_scores
            dataset_gt_transf[transformation] = dataset_gt

        # aggregating the anomaly scores for all transformations
        pds = np.mean(np.stack(list(model_scores_transf.values()), 0), 0)
        gt = dataset_gt_transf[0]

        # computing the AUC
        auc = roc_auc_score(gt, pds)

        return auc


    def test_on_saved_tensors(self, split_name: str) -> float:
        """
        Skip the prediction step and test the model on the saved tensors.

        Args:
            split_name (str): split name (val, test)

        Returns:
            float: auc score
        """

        tensors = self._load_tensors(split_name, self.aggregation_strategy, self.n_generated_samples)
        auc_score = self.post_processing(tensors['prediction'], tensors['gt_data'], tensors['trans'],
                                         tensors['metadata'], tensors['frames'])
        print(f'AUC score: {auc_score:.6f}')
        return auc_score


    ## Helper functions

    def _aggregation_strategy(self, generated_xs: List[torch.Tensor], input_sequence: torch.Tensor, aggr_strategy: str) -> Tuple[torch.Tensor]:
        """
        Aggregates the generated samples and returns the selected one and its reconstruction error.
        Strategies:
            - all: returns all the generated samples
            - random: returns a random sample
            - best: returns the sample with the lowest reconstruction loss
            - worst: returns the sample with the highest reconstruction loss
            - mean: returns the mean of the losses of the generated samples
            - median: returns the median of the losses of the generated samples
            - mean_pose: returns the mean of the generated samples
            - median_pose: returns the median of the generated samples

        Args:
            generated_xs (List[torch.Tensor]): list of generated samples
            input_sequence (torch.Tensor): ground truth sequence
            aggr_strategy (str): aggregation strategy

        Raises:
            ValueError: if the aggregation strategy is not valid

        Returns:
            Tuple[torch.Tensor]: selected sample and its reconstruction error
        """

        aggr_strategy = self.aggregation_strategy if aggr_strategy is None else aggr_strategy
        if aggr_strategy == 'random':
            return generated_xs[np.random.randint(len(generated_xs))], None # Added None as it was missing

        B, repr_shape = input_sequence.shape[0], input_sequence.shape[1:]
        compute_loss = lambda x: torch.mean(self.loss_fn(x, input_sequence).reshape(-1, prod(repr_shape)), dim=-1)
        losses = [compute_loss(x) for x in generated_xs]

        if aggr_strategy == 'all':
            dims_idxs = list(range(2, len(repr_shape)+2))
            dims_idxs = [1, 0] + dims_idxs
            selected_x = torch.stack(generated_xs).permute(*dims_idxs)
            loss_of_selected_x = torch.stack(losses).permute(1, 0)
        elif aggr_strategy == 'mean':
            selected_x = None
            loss_of_selected_x = torch.mean(torch.stack(losses), dim=0)
        elif aggr_strategy == 'mean_pose':
            selected_x = torch.mean(torch.stack(generated_xs), dim=0)
            loss_of_selected_x = compute_loss(selected_x)
        elif aggr_strategy == 'median':
            loss_of_selected_x, _ = torch.median(torch.stack(losses), dim=0)
            selected_x = None
        elif aggr_strategy == 'median_pose':
            selected_x, _ = torch.median(torch.stack(generated_xs), dim=0)
            loss_of_selected_x = compute_loss(selected_x)
        elif aggr_strategy == 'best' or aggr_strategy == 'worst':
            strategy = (lambda x, y: x < y) if aggr_strategy == 'best' else (lambda x, y: x > y)
            loss_of_selected_x = torch.full((B,), fill_value=(1e10 if aggr_strategy == 'best' else -1.), device=self.device)
            selected_x = torch.zeros((B, *repr_shape)).to(self.device)

            for i in range(len(generated_xs)):
                mask = strategy(losses[i], loss_of_selected_x)
                loss_of_selected_x[mask] = losses[i][mask]
                selected_x[mask] = generated_xs[i][mask]
        elif 'quantile' in aggr_strategy:
            q = float(aggr_strategy.split(':')[-1])
            loss_of_selected_x = torch.quantile(torch.stack(losses), q, dim=0)
            selected_x = None
        else:
            raise ValueError(f'Unknown aggregation strategy {aggr_strategy}')

        # Ensuring selected_x and loss_of_selected_x are always returned
        if selected_x is None and loss_of_selected_x is None:
             # Default to mean loss if strategy doesn't return both
             loss_of_selected_x = torch.mean(torch.stack(losses), dim=0)


        return selected_x, loss_of_selected_x


    def _infer_number_of_joint(self, args: argparse.Namespace) -> int:
        """
        Infer the number of joints based on the dataset parameters.

        Args:
            args (argparse.Namespace): arguments containing the hyperparameters of the model

        Returns:
            int: number of joints
        """

        if args.headless:
            joints_to_consider = 14
        elif args.kp18_format:
            joints_to_consider = 18
        else:
            joints_to_consider = 17
        return joints_to_consider


    def _load_tensors(self, split_name: str, aggr_strategy: str, n_gen: int) -> Dict[str, torch.Tensor]:
        """
        Loads the tensors from the experiment directory.

        Args:
            split_name (str): name of the split (train, val, test)
            aggr_strategy (str): aggregation strategy
            n_gen (int): number of generated samples

        Returns:
            Dict[str, torch.Tensor]: dictionary containing the tensors. The keys are inferred from the file names.
        """

        name = 'saved_tensors_{}_{}_{}'.format(split_name, aggr_strategy, n_gen)
        path = os.path.join(self.ckpt_dir, name)
        if not os.path.exists(path):
            os.mkdir(path)
        tensor_files = os.listdir(path)
        tensors = {}
        for t_file in tensor_files:
            t_name = t_file.split('.')[0]
            tensors[t_name] = torch.load(os.path.join(path, t_file))
        return tensors


    def _pack_out_data(self, selected_x: torch.Tensor, loss_of_selected_x: torch.Tensor, additional_out: List[torch.Tensor], return_: str) -> List[torch.Tensor]:
        """
        Packs the output data according to the return_ argument.

        Args:
            selected_x (torch.Tensor): generated samples selected among the others according to the aggregation strategy
            loss_of_selected_x (torch.Tensor): loss of the selected samples
            additional_out (List[torch.Tensor]): additional output data (ground truth, meta-data, etc.)
            return_ (str): return strategy. Can be 'pose', 'loss', 'all'

        Raises:
            ValueError: if return_ is None and self.model_return_value is None

        Returns:
            List[torch.Tensor]: output data
        """

        if return_ is None:
            if self.model_return_value is None:
                raise ValueError('Either return_ or self.model_return_value must be set')
            else:
                return_ = self.model_return_value

        if return_ == 'poses':
            out = [selected_x]
        elif return_ == 'loss':
            out = [loss_of_selected_x]
        elif return_ == 'all':
            # Check if both are available before adding to the list
            out = []
            if loss_of_selected_x is not None:
                out.append(loss_of_selected_x)
            if selected_x is not None:
                 out.append(selected_x)

        return out + additional_out


    def _save_tensors(self, tensors: Dict[str, torch.Tensor], split_name: str, aggr_strategy: str, n_gen: int) -> None:
        """
        Saves the tensors in the experiment directory.

        Args:
            tensors (Dict[str, torch.Tensor]): tensors to save
            split_name (str): name of the split (val, test)
            aggr_strategy (str): aggregation strategy
            n_gen (int): number of generated samples
        """

        name = 'saved_tensors_{}_{}_{}'.format(split_name, aggr_strategy, n_gen)
        path = os.path.join(self.ckpt_dir, name)
        if not os.path.exists(path):
            os.mkdir(path)
        for t_name, tensor in tensors.items():
            torch.save(tensor, os.path.join(path, t_name + '.pt'))


    def _set_diffusion_variables(self) -> None:
        """
        Sets the diffusion variables.
        """

        self.noise_scheduler = Diffusion(noise_steps=self.noise_steps, n_joints=self.n_joints,
                                         device=self.device, time=self.n_frames)
        self._beta_ = self.noise_scheduler.schedule_noise()
        self._alpha_ = (1. - self._beta_)
        self._alpha_hat_ = torch.cumprod(self._alpha_, dim=0)

    def _unpack_data(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
        """
        Unpacks the data.

        Args:
            x (torch.Tensor): list containing the input data, the transformation index, the metadata and the actual frames.

        Returns:
            Tuple[torch.Tensor, List[torch.Tensor]]: input data, list containing the transformation index, the metadata and the actual frames.
        """
        tensor_data = x[0].to(self.device)
        transformation_idx = x[1]
        metadata = x[2]
        actual_frames = x[3]
        meta_out = [transformation_idx, metadata, actual_frames]
        return tensor_data, meta_out


    @property
    def _beta(self) -> torch.Tensor:
        return self._beta_.to(self.device)


    @property
    def _alpha(self) -> torch.Tensor:
        return self._alpha_.to(self.device)


    @property
    def _alpha_hat(self) -> torch.Tensor:
        return self._alpha_hat_.to(self.device)

In [None]:
%%writefile /kaggle/working/DCMD-main/eval_DCMD.py
import argparse
import os

import pytorch_lightning as pl
import yaml
from models.dcmd import DCMD
from utils.argparser import init_args
from utils.dataset import get_dataset_and_loader



if __name__== '__main__':
    
    # Parse command line arguments and load config file
    parser = argparse.ArgumentParser(description='DCMD')
    parser.add_argument('-c', '--config', type=str, required=True)
    args = parser.parse_args()
    args = yaml.load(open(args.config), Loader=yaml.FullLoader)
    args = argparse.Namespace(**args)
    args = init_args(args)

    # Initialize the model
    model = DCMD(args)
    
    if args.load_tensors:
        # Load tensors and test
        model.test_on_saved_tensors(split_name=args.split)
    else:
        # Load test data
        print('Loading data and creating loaders.....')
        ckpt_path = '/kaggle/input/checkpoints/kaggle/working/DCMD-main/checkpoints/HR-Avenue/train_experiment/last.ckpt'
        dataset, loader, _, _ = get_dataset_and_loader(args, split=args.split)
        
        # Initialize trainer and test
        trainer = pl.Trainer(accelerator=args.accelerator, devices=args.devices[:1],
                             default_root_dir=args.ckpt_dir, max_epochs=1, logger=False)
        out = trainer.test(model, dataloaders=loader, ckpt_path=ckpt_path)

In [46]:
%%writefile /kaggle/working/DCMD-main/utils/get_robust_data.py
import os
import numpy as np
import pickle

from copy import deepcopy

import torch

from utils.data import load_trajectories, extract_global_features
from utils.data import change_coordinate_system, scale_trajectories, aggregate_autoencoder_data
from utils.data import input_trajectories_missing_steps
from utils.preprocessing import remove_short_trajectories, aggregate_rnn_autoencoder_data


def save_scaler(scaler, path):
    with open(path, 'wb') as scaler_file:
        pickle.dump(scaler, scaler_file)
    
        
def load_scaler(path):
    with open(path, 'rb') as scaler_file:
        scaler = pickle.load(scaler_file)
    return scaler



# Load trajectory data and convert it into a format suitable for RNN autoencoders for training and testing of joint models
def data_of_combined_model(**args):
    # General
    exp_dir = args.get('exp_dir', '')
    split = args.get('split', 'train')
    normalize_pose = args.get('normalize_pose', False)
    trajectories_path = args.get('trajectories_path', '')
    include_global = args.get('include_global', True)
    debug = args.get('debug', False)
    if 'train' in split:
      # đổi thành path khác trong folder data nếu train dataset khác
        subfolder = '/kaggle/input/avenue/Avenue/training'
    elif 'test' in split:
        subfolder = '/kaggle/input/avenue/Avenue/testing'
    else:
        subfolder = 'validating'
    trajectories_path = os.path.join(trajectories_path, f'{subfolder}/trajectories')
    video_resolution = args.get('vid_res', (1080,720))
    video_resolution = np.array(video_resolution, dtype=np.float32)
    # Architecture
    reconstruct_original_data = args.get('reconstruct_original_data', False) 
    input_length = args.get('seg_len', 12)
    seg_stride = args.get('seg_stride', 1) - 1 
    pred_length = 0 
    # Training
    input_missing_steps = False # args.input_missing_steps
    
    if normalize_pose == True:
        global_normalisation_strategy = args.get('normalization_strategy', 'robust')
        local_normalisation_strategy = args.get('normalization_strategy', 'robust')
        out_normalisation_strategy = args.get('normalization_strategy', 'robust')


    trajectories = load_trajectories(trajectories_path, debug=debug, split=split)
    print('\nLoaded %d trajectories.' % len(trajectories))

    trajectories = remove_short_trajectories(trajectories, input_length=input_length,
                                             input_gap=seg_stride, pred_length=pred_length)
    print('\nRemoved short trajectories. Number of trajectories left: %d.' % len(trajectories))

    # trajectories, trajectories_val = split_into_train_and_test(trajectories, train_ratio=0.8, seed=42)

    if input_missing_steps:
        trajectories = input_trajectories_missing_steps(trajectories)
        print('\nInputted missing steps of trajectories.')

    # Global
    if include_global:
        global_trajectories = extract_global_features(deepcopy(trajectories), video_resolution=video_resolution)

        global_trajectories = change_coordinate_system(global_trajectories, video_resolution=video_resolution,
                                                        coordinate_system='global', invert=False)

        print('\nChanged global trajectories\'s coordinate system to global.')
        
        X_global, y_global, X_global_meta, y_global_meta = aggregate_rnn_autoencoder_data(global_trajectories, 
                                                                                        input_length=input_length,
                                                                                        input_gap=seg_stride, pred_length=pred_length, 
                                                                                        return_ids=True)
        
        if normalize_pose == True:
            # nếu test avenue thì dùng dòng này
            #/content/checkpoints/HR-Avenue/train_experiment/local_robust.pickle
            # default: scaler_path = os.path.join(exp_dir, f'global_{global_normalisation_strategy}.pickle')
            scaler_path = '/kaggle/input/checkpoints/kaggle/working/DCMD-main/checkpoints/HR-Avenue/train_experiment/local_robust.pickle'
            if split == 'train':
                _, global_scaler = scale_trajectories(aggregate_autoencoder_data(global_trajectories),
                                                    strategy=global_normalisation_strategy)
                save_scaler(global_scaler, scaler_path)
            else:
                global_scaler = load_scaler(scaler_path)

            X_global, _ = scale_trajectories(X_global, scaler=global_scaler, strategy=global_normalisation_strategy)
            
            if y_global is not None:
                y_global, _ = scale_trajectories(y_global, scaler=global_scaler,
                                                strategy=global_normalisation_strategy)
                
            print('\nNormalised global trajectories using the %s normalisation strategy.' % global_normalisation_strategy)
    
    else:
        X_global, X_global_meta = None, None
    
    # Local
    local_trajectories = deepcopy(trajectories) if reconstruct_original_data else trajectories

    local_trajectories = change_coordinate_system(local_trajectories, video_resolution=video_resolution,
                                                  coordinate_system='bounding_box_centre', invert=False)

    print('\nChanged local trajectories\'s coordinate system to bounding_box_centre.')

    X_local, y_local, X_local_meta, y_local_meta = aggregate_rnn_autoencoder_data(local_trajectories, input_length=input_length, 
                                                                                  input_gap=seg_stride, pred_length=pred_length,
                                                                                  return_ids=True)
    
    if normalize_pose == True:
        #scaler_path = '/content/drive/MyDrive/DCMD-main-main/checkpoints/Avenue/test_experiment/local_robust.pickle'
        scaler_path = os.path.join(exp_dir, f'local_{local_normalisation_strategy}.pickle')

        if split == 'train':
            _, local_scaler = scale_trajectories(aggregate_autoencoder_data(local_trajectories),
                                                strategy=local_normalisation_strategy)
            save_scaler(local_scaler, scaler_path)
        else:
            local_scaler = load_scaler("/kaggle/input/checkpoints/kaggle/working/DCMD-main/checkpoints/HR-Avenue/train_experiment/local_robust.pickle")

        X_local, _ = scale_trajectories(X_local, scaler=local_scaler, strategy=local_normalisation_strategy)

        if y_local is not None:
            y_local, _ = scale_trajectories(y_local, scaler=local_scaler, strategy=local_normalisation_strategy)
        
        print('\nNormalised local trajectories using the %s normalisation strategy.' % local_normalisation_strategy)

    # (Optional) Reconstruct the original data
    if reconstruct_original_data:
        print('\nReconstruction/Prediction target is the original data.')
        out_trajectories = trajectories
        
        out_trajectories = change_coordinate_system(out_trajectories, video_resolution=video_resolution,
                                                    coordinate_system='global', invert=False)
    
        print('\nChanged target trajectories\'s coordinate system to global.')
        
        scaler_path = os.path.join(exp_dir, f'out_{out_normalisation_strategy}.pickle')
    
        if split == 'train':
            _, out_scaler = scale_trajectories(aggregate_autoencoder_data(out_trajectories),
                                               strategy=out_normalisation_strategy)
            save_scaler(out_scaler, scaler_path)
        else:
            out_scaler = load_scaler(scaler_path)
        
        ######## X_out_{}, y_out_{} numpy arrays

        X_out, y_out, X_out_meta, y_out_meta = aggregate_rnn_autoencoder_data(out_trajectories, input_length=input_length, 
                                                                              input_gap=seg_stride, pred_length=pred_length,
                                                                              return_ids=True)

        X_out, _ = scale_trajectories(X_out, scaler=out_scaler, strategy=out_normalisation_strategy)
        
        if y_out is not None:
            y_out, _ = scale_trajectories(y_out, scaler=out_scaler, strategy=out_normalisation_strategy)
            
        print('\nNormalised target trajectories using the %s normalisation strategy.' % out_normalisation_strategy)
        
            
    if pred_length > 0:
        
        if reconstruct_original_data:
            return (X_global, X_global_meta), \
                   (X_local, X_local_meta), \
                   (X_out, X_out_meta), \
                   (y_global, y_global_meta), \
                   (y_local, y_local_meta), \
                   (y_out, y_out_meta) 
        else:
            return (X_global, X_global_meta), \
                   (X_local, X_local_meta), \
                   (y_global, y_global_meta), \
                   (y_local, y_local_meta)
    else:
        if reconstruct_original_data:
            return (X_global, X_global_meta), \
                   (X_local, X_local_meta), \
                   (X_out, X_out_meta)
        else:
            return (X_global, X_global_meta), \
                   (X_local, X_local_meta)

Overwriting /kaggle/working/DCMD-main/utils/get_robust_data.py


# **Train** 

In [56]:
!touch /kaggle/working/DCMD-main/models/__init__.py
!touch /kaggle/working/DCMD-main/models/gcae/__init__.py
!touch /kaggle/working/DCMD-main/models/common/__init__.py
!touch /kaggle/working/DCMD-main/models/stsae/__init__.py


In [62]:
!pip install torch_geometric

Collecting torch_geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m1.5 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
Downloading torch_geometric-2.6.1-py3-none-any.whl (1.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m19.9 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hInstalling collected packages: torch_geometric
Successfully installed torch_geometric-2.6.1


In [142]:
!python /kaggle/working/DCMD-main/train_DCMD.py --config /kaggle/working/DCMD-main/config/Avenue/dcmd_train.yaml

Experiment directories created in /kaggle/working/DCMD-main/checkpoints/HR-Avenue/train_experiment

Loaded 2786 trajectories.

Removed short trajectories. Number of trajectories left: 1735.

Changed local trajectories's coordinate system to bounding_box_centre.

Normalised local trajectories using the robust normalisation strategy.
Traceback (most recent call last):
  File "/kaggle/working/DCMD-main/train_DCMD.py", line 85, in <module>
    model = DCMD(args)
            ^^^^^^^^^^
  File "/kaggle/working/DCMD-main/models/dcmd.py", line 98, in __init__
    self.build_model()
  File "/kaggle/working/DCMD-main/models/dcmd.py", line 122, in build_model
    self.rec_model = STSAE(
                     ^^^^^^
  File "/kaggle/working/DCMD-main/models/stsae/stsae.py", line 79, in __init__
    super(STSAE, self).__init__(c_in, h_dim, latent_dim, n_frames, n_joints,
  File "/kaggle/working/DCMD-main/models/stsae/stsae.py", line 45, in __init__
    self.build_model()
  File "/kaggle/working/DCMD-

If download checkpoint needed, run this code and download checkpoints.zip file in Output

In [None]:
# 1. Zip toàn bộ folder checkpoints
!zip -r /kaggle/working/checkpoints.zip /kaggle/working/DCMD-main/checkpoints

# 2. (tuỳ chọn) kiểm tra xem zip đã tạo xong chưa
!ls -lh /kaggle/working/checkpoints.zip


# **Testing**

In [None]:
import os
os.makedirs('/kaggle/working/DCMD-main/checkpoints/HR-Avenue/test_experiment', exist_ok=True)

In [None]:
!python /kaggle/working/DCMD-main/eval_DCMD.py --config /kaggle/working/DCMD-main/config/Avenue/dcmd_test.yaml

In [None]:
# 1. Zip toàn bộ folder checkpoints
!zip -r /kaggle/working/checkpoints.zip /kaggle/working/DCMD-main/checkpoints



# **Chỉnh phương pháp**

In [114]:
%%writefile /kaggle/working/DCMD-main/models/gcae/stsgcn.py
# Bạn có thể đăng code ST-SGCN gốc ở đây, hoặc để trống nếu không dùng.
class STSGCN:
    def __init__(self, *args, **kwargs):
        raise NotImplementedError("STSGCN chưa được định nghĩa.")


Overwriting /kaggle/working/DCMD-main/models/gcae/stsgcn.py


In [130]:
%%writefile /kaggle/working/DCMD-main/models/stsae/stsae.py

import torch
import torch.nn as nn
from typing import List, Union
from argparse import Namespace
from models.common.components import build_encoder, build_decoder

class STSE(nn.Module):
    def __init__(self, c_in: int, h_dim: int, latent_dim: int, n_frames: int,
                 n_joints: int, layer_channels: List[int], dropout: float,
                 device: Union[str, torch.device], use_adaptive=False,
                 use_jigsaw=False, emb_dim=None) -> None:
        super(STSE, self).__init__()

        # Save config to pass to decoder
        self._cfg = Namespace(
                    input_dim=c_in,
                    in_channels=c_in,
                    h_dim=h_dim,
                    latent_dim=latent_dim,
                    n_frames=n_frames,
                    n_joints=n_joints,
                    layer_channels=layer_channels,
                    dropout=dropout,
                    device=device,
                    emb_dim=emb_dim,
                    use_adaptive=use_adaptive,
                    use_jigsaw=use_jigsaw
                )

        # Set attributes
        self.input_dim = c_in
        self.h_dim = h_dim
        self.latent_dim = latent_dim
        self.n_frames = n_frames
        self.n_joints = n_joints
        self.layer_channels = layer_channels
        self.dropout = dropout
        self.device = device
        self.use_adaptive = use_adaptive
        self.use_jigsaw = use_jigsaw
        self.emb_dim = emb_dim

        # Build model
        self.build_model()

    def build_model(self):
        self.encoder = build_encoder(self._cfg)
        self.btlnk = nn.Linear(self.h_dim * self.n_frames * self.n_joints, self.latent_dim)

    def encode(self, X: torch.Tensor, return_shape: bool = False, t: torch.Tensor = None):
        X = X.unsqueeze(4)
        N, C, T, V, M = X.size()
        X = X.permute(0, 4, 3, 1, 2).contiguous()
        X = X.view(N * M, V, C, T).permute(0, 2, 3, 1).contiguous()

        X, *_ = self.encoder(X, t)
        N, C, T, V = X.size()
        X = X.view(N, -1).contiguous()
        X = X.view(N, M, self.h_dim, T, V).permute(0, 2, 3, 4, 1).contiguous()
        X_shape = X.size()
        X = X.view(N, -1).contiguous()

        X = self.btlnk(X)

        if return_shape:
            return X, X_shape
        return X

    def forward(self, X: torch.Tensor, t: torch.Tensor = None):
        return self.encode(X, return_shape=False, t=t), None


class STSAE(STSE):
    def __init__(self, c_in: int, h_dim: int, latent_dim: int, n_frames: int,
                 n_joints: int, layer_channels: List[int], dropout: float,
                 device: Union[str, torch.device], use_adaptive=False,
                 use_jigsaw=False, emb_dim=None) -> None:
        super(STSAE, self).__init__(c_in, h_dim, latent_dim, n_frames, n_joints,
                                    layer_channels, dropout, device, use_adaptive,
                                    use_jigsaw, emb_dim)

    def build_model(self):
        super().build_model()
        self.decoder = build_decoder(self._cfg)
        self.rev_btlnk = nn.Linear(self.latent_dim, self.h_dim * self.n_frames * self.n_joints)

    def decode(self, Z: torch.Tensor, input_shape, t: torch.Tensor = None):
        Z = self.rev_btlnk(Z)
        N, C, T, V, M = input_shape
        Z = Z.view(input_shape).contiguous()
        Z = Z.permute(0, 4, 1, 2, 3).contiguous()
        Z = Z.view(N * M, C, T, V).contiguous()
        Z = self.decoder(Z, t)
        return Z

    def forward(self, X: torch.Tensor, t: torch.Tensor = None):
        hidden_X, X_shape = self.encode(X, return_shape=True, t=t)
        X = self.decode(hidden_X, X_shape, t)
        return hidden_X, X


Overwriting /kaggle/working/DCMD-main/models/stsae/stsae.py


In [141]:
%%writefile /kaggle/working/DCMD-main/models/gcae/enhanced_stgcn.py
import math
import random
import itertools

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import networkx as nx
from networkx.algorithms.community import girvan_newman

from torch_geometric.nn import GATConv
from torch_geometric.utils import dense_to_sparse
from typing import Union, List, Tuple

# --- Adaptive + Jigsaw (GDN) những hàm tiện ích ---
def create_graph_from_adjacency_matrix(adj_matrix):
    G = nx.DiGraph()
    for i, targets in enumerate(adj_matrix):
        for j in targets:
            if i != j:
                G.add_edge(i, j.item())
    return G

def detect_communities_girvan_newman(G, target_communities=4):
    try:
        G_undirected = G.to_undirected()
        comm_iter = girvan_newman(G_undirected)
        for communities in itertools.islice(comm_iter, target_communities-1):
            if len(communities) >= target_communities:
                return communities
    except:
        return None

def flip_two_communities_directed(adj_matrix, communities, community_id1, community_id2):
    flipped_adj_matrix = np.copy(adj_matrix)
    nodes1 = list(communities[community_id1])
    nodes2 = list(communities[community_id2])
    min_len = min(len(nodes1), len(nodes2))
    nodes1, nodes2 = nodes1[:min_len], nodes2[:min_len]
    node_mapping = {node: node for node in range(len(adj_matrix))}
    for n1, n2 in zip(nodes1, nodes2):
        node_mapping[n1] = n2
        node_mapping[n2] = n1
    for i in range(len(adj_matrix)):
        for j in range(len(adj_matrix)):
            flipped_adj_matrix[node_mapping[i], node_mapping[j]] = adj_matrix[i, j]
    return flipped_adj_matrix

def adjacency_matrix_to_list(flipped_adj_matrix, topk):
    adjacency_list = []
    for i, row in enumerate(flipped_adj_matrix):
        connected_nodes = np.nonzero(row)[0]
        filtered_nodes = [node for node in connected_nodes if node != i][:topk-1]
        filtered_nodes.insert(0, i)
        if len(filtered_nodes) < topk:
            filtered_nodes.extend([i] * (topk - len(filtered_nodes)))
        adjacency_list.append(filtered_nodes[:topk])
    return adjacency_list

# --- Adaptive Graph Convolution with Jigsaw (GDN) ---

class AdaptiveGraphConv(nn.Module):
    def __init__(self, time_dim, joints_dim, embed_dim=64, topk=8, use_jigsaw=True, jigsaw_prob=0.3):
        super().__init__()
        self.time_dim = time_dim
        self.joints_dim = joints_dim
        self.embed_dim = embed_dim
        self.topk = topk
        self.use_jigsaw = use_jigsaw
        self.jigsaw_prob = jigsaw_prob

        self.A = nn.Parameter(torch.FloatTensor(time_dim, joints_dim, joints_dim))
        self.T = nn.Parameter(torch.FloatTensor(joints_dim, time_dim, time_dim))

        self.node_embedding = nn.Embedding(joints_dim, embed_dim)
        self.adaptive_weight = nn.Parameter(torch.FloatTensor(1))

        self.attention = nn.MultiheadAttention(embed_dim, num_heads=4, batch_first=True)

        if use_jigsaw:
            self.puzzle_classifier = nn.Sequential(
                nn.AdaptiveAvgPool1d(1),
                nn.Flatten(),
                nn.Linear(embed_dim, 32),
                nn.ReLU(),
                nn.Dropout(0.2),
                nn.Linear(32, 7)
            )

        self.init_params()

    def init_params(self):
        stdv = 1. / math.sqrt(self.A.size(1))
        self.A.data.uniform_(-stdv, stdv)
        stdv = 1. / math.sqrt(self.T.size(1))
        self.T.data.uniform_(-stdv, stdv)
        nn.init.kaiming_uniform_(self.node_embedding.weight, a=math.sqrt(5))
        nn.init.constant_(self.adaptive_weight, 0.5)

    def get_adaptive_adjacency(self, device):
        all_embeddings = self.node_embedding(torch.arange(self.joints_dim).to(device))
        weights = all_embeddings.view(self.joints_dim, -1)
        cos_ji_mat = torch.matmul(weights, weights.T)
        normed_mat = torch.matmul(weights.norm(dim=-1).view(-1,1), weights.norm(dim=-1).view(1,-1))
        cos_ji_mat = cos_ji_mat / (normed_mat + 1e-8)
        topk_indices = torch.topk(cos_ji_mat, self.topk, dim=-1)[1]
        return topk_indices, all_embeddings

    def apply_jigsaw_puzzle(self, adj_indices):
        if not (self.training and self.use_jigsaw and random.random() < self.jigsaw_prob):
            return adj_indices, 6
        try:
            X1 = adj_indices.cpu().numpy()
            G = create_graph_from_adjacency_matrix(X1)
            partition = detect_communities_girvan_newman(G, target_communities=4)
            if partition is None or len(partition) < 4:
                return adj_indices, 6
            communities = {i: list(c) for i, c in enumerate(partition)}
            community_pairs = [(0,1), (0,2), (0,3), (1,2), (1,3), (2,3)]
            selected_combination = random.randint(0, 5)
            community_id1, community_id2 = community_pairs[selected_combination]
            adj_matrix_directed = np.zeros((len(X1), len(X1)))
            for ii, targets in enumerate(X1):
                for jj in targets:
                    if ii != jj:
                        adj_matrix_directed[ii, jj] = 1
            flipped_adj_matrix = flip_two_communities_directed(
                adj_matrix_directed, communities, community_id1, community_id2
            )
            flipped_adj_list = adjacency_matrix_to_list(flipped_adj_matrix, self.topk)
            output = torch.tensor(flipped_adj_list, dtype=adj_indices.dtype).to(adj_indices.device)
            return output, selected_combination
        except:
            return adj_indices, 6

    def forward(self, X):
        device = X.device
        batch_size = X.shape[0]
        adj_indices, node_embeddings = self.get_adaptive_adjacency(device)
        adj_indices, puzzle_label = self.apply_jigsaw_puzzle(adj_indices)

        adaptive_A = torch.zeros(self.time_dim, self.joints_dim, self.joints_dim, device=device)
        for t in range(self.time_dim):
            for i, neighbors in enumerate(adj_indices):
                adaptive_A[t, i, neighbors] = 1.0
        combined_A = (1 - self.adaptive_weight) * self.A + self.adaptive_weight * adaptive_A

        X_temp = torch.einsum('nctv,vtq->ncqv', (X, self.T)).contiguous()
        X_out = torch.einsum('nctv,tvw->nctw', (X_temp, combined_A)).contiguous()

        puzzle_pred = None
        if self.use_jigsaw and hasattr(self, 'puzzle_classifier'):
            embed_features = node_embeddings.mean(dim=0, keepdim=True).expand(batch_size, -1)
            puzzle_pred = self.puzzle_classifier(embed_features.unsqueeze(-1))

        return X_out, puzzle_pred, puzzle_label


class Enhanced_ST_GCNN_layer(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size: Union[Tuple[int], List[int]],
                 stride, time_dim, joints_dim, dropout, bias=True,
                 emb_dim=None, use_adaptive=True, use_jigsaw=True):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.stride = stride
        self.time_dim = time_dim
        self.joints_dim = joints_dim
        self.dropout = dropout
        self.bias = bias
        self.emb_dim = emb_dim
        self.kernel_size = kernel_size
        self.use_adaptive = use_adaptive
        self.use_jigsaw = use_jigsaw
        self.build_model()

    def build_model(self):
        kernel_size = self.kernel_size  # <== FIX HERE
        padding = ((kernel_size[0] - 1) // 2, (kernel_size[1] - 1) // 2)

        if self.use_adaptive:
            self.gcn = AdaptiveGraphConv(
                self.time_dim, self.joints_dim,
                embed_dim=64, topk=8, use_jigsaw=self.use_jigsaw
            )
        else:
            self.gcn = ConvTemporalGraphical(self.time_dim, self.joints_dim)

        self.tcn = nn.Sequential(
            nn.Conv2d(self.in_channels, self.out_channels, kernel_size, (self.stride, self.stride), padding, bias=self.bias),
            nn.BatchNorm2d(self.out_channels),
            nn.Dropout(self.dropout, inplace=True),
        )

        if self.stride != 1 or self.in_channels != self.out_channels:
            self.residual = nn.Sequential(
                nn.Conv2d(self.in_channels, self.out_channels, kernel_size=1, stride=(1, 1), bias=self.bias),
                nn.BatchNorm2d(self.out_channels)
            )
        else:
            self.residual = nn.Identity()

        self.prelu = nn.PReLU()

        if self.emb_dim is not None:
            self.emb_layer = nn.Sequential(
                nn.SiLU(),
                nn.Linear(self.emb_dim, self.out_channels),
            )

    def forward(self, X, t=None):
        res = self.residual(X)
        if self.use_adaptive:
            X, puzzle_pred, puzzle_label = self.gcn(X)
        else:
            X = self.gcn(X)
            puzzle_pred, puzzle_label = None, None

        X = self.tcn(X)
        X = X + res
        X = self.prelu(X)

        if self.emb_dim is not None and t is not None:
            emb = self.emb_layer(t)[:, :, None, None].repeat(1, 1, X.shape[-2], X.shape[-1]).contiguous()
            X = X + emb

        return X.contiguous(), puzzle_pred, puzzle_label

class ConvTemporalGraphical(nn.Module):
    def __init__(self, time_dim, joints_dim):
        super().__init__()
        self.A = nn.Parameter(torch.FloatTensor(time_dim, joints_dim, joints_dim))
        stdv = 1. / math.sqrt(self.A.size(1))
        self.A.data.uniform_(-stdv, stdv)
        self.T = nn.Parameter(torch.FloatTensor(joints_dim, time_dim, time_dim))
        stdv = 1. / math.sqrt(self.T.size(1))
        self.T.data.uniform_(-stdv, stdv)

    def forward(self, X):
        X = torch.einsum('nctv,vtq->ncqv', (X, self.T)).contiguous()
        X = torch.einsum('nctv,tvw->nctw', (X, self.A)).contiguous()
        return X


class EnhancedEncoder(nn.Module):
    """
    Lớp Encoder “Enhanced” dùng GDN (Adaptive + Jigsaw).
    """
    def __init__(self,
                 input_dim: int,
                 layer_channels: List[int],
                 hidden_dimension: int,
                 n_frames: int,
                 n_joints: int,
                 dropout: float,
                 bias: bool = True,
                 use_adaptive: bool = True,
                 use_jigsaw: bool = True):
        super().__init__()

        self.input_dim = input_dim
        self.layer_channels = layer_channels
        self.hidden_dimension = hidden_dimension
        self.n_frames = n_frames
        self.n_joints = n_joints
        self.dropout = dropout
        self.bias = bias
        self.use_adaptive = use_adaptive
        self.use_jigsaw = use_jigsaw

        self.build_model()

    def build_model(self):
        input_channels = self.input_dim
        layer_channels = self.layer_channels + [self.hidden_dimension]
        kernel_size = [1, 1]
        stride = 1
        model_layers = nn.ModuleList()

        for i, channels in enumerate(layer_channels):
            if i == 0 and self.use_adaptive:
                model_layers.append(
                    Enhanced_ST_GCNN_layer(
                        in_channels=input_channels,
                        out_channels=channels,
                        kernel_size=kernel_size,
                        stride=stride,
                        time_dim=self.n_frames,
                        joints_dim=self.n_joints,
                        dropout=self.dropout,
                        bias=self.bias,
                        emb_dim=self.hidden_dimension,
                        use_adaptive=True,
                        use_jigsaw=self.use_jigsaw
                    )
                )
            else:
                model_layers.append(
                    Enhanced_ST_GCNN_layer(
                        in_channels=input_channels,
                        out_channels=channels,
                        kernel_size=kernel_size,
                        stride=stride,
                        time_dim=self.n_frames,
                        joints_dim=self.n_joints,
                        dropout=self.dropout,
                        bias=self.bias,
                        emb_dim=None,
                        use_adaptive=False,
                        use_jigsaw=False
                    )
                )
            input_channels = channels

        self.model_layers = model_layers

    def forward(self, X: torch.Tensor, t: torch.Tensor = None):
        layers_out = [X]
        puzzle_preds = []
        puzzle_labels = []

        for layer in self.model_layers:
            if hasattr(layer, 'use_adaptive') and layer.use_adaptive:
                out_X, puzzle_pred, puzzle_label = layer(layers_out[-1], t)
                if puzzle_pred is not None:
                    puzzle_preds.append(puzzle_pred)
                if puzzle_label is not None:
                    puzzle_labels.append(puzzle_label)
            else:
                out_X, _, _ = layer(layers_out[-1], t)
            layers_out.append(out_X)

        return layers_out[-1], layers_out[:-1], puzzle_preds, puzzle_labels


class EnhancedDCMDLoss(nn.Module):
    """
    Loss chính + auxiliary Jigsaw puzzle.
    """
    def __init__(self, main_loss_weight: float = 1.0, puzzle_loss_weight: float = 0.1):
        super(EnhancedDCMDLoss, self).__init__()
        self.main_loss_weight = main_loss_weight
        self.puzzle_loss_weight = puzzle_loss_weight
        self.puzzle_criterion = nn.CrossEntropyLoss()

    def forward(self,
                main_pred: torch.Tensor,
                main_target: torch.Tensor,
                puzzle_preds: List[torch.Tensor] = None,
                puzzle_labels: List = None):
        main_loss = F.mse_loss(main_pred, main_target)
        total_loss = self.main_loss_weight * main_loss

        puzzle_loss = torch.tensor(0.0, device=main_pred.device)
        if puzzle_preds and puzzle_labels and len(puzzle_preds) > 0 and len(puzzle_labels) > 0:
            for pred, label in zip(puzzle_preds, puzzle_labels):
                if isinstance(label, (int, list)):
                    if isinstance(label, list):
                        label_tensor = torch.tensor(label, device=pred.device)
                    else:
                        label_tensor = torch.tensor([label] * pred.shape[0], device=pred.device)
                    puzzle_loss += self.puzzle_criterion(pred, label_tensor)
            total_loss += self.puzzle_loss_weight * puzzle_loss

        return total_loss, main_loss, puzzle_loss


Overwriting /kaggle/working/DCMD-main/models/gcae/enhanced_stgcn.py


In [91]:
%%writefile /kaggle/working/DCMD-main/models/gcae/graph_layer.py
import torch
from torch.nn import Parameter, Linear
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import remove_self_loops, add_self_loops, softmax
from torch_geometric.nn.inits import glorot, zeros


class GraphLayer(MessagePassing):
    def __init__(self, in_channels, out_channels, heads=1, concat=True,
                 negative_slope=0.2, dropout=0, bias=True, inter_dim=-1, **kwargs):
        super(GraphLayer, self).__init__(aggr='add', **kwargs)

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.heads = heads
        self.concat = concat
        self.negative_slope = negative_slope
        self.dropout = dropout

        self.__alpha__ = None

        self.lin = Linear(in_channels, heads * out_channels, bias=False)

        self.att_i = Parameter(torch.Tensor(1, heads, out_channels))
        self.att_j = Parameter(torch.Tensor(1, heads, out_channels))
        self.att_em_i = Parameter(torch.Tensor(1, heads, out_channels))
        self.att_em_j = Parameter(torch.Tensor(1, heads, out_channels))

        if bias and concat:
            self.bias = Parameter(torch.Tensor(heads * out_channels))
        elif bias and not concat:
            self.bias = Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)

        self.reset_parameters()

    def reset_parameters(self):
        glorot(self.lin.weight)
        glorot(self.att_i)
        glorot(self.att_j)
        zeros(self.att_em_i)
        zeros(self.att_em_j)
        zeros(self.bias)

    def forward(self, x, edge_index, embedding=None, return_attention_weights=False):
        """
        x: node features (num_nodes, in_channels) hoặc tuple cho graph đôi (x_src, x_dst)
        edge_index: (2, num_edges)
        embedding: nếu có, shape (num_nodes, emb_dim)
        return_attention_weights: nếu True, trả thêm attention weights
        """
        if torch.is_tensor(x):
            x = self.lin(x)
            x = (x, x)
        else:
            x = (self.lin(x[0]), self.lin(x[1]))

        edge_index, _ = remove_self_loops(edge_index)
        edge_index, _ = add_self_loops(edge_index, num_nodes=x[1].size(self.node_dim))

        out = self.propagate(edge_index, x=x, embedding=embedding, edges=edge_index,
                             return_attention_weights=return_attention_weights)

        if self.concat:
            out = out.view(-1, self.heads * self.out_channels)
        else:
            out = out.mean(dim=1)

        if self.bias is not None:
            out = out + self.bias

        if return_attention_weights:
            alpha, self.__alpha__ = self.__alpha__, None
            return out, (edge_index, alpha)
        else:
            return out

    def message(self, x_i, x_j, edge_index_i, size_i, embedding, edges, return_attention_weights):
        """
        x_i, x_j: (num_edges, heads * out_channels) đã reshape thành (num_edges, heads, out_channels)
        edge_index_i: first row of edge_index
        embedding: (num_nodes, emb_dim) hoặc None
        edges: toàn bộ edge_index
        return_attention_weights: bool
        """
        x_i = x_i.view(-1, self.heads, self.out_channels)
        x_j = x_j.view(-1, self.heads, self.out_channels)

        if embedding is not None:
            emb_i = embedding[edge_index_i]      # (num_edges, emb_dim)
            emb_j = embedding[edges[0]]          # (num_edges, emb_dim)
            emb_i = emb_i.unsqueeze(1).repeat(1, self.heads, 1)  # (num_edges, heads, emb_dim)
            emb_j = emb_j.unsqueeze(1).repeat(1, self.heads, 1)

            key_i = torch.cat((x_i, emb_i), dim=-1)
            key_j = torch.cat((x_j, emb_j), dim=-1)

            cat_att_i = torch.cat((self.att_i, self.att_em_i), dim=-1)  # (1, heads, out_ch+emb_dim)
            cat_att_j = torch.cat((self.att_j, self.att_em_j), dim=-1)

            alpha = (key_i * cat_att_i).sum(-1) + (key_j * cat_att_j).sum(-1)  # (num_edges, heads)
        else:
            cat_att_i = self.att_i
            cat_att_j = self.att_j
            alpha = (x_i * cat_att_i).sum(-1) + (x_j * cat_att_j).sum(-1)

        alpha = alpha.view(-1, self.heads, 1)
        alpha = F.leaky_relu(alpha, self.negative_slope)
        self.node_dim = 0
        torch.use_deterministic_algorithms(False)
        alpha = softmax(alpha, edge_index_i, num_nodes=size_i)
        torch.use_deterministic_algorithms(True)

        if return_attention_weights:
            self.__alpha__ = alpha

        alpha = F.dropout(alpha, p=self.dropout, training=self.training)
        return x_j * alpha.view(-1, self.heads, 1)

    def __repr__(self):
        return f'{self.__class__.__name__}({self.in_channels}, {self.out_channels}, heads={self.heads})'


Overwriting /kaggle/working/DCMD-main/models/gcae/graph_layer.py


In [92]:
%%writefile /kaggle/working/DCMD-main/models/gcae/forecasting_gat.py
import torch
import torch.nn as nn
from torch_geometric.utils import dense_to_sparse

from models.gcae.graph_layer import GraphLayer


class GATForecast(nn.Module):
    """
    Dùng GraphLayer (GAT) + TCN để dự báo frame kế tiếp.
    Input: x_hist: (N, C, T, V)
    Output: pred_next: (N, C, V)
    """

    def __init__(self,
                 in_channels: int,
                 num_nodes: int,
                 hidden_dim: int,
                 heads: int = 4,
                 dropout: float = 0.2,
                 use_embedding: bool = False):
        super(GATForecast, self).__init__()
        self.in_channels = in_channels
        self.num_nodes = num_nodes
        self.hidden_dim = hidden_dim
        self.heads = heads
        self.dropout = dropout
        self.use_embedding = use_embedding

        # 1) GraphLayer: input_dim = in_channels, out_channels = hidden_dim, heads=heads
        self.graph_conv = GraphLayer(
            in_channels=in_channels,
            out_channels=hidden_dim,
            heads=heads,
            concat=True,
            negative_slope=0.2,
            dropout=dropout,
            bias=True
        )

        # 2) Conv1d để gom thông tin theo T
        self.tcn = nn.Sequential(
            nn.Conv1d(
                in_channels=hidden_dim * heads,
                out_channels=hidden_dim * heads,
                kernel_size=3,
                padding=1
            ),
            nn.BatchNorm1d(hidden_dim * heads),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout, inplace=True)
        )

        # 3) FC map ẩn → in_channels (xyz)
        self.fc_out = nn.Linear(hidden_dim * heads, in_channels)

    def forward(self, x_hist: torch.Tensor, A: torch.Tensor, embedding: torch.Tensor = None):
        """
        x_hist: (N, C, T, V)
        A: adjacency dense (V, V)
        embedding: nếu có, shape (V, emb_dim)
        """
        N, C, T, V = x_hist.shape

        # 1) Tạo batched node features: (N, T, V, C) → (N*T*V, C)
        x_perm = x_hist.permute(0, 2, 3, 1).contiguous()  # (N, T, V, C)
        x_nodes = x_perm.view(N * T * V, C)               # (N*T*V, C)

        # 2) Tạo edge_index từ A (dense) rồi batch cho N*T
        edge_index, _ = dense_to_sparse(A)  # (2, E)
        num_graphs = N * T
        node_offsets = torch.arange(0, num_graphs * V, step=V, device=A.device)

        # Mở rộng edge_index
        edge0 = edge_index[0].repeat(num_graphs) + node_offsets.repeat_interleave(edge_index.size(1))
        edge1 = edge_index[1].repeat(num_graphs) + node_offsets.repeat_interleave(edge_index.size(1))
        batch_edge_index = torch.stack([edge0, edge1], dim=0)  # (2, num_graphs*E)

        # 3) GraphLayer
        if self.use_embedding and (embedding is not None):
            emb = embedding.repeat(N * T, 1)  # (N*T*V, emb_dim) — tùy kiểu embedding
            x_gc = self.graph_conv(x_nodes, batch_edge_index, emb)  # (N*T*V, hidden_dim*heads)
        else:
            x_gc = self.graph_conv(x_nodes, batch_edge_index)      # (N*T*V, hidden_dim*heads)

        # 4) Reshape → (N, V, hidden_dim*heads, T)
        x_gc = x_gc.view(N, T, V, -1).permute(0, 2, 3, 1).contiguous()  # (N, V, H, T)
        x_gc = x_gc.view(N * V, -1, T)  # (N*V, H, T)

        # 5) TCN 1D theo chiều T
        x_tcn = self.tcn(x_gc)  # (N*V, H, T)
        feat_last = x_tcn[:, :, -1]  # lấy cuối cùng → (N*V, H)

        # 6) FC → (N*V, C) → reshape về (N, V, C) → (N, C, V)
        pred_flat = self.fc_out(feat_last)  # (N*V, C)
        pred = pred_flat.view(N, V, C)      # (N, V, C)
        pred = pred.permute(0, 2, 1).contiguous()  # (N, C, V)
        return pred


Overwriting /kaggle/working/DCMD-main/models/gcae/forecasting_gat.py


In [120]:
%%writefile /kaggle/working/DCMD-main/models/common/components.py
import torch
import torch.nn as nn
from typing import List, Tuple, Union

# Nếu bạn có STSGCN gốc trong stsgcn.py, import vào đây (nếu không có, đặt STSGCN = None)
try:
    from models.gcae.stsgcn import STSGCN
except ImportError:
    STSGCN = None

# Import Enhanced components
try:
    from models.gcae.enhanced_stgcn import EnhancedEncoder, EnhancedDCMDLoss
except ImportError:
    EnhancedEncoder = None
    EnhancedDCMDLoss = None

# Import module forecasting (GATForecast) từ forecasting_gat.py
try:
    from models.gcae.forecasting_gat import GATForecast
except ImportError:
    GATForecast = None


def build_encoder(cfg):
    """
    Tạo encoder thích hợp:
    - Nếu dùng adaptive + jigsaw thì dùng EnhancedEncoder.
    - Nếu không thì dùng encoder gốc (bạn có thể định nghĩa riêng, hoặc mặc định trả về None).
    """
    if getattr(cfg, 'use_adaptive', False):
        from models.gcae.enhanced_stgcn import EnhancedEncoder  # đảm bảo bạn có file này
        return EnhancedEncoder(
            input_dim=cfg.input_dim,
            layer_channels=cfg.layer_channels,
            hidden_dimension=cfg.h_dim,
            n_frames=cfg.n_frames,
            n_joints=cfg.n_joints,
            dropout=cfg.dropout,
            emb_dim=getattr(cfg, 'emb_dim', None),
            use_adaptive=True,
            use_jigsaw=getattr(cfg, 'use_jigsaw', False)
        )
    else:
        from models.gcae.stsgcn import STSGCN
        return STSGCN(
            input_dim=cfg.input_dim,
            layer_channels=cfg.layer_channels,
            hidden_dimension=cfg.h_dim,
            n_frames=cfg.n_frames,
            n_joints=cfg.n_joints,
            dropout=cfg.dropout
        )


def build_decoder(cfg):
    """
    Nếu bạn có decoder gốc (ví dụ STSGCN-based decoder), import và trả về ở đây.
    Nếu không, trả None.
    """
    return None


def build_loss(cfg):
    """
    Nếu cfg.use_adaptive (hoặc cfg.use_jigsaw) ⇒ trả EnhancedDCMDLoss,
    ngược lại ⇒ trả nn.MSELoss().
    """
    if getattr(cfg, 'use_adaptive', False):
        if EnhancedDCMDLoss is None:
            raise ImportError("EnhancedDCMDLoss không tìm thấy tại models/gcae/enhanced_stgcn.py")
        return EnhancedDCMDLoss(
            main_loss_weight=cfg.loss_1_prior_weight,    # hoặc cfg.main_loss_weight do bạn định nghĩa
            puzzle_loss_weight=cfg.loss_2_prior_weight   # hoặc cfg.puzzle_loss_weight
        )
    else:
        return nn.MSELoss()


def build_forecasting(cfg):
    """
    Tạo và trả GATForecast nếu cfg.use_forecasting == True.
    """
    if GATForecast is None:
        raise ImportError("GATForecast không tìm thấy tại models/gcae/forecasting_gat.py")
    return GATForecast(
        in_channels=cfg.channels,
        num_nodes=cfg.n_joints,
        hidden_dim=cfg.hidden_dim_forecast,
        heads=cfg.num_heads_gat,
        dropout=cfg.dropout_forecast,
        use_embedding=False
    )


class Encoder(nn.Module):
    """
    Lớp STS-GCN gốc, dùng nếu bạn có stsgcn.ST_GCNN_layer.
    """
    def __init__(self,
                 input_dim: int,
                 layer_channels: List[int],
                 hidden_dimension: int,
                 n_frames: int,
                 n_joints: int,
                 dropout: float,
                 bias: bool = True) -> None:
        super().__init__()

        try:
            import models.gcae.stsgcn as stsgcn
        except ImportError:
            stsgcn = None

        self.input_dim = input_dim
        self.layer_channels = layer_channels
        self.hidden_dimension = hidden_dimension
        self.n_frames = n_frames
        self.n_joints = n_joints
        self.dropout = dropout
        self.bias = bias

        self.build_model()

    def build_model(self):
        try:
            import models.gcae.stsgcn as stsgcn
        except ImportError:
            raise ImportError("Không tìm thấy stsgcn.py hoặc ST_GCNN_layer")

        input_channels = self.input_dim
        layer_channels = self.layer_channels + [self.hidden_dimension]
        kernel_size = [1, 1]
        stride = 1
        model_layers = nn.ModuleList()
        for channels in layer_channels:
            model_layers.append(
                stsgcn.ST_GCNN_layer(
                    in_channels=input_channels,
                    out_channels=channels,
                    kernel_size=kernel_size,
                    stride=stride,
                    time_dim=self.n_frames,
                    joints_dim=self.n_joints,
                    dropout=self.dropout,
                    bias=self.bias
                )
            )
            input_channels = channels

        self.model_layers = model_layers

    def forward(self, X: torch.Tensor, t: torch.Tensor = None) -> Tuple[torch.Tensor, List[torch.Tensor]]:
        layers_out = [X]
        for layer in self.model_layers:
            out_X = layer(layers_out[-1], t)
            layers_out.append(out_X)
        return layers_out[-1], layers_out[:-1]


class Decoder(nn.Module):
    """
    Lớp Decoder gốc (STS-GCN) nếu dùng.
    """
    def __init__(self,
                 output_dim: int,
                 layer_channels: List[int],
                 hidden_dimension: int,
                 n_frames: int,
                 n_joints: int,
                 dropout: float,
                 bias: bool = True) -> None:
        super().__init__()

        try:
            import models.gcae.stsgcn as stsgcn
        except ImportError:
            stsgcn = None

        self.output_dim = output_dim
        self.layer_channels = layer_channels[::-1]
        self.hidden_dimension = hidden_dimension
        self.n_frames = n_frames
        self.n_joints = n_joints
        self.dropout = dropout
        self.bias = bias

        self.build_model()

    def build_model(self):
        try:
            import models.gcae.stsgcn as stsgcn
        except ImportError:
            raise ImportError("Không tìm thấy stsgcn.py hoặc ST_GCNN_layer")

        input_channels = self.hidden_dimension
        layer_channels = self.layer_channels + [self.output_dim]
        kernel_size = [1, 1]
        stride = 1
        model_layers = nn.ModuleList()
        for channels in layer_channels:
            model_layers.append(
                stsgcn.ST_GCNN_layer(
                    in_channels=input_channels,
                    out_channels=channels,
                    kernel_size=kernel_size,
                    stride=stride,
                    time_dim=self.n_frames,
                    joints_dim=self.n_joints,
                    dropout=self.dropout,
                    bias=self.bias
                )
            )
            input_channels = channels

        self.model_layers = model_layers

    def forward(self, X: torch.Tensor, t: torch.Tensor = None) -> torch.Tensor:
        for layer in self.model_layers:
            X = layer(X, t)
        return X


class DecoderResiduals(Decoder):
    """
    Lớp Decoder có residual connections, nếu cần.
    """
    def build_model(self) -> None:
        super().build_model()
        self.out = nn.Linear(self.n_frames, self.n_frames)

    def forward(self, X: torch.Tensor, t: torch.Tensor, residuals: List[torch.Tensor]) -> torch.Tensor:
        for layer in self.model_layers:
            out_X = layer(X, t)
            X = out_X + residuals.pop()
        X = self.out(X.permute(0, 1, 3, 2).contiguous()).permute(0, 1, 3, 2).contiguous()
        return X


class Denoiser(nn.Module):
    """
    Mô hình denoiser gốc cho diffusion.
    """
    def __init__(self,
                 input_size: int,
                 hidden_sizes: List[int],
                 cond_size: int = None,
                 bias: bool = True,
                 device: Union[str, torch.device] = 'cpu') -> None:
        super().__init__()
        self.input_size = input_size
        self.hidden_sizes = hidden_sizes
        self.cond_size = cond_size
        self.embedding_dim = self.cond_size
        self.bias = bias
        self.device = device

        self.build_model()

    def build_model(self) -> None:
        self.net = nn.ModuleList()
        self.cond_layers = nn.ModuleList() if self.cond_size is not None else None
        n_layers = len(self.hidden_sizes)
        input_size = self.input_size
        for idx, next_dim in enumerate(self.hidden_sizes):
            if self.cond_size is not None:
                self.cond_layers.append(nn.Linear(self.cond_size, next_dim, bias=self.bias))
            if idx == n_layers - 1:
                self.net.append(nn.Linear(input_size, next_dim, bias=self.bias))
            else:
                self.net.append(nn.Sequential(
                    nn.Linear(input_size, next_dim, bias=self.bias),
                    nn.BatchNorm1d(next_dim),
                    nn.ReLU(inplace=True)
                ))
                input_size = next_dim

    def pos_encoding(self, t: torch.Tensor, channels: int) -> torch.Tensor:
        inv_freq = 1.0 / (
            10000 ** (torch.arange(0, channels, 2, device=self.device).float() / channels)
        ).to(t.device)
        pos_enc_a = torch.sin(t.repeat(1, channels // 2) * inv_freq)
        pos_enc_b = torch.cos(t.repeat(1, channels // 2) * inv_freq)
        pos_enc = torch.cat([pos_enc_a, pos_enc_b], dim=-1)
        return pos_enc

    def forward(self, X: torch.Tensor, t: torch.Tensor, cond: torch.Tensor = None) -> torch.Tensor:
        t = t.unsqueeze(-1).type(torch.float)
        t = self.pos_encoding(t, self.embedding_dim)
        if cond is not None:
            cond = t + cond
        else:
            cond = t
        for i in range(len(self.net)):
            X = self.net[i](X)
            if cond is not None:
                X = X + self.cond_layers[i](cond)
        return X


Overwriting /kaggle/working/DCMD-main/models/common/components.py


In [108]:
%%writefile /kaggle/working/DCMD-main/models/dcmd.py
import argparse
import os
from math import prod
from typing import Dict, List, Tuple

import numpy as np
import pandas as pd
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
from matplotlib import pyplot as plt

from models.stsae.stsae import STSAE
from sklearn.metrics import roc_curve, roc_auc_score
from torch.optim import Adam
from tqdm import tqdm

from utils.diffusion_utils import Diffusion
from utils.eval_utils import (compute_var_matrix, filter_vectors_by_cond,
                              get_avenue_mask, get_hr_ubnormal_mask, pad_scores, score_process)
from utils.model_utils import processing_data, my_kl_loss
from models.transformer import MotionTransformer
from utils.tools import get_dct_matrix, generate_pad, padding_traj


class DCMD(pl.LightningModule):

    losses = {'l1': nn.L1Loss, 'smooth_l1': nn.SmoothL1Loss, 'mse': nn.MSELoss}
    conditioning_strategies = {'inject': 'inject'}

    def __init__(self, args: argparse.Namespace) -> None:
        """
        This class implements DCMD model.

        Args:
            args (argparse.Namespace): arguments containing các hyperparameters
        """
        super(DCMD, self).__init__()

        ## Log hyperparameters
        self.save_hyperparameters(args)

        ## Thiết lập biến nội bộ
        self.n_frames = args.seg_len
        self.num_coords = args.num_coords
        self.n_joints = self._infer_number_of_joint(args)

        self.dropout = args.dropout
        self.conditioning_strategy = self.conditioning_strategies[args.conditioning_strategy]
        self.cond_h_dim = args.h_dim
        self.cond_latent_dim = args.latent_dim
        self.cond_channels = args.channels
        self.cond_dropout = args.dropout

        self.learning_rate = args.opt_lr
        self.loss_fn = self.losses[args.loss_fn](reduction='none')
        self.noise_steps = args.noise_steps
        self.aggregation_strategy = args.aggregation_strategy
        self.n_generated_samples = args.n_generated_samples
        self.model_return_value = args.model_return_value
        self.gt_path = args.gt_path
        self.split = args.split
        self.use_hr = args.use_hr
        self.ckpt_dir = args.ckpt_dir
        self.save_tensors = args.save_tensors

        # SỬA: args.num_transform (không phải num_transforms)
        self.num_transforms = args.num_transform
        self.anomaly_score_pad_size = args.pad_size
        self.anomaly_score_filter_kernel_size = args.filter_kernel_size
        self.anomaly_score_frames_shift = args.frames_shift
        self.dataset_name = args.dataset_choice

        self.n_his = args.n_his
        self.padding = args.padding
        self.num_layers = args.num_layers
        self.num_heads = args.num_heads
        self.latent_dims = args.latent_dims
        self.automatic_optimization = False
        self.loss_1_series_weight = args.loss_1_series_weight
        self.loss_1_prior_weight = args.loss_1_prior_weight
        self.loss_2_series_weight = args.loss_2_series_weight
        self.loss_2_prior_weight = args.loss_2_prior_weight
        self.idx_pad, self.zero_index = generate_pad(self.padding, self.n_his, self.n_frames - self.n_his)

        ## Khởi noise scheduler
        self._set_diffusion_variables()

        # Forecasting MLP: thay args.in_channels → args.channels
        if args.use_forecasting:
            self.fct_forecasting = nn.Sequential(
                nn.Linear(args.channels[-1] * self.n_joints, args.latent_dim_forecast),
                nn.ReLU(),
            )

        ## Xây dựng model
        self.build_model()

    def build_model(self) -> None:
        """
        Build model theo các hyperparameters.
        """
        args = self.hparams

        # Prediction Model (Transformer)
        pre_model = MotionTransformer(
            input_feats=2 * self.n_joints,
            num_frames=self.n_frames,
            num_layers=self.num_layers,
            num_heads=self.num_heads,
            latent_dim=self.latent_dims,
            dropout=self.dropout,
            device=self.device,
            inject_condition=(self.conditioning_strategy == 'inject')
        )

        # Reconstruction Model (STSAE)
        default_layers = getattr(args, 'layer_channels', [128, 64, 128])
        default_use_adaptive = getattr(args, 'use_adaptive', False)
        default_use_jigsaw = getattr(args, 'use_jigsaw', False)
        self.rec_model = STSAE(
            c_in=args.channels,                  # dùng args.channels
            h_dim=args.h_dim,
            latent_dim=args.latent_dim,
            n_frames=args.seg_len,
            n_joints=self.n_joints,
            layer_channels=args.layer_channels,
            dropout=args.dropout,
            device=self.device,
            use_adaptive=getattr(args, 'use_adaptive', False),
            use_jigsaw = getattr(args, 'use_jigsaw', False),
            emb_dim = args.emb_dim

        )

        self.pre_model = pre_model
        self.rec_model = self.rec_model

        # Loss function: nếu dùng adaptive ⇒ EnhancedDCMDLoss, else MSELoss
        from models.common.components import build_loss
        self.criterion = build_loss(args)

    def forward(self, input_data: List[torch.Tensor], aggr_strategy: str = None, return_: str = None) -> List[torch.Tensor]:
        """
        Forward pass of the model.
        """
        tensor_data, meta_out = self._unpack_data(input_data)
        B = tensor_data.shape[0]

        history_data = tensor_data[:, :, :self.n_his, :]
        x_0 = padding_traj(history_data, self.padding, self.idx_pad, self.zero_index)

        generated_xs = []
        for _ in range(self.n_generated_samples):

            # Reconstruction
            condition_embedding, rec_his_data = self.rec_model(history_data)

            # Prediction (diffusion)
            dct_m, idct_m = get_dct_matrix(self.n_frames)
            dct_m_all = dct_m.float().to(self.device)
            idct_m_all = idct_m.float().to(self.device)
            x = x_0.permute(0, 2, 3, 1).contiguous()
            x = x.reshape([x.shape[0], self.n_frames, -1])
            y = torch.matmul(dct_m_all, x)

            y_d = torch.randn_like(y, device=self.device)
            for i in reversed(range(1, self.noise_steps)):
                t = torch.full(size=(B,), fill_value=i, dtype=torch.long, device=self.device)
                t_prev = torch.full(size=(B,), fill_value=i, dtype=torch.long, device=self.device)
                t_prev[0] = 0
                noise_pre = torch.randn_like(y_d, device=self.device) if i > 1 else torch.zeros_like(y_d, device=self.device)

                predicted_noise_pre, series, prior, _ = self.pre_model(y_d, t, condition_data=condition_embedding)
                alpha_pre = self._alpha[t][:, None, None]
                alpha_hat_pre = self._alpha_hat[t][:, None, None]
                beta_pre = self._beta[t][:, None, None]
                y_d = (1 / torch.sqrt(alpha_pre)) * (y_d - ((1 - alpha_pre) / (torch.sqrt(1 - alpha_hat_pre))) * predicted_noise_pre) \
                    + torch.sqrt(beta_pre) * noise_pre

                alpha_hat_prev = self._alpha_hat[t_prev][:, None, None]
                y_n = (torch.sqrt(alpha_hat_prev) * y) + (torch.sqrt(1 - alpha_hat_prev) * noise_pre)

                mask = torch.zeros_like(x, device=self.device)
                for m in range(0, self.n_his):
                    mask[:, m, :] = 1

                y_d_idct = torch.matmul(idct_m_all, y_d)
                y_n_idct = torch.matmul(idct_m_all, y_n)
                m_mul_y_n = torch.mul(mask, y_n_idct)
                m_mul_y_d = torch.mul((1 - mask), y_d_idct)
                m_y = m_mul_y_d + m_mul_y_n
                y_d = torch.matmul(dct_m_all, m_y)

            pre_future_data = torch.matmul(idct_m_all, y_d)
            pre_future_data = pre_future_data.reshape(pre_future_data.shape[0], pre_future_data.shape[1], -1, 2)
            pre_future_data = pre_future_data.permute(0, 3, 1, 2).contiguous()
            pre_future_data = pre_future_data[:, :, self.n_his:, :]

            xs = torch.cat((rec_his_data, pre_future_data), dim=2)
            generated_xs.append(xs)

        selected_x, loss_of_selected_x = self._aggregation_strategy(generated_xs, tensor_data, aggr_strategy)
        return self._pack_out_data(selected_x, loss_of_selected_x, [tensor_data] + meta_out, return_=return_)

    def training_step(self, batch: List[torch.Tensor], batch_idx: int) -> torch.float32:
        opt = self.optimizers()
        tensor_data, _ = self._unpack_data(batch)
        history_data = tensor_data[:, :, :self.n_his, :]
        x_0 = tensor_data

        # Reconstruction Loss
        condition_embedding, rec_his_data = self.rec_model(history_data)
        rec_loss = torch.mean(self.loss_fn(rec_his_data, history_data))
        self.log('rec_loss', rec_loss)

        # Forecasting (nếu bật)
        if self.hparams.use_forecasting:
            # Dùng mạng forecasting để dự báo next frame từ history_data
            from models.common.components import build_forecasting
            if not hasattr(self, 'forecasting'):
                self.forecasting = build_forecasting(self.hparams)
            forecast_pred = self.forecasting(history_data, self.adjacency_matrix)  # (B, C, V)
            flat = forecast_pred.view(forecast_pred.size(0), -1)  # (B, C*V)
            forecast_emb = self.fct_forecasting(flat)  # (B, latent_dim_forecast)
            gt_next = history_data[:, :, -1, :]  # last frame in history
            loss_forecast = torch.mean(self.loss_fn(forecast_pred, gt_next))
            self.log('loss_forecast', loss_forecast)
        else:
            forecast_emb = None
            loss_forecast = 0.0

        # Prediction Branch (diffusion)
        dct_m, _ = get_dct_matrix(self.n_frames)
        dct_m_all = dct_m.float().to(self.device)
        x = x_0.permute(0, 2, 3, 1).contiguous()
        x = x.reshape([x.shape[0], self.n_frames, -1])
        y_0 = torch.matmul(dct_m_all, x)

        t = self.noise_scheduler.sample_timesteps(y_0.shape[0]).to(self.device)
        y_t, pre_noise = self.noise_scheduler.noise_motion(y_0, t)

        if forecast_emb is not None:
            cond_data = torch.cat([condition_embedding, forecast_emb], dim=-1)
        else:
            cond_data = condition_embedding

        pre_predicted_noise, series, prior, _ = self.pre_model(y_t, t, condition_data=cond_data)

        series_loss = 0.0
        prior_loss = 0.0
        for u in range(len(prior)):
            series_loss += \
                (torch.mean(my_kl_loss(
                    series[u],
                    (prior[u] / torch.unsqueeze(torch.sum(prior[u], dim=-1), dim=-1).repeat(1, 1, 1, self.n_frames)).detach()))
                 + torch.mean(my_kl_loss(
                    (prior[u] / torch.unsqueeze(torch.sum(prior[u], dim=-1), dim=-1).repeat(1, 1, 1, self.n_frames)).detach(),
                    series[u])))
            prior_loss += \
                (torch.mean(my_kl_loss(
                    (prior[u] / torch.unsqueeze(torch.sum(prior[u], dim=-1), dim=-1).repeat(1, 1, 1, self.n_frames)),
                    series[u].detach()))
                 + torch.mean(my_kl_loss(
                    series[u].detach(),
                    (prior[u] / torch.unsqueeze(torch.sum(prior[u], dim=-1), dim=-1).repeat(1, 1, 1, self.n_frames)))))

        series_loss = series_loss / len(prior)
        prior_loss = prior_loss / len(prior)

        pre_loss = torch.mean(self.loss_fn(pre_predicted_noise, pre_noise))
        self.log('pre_loss', pre_loss)

        loss1 = rec_loss + pre_loss \
                - self.loss_1_series_weight * series_loss \
                + self.loss_1_prior_weight * prior_loss
        self.log('loss1', loss1)
        loss2 = rec_loss + pre_loss \
                + self.loss_2_prior_weight * prior_loss \
                + self.loss_2_series_weight * series_loss
        self.log('loss2', loss2)

        self.manual_backward(loss1, retain_graph=True)
        self.manual_backward(loss2)
        opt.step()
        opt.zero_grad()

    def test_step(self, batch: List[torch.Tensor], batch_idx: int) -> None:
        self._test_output_list.append(self.forward(batch))
        return

    def on_test_epoch_start(self) -> None:
        super().on_test_epoch_start()
        self._test_output_list = []
        return

    def on_test_epoch_end(self) -> float:
        out, gt_data, trans, meta, frames = processing_data(self._test_output_list)
        del self._test_output_list
        if self.save_tensors:
            tensors = {'prediction': out, 'gt_data': gt_data,
                       'trans': trans, 'metadata': meta, 'frames': frames}
            self._save_tensors(tensors, split_name=self.split, aggr_strategy=self.aggregation_strategy, n_gen=self.n_generated_samples)
        auc_score = self.post_processing(out, gt_data, trans, meta, frames)
        self.log('AUC', auc_score)
        return auc_score

    def validation_step(self, batch: List[torch.Tensor], batch_idx: int) -> None:
        self._validation_output_list.append(self.forward(batch))
        return

    def on_validation_epoch_start(self) -> None:
        super().on_validation_epoch_start()
        self._validation_output_list = []
        return

    def on_validation_epoch_end(self) -> float:
        out, gt_data, trans, meta, frames = processing_data(self._validation_output_list)
        del self._validation_output_list
        if self.save_tensors:
            tensors = {'prediction': out, 'gt_data': gt_data,
                       'trans': trans, 'metadata': meta, 'frames': frames}
            self._save_tensors(tensors, split_name=self.split, aggr_strategy=self.aggregation_strategy, n_gen=self.n_generated_samples)
        auc_score = self.post_processing(out, gt_data, trans, meta, frames)
        self.log('AUC', auc_score, sync_dist=True)
        return auc_score

    def configure_optimizers(self) -> Dict:
        optimizer = Adam(self.parameters(), lr=self.learning_rate)
        scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99, last_epoch=-1)
        return {'optimizer': optimizer, 'lr_scheduler': scheduler, 'monitor': 'AUC'}

    def post_processing(self, out: np.ndarray, gt_data: np.ndarray, trans: np.ndarray, meta: np.ndarray, frames: np.ndarray) -> float:
        all_gts = [file_name for file_name in os.listdir(self.gt_path) if file_name.endswith('.npy')]
        all_gts = sorted(all_gts)
        scene_clips = [(int(fn.split('_')[0]), int(fn.split('_')[1].split('.')[0])) for fn in all_gts]
        hr_ubnormal_masked_clips = get_hr_ubnormal_mask(self.split) if (self.use_hr and (self.dataset_name == 'UBnormal')) else {}
        hr_avenue_masked_clips = get_avenue_mask() if self.dataset_name == 'HR-Avenue' else {}

        num_transform = self.num_transforms
        model_scores_transf = {}
        dataset_gt_transf = {}

        for transformation in tqdm(range(num_transform)):
            dataset_gt = []
            model_scores = []
            cond_transform = (trans == transformation)

            out_transform, gt_data_transform, meta_transform, frames_transform = filter_vectors_by_cond(
                [out, gt_data, meta, frames], cond_transform)

            for idx in range(len(all_gts)):
                scene_idx, clip_idx = scene_clips[idx]
                gt = np.load(os.path.join(self.gt_path, all_gts[idx]))
                n_frames = gt.shape[0]

                cond_scene_clip = (meta_transform[:, 0] == scene_idx) & (meta_transform[:, 1] == clip_idx)
                out_scene_clip, gt_scene_clip, meta_scene_clip, frames_scene_clip = filter_vectors_by_cond(
                    [out_transform, gt_data_transform, meta_transform, frames_transform],
                    cond_scene_clip)

                figs_ids = sorted(list(set(meta_scene_clip[:, 2])))
                error_per_person = []
                error_per_person_max_loss = []

                for fig in figs_ids:
                    cond_fig = (meta_scene_clip[:, 2] == fig)
                    out_fig, _, frames_fig = filter_vectors_by_cond(
                        [out_scene_clip, gt_scene_clip, frames_scene_clip], cond_fig)
                    loss_matrix = compute_var_matrix(out_fig, frames_fig, n_frames)
                    fig_reconstruction_loss = np.nanmax(loss_matrix, axis=0)
                    if self.anomaly_score_pad_size != -1:
                        fig_reconstruction_loss = pad_scores(fig_reconstruction_loss, gt, self.anomaly_score_pad_size)

                    error_per_person.append(fig_reconstruction_loss)
                    error_per_person_max_loss.append(max(fig_reconstruction_loss))

                clip_score = np.stack(error_per_person, axis=0)
                clip_score_log = np.log1p(clip_score)
                clip_score = np.mean(clip_score, axis=0) + (np.amax(clip_score_log, axis=0) - np.amin(clip_score_log, axis=0))

                if (scene_idx, clip_idx) in hr_ubnormal_masked_clips:
                    clip_score = clip_score[hr_ubnormal_masked_clips[(scene_idx, clip_idx)]]
                    gt = gt[hr_ubnormal_masked_clips[(scene_idx, clip_idx)]]

                if clip_idx in hr_avenue_masked_clips:
                    clip_score = clip_score[np.array(hr_avenue_masked_clips[clip_idx]) == 1]
                    gt = gt[np.array(hr_avenue_masked_clips[clip_idx]) == 1]

                clip_score = score_process(clip_score, self.anomaly_score_frames_shift, self.anomaly_score_filter_kernel_size)
                model_scores.append(clip_score)
                dataset_gt.append(gt)

            model_scores = np.concatenate(model_scores, axis=0)
            dataset_gt = np.concatenate(dataset_gt, axis=0)
            model_scores_transf[transformation] = model_scores
            dataset_gt_transf[transformation] = dataset_gt

        pds = np.mean(np.stack(list(model_scores_transf.values()), 0), 0)
        gt = dataset_gt_transf[0]
        auc = roc_auc_score(gt, pds)
        return auc

    def test_on_saved_tensors(self, split_name: str) -> float:
        tensors = self._load_tensors(split_name, self.aggregation_strategy, self.n_generated_samples)
        auc_score = self.post_processing(tensors['prediction'], tensors['gt_data'], tensors['trans'],
                                         tensors['metadata'], tensors['frames'])
        print(f'AUC score: {auc_score:.6f}')
        return auc_score

    ## Helper functions

    def _aggregation_strategy(self, generated_xs: List[torch.Tensor], input_sequence: torch.Tensor, aggr_strategy: str) -> Tuple[torch.Tensor]:
        aggr_strategy = self.aggregation_strategy if aggr_strategy is None else aggr_strategy
        if aggr_strategy == 'random':
            return generated_xs[np.random.randint(len(generated_xs))], None

        B, repr_shape = input_sequence.shape[0], input_sequence.shape[1:]
        compute_loss = lambda x: torch.mean(self.loss_fn(x, input_sequence).reshape(-1, prod(repr_shape)), dim=-1)
        losses = [compute_loss(x) for x in generated_xs]

        if aggr_strategy == 'all':
            dims_idxs = list(range(2, len(repr_shape) + 2))
            dims_idxs = [1, 0] + dims_idxs
            selected_x = torch.stack(generated_xs).permute(*dims_idxs)
            loss_of_selected_x = torch.stack(losses).permute(1, 0)
        elif aggr_strategy == 'mean':
            selected_x = None
            loss_of_selected_x = torch.mean(torch.stack(losses), dim=0)
        elif aggr_strategy == 'mean_pose':
            selected_x = torch.mean(torch.stack(generated_xs), dim=0)
            loss_of_selected_x = compute_loss(selected_x)
        elif aggr_strategy == 'median':
            loss_of_selected_x, _ = torch.median(torch.stack(losses), dim=0)
            selected_x = None
        elif aggr_strategy == 'median_pose':
            selected_x, _ = torch.median(torch.stack(generated_xs), dim=0)
            loss_of_selected_x = compute_loss(selected_x)
        elif aggr_strategy == 'best' or aggr_strategy == 'worst':
            strategy = (lambda x, y: x < y) if aggr_strategy == 'best' else (lambda x, y: x > y)
            loss_of_selected_x = torch.full((B,), fill_value=(1e10 if aggr_strategy == 'best' else -1.), device=self.device)
            selected_x = torch.zeros((B, *repr_shape)).to(self.device)
            for i in range(len(generated_xs)):
                mask = strategy(losses[i], loss_of_selected_x)
                loss_of_selected_x[mask] = losses[i][mask]
                selected_x[mask] = generated_xs[i][mask]
        elif 'quantile' in aggr_strategy:
            q = float(aggr_strategy.split(':')[-1])
            loss_of_selected_x = torch.quantile(torch.stack(losses), q, dim=0)
            selected_x = None
        else:
            raise ValueError(f'Unknown aggregation strategy {aggr_strategy}')

        if selected_x is None and loss_of_selected_x is None:
            loss_of_selected_x = torch.mean(torch.stack(losses), dim=0)

        return selected_x, loss_of_selected_x

    def _infer_number_of_joint(self, args: argparse.Namespace) -> int:
        if args.headless:
            joints_to_consider = 14
        elif args.kp18_format:
            joints_to_consider = 18
        else:
            joints_to_consider = 17
        return joints_to_consider

    def _load_tensors(self, split_name: str, aggr_strategy: str, n_gen: int) -> Dict[str, torch.Tensor]:
        name = 'saved_tensors_{}_{}_{}'.format(split_name, aggr_strategy, n_gen)
        path = os.path.join(self.ckpt_dir, name)
        if not os.path.exists(path):
            os.mkdir(path)
        tensor_files = os.listdir(path)
        tensors = {}
        for t_file in tensor_files:
            t_name = t_file.split('.')[0]
            tensors[t_name] = torch.load(os.path.join(path, t_file))
        return tensors

    def _pack_out_data(self, selected_x: torch.Tensor, loss_of_selected_x: torch.Tensor, additional_out: List[torch.Tensor], return_: str) -> List[torch.Tensor]:
        if return_ is None:
            if self.model_return_value is None:
                raise ValueError('Either return_ or self.model_return_value must be set')
            else:
                return_ = self.model_return_value

        if return_ == 'poses':
            out = [selected_x]
        elif return_ == 'loss':
            out = [loss_of_selected_x]
        elif return_ == 'all':
            out = []
            if loss_of_selected_x is not None:
                out.append(loss_of_selected_x)
            if selected_x is not None:
                out.append(selected_x)

        return out + additional_out

    def _save_tensors(self, tensors: Dict[str, torch.Tensor], split_name: str, aggr_strategy: str, n_gen: int) -> None:
        name = 'saved_tensors_{}_{}_{}'.format(split_name, aggr_strategy, n_gen)
        path = os.path.join(self.ckpt_dir, name)
        if not os.path.exists(path):
            os.mkdir(path)
        for t_name, tensor in tensors.items():
            torch.save(tensor, os.path.join(path, t_name + '.pt'))

    def _set_diffusion_variables(self) -> None:
        self.noise_scheduler = Diffusion(noise_steps=self.noise_steps, n_joints=self.n_joints,
                                         device=self.device, time=self.n_frames)
        self._beta_ = self.noise_scheduler.schedule_noise()
        self._alpha_ = (1. - self._beta_)
        self._alpha_hat_ = torch.cumprod(self._alpha_, dim=0)

    def _unpack_data(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
        tensor_data = x[0].to(self.device)
        transformation_idx = x[1]
        metadata = x[2]
        actual_frames = x[3]
        meta_out = [transformation_idx, metadata, actual_frames]
        return tensor_data, meta_out

    @property
    def _beta(self) -> torch.Tensor:
        return self._beta_.to(self.device)

    @property
    def _alpha(self) -> torch.Tensor:
        return self._alpha_.to(self.device)

    @property
    def _alpha_hat(self) -> torch.Tensor:
        return self._alpha_hat_.to(self.device)


Overwriting /kaggle/working/DCMD-main/models/dcmd.py
