# Cloning repository and installing dependencies

In [None]:
!git clone https://github.com/pinellolab/DNA-Diffusion.git && cd DNA-Diffusion && uv sync

In [2]:
%cd DNA-Diffusion

/content/DNA-Diffusion


# Creating a Simulated Dataset

Below we provide an example of how to use a new dataset with the DNA-Diffusion library. The dummy dataset has 3 sequences with an associated cell type of "CELL_A". We demonstrate that using this dataset we can regenerate the associated file "encode_data.pkl" that is used to train the model.

In [3]:
import pandas as pd
import numpy as np

tags = ['CELL_A', 'CELL_A', 'CELL_A']
chr = ["chr1", "chr2", "chr3"]

df = pd.DataFrame(columns=['chr', 'sequence', 'TAG'])

for i, (tag, chromosome) in enumerate(zip(tags, chr)):
    if i == 2:
        sequence = "A" * 200
        df.loc[i] = [chromosome, sequence, tag]
    else:
        sequence = ''.join(np.random.choice(list('ACGT'), size=200))
        df.loc[i] = [chromosome, sequence, tag]

df.to_csv('data/dummy_data.txt', index=False, sep='\t')

# Basic Training Example

Below we provide an example of the training using the debug flag. This will only train the model on a single sequence for a minimum of 5 epochs with a patience parameter of 2 epochs. We also show that the data file can be overrided within the CLI call to integrate the new dataset. It is important to set data.load_saved_data=False, so that the additional metadata used to train the model is regenerated.

In [4]:
!uv run train.py -cn train_debug data.data_path='data/dummy_data.txt' data.load_saved_data=False

model:
  _target_: src.dnadiffusion.models.unet.UNet
  dim: 200
  channels: 1
  dim_mults:
  - 1
  - 2
  - 4
  resnet_block_groups: 4
data:
  _target_: src.dnadiffusion.data.dataloader.get_dataset
  data_path: data/dummy_data.txt
  saved_data_path: data/encode_data.pkl
  load_saved_data: false
  debug: true
optimizer:
  _target_: torch.optim.Adam
  lr: 0.0001
diffusion:
  _target_: src.dnadiffusion.models.diffusion.Diffusion
  timesteps: 50
  beta_start: 0.0001
  beta_end: 0.2
training:
  distributed: false
  precision: float32
  num_workers: 1
  pin_memory: false
  batch_size: 1
  sample_batch_size: 1
  num_epochs: 2200
  min_epochs: 5
  patience: 2
  log_step: 50
  sample_epoch: 50000
  number_of_samples: 10
  use_wandb: false

  return fn(*args, **kwargs)
  0% 9/2200 [01:51<6:44:30, 11.08s/it]Early stopping at epoch 9, Best val loss: 0.30490654706954956 achieved at epoch 7
  0% 9/2200 [01:51<7:33:13, 12.41s/it]


# Basic sequence generation example using the created checkpoint

The model successfully trained and now we will use the checkpoint with the lowest validation loss to generate 1 sequence per cell type. Given that the training seed is not fixed, this is not a deterministic result and your validation loss may slightly vary from the cached example.

In [5]:
import os

checkpoint_dir = "checkpoints"
files = os.listdir(checkpoint_dir)
file_paths = sorted([os.path.join(checkpoint_dir, f) for f in files if os.path.isfile(os.path.join(checkpoint_dir, f)) and ".gitkeep" not in f])
best_checkpoint = file_paths[-1]

print(f"All available checkpoints: \n{file_paths}")
print(f"\nUsing checkpoint with lowest validation loss: \n{best_checkpoint}")

All available checkpoints: 
['checkpoints/model_epoch5_step6_valloss_0.348039.pt', 'checkpoints/model_epoch7_step8_valloss_0.304907.pt']

Using checkpoint with lowest validation loss: 
checkpoints/model_epoch7_step8_valloss_0.304907.pt


In [7]:
!uv run sample.py sampling.number_of_samples=1 sampling.sample_batch_size=1 sampling.checkpoint_path=$best_checkpoint data.data_path='data/dummy_data.txt'

model:
  _target_: src.dnadiffusion.models.unet.UNet
  dim: 200
  channels: 1
  dim_mults:
  - 1
  - 2
  - 4
  resnet_block_groups: 4
data:
  _target_: src.dnadiffusion.data.dataloader.get_dataset_for_sampling
  data_path: data/dummy_data.txt
  saved_data_path: data/encode_data.pkl
  load_saved_data: true
  debug: false
  cell_types: null
diffusion:
  _target_: src.dnadiffusion.models.diffusion.Diffusion
  timesteps: 50
  beta_start: 0.0001
  beta_end: 0.2
sampling:
  checkpoint_path: checkpoints/model_epoch7_step8_valloss_0.304907.pt
  sample_batch_size: 1
  number_of_samples: 1
  guidance_scale: 1.0

Loading checkpoint
Model sent to cuda
Found cell types: ['CELL_A']
Generating 1 samples for cell CELL_A
100% 1/1 [00:02<00:00,  2.40s/it]


# View Generated Sequences

In [8]:
import os
import subprocess

def display_sequences(output_dir="data/outputs"):
    if not os.path.isdir(output_dir):
        print(f"Error: Directory '{output_dir}' not found.")
        return

    print(f"Displaying sequences from: {output_dir}\n")

    for filename in sorted(os.listdir(output_dir)):
        filepath = os.path.join(output_dir, filename)

        if os.path.isfile(filepath) and "gitkeep" not in filepath:
          cell_type = filename.split('_')[0]
          print(f"--- Cell Type: {cell_type} ({filename}) ---")
          result = subprocess.run(['cat', filepath], capture_output=True, text=True, check=True)
          print(result.stdout)
          print("-" * (len(cell_type) + 18) + "\n")

display_sequences()

Displaying sequences from: data/outputs

--- Cell Type: CELL (CELL_A.txt) ---
TATATTTATGTTAAATTCATGCTTATTTTATATTTTTTTTTTTTTTTGAGTAGTTATTGTATTTTTTATATATTGAAAATATTTTTTTTTTTTACAAAAATAAAATATAAATAACATTTGTAAAATGTTCTAAGTGTGTGTTGCATTTTAATATATAATATTTTTATTTGTAAATAAATAATAATTTTTATTTTTGTTTG
----------------------

