In [None]:
# libraries
import torch
from transformers import XLNetConfig, XLNetForTokenClassification
import pytorch_lightning as pl
import itertools
import hyperopt
import pygad
import matplotlib.pyplot as plt
# suppress warnings
import warnings
warnings.filterwarnings("ignore")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# reproducibility
seed_val = 4
pl.seed_everything(seed_val)

# one-hot encoding for the conditions
condition_values = {'CTRL': 64, 'ILE': 65, 'LEU': 66, 'LEU_ILE': 67, 'LEU_ILE_VAL': 68, 'VAL': 69}
inverse_condition_values = {64: 'CTRL', 65: 'ILE', 66: 'LEU', 67: 'LEU_ILE', 68: 'LEU_ILE_VAL', 69: 'VAL'}

codon_to_aa = {
        'ATA':'I', 'ATC':'I', 'ATT':'I', 'ATG':'M',
        'ACA':'T', 'ACC':'T', 'ACG':'T', 'ACT':'T',
        'AAC':'N', 'AAT':'N', 'AAA':'K', 'AAG':'K',
        'AGC':'S', 'AGT':'S', 'AGA':'R', 'AGG':'R',                
        'CTA':'L', 'CTC':'L', 'CTG':'L', 'CTT':'L',
        'CCA':'P', 'CCC':'P', 'CCG':'P', 'CCT':'P',
        'CAC':'H', 'CAT':'H', 'CAA':'Q', 'CAG':'Q',
        'CGA':'R', 'CGC':'R', 'CGG':'R', 'CGT':'R',
        'GTA':'V', 'GTC':'V', 'GTG':'V', 'GTT':'V',
        'GCA':'A', 'GCC':'A', 'GCG':'A', 'GCT':'A',
        'GAC':'D', 'GAT':'D', 'GAA':'E', 'GAG':'E',
        'GGA':'G', 'GGC':'G', 'GGG':'G', 'GGT':'G',
        'TCA':'S', 'TCC':'S', 'TCG':'S', 'TCT':'S',
        'TTC':'F', 'TTT':'F', 'TTA':'L', 'TTG':'L',
        'TAC':'Y', 'TAT':'Y', 'TAA':'_', 'TAG':'_',
        'TGC':'C', 'TGT':'C', 'TGA':'_', 'TGG':'W'
    }

aa_to_codon = {}
for codon, aa in codon_to_aa.items():
    if aa in aa_to_codon:
        aa_to_codon[aa].append(codon)
    else:
        aa_to_codon[aa] = [codon]

In [None]:
# model parameters
annot_thresh = 0.3
longZerosThresh_val = 20
percNansThresh_val = 0.05
d_model_val = 256
n_layers_val = 3
n_heads_val = 8
dropout_val = 0.1
lr_val = 1e-4
batch_size_val = 1
loss_fun_name = '5L' # 5L

# model name and output folder path
model_name = 'PLabelXLNetDHConds DS: DeprNA [3, 256, 8] FT: [PEL] BS: 1 Loss: 5L Data Conds: [NZ: 20, PNTh: 0.05, AnnotThresh: 0.3] Seed: 4[Exp: set1, PLQ1: 1, PLQ2: None, ImpDen: impute] Noisy: False'
output_loc = 'saved_models/' + model_name 

condition_dict_values = {64: 'CTRL', 65: 'ILE', 66: 'LEU', 67: 'LEU_ILE', 68: 'LEU_ILE_VAL', 69: 'VAL'}

class XLNetDH(XLNetForTokenClassification):
    def __init__(self, config):
        super().__init__(config)
        self.classifier = torch.nn.Linear(d_model_val, 2, bias=True)

config = XLNetConfig(vocab_size=71, pad_token_id=70, d_model = d_model_val, n_layer = n_layers_val, n_head = n_heads_val, d_inner = d_model_val, num_labels = 1, dropout=dropout_val) # 64*6 tokens + 1 for padding
model = XLNetDH(config)

# load model from the saved model
model = model.from_pretrained(output_loc + "/best_model")
model.to(device)
# set model to evaluation mode
model.eval()

In [None]:

