In [2]:
import argparse
import numpy as np
import os
import torch
from torch.utils.data import DataLoader

from model import BertAttentionScoreExtractor
from utils.data_preprocessing import load_dataset
from utils.attention_utils import extract_kmer_attention_vectors

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
parser = argparse.ArgumentParser()
parser.add_argument("--batch_size", type=int, default=8, help="") # Set default value to 8 because the size of the dataset (2968) is divisible by 8
parser.add_argument("--max_length", type=int, default=200, help="")
args = parser.parse_args(args=[])

In [5]:
# Define the output directory and file path for saving attention scores
output_dir = "./outputs/attention_scores"
os.makedirs(output_dir, exist_ok=True)

In [6]:
kmer_values = [3, 4, 5, 6]
model_date = "2025-02-27_V2"

for kmer in kmer_values:
    args.kmer = kmer
    args.model_path = f"./outputs/identifier_models/{model_date}/{kmer}-mer"
    args.train_data_path = f"./data/enhancer_identification/{kmer}-mer_identification_train.txt"

    model = BertAttentionScoreExtractor.from_pretrained(args.model_path, output_attentions=True).to(device)

    train_dataset = load_dataset(args, validation=False)
    train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=False)

    train_attention_scores = extract_kmer_attention_vectors(model, train_dataloader, args)

    file_path = f"./outputs/attention_scores/{model_date}/{args.kmer}-mer_train_attention_scores.npy"
    os.makedirs(f"./outputs/attention_scores/{model_date}", exist_ok=True)
    
    # Save the train_attention_scores
    np.save(file_path, train_attention_scores)

    print(f"{kmer}-mer attention scores saved to {file_path}")
    print(f"Size: {train_attention_scores.shape}")

Attention scores shape for 3-mer: (2968, 12, 200, 200)
3-mer attention scores saved to ./outputs/attention_scores/2025-02-27_V2/3-mer_train_attention_scores.npy
Size: (2968, 200)
Attention scores shape for 4-mer: (2968, 12, 199, 199)
4-mer attention scores saved to ./outputs/attention_scores/2025-02-27_V2/4-mer_train_attention_scores.npy
Size: (2968, 200)
Attention scores shape for 5-mer: (2968, 12, 198, 198)
5-mer attention scores saved to ./outputs/attention_scores/2025-02-27_V2/5-mer_train_attention_scores.npy
Size: (2968, 200)
Attention scores shape for 6-mer: (2968, 12, 197, 197)
6-mer attention scores saved to ./outputs/attention_scores/2025-02-27_V2/6-mer_train_attention_scores.npy
Size: (2968, 200)
