In [1]:
import os
import math
import time
import numpy as np
import pandas as pd

import pyarrow
import fastparquet

# Pytorch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import random_split, Dataset, DataLoader

# Pytorch Lightening
import pytorch_lightning as pl

# For Visualization
import seaborn as sns 
import matplotlib.pyplot as plt
%matplotlib inline     
sns.set(color_codes=True)

In [2]:
# load dataset
def load_data(input_path=""):

    # read data
    data = pd.read_parquet(input_path).to_numpy()

    # normalize data: map (0 → 0.0, 1 → 0.5, 2 → 1.0)
    data = np.where(data == 0, 0.0, data)  # Map 0 to 0.0
    data = np.where(data == 1, 0.5, data)  # Map 1 to 0.5
    data = np.where(data == 2, 1.0, data)  # Map 2 to 1.0
    
    return torch.FloatTensor(data)

In [3]:
input_file = "HO_data/HO_data_filtered/HumanOrigins2067_filtered.parquet"

In [4]:
snp_data = load_data(input_file)

In [5]:
snp_data.shape

torch.Size([160858, 2067])

In [6]:
# Verification
unique_values = np.unique(snp_data)
print("Unique values after normalization:", unique_values)  # Should show [0.0, 0.5, 1.0, 9.0]

Unique values after normalization: [0.  0.5 1.  9. ]


### _LightningDataModule_

In [7]:
class SNPDataModule(pl.LightningDataModule):
    def __init__(self, input_path, batch_size=256, num_workers=1):
        super().__init__()
        self.path = input_path
        self.batch_size = batch_size
        self.workers = num_workers
        self.data_split = [128686, 16086, 16086] # 80%, 10% and 10%

    # Setup Data
    def setup(self, stage=None):
        """Prepare the dataset"""
        full_dataset = load_data(self.path)
        self.trainset, self.valset, self.testset = random_split(
            full_dataset,
            self.data_split,
            generator=torch.Generator().manual_seed(42)  # Fixed seed for reproducibility
        )

    # Data Loaders
    def train_dataloader(self):
        return DataLoader(
            self.trainset, batch_size=self.batch_size, shuffle=True, num_workers=self.workers
            )  # , pin_memory=True, persistent_workers=True)

    def val_dataloader(self):
        return DataLoader(
            self.valset, batch_size=self.batch_size, shuffle=False, num_workers=self.workers
            )  # , pin_memory=True, persistent_workers=True)
        
    def test_dataloader(self):
        return DataLoader(
            self.testset, batch_size=self.batch_size, shuffle=False, num_workers=self.workers
            )  # , pin_memory=True, persistent_workers=True)

In [8]:
# initialize DataModule
snp_data_module = SNPDataModule(input_path=input_file, batch_size=256, num_workers=1)

In [9]:
# Setup Data
snp_data_module.setup()

In [10]:
# Train DataLoader
train_loader = snp_data_module.train_dataloader()

In [11]:
# Get a batch from DataLoader
sample_batch = next(iter(train_loader))

In [12]:
print("Batch Shape:", sample_batch.shape)  # Expected: (batch_size, num_markers)
print("First 5 Samples:\n", sample_batch[:5])  # Show first 5 rows

Batch Shape: torch.Size([256, 2067])
First 5 Samples:
 tensor([[0.0000, 0.0000, 0.0000,  ..., 0.5000, 0.5000, 0.5000],
        [0.0000, 1.0000, 0.0000,  ..., 0.0000, 0.5000, 1.0000],
        [0.5000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.5000,  ..., 1.0000, 0.0000, 1.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]])


## _Test Model_

In [16]:
from model import DDPM

In [18]:
# Example data
batch_size = 16
num_snps = 2067
device = "cuda" if torch.cuda.is_available() else "cpu"

# Initialize model
ddpm = DDPM(snp_length=num_snps).to(device)

# Generate synthetic SNP samples
generated_snps = ddpm.sample(batch_size=batch_size, snp_length=num_snps, device=device)
print(generated_snps.shape)  # Expected output: [batch_size, 1, num_snps]

RuntimeError: mat1 and mat2 shapes cannot be multiplied (512x1 and 32x32)

### _LightningModule_

- Model
- Training Hooks (training, validation, testing)
- Data Hooks (training, validation, testing)
- etc.

