In [None]:
from Bio import SeqIO
import pandas as pd
from tqdm import tqdm
import os
import shutil
import numpy as np

In [None]:
# to parse pandas df. Extract columns, change start/end, add position, delete first and last 256 bp
# drop_nan_strategy = any - will delete columns with NaN in any column. This is for Zero Shot
# drop_nan_strategy = both - delete only if Nan in both; is for Fine Tuning
def prepare_df_for_PlantCad_zeroshot(inp_df, ref_col, alt_col, frame=512, drop_nan_strategy="any"):
    df = inp_df.copy()
    df = df[["chr", "start", "end", ref_col, alt_col]]
    
    # Удаление NaN по выбранной стратегии
    if drop_nan_strategy == "any":
        df = df.dropna(subset=[ref_col, alt_col], how="any")
    elif drop_nan_strategy == "both":
        df = df[~(df[ref_col].isna() & df[alt_col].isna())]
    else:
        raise ValueError("Invalid drop_nan_strategy. Choose 'any' or 'both'.")
    
    # Применяем фильтр для удаления строк, где 'start' меньше (frame/2 - 1)
    df = df[df["start"] >= (frame // 2 - 1)]
    
    # Создаем DataFrame df_zero_shot_input_coordinates. Координаты pos начинаются с 1
    df_zero_shot_input_coordinates = pd.DataFrame({
        "chr": df["chr"],
        "start": df["start"] - (frame // 2 - 1),
        "end": df["start"] + 1 + (frame // 2),
        "pos": df["start"] + 1,
        "ref": df[ref_col],
        "alt": df[alt_col]
    })
    
    return df_zero_shot_input_coordinates


In [None]:
def add_seqs_to_df_input_coordinates(genome_fasta, df_zero_shot_input_coordinates):
    # Read the genome FASTA file
    genome = {record.id: record.seq for record in SeqIO.parse(genome_fasta, "fasta")}
    
    # Define a function to extract the sequence based on coordinates
    def extract_sequence(row):
        chrom = row['chr']
        start = row['start']
        end = row['end']
        # Check if chromosome exists in genome
        if chrom in genome:
            # Check if 'end' is greater than the sequence length
            seq_length = len(genome[chrom])
            if end > seq_length:
                return None
            # Extract the subsequence, convert to string and uppercase
            subsequence = genome[chrom][start:end] 
            return str(subsequence).upper()
        else:
            return None 

    # Add progress bar to the DataFrame processing
    tqdm.pandas(desc="Adding sequences")
    df_zero_shot_input_coordinates['sequences'] = df_zero_shot_input_coordinates.progress_apply(extract_sequence, axis=1)
    
    # Drop rows where 'sequence' is None (i.e., where 'end' was out of bounds)
    df_zero_shot_input_coordinates = df_zero_shot_input_coordinates.dropna(subset=['sequences'])
    
    return df_zero_shot_input_coordinates

In [None]:
# change 255 to ref state
def correct_sequence_ref(inp_df):
    df = inp_df.copy()
    
    def modify_sequence(row):
        # Заменяем символ на 255-й позиции в зависимости от label
        char_to_replace = row["ref"]
        sequence = row["sequences"]
        modified_sequence = sequence[:255] + char_to_replace + sequence[256:]
        return modified_sequence
    # Применяем функцию с прогрессом к DataFrame
    tqdm.pandas(desc="Correcting sequences")
    df["sequences"] = df.progress_apply(modify_sequence, axis=1)
    return df

# Prepare files

In [None]:
# Use bed files as input:
# chr     start   end     ref     alt
# Chr1A   119182  119183  A       C
# Chr1A   119183  119184  C       A
# Chr1A   119192  119193  C       T

In [None]:
allele_dataset = "/home/labs/alevy/petrzhu/Wheat/1k_project_liftover_v2.1/SNPs_lifted_final2_sorted_v1.bed" #"/home/labs/alevy/petrzhu/AI_workshop/PlantCaduceus/datasets/TAIR/TAIR10_allele_dataset.txt.gz"
genome_fasta = "/home/labs/alevy/petrzhu/Prog/Bitbucket_msa/iwgsc_refseqv2.1/iwgsc_refseqv2.1_assembly.fa" #"/home/labs/alevy/omerbar/backups/TAIR/A_thaliana.fa"
output_file = "/home/labs/alevy/petrzhu/AI_workshop/PlantCaduceus/datasets/Wheat/ZH_1K_exomes.txt" #"/home/labs/alevy/petrzhu/AI_workshop/PlantCaduceus/datasets/TAIR/TAIR10_ZH_neutral_vs_simulated.txt"
README = "/home/labs/alevy/petrzhu/AI_workshop/PlantCaduceus/datasets/Wheat/REAMDE.txt"
ref_col = 'ref' 
alt_col = 'alt'

print('loading dataset')
df_allele_dataset = pd.read_csv(allele_dataset, sep="\t", header=0, compression='gzip', na_values=["NA", "null", ".", "-", "n/a", "N/A", "NaN"])
print('change coordinates')
df_zero_shot_input_coordinates = prepare_df_for_PlantCad_zeroshot(df_allele_dataset, ref_col, alt_col)
df_zero_shot_input = add_seqs_to_df_input_coordinates(genome_fasta, df_zero_shot_input_coordinates)
df_zero_shot_input_corr = correct_sequence_ref(df_zero_shot_input)

df_zero_shot_input_corr.to_csv(output_file, index=False, sep="\t")

readme_text = f"{output_file}: ref_col = {ref_col}, alt_col = {alt_col}. VCF = /home/labs/alevy/petrzhu/Wheat/1k_project_liftover_v2.1/SNPs_lifted_final2_sorted_v1.vcf \
    ref_allele = ref from vcf, alt_allele = alt from vcf. \n"
! echo "{readme_text}" >> {README}


# Split files into chunks for downstream parallel analysis

In [None]:
input_file = "/home/labs/alevy/petrzhu/AI_workshop/PlantCaduceus/datasets/Wheat/ZH_1K_exomes.txt" #"/home/labs/alevy/petrzhu/AI_workshop/PlantCaduceus/datasets/TAIR/TAIR10_ZH_anc_vs_neutral.txt"
input_df = pd.read_csv(input_file, sep="\t", header=0, na_values=["NA", "null", ".", "-", "n/a", "N/A", "NaN"]) # , compression='gzip'

In [None]:
# Определяем размер чанка
chunk_size = 100000

for name_df, df in zip([input_file], [input_df]):
    output_folder = name_df.replace(".txt", "_chunks")
    if os.path.exists(output_folder):
        shutil.rmtree(output_folder)
        print(f"The content of the {output_folder} was deleted")
    os.makedirs(output_folder, exist_ok=True)
    chunk_count = 0

    # Добавляем сообщение перед началом цикла
    print(f"Splitting {len(df)} rows into chunks of size {chunk_size}")
    for i, chunk_start in enumerate(range(0, len(df), chunk_size)):
        chunk = df.iloc[chunk_start:chunk_start + chunk_size]
        chunk_file = os.path.join(output_folder, f"chunk_{i+1}.tsv")
        chunk.to_csv(chunk_file, sep="\t", index=False)
        chunk_count += 1
        # Печать прогресса
        if chunk_count % 10 == 0:
            print(f"{chunk_count} chunks saved...")
    print(f"All chunks have been saved to the folder: {output_folder}. Total chunks = {chunk_count}")
