# Cloning repository and installing dependencies

In [None]:
!git clone https://github.com/pinellolab/DNA-Diffusion.git && cd DNA-Diffusion && uv sync

In [5]:
%cd DNA-Diffusion

/content/DNA-Diffusion


# Basic Training Example

Below we provide an example of the training using the debug flag. This will only train the model on a single sequence for a minimum of 5 epochs with a patience parameter of 2 epochs.

In [6]:
!uv run train.py -cn train_debug

model:
  _target_: src.dnadiffusion.models.unet.UNet
  dim: 200
  channels: 1
  dim_mults:
  - 1
  - 2
  - 4
  resnet_block_groups: 4
data:
  _target_: src.dnadiffusion.data.dataloader.get_dataset
  data_path: data/K562_hESCT0_HepG2_GM12878_12k_sequences_per_group.txt
  saved_data_path: data/encode_data.pkl
  load_saved_data: true
  debug: true
optimizer:
  _target_: torch.optim.Adam
  lr: 0.0001
diffusion:
  _target_: src.dnadiffusion.models.diffusion.Diffusion
  timesteps: 50
  beta_start: 0.0001
  beta_end: 0.2
training:
  distributed: false
  precision: float32
  num_workers: 1
  pin_memory: false
  batch_size: 1
  sample_batch_size: 1
  num_epochs: 2200
  min_epochs: 5
  patience: 2
  log_step: 50
  sample_epoch: 50000
  number_of_samples: 10
  use_wandb: false

  return fn(*args, **kwargs)
  0% 6/2200 [01:14<5:18:05,  8.70s/it]Early stopping at epoch 6, Best val loss: 0.28959813714027405 achieved at epoch 4
  0% 6/2200 [01:15<7:37:23, 12.51s/it]


# Basic sequence generation example using the created checkpoint

The model successfully trained and now we will use the checkpoint with the lowest validation loss to generate 1 sequence per cell type. Given that the training seed is not fixed, this is not a deterministic result and your validation loss may slightly vary from the cached example.

In [16]:
import os

checkpoint_dir = "checkpoints"
files = os.listdir(checkpoint_dir)
file_paths = sorted([os.path.join(checkpoint_dir, f) for f in files if os.path.isfile(os.path.join(checkpoint_dir, f)) and ".gitkeep" not in f])
best_checkpoint = file_paths[-1]

print(f"All available checkpoints: \n{file_paths}")
print(f"\nUsing checkpoint with lowest validation loss: \n{best_checkpoint}")

All available checkpoints: 
['checkpoints/model_epoch2_step3_valloss_0.348331.pt', 'checkpoints/model_epoch4_step5_valloss_0.289598.pt']

Using checkpoint with lowest validation loss: 
checkpoints/model_epoch4_step5_valloss_0.289598.pt


In [17]:
!uv run sample.py sampling.number_of_samples=1 sampling.sample_batch_size=1 sampling.checkpoint_path=$best_checkpoint

model:
  _target_: src.dnadiffusion.models.unet.UNet
  dim: 200
  channels: 1
  dim_mults:
  - 1
  - 2
  - 4
  resnet_block_groups: 4
data:
  _target_: src.dnadiffusion.data.dataloader.get_dataset_for_sampling
  data_path: data/K562_hESCT0_HepG2_GM12878_12k_sequences_per_group.txt
  saved_data_path: data/encode_data.pkl
  load_saved_data: true
  debug: false
  cell_types: null
diffusion:
  _target_: src.dnadiffusion.models.diffusion.Diffusion
  timesteps: 50
  beta_start: 0.0001
  beta_end: 0.2
sampling:
  checkpoint_path: checkpoints/model_epoch4_step5_valloss_0.289598.pt
  sample_batch_size: 1
  number_of_samples: 1
  guidance_scale: 1.0

Loading checkpoint
Model sent to cuda
Found cell types: ['GM12878_ENCLB441ZZZ', 'HepG2_ENCLB029COU', 'K562_ENCLB843GMH', 'hESCT0_ENCLB449ZZZ']
Generating 1 samples for cell GM12878_ENCLB441ZZZ
100% 1/1 [00:02<00:00,  2.14s/it]
Generating 1 samples for cell HepG2_ENCLB029COU
100% 1/1 [00:01<00:00,  1.68s/it]
Generating 1 samples for cell K562_ENCLB84

# View Generated Sequences

In [18]:
import os
import subprocess

def display_sequences(output_dir="data/outputs"):
    if not os.path.isdir(output_dir):
        print(f"Error: Directory '{output_dir}' not found.")
        return

    print(f"Displaying sequences from: {output_dir}\n")

    for filename in sorted(os.listdir(output_dir)):
        filepath = os.path.join(output_dir, filename)

        if os.path.isfile(filepath) and "gitkeep" not in filepath:
          cell_type = filename.split('_')[0]
          print(f"--- Cell Type: {cell_type} ({filename}) ---")
          result = subprocess.run(['cat', filepath], capture_output=True, text=True, check=True)
          print(result.stdout)
          print("-" * (len(cell_type) + 18) + "\n")

display_sequences()

Displaying sequences from: data/outputs

--- Cell Type: GM12878 (GM12878_ENCLB441ZZZ.txt) ---
CAACAAAGTAAAATCGAATAATAAGGCCGCCCTGACCCCAAAGAGAACCTAAAACCAAAACCAATTTTACAAACACCCAAGTTTCCTTCCAACGCGCCGAAAAAATATATTAAGCTTAAGAACACCAAAGAGTCGTTGACAAGCGCCTTTTATCAGACAACGCCCTACCCAAGACTAACGATAATAAAGTGCGAAGAAGG
-------------------------

--- Cell Type: HepG2 (HepG2_ENCLB029COU.txt) ---
AAAGGACAGAACAACTGGTTTTTCTTTAGGTCATTAGGCCCGTTCAAAGAGGAACACAACCACCCGGGGCGCAAAAAAAATTACCCCAGTAGTTGCCAAATCTCTCAATGTCCTTGATACCCACTCCGAGATCCGGGGATGAAAGAACTGGCAGGTTGGGAGAAAAGACCACGGCAATCTGCGCCACAATAATATTCTAT
-----------------------

--- Cell Type: K562 (K562_ENCLB843GMH.txt) ---
GAGGAGCTAAAACTAATGTGACGCCCGACCAAGTGGCACATCATAAAGGATGTTCGCAATTTACCAATGCCACCCAAAAATGATGAGAACTATTCGCCTCCACTGACGAAACTTCACAAGAGTTCCTTATGAAAGTTTTGCAAAAGCAAGCGCCCCGCCGGGTTTCATCTGCCAACCCACACCCGAAAGAAAAACACAAG
----------------------

--- Cell Type: hESCT0 (hESCT0_ENCLB449ZZZ.txt) ---
GAAAGATGTTGTAAAGGGCCAACAACAGATCTCCATACCTTAAGAAGTTAGGAAAATGGGTGAGATGATGAGCGATTTGG

To see an example of sequence generation using our trained checkpoint checkout the example listed in the README