def getDensity(codon_list):
    codon_list = list(codon_list)
    # add 64 in the beginning of each list inside the list
    for i in range(len(codon_list)):
        new_list = [64]
        new_list.extend(codon_list[i])
        codon_list[i] = new_list
    codon_list = torch.tensor(codon_list, dtype=torch.long).to(device)
    pred = model(codon_list)
    # remove dim 0
    # pred = torch.squeeze(pred["logits"], dim=0)
    pred = pred["logits"][:, 1:, :]
    pred = torch.sum(pred, dim=2)
    sum_value = torch.sum(pred, dim=1)
    
    return {'loss': sum_value, 'input': codon_list, 'status': hyperopt.STATUS_OK}

In [None]:
# original sequence which you wish to mutate (provide this without the stop codon)
orig_sequence = 'ATGGAAGACGCCAAAAACATAAAGAAAGGCCCGGCGCCATTCTATCCTCTAGAGGATGGAACCGCTGGAGAGCAACTGCATAAGGCTATGAAGAGATACGCCCTGGTTCCTGGAACAATTGCTTTTACAGATGCACATATCGAGGTGAACATCACGTACGCGGAATACTTCGAAATGTCCGTTCGGTTGGCAGAAGCTATGAAACGATATGGGCTGAATACAAATCACAGAATCGTCGTATGCAGTGAAAACTCTCTTCAATTCTTTATGCCGGTGTTGGGCGCGTTATTTATCGGAGTTGCAGTTGCGCCCGCGAACGACATTTATAATGAACGTGAATTGCTCAACAGTATGAACATTTCGCAGCCTACCGTAGTGTTTGTTTCCAAAAAGGGGTTGCAAAAAATTTTGAACGTGCAAAAAAAATTACCAATAATCCAGAAAATTATTATCATGGATTCTAAAACGGATTACCAGGGATTTCAGTCGATGTACACGTTCGTCACATCTCATCTACCTCCCGGTTTTAATGAATACGATTTTGTACCAGAGTCCTTTGATCGTGACAAAACAATTGCACTGATAATGAATTCCTCTGGATCTACTGGGTTACCTAAGGGTGTGGCCCTTCCGCATAGAACTGCCTGCGTCAGATTCTCGCATGCCAGAGATCCTATTTTTGGCAATCAAATCATTCCGGATACTGCGATTTTAAGTGTTGTTCCATTCCATCACGGTTTTGGAATGTTTACTACACTCGGATATTTGATATGTGGATTTCGAGTCGTCTTAATGTATAGATTTGAAGAAGAGCTGTTTTTACGATCCCTTCAGGATTACAAAATTCAAAGTGCGTTGCTAGTACCAACCCTATTTTCATTCTTCGCCAAAAGCACTCTGATTGACAAATACGATTTATCTAATTTACACGAAATTGCTTCTGGGGGCGCACCTCTTTCGAAAGAAGTCGGGGAAGCGGTTGCAAAACGCTTCCATCTTCCAGGGATACGACAAGGATATGGGCTCACTGAGACTACATCAGCTATTCTGATTACACCCGAGGGGGATGATAAACCGGGCGCGGTCGGTAAAGTTGTTCCATTTTTTGAAGCGAAGGTTGTGGATCTGGATACCGGGAAAACGCTGGGCGTTAATCAGAGAGGCGAATTATGTGTCAGAGGACCTATGATTATGTCCGGTTATGTAAACAATCCGGAAGCGACCAACGCCTTGATTGACAAGGATGGATGGCTACATTCTGGAGACATAGCTTACTGGGACGAAGACGAACACTTCTTCATAGTTGACCGCTTGAAGTCTTTAATTAAATACAAAGGATATCAGGTGGCCCCCGCTGAATTGGAATCGATATTGTTACAACACCCCAACATCTTCGACGCGGGCGTGGCAGGTCTTCCCGACGATGACGCCGGTGAACTTCCCGCCGCCGTTGTTGTTTTGGAGCACGGAAAGACGATGACGGAAAAAGAGATCGTGGATTACGTCGCCAGTCAAGTAACAACCGCGAAAAAGTTGCGCGGAGGAGTTGTGTTTGTGGACGAAGTACCGAAAGGTCTTACCGGAAAACTCGACGCAAGAAAAATCAGAGAGATCCTCATAAAGGCCAAGAAGGGCGGAAAGTCCAAATTG' 
stop_codon = 'TAA'
gene_name = 'luciferase'
num_gen = 20
num_sequence = 10
hyperopt_max_evals = 50000

