In [None]:
import kipoiseq
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from train import TrainModule, FastaStringExtractor

import matplotlib.pyplot as plt
import seaborn as sns

import pyBigWig
from kipoiseq import Interval
import math
import random
from scipy.stats import pearsonr

random.seed(2077)
SEQUENCE_LENGTH = 65536

def one_hot_encode(sequence):
    #return kipoiseq.transforms.functional.one_hot_dna(sequence, alphabet=('A', 'C', 'G', 'T'),).astype(np.float32)
    #en_dict = {'A' : 0, 'C' : 1, 'G' : 2, 'T' : 3, 'N' : 4}
    en_dict = {'A' : 0, 'T' : 1, 'C' : 2, 'G' : 3, 'N' : 4}
    en_seq = [en_dict[ch] for ch in sequence]
    np_seq = np.array(en_seq, dtype = int)
    seq_emb = np.zeros((len(np_seq), 5))
    seq_emb[np.arange(len(np_seq)), np_seq] = 1
    #seq_emb = np.eye(5)[np_seq]
    return seq_emb.astype(np.float32)

def plot_tracks(tracks, interval, height=1.5):
    fig, axes = plt.subplots(len(tracks), 1, figsize=(20, height * len(tracks)), sharex=True)
    for ax, (title, y) in zip(axes, tracks.items()):
        ax.fill_between(np.linspace(interval.start, interval.end, num=len(y)), y)
        ax.set_title(title)
        sns.despine(top=True, right=True, bottom=True)
    ax.set_xlabel(str(interval))
    plt.tight_layout()
    
def variant_generator(vcf_file):
    with open(vcf_file) as f:
        for line in f:
            chrom, pos, id, ref, alt = line.split('\t')[:5]
            
      # Split ALT alleles and return individual variants as output.
            yield kipoiseq.Variant(chrom=chrom, pos=pos,ref=ref, alt=alt, id=id)
        
def get_cds_interval(gff_df, gene_name):
    # filter by gene_name
    gff_df = gff_df[gff_df[8].str.split(';').str[5].str.split('=').str[1] == gene_name]

    # generate dict
    gene_dict = {row[8].split(';')[0].split('=')[1]: {'chrom': row[0],'start': row[3], 'end': row[4], 'name':row[8].split(';')[7].split('=')[1], 
                                                      'cds_intervals': []} for _, row in gff_df.iterrows() if row[2] == 'transcript'}
    cds_df = gff_df[gff_df[2] == 'CDS']
    for _, row in cds_df.iterrows():
        gene_name = row[8].split(';')[3].split('=')[1]
        gene_dict[gene_name]['cds_intervals'].append((row[3], row[4]))
    target_gene = next((gene_info for gene_info in gene_dict.values()
                      ), None)
    if target_gene:
        cds_intervals = target_gene['cds_intervals']
        return {
            'chrom': target_gene['chrom'],
            'gene_start': target_gene['start'],
            'gene_end': target_gene['end'],
            'cds_intervals': cds_intervals,
            'gene_name': target_gene['name']
        }
    
def generate_inputs(region, fasta_file, bw_file, region_len=SEQUENCE_LENGTH):
    bw = pyBigWig.open(bw_file)
    target = []
    chrom = region.chrom
    start = region.start
    end = region.end
    chromosome_length = bw.chroms(chrom)
    interval = Interval(chrom, start, end).resize(region_len)
    trimmed_interval = Interval(interval.chrom,
                                max(interval.start, 0),
                                min(interval.end, chromosome_length),
                                )
    signals = np.array(bw.values(chrom, trimmed_interval.start, trimmed_interval.end)).astype(np.float32).tolist()
    pad_upstream = np.array([0] * max(-interval.start, 0)).astype(np.float32).tolist()
    pad_downstream = np.array([0] * max(interval.end - chromosome_length, 0)).astype(np.float32).tolist()
    tmp = pad_upstream + signals + pad_downstream
    arr = np.array(tmp).astype(np.float32)
    target.append(arr)

    target = np.array(target).astype(np.float32)
    target = np.nan_to_num(target,0)
    target = np.log(target + 1)
    bw.close()
    return target