In [None]:
class NetworkBase(pl.LightningModule):
    def __init__(self, input_path, hparams):
        super().__init__()
        
        self.path = hparams["input_path"]
        self.split = [10, 1, 1]
        self.batch = 64
        self.workers = 4
        self.data_split = [128686, 16086, 16086] # 80%, 10% and 10%
        self.trainset, self.valset, self.testset = None, None, None
        
        # Save hyperparameters
        self.save_hyperparameters(hparams)

    # Setup Data
    def setup(self, stage=None):
        """Prepare the dataset"""
        full_dataset = load_data(self.path)
        self.trainset, self.valset, self.testset = random_split(
            full_dataset,
            self.data_split,
            generator=torch.Generator().manual_seed(42)  # Fixed seed for reproducibility
        )

    # Data Loaders
    def train_dataloader(self):
        return DataLoader(
            self.trainset, batch_size=self.batch_size, shuffle=True, num_workers=self.workers
            )  # , pin_memory=True, persistent_workers=True)

    def val_dataloader(self):
        return DataLoader(
            self.valset, batch_size=self.batch_size, shuffle=False, num_workers=self.workers
            )  # , pin_memory=True, persistent_workers=True)
        
    def test_dataloader(self):
        return DataLoader(
            self.testset, batch_size=self.batch_size, shuffle=False, num_workers=self.workers
            )  # , pin_memory=True, persistent_workers=True)

    # Configure Optimizer & Scheduler
    def configure_optimizers(self):
        """Configure the Optimizer and Scheduler"""
        optimizer = [
            torch.optim.AdamW(
                self.parameters(),
                lr=(self.hparams["lr"]),
                betas=(0.9, 0.999),
                eps=1e-08,
                amsgrad=True,
            )
        ]
        scheduler = [
            {
                "scheduler": torch.optim.lr_scheduler.StepLR(
                    optimizer[0],
                    step_size=0.3,
                    gamma=10,
                ),
                "interval": "epoch",
                "frequency": 1,
            }
        ]
        return optimizer, scheduler
    
    # Trainig Step
    def training_step(self, batch, batch_idx):
        # YOUR CODE HERE:
        pass

    # Validation Step
    def validation_step(self, batch, batch_idx):
        # YOUR CODE HERE:
        pass

    # Test Step
    def test_step(self, batch, batch_idx):
        # YOUR CODE HERE:
        pass

    # Optimizer Step
    def on_before_optimizer_step(self, optimizer, *args, **kwargs):
        """Settings before Optimizer Step"""

        # warm up lr
        if self.hparams.get("warmup", 0) and (
            self.trainer.current_epoch < self.hparams["warmup"]
        ):
            lr_scale = min(
                1.0, float(self.trainer.current_epoch + 1) / self.hparams["warmup"]
            )
            for pg in optimizer.param_groups:
                pg["lr"] = lr_scale * self.hparams["lr"]

        # after reaching minimum learning rate, stop LR decay
        for pg in optimizer.param_groups:
            pg["lr"] = max(pg["lr"], self.hparams.get("min_lr", 0))

### _Model Development_

- These define the range of noise levels.
- The noise increases from 1e-4 (almost no noise) to 0.02 (more noise) over time.
- `t_range` → Total number of diffusion steps.
- `in_size` → Input image size (flattened).
- `img_depth` → Number of image channels (e.g., 3 for RGB).
- `self.unet` → A U-Net model that predicts the noise at each step.
- Simply passes the input image (x) and time step (t) to the `U-Net`.
- The `U-Net` predicts the noise (ϵ) at that time step.
- `Beta function` Linearly interpolates between beta_small and beta_large over time, Controls how much noise is added at each time step.
- `α-alpha function` controls signal preservation. Since β(t) is noise, α(t) = 1 - β(t) represents how much of the original image remains at each step.
- `α̅(t)-Cumulative alpha` is the product of all previous α(t) values.Represents the total preservation of the original image after t steps.
- `Get_loss` Selects a random diffusion step t for each image in the batch. Generates Gaussian noise ϵ for each image.
- Loop: Computes the noisy version of the image using First term: Preserves part of the original image. Second term: Adds noise.
- Denoising and Loss calculation: Runs noisy images through the U-Net to predict the noise (e_hat). Loss function: Mean Squared Error (MSE) between: Predicted noise (e_hat) Actual noise (ϵ) This teaches the model to predict noise correctly, enabling image denoising.
- `Denoise_sample` Starts from a noisy sample (x_T = pure noise).Generates random Gaussian noise (z) unless it’s the last step.
- Gets predicted noise (e_hat) from U-Net. Computes the denoised image (x_{t-1}) using:
First term: Restores signal.
Second term: Removes predicted noise.
Third term: Adds slight randomness (for realistic diversity).