In [None]:
# convert to uppercase
orig_sequence = orig_sequence.upper()
stop_codon = stop_codon.upper()

# one-hot encoding for the codons
id_to_codon = {idx:''.join(el) for idx, el in enumerate(itertools.product(['A', 'T', 'C', 'G'], repeat=3))}
codon_to_id = {v:k for k,v in id_to_codon.items()}

# convert to codon_int
orig_codon_list = []
for i in range(0, len(orig_sequence), 3):
    codon = orig_sequence[i:i+3]
    if len(codon) == 3:
        orig_codon_list.append(codon_to_id[codon])
    else:
        break

orig_aa_sequence = ''.join([codon_to_aa[id_to_codon[codon]] for codon in orig_codon_list])


In [None]:
orig_pred_densitySum = getDensity([orig_codon_list])
print("Original Sequence Density: ", orig_pred_densitySum['loss'][0].item())

In [None]:
# make a full space of all the possible codon lists that can give this aa_sequence
possibilities_codons = []
for i in range(len(orig_aa_sequence)):
    aa = orig_aa_sequence[i]
    codon_possible = aa_to_codon[aa]
    possibilities_codons.append([codon_to_id[codon] for codon in codon_possible])

# Genetic Algorithm Based Gene Optimization

In [None]:
# initial population that has the original sequence + some random sequences
initial_population = []
# add the original sequence
initial_population.append(orig_codon_list)

# make up 19 more sequences which are random permutations of the original sequence but according to the possibilites space
for i in range(19):
    new_sequence = [possibilities_codons[j][torch.randint(0, len(possibilities_codons[j]), (1,)).item()] for j in range(len(orig_aa_sequence))]
    initial_population.append(new_sequence)

# genetic algorithm based optimization
def fitness_func(ga_instance, solutions, solution_idx): # a fitness function is maximized by a genetic algorithm
    pred_densitySum = getDensity(solutions)
    fitness_scores = pred_densitySum['loss']
    batch_fitness = [1.0 / x.item() for x in fitness_scores]
    return batch_fitness

last_fitness = 0
def on_generation(ga_instance):
    global last_fitness
    print("Generation = {generation}".format(generation=ga_instance.generations_completed))
    print("Fitness    = {fitness}".format(fitness=ga_instance.best_solution(pop_fitness=ga_instance.last_generation_fitness)[1]))
    print("Change     = {change}".format(change=ga_instance.best_solution(pop_fitness=ga_instance.last_generation_fitness)[1] - last_fitness))
    last_fitness = ga_instance.best_solution(pop_fitness=ga_instance.last_generation_fitness)[1]

# create an instance of the pygad.GA class
ga_instance = pygad.GA(num_generations=num_gen, 
                    num_parents_mating=10, 
                    fitness_func=fitness_func, 
                    sol_per_pop=20, 
                    num_genes=len(orig_aa_sequence), 
                    initial_population=initial_population,
                    gene_type=int,
                    gene_space = possibilities_codons,
                    allow_duplicate_genes=False,
                    fitness_batch_size=20,
                    save_solutions=True,
                    on_generation=on_generation)

# run the genetic algorithm
ga_instance.run()

In [None]:
# get the best solution after running the genetic algorithm
solution = ga_instance.best_solution()
print("Genetic Algorithm Best Solution: ", ''.join([id_to_codon[solution[0][j]] for j in range(len(orig_aa_sequence))]) + stop_codon, " \nDensity: ", getDensity([solution[0]])['loss'].item())

In [None]:
# make a plot with the distribution on the orig sequence and the best sequence
orig_pred = model(torch.tensor([64] + orig_codon_list).long().unsqueeze(0).to(device))
orig_pred = torch.sum(torch.squeeze(orig_pred["logits"], dim=0), dim=1)[1:].detach().cpu().numpy()

best_pred = model(torch.tensor([64] + solution[0].tolist()).long().unsqueeze(0).to(device))
best_pred = torch.sum(torch.squeeze(best_pred["logits"], dim=0), dim=1)[1:].detach().cpu().numpy()

