In [1]:
# !curl -L https://github.com/lucidrains/enformer-pytorch/raw/main/data/test-sample.pt -o data/test-sample.pt

In [1]:
# import os
# os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = "1"

In [1]:
import pandas as pd
import numpy as np
from pathlib import Path

import torch
from torch.utils.data.dataloader import DataLoader
from torch.utils.data import TensorDataset
from torch.cuda.amp import autocast, GradScaler
from enformer_pytorch import Enformer, GenomeIntervalDataset

import kipoiseq
import seaborn as sns
import matplotlib.pyplot as plt

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
if torch.backends.mps.is_available():
    device = torch.device("mps")
    print("Running on MPS:", device)
elif torch.cuda.is_available():
    device = torch.device("cuda")
    print("Running on GPU:", device)
else:
    device = torch.device("cpu")
    print("Running on CPU:", device)

Running on MPS: mps


In [5]:
TEST_MODEL = False

if TEST_MODEL:
    enformer = Enformer.from_pretrained('EleutherAI/enformer-official-rough').to(device)
    data = torch.load('../data/test-sample.pt', map_location=device)
    seq, target = data['sequence'].to(device), data['target'].to(device)

    with torch.no_grad():
        corr_coef = enformer(
            seq,
            target = target,
            return_corr_coef = True,
            head = 'human'
        )

    print(corr_coef)
    assert corr_coef > 0.1

In [3]:
import pandas as pd
import numpy as np
from datetime import datetime
from pathlib import Path

def avg_bin(array, n_bins):
    splitted = np.array_split(array, n_bins)
    binned_array = [np.mean(a) for a in splitted]
    return binned_array

# Data
seq_data = pd.read_pickle("../data/processed/PROMOTERS.pkl")

# Target
target = seq_data['values']
target = torch.stack([torch.tensor(i) for i in target]).unsqueeze(-1)


# DataLoaders
batch_size = 1 # T4 only enough memory for 1 batch size

seq_ds = GenomeIntervalDataset(
    bed_file = '../data/processed/PROMOTERS.bed',                       # bed file - columns 0, 1, 2 must be <chromosome>, <start position>, <end position>
    fasta_file = '../data/hg38.fa',                        # path to fasta file
    return_seq_indices = True,                          # return nucleotide indices (ACGTN) or one hot encodings
    context_length = 196_608,
)

seq_dl = DataLoader(seq_ds, batch_size=batch_size, shuffle=False)

target_ds = TensorDataset(target)
target_dl = DataLoader(target_ds, batch_size=batch_size, shuffle=False)

In [None]:
from enformer_pytorch import seq_indices_to_one_hot # is this only neccessary for CPU and MPS??
from enformer_pytorch.finetune import HeadAdapterWrapper
from datetime import datetime
from pathlib import Path
from torch.cuda.amp import autocast, GradScaler
import torch

# Training loop
model_path = Path("../test-models/")

# Setup paths
now = datetime.now()
formatted_date_time = now.strftime("%Y-%m-%d_%H-%M-%S")
folder_path = model_path.joinpath(formatted_date_time)
folder_path.mkdir(parents=True, exist_ok=True)

# Model
enformer = Enformer.from_pretrained('EleutherAI/enformer-official-rough')
enformer = HeadAdapterWrapper(
    enformer=enformer,
    num_tracks=1,
    post_transformer_embed=False
).to(device)
_ = enformer.train()

scaler = GradScaler()
optimizer = torch.optim.Adam(enformer.parameters(), lr=0.0001)

num_epochs = 1
accumulation_steps = 1

for epoch in range(num_epochs):
    running_loss = 0.0  # Keep track of the running loss
    for idx, (seq_batch, (target_batch,)) in enumerate(zip(seq_dl, target_dl)):
        if device != 'cuda':
            seq_batch = seq_indices_to_one_hot(seq_batch)
        seq_batch = seq_batch.to(dtype=torch.float32, device=device)
        target_batch = target_batch.to(dtype=torch.float32, device=device)
        
        with autocast():
            # Forward pass
            loss = enformer(seq_batch, target=target_batch)

        # Backward pass
        scaler.scale(loss).backward()

        running_loss += loss.item()

        if (idx + 1) % accumulation_steps == 0:  # Update every accumulation_steps
            # Gradient accumulation
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()

    avg_loss = running_loss / len(seq_dl)
    print(f"Epoch [{epoch+1}/{num_epochs}], Step [{idx+1}/{len(seq_dl)}], Avg Loss: {avg_loss:.4f}")
    
    # Save model (optional)
    model_path = folder_path.joinpath(f'enformer-ft_epoch={epoch}_loss={avg_loss:.4f}.pth')
    torch.save(enformer.state_dict(), model_path)


## Testing modules

In [2]:
from eft.data import CustomGenomeIntervalDataset
from eft.models import EnformerTX
from scipy.stats import pearsonr