In [None]:
class DiffusionModel(NetworkBase):
    def __init__(self, in_size, t_range, img_depth):
        super().__init__()
        self.beta_small = 1e-4
        self.beta_large = 0.02
        self.t_range = t_range
        self.in_size = in_size

        self.unet = Unet(dim = 64, dim_mults = (1, 2, 4, 8), channels=img_depth)

    def forward(self, x, t):
        return self.unet(x, t)

    def beta(self, t):
        # Just a simple linear interpolation between beta_small and beta_large based on t
        return self.beta_small + (t / self.t_range) * (self.beta_large - self.beta_small)

    def alpha(self, t):
        return 1 - self.beta(t)

    def alpha_bar(self, t):
        # Product of alphas from 0 to t
        return math.prod([self.alpha(j) for j in range(t)])

    def get_loss(self, batch, batch_idx):
        """
        Corresponds to Algorithm 1 from (Ho et al., 2020).
        """
        # Get a random time step for each image in the batch
        ts = torch.randint(0, self.t_range, [batch.shape[0]], device=self.device)
        noise_imgs = []
        # Generate noise, one for each image in the batch
        epsilons = torch.randn(batch.shape, device=self.device)
        for i in range(len(ts)):
            a_hat = self.alpha_bar(ts[i])
            noise_imgs.append(
                (math.sqrt(a_hat) * batch[i]) + (math.sqrt(1 - a_hat) * epsilons[i])
            )
        noise_imgs = torch.stack(noise_imgs, dim=0)
        # Run the noisy images through the U-Net, to get the predicted noise
        e_hat = self.forward(noise_imgs, ts)
        # Calculate the loss, that is, the MSE between the predicted noise and the actual noise
        loss = nn.functional.mse_loss(
            e_hat.reshape(-1, self.in_size), epsilons.reshape(-1, self.in_size)
        )
        return loss

    def denoise_sample(self, x, t):
        """
        Corresponds to the inner loop of Algorithm 2 from (Ho et al., 2020).
        """
        with torch.no_grad():
            if t > 1:
                z = torch.randn(x.shape)
            else:
                z = 0
            # Get the predicted noise from the U-Net
            e_hat = self.forward(x, t.view(1).repeat(x.shape[0]))
            # Perform the denoising step to take the image from t to t-1
            pre_scale = 1 / math.sqrt(self.alpha(t))
            e_scale = (1 - self.alpha(t)) / math.sqrt(1 - self.alpha_bar(t))
            post_sigma = math.sqrt(self.beta(t)) * z
            x = pre_scale * (x - e_scale * e_hat) + post_sigma
            return x


In [1]:
# instantiate model
# my_model = DiffusionMode()

- `dim`: Initial dimension size for the filters in the U-Net.
- `dim_mults`: A tuple representing how the dimensions of the feature maps increase during the downsampling and upsampling.
- `channels:` The number of input channels (4 for one-hot encoded SNPs: A, C, G, T).
- `prev_dim:` Starts as the number of channels (4 for one-hot encoding) and gets updated as we add layers.
- `Downsampling layers:` For each value in dim_mults, we create a 1D convolutional layer (nn.Conv1d) that increases the depth of the feature maps.prev_dim keeps track of the number of channels from the previous layer. kernel_size=3 and padding=1 keep the sequence length intact during convolutions (3x3 kernels).
- `Upsampling layers:` This part creates the decoder, where we reduce the depth of the feature maps and aim to recreate the input SNP sequence.
- `Reversed dim_mults:` Since we are upsampling, we reverse the order of dim_mults.
- Final layer: A 1D convolution to reduce the output back to the number of classes (4), i.e., the SNP categories (A, C, G, T).