# plot the distributions
plt.figure()
plt.plot(orig_pred, label='Original Sequence')
plt.plot(best_pred, label='Best Sequence')
plt.legend()
plt.title('GA - Density Distribution')
plt.xlabel('Position')
plt.ylabel('Density')
plt.show()

In [None]:
# print top 5 solutions
solutions = ga_instance.population
solutions_density = [getDensity([sol])['loss'].item() for sol in solutions]
# solutions = sorted(solutions, key=lambda x: x)
# sort both solutions and solutions_density according to solutions_density
solutions = [x for _, x in sorted(zip(solutions_density, solutions), key=lambda pair: pair[0])]
solutions_density = sorted(solutions_density)

print("Top Solutions: ")
for i in range(num_sequence):
    print("Sequence: ", ''.join([id_to_codon[solutions[i][j]] for j in range(len(orig_aa_sequence))]) + stop_codon, " \nDensity: ", solutions_density[i])

# make a fasta file out of this
f = open('best_solutionsGA_' + gene_name + '_numgen' + str(num_gen) + '.fasta', 'w')

for i in range(num_sequence):
    f.write(">Seq" + str(i+1) + " density: " + str(solutions_density[i]) + "\n")
    f.write(''.join([id_to_codon[solutions[i][j]] for j in range(len(orig_aa_sequence))]) + stop_codon + "\n")

f.close()

# Hyperopt Based Gene Optimization

In [None]:
# # manipulate orig_aa_sequence so that the codon_list has the least densitySum

# # sequence of codons with num_codon variables
# optimized_sequence = []
# for i in range(len(orig_aa_sequence)):
#     optimized_sequence.append(hyperopt.hp.choice(f"codon_{i}", possibilities_codons[i]))

# # define the objective function
# def objective(codon_list):
#     pred_densitySum = getDensity([codon_list])
#     return pred_densitySum

# # run the hyperparameter search
# trials = hyperopt.Trials()
# best = hyperopt.fmin(objective, optimized_sequence, algo=hyperopt.tpe.suggest, max_evals=hyperopt_max_evals, trials=trials)

# # print best sequence
# best_sequence = [possibilities_codons[i][best[f"codon_{i}"]] for i in range(len(orig_aa_sequence))]
# print("Best Sequence: ", best_sequence, " \n Density: ", objective(best_sequence)['loss'])


In [None]:
# # # best sequence
# trials = sorted(trials.results, key=lambda x: x['loss'])
# print("Best Sequence: ", ''.join([id_to_codon[trials[0]['input'][0][j].item()] for j in range(1, len(orig_aa_sequence))]) + stop_codon, " \nDensity: ", trials[0]['loss'])


In [None]:
# # make a plot with the distribution on the orig sequence and the best sequence
# orig_pred = model(torch.tensor([64] + orig_codon_list).long().unsqueeze(0).to(device))
# orig_pred = torch.sum(torch.squeeze(orig_pred["logits"], dim=0), dim=1)[1:].detach().cpu().numpy()

# best_pred = model(torch.tensor(trials[0]['input'][0].tolist()).long().unsqueeze(0).to(device))
# best_pred = torch.sum(torch.squeeze(best_pred["logits"], dim=0), dim=1)[1:].detach().cpu().numpy()

# # plot the distributions
# plt.figure()
# plt.plot(orig_pred, label='Original Sequence')
# plt.plot(best_pred, label='Best Sequence')
# plt.legend()
# plt.title('Hyperopt - Density Distribution')
# plt.xlabel('Position')
# plt.ylabel('Density')
# plt.show()

In [None]:
# print("Top Trials: ")
# # print top 5 trials w sequence and loss
# for i in range(num_sequence):
#     print("Sequence: ", ''.join([id_to_codon[trials[i]['input'][0][j].item()] for j in range(1, len(orig_aa_sequence))]) + stop_codon, " \nDensity: ", trials[i]['loss'])

# # make a fasta file out of this
# f = open('best_solutionsHopt_' + gene_name + '_evals' + str(hyperopt_max_evals) + '.fasta', 'w')

# for i in range(num_sequence):
#     f.write(">Seq" + str(i+1) + " density: " + str(trials[i]['loss']) + "\n")
#     f.write(''.join([id_to_codon[trials[i]['input'][0][j].item()] for j in range(1, len(orig_aa_sequence))]) + stop_codon + "\n")

# f.close()