In [1]:
from keras.models import load_model
import joblib
import numpy as np
from keras.utils import to_categorical




In [2]:
import sys
sys.path.append('../src/')
import vcf2onehot


In [3]:
def label2onehot(label: np.array) -> np.array:
    thresholds = [0.25, 0.75, 1.25, 1.75]

    categorical_labels = np.digitize(label, thresholds)

    one_hot_encoder = to_categorical(categorical_labels)
    
    return one_hot_encoder

In [30]:
def get_model():
	# model = load_model('../save_model/ModelCheckPoint/final_25-03-2024_02-32/model.088-0.0545-0.9746.h5')
	model = load_model('../save_model/FinalModel/final_25-03-2024_02-32/model.h5')
 
	return model

In [31]:
def get_data(vcf_file=None, seq_file=None, joblib_file=None):
	if joblib_file is not None:
		try:
			data = joblib.load(joblib_file)
			X = data['X']
			y = label2onehot(data['activate_score'])
			return X, y
		except FileNotFoundError:
			print(f"File {joblib_file} not found.")
			return None, None
	elif seq_file is not None:
		try:
			sample_seq = {}
			with open(seq_file, 'r') as f:
				for line in f:
					fields = line.split()
					seqs_hap1 = [int(x) for x in fields[1].split(',')]
					seqs_hap2 = [int(x) for x in fields[2].split(',')]
				
					sample_seq[fields[0]] = [seqs_hap1, seqs_hap2]
			print(sample_seq.keys())
			# seqs = content.split() 
	
			# seq_data = {seqs[0]: [[seqs[1]], [seqs[2]]]}
			# print(np.array(seq_data[seqs[0]]))
			return vcf2onehot.format_seqs(sample_seq)['X']
				
		except FileNotFoundError:
			print(f"File {seq_file} not found.")
	elif vcf_file is not None:
		try:
			seqs = vcf2onehot.build_seqs(vcf_file)
			return vcf2onehot.format_seqs(seqs)['X']
		except FileNotFoundError:
			print(f"File {seq_file} not found.")
		


In [32]:
def get_score(predictions: np.array):
    dict_class = {0: 0, 1: 0.5, 2: 1, 3: 1.5, 4: 2}
    
    mapped_results = [dict_class[pred] for pred in predictions]
    return mapped_results

In [71]:
X = get_data(seq_file='../data/PRJEB19931.seq')

dict_keys(['HG00276', 'HG00436', 'HG00589', 'HG01190', 'NA06991', 'NA07000', 'NA07019', 'NA07029', 'NA07055', 'NA07056', 'NA07348', 'NA07357', 'NA10831', 'NA10847', 'NA10851', 'NA10854', 'NA11832', 'NA11839', 'NA11993', 'NA12003', 'NA12006', 'NA12145', 'NA12156', 'NA12717', 'NA12813', 'NA12873', 'NA18484', 'NA18509', 'NA18518', 'NA18519', 'NA18524', 'NA18526', 'NA18540', 'NA18544', 'NA18552', 'NA18564', 'NA18565', 'NA18617', 'NA18855', 'NA18861', 'NA18868', 'NA18942', 'NA18952', 'NA18959', 'NA18966', 'NA18973', 'NA18980', 'NA18992', 'NA19003', 'NA19007', 'NA19095', 'NA19109', 'NA19122', 'NA19143', 'NA19147', 'NA19174', 'NA19176', 'NA19178', 'NA19207', 'NA19213', 'NA19226', 'NA19239', 'NA19789', 'NA19819', 'NA19908', 'NA19917', 'NA19920', 'NA20296', 'NA20509', 'NA21781'])


In [72]:
X.shape

(70, 14868, 13)

In [73]:
model = get_model()

In [74]:
model.summary()

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv1d_1 (Conv1D)           (None, 2970, 70)          17360     
                                                                 
 batch_1 (BatchNormalizatio  (None, 2970, 70)          280       
 n)                                                              
                                                                 
 relu_1 (ReLU)               (None, 2970, 70)          0         
                                                                 
 maxpooling_1 (MaxPooling1D  (None, 990, 70)           0         
 )                                                               
                                                                 
 conv1d_2 (Conv1D)           (None, 196, 46)           35466     
                                                                 
 batch_2 (BatchNormalizatio  (None, 196, 46)           1

In [47]:
X_test, y_test = get_data(joblib_file='../data/input_data/batch_126.joblib')

In [23]:
model.evaluate(X_test, y_test)




[0.03168383985757828, 0.9940000176429749]

In [75]:
pred = model.predict(X)



In [76]:
pred

array([[3.23991477e-02, 1.01728465e-05, 9.66758251e-01, 3.96954747e-05,
        7.92718900e-04],
       [2.63585425e-05, 6.97078349e-06, 9.99965906e-01, 6.33612501e-07,
        8.36593728e-08],
       [5.07644882e-06, 1.78142334e-04, 9.99796093e-01, 2.02863030e-05,
        3.86227299e-07],
       [1.85539350e-02, 4.15298418e-05, 9.81260300e-01, 2.10658163e-05,
        1.23206439e-04],
       [4.50947955e-02, 3.96103802e-04, 9.54441905e-01, 6.53739262e-05,
        1.84414023e-06],
       [3.16594414e-05, 7.16761127e-03, 9.92539227e-01, 2.56726431e-04,
        4.83209260e-06],
       [1.91778727e-02, 1.64266512e-05, 9.80794728e-01, 7.40912947e-06,
        3.52272264e-06],
       [2.63512247e-05, 5.47488825e-03, 9.94230866e-01, 2.67089141e-04,
        8.32010301e-07],
       [1.99344698e-02, 1.70503365e-04, 9.79056954e-01, 6.01309934e-04,
        2.36722975e-04],
       [7.01119425e-05, 9.86875035e-04, 9.98648584e-01, 2.72147183e-04,
        2.22271483e-05],
       [5.20025313e-01, 3.5382

In [77]:
predict_class = np.argmax(pred, axis=1)

In [78]:
predict_class

array([2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 0, 2, 2, 2, 2, 2, 2, 0, 2, 2, 2,
       2, 1, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 2, 2, 2,
       2, 2, 2, 2], dtype=int64)

In [80]:
get_score(predict_class)

[1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 0,
 0,
 1,
 1,
 1,
 1,
 1,
 1,
 0,
 1,
 1,
 1,
 1,
 0.5,
 1,
 0,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 0,
 0,
 0,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 0,
 1,
 1,
 1,
 1,
 1,
 1,
 1]

In [46]:
data = get_data(vcf_file='../data/simulated_cyp2d6_diplotypes/batch_500.vcf')

317 1406


In [47]:
get_score(np.argmax(model.predict(data), axis=1))



[0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0.5,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0.5,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0.5,
 0.5,
 0,
 0,
 0,
 0,
 0.5,
 0,
 0,
 0,
 0,
 0,
 0,
 0.5,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0.5,
 0,
 0,
 0,
 0,
 0.5,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 1,
 0,
 0,
 0,
 0.5,
 0,
 0,
 0,
 0,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5