In [None]:
class Unet(nn.Module):
    """ Simple U-Net for SNP denoising. """
    def __init__(self, dim=64, dim_mults=(0, 0.5, 1, 9), channels=4):
        super().__init__()
        self.downs = nn.ModuleList()
        self.ups = nn.ModuleList()
        
        prev_dim = channels
        for mult in dim_mults:
            self.downs.append(nn.Conv1d(prev_dim, dim * mult, kernel_size=3, padding=1))
            prev_dim = dim * mult
        
        for mult in reversed(dim_mults):
            self.ups.append(nn.Conv1d(prev_dim, dim * mult, kernel_size=3, padding=1))
            prev_dim = dim * mult
        
        self.final = nn.Conv1d(prev_dim, channels, kernel_size=1)

    def forward(self, x, t):
        skips = []
        for down in self.downs:
            x = F.relu(down(x))
            skips.append(x)

        for up in self.ups:
            x = F.relu(up(x + skips.pop()))  # Skip connections

        return self.final(x)  # Output logits

- `snp_length:` Length of the SNP sequence.
- `t_range:` The number of timesteps used for the diffusion process.
- `num_classes:` Number of possible categories for each SNP (A, C, G, T).
- `beta_small and beta_large:` Parameters that control the amount of noise added at each timestep in the diffusion process.
- `Noise schedule:` The beta(t) function defines the noise level at each timestep. It linearly interpolates between beta_small and beta_large over the range of t.
- `q_sample(x, t):` Adds noise to the input SNP sequence x at timestep t.
- `(1 - beta_t) * x` keeps the original SNP with probability (1 - beta_t).
- `(beta_t / self.num_classes)` represents the small probability of flipping the SNP to any of the other categories (A, C, G, T).
- `torch.multinomial(probs.view(-1, self.num_classes), 1):` Samples new SNPs based on the probabilities.
- `Forward pass:` The noisy SNP data x is passed through the U-Net to get predicted categorical probabilities.
- `get_loss():` Computes the loss function for training. It calculates the categorical cross-entropy between the predicted SNP sequence and the ground truth.
- `ts:` Randomly generates timesteps for each sample in the batch.

In [None]:
class DiffusionModel(nn.Module):
    def __init__(self, snp_length, t_range):
        super().__init__()
        self.snp_length = snp_length
        self.t_range = t_range
        self.beta_small = 1e-4
        self.beta_large = 0.02

        # Define U-Net (assuming 1 input channel for SNPs)
        self.unet = Unet(dim=64, dim_mults=(1, 2, 4, 8), channels=1)  

    def beta(self, t):
        """ Defines the noise schedule: a simple linear interpolation. """
        return self.beta_small + (t / self.t_range) * (self.beta_large - self.beta_small)

    def q_sample(self, x, t):
        """
        Forward diffusion process: Adds noise to SNP data.
        x: SNP data (normalized, continuous)
        t: Time step
        """
        beta_t = self.beta(t)  # Get noise level at time t

        # Apply Gaussian noise (since SNPs are normalized as continuous values)
        noise = torch.randn_like(x) * math.sqrt(beta_t)

        x_t = x + noise  # Add noise to SNP data

        # Clip values to stay within valid SNP range
        return torch.clamp(x_t, 0.0, 1.0)

    def forward(self, x, t):
        """ Predicts the noise added to SNPs. """
        return self.unet(x, t)

    def get_loss(self, batch, batch_idx):
        """
        Training loss using Mean Squared Error (MSE).
        """
        ts = torch.randint(0, self.t_range, [batch.shape[0]], device=batch.device)

        # Apply Gaussian noise
        noisy_batch = torch.stack([self.q_sample(batch[i], ts[i]) for i in range(len(ts))])

        # Run noisy SNPs through the model to predict the noise
        e_hat = self.forward(noisy_batch, ts)

        # Compute MSE loss between predicted noise and actual noise
        loss = nn.MSELoss()(e_hat, noisy_batch - batch)  
        return loss

    def denoise_sample(self, x, t):
        """
        Reverse diffusion: Recovers the SNPs step-by-step.
        """
        with torch.no_grad():
            # Predict the noise
            e_hat = self.forward(x, t.view(1).repeat(x.shape[0]))

            # Perform the denoising step: x_(t-1) = x_t - predicted_noise
            x_t_minus_1 = x - e_hat

            # Clip values to stay within valid SNP range
            return torch.clamp(x_t_minus_1, 0.0, 1.0)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts

