In [1]:
!curl -L https://github.com/lucidrains/enformer-pytorch/raw/main/data/test-sample.pt -o data/test-sample.pt

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0
100 20.1M  100 20.1M    0     0   752k      0  0:00:27  0:00:27 --:--:--  766k0   984k      0  0:00:20  0:00:10  0:00:10 1091k


In [1]:
import os
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = "1"

In [2]:
import pandas as pd
import numpy as np
from pathlib import Path

import torch
from torch.utils.data.dataloader import DataLoader
from torch.utils.data import TensorDataset
from torch.cuda.amp import autocast, GradScaler
from enformer_pytorch import Enformer, GenomeIntervalDataset

import kipoiseq
import seaborn as sns
import matplotlib.pyplot as plt

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
if torch.backends.mps.is_available():
    device = torch.device("mps")
    print("Running on MPS:", device)
elif torch.cuda.is_available():
    device = torch.device("cuda")
    print("Running on GPU:", device)
else:
    device = torch.device("cpu")
    print("Running on CPU:", device)

Running on MPS: mps


In [4]:
TEST_MODEL = False

if TEST_MODEL:
  enformer = Enformer.from_pretrained('EleutherAI/enformer-official-rough').to(device)
  data = torch.load('data/test-sample.pt', map_location=device)
  seq, target = data['sequence'].to(device), data['target'].to(device)

  with torch.no_grad():
      corr_coef = enformer(
          seq,
          target = target,
          return_corr_coef = True,
          head = 'human'
      )

  print(corr_coef)
  assert corr_coef > 0.1

In [27]:
target.shape
seq_data.shape

torch.Size([19993, 896, 1])

(19993, 4)

In [29]:
import pandas as pd
import numpy as np
from datetime import datetime
from pathlib import Path

def avg_bin(array, n_bins):
    splitted = np.array_split(array, n_bins)
    binned_array = [np.mean(a) for a in splitted]
    return binned_array

# Data
# data_path = Path("data/")
seq_data = pd.read_pickle("../data/processed/PROMOTERS.pkl")

# Target
target = seq_data['values']
target = torch.stack([torch.tensor(i) for i in target]).unsqueeze(-1)
target.shape


# DataLoaders
batch_size = 4 # T4 only enough memory for 1 batch size

seq_ds = GenomeIntervalDataset(
    bed_file = '../data/processed/PROMOTERS.bed',                       # bed file - columns 0, 1, 2 must be <chromosome>, <start position>, <end position>
    fasta_file = '../data/hg38.fa',                        # path to fasta file
    return_seq_indices = True,                          # return nucleotide indices (ACGTN) or one hot encodings
    context_length = 196_608,
)

seq_dl = DataLoader(seq_ds, batch_size=batch_size, shuffle=False)

target_ds = TensorDataset(target)
target_dl = DataLoader(target_ds, batch_size=batch_size, shuffle=False)

torch.Size([19993, 896, 1])

In [47]:
seq_batch.shape
target_batch.shape

torch.Size([4, 196608])

torch.Size([4, 896, 1])

In [None]:
from enformer_pytorch.finetune import HeadAdapterWrapper

# Training loop
model_path = Path("../models/")

# Setup paths
now = datetime.now()
formatted_date_time = now.strftime("%Y-%m-%d_%H-%M-%S")

folder_path = model_path.joinpath(formatted_date_time)
folder_path.mkdir(parents=True, exist_ok=True)

# Model
enformer = Enformer.from_pretrained('EleutherAI/enformer-official-rough')
enformer = HeadAdapterWrapper(
    enformer=enformer,
    num_tracks=1,
    post_transformer_embed=False
).to(device)
_ = enformer.train()

scaler = GradScaler()
optimizer = torch.optim.Adam(enformer.parameters(), lr=0.0001)
losses = []

num_epochs = 20
accumulation_steps = 8 # effective batch_size*accumulation_steps

for epoch in range(num_epochs):
  for idx, (seq_batch, (target_batch,)) in enumerate(zip(seq_dl, target_dl)):
    seq_batch = seq_batch.to(dtype=torch.float32, device=device)
    target_batch = target_batch.to(dtype=torch.float32, device=device)
    
    with autocast():
      # Forward pass
      loss = enformer(seq_batch, target=target_batch)

    # Backward pass
    scaler.scale(loss).backward()

    losses.append(loss.item())

    # Gradient accumulation
    scaler.step(optimizer)
    scaler.update()
    optimizer.zero_grad()

  print(f"Epoch [{epoch+1}/{num_epochs}], Step [{idx+1}/{len(seq_dl)}], Loss: {loss.item():.4f}")
  # Save model (optional)
  model_path = folder_path.joinpath(f'enformer-ft_epoch={epoch}_loss={loss:.4f}.pth')
  torch.save(enformer.state_dict(), model_path)