In [1]:
# !curl -L https://github.com/lucidrains/enformer-pytorch/raw/main/data/test-sample.pt -o data/test-sample.pt

In [1]:
import os
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = "1"

In [2]:
import pandas as pd
import numpy as np
from pathlib import Path

import torch
from torch.utils.data.dataloader import DataLoader
from torch.utils.data import TensorDataset
from torch.cuda.amp import autocast, GradScaler
from enformer_pytorch import Enformer, GenomeIntervalDataset

import kipoiseq
import seaborn as sns
import matplotlib.pyplot as plt

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
if torch.backends.mps.is_available():
    device = torch.device("mps")
    print("Running on MPS:", device)
elif torch.cuda.is_available():
    device = torch.device("cuda")
    print("Running on GPU:", device)
else:
    device = torch.device("cpu")
    print("Running on CPU:", device)

Running on MPS: mps


In [5]:
TEST_MODEL = False

if TEST_MODEL:
    enformer = Enformer.from_pretrained('EleutherAI/enformer-official-rough').to(device)
    data = torch.load('../data/test-sample.pt', map_location=device)
    seq, target = data['sequence'].to(device), data['target'].to(device)

    with torch.no_grad():
        corr_coef = enformer(
            seq,
            target = target,
            return_corr_coef = True,
            head = 'human'
        )

    print(corr_coef)
    assert corr_coef > 0.1

In [3]:
import pandas as pd
import numpy as np
from datetime import datetime
from pathlib import Path

def avg_bin(array, n_bins):
    splitted = np.array_split(array, n_bins)
    binned_array = [np.mean(a) for a in splitted]
    return binned_array

# Data
seq_data = pd.read_pickle("../data/processed/PROMOTERS.pkl")

# Target
target = seq_data['values']
target = torch.stack([torch.tensor(i) for i in target]).unsqueeze(-1)


# DataLoaders
batch_size = 1 # T4 only enough memory for 1 batch size

seq_ds = GenomeIntervalDataset(
    bed_file = '../data/processed/PROMOTERS.bed',                       # bed file - columns 0, 1, 2 must be <chromosome>, <start position>, <end position>
    fasta_file = '../data/hg38.fa',                        # path to fasta file
    return_seq_indices = True,                          # return nucleotide indices (ACGTN) or one hot encodings
    context_length = 196_608,
)

seq_dl = DataLoader(seq_ds, batch_size=batch_size, shuffle=False)

target_ds = TensorDataset(target)
target_dl = DataLoader(target_ds, batch_size=batch_size, shuffle=False)

In [None]:
from enformer_pytorch import seq_indices_to_one_hot # is this only neccessary for CPU and MPS??
from enformer_pytorch.finetune import HeadAdapterWrapper
from datetime import datetime
from pathlib import Path
from torch.cuda.amp import autocast, GradScaler
import torch

# Training loop
model_path = Path("../test-models/")

# Setup paths
now = datetime.now()
formatted_date_time = now.strftime("%Y-%m-%d_%H-%M-%S")
folder_path = model_path.joinpath(formatted_date_time)
folder_path.mkdir(parents=True, exist_ok=True)

# Model
enformer = Enformer.from_pretrained('EleutherAI/enformer-official-rough')
enformer = HeadAdapterWrapper(
    enformer=enformer,
    num_tracks=1,
    post_transformer_embed=False
).to(device)
_ = enformer.train()

scaler = GradScaler()
optimizer = torch.optim.Adam(enformer.parameters(), lr=0.0001)

num_epochs = 1
accumulation_steps = 1

for epoch in range(num_epochs):
    running_loss = 0.0  # Keep track of the running loss
    for idx, (seq_batch, (target_batch,)) in enumerate(zip(seq_dl, target_dl)):
        if device != 'cuda':
            seq_batch = seq_indices_to_one_hot(seq_batch)
        seq_batch = seq_batch.to(dtype=torch.float32, device=device)
        target_batch = target_batch.to(dtype=torch.float32, device=device)
        
        with autocast():
            # Forward pass
            loss = enformer(seq_batch, target=target_batch)

        # Backward pass
        scaler.scale(loss).backward()

        running_loss += loss.item()

        if (idx + 1) % accumulation_steps == 0:  # Update every accumulation_steps
            # Gradient accumulation
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()

    avg_loss = running_loss / len(seq_dl)
    print(f"Epoch [{epoch+1}/{num_epochs}], Step [{idx+1}/{len(seq_dl)}], Avg Loss: {avg_loss:.4f}")
    
    # Save model (optional)
    model_path = folder_path.joinpath(f'enformer-ft_epoch={epoch}_loss={avg_loss:.4f}.pth')
    torch.save(enformer.state_dict(), model_path)


## Testing modules

In [2]:
from eft.data import CustomGenomeIntervalDataset

In [5]:
prom = pd.read_csv('../data/processed/PROMOTERS.bed', sep='\t', header=None)
prom.sample(n=200, random_state=42).to_csv('../tests/data/train.bed', sep='\t', header=None, index=False)
prom.sample(n=10, random_state=42).to_csv('../tests/data/val.bed', sep='\t', header=None, index=False)

Unnamed: 0,0,1,2,3
0,chr1,64419,66419,"[0.0539362762208012, 0.0002493201924318617, 0...."
1,chr1,450678,452678,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
2,chr1,685654,687654,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
3,chr1,922923,924923,"[0.07061302648349242, 0.07809266664765098, 0.0..."
4,chr1,958256,960256,"[0.13313712802800265, 0.16205803697759455, 0.1..."


In [3]:
data_dir = Path('../data/processed')
train = CustomGenomeIntervalDataset(
    bed_file = data_dir.joinpath('PROMOTERS.bed'),
    fasta_file = '../data/hg38.fa',
    return_seq_indices = True,
    context_length = 196_608,
)