In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

/kaggle/input/dgraph-fin/Readme.md
/kaggle/input/dgraph-fin/dgraphfin.npz


In [2]:
!pip install pygod torch_geometric

Collecting pygod
  Downloading pygod-1.1.0-py3-none-any.whl.metadata (15 kB)
Collecting torch_geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m2.8 MB/s[0m eta [36m0:00:00[0m
Downloading pygod-1.1.0-py3-none-any.whl (86 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m86.8/86.8 kB[0m [31m6.8 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading torch_geometric-2.6.1-py3-none-any.whl (1.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m26.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch_geometric, pygod
Successfully installed pygod-1.1.0 torch_geometric-2.6.1


In [3]:
import os
import math
import numpy as np
from datetime import datetime
from typing import Callable, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim
from torch import Tensor
from torch_geometric.datasets import DGraphFin
from torch_geometric.transforms import BaseTransform
from torch_geometric.nn import GCN
from torch_geometric.utils import to_dense_adj

import tqdm

# NOTE: Ensure that the following modules are available in your Kaggle environment.
from pygod.metric import *
from pygod.metric.metric import *
from pygod.utils import load_data
from pygod.nn.decoder import DotProductDecoder
from pygod.nn.functional import double_recon_loss

####################################
# Set device (Kaggle provides GPU if enabled)
####################################
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

####################################
# Utility Functions & Schedules
####################################
def extract(a, t, x_shape):
    batch_size = t.shape[0]
    out = a.gather(-1, t.cpu())
    return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)

def linear_beta_schedule(timesteps):
    beta_start = 0.0001
    beta_end = 0.02
    return torch.linspace(beta_start, beta_end, timesteps)

# define beta schedule and related terms
timesteps = 500
betas = linear_beta_schedule(timesteps=timesteps)
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
sqrt_recip_alphas = torch.sqrt(1.0 / alphas)
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)

####################################
# Graph Autoencoder (Graph_AE)
####################################
class Graph_AE(nn.Module):
    def __init__(self,
                 in_dim,
                 hid_dim=64,
                 num_layers=4,
                 dropout=0.,
                 act=torch.nn.functional.relu,
                 sigmoid_s=False,
                 backbone=GCN,
                 **kwargs):
        super(Graph_AE, self).__init__()

        # split the number of layers for the encoder and decoders
        assert num_layers >= 2, "Number of layers must be >= 2."
        encoder_layers = math.floor(num_layers / 2)
        decoder_layers = math.ceil(num_layers / 2)

        self.shared_encoder = backbone(in_channels=in_dim,
                                       hidden_channels=hid_dim,
                                       num_layers=encoder_layers,
                                       out_channels=hid_dim,
                                       dropout=dropout,
                                       act=act,
                                       **kwargs)

        self.attr_decoder = backbone(in_channels=hid_dim,
                                     hidden_channels=hid_dim,
                                     num_layers=decoder_layers,
                                     out_channels=in_dim,
                                     dropout=dropout,
                                     act=act,
                                     **kwargs)

        self.struct_decoder = DotProductDecoder(in_dim=hid_dim,
                                                hid_dim=hid_dim,
                                                num_layers=decoder_layers - 1,
                                                dropout=dropout,
                                                act=act,
                                                sigmoid_s=sigmoid_s,
                                                backbone=backbone,
                                                **kwargs)

        self.loss_func = double_recon_loss
        self.emb = None

    def forward(self, x, edge_index):
        self.emb = self.encode(x, edge_index)
        x_, s_ = self.decode(self.emb, edge_index)
        return x_, s_, self.emb

    def encode(self, x, edge_index):
        self.emb = self.shared_encoder(x, edge_index)
        return self.emb

    def decode(self, emb, edge_index):
        x_ = self.attr_decoder(emb, edge_index)
        s_ = self.struct_decoder(emb, edge_index)
        return x_, s_

####################################
# DiffGAD (combining autoencoder + diffusion models)
####################################
class DiffGAD(BaseTransform):
    def __init__(self,
                 name="",
                 hid_dim=None,
                 diff_dim=None,
                 ae_epochs=300,
                 diff_epochs=800,
                 patience=100,
                 lr=0.005,
                 wd=0.,
                 lamda=0.0,
                 sample_steps=50,
                 radius=1,
                 ae_dropout=0.3,
                 ae_lr=0.05, 
                 ae_alpha=0.8,
                 verbose=True):

        self.name = name
        self.hid_dim = hid_dim
        self.diff_dim = diff_dim
        self.ae_epochs = ae_epochs
        self.diff_epochs = diff_epochs
        self.patience = patience
        self.lr = lr
        self.wd = wd
        self.sample_steps = sample_steps
        self.verbose = verbose
        self.lamda = lamda
        
        self.common_feat = None
        self.dm = None

        self.ae = None
        self.ae_dropout = ae_dropout
        self.ae_lr = ae_lr
        self.ae_alpha = ae_alpha
        self.cos = nn.CosineSimilarity(dim=1, eps=1e-6)
        self.timesteps = timesteps
        self.radius = radius

    def forward(self, dset):
        self.dataset = dset 
        data = load_data(self.dataset)

        if self.hid_dim is None:
            self.hid_dim = 2 ** int(math.log2(data.x.size(1)) - 1)
        if self.diff_dim is None:
            self.diff_dim = 2 * self.hid_dim
        
        self.ae = Graph_AE(in_dim=data.num_node_features, 
                           hid_dim=self.hid_dim,
                           dropout=self.ae_dropout).to(device)
        self.save_dir = os.path.join(os.getcwd(), 'models', self.dataset, 'full_batch')
        self.ae_path = os.path.join(self.save_dir, f"{self.ae_dropout}_{self.ae_lr}_{self.ae_alpha}_{self.hid_dim}")
        if not os.path.exists(self.ae_path):
            os.makedirs(self.ae_path)
        ######################## Train Autoencoder #######################
        self.train_ae(data)
        ae_dict = torch.load(os.path.join(self.ae_path, 'Graph_AE.pt'))
        self.ae.load_state_dict(ae_dict['state_dict'])

        num_trial = 20
        for _ in tqdm.tqdm(range(num_trial)):
            ##################################
            # Unconditional diffusion models
            denoise_fn = MLPDiffusion(self.hid_dim, self.diff_dim).to(device)
            self.dm = Model(denoise_fn=denoise_fn, hid_dim=self.hid_dim).to(device)
            self.common_feat = self.train_dm(data)
            dm_dict = torch.load(os.path.join(self.ae_path, 'edm.pt'))
            self.dm.load_state_dict(dm_dict['state_dict'])
            self.common_feat = dm_dict['common_feat']
            #################################
            # Conditional diffusion models
            print("Common feature:", self.common_feat)
            denoise_condition = MLPDiffusion(self.hid_dim, self.diff_dim).to(device)
            self.dm_condition = Model(denoise_fn=denoise_condition, hid_dim=self.hid_dim).to(device)
            self.train_dm_condition(data)
            dm_free_dict = torch.load(os.path.join(self.ae_path, 'conditional_edm.pt'))
            self.dm_condition.load_state_dict(dm_free_dict['state_dict'])
            #################################
            # Evaluation (sampling)
            self.sample_free(self.dm_condition, self.dm, data)

    def train_ae(self, data):
        if self.verbose:
            print('Training autoencoder ...')
        optimizer = torch.optim.Adam(self.ae.parameters(), self.ae_lr, weight_decay=0.01)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.5)   
        
        for epoch in range(1, self.ae_epochs + 1):
            self.ae.train()
            optimizer.zero_grad()

            x = data.x.to(device)
            edge_index = data.edge_index.to(device)
            y = data.y.bool()
            s = to_dense_adj(edge_index)[0].to(device)
            x_, s_, embedding = self.ae(x, edge_index)
            score = self.ae.loss_func(x, x_, s, s_, self.ae_alpha)
            loss = torch.mean(score)

            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.ae.parameters(), 1.0)
            optimizer.step()
            scheduler.step()

            # Save model checkpoint
            torch.save({'state_dict': self.ae.state_dict()},
                       os.path.join(self.ae_path, 'Graph_AE.pt'))

    def train_dm(self, data):
        if self.verbose:
            print('Training diffusion model (unconditional) ...')      
        optimizer = torch.optim.Adam(self.dm.parameters(), lr=self.lr, weight_decay=self.wd)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.5)
        self.dm.train()
        best_loss = float('inf')
        patience = 0
        common_feat = None
        for epoch in range(self.diff_epochs):
            x = data.x.to(device)
            edge_index = data.edge_index.to(device)
            inputs = self.ae.encode(x, edge_index) 

            if epoch == 0:
                common_feat = torch.mean(inputs, dim=0)
            else:
                s_v = self.cos(common_feat, reconstructed)
                omega = softmax_with_temperature(s_v, t=5).reshape(1, -1)
                common_feat = torch.mm(omega, reconstructed).detach() 

            loss, score_train, reconstructed = self.dm(inputs)
            loss = loss.mean()

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.dm.parameters(), 1.0)
            optimizer.step()
            if epoch % 10 == 0:
                print("Epoch: {:04d} loss= {:.5f}".format(epoch, loss.item()))
            scheduler.step()

            if loss < best_loss:
                best_loss = loss
                patience = 0
                torch.save({'state_dict': self.dm.state_dict(),
                            'common_feat': common_feat},
                           os.path.join(self.ae_path, 'edm.pt'))
            else:
                patience += 1
                if patience == self.patience:
                    if self.verbose:
                        print('Early stopping')
                    break

        return common_feat

    def train_dm_condition(self, data):
        if self.verbose:
            print('Training diffusion model (conditional) ...')      
        optimizer = torch.optim.Adam(self.dm_condition.parameters(), lr=self.lr, weight_decay=self.wd)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.5)
        self.dm_condition.train()
        best_loss = float('inf')
        patience = 0
        for epoch in range(self.diff_epochs):
            x = data.x.to(device)
            edge_index = data.edge_index.to(device)
            inputs = self.ae.encode(x, edge_index)
            loss, score_train, reconstructed = self.dm_condition(inputs, common_feat=self.common_feat)
            loss = loss.mean()

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.dm_condition.parameters(), 1.0)
            optimizer.step()
            if epoch % 10 == 0:
                print("Epoch: {:04d} loss= {:.5f}".format(epoch, loss.item()))
            scheduler.step()

            if loss < best_loss:
                best_loss = loss
                patience = 0
                torch.save({'state_dict': self.dm_condition.state_dict()},
                           os.path.join(self.ae_path, 'conditional_edm.pt'))
            else:
                patience += 1
                if patience == self.patience:
                    if self.verbose:
                        print('Early stopping')
                    break

    def sample_free(self, condition_model, uncondition_model, data):
        self.ae.eval()
        condition_model.eval()
        uncondition_model.eval()
        condition_net = condition_model.denoise_fn_D
        uncondition_net = uncondition_model.denoise_fn_D
        auc = []
        x = data.x.to(device)
        edge_index = data.edge_index.to(device)
        y = data.y.bool()
        Z_0 = self.ae.encode(x, edge_index)
        ###############  forward process  ####################
        noise = torch.randn_like(Z_0)
        for i in range(0, self.timesteps):
            t = torch.tensor([i] * Z_0.size(0), device=device).long()
            sqrt_alphas_cumprod_t = extract(sqrt_alphas_cumprod, t, Z_0.shape)
            sqrt_one_minus_alphas_cumprod_t = extract(sqrt_one_minus_alphas_cumprod, t, Z_0.shape)
            Z_t = sqrt_alphas_cumprod_t * Z_0 + sqrt_one_minus_alphas_cumprod_t * noise

            if self.sample_steps > 0:
                reconstructed = sample_dm_free(condition_net, uncondition_net, Z_t, self.sample_steps,
                                               common_feat=self.common_feat, lamda=self.lamda)
            s = to_dense_adj(edge_index)[0].to(device)
            x_, s_ = self.ae.decode(reconstructed, edge_index)
            score = self.ae.loss_func(x, x_, s, s_, self.ae_alpha)

            pyg_auc = eval_roc_auc(y, score.cpu().detach())
            auc.append(pyg_auc)
            print("timestep:{}, pyg_AUC: {:.4f}".format(i, pyg_auc))

####################################
# Diffusion Model Components (from diffusion_models.py)
####################################

ModuleType = Union[str, Callable[..., nn.Module]]
SIGMA_MIN = 0.002
SIGMA_MAX = 80
rho = 7
S_churn = 1
S_min = 0
S_max = float('inf')
S_noise = 1

def softmax_with_temperature(input, t=1, axis=-1):
    ex = torch.exp(input / t)
    sum_ex = torch.sum(ex, axis=axis)
    return ex / sum_ex

class EDMLoss:
    def __init__(self, P_mean=-1.2, P_std=1.2, sigma_data=0.5, hid_dim=100,
                 gamma=5, opts=None):
        self.P_mean = P_mean
        self.P_std = P_std
        self.sigma_data = sigma_data
        self.hid_dim = hid_dim
        self.gamma = gamma
        self.opts = opts
        self.cos = nn.CosineSimilarity(dim=1, eps=1e-6)
        self.KLDiv = nn.KLDivLoss(reduction='batchmean')

    def __call__(self, denoise_fn, data, common_feat=None):
        rnd_normal = torch.randn(data.shape[0], device=data.device)
        sigma = (rnd_normal * self.P_std + self.P_mean).exp()
        weight = (sigma ** 2 + self.sigma_data ** 2) / ((sigma * self.sigma_data) ** 2)

        y = data
        n = torch.randn_like(y) * sigma.unsqueeze(1)
        D_yn = denoise_fn(y + n, sigma, common_feat)
        target = y
        loss = weight.unsqueeze(1) * ((D_yn - target) ** 2)
        reconstruction_errors = (D_yn - target) ** 2   
        score = torch.sqrt(torch.sum(reconstruction_errors, 1))
        return loss, score, D_yn