def generate_outputs(region, fasta_file, bw_file, nBins=1024, region_len=65536):
    bw = pyBigWig.open(bw_file)
    target = []
    chrom = region.chrom
    start = region.start
    end = region.end
    chromosome_length = bw.chroms(chrom)
    interval = Interval(chrom, start, end).resize(region_len)
    trimmed_interval = Interval(interval.chrom,
                                max(interval.start, 0),
                                min(interval.end, chromosome_length),
                               )
    signals = np.array(bw.values(chrom, trimmed_interval.start, trimmed_interval.end)).astype(np.float32).tolist()
    pad_upstream = np.array([0] * max(-interval.start, 0)).astype(np.float32).tolist()
    pad_downstream = np.array([0] * max(interval.end - chromosome_length, 0)).astype(np.float32).tolist()
    tmp = pad_upstream + signals + pad_downstream
    arr = np.array(tmp).astype(np.float32)
    reshaped_arr = arr.reshape(-1, 64)
    averages = np.mean(reshaped_arr, axis=1)
    target.append(averages)

    target = np.array(target).astype(np.float32)
    target = np.nan_to_num(target,0)
    target = np.log(target + 1)
    bw.close()
    return target




device = 'cuda:0'
checkpoint = '/data/slurm/hejl/riboseq/results_DNA/bigmodel/bigmodel_h512_l12_lr1e-5/models/epoch=38-step=746889.ckpt'
model = TrainModule.load_from_checkpoint(checkpoint).to(device)
model = model.eval()

#GTF_FILE = '/data/slurm/leixiong/m6A_prediction/data/gencode.v42.chr_patch_hapl_scaff.annotation.gtf.gz'
SEQUENCE_LENGTH = 65536
fasta_file = '/data/slurm/hejl/riboseq/data/hg38/hg38.fa'#str(DATA_PATH / 'hg38.fa')
fasta_extractor = FastaStringExtractor(fasta_file)
gff_file = '/data/slurm/hejl/riboseq/gencode.v43.annotation.gff3'
#rna_bw_file = '/data/slurm/hejl/riboseq/data/hg38/K562/GSE153597/input_features/rnaseq.bw'
rna_bw_file = '/data/slurm/hejl/riboseq/data/hg38/neuron/GSE90469_2/input_features/rnaseq.bw'
ribo_bw_file = '/data/slurm/hejl/riboseq/data/hg38/neuron/GSE90469_2/output_features/riboseq.bw'

gene_name = 'OR4F29'
#gene_name = 'KRT19'

gff_df = pd.read_csv(gff_file, sep='\t', comment='#', header=None)
gene_info = get_cds_interval(gff_df, gene_name)

gene_interval =kipoiseq.Interval(gene_info['chrom'], gene_info['gene_start'], gene_info['gene_end'])
region_interval = gene_interval
target_interval = Interval(region_interval.chrom, region_interval.start, region_interval.end).resize(SEQUENCE_LENGTH)
ref_seq = fasta_extractor.extract(target_interval)
ref_emb = torch.Tensor(one_hot_encode(ref_seq)).to(device)

epi = torch.Tensor(generate_inputs(region_interval, fasta_file, rna_bw_file)[0]).unsqueeze(1).to(device)
reference_input = torch.cat([ref_emb, epi], dim = 1).unsqueeze(0)
pred = model(reference_input)[0].detach().cpu().numpy()

gt = torch.Tensor(generate_outputs(region_interval, fasta_file, ribo_bw_file)[0]).cpu().numpy()
#
tracks = {'Ground Truth(Epithelial)': gt,
          'Prediction(Epithelial)': pred
          }

plot_tracks(tracks, target_interval, height=1) #

correlation, p_value = pearsonr(pred,gt)

print('Pearson: ')
print(correlation)


##

In [None]:
import kipoiseq
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from train import TrainModule, FastaStringExtractor

import matplotlib.pyplot as plt
import seaborn as sns

import pyBigWig
from kipoiseq import Interval
import math
import random
from scipy.stats import pearsonr

random.seed(2077)
SEQUENCE_LENGTH = 65536

def one_hot_encode(sequence):
    #return kipoiseq.transforms.functional.one_hot_dna(sequence, alphabet=('A', 'C', 'G', 'T'),).astype(np.float32)
    #en_dict = {'A' : 0, 'C' : 1, 'G' : 2, 'T' : 3, 'N' : 4}
    en_dict = {'A' : 0, 'T' : 1, 'C' : 2, 'G' : 3, 'N' : 4}
    en_seq = [en_dict[ch] for ch in sequence]
    np_seq = np.array(en_seq, dtype = int)
    seq_emb = np.zeros((len(np_seq), 5))
    seq_emb[np.arange(len(np_seq)), np_seq] = 1
    #seq_emb = np.eye(5)[np_seq]
    return seq_emb.astype(np.float32)

def plot_tracks(tracks, interval, height=1.5):
    fig, axes = plt.subplots(len(tracks), 1, figsize=(20, height * len(tracks)), sharex=True)
    for ax, (title, y) in zip(axes, tracks.items()):
        ax.fill_between(np.linspace(interval.start, interval.end, num=len(y)), y)
        ax.set_title(title)
        sns.despine(top=True, right=True, bottom=True)
    ax.set_xlabel(str(interval))
    plt.tight_layout()
    
