# Cloning repository and installing dependencies

In [None]:
!git clone https://github.com/pinellolab/DNA-Diffusion.git && cd DNA-Diffusion && uv sync

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
[2mnvidia-cufft-cu12[0m [32m--------------[2m----------------[0m[0m 84.01 MiB/190.95 MiB
[2mnvidia-nccl-cu12[0m [32m--------------[2m----------------[0m[0m 84.00 MiB/191.99 MiB
[2mnvidia-cusparse-cu12[0m [32m-------------[2m-----------------[0m[0m 83.59 MiB/206.53 MiB
[2mnvidia-cublas-cu12[0m [32m-------[2m-----------------------[0m[0m 83.89 MiB/374.93 MiB
[2mnvidia-cudnn-cu12[0m [32m-----[2m-------------------------[0m[0m 83.74 MiB/544.54 MiB
[2K[10A[37m⠴[0m [2mPreparing packages...[0m (166/176)
[2mmkdocs-material[0m [32m---------[2m---------------------[0m[0m 2.47 MiB/8.30 MiB
[2mtriton    [0m [32m------------------[2m------------[0m[0m 85.00 MiB/149.25 MiB
[2mnvidia-cusparselt-cu12[0m [32m-----------------[2m-------------[0m[0m 84.10 MiB/149.52 MiB
[2mnvidia-cusolver-cu12[0m [32m-----------------[2m-------------[0m[0m 83.97 MiB/150.90 MiB
[2mnvidia-cufft-cu12

In [None]:
%cd DNA-Diffusion

/content/DNA-Diffusion


# Basic Training Example

Below we provide an example of the training using the debug flag. This will only train the model on a single sequence for a minimum of 5 epochs with a patience parameter of 2 epochs.

In [None]:
!uv run train.py -cn train_debug

model:
  _target_: src.dnadiffusion.models.unet.UNet
  dim: 200
  channels: 1
  dim_mults:
  - 1
  - 2
  - 4
  resnet_block_groups: 4
data:
  _target_: src.dnadiffusion.data.dataloader.get_dataset
  data_path: data/K562_hESCT0_HepG2_GM12878_12k_sequences_per_group.txt
  saved_data_path: data/encode_data.pkl
  load_saved_data: true
  debug: true
optimizer:
  _target_: torch.optim.Adam
  lr: 0.0001
diffusion:
  _target_: src.dnadiffusion.models.diffusion.Diffusion
  timesteps: 50
  beta_start: 0.0001
  beta_end: 0.2
training:
  distributed: false
  precision: float32
  num_workers: 1
  pin_memory: false
  batch_size: 1
  sample_batch_size: 1
  num_epochs: 2200
  min_epochs: 5
  patience: 2
  log_step: 50
  sample_epoch: 50000
  number_of_samples: 10
  use_wandb: false

  return fn(*args, **kwargs)
  0% 7/2200 [00:25<1:58:12,  3.23s/it]Early stopping at epoch 7, Best val loss: 0.30135953426361084 achieved at epoch 5
  0% 7/2200 [00:25<2:13:58,  3.67s/it]


# Basic sequence generation example using the created checkpoint

The model successfully trained and got the lowest validation loss at epoch 5 with a validation loss of 0.30135953426361084. Now we will use this checkpoint to generate 1 sequence per cell type.

In [None]:
!uv run sample.py sampling.number_of_samples=1 sampling.sample_batch_size=1 sampling.checkpoint_path="/content/DNA-Diffusion/checkpoints/model_epoch5_step6_valloss_0.301360.pt"

model:
  _target_: src.dnadiffusion.models.unet.UNet
  dim: 200
  channels: 1
  dim_mults:
  - 1
  - 2
  - 4
  resnet_block_groups: 4
data:
  _target_: src.dnadiffusion.data.dataloader.get_dataset
  data_path: data/K562_hESCT0_HepG2_GM12878_12k_sequences_per_group.txt
  saved_data_path: data/encode_data.pkl
  load_saved_data: true
  debug: false
diffusion:
  _target_: src.dnadiffusion.models.diffusion.Diffusion
  timesteps: 50
  beta_start: 0.0001
  beta_end: 0.2
sampling:
  checkpoint_path: /content/DNA-Diffusion/checkpoints/model_epoch5_step6_valloss_0.301360.pt
  sample_batch_size: 1
  number_of_samples: 1
  guidance_scale: 1.0

(<src.dnadiffusion.data.dataloader.SequenceDataset object at 0x7bf643a71a90>, <src.dnadiffusion.data.dataloader.SequenceDataset object at 0x7bf64fa74ce0>, [1, 2, 3, 4], {1: 'GM12878_ENCLB441ZZZ', 2: 'HepG2_ENCLB029COU', 3: 'K562_ENCLB843GMH', 4: 'hESCT0_ENCLB449ZZZ'})
Loading checkpoint
Sending model to device
Generating 1 samples for cell GM12878_ENCLB441ZZ

# View Generated Sequences

In [None]:
import os
import subprocess

def display_sequences(output_dir="data/outputs"):
    if not os.path.isdir(output_dir):
        print(f"Error: Directory '{output_dir}' not found.")
        return

    print(f"Displaying sequences from: {output_dir}\n")

    for filename in sorted(os.listdir(output_dir)):
        filepath = os.path.join(output_dir, filename)

        if os.path.isfile(filepath) and "gitkeep" not in filepath:
          cell_type = filename.split('_')[0]
          print(f"--- Cell Type: {cell_type} ({filename}) ---")
          result = subprocess.run(['cat', filepath], capture_output=True, text=True, check=True)
          print(result.stdout)
          print("-" * (len(cell_type) + 18) + "\n")

display_sequences()

Displaying sequences from: data/outputs

--- Cell Type: GM12878 (GM12878_ENCLB441ZZZ.txt) ---
GCAACCAGGCGTGTGTATACCATGACTCACCAGTGCTTTCCGCCTACATGCTGTGTGCAGTTTATACGATACCGCTTCCTGCGCGTTTCCATGAATGCCACCGTCGCATCACGAGTTACGTGTGAAGATGCAGCAATCGTTTTGTCGAAGCATAACCATGTGTGCCAGCAAACCCATCGACAGTGTATCTCATCATCATG
-------------------------

--- Cell Type: HepG2 (HepG2_ENCLB029COU.txt) ---
AATCAGTATTAATATTCGTACACATGCAATTCACGTGTGGCCTGCCTATTGCAGCACTGATGTCCATGGTACCGGGAAACCTCGTTCAAGTAAGTAACAACCAAAAACATGCCACTAACCGTGAATACACGCGCTTCGAGTCGAATGTATGTATAGGTATTCGTGCAACATGCATATTTTCACACCTAGCACGCGGGATA
-----------------------

--- Cell Type: K562 (K562_ENCLB843GMH.txt) ---
AAGCATAGTGCTTAACAAACGTTAAATCCACCGACATACAACAATTCACCAACAACGTGTGTTTGCAGAAGTTCACAGCCCTGCAAATGTAGAATGGACAAGTCGCCACAACCATAAGAGATTCAAGTACGAGTGCAAGCATGTTTCTCCTATCAATACCGGCGTAAGAAACACCGAGTGTATTATGCCCAACTGCATAG
----------------------

--- Cell Type: hESCT0 (hESCT0_ENCLB449ZZZ.txt) ---
AAAGTCCCAAGCTTCATCACGAGCTTCCCATCCGACAACTCCATGTACAATTCACAACCCTGACCATGAACATCGAGAAT

To see an example of sequence generation using our trained checkpoint checkout the example listed in the README