In [13]:
import os
import math
import time
import numpy as np
import pandas as pd


import dataclasses
from typing import Sequence
import functools
from typing import Tuple  # Add this line to import Tuple
from torch import optim
#import pytorch_warmup as warmup

# Pytorch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import random_split, Dataset, DataLoader

# Pytorch Lightening
import pytorch_lightning as pl

### _Diffusion/Noising Process_

In [14]:
# Positional Embedding, Time Sampling and DDPM Process
from model import SinusoidalPositionalEmbeddings, UniformDiscreteTimeSampler, DDPMProcess

### _U-Net_

In [15]:
# Residual Join, Downsampling, Upsampling, ConvBlock and ResnetBlock
from model import Residual, DownsampleConv, UpsampleConv, ConvBlock, ResnetBlock

In [20]:
from model import Unet1D

## _Final Diffusion Model_

In [18]:
class DiffusionModel(nn.Module):
    """Diffusion model with 1D Convolutional network for SNP data."""

    def __init__(self, diffusion_process, time_sampler, net_config, data_shape):
        super(DiffusionModel, self).__init__()
        self._process = diffusion_process
        self._time_sampler = time_sampler
        self._net_config = net_config
        self._data_shape = data_shape
        self.net_fwd = Net(net_config)  # Uses Net with ResidualConv1D

    def loss(self, x0: torch.Tensor) -> torch.Tensor:
        """
        Computes MSE between true noise and predicted noise.
        The network's goal is to correctly predict noise (eps) from noisy observations.

        Args:
            x0 (torch.Tensor): Original clean input data (batch_size, seq_len)

        Returns:
            torch.Tensor: MSE loss
        """
        t = self._time_sampler.sample(shape=(x0.shape[0],))  # Sample time
        eps = torch.randn_like(x0, device=x0.device)         # Sample noise
        xt = self._process.sample(x0, t, eps)                # Corrupt the data
        net_outputs = self.net_fwd(xt, t)             # Pass through Conv1D model
        loss = torch.mean((net_outputs - eps) ** 2)          # Compute MSE loss
        return loss

    def loss_per_timesteps(self, x0: torch.Tensor, eps: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor:
        """
        Computes loss at specific timesteps.

        Args:
            x0 (torch.Tensor): Original clean input data.
            eps (torch.Tensor): Sampled noise.
            timesteps (torch.Tensor): Selected timesteps.

        Returns:
            torch.Tensor: Loss values for each timestep.
        """
        losses = []
        for t in timesteps:
            t = int(t.item()) * torch.ones((x0.shape[0],), dtype=torch.int32, device=x0.device)
            xt = self._process.sample(x0, t, eps)
            net_outputs = self.net_fwd(xt, t)
            loss = torch.mean((net_outputs - eps) ** 2)
            losses.append(loss)
        return torch.stack(losses)

    def _reverse_process_step(self, xt: torch.Tensor, t: int) -> torch.Tensor:
        """
        Reverse diffusion step to estimate x_{t-1} given x_t.

        Args:
            xt (torch.Tensor): Noisy input at time t.
            t (int): Current timestep.

        Returns:
            torch.Tensor: Estimated previous timestep data.
        """
        t = t * torch.ones((xt.shape[0],), dtype=torch.int32, device=xt.device)
        eps_pred = self.net_fwd(xt, t)  # Predict epsilon
        sqrt_a_t = self._process.alpha(t) / self._process.alpha(t - 1)
        inv_sqrt_a_t = 1.0 / sqrt_a_t
        beta_t = 1.0 - sqrt_a_t ** 2
        inv_sigma_t = 1.0 / self._process.sigma(t)
        mean = inv_sqrt_a_t * (xt - beta_t * inv_sigma_t * eps_pred)
        std = torch.sqrt(beta_t)
        z = torch.randn_like(xt)
        return mean + std * z


    def sample(self, x0, sample_size):
        """
        Samples from the learned reverse diffusion process without conditioning.
    
        Args:
            x0 (torch.Tensor): Initial input (not used, only for device reference).
            sample_size (int): Number of samples.
    
        Returns:
            torch.Tensor: Generated samples.
        """
        with torch.no_grad():
            x = torch.randn((sample_size,) + self._data_shape, device=x0.device)
            for t in range(self._process.tmax, 0, -1):
                x = self._reverse_process_step(x, t)  
        return x

## _Instantiating_

In [19]:
# create the model
diffusion_process = DiscreteDDPMProcess(num_diffusion_timesteps=1000)
time_sampler = UniformDiscreteTimeSampler(diffusion_process.tmin, diffusion_process.tmax)
model = DiffusionModel(diffusion_process, time_sampler, net_config=NetConfig(), data_shape=(6,))

NameError: name 'DiscreteDDPMProcess' is not defined