In [None]:
import torch
from transformers import ElectraConfig, ElectraForSequenceClassification, PreTrainedTokenizerFast, get_linear_schedule_with_warmup
from torch.optim import AdamW
from torch.utils.data import DataLoader, random_split, Dataset
from torch.utils.data.sampler import SubsetRandomSampler
from torch.nn import BCEWithLogitsLoss
import argparse
import os
import pandas as pd
import numpy as np
import logging
import copy
import random
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from torch.utils.data import DataLoader

In [None]:


# Load the configuration from a JSON file
config = ElectraConfig.from_json_file("/Users/wejarrard/projects/atacToChip/finetuning/preprocessing/output/discriminator.json")
config.num_labels = 1  # Adjust the number of output labels

# Initialize the model
model = ElectraForSequenceClassification(config)

# Identify the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define the path where you saved the model
save_path = 'best_model'

# Load the model weights
model.load_state_dict(torch.load(save_path, map_location=device))

# Make sure to call model.to(device) to ensure that the model's parameters are on the right device
model = model.to(device)


In [None]:
def get_tokenizer(tokenizer_path):
    tokenizer = PreTrainedTokenizerFast(
        tokenizer_file=tokenizer_path, max_len=512)
    return tokenizer

class GenomicsDataset(Dataset):
    def __init__(self, dir_label_dict, min_val, max_val, dataset="train", tokenizer=get_tokenizer(os.path.join("/Users/wejarrard/projects/atacToChip/finetuning/preprocessing/output", 'tokenizer.json'))):
        self.tokenizer = tokenizer
        self.min_val = min_val
        self.max_val = max_val
        self.dataset = dataset
        self.file_paths = []
        self.labels = []
        for dir_path, label in dir_label_dict.items():
            files = [os.path.join(dir_path, file) for file in os.listdir(dir_path) if not pd.read_feather(os.path.join(dir_path, file)).empty]
            self.file_paths += files
            self.labels += [label] * len(files)

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        file_path = self.file_paths[idx]
        label = self.labels[idx]
        
        df = pd.read_feather(file_path)

        # Seperate the data into the sequence, chromosome, and read counts
        if self.dataset == "train":
            # Introduce a small random cut
            cut = random.randint(0, 10)

            # Apply the cut
            dna_sequence = df["base"].str.cat(sep="")[cut:]
            read_counts = torch.tensor(df["reads"].values.astype(np.float64))[cut:]
            positions = torch.tensor(df["pos"].values.astype(np.int64))[cut:]
        else:
            dna_sequence = df["base"].str.cat(sep="")
            read_counts = torch.tensor(df["reads"].values.astype(np.float64))
            positions = torch.tensor(df["pos"].values.astype(np.int64))
        chromosome = df["chrom"].values[0]
        # Normalize the read counts
        read_counts = self.min_max_norm(read_counts, self.min_val, self.max_val)

        tokenized_sequence = self.tokenizer.encode(dna_sequence, return_tensors="pt")

        cur_pos = 0

        final_read_counts = torch.zeros(len(tokenized_sequence[0]))
        final_positions = torch.zeros(len(tokenized_sequence[0]))

        for i in range(len(tokenized_sequence[0])):

            # Get the token
            token = tokenized_sequence[0][i]

            # Get length of token
            token_length = len(self.tokenizer.decode(token))

            # Get the read counts for the token
            token_read_counts = read_counts[cur_pos:cur_pos+token_length]

            token_positions = positions[cur_pos:cur_pos+token_length]

            # Get the average read count for the token
            token_read_count = token_read_counts.mean()

            #get the beginning of the token position
            token_position = token_positions.min()

            # Replace the read counts for the token with the average read count
            final_read_counts[i] = token_read_count

            final_positions[i] = token_position

            # Update the current position
            cur_pos += token_length

        # Add batch dimension
        final_read_counts = torch.unsqueeze(final_read_counts, 0)
        final_positions = torch.unsqueeze(final_positions, 0)

        # Get the tokenized sequence, chromosome, read counts and position for the 512 tokens if the sequence is longer than 512
        if len(tokenized_sequence[0]) > 512:  
            if self.dataset == "train":
                start = random.randint(0, len(tokenized_sequence[0]) - 512)
                tokenized_sequence = tokenized_sequence[:, start:start+512]
                final_read_counts = final_read_counts[:, start:start+512]
                final_positions = final_positions[:, start:start+512]
            else:
                # Get the middle 512 tokens
                start = (len(tokenized_sequence[0]) - 512) // 2
                tokenized_sequence = tokenized_sequence[:, start:start+512]
                final_read_counts = final_read_counts[:, start:start+512]
                final_positions = final_positions[:, start:start+512]


        elif len(tokenized_sequence[0]) < 512:

            # Pad the tokenized sequence, masked tokens, read counts and position with 0s
            tokenized_sequence = torch.nn.functional.pad(
                tokenized_sequence, (0, 512 - len(tokenized_sequence[0])))
            final_read_counts = torch.nn.functional.pad(
                final_read_counts, (0, 512 - len(final_read_counts[0])))
            final_positions = torch.nn.functional.pad(
                final_positions, (0, 512 - len(final_positions[0])))

        # create position tensor
        position = torch.arange(0, 512, 1)

        # Create chromosome tensor
        chromosome = torch.full((1, 512), chromosome)

        # Add batch dimension
        position = torch.unsqueeze(position, 0)


        tokenized_sequence = tokenized_sequence.squeeze(1)
        position = position.squeeze(1)
        chromosome = chromosome.squeeze(1)
        final_read_counts = final_read_counts.squeeze(1)
        final_positions = final_positions.squeeze(1)

        tokenized_sequence = tokenized_sequence.squeeze(0)
        position = position.squeeze(0)
        chromosome = chromosome.squeeze(0)
        final_read_counts = final_read_counts.squeeze(0)
        final_positions = final_positions.squeeze(0)

        # Return the tokenized sequence, chromosome, read counts and position
        return {
            'input_ids': tokenized_sequence, 
            'position_ids': position, 
            'chromosome': chromosome, 
            'location': final_positions,
            'reads': final_read_counts, 
            'labels': torch.tensor(label)
        }

    def min_max_norm(self, x, min_val, max_val):
        for i in range(len(x)):
            x[i] = (x[i] - min_val) / (max_val - min_val)
            if x[i] < 0:
                x[i] = 0
            if x[i] > 1:
                x[i] = 1
        return x