class PositionalEmbedding(nn.Module):
    def __init__(self, num_channels, max_positions=10000, endpoint=False):
        super().__init__()
        self.num_channels = num_channels
        self.max_positions = max_positions
        self.endpoint = endpoint

    def forward(self, x):
        freqs = torch.arange(start=0, end=self.num_channels // 2,
                             dtype=torch.float32, device=x.device)
        freqs = freqs / (self.num_channels // 2 - (1 if self.endpoint else 0))
        freqs = (1 / self.max_positions) ** freqs
        x = x.ger(freqs.to(x.dtype))
        x = torch.cat([x.cos(), x.sin()], dim=1)
        return x

class MLPDiffusion(nn.Module):
    def __init__(self, d_in, dim_t=512):
        super().__init__()
        self.dim_t = dim_t
        self.proj = nn.Linear(d_in, dim_t)
        self.mlp = nn.Sequential(
            nn.Linear(dim_t, dim_t * 2),
            nn.SiLU(),
            nn.Linear(dim_t * 2, dim_t * 2),
            nn.SiLU(),
            nn.Linear(dim_t * 2, dim_t),
            nn.SiLU(),
            nn.Linear(dim_t, d_in),
        )
        self.map_noise = PositionalEmbedding(num_channels=dim_t)
        self.time_embed = nn.Sequential(
            nn.Linear(dim_t, dim_t),
            nn.SiLU(),
            nn.Linear(dim_t, dim_t)
        )
        self.feat_proj = nn.Linear(d_in, dim_t)
        self.head = nn.Sequential(
            nn.SiLU(), 
            nn.Linear(dim_t, dim_t)
        )

    def forward(self, x, noise_labels, common_feat=None):
        emb = self.map_noise(noise_labels)
        emb = emb.reshape(emb.shape[0], 2, -1).flip(1).reshape(*emb.shape)
        emb = self.time_embed(emb)
        if common_feat is None:
            x = self.proj(x) + emb
        else:
            x = self.proj(x) + emb + self.feat_proj(common_feat)
        return self.mlp(x)

class Precond(nn.Module):
    def __init__(self, denoise_fn, hid_dim, sigma_min=0, sigma_max=float('inf'), sigma_data=0.5):
        super().__init__()
        self.hid_dim = hid_dim
        self.sigma_min = sigma_min
        self.sigma_max = sigma_max
        self.sigma_data = sigma_data
        self.denoise_fn_F = denoise_fn

    def forward(self, x, sigma, common_feat=None):
        x = x.to(torch.float32)
        sigma = sigma.to(torch.float32).reshape(-1, 1)
        dtype = torch.float32

        c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2)
        c_out = sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2).sqrt()
        c_in = 1 / (self.sigma_data ** 2 + sigma ** 2).sqrt()
        c_noise = sigma.log() / 4

        x_in = c_in * x 
        F_x = self.denoise_fn_F(x_in.to(dtype), c_noise.flatten(), common_feat)
        D_x = c_skip * x + c_out * F_x.to(torch.float32)
        return D_x

    def round_sigma(self, sigma):
        return torch.as_tensor(sigma)

class Model(nn.Module):
    def __init__(self, denoise_fn, hid_dim, P_mean=-1.2, P_std=1.2, sigma_data=0.5, gamma=5, opts=None, pfgmpp=False):
        super().__init__()
        self.denoise_fn_D = Precond(denoise_fn, hid_dim)
        self.loss_fn = EDMLoss(P_mean, P_std, sigma_data, hid_dim=hid_dim, gamma=5, opts=None)

    def forward(self, x, common_feat=None):
        loss, score, reconstructed = self.loss_fn(self.denoise_fn_D, x, common_feat)
        return loss.mean(-1).mean(), score, reconstructed

def sample_step(net, num_steps, i, t_cur, t_next, x_next, common_feat=None):
    x_cur = x_next
    # Increase noise temporarily.    
    gamma = min(S_churn / num_steps, math.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0
    t_hat = net.round_sigma(t_cur + gamma * t_cur)
    x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * torch.randn_like(x_cur)
    # Euler step.
    denoised = net(x_hat, t_hat, common_feat).to(torch.float32)
    d_cur = (x_hat - denoised) / t_hat
    x_next = x_hat + (t_next - t_hat) * d_cur
    # Apply 2nd order correction.
    if i < num_steps - 1:
        denoised = net(x_next, t_next, common_feat).to(torch.float32)
        d_prime = (x_next - denoised) / t_next
        x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime)
    return x_next

def sample_dm(net, noise, num_steps, common_feat=None):
    step_indices = torch.arange(num_steps, dtype=torch.float32, device=noise.device)
    sigma_min = max(SIGMA_MIN, net.sigma_min)
    sigma_max = min(SIGMA_MAX, net.sigma_max)
    t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
    t_steps = torch.cat([net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])])
    z = noise.to(torch.float32) * t_steps[0]
    with torch.no_grad():
        for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])):
            z = sample_step(net, num_steps, i, t_cur, t_next, z, common_feat)
    return z

def sample_step_free(condition_net, uncondition_net, num_steps, i, t_cur, t_next, x_next, common_feat=None, lamda=None):
    x_cur = x_next
    gamma = min(S_churn / num_steps, math.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0
    t_hat = uncondition_net.round_sigma(t_cur + gamma * t_cur)
    x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * torch.randn_like(x_cur)
    
    denoised_condition = condition_net(x_hat, t_hat, common_feat=common_feat).to(torch.float32)
    denoised_uncondition = uncondition_net(x_hat, t_hat).to(torch.float32)
    d_cur_condition = (x_hat - denoised_condition) / t_hat
    d_cur_uncondition = (x_hat - denoised_uncondition) / t_hat

    # Guidance Process
    d_cur = (1 + lamda) * d_cur_uncondition - lamda * d_cur_condition
    x_next = x_hat + (t_next - t_hat) * d_cur
    # Apply 2nd order correction.
    if i < num_steps - 1:
        denoised_condition = condition_net(x_next, t_next, common_feat=common_feat).to(torch.float32)
        denoised_uncondition = uncondition_net(x_next, t_next).to(torch.float32)
        d_prime_condition = (x_next - denoised_condition) / t_next
        d_prime_uncondition = (x_next - denoised_uncondition) / t_next
        d_prime = (1 + lamda) * d_prime_uncondition - lamda * d_prime_condition
        x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime)
    return x_next

def sample_dm_free(condition_net, uncondition_net, noise, num_steps, common_feat=None, lamda=None):
    step_indices = torch.arange(num_steps, dtype=torch.float32, device=noise.device)
    sigma_min = max(SIGMA_MIN, uncondition_net.sigma_min)
    sigma_max = min(SIGMA_MAX, uncondition_net.sigma_max)
    t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
    t_steps = torch.cat([uncondition_net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])])
    z = noise.to(torch.float32) * t_steps[0]
    with torch.no_grad():
        for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])):
            z = sample_step_free(condition_net, uncondition_net, num_steps, i, t_cur, t_next, z, common_feat, lamda=lamda)
    return z

####################################
# Main Execution (replacing main.py)
####################################
if __name__ == '__main__':
    # Set your parameters directly (since Kaggle notebooks do not use CLI args)
    lamda = 0.2
    dataset = 'disney'
    ae_lr = 0.05
    ae_alpha = 0.1
    ae_dropout = 0.3

    # Instantiate and run DiffGAD
    model = DiffGAD(lr=0.004,
                    ae_alpha=ae_alpha,
                    ae_lr=ae_lr,
                    ae_dropout=ae_dropout,
                    lamda=lamda)
    model(dataset)


  data = torch.load(file_path)


Training autoencoder ...


  ae_dict = torch.load(os.path.join(self.ae_path, 'Graph_AE.pt'))
  0%|          | 0/20 [00:00<?, ?it/s]