def variant_generator(vcf_file):
    with open(vcf_file) as f:
        for line in f:
            chrom, pos, id, ref, alt = line.split('\t')[:5]
            
      # Split ALT alleles and return individual variants as output.
            yield kipoiseq.Variant(chrom=chrom, pos=pos,ref=ref, alt=alt, id=id)
        
def get_cds_interval(gff_df, gene_name):
    # filter by gene_name
    gff_df = gff_df[gff_df[8].str.split(';').str[5].str.split('=').str[1] == gene_name]

    # generate dict
    gene_dict = {row[8].split(';')[0].split('=')[1]: {'chrom': row[0],'start': row[3], 'end': row[4], 'name':row[8].split(';')[7].split('=')[1], 
                                                      'cds_intervals': []} for _, row in gff_df.iterrows() if row[2] == 'transcript'}
    cds_df = gff_df[gff_df[2] == 'CDS']
    for _, row in cds_df.iterrows():
        gene_name = row[8].split(';')[3].split('=')[1]
        gene_dict[gene_name]['cds_intervals'].append((row[3], row[4]))
    target_gene = next((gene_info for gene_info in gene_dict.values()
                      ), None)
    if target_gene:
        cds_intervals = target_gene['cds_intervals']
        return {
            'chrom': target_gene['chrom'],
            'gene_start': target_gene['start'],
            'gene_end': target_gene['end'],
            'cds_intervals': cds_intervals,
            'gene_name': target_gene['name']
        }
    
def generate_inputs(region, fasta_file, bw_file, region_len=SEQUENCE_LENGTH):
    bw = pyBigWig.open(bw_file)
    target = []
    chrom = region.chrom
    start = region.start
    end = region.end
    chromosome_length = bw.chroms(chrom)
    interval = Interval(chrom, start, end).resize(region_len)
    trimmed_interval = Interval(interval.chrom,
                                max(interval.start, 0),
                                min(interval.end, chromosome_length),
                                )
    signals = np.array(bw.values(chrom, trimmed_interval.start, trimmed_interval.end)).astype(np.float32).tolist()
    pad_upstream = np.array([0] * max(-interval.start, 0)).astype(np.float32).tolist()
    pad_downstream = np.array([0] * max(interval.end - chromosome_length, 0)).astype(np.float32).tolist()
    tmp = pad_upstream + signals + pad_downstream
    arr = np.array(tmp).astype(np.float32)
    target.append(arr)

    target = np.array(target).astype(np.float32)
    target = np.nan_to_num(target,0)
    target = np.log(target + 1)
    bw.close()
    return target

def generate_outputs(region, fasta_file, bw_file, nBins=1024, region_len=65536):
    bw = pyBigWig.open(bw_file)
    target = []
    chrom = region.chrom
    start = region.start
    end = region.end
    chromosome_length = bw.chroms(chrom)
    interval = Interval(chrom, start, end).resize(region_len)
    trimmed_interval = Interval(interval.chrom,
                                max(interval.start, 0),
                                min(interval.end, chromosome_length),
                               )
    signals = np.array(bw.values(chrom, trimmed_interval.start, trimmed_interval.end)).astype(np.float32).tolist()
    pad_upstream = np.array([0] * max(-interval.start, 0)).astype(np.float32).tolist()
    pad_downstream = np.array([0] * max(interval.end - chromosome_length, 0)).astype(np.float32).tolist()
    tmp = pad_upstream + signals + pad_downstream
    arr = np.array(tmp).astype(np.float32)
    reshaped_arr = arr.reshape(-1, 64)
    averages = np.mean(reshaped_arr, axis=1)
    target.append(averages)

    target = np.array(target).astype(np.float32)
    target = np.nan_to_num(target,0)
    target = np.log(target + 1)
    bw.close()
    return target




device = 'cuda:0'
checkpoint = '/data/slurm/hejl/riboseq/results_DNA/bigmodel/bigmodel_h512_l12_lr1e-5/models/epoch=38-step=746889.ckpt'
model = TrainModule.load_from_checkpoint(checkpoint).to(device)
model = model.eval()