In [None]:
    
min_val, max_val = 0, 366.0038259577389
data_dir = "/Users/wejarrard/projects/atacToChip/finetuning/testing/output/test"

dir_label_dict = {
    os.path.join(data_dir, "atacseq_only"): 0,
    # os.path.join(data_dir, "chipseq_only"): 1,
    os.path.join(data_dir, "intersecting"): 1
}

valid_dataset = GenomicsDataset(dir_label_dict, min_val, max_val, dataset='valid')

# create a DataLoader for the dataset
valid_loader = DataLoader(valid_dataset, batch_size=32, shuffle=False)

# Initialize lists to store predictions, labels, and locations
all_preds = []
all_labels = []
all_locations = []

# Switch to eval mode
model.eval()

# Get predictions for all batches in the validation dataset
with torch.no_grad():
    for batch in valid_loader:
        inputs = {k: v.to(device) for k, v in batch.items() if k != 'labels' and k != 'chromosome' and k != 'location'}
        outputs = model(**inputs)
        prob = torch.sigmoid(outputs.logits)
        pred_labels = torch.round(prob).cpu().numpy()
        all_preds.extend(pred_labels)
        
        true_labels = batch['labels'].cpu().numpy()
        all_labels.extend(true_labels)

        locations = batch['location'].cpu().numpy()
        all_locations.extend(locations)

# Create lists to store correctly and incorrectly predicted locations
correct_locations = []
incorrect_locations = []

# Compare predicted and true labels and store locations
for pred, true, loc in zip(all_preds, all_labels, all_locations):
    if pred == true:
        correct_locations.append(loc)
    else:
        incorrect_locations.append(loc)

# Calculate metrics
accuracy = accuracy_score(all_labels, all_preds)
precision = precision_score(all_labels, all_preds)
recall = recall_score(all_labels, all_preds)
f1 = f1_score(all_labels, all_preds)

np.savetxt("correct_locations_both.txt", correct_locations)
np.savetxt("incorrect_locations_both.txt", incorrect_locations)

print(f'Accuracy: {accuracy:.2f}')
print(f'Precision: {precision:.2f}')
print(f'Recall: {recall:.2f}')
print(f'F1 score: {f1:.2f}')
print(f'Correctly predicted locations: {correct_locations}')
print(f'Incorrectly predicted locations: {incorrect_locations}')