Training diffusion model (unconditional) ...
Epoch: 0000 loss= 41.74825
Epoch: 0010 loss= 28.71758
Epoch: 0020 loss= 22.57305
Epoch: 0030 loss= 19.26435
Epoch: 0040 loss= 16.66113
Epoch: 0050 loss= 9.22207
Epoch: 0060 loss= 2.07658
Epoch: 0070 loss= 1.01230
Epoch: 0080 loss= 0.82857
Epoch: 0090 loss= 0.70719
Epoch: 0100 loss= 0.70981
Epoch: 0110 loss= 0.67289
Epoch: 0120 loss= 0.57775
Epoch: 0130 loss= 0.63716
Epoch: 0140 loss= 0.67159
Epoch: 0150 loss= 0.59822
Epoch: 0160 loss= 0.66566
Epoch: 0170 loss= 0.60800
Epoch: 0180 loss= 0.56894
Epoch: 0190 loss= 0.62744
Epoch: 0200 loss= 0.59100
Epoch: 0210 loss= 0.56577
Epoch: 0220 loss= 0.58603
Epoch: 0230 loss= 0.50014
Epoch: 0240 loss= 0.53733
Epoch: 0250 loss= 0.58627
Epoch: 0260 loss= 0.53912
Epoch: 0270 loss= 0.59731
Epoch: 0280 loss= 0.54537
Epoch: 0290 loss= 0.56867
Epoch: 0300 loss= 0.48801
Epoch: 0310 loss= 0.57566
Epoch: 0320 loss= 0.52488
Early stopping
Common feature: tensor([[-4.8218,  4.7629, -5.3470, -4.6599, -4.4118,  5.6766

  dm_dict = torch.load(os.path.join(self.ae_path, 'edm.pt'))


Training diffusion model (conditional) ...
Epoch: 0000 loss= 38.58255
Epoch: 0010 loss= 27.31596
Epoch: 0020 loss= 16.04342
Epoch: 0030 loss= 13.42252
Epoch: 0040 loss= 5.97612
Epoch: 0050 loss= 1.92771
Epoch: 0060 loss= 1.52439
Epoch: 0070 loss= 1.00807
Epoch: 0080 loss= 0.73409
Epoch: 0090 loss= 0.73142
Epoch: 0100 loss= 0.70292
Epoch: 0110 loss= 0.65698
Epoch: 0120 loss= 0.68502
Epoch: 0130 loss= 0.65894
Epoch: 0140 loss= 0.68381
Epoch: 0150 loss= 0.61781
Epoch: 0160 loss= 0.56299
Epoch: 0170 loss= 0.64625
Epoch: 0180 loss= 0.61310
Epoch: 0190 loss= 0.55782
Epoch: 0200 loss= 0.59901
Epoch: 0210 loss= 0.55578
Epoch: 0220 loss= 0.62338
Epoch: 0230 loss= 0.65824
Epoch: 0240 loss= 0.58399
Epoch: 0250 loss= 0.57690
Early stopping


  dm_free_dict = torch.load(os.path.join(self.ae_path, 'conditional_edm.pt'))


timestep:0, pyg_AUC: 0.4859
timestep:1, pyg_AUC: 0.4887
timestep:2, pyg_AUC: 0.4915
timestep:3, pyg_AUC: 0.4887
timestep:4, pyg_AUC: 0.4859
timestep:5, pyg_AUC: 0.4845
timestep:6, pyg_AUC: 0.4873
timestep:7, pyg_AUC: 0.4929
timestep:8, pyg_AUC: 0.4915
timestep:9, pyg_AUC: 0.4887
timestep:10, pyg_AUC: 0.4915
timestep:11, pyg_AUC: 0.4901
timestep:12, pyg_AUC: 0.4915
timestep:13, pyg_AUC: 0.4915
timestep:14, pyg_AUC: 0.4901
timestep:15, pyg_AUC: 0.4873
timestep:16, pyg_AUC: 0.4944
timestep:17, pyg_AUC: 0.4887
timestep:18, pyg_AUC: 0.4915
timestep:19, pyg_AUC: 0.4873
timestep:20, pyg_AUC: 0.4859
timestep:21, pyg_AUC: 0.4887
timestep:22, pyg_AUC: 0.4887
timestep:23, pyg_AUC: 0.4958
timestep:24, pyg_AUC: 0.4887
timestep:25, pyg_AUC: 0.4944
timestep:26, pyg_AUC: 0.4901
timestep:27, pyg_AUC: 0.4901
timestep:28, pyg_AUC: 0.4887
timestep:29, pyg_AUC: 0.4873
timestep:30, pyg_AUC: 0.4929
timestep:31, pyg_AUC: 0.4887
timestep:32, pyg_AUC: 0.4929
timestep:33, pyg_AUC: 0.4901
timestep:34, pyg_AUC: 0.

  5%|▌         | 1/20 [01:39<31:25, 99.22s/it]

timestep:499, pyg_AUC: 0.4915
Training diffusion model (unconditional) ...
Epoch: 0000 loss= 43.09970
Epoch: 0010 loss= 23.89958
Epoch: 0020 loss= 20.00839
Epoch: 0030 loss= 15.99837
Epoch: 0040 loss= 13.97786
Epoch: 0050 loss= 8.65574
Epoch: 0060 loss= 2.16946
Epoch: 0070 loss= 0.96767
Epoch: 0080 loss= 0.84541
Epoch: 0090 loss= 0.72149
Epoch: 0100 loss= 0.65871
Epoch: 0110 loss= 0.61976
Epoch: 0120 loss= 0.61856
Epoch: 0130 loss= 0.66414
Epoch: 0140 loss= 0.56016
Epoch: 0150 loss= 0.60452
Epoch: 0160 loss= 0.63460
Epoch: 0170 loss= 0.57259
Epoch: 0180 loss= 0.63114
Epoch: 0190 loss= 0.60596
Epoch: 0200 loss= 0.58978
Epoch: 0210 loss= 0.56615
Epoch: 0220 loss= 0.60732
Epoch: 0230 loss= 0.53996
Epoch: 0240 loss= 0.63378
Epoch: 0250 loss= 0.61328
Epoch: 0260 loss= 0.61691
Epoch: 0270 loss= 0.61950
Epoch: 0280 loss= 0.54428
Epoch: 0290 loss= 0.58976
Epoch: 0300 loss= 0.52828
Epoch: 0310 loss= 0.56554
Early stopping
Common feature: tensor([[-4.7810,  4.7166, -5.3294, -4.6056, -4.4004,  5.

  dm_dict = torch.load(os.path.join(self.ae_path, 'edm.pt'))


Epoch: 0020 loss= 16.06802
Epoch: 0030 loss= 10.82818
Epoch: 0040 loss= 3.66511
Epoch: 0050 loss= 1.80205
Epoch: 0060 loss= 1.07684
Epoch: 0070 loss= 1.02895
Epoch: 0080 loss= 0.99625
Epoch: 0090 loss= 0.60326
Epoch: 0100 loss= 0.62781
Epoch: 0110 loss= 0.54011
Epoch: 0120 loss= 0.64524
Epoch: 0130 loss= 0.66701
Epoch: 0140 loss= 0.59699
Epoch: 0150 loss= 0.58893
Epoch: 0160 loss= 0.61192
Epoch: 0170 loss= 0.61490
Epoch: 0180 loss= 0.58743
Epoch: 0190 loss= 0.57469
Epoch: 0200 loss= 0.63438
Epoch: 0210 loss= 0.55815
Epoch: 0220 loss= 0.58641
Epoch: 0230 loss= 0.59232
Epoch: 0240 loss= 0.55423
Epoch: 0250 loss= 0.58151
Epoch: 0260 loss= 0.58581
Epoch: 0270 loss= 0.52504
Epoch: 0280 loss= 0.57083
Epoch: 0290 loss= 0.60417
Epoch: 0300 loss= 0.53191
Epoch: 0310 loss= 0.58192
Epoch: 0320 loss= 0.57072
Epoch: 0330 loss= 0.53715
Epoch: 0340 loss= 0.61432
Early stopping


  dm_free_dict = torch.load(os.path.join(self.ae_path, 'conditional_edm.pt'))


timestep:0, pyg_AUC: 0.4944
timestep:1, pyg_AUC: 0.4873
timestep:2, pyg_AUC: 0.4845
timestep:3, pyg_AUC: 0.4873
timestep:4, pyg_AUC: 0.4873
timestep:5, pyg_AUC: 0.4887
timestep:6, pyg_AUC: 0.4901
timestep:7, pyg_AUC: 0.4901
timestep:8, pyg_AUC: 0.4958
timestep:9, pyg_AUC: 0.4915
timestep:10, pyg_AUC: 0.4915
timestep:11, pyg_AUC: 0.4859
timestep:12, pyg_AUC: 0.4901
timestep:13, pyg_AUC: 0.4901
timestep:14, pyg_AUC: 0.4915
timestep:15, pyg_AUC: 0.4887
timestep:16, pyg_AUC: 0.4887
timestep:17, pyg_AUC: 0.4958
timestep:18, pyg_AUC: 0.4915
timestep:19, pyg_AUC: 0.4915
timestep:20, pyg_AUC: 0.4915
timestep:21, pyg_AUC: 0.4887
timestep:22, pyg_AUC: 0.4859
timestep:23, pyg_AUC: 0.4845
timestep:24, pyg_AUC: 0.4915
timestep:25, pyg_AUC: 0.4901
timestep:26, pyg_AUC: 0.4901
timestep:27, pyg_AUC: 0.4873
timestep:28, pyg_AUC: 0.4901
timestep:29, pyg_AUC: 0.4873
timestep:30, pyg_AUC: 0.4915
timestep:31, pyg_AUC: 0.4901
timestep:32, pyg_AUC: 0.4915
timestep:33, pyg_AUC: 0.4929
timestep:34, pyg_AUC: 0.

 10%|█         | 2/20 [03:18<29:47, 99.29s/it]

timestep:498, pyg_AUC: 0.4887
timestep:499, pyg_AUC: 0.4915
Training diffusion model (unconditional) ...
Epoch: 0000 loss= 34.30563
Epoch: 0010 loss= 24.52308
Epoch: 0020 loss= 18.01850
Epoch: 0030 loss= 15.73822
Epoch: 0040 loss= 14.46130
Epoch: 0050 loss= 8.25099
Epoch: 0060 loss= 1.74654
Epoch: 0070 loss= 1.25022
Epoch: 0080 loss= 0.74627
Epoch: 0090 loss= 0.62320
Epoch: 0100 loss= 0.77344
Epoch: 0110 loss= 0.75890
Epoch: 0120 loss= 0.60660
Epoch: 0130 loss= 0.70672
Epoch: 0140 loss= 0.64233
Epoch: 0150 loss= 0.66480
Epoch: 0160 loss= 0.68096
Epoch: 0170 loss= 0.57385
Epoch: 0180 loss= 0.59245
Epoch: 0190 loss= 0.73376
Epoch: 0200 loss= 0.61577
Epoch: 0210 loss= 0.55745
Epoch: 0220 loss= 0.61199
Epoch: 0230 loss= 0.57383
Epoch: 0240 loss= 0.63540
Epoch: 0250 loss= 0.57566
Epoch: 0260 loss= 0.52862
Epoch: 0270 loss= 0.60758
Epoch: 0280 loss= 0.53857
Epoch: 0290 loss= 0.54783
Epoch: 0300 loss= 0.55545
Epoch: 0310 loss= 0.59615
Epoch: 0320 loss= 0.54385
Epoch: 0330 loss= 0.46250
Epoch:

  dm_dict = torch.load(os.path.join(self.ae_path, 'edm.pt'))


Epoch: 0020 loss= 17.71773
Epoch: 0030 loss= 8.76839
Epoch: 0040 loss= 1.72248
Epoch: 0050 loss= 0.90573
Epoch: 0060 loss= 1.06366
Epoch: 0070 loss= 0.76197
Epoch: 0080 loss= 0.79150
Epoch: 0090 loss= 0.64370
Epoch: 0100 loss= 0.70343
Epoch: 0110 loss= 0.61989
Epoch: 0120 loss= 0.60040
Epoch: 0130 loss= 0.58892
Epoch: 0140 loss= 0.80324
Epoch: 0150 loss= 0.53787
Epoch: 0160 loss= 0.69766
Epoch: 0170 loss= 0.64403
Epoch: 0180 loss= 0.80786
Epoch: 0190 loss= 0.71461
Epoch: 0200 loss= 0.81721
Epoch: 0210 loss= 0.65608
Epoch: 0220 loss= 0.66994
Epoch: 0230 loss= 0.50288
Epoch: 0240 loss= 0.65217
Epoch: 0250 loss= 0.51890
Epoch: 0260 loss= 0.51229
Epoch: 0270 loss= 0.57621
Epoch: 0280 loss= 0.52907
Epoch: 0290 loss= 0.60237
Epoch: 0300 loss= 0.57984
Epoch: 0310 loss= 0.61565
Epoch: 0320 loss= 0.55260
Epoch: 0330 loss= 0.61863
Epoch: 0340 loss= 0.53616
Epoch: 0350 loss= 0.53919
Epoch: 0360 loss= 0.54303
Epoch: 0370 loss= 0.56474
Epoch: 0380 loss= 0.63315
Epoch: 0390 loss= 0.53331
Epoch: 0400

  dm_free_dict = torch.load(os.path.join(self.ae_path, 'conditional_edm.pt'))


timestep:0, pyg_AUC: 0.4873
timestep:1, pyg_AUC: 0.4873
timestep:2, pyg_AUC: 0.4873
timestep:3, pyg_AUC: 0.4831
timestep:4, pyg_AUC: 0.4845
timestep:5, pyg_AUC: 0.4944
timestep:6, pyg_AUC: 0.4859
timestep:7, pyg_AUC: 0.4859
timestep:8, pyg_AUC: 0.4915
timestep:9, pyg_AUC: 0.4845
timestep:10, pyg_AUC: 0.4929
timestep:11, pyg_AUC: 0.4845
timestep:12, pyg_AUC: 0.4915
timestep:13, pyg_AUC: 0.4873
timestep:14, pyg_AUC: 0.4901
timestep:15, pyg_AUC: 0.4915
timestep:16, pyg_AUC: 0.4929
timestep:17, pyg_AUC: 0.4901
timestep:18, pyg_AUC: 0.4958
timestep:19, pyg_AUC: 0.4887
timestep:20, pyg_AUC: 0.4845
timestep:21, pyg_AUC: 0.4873
timestep:22, pyg_AUC: 0.4929
timestep:23, pyg_AUC: 0.4873
timestep:24, pyg_AUC: 0.4901
timestep:25, pyg_AUC: 0.4887
timestep:26, pyg_AUC: 0.4915
timestep:27, pyg_AUC: 0.4887
timestep:28, pyg_AUC: 0.4831
timestep:29, pyg_AUC: 0.4944
timestep:30, pyg_AUC: 0.4944
timestep:31, pyg_AUC: 0.4915
timestep:32, pyg_AUC: 0.4873
timestep:33, pyg_AUC: 0.4958
timestep:34, pyg_AUC: 0.

 15%|█▌        | 3/20 [05:00<28:28, 100.49s/it]

timestep:498, pyg_AUC: 0.4859
timestep:499, pyg_AUC: 0.4859
Training diffusion model (unconditional) ...
Epoch: 0000 loss= 39.86916
Epoch: 0010 loss= 30.48896
Epoch: 0020 loss= 21.99897
Epoch: 0030 loss= 15.32752
Epoch: 0040 loss= 13.91725
Epoch: 0050 loss= 10.19799
Epoch: 0060 loss= 1.75214
Epoch: 0070 loss= 0.96667
Epoch: 0080 loss= 0.66522
Epoch: 0090 loss= 0.65110
Epoch: 0100 loss= 0.58038
Epoch: 0110 loss= 0.56416
Epoch: 0120 loss= 0.58306
Epoch: 0130 loss= 0.63422
Epoch: 0140 loss= 0.56405
Epoch: 0150 loss= 0.53804
Epoch: 0160 loss= 0.59004
Epoch: 0170 loss= 0.57265
Epoch: 0180 loss= 0.60825
Epoch: 0190 loss= 0.65617
Epoch: 0200 loss= 0.59594
Epoch: 0210 loss= 0.63938
Epoch: 0220 loss= 0.66307
Epoch: 0230 loss= 0.58697
Epoch: 0240 loss= 0.60111
Epoch: 0250 loss= 0.53910
Epoch: 0260 loss= 0.57162
Epoch: 0270 loss= 0.59770
Epoch: 0280 loss= 0.64961
Epoch: 0290 loss= 0.52843
Early stopping
Common feature: tensor([[-4.8257,  4.6789, -5.3808, -4.5979, -4.4125,  5.6426,  5.4814, -5.239

  dm_dict = torch.load(os.path.join(self.ae_path, 'edm.pt'))


Epoch: 0010 loss= 26.02908
Epoch: 0020 loss= 15.64923
Epoch: 0030 loss= 10.09938
Epoch: 0040 loss= 2.52251
Epoch: 0050 loss= 0.99488
Epoch: 0060 loss= 0.73221
Epoch: 0070 loss= 1.24813
Epoch: 0080 loss= 0.61150
Epoch: 0090 loss= 0.57408
Epoch: 0100 loss= 0.57113
Epoch: 0110 loss= 0.63948
Epoch: 0120 loss= 0.65565
Epoch: 0130 loss= 0.62418
Epoch: 0140 loss= 0.59848
Epoch: 0150 loss= 0.60206
Epoch: 0160 loss= 0.64204
Epoch: 0170 loss= 0.59867
Epoch: 0180 loss= 0.58721
Epoch: 0190 loss= 0.61809
Epoch: 0200 loss= 0.54739
Epoch: 0210 loss= 0.57173
Epoch: 0220 loss= 0.53242
Epoch: 0230 loss= 0.48657
Epoch: 0240 loss= 0.60381
Epoch: 0250 loss= 0.62325
Epoch: 0260 loss= 0.57790
Epoch: 0270 loss= 0.52301
Epoch: 0280 loss= 0.65179
Epoch: 0290 loss= 0.62602
Epoch: 0300 loss= 0.64857
Epoch: 0310 loss= 0.52779
Epoch: 0320 loss= 0.60063
Epoch: 0330 loss= 0.56770
Epoch: 0340 loss= 0.60228
Epoch: 0350 loss= 0.52561
Epoch: 0360 loss= 0.55866
Epoch: 0370 loss= 0.57887
Epoch: 0380 loss= 0.51621
Epoch: 03

  dm_free_dict = torch.load(os.path.join(self.ae_path, 'conditional_edm.pt'))


timestep:0, pyg_AUC: 0.4915
timestep:1, pyg_AUC: 0.4901
timestep:2, pyg_AUC: 0.4915
timestep:3, pyg_AUC: 0.4845
timestep:4, pyg_AUC: 0.4859
timestep:5, pyg_AUC: 0.4859
timestep:6, pyg_AUC: 0.4944
timestep:7, pyg_AUC: 0.4859
timestep:8, pyg_AUC: 0.4929
timestep:9, pyg_AUC: 0.4859
timestep:10, pyg_AUC: 0.4915
timestep:11, pyg_AUC: 0.4915
timestep:12, pyg_AUC: 0.4887
timestep:13, pyg_AUC: 0.4887
timestep:14, pyg_AUC: 0.4873
timestep:15, pyg_AUC: 0.4873
timestep:16, pyg_AUC: 0.4859
timestep:17, pyg_AUC: 0.4901
timestep:18, pyg_AUC: 0.4887
timestep:19, pyg_AUC: 0.4944
timestep:20, pyg_AUC: 0.4887
timestep:21, pyg_AUC: 0.4929
timestep:22, pyg_AUC: 0.4887
timestep:23, pyg_AUC: 0.4873
timestep:24, pyg_AUC: 0.4887
timestep:25, pyg_AUC: 0.4887
timestep:26, pyg_AUC: 0.4859
timestep:27, pyg_AUC: 0.4929
timestep:28, pyg_AUC: 0.4887
timestep:29, pyg_AUC: 0.4859
timestep:30, pyg_AUC: 0.4901
timestep:31, pyg_AUC: 0.4887
timestep:32, pyg_AUC: 0.4901
timestep:33, pyg_AUC: 0.4915
timestep:34, pyg_AUC: 0.

 20%|██        | 4/20 [06:41<26:49, 100.57s/it]

timestep:499, pyg_AUC: 0.4901
Training diffusion model (unconditional) ...
Epoch: 0000 loss= 41.68215
Epoch: 0010 loss= 28.47275
Epoch: 0020 loss= 24.73520
Epoch: 0030 loss= 14.04542
Epoch: 0040 loss= 9.71388
Epoch: 0050 loss= 2.42823
Epoch: 0060 loss= 1.32007
Epoch: 0070 loss= 0.89542
Epoch: 0080 loss= 0.67976
Epoch: 0090 loss= 0.90967
Epoch: 0100 loss= 0.65428
Epoch: 0110 loss= 0.61078
Epoch: 0120 loss= 0.60525
Epoch: 0130 loss= 0.59291
Epoch: 0140 loss= 0.64545
Epoch: 0150 loss= 0.65498
Epoch: 0160 loss= 0.68747
Epoch: 0170 loss= 0.66787
Epoch: 0180 loss= 0.65678
Epoch: 0190 loss= 0.63102
Epoch: 0200 loss= 0.63408
Epoch: 0210 loss= 0.54873
Epoch: 0220 loss= 0.67817
Epoch: 0230 loss= 0.57302
Epoch: 0240 loss= 0.61365
Epoch: 0250 loss= 0.67216
Epoch: 0260 loss= 0.52865
Epoch: 0270 loss= 0.58357
Epoch: 0280 loss= 0.61939
Epoch: 0290 loss= 0.54986
Epoch: 0300 loss= 0.57297
Epoch: 0310 loss= 0.57103
Epoch: 0320 loss= 0.60140
Epoch: 0330 loss= 0.61300
Epoch: 0340 loss= 0.47295
Epoch: 0350

  dm_dict = torch.load(os.path.join(self.ae_path, 'edm.pt'))


Epoch: 0010 loss= 25.91788
Epoch: 0020 loss= 16.83441
Epoch: 0030 loss= 12.59090
Epoch: 0040 loss= 4.39100
Epoch: 0050 loss= 1.13026
Epoch: 0060 loss= 0.71802
Epoch: 0070 loss= 0.71378
Epoch: 0080 loss= 0.73922
Epoch: 0090 loss= 0.91783
Epoch: 0100 loss= 0.64987
Epoch: 0110 loss= 0.58740
Epoch: 0120 loss= 0.65692
Epoch: 0130 loss= 0.66047
Epoch: 0140 loss= 0.58717
Epoch: 0150 loss= 0.63363
Epoch: 0160 loss= 0.55541
Epoch: 0170 loss= 0.60270
Epoch: 0180 loss= 0.67354
Epoch: 0190 loss= 0.63424
Epoch: 0200 loss= 0.61216
Epoch: 0210 loss= 0.60508
Epoch: 0220 loss= 0.61136
Epoch: 0230 loss= 0.60094
Early stopping


  dm_free_dict = torch.load(os.path.join(self.ae_path, 'conditional_edm.pt'))


timestep:0, pyg_AUC: 0.4873
timestep:1, pyg_AUC: 0.4901
timestep:2, pyg_AUC: 0.4887
timestep:3, pyg_AUC: 0.4887
timestep:4, pyg_AUC: 0.4887
timestep:5, pyg_AUC: 0.4915
timestep:6, pyg_AUC: 0.4816
timestep:7, pyg_AUC: 0.4929
timestep:8, pyg_AUC: 0.4845
timestep:9, pyg_AUC: 0.4887
timestep:10, pyg_AUC: 0.4901
timestep:11, pyg_AUC: 0.4873
timestep:12, pyg_AUC: 0.4944
timestep:13, pyg_AUC: 0.4887
timestep:14, pyg_AUC: 0.4915
timestep:15, pyg_AUC: 0.4929
timestep:16, pyg_AUC: 0.4929
timestep:17, pyg_AUC: 0.4901
timestep:18, pyg_AUC: 0.4915
timestep:19, pyg_AUC: 0.4901
timestep:20, pyg_AUC: 0.4944
timestep:21, pyg_AUC: 0.4887
timestep:22, pyg_AUC: 0.4887
timestep:23, pyg_AUC: 0.4831
timestep:24, pyg_AUC: 0.4901
timestep:25, pyg_AUC: 0.4901
timestep:26, pyg_AUC: 0.4887
timestep:27, pyg_AUC: 0.4887
timestep:28, pyg_AUC: 0.4901
timestep:29, pyg_AUC: 0.4873
timestep:30, pyg_AUC: 0.4887
timestep:31, pyg_AUC: 0.4859
timestep:32, pyg_AUC: 0.4887
timestep:33, pyg_AUC: 0.4845
timestep:34, pyg_AUC: 0.

 25%|██▌       | 5/20 [08:21<25:08, 100.60s/it]

timestep:499, pyg_AUC: 0.4929
Training diffusion model (unconditional) ...
Epoch: 0000 loss= 36.81496
Epoch: 0010 loss= 24.01574
Epoch: 0020 loss= 20.37944
Epoch: 0030 loss= 17.50788
Epoch: 0040 loss= 10.74729
Epoch: 0050 loss= 5.85092
Epoch: 0060 loss= 2.17113
Epoch: 0070 loss= 1.07578
Epoch: 0080 loss= 0.98621
Epoch: 0090 loss= 0.76385
Epoch: 0100 loss= 0.66929
Epoch: 0110 loss= 0.67320
Epoch: 0120 loss= 0.63755
Epoch: 0130 loss= 0.60513
Epoch: 0140 loss= 0.61207
Epoch: 0150 loss= 0.59184
Epoch: 0160 loss= 0.59582
Epoch: 0170 loss= 0.55232
Epoch: 0180 loss= 0.57364
Epoch: 0190 loss= 0.61054
Epoch: 0200 loss= 0.61766
Epoch: 0210 loss= 0.63381
Epoch: 0220 loss= 0.52509
Epoch: 0230 loss= 0.64633
Epoch: 0240 loss= 0.57522
Epoch: 0250 loss= 0.61636
Epoch: 0260 loss= 0.51932
Epoch: 0270 loss= 0.55697
Epoch: 0280 loss= 0.56068
Epoch: 0290 loss= 0.60778
Epoch: 0300 loss= 0.55290
Epoch: 0310 loss= 0.55296
Epoch: 0320 loss= 0.56979
Epoch: 0330 loss= 0.60260
Epoch: 0340 loss= 0.60295
Epoch: 035

  dm_dict = torch.load(os.path.join(self.ae_path, 'edm.pt'))


Epoch: 0010 loss= 21.09327
Epoch: 0020 loss= 12.61672
Epoch: 0030 loss= 7.89281
Epoch: 0040 loss= 2.13680
Epoch: 0050 loss= 0.98075
Epoch: 0060 loss= 0.79291
Epoch: 0070 loss= 0.91469
Epoch: 0080 loss= 0.77614
Epoch: 0090 loss= 0.62193
Epoch: 0100 loss= 1.23853
Epoch: 0110 loss= 0.69466
Epoch: 0120 loss= 0.57718
Epoch: 0130 loss= 0.64769
Epoch: 0140 loss= 0.65806
Epoch: 0150 loss= 0.53412
Epoch: 0160 loss= 0.65618
Epoch: 0170 loss= 0.64133
Epoch: 0180 loss= 0.62494
Epoch: 0190 loss= 0.63264
Epoch: 0200 loss= 0.55154
Epoch: 0210 loss= 0.58608
Epoch: 0220 loss= 0.60385
Epoch: 0230 loss= 0.53290
Epoch: 0240 loss= 0.61490
Epoch: 0250 loss= 0.52041
Epoch: 0260 loss= 0.54703
Epoch: 0270 loss= 0.59778
Epoch: 0280 loss= 0.54817
Epoch: 0290 loss= 0.50545
Epoch: 0300 loss= 0.61654
Epoch: 0310 loss= 0.56467
Epoch: 0320 loss= 0.52343
Epoch: 0330 loss= 0.54388
Epoch: 0340 loss= 0.60762
Early stopping


  dm_free_dict = torch.load(os.path.join(self.ae_path, 'conditional_edm.pt'))


timestep:0, pyg_AUC: 0.4915
timestep:1, pyg_AUC: 0.4915
timestep:2, pyg_AUC: 0.4873
timestep:3, pyg_AUC: 0.4859
timestep:4, pyg_AUC: 0.4915
timestep:5, pyg_AUC: 0.4915
timestep:6, pyg_AUC: 0.4901
timestep:7, pyg_AUC: 0.4901
timestep:8, pyg_AUC: 0.4845
timestep:9, pyg_AUC: 0.4901
timestep:10, pyg_AUC: 0.4929
timestep:11, pyg_AUC: 0.4845
timestep:12, pyg_AUC: 0.4845
timestep:13, pyg_AUC: 0.4859
timestep:14, pyg_AUC: 0.4859
timestep:15, pyg_AUC: 0.4859
timestep:16, pyg_AUC: 0.4887
timestep:17, pyg_AUC: 0.4887
timestep:18, pyg_AUC: 0.4887
timestep:19, pyg_AUC: 0.4915
timestep:20, pyg_AUC: 0.4887
timestep:21, pyg_AUC: 0.4873
timestep:22, pyg_AUC: 0.4901
timestep:23, pyg_AUC: 0.4901
timestep:24, pyg_AUC: 0.4887
timestep:25, pyg_AUC: 0.4859
timestep:26, pyg_AUC: 0.4915
timestep:27, pyg_AUC: 0.4901
timestep:28, pyg_AUC: 0.4845
timestep:29, pyg_AUC: 0.4887
timestep:30, pyg_AUC: 0.4929
timestep:31, pyg_AUC: 0.4901
timestep:32, pyg_AUC: 0.4873
timestep:33, pyg_AUC: 0.4901
timestep:34, pyg_AUC: 0.

 30%|███       | 6/20 [10:02<23:28, 100.61s/it]

timestep:499, pyg_AUC: 0.4901
Training diffusion model (unconditional) ...
Epoch: 0000 loss= 39.36583
Epoch: 0010 loss= 24.33446
Epoch: 0020 loss= 23.34266
Epoch: 0030 loss= 16.61592
Epoch: 0040 loss= 13.98286
Epoch: 0050 loss= 8.16397
Epoch: 0060 loss= 1.66162
Epoch: 0070 loss= 0.66561
Epoch: 0080 loss= 0.70240
Epoch: 0090 loss= 0.77302
Epoch: 0100 loss= 0.78796
Epoch: 0110 loss= 0.68809
Epoch: 0120 loss= 0.57944
Epoch: 0130 loss= 0.60001
Epoch: 0140 loss= 0.60059
Epoch: 0150 loss= 0.63417
Epoch: 0160 loss= 0.55416
Epoch: 0170 loss= 0.66118
Epoch: 0180 loss= 0.62555
Epoch: 0190 loss= 0.60469
Epoch: 0200 loss= 0.68234
Epoch: 0210 loss= 0.65110
Epoch: 0220 loss= 0.58275
Epoch: 0230 loss= 0.60707
Epoch: 0240 loss= 0.55465
Early stopping


  dm_dict = torch.load(os.path.join(self.ae_path, 'edm.pt'))


Common feature: tensor([[-4.8299,  4.7370, -5.3792, -4.6727, -4.4175,  5.6911,  5.5365, -5.3083]],
       device='cuda:0')
Training diffusion model (conditional) ...
Epoch: 0000 loss= 39.87531
Epoch: 0010 loss= 19.38387
Epoch: 0020 loss= 13.05482
Epoch: 0030 loss= 8.89157
Epoch: 0040 loss= 1.90042
Epoch: 0050 loss= 1.07729
Epoch: 0060 loss= 1.18685
Epoch: 0070 loss= 0.72021
Epoch: 0080 loss= 0.65817
Epoch: 0090 loss= 0.65867
Epoch: 0100 loss= 0.51590
Epoch: 0110 loss= 0.54971
Epoch: 0120 loss= 0.71047
Epoch: 0130 loss= 0.62737
Epoch: 0140 loss= 0.57290
Epoch: 0150 loss= 0.68967
Epoch: 0160 loss= 0.56217
Epoch: 0170 loss= 0.56579
Epoch: 0180 loss= 0.70108
Epoch: 0190 loss= 0.55479
Epoch: 0200 loss= 0.64591
Epoch: 0210 loss= 0.63127
Epoch: 0220 loss= 0.53420
Epoch: 0230 loss= 0.60165
Epoch: 0240 loss= 0.63206
Early stopping


  dm_free_dict = torch.load(os.path.join(self.ae_path, 'conditional_edm.pt'))


timestep:0, pyg_AUC: 0.4944
timestep:1, pyg_AUC: 0.4859
timestep:2, pyg_AUC: 0.4929
timestep:3, pyg_AUC: 0.4929
timestep:4, pyg_AUC: 0.4845
timestep:5, pyg_AUC: 0.4831
timestep:6, pyg_AUC: 0.4859
timestep:7, pyg_AUC: 0.4887
timestep:8, pyg_AUC: 0.4831
timestep:9, pyg_AUC: 0.4944
timestep:10, pyg_AUC: 0.4929
timestep:11, pyg_AUC: 0.4901
timestep:12, pyg_AUC: 0.4958
timestep:13, pyg_AUC: 0.4929
timestep:14, pyg_AUC: 0.4915
timestep:15, pyg_AUC: 0.4901
timestep:16, pyg_AUC: 0.4944
timestep:17, pyg_AUC: 0.4873
timestep:18, pyg_AUC: 0.4845
timestep:19, pyg_AUC: 0.4859
timestep:20, pyg_AUC: 0.4915
timestep:21, pyg_AUC: 0.4901
timestep:22, pyg_AUC: 0.4915
timestep:23, pyg_AUC: 0.4944
timestep:24, pyg_AUC: 0.4845
timestep:25, pyg_AUC: 0.4887
timestep:26, pyg_AUC: 0.4958
timestep:27, pyg_AUC: 0.4873
timestep:28, pyg_AUC: 0.4901
timestep:29, pyg_AUC: 0.4929
timestep:30, pyg_AUC: 0.4859
timestep:31, pyg_AUC: 0.4929
timestep:32, pyg_AUC: 0.4845
timestep:33, pyg_AUC: 0.4901
timestep:34, pyg_AUC: 0.

 35%|███▌      | 7/20 [11:41<21:43, 100.26s/it]

timestep:499, pyg_AUC: 0.4845
Training diffusion model (unconditional) ...
Epoch: 0000 loss= 36.04852
Epoch: 0010 loss= 21.20682
Epoch: 0020 loss= 17.15179
Epoch: 0030 loss= 15.52214
Epoch: 0040 loss= 12.81039
Epoch: 0050 loss= 7.01130
Epoch: 0060 loss= 1.65614
Epoch: 0070 loss= 0.89220
Epoch: 0080 loss= 0.72031
Epoch: 0090 loss= 0.69835
Epoch: 0100 loss= 0.62402
Epoch: 0110 loss= 0.68970
Epoch: 0120 loss= 0.57058
Epoch: 0130 loss= 0.56835
Epoch: 0140 loss= 0.55246
Epoch: 0150 loss= 0.64890
Epoch: 0160 loss= 0.58354
Epoch: 0170 loss= 0.58221
Epoch: 0180 loss= 0.57907
Epoch: 0190 loss= 0.58469
Epoch: 0200 loss= 0.54268
Epoch: 0210 loss= 0.49849
Epoch: 0220 loss= 0.53378
Epoch: 0230 loss= 0.56737
Epoch: 0240 loss= 0.63000
Epoch: 0250 loss= 0.53736
Epoch: 0260 loss= 0.57091
Epoch: 0270 loss= 0.59361
Epoch: 0280 loss= 0.51767
Epoch: 0290 loss= 0.62360
Epoch: 0300 loss= 0.57281
Epoch: 0310 loss= 0.51701
Epoch: 0320 loss= 0.49488
Epoch: 0330 loss= 0.56683
Epoch: 0340 loss= 0.56669
Early stop

  dm_dict = torch.load(os.path.join(self.ae_path, 'edm.pt'))


Epoch: 0010 loss= 18.65797
Epoch: 0020 loss= 12.12986
Epoch: 0030 loss= 8.37839
Epoch: 0040 loss= 1.58492
Epoch: 0050 loss= 1.37858
Epoch: 0060 loss= 0.84123
Epoch: 0070 loss= 0.59859
Epoch: 0080 loss= 0.62477
Epoch: 0090 loss= 0.75054
Epoch: 0100 loss= 0.61656
Epoch: 0110 loss= 0.64501
Epoch: 0120 loss= 0.78631
Epoch: 0130 loss= 0.65359
Epoch: 0140 loss= 0.58450
Epoch: 0150 loss= 0.53318
Epoch: 0160 loss= 0.54930
Epoch: 0170 loss= 0.54437
Epoch: 0180 loss= 0.68925
Epoch: 0190 loss= 0.55348
Epoch: 0200 loss= 0.56601
Epoch: 0210 loss= 0.51698
Epoch: 0220 loss= 0.61534
Epoch: 0230 loss= 0.57268
Epoch: 0240 loss= 0.53197
Epoch: 0250 loss= 0.56855
Epoch: 0260 loss= 0.58838
Epoch: 0270 loss= 0.58513
Epoch: 0280 loss= 0.58654
Epoch: 0290 loss= 0.56222
Epoch: 0300 loss= 0.58426
Epoch: 0310 loss= 0.57357
Epoch: 0320 loss= 0.56261
Epoch: 0330 loss= 0.53348
Epoch: 0340 loss= 0.50835
Epoch: 0350 loss= 0.51057
Epoch: 0360 loss= 0.61422
Epoch: 0370 loss= 0.53901
Early stopping


  dm_free_dict = torch.load(os.path.join(self.ae_path, 'conditional_edm.pt'))


timestep:0, pyg_AUC: 0.4915
timestep:1, pyg_AUC: 0.4901
timestep:2, pyg_AUC: 0.4873
timestep:3, pyg_AUC: 0.4845
timestep:4, pyg_AUC: 0.4929
timestep:5, pyg_AUC: 0.4887
timestep:6, pyg_AUC: 0.4915
timestep:7, pyg_AUC: 0.4873
timestep:8, pyg_AUC: 0.4929
timestep:9, pyg_AUC: 0.4873
timestep:10, pyg_AUC: 0.4887
timestep:11, pyg_AUC: 0.4845
timestep:12, pyg_AUC: 0.4929
timestep:13, pyg_AUC: 0.4915
timestep:14, pyg_AUC: 0.4915
timestep:15, pyg_AUC: 0.4901
timestep:16, pyg_AUC: 0.4845
timestep:17, pyg_AUC: 0.4915
timestep:18, pyg_AUC: 0.4887
timestep:19, pyg_AUC: 0.4887
timestep:20, pyg_AUC: 0.4859
timestep:21, pyg_AUC: 0.4887
timestep:22, pyg_AUC: 0.4887
timestep:23, pyg_AUC: 0.4873
timestep:24, pyg_AUC: 0.4887
timestep:25, pyg_AUC: 0.4915
timestep:26, pyg_AUC: 0.4873
timestep:27, pyg_AUC: 0.4901
timestep:28, pyg_AUC: 0.4859
timestep:29, pyg_AUC: 0.4845
timestep:30, pyg_AUC: 0.4915
timestep:31, pyg_AUC: 0.4958
timestep:32, pyg_AUC: 0.4915
timestep:33, pyg_AUC: 0.4887
timestep:34, pyg_AUC: 0.

 40%|████      | 8/20 [13:23<20:07, 100.62s/it]

timestep:499, pyg_AUC: 0.4887
Training diffusion model (unconditional) ...
Epoch: 0000 loss= 40.08904
Epoch: 0010 loss= 24.59587
Epoch: 0020 loss= 18.23821
Epoch: 0030 loss= 15.96839
Epoch: 0040 loss= 10.45609
Epoch: 0050 loss= 1.78387
Epoch: 0060 loss= 2.07045
Epoch: 0070 loss= 0.84090
Epoch: 0080 loss= 0.73431
Epoch: 0090 loss= 0.77535
Epoch: 0100 loss= 0.69004
Epoch: 0110 loss= 0.68497
Epoch: 0120 loss= 0.73894
Epoch: 0130 loss= 0.71031
Epoch: 0140 loss= 0.69673
Epoch: 0150 loss= 0.59125
Epoch: 0160 loss= 0.58074
Epoch: 0170 loss= 0.58226
Epoch: 0180 loss= 0.63660
Epoch: 0190 loss= 0.62512
Epoch: 0200 loss= 0.55665
Epoch: 0210 loss= 0.62328
Epoch: 0220 loss= 0.61565
Epoch: 0230 loss= 0.58649
Epoch: 0240 loss= 0.58926
Epoch: 0250 loss= 0.56000
Epoch: 0260 loss= 0.58832
Epoch: 0270 loss= 0.59169
Epoch: 0280 loss= 0.55733
Epoch: 0290 loss= 0.54055
Epoch: 0300 loss= 0.60310
Epoch: 0310 loss= 0.58121
Epoch: 0320 loss= 0.53030
Epoch: 0330 loss= 0.55677
Epoch: 0340 loss= 0.58418
Epoch: 035

  dm_dict = torch.load(os.path.join(self.ae_path, 'edm.pt'))


Epoch: 0030 loss= 12.70227
Epoch: 0040 loss= 5.18452
Epoch: 0050 loss= 1.14904
Epoch: 0060 loss= 0.97878
Epoch: 0070 loss= 0.70795
Epoch: 0080 loss= 0.58560
Epoch: 0090 loss= 0.61537
Epoch: 0100 loss= 0.60637
Epoch: 0110 loss= 0.58919
Epoch: 0120 loss= 0.65843
Epoch: 0130 loss= 0.60368
Epoch: 0140 loss= 0.63667
Epoch: 0150 loss= 0.62997
Epoch: 0160 loss= 0.64687
Epoch: 0170 loss= 0.58297
Epoch: 0180 loss= 0.54966
Epoch: 0190 loss= 0.52281
Epoch: 0200 loss= 0.58987
Epoch: 0210 loss= 0.58067
Epoch: 0220 loss= 0.61337
Epoch: 0230 loss= 0.55011
Epoch: 0240 loss= 0.60010
Epoch: 0250 loss= 0.62341
Epoch: 0260 loss= 0.67706
Epoch: 0270 loss= 0.53032
Epoch: 0280 loss= 0.58574
Epoch: 0290 loss= 0.60657
Epoch: 0300 loss= 0.50432
Epoch: 0310 loss= 0.60256
Epoch: 0320 loss= 0.57000
Epoch: 0330 loss= 0.58846
Epoch: 0340 loss= 0.55091
Epoch: 0350 loss= 0.58578
Early stopping


  dm_free_dict = torch.load(os.path.join(self.ae_path, 'conditional_edm.pt'))


timestep:0, pyg_AUC: 0.4845
timestep:1, pyg_AUC: 0.4944
timestep:2, pyg_AUC: 0.4915
timestep:3, pyg_AUC: 0.4887
timestep:4, pyg_AUC: 0.4887
timestep:5, pyg_AUC: 0.4887
timestep:6, pyg_AUC: 0.4873
timestep:7, pyg_AUC: 0.4929
timestep:8, pyg_AUC: 0.4915
timestep:9, pyg_AUC: 0.4873
timestep:10, pyg_AUC: 0.4901
timestep:11, pyg_AUC: 0.4944
timestep:12, pyg_AUC: 0.4859
timestep:13, pyg_AUC: 0.4887
timestep:14, pyg_AUC: 0.4831
timestep:15, pyg_AUC: 0.4873
timestep:16, pyg_AUC: 0.4915
timestep:17, pyg_AUC: 0.4859
timestep:18, pyg_AUC: 0.4873
timestep:19, pyg_AUC: 0.4859
timestep:20, pyg_AUC: 0.4859
timestep:21, pyg_AUC: 0.4929
timestep:22, pyg_AUC: 0.4901
timestep:23, pyg_AUC: 0.4873
timestep:24, pyg_AUC: 0.4859
timestep:25, pyg_AUC: 0.4915
timestep:26, pyg_AUC: 0.4873
timestep:27, pyg_AUC: 0.4901
timestep:28, pyg_AUC: 0.4915
timestep:29, pyg_AUC: 0.4901
timestep:30, pyg_AUC: 0.4929
timestep:31, pyg_AUC: 0.4873
timestep:32, pyg_AUC: 0.4816
timestep:33, pyg_AUC: 0.4845
timestep:34, pyg_AUC: 0.

 45%|████▌     | 9/20 [15:05<18:31, 101.00s/it]

timestep:498, pyg_AUC: 0.4915
timestep:499, pyg_AUC: 0.4887
Training diffusion model (unconditional) ...
Epoch: 0000 loss= 40.26085
Epoch: 0010 loss= 22.15419
Epoch: 0020 loss= 20.86200
Epoch: 0030 loss= 18.15471
Epoch: 0040 loss= 16.44251
Epoch: 0050 loss= 11.18345
Epoch: 0060 loss= 6.28638
Epoch: 0070 loss= 1.63418
Epoch: 0080 loss= 1.74932
Epoch: 0090 loss= 0.85999
Epoch: 0100 loss= 0.68743
Epoch: 0110 loss= 0.55744
Epoch: 0120 loss= 0.63184
Epoch: 0130 loss= 0.62867
Epoch: 0140 loss= 0.67911
Epoch: 0150 loss= 0.63033
Epoch: 0160 loss= 0.61746
Epoch: 0170 loss= 0.59930
Epoch: 0180 loss= 0.52183
Epoch: 0190 loss= 0.62963
Epoch: 0200 loss= 0.62968
Epoch: 0210 loss= 0.60948
Epoch: 0220 loss= 0.60556
Epoch: 0230 loss= 0.58215
Epoch: 0240 loss= 0.53245
Epoch: 0250 loss= 0.55564
Early stopping
Common feature: tensor([[-4.8127,  4.7619, -5.4034, -4.6698, -4.4210,  5.6952,  5.4990, -5.2999]],
       device='cuda:0')
Training diffusion model (conditional) ...
Epoch: 0000 loss= 37.85318
Epoch

  dm_dict = torch.load(os.path.join(self.ae_path, 'edm.pt'))


Epoch: 0030 loss= 11.97665
Epoch: 0040 loss= 5.09608
Epoch: 0050 loss= 2.03420
Epoch: 0060 loss= 1.02793
Epoch: 0070 loss= 0.79431
Epoch: 0080 loss= 0.64497
Epoch: 0090 loss= 0.57750
Epoch: 0100 loss= 0.82892
Epoch: 0110 loss= 0.68237
Epoch: 0120 loss= 0.62658
Epoch: 0130 loss= 0.54531
Epoch: 0140 loss= 0.65706
Epoch: 0150 loss= 0.69874
Epoch: 0160 loss= 0.64743
Epoch: 0170 loss= 0.59881
Epoch: 0180 loss= 0.65388
Epoch: 0190 loss= 0.53005
Epoch: 0200 loss= 0.55560
Epoch: 0210 loss= 0.55213
Epoch: 0220 loss= 0.56445
Epoch: 0230 loss= 0.61492
Epoch: 0240 loss= 0.62763
Epoch: 0250 loss= 0.59007
Epoch: 0260 loss= 0.61574
Epoch: 0270 loss= 0.53579
Epoch: 0280 loss= 0.62249
Epoch: 0290 loss= 0.57000
Epoch: 0300 loss= 0.57561
Early stopping


  dm_free_dict = torch.load(os.path.join(self.ae_path, 'conditional_edm.pt'))


timestep:0, pyg_AUC: 0.4831
timestep:1, pyg_AUC: 0.4816
timestep:2, pyg_AUC: 0.4929
timestep:3, pyg_AUC: 0.4873
timestep:4, pyg_AUC: 0.4901
timestep:5, pyg_AUC: 0.4873
timestep:6, pyg_AUC: 0.4944
timestep:7, pyg_AUC: 0.4887
timestep:8, pyg_AUC: 0.4887
timestep:9, pyg_AUC: 0.4859
timestep:10, pyg_AUC: 0.4873
timestep:11, pyg_AUC: 0.4929
timestep:12, pyg_AUC: 0.4887
timestep:13, pyg_AUC: 0.4901
timestep:14, pyg_AUC: 0.4901
timestep:15, pyg_AUC: 0.4873
timestep:16, pyg_AUC: 0.4887
timestep:17, pyg_AUC: 0.4887
timestep:18, pyg_AUC: 0.4831
timestep:19, pyg_AUC: 0.4901
timestep:20, pyg_AUC: 0.4887
timestep:21, pyg_AUC: 0.4845
timestep:22, pyg_AUC: 0.4901
timestep:23, pyg_AUC: 0.4915
timestep:24, pyg_AUC: 0.4915
timestep:25, pyg_AUC: 0.4901
timestep:26, pyg_AUC: 0.4929
timestep:27, pyg_AUC: 0.4845
timestep:28, pyg_AUC: 0.4774
timestep:29, pyg_AUC: 0.4901
timestep:30, pyg_AUC: 0.4887
timestep:31, pyg_AUC: 0.4887
timestep:32, pyg_AUC: 0.4901
timestep:33, pyg_AUC: 0.4873
timestep:34, pyg_AUC: 0.

 50%|█████     | 10/20 [16:46<16:52, 101.23s/it]

timestep:498, pyg_AUC: 0.4873
timestep:499, pyg_AUC: 0.4859
Training diffusion model (unconditional) ...
Epoch: 0000 loss= 41.35893
Epoch: 0010 loss= 24.60151
Epoch: 0020 loss= 18.33657
Epoch: 0030 loss= 16.52004
Epoch: 0040 loss= 13.95027
Epoch: 0050 loss= 5.52830
Epoch: 0060 loss= 1.37850
Epoch: 0070 loss= 0.81585
Epoch: 0080 loss= 0.75493
Epoch: 0090 loss= 0.77594
Epoch: 0100 loss= 0.68952
Epoch: 0110 loss= 0.63322
Epoch: 0120 loss= 0.62033
Epoch: 0130 loss= 0.60450
Epoch: 0140 loss= 0.63107
Epoch: 0150 loss= 0.50966
Epoch: 0160 loss= 0.56602
Epoch: 0170 loss= 0.61068
Epoch: 0180 loss= 0.66395
Epoch: 0190 loss= 0.58228
Epoch: 0200 loss= 0.61942
Epoch: 0210 loss= 0.54799
Epoch: 0220 loss= 0.56241
Epoch: 0230 loss= 0.50945
Epoch: 0240 loss= 0.60310
Epoch: 0250 loss= 0.65515
Epoch: 0260 loss= 0.61181
Epoch: 0270 loss= 0.57547
Epoch: 0280 loss= 0.50513
Epoch: 0290 loss= 0.51688
Epoch: 0300 loss= 0.55942
Epoch: 0310 loss= 0.68189
Epoch: 0320 loss= 0.54025
Epoch: 0330 loss= 0.54704
Epoch:

  dm_dict = torch.load(os.path.join(self.ae_path, 'edm.pt'))


Epoch: 0010 loss= 36.01245
Epoch: 0020 loss= 22.90716
Epoch: 0030 loss= 16.17031
Epoch: 0040 loss= 7.57759
Epoch: 0050 loss= 2.18238
Epoch: 0060 loss= 0.72347
Epoch: 0070 loss= 0.68237
Epoch: 0080 loss= 1.38942
Epoch: 0090 loss= 0.68676
Epoch: 0100 loss= 1.04547
Epoch: 0110 loss= 0.79572
Epoch: 0120 loss= 0.62468
Epoch: 0130 loss= 0.58597
Epoch: 0140 loss= 0.57311
Epoch: 0150 loss= 0.58915
Epoch: 0160 loss= 0.60219
Epoch: 0170 loss= 0.58091
Epoch: 0180 loss= 0.67272
Epoch: 0190 loss= 0.56651
Epoch: 0200 loss= 0.60999
Epoch: 0210 loss= 0.55475
Epoch: 0220 loss= 0.65925
Epoch: 0230 loss= 0.53609
Epoch: 0240 loss= 0.57257
Epoch: 0250 loss= 0.62324
Epoch: 0260 loss= 0.55636
Epoch: 0270 loss= 0.64556
Epoch: 0280 loss= 0.50953
Epoch: 0290 loss= 0.62862
Epoch: 0300 loss= 0.57529
Epoch: 0310 loss= 0.54115
Early stopping


  dm_free_dict = torch.load(os.path.join(self.ae_path, 'conditional_edm.pt'))


timestep:0, pyg_AUC: 0.4887
timestep:1, pyg_AUC: 0.4929
timestep:2, pyg_AUC: 0.4929
timestep:3, pyg_AUC: 0.4901
timestep:4, pyg_AUC: 0.4901
timestep:5, pyg_AUC: 0.4915
timestep:6, pyg_AUC: 0.4929
timestep:7, pyg_AUC: 0.4873
timestep:8, pyg_AUC: 0.4901
timestep:9, pyg_AUC: 0.4929
timestep:10, pyg_AUC: 0.4859
timestep:11, pyg_AUC: 0.4873
timestep:12, pyg_AUC: 0.4901
timestep:13, pyg_AUC: 0.4887
timestep:14, pyg_AUC: 0.4901
timestep:15, pyg_AUC: 0.4915
timestep:16, pyg_AUC: 0.4915
timestep:17, pyg_AUC: 0.4929
timestep:18, pyg_AUC: 0.4929
timestep:19, pyg_AUC: 0.4901
timestep:20, pyg_AUC: 0.4901
timestep:21, pyg_AUC: 0.4915
timestep:22, pyg_AUC: 0.4901
timestep:23, pyg_AUC: 0.4901
timestep:24, pyg_AUC: 0.4873
timestep:25, pyg_AUC: 0.4873
timestep:26, pyg_AUC: 0.4929
timestep:27, pyg_AUC: 0.4831
timestep:28, pyg_AUC: 0.4929
timestep:29, pyg_AUC: 0.4901
timestep:30, pyg_AUC: 0.4929
timestep:31, pyg_AUC: 0.4887
timestep:32, pyg_AUC: 0.4887
timestep:33, pyg_AUC: 0.4901
timestep:34, pyg_AUC: 0.

 55%|█████▌    | 11/20 [18:30<15:17, 101.95s/it]

timestep:499, pyg_AUC: 0.4915
Training diffusion model (unconditional) ...
Epoch: 0000 loss= 39.35904
Epoch: 0010 loss= 29.17165
Epoch: 0020 loss= 19.43969
Epoch: 0030 loss= 18.19963
Epoch: 0040 loss= 12.00288
Epoch: 0050 loss= 6.20477
Epoch: 0060 loss= 2.22260
Epoch: 0070 loss= 1.02518
Epoch: 0080 loss= 0.95035
Epoch: 0090 loss= 0.93086
Epoch: 0100 loss= 0.73464
Epoch: 0110 loss= 0.62667
Epoch: 0120 loss= 0.58846
Epoch: 0130 loss= 0.64864
Epoch: 0140 loss= 0.65306
Epoch: 0150 loss= 0.58918
Epoch: 0160 loss= 0.53883
Epoch: 0170 loss= 0.62951
Epoch: 0180 loss= 0.59401
Epoch: 0190 loss= 0.52517
Epoch: 0200 loss= 0.64239
Epoch: 0210 loss= 0.61863
Epoch: 0220 loss= 0.61715
Epoch: 0230 loss= 0.66073
Epoch: 0240 loss= 0.63735
Epoch: 0250 loss= 0.55691
Epoch: 0260 loss= 0.64955
Epoch: 0270 loss= 0.50248
Epoch: 0280 loss= 0.58078
Epoch: 0290 loss= 0.59196
Epoch: 0300 loss= 0.62263
Epoch: 0310 loss= 0.57856
Epoch: 0320 loss= 0.52537
Epoch: 0330 loss= 0.50644
Epoch: 0340 loss= 0.54605
Epoch: 035

  dm_dict = torch.load(os.path.join(self.ae_path, 'edm.pt'))


Epoch: 0030 loss= 12.92207
Epoch: 0040 loss= 3.58091
Epoch: 0050 loss= 1.32048
Epoch: 0060 loss= 1.11041
Epoch: 0070 loss= 0.70642
Epoch: 0080 loss= 0.66843
Epoch: 0090 loss= 0.89318
Epoch: 0100 loss= 0.70181
Epoch: 0110 loss= 0.67280
Epoch: 0120 loss= 0.64324
Epoch: 0130 loss= 0.51843
Epoch: 0140 loss= 0.62808
Epoch: 0150 loss= 0.66568
Epoch: 0160 loss= 0.56307
Epoch: 0170 loss= 0.63642
Epoch: 0180 loss= 0.71387
Epoch: 0190 loss= 0.51330
Epoch: 0200 loss= 0.54778
Epoch: 0210 loss= 0.64114
Epoch: 0220 loss= 0.59166
Epoch: 0230 loss= 0.54896
Epoch: 0240 loss= 0.60114
Epoch: 0250 loss= 0.56157
Epoch: 0260 loss= 0.56298
Epoch: 0270 loss= 0.56409
Epoch: 0280 loss= 0.61256
Epoch: 0290 loss= 0.55272
Epoch: 0300 loss= 0.55279
Epoch: 0310 loss= 0.57982
Epoch: 0320 loss= 0.55660
Epoch: 0330 loss= 0.49976
Epoch: 0340 loss= 0.61634
Epoch: 0350 loss= 0.52848
Epoch: 0360 loss= 0.64468
Epoch: 0370 loss= 0.62946
Epoch: 0380 loss= 0.55514
Epoch: 0390 loss= 0.60773
Epoch: 0400 loss= 0.52353
Epoch: 0410

  dm_free_dict = torch.load(os.path.join(self.ae_path, 'conditional_edm.pt'))


timestep:0, pyg_AUC: 0.4901
timestep:1, pyg_AUC: 0.4929
timestep:2, pyg_AUC: 0.4873
timestep:3, pyg_AUC: 0.4915
timestep:4, pyg_AUC: 0.4887
timestep:5, pyg_AUC: 0.4887
timestep:6, pyg_AUC: 0.4915
timestep:7, pyg_AUC: 0.4929
timestep:8, pyg_AUC: 0.4873
timestep:9, pyg_AUC: 0.4873
timestep:10, pyg_AUC: 0.4915
timestep:11, pyg_AUC: 0.4901
timestep:12, pyg_AUC: 0.4887
timestep:13, pyg_AUC: 0.4901
timestep:14, pyg_AUC: 0.4901
timestep:15, pyg_AUC: 0.4915
timestep:16, pyg_AUC: 0.4915
timestep:17, pyg_AUC: 0.4845
timestep:18, pyg_AUC: 0.4915
timestep:19, pyg_AUC: 0.4887
timestep:20, pyg_AUC: 0.4944
timestep:21, pyg_AUC: 0.4873
timestep:22, pyg_AUC: 0.4887
timestep:23, pyg_AUC: 0.4901
timestep:24, pyg_AUC: 0.4887
timestep:25, pyg_AUC: 0.4901
timestep:26, pyg_AUC: 0.4901
timestep:27, pyg_AUC: 0.4915
timestep:28, pyg_AUC: 0.4859
timestep:29, pyg_AUC: 0.4929
timestep:30, pyg_AUC: 0.4859
timestep:31, pyg_AUC: 0.4915
timestep:32, pyg_AUC: 0.4873
timestep:33, pyg_AUC: 0.4887
timestep:34, pyg_AUC: 0.

 60%|██████    | 12/20 [20:14<13:41, 102.63s/it]

timestep:499, pyg_AUC: 0.4901
Training diffusion model (unconditional) ...
Epoch: 0000 loss= 36.23397
Epoch: 0010 loss= 26.90117
Epoch: 0020 loss= 19.87727
Epoch: 0030 loss= 15.22818
Epoch: 0040 loss= 11.90355
Epoch: 0050 loss= 6.82045
Epoch: 0060 loss= 2.14924
Epoch: 0070 loss= 1.06985
Epoch: 0080 loss= 0.75501
Epoch: 0090 loss= 0.88825
Epoch: 0100 loss= 0.67324
Epoch: 0110 loss= 0.63991
Epoch: 0120 loss= 0.60692
Epoch: 0130 loss= 0.69814
Epoch: 0140 loss= 0.59540
Epoch: 0150 loss= 0.54465
Epoch: 0160 loss= 0.66323
Epoch: 0170 loss= 0.71100
Epoch: 0180 loss= 0.52948
Epoch: 0190 loss= 0.53572
Epoch: 0200 loss= 0.60856
Epoch: 0210 loss= 0.57210
Epoch: 0220 loss= 0.61477
Epoch: 0230 loss= 0.67811
Epoch: 0240 loss= 0.56939
Epoch: 0250 loss= 0.54642
Epoch: 0260 loss= 0.62267
Epoch: 0270 loss= 0.65880
Epoch: 0280 loss= 0.59095
Epoch: 0290 loss= 0.53844
Epoch: 0300 loss= 0.58203
Epoch: 0310 loss= 0.54537
Epoch: 0320 loss= 0.56969
Epoch: 0330 loss= 0.56328
Epoch: 0340 loss= 0.56745
Epoch: 035

  dm_dict = torch.load(os.path.join(self.ae_path, 'edm.pt'))


Epoch: 0030 loss= 4.25458
Epoch: 0040 loss= 1.27438
Epoch: 0050 loss= 0.87749
Epoch: 0060 loss= 0.80064
Epoch: 0070 loss= 0.68512
Epoch: 0080 loss= 1.12678
Epoch: 0090 loss= 0.67713
Epoch: 0100 loss= 0.65629
Epoch: 0110 loss= 0.66010
Epoch: 0120 loss= 0.59068
Epoch: 0130 loss= 0.67171
Epoch: 0140 loss= 0.59263
Epoch: 0150 loss= 0.66711
Epoch: 0160 loss= 0.67503
Epoch: 0170 loss= 0.55934
Epoch: 0180 loss= 0.54927
Epoch: 0190 loss= 0.59628
Epoch: 0200 loss= 0.65954
Epoch: 0210 loss= 0.64192
Epoch: 0220 loss= 0.56661
Epoch: 0230 loss= 0.62024
Early stopping


  dm_free_dict = torch.load(os.path.join(self.ae_path, 'conditional_edm.pt'))


timestep:0, pyg_AUC: 0.4859
timestep:1, pyg_AUC: 0.4901
timestep:2, pyg_AUC: 0.4873
timestep:3, pyg_AUC: 0.4887
timestep:4, pyg_AUC: 0.4873
timestep:5, pyg_AUC: 0.4887
timestep:6, pyg_AUC: 0.4901
timestep:7, pyg_AUC: 0.4859
timestep:8, pyg_AUC: 0.4859
timestep:9, pyg_AUC: 0.4887
timestep:10, pyg_AUC: 0.4901
timestep:11, pyg_AUC: 0.4887
timestep:12, pyg_AUC: 0.4887
timestep:13, pyg_AUC: 0.4859
timestep:14, pyg_AUC: 0.4887
timestep:15, pyg_AUC: 0.4873
timestep:16, pyg_AUC: 0.4901
timestep:17, pyg_AUC: 0.4901
timestep:18, pyg_AUC: 0.4887
timestep:19, pyg_AUC: 0.4915
timestep:20, pyg_AUC: 0.4901
timestep:21, pyg_AUC: 0.4944
timestep:22, pyg_AUC: 0.4915
timestep:23, pyg_AUC: 0.4901
timestep:24, pyg_AUC: 0.4901
timestep:25, pyg_AUC: 0.4873
timestep:26, pyg_AUC: 0.4887
timestep:27, pyg_AUC: 0.4887
timestep:28, pyg_AUC: 0.4887
timestep:29, pyg_AUC: 0.4915
timestep:30, pyg_AUC: 0.4915
timestep:31, pyg_AUC: 0.4873
timestep:32, pyg_AUC: 0.4901
timestep:33, pyg_AUC: 0.4915
timestep:34, pyg_AUC: 0.

 65%|██████▌   | 13/20 [21:57<11:59, 102.78s/it]

timestep:498, pyg_AUC: 0.4859
timestep:499, pyg_AUC: 0.4887
Training diffusion model (unconditional) ...
Epoch: 0000 loss= 40.11938
Epoch: 0010 loss= 24.70359
Epoch: 0020 loss= 19.34404
Epoch: 0030 loss= 15.07860
Epoch: 0040 loss= 14.36049
Epoch: 0050 loss= 5.95601
Epoch: 0060 loss= 2.74010
Epoch: 0070 loss= 1.84789
Epoch: 0080 loss= 0.80541
Epoch: 0090 loss= 0.71407
Epoch: 0100 loss= 0.71716
Epoch: 0110 loss= 0.60542
Epoch: 0120 loss= 0.61667
Epoch: 0130 loss= 0.61614
Epoch: 0140 loss= 0.67708
Epoch: 0150 loss= 0.64543
Epoch: 0160 loss= 0.60725
Epoch: 0170 loss= 0.59986
Epoch: 0180 loss= 0.60719
Epoch: 0190 loss= 0.61792
Epoch: 0200 loss= 0.60540
Epoch: 0210 loss= 0.59125
Epoch: 0220 loss= 0.60447
Epoch: 0230 loss= 0.61284
Epoch: 0240 loss= 0.51267
Epoch: 0250 loss= 0.55766
Epoch: 0260 loss= 0.55647
Epoch: 0270 loss= 0.53824
Epoch: 0280 loss= 0.53275
Epoch: 0290 loss= 0.59926
Epoch: 0300 loss= 0.55750
Epoch: 0310 loss= 0.55885
Epoch: 0320 loss= 0.53951
Epoch: 0330 loss= 0.52188
Epoch:

  dm_dict = torch.load(os.path.join(self.ae_path, 'edm.pt'))


Epoch: 0030 loss= 6.72635
Epoch: 0040 loss= 3.04723
Epoch: 0050 loss= 0.92867
Epoch: 0060 loss= 1.79798
Epoch: 0070 loss= 0.73109
Epoch: 0080 loss= 0.97270
Epoch: 0090 loss= 0.72456
Epoch: 0100 loss= 0.59531
Epoch: 0110 loss= 0.62335
Epoch: 0120 loss= 0.65345
Epoch: 0130 loss= 0.55997
Epoch: 0140 loss= 0.58046
Epoch: 0150 loss= 0.49292
Epoch: 0160 loss= 0.63118
Epoch: 0170 loss= 0.58131
Epoch: 0180 loss= 0.55873
Epoch: 0190 loss= 0.63147
Epoch: 0200 loss= 0.56926
Epoch: 0210 loss= 0.53257
Epoch: 0220 loss= 0.61724
Epoch: 0230 loss= 0.57747
Epoch: 0240 loss= 0.55292
Epoch: 0250 loss= 0.56004
Epoch: 0260 loss= 0.55808
Epoch: 0270 loss= 0.50165
Epoch: 0280 loss= 0.60063
Epoch: 0290 loss= 0.47494
Epoch: 0300 loss= 0.52866
Epoch: 0310 loss= 0.49927
Epoch: 0320 loss= 0.56499
Epoch: 0330 loss= 0.47230
Epoch: 0340 loss= 0.50826
Epoch: 0350 loss= 0.61077
Epoch: 0360 loss= 0.55989
Epoch: 0370 loss= 0.51098
Epoch: 0380 loss= 0.48933
Epoch: 0390 loss= 0.47726
Epoch: 0400 loss= 0.46712
Epoch: 0410 

  dm_free_dict = torch.load(os.path.join(self.ae_path, 'conditional_edm.pt'))


timestep:0, pyg_AUC: 0.4816
timestep:1, pyg_AUC: 0.4831
timestep:2, pyg_AUC: 0.4901
timestep:3, pyg_AUC: 0.4873
timestep:4, pyg_AUC: 0.4845
timestep:5, pyg_AUC: 0.4915
timestep:6, pyg_AUC: 0.4887
timestep:7, pyg_AUC: 0.4901
timestep:8, pyg_AUC: 0.4915
timestep:9, pyg_AUC: 0.4929
timestep:10, pyg_AUC: 0.4915
timestep:11, pyg_AUC: 0.4958
timestep:12, pyg_AUC: 0.4873
timestep:13, pyg_AUC: 0.4901
timestep:14, pyg_AUC: 0.4901
timestep:15, pyg_AUC: 0.4915
timestep:16, pyg_AUC: 0.4901
timestep:17, pyg_AUC: 0.4915
timestep:18, pyg_AUC: 0.4859
timestep:19, pyg_AUC: 0.4901
timestep:20, pyg_AUC: 0.4887
timestep:21, pyg_AUC: 0.4873
timestep:22, pyg_AUC: 0.4887
timestep:23, pyg_AUC: 0.4915
timestep:24, pyg_AUC: 0.4901
timestep:25, pyg_AUC: 0.4887
timestep:26, pyg_AUC: 0.4915
timestep:27, pyg_AUC: 0.4873
timestep:28, pyg_AUC: 0.4901
timestep:29, pyg_AUC: 0.4944
timestep:30, pyg_AUC: 0.4915
timestep:31, pyg_AUC: 0.4873
timestep:32, pyg_AUC: 0.4915
timestep:33, pyg_AUC: 0.4901
timestep:34, pyg_AUC: 0.

 70%|███████   | 14/20 [23:43<10:21, 103.60s/it]

timestep:498, pyg_AUC: 0.4915
timestep:499, pyg_AUC: 0.4915
Training diffusion model (unconditional) ...
Epoch: 0000 loss= 38.32571
Epoch: 0010 loss= 29.86969
Epoch: 0020 loss= 25.57851
Epoch: 0030 loss= 20.39316
Epoch: 0040 loss= 15.33101
Epoch: 0050 loss= 11.71732
Epoch: 0060 loss= 3.65759
Epoch: 0070 loss= 1.83816
Epoch: 0080 loss= 1.10079
Epoch: 0090 loss= 0.77652
Epoch: 0100 loss= 0.75840
Epoch: 0110 loss= 0.61433
Epoch: 0120 loss= 0.62306
Epoch: 0130 loss= 0.66919
Epoch: 0140 loss= 0.65748
Epoch: 0150 loss= 0.59054
Epoch: 0160 loss= 0.63096
Epoch: 0170 loss= 0.61209
Epoch: 0180 loss= 0.64359
Epoch: 0190 loss= 0.60465
Epoch: 0200 loss= 0.62956
Epoch: 0210 loss= 0.60971
Epoch: 0220 loss= 0.57378
Epoch: 0230 loss= 0.59157
Epoch: 0240 loss= 0.58462
Epoch: 0250 loss= 0.62646
Epoch: 0260 loss= 0.56112
Epoch: 0270 loss= 0.59263
Epoch: 0280 loss= 0.52960
Epoch: 0290 loss= 0.64433
Epoch: 0300 loss= 0.61016
Epoch: 0310 loss= 0.55779
Epoch: 0320 loss= 0.53568
Epoch: 0330 loss= 0.55813
Epoch

  dm_dict = torch.load(os.path.join(self.ae_path, 'edm.pt'))


Epoch: 0010 loss= 26.55460
Epoch: 0020 loss= 17.73634
Epoch: 0030 loss= 13.58685
Epoch: 0040 loss= 7.75981
Epoch: 0050 loss= 1.78495
Epoch: 0060 loss= 0.85992
Epoch: 0070 loss= 0.67967
Epoch: 0080 loss= 0.68469
Epoch: 0090 loss= 0.65749
Epoch: 0100 loss= 0.58378
Epoch: 0110 loss= 0.52620
Epoch: 0120 loss= 0.65414
Epoch: 0130 loss= 0.58833
Epoch: 0140 loss= 0.57408
Epoch: 0150 loss= 0.57252
Epoch: 0160 loss= 0.55539
Epoch: 0170 loss= 0.56246
Epoch: 0180 loss= 0.63358
Epoch: 0190 loss= 0.60719
Epoch: 0200 loss= 0.62587
Early stopping


  dm_free_dict = torch.load(os.path.join(self.ae_path, 'conditional_edm.pt'))


timestep:0, pyg_AUC: 0.4887
timestep:1, pyg_AUC: 0.4831
timestep:2, pyg_AUC: 0.4915
timestep:3, pyg_AUC: 0.4929
timestep:4, pyg_AUC: 0.4887
timestep:5, pyg_AUC: 0.5000
timestep:6, pyg_AUC: 0.4915
timestep:7, pyg_AUC: 0.4859
timestep:8, pyg_AUC: 0.4915
timestep:9, pyg_AUC: 0.4816
timestep:10, pyg_AUC: 0.4929
timestep:11, pyg_AUC: 0.4972
timestep:12, pyg_AUC: 0.4929
timestep:13, pyg_AUC: 0.4915
timestep:14, pyg_AUC: 0.4901
timestep:15, pyg_AUC: 0.4859
timestep:16, pyg_AUC: 0.4873
timestep:17, pyg_AUC: 0.4887
timestep:18, pyg_AUC: 0.4929
timestep:19, pyg_AUC: 0.4929
timestep:20, pyg_AUC: 0.4887
timestep:21, pyg_AUC: 0.4859
timestep:22, pyg_AUC: 0.4944
timestep:23, pyg_AUC: 0.4915
timestep:24, pyg_AUC: 0.4845
timestep:25, pyg_AUC: 0.4873
timestep:26, pyg_AUC: 0.4859
timestep:27, pyg_AUC: 0.4845
timestep:28, pyg_AUC: 0.4859
timestep:29, pyg_AUC: 0.4887
timestep:30, pyg_AUC: 0.4859
timestep:31, pyg_AUC: 0.4873
timestep:32, pyg_AUC: 0.4859
timestep:33, pyg_AUC: 0.4831
timestep:34, pyg_AUC: 0.

 75%|███████▌  | 15/20 [25:25<08:35, 103.08s/it]

timestep:498, pyg_AUC: 0.4845
timestep:499, pyg_AUC: 0.4859
Training diffusion model (unconditional) ...
Epoch: 0000 loss= 46.96055
Epoch: 0010 loss= 27.72296
Epoch: 0020 loss= 18.71268
Epoch: 0030 loss= 16.67502
Epoch: 0040 loss= 12.09468
Epoch: 0050 loss= 7.14389
Epoch: 0060 loss= 2.99753
Epoch: 0070 loss= 0.69929
Epoch: 0080 loss= 0.64417
Epoch: 0090 loss= 0.76861
Epoch: 0100 loss= 0.65354
Epoch: 0110 loss= 0.62975
Epoch: 0120 loss= 0.66175
Epoch: 0130 loss= 0.62493
Epoch: 0140 loss= 0.59529
Epoch: 0150 loss= 0.57452
Epoch: 0160 loss= 0.60846
Epoch: 0170 loss= 0.61542
Epoch: 0180 loss= 0.67766
Epoch: 0190 loss= 0.63992
Epoch: 0200 loss= 0.61400
Epoch: 0210 loss= 0.59975
Epoch: 0220 loss= 0.59754
Epoch: 0230 loss= 0.57064
Epoch: 0240 loss= 0.53588
Epoch: 0250 loss= 0.58023
Epoch: 0260 loss= 0.61076
Epoch: 0270 loss= 0.57571
Epoch: 0280 loss= 0.53753
Epoch: 0290 loss= 0.54681
Epoch: 0300 loss= 0.55722
Epoch: 0310 loss= 0.58031
Epoch: 0320 loss= 0.53478
Epoch: 0330 loss= 0.50970
Epoch:

  dm_dict = torch.load(os.path.join(self.ae_path, 'edm.pt'))


Epoch: 0020 loss= 13.65278
Epoch: 0030 loss= 9.15199
Epoch: 0040 loss= 3.73097
Epoch: 0050 loss= 1.14531
Epoch: 0060 loss= 0.90801
Epoch: 0070 loss= 0.67545
Epoch: 0080 loss= 0.68483
Epoch: 0090 loss= 0.62712
Epoch: 0100 loss= 0.61185
Epoch: 0110 loss= 0.56833
Epoch: 0120 loss= 0.61043
Epoch: 0130 loss= 0.63586
Epoch: 0140 loss= 0.52281
Epoch: 0150 loss= 0.64929
Epoch: 0160 loss= 0.61932
Epoch: 0170 loss= 0.57780
Epoch: 0180 loss= 0.53643
Epoch: 0190 loss= 0.56999
Epoch: 0200 loss= 0.48895
Epoch: 0210 loss= 0.58098
Epoch: 0220 loss= 0.51550
Epoch: 0230 loss= 0.54480
Epoch: 0240 loss= 0.62267
Epoch: 0250 loss= 0.59923
Epoch: 0260 loss= 0.57330
Epoch: 0270 loss= 0.51541
Epoch: 0280 loss= 0.54510
Epoch: 0290 loss= 0.61283
Epoch: 0300 loss= 0.58172
Early stopping


  dm_free_dict = torch.load(os.path.join(self.ae_path, 'conditional_edm.pt'))


timestep:0, pyg_AUC: 0.4901
timestep:1, pyg_AUC: 0.4873
timestep:2, pyg_AUC: 0.4901
timestep:3, pyg_AUC: 0.4859
timestep:4, pyg_AUC: 0.4831
timestep:5, pyg_AUC: 0.4901
timestep:6, pyg_AUC: 0.4915
timestep:7, pyg_AUC: 0.4915
timestep:8, pyg_AUC: 0.4901
timestep:9, pyg_AUC: 0.4915
timestep:10, pyg_AUC: 0.4915
timestep:11, pyg_AUC: 0.4887
timestep:12, pyg_AUC: 0.4929
timestep:13, pyg_AUC: 0.4901
timestep:14, pyg_AUC: 0.4915
timestep:15, pyg_AUC: 0.4901
timestep:16, pyg_AUC: 0.4929
timestep:17, pyg_AUC: 0.4929
timestep:18, pyg_AUC: 0.4915
timestep:19, pyg_AUC: 0.4887
timestep:20, pyg_AUC: 0.4859
timestep:21, pyg_AUC: 0.4859
timestep:22, pyg_AUC: 0.4887
timestep:23, pyg_AUC: 0.4901
timestep:24, pyg_AUC: 0.4901
timestep:25, pyg_AUC: 0.4915
timestep:26, pyg_AUC: 0.4816
timestep:27, pyg_AUC: 0.4887
timestep:28, pyg_AUC: 0.4845
timestep:29, pyg_AUC: 0.4901
timestep:30, pyg_AUC: 0.4901
timestep:31, pyg_AUC: 0.4929
timestep:32, pyg_AUC: 0.4901
timestep:33, pyg_AUC: 0.4901
timestep:34, pyg_AUC: 0.

 80%|████████  | 16/20 [27:06<06:50, 102.68s/it]

timestep:499, pyg_AUC: 0.4915
Training diffusion model (unconditional) ...
Epoch: 0000 loss= 38.02254
Epoch: 0010 loss= 27.99958
Epoch: 0020 loss= 19.41045
Epoch: 0030 loss= 16.66020
Epoch: 0040 loss= 14.44121
Epoch: 0050 loss= 4.17237
Epoch: 0060 loss= 1.21599
Epoch: 0070 loss= 0.84787
Epoch: 0080 loss= 0.61324
Epoch: 0090 loss= 0.64598
Epoch: 0100 loss= 0.61378
Epoch: 0110 loss= 0.63308
Epoch: 0120 loss= 0.65770
Epoch: 0130 loss= 0.59138
Epoch: 0140 loss= 0.66857
Epoch: 0150 loss= 0.62097
Epoch: 0160 loss= 0.54267
Epoch: 0170 loss= 0.59196
Epoch: 0180 loss= 0.69642
Epoch: 0190 loss= 0.57597
Epoch: 0200 loss= 0.59372
Epoch: 0210 loss= 0.56244
Epoch: 0220 loss= 0.59800
Epoch: 0230 loss= 0.66805
Epoch: 0240 loss= 0.63774
Epoch: 0250 loss= 0.59312
Epoch: 0260 loss= 0.52820
Epoch: 0270 loss= 0.58407
Epoch: 0280 loss= 0.56977
Early stopping
Common feature: tensor([[-4.7623,  4.6843, -5.3228, -4.6303, -4.3717,  5.5789,  5.4012, -5.1959]],
       device='cuda:0')
Training diffusion model (co

  dm_dict = torch.load(os.path.join(self.ae_path, 'edm.pt'))


Epoch: 0010 loss= 28.25689
Epoch: 0020 loss= 22.48223
Epoch: 0030 loss= 11.13611
Epoch: 0040 loss= 5.92358
Epoch: 0050 loss= 0.92838
Epoch: 0060 loss= 0.90961
Epoch: 0070 loss= 0.68554
Epoch: 0080 loss= 0.69580
Epoch: 0090 loss= 0.74024
Epoch: 0100 loss= 0.66947
Epoch: 0110 loss= 0.57658
Epoch: 0120 loss= 0.63795
Epoch: 0130 loss= 0.54824
Epoch: 0140 loss= 0.65746
Epoch: 0150 loss= 0.64628
Epoch: 0160 loss= 0.57223
Epoch: 0170 loss= 0.62112
Epoch: 0180 loss= 0.59720
Epoch: 0190 loss= 0.54443
Epoch: 0200 loss= 0.59654
Epoch: 0210 loss= 0.54539
Epoch: 0220 loss= 0.52675
Epoch: 0230 loss= 0.60256
Epoch: 0240 loss= 0.47855
Epoch: 0250 loss= 0.63726
Epoch: 0260 loss= 0.59217
Epoch: 0270 loss= 0.53564
Epoch: 0280 loss= 0.62018
Epoch: 0290 loss= 0.58246
Epoch: 0300 loss= 0.56762
Epoch: 0310 loss= 0.57211
Epoch: 0320 loss= 0.55664
Epoch: 0330 loss= 0.58422
Early stopping


  dm_free_dict = torch.load(os.path.join(self.ae_path, 'conditional_edm.pt'))


timestep:0, pyg_AUC: 0.4915
timestep:1, pyg_AUC: 0.4901
timestep:2, pyg_AUC: 0.4901
timestep:3, pyg_AUC: 0.4873
timestep:4, pyg_AUC: 0.4887
timestep:5, pyg_AUC: 0.4887
timestep:6, pyg_AUC: 0.4915
timestep:7, pyg_AUC: 0.4958
timestep:8, pyg_AUC: 0.4915
timestep:9, pyg_AUC: 0.4929
timestep:10, pyg_AUC: 0.4901
timestep:11, pyg_AUC: 0.4901
timestep:12, pyg_AUC: 0.4816
timestep:13, pyg_AUC: 0.4887
timestep:14, pyg_AUC: 0.4972
timestep:15, pyg_AUC: 0.4845
timestep:16, pyg_AUC: 0.4845
timestep:17, pyg_AUC: 0.4845
timestep:18, pyg_AUC: 0.4859
timestep:19, pyg_AUC: 0.4873
timestep:20, pyg_AUC: 0.4859
timestep:21, pyg_AUC: 0.4831
timestep:22, pyg_AUC: 0.4944
timestep:23, pyg_AUC: 0.4859
timestep:24, pyg_AUC: 0.4873
timestep:25, pyg_AUC: 0.4887
timestep:26, pyg_AUC: 0.4901
timestep:27, pyg_AUC: 0.4873
timestep:28, pyg_AUC: 0.4831
timestep:29, pyg_AUC: 0.4972
timestep:30, pyg_AUC: 0.4958
timestep:31, pyg_AUC: 0.4831
timestep:32, pyg_AUC: 0.4873
timestep:33, pyg_AUC: 0.4915
timestep:34, pyg_AUC: 0.

 85%|████████▌ | 17/20 [28:48<05:06, 102.19s/it]

timestep:499, pyg_AUC: 0.4859
Training diffusion model (unconditional) ...
Epoch: 0000 loss= 35.14841
Epoch: 0010 loss= 34.99444
Epoch: 0020 loss= 22.03244
Epoch: 0030 loss= 16.98173
Epoch: 0040 loss= 13.87246
Epoch: 0050 loss= 6.01774
Epoch: 0060 loss= 1.51885
Epoch: 0070 loss= 0.91288
Epoch: 0080 loss= 0.66664
Epoch: 0090 loss= 0.97027
Epoch: 0100 loss= 0.68271
Epoch: 0110 loss= 0.57238
Epoch: 0120 loss= 0.62101
Epoch: 0130 loss= 0.65714
Epoch: 0140 loss= 0.60411
Epoch: 0150 loss= 0.53573
Epoch: 0160 loss= 0.67983
Epoch: 0170 loss= 0.61286
Epoch: 0180 loss= 0.58674
Epoch: 0190 loss= 0.58891
Epoch: 0200 loss= 0.53657
Epoch: 0210 loss= 0.53797
Epoch: 0220 loss= 0.51880
Epoch: 0230 loss= 0.53463
Epoch: 0240 loss= 0.52000
Epoch: 0250 loss= 0.56848
Epoch: 0260 loss= 0.50461
Epoch: 0270 loss= 0.54934
Epoch: 0280 loss= 0.54307
Epoch: 0290 loss= 0.56943
Epoch: 0300 loss= 0.53307
Epoch: 0310 loss= 0.45466
Epoch: 0320 loss= 0.48800
Epoch: 0330 loss= 0.53185
Epoch: 0340 loss= 0.54679
Epoch: 035

  dm_dict = torch.load(os.path.join(self.ae_path, 'edm.pt'))


Epoch: 0010 loss= 25.78571
Epoch: 0020 loss= 13.88853
Epoch: 0030 loss= 3.97133
Epoch: 0040 loss= 1.88572
Epoch: 0050 loss= 1.00469
Epoch: 0060 loss= 0.73359
Epoch: 0070 loss= 0.68733
Epoch: 0080 loss= 0.67551
Epoch: 0090 loss= 0.58386
Epoch: 0100 loss= 0.60935
Epoch: 0110 loss= 0.64305
Epoch: 0120 loss= 0.59034
Epoch: 0130 loss= 0.65237
Epoch: 0140 loss= 0.60386
Epoch: 0150 loss= 0.55887
Epoch: 0160 loss= 0.62363
Epoch: 0170 loss= 0.61377
Epoch: 0180 loss= 0.72035
Epoch: 0190 loss= 0.61290
Epoch: 0200 loss= 0.62514
Epoch: 0210 loss= 0.55095
Epoch: 0220 loss= 0.56740
Epoch: 0230 loss= 0.57327
Epoch: 0240 loss= 0.57332
Epoch: 0250 loss= 0.56735
Epoch: 0260 loss= 0.57198
Early stopping


  dm_free_dict = torch.load(os.path.join(self.ae_path, 'conditional_edm.pt'))


timestep:0, pyg_AUC: 0.4929
timestep:1, pyg_AUC: 0.4887
timestep:2, pyg_AUC: 0.4873
timestep:3, pyg_AUC: 0.4831
timestep:4, pyg_AUC: 0.4859
timestep:5, pyg_AUC: 0.4901
timestep:6, pyg_AUC: 0.4887
timestep:7, pyg_AUC: 0.4873
timestep:8, pyg_AUC: 0.4873
timestep:9, pyg_AUC: 0.4901
timestep:10, pyg_AUC: 0.4887
timestep:11, pyg_AUC: 0.4873
timestep:12, pyg_AUC: 0.4944
timestep:13, pyg_AUC: 0.4887
timestep:14, pyg_AUC: 0.4915
timestep:15, pyg_AUC: 0.4915
timestep:16, pyg_AUC: 0.4915
timestep:17, pyg_AUC: 0.4944
timestep:18, pyg_AUC: 0.4944
timestep:19, pyg_AUC: 0.4873
timestep:20, pyg_AUC: 0.4901
timestep:21, pyg_AUC: 0.4901
timestep:22, pyg_AUC: 0.4901
timestep:23, pyg_AUC: 0.4887
timestep:24, pyg_AUC: 0.4915
timestep:25, pyg_AUC: 0.4958
timestep:26, pyg_AUC: 0.4873
timestep:27, pyg_AUC: 0.4859
timestep:28, pyg_AUC: 0.4901
timestep:29, pyg_AUC: 0.4915
timestep:30, pyg_AUC: 0.4859
timestep:31, pyg_AUC: 0.4887
timestep:32, pyg_AUC: 0.4915
timestep:33, pyg_AUC: 0.4929
timestep:34, pyg_AUC: 0.

 90%|█████████ | 18/20 [30:29<03:23, 101.97s/it]

timestep:498, pyg_AUC: 0.4915
timestep:499, pyg_AUC: 0.4887
Training diffusion model (unconditional) ...
Epoch: 0000 loss= 43.06204
Epoch: 0010 loss= 27.59716
Epoch: 0020 loss= 21.59381
Epoch: 0030 loss= 18.75920
Epoch: 0040 loss= 10.48564
Epoch: 0050 loss= 6.08166
Epoch: 0060 loss= 1.56193
Epoch: 0070 loss= 0.83657
Epoch: 0080 loss= 0.63943
Epoch: 0090 loss= 0.56184
Epoch: 0100 loss= 0.54279
Epoch: 0110 loss= 0.58262
Epoch: 0120 loss= 0.56408
Epoch: 0130 loss= 0.61578
Epoch: 0140 loss= 0.56002
Epoch: 0150 loss= 0.57404
Epoch: 0160 loss= 0.62803
Epoch: 0170 loss= 0.60751
Epoch: 0180 loss= 0.56756
Epoch: 0190 loss= 0.57722
Epoch: 0200 loss= 0.66611
Epoch: 0210 loss= 0.55400
Epoch: 0220 loss= 0.59936
Epoch: 0230 loss= 0.48546
Epoch: 0240 loss= 0.61613
Epoch: 0250 loss= 0.58629
Epoch: 0260 loss= 0.56399
Epoch: 0270 loss= 0.57619
Epoch: 0280 loss= 0.57528
Epoch: 0290 loss= 0.62420
Epoch: 0300 loss= 0.53195
Epoch: 0310 loss= 0.56362
Early stopping
Common feature: tensor([[-4.8256,  4.7612, 

  dm_dict = torch.load(os.path.join(self.ae_path, 'edm.pt'))


Epoch: 0010 loss= 29.35122
Epoch: 0020 loss= 17.37148
Epoch: 0030 loss= 13.07904
Epoch: 0040 loss= 8.53599
Epoch: 0050 loss= 1.63882
Epoch: 0060 loss= 0.75722
Epoch: 0070 loss= 0.68202
Epoch: 0080 loss= 0.68124
Epoch: 0090 loss= 0.59762
Epoch: 0100 loss= 0.57713
Epoch: 0110 loss= 0.63347
Epoch: 0120 loss= 0.60384
Epoch: 0130 loss= 0.56794
Epoch: 0140 loss= 0.57344
Epoch: 0150 loss= 0.51956
Epoch: 0160 loss= 0.60322
Epoch: 0170 loss= 0.52328
Epoch: 0180 loss= 0.55711
Epoch: 0190 loss= 0.52471
Epoch: 0200 loss= 0.53583
Epoch: 0210 loss= 0.63664
Epoch: 0220 loss= 0.60140
Epoch: 0230 loss= 0.52650
Early stopping


  dm_free_dict = torch.load(os.path.join(self.ae_path, 'conditional_edm.pt'))


timestep:0, pyg_AUC: 0.4859
timestep:1, pyg_AUC: 0.4873
timestep:2, pyg_AUC: 0.4915
timestep:3, pyg_AUC: 0.4873
timestep:4, pyg_AUC: 0.4873
timestep:5, pyg_AUC: 0.4915
timestep:6, pyg_AUC: 0.4915
timestep:7, pyg_AUC: 0.4887
timestep:8, pyg_AUC: 0.4901
timestep:9, pyg_AUC: 0.4845
timestep:10, pyg_AUC: 0.4873
timestep:11, pyg_AUC: 0.4887
timestep:12, pyg_AUC: 0.4901
timestep:13, pyg_AUC: 0.4873
timestep:14, pyg_AUC: 0.4816
timestep:15, pyg_AUC: 0.4845
timestep:16, pyg_AUC: 0.4887
timestep:17, pyg_AUC: 0.4887
timestep:18, pyg_AUC: 0.4873
timestep:19, pyg_AUC: 0.4873
timestep:20, pyg_AUC: 0.4901
timestep:21, pyg_AUC: 0.4873
timestep:22, pyg_AUC: 0.4901
timestep:23, pyg_AUC: 0.4887
timestep:24, pyg_AUC: 0.4873
timestep:25, pyg_AUC: 0.4873
timestep:26, pyg_AUC: 0.4929
timestep:27, pyg_AUC: 0.4845
timestep:28, pyg_AUC: 0.4915
timestep:29, pyg_AUC: 0.4915
timestep:30, pyg_AUC: 0.4915
timestep:31, pyg_AUC: 0.4873
timestep:32, pyg_AUC: 0.4901
timestep:33, pyg_AUC: 0.4873
timestep:34, pyg_AUC: 0.

 95%|█████████▌| 19/20 [32:08<01:41, 101.12s/it]

timestep:498, pyg_AUC: 0.4901
timestep:499, pyg_AUC: 0.4915
Training diffusion model (unconditional) ...
Epoch: 0000 loss= 34.92589
Epoch: 0010 loss= 23.20242
Epoch: 0020 loss= 21.63348
Epoch: 0030 loss= 18.02320
Epoch: 0040 loss= 17.32604
Epoch: 0050 loss= 8.92889
Epoch: 0060 loss= 2.53881
Epoch: 0070 loss= 1.02574
Epoch: 0080 loss= 0.97408
Epoch: 0090 loss= 0.69554
Epoch: 0100 loss= 0.62242
Epoch: 0110 loss= 0.67575
Epoch: 0120 loss= 0.65921
Epoch: 0130 loss= 0.69209
Epoch: 0140 loss= 0.62293
Epoch: 0150 loss= 0.58934
Epoch: 0160 loss= 0.56748
Epoch: 0170 loss= 0.64608
Epoch: 0180 loss= 0.64471
Epoch: 0190 loss= 0.58903
Epoch: 0200 loss= 0.54537
Epoch: 0210 loss= 0.55741
Epoch: 0220 loss= 0.58101
Epoch: 0230 loss= 0.56812
Epoch: 0240 loss= 0.63996
Epoch: 0250 loss= 0.56430
Epoch: 0260 loss= 0.59395
Epoch: 0270 loss= 0.55968
Epoch: 0280 loss= 0.63170
Epoch: 0290 loss= 0.50627
Epoch: 0300 loss= 0.63401
Epoch: 0310 loss= 0.49713
Epoch: 0320 loss= 0.53002
Epoch: 0330 loss= 0.49404
Epoch:

  dm_dict = torch.load(os.path.join(self.ae_path, 'edm.pt'))


Epoch: 0030 loss= 6.09505
Epoch: 0040 loss= 2.86004
Epoch: 0050 loss= 1.28045
Epoch: 0060 loss= 0.86548
Epoch: 0070 loss= 0.82537
Epoch: 0080 loss= 0.69127
Epoch: 0090 loss= 0.62643
Epoch: 0100 loss= 0.62424
Epoch: 0110 loss= 0.60933
Epoch: 0120 loss= 0.56010
Epoch: 0130 loss= 0.60436
Epoch: 0140 loss= 0.62943
Epoch: 0150 loss= 0.59407
Epoch: 0160 loss= 0.60490
Epoch: 0170 loss= 0.60129
Epoch: 0180 loss= 0.56619
Epoch: 0190 loss= 0.52025
Epoch: 0200 loss= 0.61643
Epoch: 0210 loss= 0.54675
Epoch: 0220 loss= 0.60639
Epoch: 0230 loss= 0.59419
Epoch: 0240 loss= 0.57819
Epoch: 0250 loss= 0.54647
Epoch: 0260 loss= 0.54994
Epoch: 0270 loss= 0.57956
Epoch: 0280 loss= 0.58097
Epoch: 0290 loss= 0.52719
Epoch: 0300 loss= 0.59527
Epoch: 0310 loss= 0.54600
Epoch: 0320 loss= 0.58268
Early stopping


  dm_free_dict = torch.load(os.path.join(self.ae_path, 'conditional_edm.pt'))


timestep:0, pyg_AUC: 0.4816
timestep:1, pyg_AUC: 0.4887
timestep:2, pyg_AUC: 0.4915
timestep:3, pyg_AUC: 0.4915
timestep:4, pyg_AUC: 0.4859
timestep:5, pyg_AUC: 0.4901
timestep:6, pyg_AUC: 0.4845
timestep:7, pyg_AUC: 0.4859
timestep:8, pyg_AUC: 0.4929
timestep:9, pyg_AUC: 0.4901
timestep:10, pyg_AUC: 0.4887
timestep:11, pyg_AUC: 0.4944
timestep:12, pyg_AUC: 0.4845
timestep:13, pyg_AUC: 0.4915
timestep:14, pyg_AUC: 0.4901
timestep:15, pyg_AUC: 0.4873
timestep:16, pyg_AUC: 0.4901
timestep:17, pyg_AUC: 0.4887
timestep:18, pyg_AUC: 0.4887
timestep:19, pyg_AUC: 0.4859
timestep:20, pyg_AUC: 0.4831
timestep:21, pyg_AUC: 0.4929
timestep:22, pyg_AUC: 0.4929
timestep:23, pyg_AUC: 0.4859
timestep:24, pyg_AUC: 0.4873
timestep:25, pyg_AUC: 0.4859
timestep:26, pyg_AUC: 0.4929
timestep:27, pyg_AUC: 0.4915
timestep:28, pyg_AUC: 0.4845
timestep:29, pyg_AUC: 0.4887
timestep:30, pyg_AUC: 0.4901
timestep:31, pyg_AUC: 0.4901
timestep:32, pyg_AUC: 0.4915
timestep:33, pyg_AUC: 0.4901
timestep:34, pyg_AUC: 0.

100%|██████████| 20/20 [33:48<00:00, 101.43s/it]

timestep:498, pyg_AUC: 0.4859
timestep:499, pyg_AUC: 0.4887