#GTF_FILE = '/data/slurm/leixiong/m6A_prediction/data/gencode.v42.chr_patch_hapl_scaff.annotation.gtf.gz'
SEQUENCE_LENGTH = 65536
fasta_file = '/data/slurm/hejl/riboseq/data/hg38/hg38.fa'#str(DATA_PATH / 'hg38.fa')
fasta_extractor = FastaStringExtractor(fasta_file)
gff_file = '/data/slurm/hejl/riboseq/gencode.v43.annotation.gff3'
#rna_bw_file = '/data/slurm/hejl/riboseq/data/hg38/K562/GSE153597/input_features/rnaseq.bw'
rna_bw_file = '/data/slurm/hejl/riboseq/data/hg38/Epithelial/GSE200097/input_features/rnaseq.bw'
ribo_bw_file = '/data/slurm/hejl/riboseq/data/hg38/Epithelial/GSE200097/output_features/riboseq.bw'
rna_bw_file_t1 = '/data/slurm/hejl/riboseq/data/hg38/Prostate/GSE130465/input_features/rnaseq.bw'
rna_bw_file_t2 = '/data/slurm/hejl/riboseq/data/hg38/hTERT-RPE1/GSE138533/input_features/rnaseq.bw'
rna_bw_file_t3 = '/data/slurm/hejl/riboseq/data/hg38/RD/GSE103308/input_features/rnaseq.bw'

#rna_bw_file = '/data/slurm/hejl/riboseq/data/hg38/mean.sorted.bw'
#rna_bw_file = '/data/slurm/hejl/riboseq/Translatomer/data/hg38/mean.sorted.bw'

gene_name = 'KLF6'
#gene_name = 'KRT19'

gff_df = pd.read_csv(gff_file, sep='\t', comment='#', header=None)
gene_info = get_cds_interval(gff_df, gene_name)

gene_interval =kipoiseq.Interval(gene_info['chrom'], gene_info['gene_start'], gene_info['gene_end'])
region_interval = gene_interval
target_interval = Interval(region_interval.chrom, region_interval.start, region_interval.end).resize(SEQUENCE_LENGTH)
ref_seq = fasta_extractor.extract(target_interval)
ref_emb = torch.Tensor(one_hot_encode(ref_seq)).to(device)

epi = torch.Tensor(generate_inputs(region_interval, fasta_file, rna_bw_file)[0]).unsqueeze(1).to(device)
reference_input = torch.cat([ref_emb, epi], dim = 1).unsqueeze(0)
pred = model(reference_input)[0].detach().cpu().numpy()

epi_t1 = torch.Tensor(generate_inputs(region_interval, fasta_file, rna_bw_file_t1)[0]).unsqueeze(1).to(device)
reference_input_t1 = torch.cat([ref_emb, epi_t1], dim = 1).unsqueeze(0)
pred_t1 = model(reference_input_t1)[0].detach().cpu().numpy()

epi_t2 = torch.Tensor(generate_inputs(region_interval, fasta_file, rna_bw_file_t2)[0]).unsqueeze(1).to(device)
reference_input_t2 = torch.cat([ref_emb, epi_t2], dim = 1).unsqueeze(0)
pred_t2 = model(reference_input_t2)[0].detach().cpu().numpy()

epi_t3 = torch.Tensor(generate_inputs(region_interval, fasta_file, rna_bw_file_t3)[0]).unsqueeze(1).to(device)
reference_input_t3 = torch.cat([ref_emb, epi_t3], dim = 1).unsqueeze(0)
pred_t3 = model(reference_input_t3)[0].detach().cpu().numpy()


gt = torch.Tensor(generate_outputs(region_interval, fasta_file, ribo_bw_file)[0]).cpu().numpy()
#
tracks = {'Ground Truth(Epithelial)': gt,
          'Prediction(Epithelial)': pred,          
          'Prediction(hTERT-RPE1)': pred_t2,
          'Prediction(Prostate)': pred_t1,
          'Prediction(RD)': pred_t3,
          }

plot_tracks(tracks, target_interval, height=1) #

correlation, p_value = pearsonr(pred,gt)
correlation1, p_value1 = pearsonr(pred_t1,gt)
correlation2, p_value2 = pearsonr(pred_t2,gt)
correlation3, p_value3 = pearsonr(pred_t3,gt)
print('Pearson(Epithelial): ')
print(correlation)
print('Pearson(hTERT-RPE1): ')
print(correlation2)
print('Pearson(Prostate): ')
print(correlation1)
print('Pearson(RD): ')
print(correlation3)

In [None]:
def plot_tracks(tracks, interval, height=1.5):
    fig, axes = plt.subplots(len(tracks), 1, figsize=(20, height * len(tracks)), sharex=True)
    for ax, (title, y) in zip(axes, tracks.items()):
        ax.fill_between(np.linspace(interval.start, interval.end, num=len(y)), y)
        ax.set_title(title)
        ax.set_ylim(0, 5)
        sns.despine(top=True, right=True, bottom=True)
    ax.set_xlabel("chr10:3745339-3810875")
    plt.tight_layout()
    plt.savefig('/data/slurm/hejl/riboseq/Translatomer/results/Fig4/KLF6.pdf')
plot_tracks(tracks, target_interval, height=1)