class SinusoidalTimeEmbedding(nn.Module):
    def __init__(self, dim, max_period=10000):
        super().__init__()
        self.dim = dim
        self.max_period = max_period

    def forward(self, timesteps):
        half_dim = self.dim // 2
        freqs = torch.exp(
            -math.log(self.max_period) * torch.arange(half_dim, dtype=torch.float32) / half_dim
        ).to(timesteps.device)
        args = timesteps[:, None].float() * freqs[None]
        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
        if self.dim % 2:
            embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
        return embedding

class ResidualBlock(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.fc1 = nn.Linear(dim, dim)
        self.time_fc = nn.Linear(dim, dim)
        self.fc2 = nn.Linear(dim, dim)
        self.activation = nn.ReLU()

    def forward(self, x, time_emb):
        residual = x
        x = self.fc1(x)
        x += self.time_fc(time_emb)  # Time conditioning
        x = self.activation(x)
        x = self.fc2(x)
        x += residual  # Residual connection
        x = self.activation(x)
        return x

class DiffusionModel(nn.Module):
    def __init__(self, snp_length, t_range, time_emb_dim=256):
        super().__init__()
        self.snp_length = snp_length
        self.t_range = t_range
        self.beta_small = 1e-4
        self.beta_large = 0.02
        self.time_emb_dim = time_emb_dim

        # Time embedding layer
        self.time_embedding = SinusoidalTimeEmbedding(time_emb_dim)

        # Label embedding for 3 classes (super-populations)
        self.label_embedding = nn.Embedding(3, 3)

        # Predictor network
        self.input_layer = nn.Linear(snp_length + 3, 256)
        self.res_block1 = ResidualBlock(256)
        self.res_block2 = ResidualBlock(256)
        self.output_layer = nn.Linear(256, snp_length)

    def beta(self, t):
        """ Defines the noise schedule: a simple linear interpolation. """
        return self.beta_small + (t / self.t_range) * (self.beta_large - self.beta_small)

    def q_sample(self, x, t):
        """ Forward diffusion process: Adds noise to SNP data. """
        beta_t = self.beta(t)  # Get noise level at time t
        noise = torch.randn_like(x) * math.sqrt(beta_t)
        x_t = x + noise  # Add noise to SNP data
        return torch.clamp(x_t, 0.0, 1.0)  # Keep within SNP range

    def forward(self, x, t, labels):
        """ Predicts the noise added to SNPs. """
        # Compute sinusoidal time embeddings
        time_emb = self.time_embedding(t)

        # Compute label embeddings
        label_emb = self.label_embedding(labels)

        # Concatenate SNP data with label embedding
        x = torch.cat([x, label_emb], dim=-1)

        # Pass through the predictor network
        x = self.input_layer(x)
        x = F.relu(x)

        x = self.res_block1(x, time_emb)
        x = self.res_block2(x, time_emb)

        x = self.output_layer(x)
        return x

    def get_loss(self, batch, labels):
        """ Compute MSE loss for training. """
        ts = torch.randint(0, self.t_range, [batch.shape[0]], device=batch.device)
        noisy_batch = torch.stack([self.q_sample(batch[i], ts[i]) for i in range(len(ts))])
        e_hat = self.forward(noisy_batch, ts, labels)
        loss = nn.MSELoss()(e_hat, noisy_batch - batch)
        return loss

    def denoise_sample(self, x, t, labels):
        """ Reverse diffusion to recover SNPs. """
        with torch.no_grad():
            time_emb = self.time_embedding(t)
            e_hat = self.forward(x, t, labels)
            x_t_minus_1 = x - e_hat
            return torch.clamp(x_t_minus_1, 0.0, 1.0)  # Keep SNPs in valid range

# Optimizer & Learning Rate Scheduler
def get_optimizer_and_scheduler(model):
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0003)
    scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=1000)
    return optimizer, scheduler


In [None]:
sample_batch = ... # from LightningDataModule

In [None]:
model = ... # instantiate model

In [None]:
output = model(batch)