In [3]:
data_dir = Path('../data/sequences')
ckpt_path = "../lightning_logs/90-10/version_2/checkpoints/epoch=3-step=1116.ckpt"
batch_size = 4

val = CustomGenomeIntervalDataset(
    bed_file = data_dir.joinpath('val.bed'),
    fasta_file = '../data/hg38.fa',
    return_seq_indices = True,
    context_length = 196_608,
)
val_dl = DataLoader(val, batch_size=batch_size, shuffle=False)

enformer = EnformerTX.load_from_checkpoint(ckpt_path)

In [4]:
filter_rand = False

_ = enformer.eval()

all_target = []
all_pred = []
corr_coefs = []

num_epochs = 1

with torch.no_grad():
  for epoch in range(num_epochs):
      for idx, batch in enumerate(val_dl):
          seq, target = batch
          
          # filter out random seq with 0 expression
          if filter_rand:
            promoters = (target.sum(axis=1) != 0)
            seq = seq[promoters]
            target = target[promoters]

          out = enformer(seq.to(enformer.device), None)
          out = out.squeeze(-1).cpu().numpy()
          target = target.cpu().numpy()
          all_target.append(target)
          all_pred.append(out)
          # for i in range(len(target)):
          #   rho, pval = pearsonr(target[i], out[i])
          #   corr_coefs.append(rho)
          if idx == 2: break

all_target = np.concatenate(all_target)
all_pred = np.concatenate(all_pred)
# # viz
# avg_cc = np.mean(corr_coefs)
# sns.violinplot(corr_coefs)
# sns.swarmplot(corr_coefs, color='k')
# plt.title(f"Average pearsonr={avg_cc:.4f}")

In [6]:
all_target
all_pred

array([[0.        , 0.02049967, 0.04243984, ..., 0.01965334, 0.05225051,
        0.05141183],
       [0.03897707, 0.05997542, 0.04063922, ..., 0.39189282, 0.27917278,
        0.11266407],
       [0.01102549, 0.08504591, 0.09557272, ..., 0.01599106, 0.01685769,
        0.0076321 ],
       ...,
       [0.02227262, 0.00398912, 0.01756323, ..., 0.05750633, 0.05960308,
        0.05663969],
       [0.        , 0.00470938, 0.01711997, ..., 0.00089461, 0.00304725,
        0.00053117],
       [0.06579287, 0.05983687, 0.02326989, ..., 0.01822758, 0.        ,
        0.0148728 ]], dtype=float32)

array([[nan, nan, nan, ..., nan, nan, nan],
       [nan, nan, nan, ..., nan, nan, nan],
       [nan, nan, nan, ..., nan, nan, nan],
       ...,
       [nan, nan, nan, ..., nan, nan, nan],
       [nan, nan, nan, ..., nan, nan, nan],
       [nan, nan, nan, ..., nan, nan, nan]], dtype=float32)

In [29]:
## downsampel train/val
train = pd.read_csv("../data/sequences/train_50-50.bed", sep='\t', header=None)
val = pd.read_csv("../data/sequences/val_50-50.bed", sep='\t', header=None)
train.head()

Unnamed: 0,0,1,2,3,4
0,chr22,15527192,15529192,promoter,"[0.01468216348439455, 0.005263423758812926, 0...."
1,chr12,27709822,27711822,promoter,"[0.052080263807014984, 0.05263428305360404, 0...."
2,chrY,49339788,49438092,random,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
3,chrY,49165093,49263397,random,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
4,chr22,14123836,14222140,random,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."


In [30]:
def balance_dataset(dataset, label_column, positive_fraction):
    """
    Balances the dataset based on the specified fraction of positive cases.

    Parameters:
    - dataset (pd.DataFrame): The dataset to be balanced.
    - label_column (int or str): The column in the dataset containing the class labels.
    - positive_fraction (float): The desired fraction of positive cases in the dataset.

    Returns:
    - pd.DataFrame: A balanced dataset.
    """

    # Count the number of positive instances
    pos_count = dataset[label_column].value_counts().loc['promoter']

    # Calculate the number of negative instances needed for the desired balance
    total_count = pos_count / positive_fraction
    neg_count = int(total_count - pos_count)

    # Extract positive and negative samples
    positives = dataset[dataset[label_column] == 'promoter']
    negatives = dataset[dataset[label_column] != 'promoter'].sample(n=neg_count)

    # Concatenate and shuffle the dataset
    balanced_dataset = pd.concat([positives, negatives]).sample(frac=1).reset_index(drop=True)

    return balanced_dataset

# Example usage
balanced_train = balance_dataset(train, 3, 0.9)
balanced_val = balance_dataset(val, 3, 0.9)


In [32]:
balanced_train.to_csv("../data/sequences/train.bed", sep='\t', header=None, index=False)
balanced_val.to_csv("../data/sequences/val.bed", sep='\t', header=None, index=False)