# Google Colab Intro
#### This notebook downloads data and performs computations. These computations are made significantly faster with the use of a GPU, provided free by Google as part of Colab.

#### Press the play button on the left side of each cell to run it. Alternatively, hold shift or ctrl and press enter to run cells.
#### Double click the top of a cell to inspect the code inside and change things. Double click the right side of the cell to hide the code.
#### Have fun!

#### (bug me on github if a significantly faster c++ version would be useful to you)

# Setup

In [1]:
# @title Install dependencies
!pip install biopython
!pip install torch
print("\nDone")

Collecting biopython
[?25l  Downloading https://files.pythonhosted.org/packages/76/02/8b606c4aa92ff61b5eda71d23b499ab1de57d5e818be33f77b01a6f435a8/biopython-1.78-cp36-cp36m-manylinux1_x86_64.whl (2.3MB)
[K     |████████████████████████████████| 2.3MB 7.6MB/s 
Installing collected packages: biopython
Successfully installed biopython-1.78

Done


In [2]:
# @title Import python packages
import os
import sys
import gzip
import copy
import time
import pandas
import pickle
import numpy as np
from tqdm.auto import tqdm
from Bio import SeqIO
from scipy.special import expit
from scipy.special import logit
from multiprocessing import Pool

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import weight_norm

# # import networkx as nx
# from scipy.sparse import coo_matrix
# from scipy.sparse import csr_matrix
# from scipy.sparse.csgraph import bellman_ford

print("Done")

Done


In [3]:
# @title Set parameters (double click top of cell to change defaults)
""" Print what the program is doing."""
verbose = True

""" Maximum ORF (open reading frame) overlap length in nucleotides."""
max_gene_overlap = 60

""" Minimum ORF length in nucleotides."""
min_orf_length = 60

""" Use kmer prefilter to increase gene sensitivity. 
May not play nice with very high GC genomes."""
# protein_kmer_filter = True
protein_kmer_filter = False

""" Use mmseqs2 and a gene score cutoff to remove most false positive predictions."""
# mmseqs2_gene_filter = True
mmseqs2_gene_filter = False


""" Nucleotide to amino acid translation table. 11 for most bacteria/archaea.
4 for Mycoplasma/Spiroplasma."""
translation_table = 11
# translation_table = 4

""" Maximum number of forward connections in the directed acyclic graph used to
find a set of coherent genes in each genome.
Higher values will slow execution time and increase memory usage,
but may slightly increase performace.
Recommended range ~30-50"""
max_forward_connections = 50

""" Batch size for the temporal convolutional network used to score genes.
Small batches and big batches slow down the model. Very big batches may crash the 
GPU. """
gene_batch_size = 200
TIS_batch_size = 1000

""" Where the pre-trained gene model should be saved."""
model_dir = "/home"

""" All following are internal parameters. Change at your own risk."""
weight_gene_prob = 0.9746869839852076 
weight_TIS_prob = 0.25380288790532707 
score_threshold = 0.47256101519707244
weight_ATG = 0.84249804151264 
weight_GTG = 0.7083689705744909
weight_TTG = 0.7512400826652517 
unidirectional_penalty_per_base = 3.895921717182765 # 3' 5' overlap
convergent_penalty_per_base = 4.603432608883688 # 3' 3' overlap
divergent_penalty_per_base = 3.3830814940689975 # 5' 5' overlap
k_seengene = 10
multimer_threshold = 2

print("Done")

Done


In [4]:
# @title Load pre-trained gene and translation initiation site models

""" if you're interested in the inner workings of the 
temporal convolutional network, see hubconf.py in the Github repo below."""

repo = "salzberg-lab/Balrog"
# repo = "salzberg-lab/Balrog:develop"

torch.hub.set_dir(model_dir)
if torch.cuda.device_count() > 0:
    print("GPU detected...")
    model = torch.hub.load(repo, "geneTCN", force_reload=True).cuda()
    model_tis = torch.hub.load(repo, "tisTCN", force_reload=False).cuda()
    time.sleep(0.5)
    print("\nDone")
else:
    print("No GPU detected, using CPU...")
    model = torch.hub.load(repo, "geneTCN", force_reload=True)
    model_tis = torch.hub.load(repo, "tisTCN", force_reload=False)
    time.sleep(0.5)
    print("\nDone")

GPU detected...


Downloading: "https://github.com/salzberg-lab/Balrog/archive/master.zip" to /home/master.zip
Using cache found in /home/salzberg-lab_Balrog_master



Done


In [5]:
# @title Prepare protein kmer filter
if protein_kmer_filter:
    # decompress kmer filter
    seengene_dir = "/content/kmerfilter/"
    !mkdir {seengene_dir}
    %cd {seengene_dir}
    !tar -xvzf /home/salzberg-lab_Balrog_master/kmer_filter/genexa_10mer_thresh2_minusARF_all.tar.gz
    # !tar -xvzf /home/salzberg-lab_Balrog_develop/kmer_filter/genexa_10mer_thresh2_minusARF_all.tar.gz
    genexa_kmer_path = os.path.join(seengene_dir, "10mer_thresh2_minusARF_all.pkl")

    # load kmer filter
    with open(genexa_kmer_path, "rb") as f:
        aa_kmer_set = pickle.load(f)

print("\nDone")


Done


In [6]:
# @title Prepare mmseqs2

if mmseqs2_gene_filter:
    ![ $(uname -m) = "x86_64" ] && echo "64bit: Yes" || echo "64bit: No"
    !grep -q sse4_1 /proc/cpuinfo && echo "SSE4.1: Yes" || echo "SSE4.1: No"
    !grep -q avx2 /proc/cpuinfo && echo "AVX2: Yes" || echo "AVX2: No"

    # install
    !mkdir /content/mmseqs2
    %cd /content/mmseqs2
    !wget https://mmseqs.com/latest/mmseqs-linux-avx2.tar.gz
    !tar xvzf mmseqs-linux-avx2.tar.gz

    # decompress fasta
    !mkdir /content/mmseqs2/genexa
    %cd /content/mmseqs2/genexa
    !!tar -xvzf /home/salzberg-lab_Balrog_master/protein_filter/genexa_genes.tar.gz
    # !!tar -xvzf /home/salzberg-lab_Balrog_develop/protein_filter/genexa_genes.tar.gz
    genexa_fasta_path = "/content/mmseqs2/genexa/genexa_genes.fasta"

    # create DB
    genexa_DB_path = "/content/mmseqs2/genexa/genexaDB"
    !/content/mmseqs2/mmseqs/bin/mmseqs createdb {genexa_fasta_path} {genexa_DB_path}

    # build mmseqs index
    !mkdir /content/mmseqs2/tmp
    !/content/mmseqs2/mmseqs/bin/mmseqs createindex {genexa_DB_path} /content/mmseqs2/tmp

    # download swissprot DB
    !mkdir /content/mmseqs2/swissprot
    !/content/mmseqs2/mmseqs/bin/mmseqs databases UniProtKB/Swiss-Prot /content/mmseqs2/swissprot/swissprotDB /content/mmseqs2/tmp
    swissprot_DB_path = "/content/mmseqs2/swissprot/swissprotDB"
    
print("\nDone")


Done


# Gene Prediction

In [26]:
# @title Upload prokaryotic genomes as FASTA or gzipped FASTA.
from google.colab import files
genome_dict = files.upload()

print("\nDone")

Saving GCA_000830275.1_ASM83027v1_genomic.fna.gz to GCA_000830275.1_ASM83027v1_genomic.fna.gz
Saving GCF_000015765.1_ASM1576v1_genomic.fna.gz to GCF_000015765.1_ASM1576v1_genomic.fna.gz
Saving GCF_000022205.1_ASM2220v1_genomic.fna.gz to GCF_000022205.1_ASM2220v1_genomic.fna.gz
Saving GCF_000024185.1_ASM2418v1_genomic.fna.gz to GCF_000024185.1_ASM2418v1_genomic.fna.gz
Saving GCF_000969965.1_ASM96996v1_genomic.fna.gz to GCF_000969965.1_ASM96996v1_genomic.fna.gz

Done


In [27]:
# @title Find genes

def tokenize_aa_seq(aa_seq):
    """ Convert amino acid letters to integers."""
    table = {"L": 1,
             "V": 2,
             "I": 3,
             "M": 4,
             "C": 5,
             "A": 6,
             "G": 7,
             "S": 8,
             "T": 9,
             "P": 10,
             "F": 11,
             "Y": 12,
             "W": 13,
             "E": 14,
             "D": 15,
             "N": 16,
             "Q": 17,
             "K": 18,
             "R": 19,
             "H": 20,
             "*": 0,
             "X": 0}
    tokenized = torch.tensor([table[aa] for aa in aa_seq])
    return tokenized


def get_start_codon(seq, orfcoords, strand):
    if strand == 1:
        # forward strand
        startcoord = orfcoords[0]
        return seq[startcoord-3:startcoord]
    else:
        # reverse strand
        startcoord = orfcoords[1]
        return seq[startcoord:startcoord+3].reverse_complement()


def find_ORFs(nuc_seq, minimum_length):
    """find positions of all open reading frames in given reading frame"""
    if translation_table == 11:
        starts = set(["ATG", "GTG", "TTG"])
        stops = set(["TAA", "TAG", "TGA"])
    elif translation_table == 4:
        starts = set(["ATG", "GTG", "TTG"])
        stops = set(["TAA", "TAG"])
    else:
        print("Translation table ", translation_table, " not implemented. Please open a GitHub issue if this is a problem.")
        sys.exit()

    ORF_startstop = []
    temp_starts = []
    l = len(nuc_seq)
    for i in range(0, l, 3): 
        if i==0 or nuc_seq[i:i+3] in starts:
            temp_starts.append(i)
            continue
        if ((nuc_seq[i:i+3] in stops) or (i+3==l)) and len(temp_starts) != 0:
            for start in temp_starts:
                if (i-start >= minimum_length):
                    ORF_startstop.append((start, i))
            temp_starts = []
    return ORF_startstop

# def find_ORFs(nuc_seq, minimum_length):
#     MAX_STARTS = 20

#     """find positions of all open reading frames in given reading frame"""
#     if translation_table == 11:
#         starts = set(["ATG", "GTG", "TTG"])
#         stops = set(["TAA", "TAG", "TGA"])
#     elif translation_table == 4:
#         starts = set(["ATG", "GTG", "TTG"])
#         stops = set(["TAA", "TAG"])
#     else:
#         print("Translation table ", translation_table, " not implemented. Please open a GitHub issue if this is a problem.")
#         sys.exit()

#     ORF_startstop = []
#     temp_starts = []
#     l = len(nuc_seq)
#     for i in range(0, l, 3):
#         if (i==0 or nuc_seq[i:i+3] in starts) and (len(temp_starts) < MAX_STARTS):
#             temp_starts.append(i)
#             continue
#         if ((nuc_seq[i:i+3] in stops) or (i+3==l)) and len(temp_starts) != 0:
#             for start in temp_starts:
#                 if (i-start >= minimum_length):
#                     ORF_startstop.append((start, i))
#             temp_starts = []
#     return ORF_startstop


def get_ORF_info(seq_list):
    ORF_seq = []
    ORF_coord = []
    ORF_nucseq = []
    for i, seq in enumerate(seq_list[:]):
        # frame 0: starts at 0
        # frame 1: starts at 1
        # frame 2: starts at 2
        # frame r0: ends at 0, MAY NOT START AT THE LAST COORD DUE TO MULTIPLE OF 3 DIFFERENCES
        # frame r1: ends at 1
        # frame r2: ends at 2

        seqstr = str(seq)
        seq_c = seq.complement()
        seqstr_c = str(seq_c)
        l = len(seqstr)
        frame_0_end = (l-0)-(l-0)%3+0
        frame_1_end = (l-1)-(l-1)%3+1
        frame_2_end = (l-2)-(l-2)%3+2

        frame_0 = find_ORFs(seqstr[0:frame_0_end], min_orf_length)
        frame_1 = find_ORFs(seqstr[1:frame_1_end], min_orf_length)
        frame_2 = find_ORFs(seqstr[2:frame_2_end], min_orf_length)

        frame_r0 = find_ORFs(seqstr_c[0:frame_0_end][::-1], min_orf_length)
        frame_r1 = find_ORFs(seqstr_c[1:frame_1_end][::-1], min_orf_length)
        frame_r2 = find_ORFs(seqstr_c[2:frame_2_end][::-1], min_orf_length)

        # standardize coords
        ORF_0f_standard_nuccoord = [(x[0]+3, x[1]) for x in frame_0]
        ORF_1f_standard_nuccoord = [(x[0]+4, x[1]+1) for x in frame_1]
        ORF_2f_standard_nuccoord = [(x[0]+5, x[1]+2) for x in frame_2]

        ORF_0r_standard_nuccoord = [(frame_0_end-x[1], frame_0_end-x[0]-3) for x in frame_r0]
        ORF_1r_standard_nuccoord = [(frame_1_end-x[1], frame_1_end-x[0]-3) for x in frame_r1]
        ORF_2r_standard_nuccoord = [(frame_2_end-x[1], frame_2_end-x[0]-3) for x in frame_r2]

        # translate once per frame, then slice
        aa_0 = str(seq[0:frame_0_end].translate(table=translation_table, to_stop=False))
        aa_1 = str(seq[1:frame_1_end].translate(table=translation_table, to_stop=False))
        aa_2 = str(seq[2:frame_2_end].translate(table=translation_table, to_stop=False))
        aa_r0 = str(seq_c[0:frame_0_end][::-1].translate(table=translation_table, to_stop=False))
        aa_r1 = str(seq_c[1:frame_1_end][::-1].translate(table=translation_table, to_stop=False))
        aa_r2 = str(seq_c[2:frame_2_end][::-1].translate(table=translation_table, to_stop=False))

        ORF_0f_aa = [aa_0[slice(*tuple(int(idx/3) for idx in x))][::-1] for x in frame_0] # reversed because model is trained with first amino acid directly upstream of stop codon
        ORF_1f_aa = [aa_1[slice(*tuple(int(idx/3) for idx in x))][::-1] for x in frame_1] 
        ORF_2f_aa = [aa_2[slice(*tuple(int(idx/3) for idx in x))][::-1] for x in frame_2]
        ORF_0r_aa = [aa_r0[slice(*tuple(int(idx/3) for idx in x))][::-1] for x in frame_r0]
        ORF_1r_aa = [aa_r1[slice(*tuple(int(idx/3) for idx in x))][::-1] for x in frame_r1]
        ORF_2r_aa = [aa_r2[slice(*tuple(int(idx/3) for idx in x))][::-1] for x in frame_r2]

        ORF_seq.append([ORF_0f_aa, ORF_1f_aa, ORF_2f_aa, 
                        ORF_0r_aa, ORF_1r_aa, ORF_2r_aa])
        ORF_coord.append([ORF_0f_standard_nuccoord, ORF_1f_standard_nuccoord, ORF_2f_standard_nuccoord, 
                          ORF_0r_standard_nuccoord, ORF_1r_standard_nuccoord, ORF_2r_standard_nuccoord])
        
        ORF_nucseq.append([str(seq[0:frame_0_end]), # all 5' to 3'
                           str(seq[1:frame_1_end]),
                           str(seq[2:frame_2_end]),
                           str(seq_c[0:frame_0_end][::-1]),
                           str(seq_c[1:frame_1_end][::-1]),
                           str(seq_c[2:frame_2_end][::-1])])
    return ORF_seq, ORF_nucseq, ORF_coord


def analyze_overlap(coords0, coords1, strand0, strand1,
                    unidirectional_penalty_per_base,
                    convergent_penalty_per_base,
                    divergent_penalty_per_base):
    overlap = coords0[1] - coords1[0] # TODO account for fully overlapped gene

    if overlap <= 0:
        compatible, penalty = True, 0
        return compatible, penalty
    
    if overlap > max_gene_overlap:
        compatible, penalty = False, 0
        return compatible, penalty

    # get prime locations
    if strand0 == 1:
        threeprime0 = coords0[1]
        fiveprime0 = coords0[0]
    else:
        threeprime0 = coords0[0]
        fiveprime0 = coords0[1]
    if strand1 == 1:
        threeprime1 = coords1[1]
        fiveprime1 = coords1[0]
    else:
        threeprime1 = coords1[0]
        fiveprime1 = coords1[1]
    
    # exclude ORFs in same frame sharing same stop codon
    if strand0 == strand1 and threeprime0 == threeprime1:
        compatible, penalty = False, 0
        return compatible, penalty

    # unidirectional overlap
    if (threeprime0 < fiveprime0) == (threeprime1 < fiveprime1):
        compatible, penalty = True, overlap * unidirectional_penalty_per_base
        return compatible, penalty

    # convergent overlap
    if (fiveprime0 < threeprime1 <= threeprime0) or (fiveprime1 < threeprime0 <= threeprime1):
        compatible, penalty = True, overlap * convergent_penalty_per_base
        return compatible, penalty
    
    # divergent overlap
    if (threeprime0 < fiveprime1 <= fiveprime0) or (threeprime1 < fiveprime0 <= fiveprime1):
        compatible, penalty = True, overlap * divergent_penalty_per_base
        return compatible, penalty

    return True, 0 # edge case of exactly 1 ORF

def predict(X):
    model.eval()
    with torch.no_grad():
        if torch.cuda.device_count() > 0:
            X_enc = F.one_hot(X, 21).permute(0,2,1).float().cuda()
            probs = expit(model(X_enc).cpu())
            del X_enc
            torch.cuda.empty_cache()
        else:
            X_enc = F.one_hot(X, 21).permute(0,2,1).float()
            probs = expit(model(X_enc).cpu())

    return probs

def predict_tis(X):
    model_tis.eval()
    with torch.no_grad():
        if torch.cuda.device_count() > 0:
            X_enc = F.one_hot(X, 4).permute(0,2,1).float().cuda()
        else:
            X_enc = F.one_hot(X, 4).permute(0,2,1).float()
        probs = expit(model_tis(X_enc).cpu())
    return probs

nuc_encode = {"A":0,
              "T":1,
              "G":2,
              "C":3,
              "N":0,
              "M":0,
              "R":0,
              "Y":0,
              "W":0,
              "K":0}
              
start_enc = {"ATG":0,
             "GTG":1,
             "TTG":2}

def tensor_to_seq(tensor):
    table = {0: "X",
             1: "L",
             2: "V",
             3: "I",
             4: "M",
             5: "C",
             6: "A",
             7: "G",
             8: "S",
             9: "T",
             10: "P",
             11: "F",
             12: "Y",
             13: "W",
             14: "E",
             15: "D",
             16: "N",
             17: "Q",
             18: "K",
             19: "R",
             20: "H"}
    return "".join([table[x] for x in tensor])

def kmerfilter(seq):
    kmerset = kmerize(seq, k_seengene)
    s = [x in aa_kmer_set for x in kmerset]
    seen = np.sum(s) >= multimer_threshold
    return seen

def kmerize(seq, k):
    kmerset = set()
    for i in range(len(seq) - k + 1):
        kmer = tuple(seq[i: i + k].tolist())
        kmerset.add(kmer)
    return kmerset

# find genes for each uploaded genome
GCF_list = []
contig_name_list = []
contig_length_list = []
contig_seq_list = []
contig_gene_coord_list = []
contig_gene_strand_list = []

for genome_name in genome_dict.keys():
    if verbose:
        print("Reading fasta file", str(genome_name) + "...\n")

    # read genome sequence
    seq_list = []
    contig_name_sublist = []
    contig_length_sublist = []
    if os.path.splitext(genome_name)[1].lower() == ".gz":
        with gzip.open(genome_name, "rt") as f:
            for record in SeqIO.parse(f, "fasta"):
                seq_list.append(record.seq)
                contig_name_sublist.append(record.id)
                contig_length_sublist.append(len(record.seq))
    else:
        with open(genome_name, "rt") as f:
            for record in SeqIO.parse(f, "fasta"):
                seq_list.append(record.seq)
                contig_name_sublist.append(record.id)
                contig_length_sublist.append(len(record.seq))
    contig_name_list.append(contig_name_sublist)
    contig_length_list.append(contig_length_sublist)
    contig_seq_list.append(seq_list)

    # get sequences and coordinates of ORFs
    if verbose:
        print("Finding and translating open reading frames...\n")

    ORF_seq_list, ORF_nucseq_list, ORF_coord_list = get_ORF_info(seq_list)

    # combine ORFs to submit to GPU in batches
    ORF_seq_combine = []
    for i, contig in enumerate(ORF_seq_list):
        for j, frame in enumerate(contig):
            for k, coord in enumerate(frame):
                ORF_seq_combine.append(coord)

    # encode amino acids as integers
    if verbose:
        print("Encoding amino acids...\n")
    ORF_seq_enc = [tokenize_aa_seq(x) for x in ORF_seq_combine]

    # seengene check
    if protein_kmer_filter:
        if verbose:
            print("Applying protein kmer filter...\n")
        seengene = []
        for s in ORF_seq_enc:
            kmerset = kmerize(s, k_seengene)
            s = [x in aa_kmer_set for x in kmerset]
            seen = np.sum(s) >= multimer_threshold

            seengene.append(seen)

    # score
    if verbose:
        print("Scoring ORFs with temporal convolutional network...\n")

    # sort by length to minimize impact of batch padding 
    ORF_lengths = np.asarray([len(x) for x in ORF_seq_enc])
    length_idx = np.argsort(ORF_lengths)
    ORF_seq_sorted = [ORF_seq_enc[i] for i in length_idx]

    # pad to allow creation of batch matrix
    prob_list = []
    for i in tqdm(range(0, len(ORF_seq_sorted), gene_batch_size), unit=" batch"):
        batch = ORF_seq_sorted[i:i+gene_batch_size]
        seq_lengths = torch.LongTensor(list(map(len, batch)))
        seq_tensor = torch.zeros((len(batch), seq_lengths.max())).long()

        for idx, (seq, seqlen) in enumerate(zip(batch, seq_lengths)):
            seq_tensor[idx, :seqlen] = torch.LongTensor(seq)

        pred_all = predict(seq_tensor)

        pred = []
        for j, length in enumerate(seq_lengths):
            subseq = pred_all[j, 0, 0:int(length)]
            predprob = float(expit(torch.mean(logit(subseq))))
            pred.append(predprob)
        
        prob_list.extend(pred)
    prob_arr = np.asarray(prob_list, dtype=float)

    # unsort
    unsort_idx = np.argsort(length_idx)
    ORF_prob = prob_arr[unsort_idx]

    # recombine ORFs
    idx = 0
    ORF_gene_score = copy.deepcopy(ORF_coord_list) # fill coord matrix with scores
    for i, contig in enumerate(ORF_gene_score):
        for j, frame in enumerate(contig):
            for k, coord in enumerate(frame):
                ORF_gene_score[i][j][k] = float(ORF_prob[idx])
                idx += 1

    # create strand information
    ORF_strand_flat = []
    for i, seq in enumerate(ORF_seq_list):
        if not any(seq):
            ORF_strand_flat.append([])
            continue
        n_forward = len(seq[0]) + len(seq[1]) + len(seq[2])
        n_reverse = len(seq[3]) + len(seq[4]) + len(seq[5])
        ORF_allframe_strand = [1]*n_forward + [-1]*n_reverse
        ORF_strand_flat.append(ORF_allframe_strand)

    # flatten coords within contigs
    ORF_coord_flat = [[item for sublist in x for item in sublist] for x in ORF_coord_list]

    # get ORF lengths
    ORF_length_flat = [[coords[1]-coords[0] for coords in x] for x in ORF_coord_flat]
    
    if verbose:
        print("Scoring translation initiation sites...\n")

    # extract nucleotide sequence surrounding potential start codons
    ORF_TIS_seq = copy.deepcopy(ORF_coord_list)
    ORF_start_codon = copy.deepcopy(ORF_coord_list)

    for i, contig in enumerate(ORF_TIS_seq):
        n = 0 # count to index into flat structure # TODO make sure this works as expected

        nucseq = ORF_nucseq_list[i][0] # easier to use coords relative to single nucseq
        nucseq_c = ORF_nucseq_list[i][3][::-1]
        contig_nuclength = len(nucseq)


        for j, frame in enumerate(contig):
            for k, temp in enumerate(frame):
                if any(temp):
                    coords = ORF_coord_list[i][j][k]
                    strand = ORF_strand_flat[i][n]
                    n += 1
                    if strand == 1:
                        fiveprime = coords[0]
                        if fiveprime >= 16 + 3: # NOTE 16 HARD CODED HERE
                            downstream = nucseq[fiveprime: fiveprime + 16]
                            upstream = nucseq[fiveprime - 16 - 3: fiveprime - 3]
                            start_codon = start_enc[nucseq[fiveprime-3: fiveprime]]
                            TIS_seq = torch.tensor([nuc_encode[c] for c in (upstream + downstream)[::-1]], dtype=int) # model scores 3' to 5' direction
                        else:
                            TIS_seq = -1 # deal with gene fragments later
                            start_codon = 2

                        ORF_TIS_seq[i][j][k] = TIS_seq
                        ORF_start_codon[i][j][k] = start_codon
                        
                    else: # reverse strand
                        fiveprime = coords[1]
                        if contig_nuclength - fiveprime + 3 >= 16 + 3: # NOTE 16 HARD CODED HERE
                            downstream = nucseq_c[fiveprime - 16: fiveprime][::-1]
                            upstream = nucseq_c[fiveprime + 3: fiveprime + 3 + 16][::-1]
                            start_codon = start_enc[nucseq_c[fiveprime: fiveprime + 3][::-1]]
                            TIS_seq = torch.tensor([nuc_encode[c] for c in (upstream + downstream)[::-1]], dtype=int) # model scores 3' to 5' direction
                        else:
                            TIS_seq = -1 # deal with gene fragments later
                            start_codon = 2
                            
                        ORF_TIS_seq[i][j][k] = TIS_seq
                        ORF_start_codon[i][j][k] = start_codon

    # flatten TIS for batching
    ORF_TIS_prob = copy.deepcopy(ORF_TIS_seq)

    ORF_TIS_seq_flat = []
    ORF_TIS_seq_idx = []
    for i, contig in enumerate(ORF_TIS_seq):
        for j, frame in enumerate(contig):
            for k, seq in enumerate(frame):
                if type(seq) == int: # fragment
                    ORF_TIS_prob[i][j][k] = 0.5 # HOW BEST TO DEAL WITH FRAGMENT TIS?
                elif len(seq) != 32:
                    ORF_TIS_prob[i][j][k] = 0.5 
                else:
                    ORF_TIS_seq_flat.append(seq)
                    ORF_TIS_seq_idx.append((i, j, k))

    # batch score TIS
    TIS_prob_list = []
    for i in tqdm(range(0, len(ORF_TIS_seq_flat), TIS_batch_size), unit=" batch"):
        batch = ORF_TIS_seq_flat[i:i+TIS_batch_size]
        TIS_stacked = torch.stack(batch)
        pred = predict_tis(TIS_stacked)

        TIS_prob_list.extend(pred)
    y_pred_TIS = np.asarray(TIS_prob_list, dtype=float)

    # reindex batched scores
    for i, prob in enumerate(y_pred_TIS):
        idx = ORF_TIS_seq_idx[i]
        ORF_TIS_prob[idx[0]][idx[1]][idx[2]] = float(prob)

    # combine all info into single score for each ORF
    if protein_kmer_filter:
        ORF_score_flat = []
        for i, contig in enumerate(ORF_gene_score):
            if not any(contig):
                ORF_score_flat.append([])
                continue
            temp = []
            seengene_idx = 0
            for j, frame in enumerate(contig):
                for k, geneprob in enumerate(frame):
                    length = ORF_coord_list[i][j][k][1] - ORF_coord_list[i][j][k][0] + 1 
                    TIS_prob = ORF_TIS_prob[i][j][k]
                    start_codon = ORF_start_codon[i][j][k]
                    ATG = start_codon == 0
                    GTG = start_codon == 1
                    TTG = start_codon == 2

                    combprob =   geneprob * weight_gene_prob \
                            + TIS_prob * weight_TIS_prob \
                            + ATG * weight_ATG \
                            + GTG * weight_TTG \
                            + TTG * weight_GTG
                    maxprob = weight_gene_prob + weight_TIS_prob + max(weight_ATG, weight_TTG, weight_GTG)
                    probthresh = score_threshold * maxprob
                    score = (combprob - probthresh) * length  + 1e6*seengene[seengene_idx]
                    seengene_idx += 1

                    temp.append(score)
            ORF_score_flat.append(temp)

    else:
        ORF_score_flat = []
        for i, contig in enumerate(ORF_gene_score):
            if not any(contig):
                ORF_score_flat.append([])
                continue
            temp = []
            for j, frame in enumerate(contig):
                for k, geneprob in enumerate(frame):
                    length = ORF_coord_list[i][j][k][1] - ORF_coord_list[i][j][k][0] + 1 
                    TIS_prob = ORF_TIS_prob[i][j][k]
                    start_codon = ORF_start_codon[i][j][k]
                    ATG = start_codon == 0
                    GTG = start_codon == 1
                    TTG = start_codon == 2

                    combprob =   geneprob * weight_gene_prob \
                            + TIS_prob * weight_TIS_prob \
                            + ATG * weight_ATG \
                            + GTG * weight_TTG \
                            + TTG * weight_GTG
                    maxprob = weight_gene_prob + weight_TIS_prob + max(weight_ATG, weight_TTG, weight_GTG)
                    probthresh = score_threshold * maxprob
                    score = (combprob - probthresh) * length

                    temp.append(score)
            ORF_score_flat.append(temp)

    # DAGs to maximize geneiness on each contig
    contig_gene_coord = []
    contig_gene_strand = []

    for i, coords in enumerate(ORF_coord_flat):
        if verbose:
            print("Creating graph of contig " + str(i) + "...\n")

        # sort coords, lengths, strands, and scores
        startpos = np.array([x[0] for x in coords])
        sortidx = list(np.argsort(startpos))

        coords_sorted = [coords[j] for j in sortidx]

        lengths = ORF_length_flat[i]
        lengths_sorted = [lengths[j] for j in sortidx]

        scores = ORF_score_flat[i]
        scores_sorted = [scores[j] for j in sortidx]

        strands = ORF_strand_flat[i]
        strands_sorted = [strands[j] for j in sortidx]

        # create DAG
        # keep track of graph path and score
        predecessor = np.zeros(len(scores_sorted), dtype=np.int64)
        max_path_score = np.zeros(len(scores_sorted), dtype=np.int64)

        # add null starting node
        n_connections = 0
        idx_offset = 1
        while n_connections < max_forward_connections:
            k = idx_offset
            idx_offset += 1
            if k > len(scores_sorted)-1: # dont try to add edge past last ORF
                n_connections += 1
                continue
            edge_weight = scores_sorted[k-1]

            # initial scores from null node
            max_path_score[k] = edge_weight
            predecessor[k] = 0
            
            n_connections += 1

        # add edges between compatible ORFs
        for j in tqdm(range(1, len(scores_sorted)-1), unit=" ORF"):
            n_connections = 0
            idx_offset = 1

            while n_connections < max_forward_connections:
                k = j + idx_offset
                idx_offset += 1

                if k > len(scores_sorted)-1: # dont try to add edge past end of contigs
                    n_connections += 1
                    continue 

                coords0 = coords_sorted[j-1]
                coords1 = coords_sorted[k-1]

                strand0 = strands_sorted[j-1]
                strand1 = strands_sorted[k-1]

                compat, penalty = analyze_overlap(coords0, coords1, 
                                                  strand0, strand1,
                                                  unidirectional_penalty_per_base,
                                                  convergent_penalty_per_base,
                                                  divergent_penalty_per_base)

                if compat:
                    score = scores_sorted[k-1] - penalty

                    path_score = max_path_score[j] + score
                    if path_score > max_path_score[k]:
                        max_path_score[k] = path_score
                        predecessor[k] = j


                    n_connections += 1


        # solve for geneiest path through contig
        if verbose:
            print("Maximizing geneiness...")

        pred_idx = np.argmax(max_path_score)
        pred_path = []
        while pred_idx > 0:
            pred_path.append(pred_idx)
            pred_idx = predecessor[pred_idx]

        # max_ORF_PATH = [x-1 for x in max_ORF_PATH_withnull[1:]]
        max_ORF_PATH = [x-1 for x in pred_path[:]] # 0 isnt added

        gene_predict_coords = [coords_sorted[j] for j in max_ORF_PATH]
        gene_predict_strand = [strands_sorted[j] for j in max_ORF_PATH]

        # mmseqs filter
        if mmseqs2_gene_filter:
            if verbose:
                print("\nFiltering predictions with mmseqs2...")

            # get amino acid sequence from coherent ORFs
            # 3' TO 5'
            aa_sorted = [ORF_seq_enc[j] for j in sortidx]
            aa_tensor = [aa_sorted[j] for j in max_ORF_PATH]
            aa_seq = [tensor_to_seq([int(y) for y in x]) for x in aa_tensor]

            # make temp dir to store mmseqs stuff
            finding_empty_dir = True
            dir_idx = 0
            while finding_empty_dir:
                dirpath = os.path.join("/content/mmseqs2/tmp", str(dir_idx))
                if os.path.isdir(dirpath):
                    dir_idx += 1
                    continue
                else:
                    !mkdir {dirpath}
                    finding_empty_dir = False
            
            # mmseqs create query DB     3' to 5'
            query_fasta_path_35 = os.path.join(dirpath, "candidate_genes_35.fasta")
            with open(query_fasta_path_35, "w") as f:
                for i, s in enumerate(aa_seq):
                    f.writelines(">" + str(i) + "\n")
                    f.writelines(str(s) + "\n")

            # mmseqs create query DB     5' to 3'
            query_fasta_path_53 = os.path.join(dirpath, "candidate_genes_53.fasta")
            with open(query_fasta_path_53, "w") as f:
                for i, s in enumerate(aa_seq):
                    f.writelines(">" + str(i) + "\n")
                    f.writelines(str(s)[::-1] + "\n")

            query_DB_path_35 = os.path.join(dirpath, "candidateDB_35")
            !/content/mmseqs2/mmseqs/bin/mmseqs createdb {query_fasta_path_35} {query_DB_path_35}
            
            query_DB_path_53 = os.path.join(dirpath, "candidateDB_53")
            !/content/mmseqs2/mmseqs/bin/mmseqs createdb {query_fasta_path_53} {query_DB_path_53}
            
            # mmseqs search
            results_DB_path_35 = os.path.join(dirpath, "resultsDB_35")
            !/content/mmseqs2/mmseqs/bin/mmseqs search -s 7.0 {query_DB_path_35} {genexa_DB_path} {results_DB_path_35} {dirpath}

            # convert to readable format
            m8_path_genexa = os.path.join(dirpath, "resultDB_genexa.m8")
            !/content/mmseqs2/mmseqs/bin/mmseqs convertalis {query_DB_path_35} {genexa_DB_path} {results_DB_path_35} {m8_path_genexa} --format-output "query,target,evalue,raw"

            # load search results
            mmseqs_results_genexa = pandas.read_table(m8_path_genexa, header=None, names=["query", "target", "evalue", "raw"]).to_numpy()

            # get hits
            hit_idx_genexa = np.unique(mmseqs_results_genexa[:, 0]).astype(int)

            # mmseqs search
            results_DB_path_53 = os.path.join(dirpath, "resultsDB_53")
            !/content/mmseqs2/mmseqs/bin/mmseqs search -s 7.0 {query_DB_path_53} {swissprot_DB_path} {results_DB_path_53} {dirpath}

            # convert to readable format
            m8_path_secondary = os.path.join(dirpath, "resultDB_secondary.m8")
            !/content/mmseqs2/mmseqs/bin/mmseqs convertalis {query_DB_path_53} {swissprot_DB_path} {results_DB_path_53} {m8_path_secondary} --format-output "query,target,evalue,raw"

            # load search results
            mmseqs_results_secondary = pandas.read_table(m8_path_secondary, header=None, names=["query", "target", "evalue", "raw"]).to_numpy()

            # get hits
            hit_idx_secondary = np.unique(mmseqs_results_secondary[:, 0]).astype(int)

            # filter gene predictions, keep if mmseqs hit or high enough gene score
            cutoff = 200

            cutoffpath = [x for i, x in enumerate(max_ORF_PATH) if (scores_sorted[x] > cutoff or (i in hit_idx_genexa or i in hit_idx_secondary))]
            gene_predict_coords = [coords_sorted[j] for j in cutoffpath]
            gene_predict_strand = [strands_sorted[j] for j in cutoffpath]

            graph_score_cutoff = [scores_sorted[j] for j in cutoffpath]
            contig_gene_coord.append(gene_predict_coords)
            contig_gene_strand.append(gene_predict_strand)

            n_genes = len(gene_predict_coords)
            if verbose:
                print("found", n_genes, "genes\n\n")


        else:
            cutoff = 100
            cutoffpath = [x for x in max_ORF_PATH if scores_sorted[x] > cutoff]
            gene_predict_coords = [coords_sorted[j] for j in cutoffpath]
            gene_predict_strand = [strands_sorted[j] for j in cutoffpath]

            graph_score_cutoff = [scores_sorted[j] for j in cutoffpath]
            contig_gene_coord.append(gene_predict_coords)
            contig_gene_strand.append(gene_predict_strand)

            n_genes = len(gene_predict_coords)
            if verbose:
                print("found", n_genes, "genes\n\n")
    contig_gene_coord_list.append(contig_gene_coord)
    contig_gene_strand_list.append(contig_gene_strand)

print("Done")

Reading fasta file GCA_000830275.1_ASM83027v1_genomic.fna.gz...

Finding and translating open reading frames...

Encoding amino acids...

Applying protein kmer filter...

Scoring ORFs with temporal convolutional network...



HBox(children=(FloatProgress(value=0.0, max=315.0), HTML(value='')))


Scoring translation initiation sites...



HBox(children=(FloatProgress(value=0.0, max=63.0), HTML(value='')))


Creating graph of contig 0...



HBox(children=(FloatProgress(value=0.0, max=62860.0), HTML(value='')))


Maximizing geneiness...

Filtering predictions with mmseqs2...
createdb /content/mmseqs2/tmp/26/candidate_genes_35.fasta /content/mmseqs2/tmp/26/candidateDB_35 

MMseqs Version:       	df69c26e1c9aaeaa3f5d72fd6e782d02742b2b0c
Database type         	0
Shuffle input database	true
Createdb mode         	0
Write lookup file     	1
Offset of numeric ids 	0
Compressed            	0
Verbosity             	3

Converting sequences
[1920] 0s 4ms
Time for merging to candidateDB_35_h: 0h 0m 0s 2ms
Time for merging to candidateDB_35: 0h 0m 0s 2ms
Database type: Aminoacid
Time for processing: 0h 0m 0s 14ms
createdb /content/mmseqs2/tmp/26/candidate_genes_53.fasta /content/mmseqs2/tmp/26/candidateDB_53 

MMseqs Version:       	df69c26e1c9aaeaa3f5d72fd6e782d02742b2b0c
Database type         	0
Shuffle input database	true
Createdb mode         	0
Write lookup file     	1
Offset of numeric ids 	0
Compressed            	0
Verbosity             	3

Converting sequences
[1920] 0s 5ms
Time for merging to ca

HBox(children=(FloatProgress(value=0.0, max=447.0), HTML(value='')))


Scoring translation initiation sites...



HBox(children=(FloatProgress(value=0.0, max=90.0), HTML(value='')))


Creating graph of contig 0...



HBox(children=(FloatProgress(value=0.0, max=89307.0), HTML(value='')))


Maximizing geneiness...

Filtering predictions with mmseqs2...
createdb /content/mmseqs2/tmp/27/candidate_genes_35.fasta /content/mmseqs2/tmp/27/candidateDB_35 

MMseqs Version:       	df69c26e1c9aaeaa3f5d72fd6e782d02742b2b0c
Database type         	0
Shuffle input database	true
Createdb mode         	0
Write lookup file     	1
Offset of numeric ids 	0
Compressed            	0
Verbosity             	3

Converting sequences
[2930] 0s 6ms
Time for merging to candidateDB_35_h: 0h 0m 0s 2ms
Time for merging to candidateDB_35: 0h 0m 0s 2ms
Database type: Aminoacid
Time for processing: 0h 0m 0s 17ms
createdb /content/mmseqs2/tmp/27/candidate_genes_53.fasta /content/mmseqs2/tmp/27/candidateDB_53 

MMseqs Version:       	df69c26e1c9aaeaa3f5d72fd6e782d02742b2b0c
Database type         	0
Shuffle input database	true
Createdb mode         	0
Write lookup file     	1
Offset of numeric ids 	0
Compressed            	0
Verbosity             	3

Converting sequences
[2930] 0s 6ms
Time for merging to ca

HBox(children=(FloatProgress(value=0.0, max=801.0), HTML(value='')))


Scoring translation initiation sites...



HBox(children=(FloatProgress(value=0.0, max=161.0), HTML(value='')))


Creating graph of contig 0...



HBox(children=(FloatProgress(value=0.0, max=113736.0), HTML(value='')))


Maximizing geneiness...

Filtering predictions with mmseqs2...
createdb /content/mmseqs2/tmp/28/candidate_genes_35.fasta /content/mmseqs2/tmp/28/candidateDB_35 

MMseqs Version:       	df69c26e1c9aaeaa3f5d72fd6e782d02742b2b0c
Database type         	0
Shuffle input database	true
Createdb mode         	0
Write lookup file     	1
Offset of numeric ids 	0
Compressed            	0
Verbosity             	3

Converting sequences
[4142] 0s 8ms
Time for merging to candidateDB_35_h: 0h 0m 0s 2ms
Time for merging to candidateDB_35: 0h 0m 0s 3ms
Database type: Aminoacid
Time for processing: 0h 0m 0s 22ms
createdb /content/mmseqs2/tmp/28/candidate_genes_53.fasta /content/mmseqs2/tmp/28/candidateDB_53 

MMseqs Version:       	df69c26e1c9aaeaa3f5d72fd6e782d02742b2b0c
Database type         	0
Shuffle input database	true
Createdb mode         	0
Write lookup file     	1
Offset of numeric ids 	0
Compressed            	0
Verbosity             	3

Converting sequences
[4142] 0s 7ms
Time for merging to ca

HBox(children=(FloatProgress(value=0.0, max=25767.0), HTML(value='')))


Maximizing geneiness...

Filtering predictions with mmseqs2...
createdb /content/mmseqs2/tmp/29/candidate_genes_35.fasta /content/mmseqs2/tmp/29/candidateDB_35 

MMseqs Version:       	df69c26e1c9aaeaa3f5d72fd6e782d02742b2b0c
Database type         	0
Shuffle input database	true
Createdb mode         	0
Write lookup file     	1
Offset of numeric ids 	0
Compressed            	0
Verbosity             	3

Converting sequences
[1718] 0s 4ms
Time for merging to candidateDB_35_h: 0h 0m 0s 2ms
Time for merging to candidateDB_35: 0h 0m 0s 2ms
Database type: Aminoacid
Time for processing: 0h 0m 0s 13ms
createdb /content/mmseqs2/tmp/29/candidate_genes_53.fasta /content/mmseqs2/tmp/29/candidateDB_53 

MMseqs Version:       	df69c26e1c9aaeaa3f5d72fd6e782d02742b2b0c
Database type         	0
Shuffle input database	true
Createdb mode         	0
Write lookup file     	1
Offset of numeric ids 	0
Compressed            	0
Verbosity             	3

Converting sequences
[1718] 0s 4ms
Time for merging to ca

HBox(children=(FloatProgress(value=0.0, max=20499.0), HTML(value='')))


Maximizing geneiness...

Filtering predictions with mmseqs2...
createdb /content/mmseqs2/tmp/30/candidate_genes_35.fasta /content/mmseqs2/tmp/30/candidateDB_35 

MMseqs Version:       	df69c26e1c9aaeaa3f5d72fd6e782d02742b2b0c
Database type         	0
Shuffle input database	true
Createdb mode         	0
Write lookup file     	1
Offset of numeric ids 	0
Compressed            	0
Verbosity             	3

Converting sequences
[1415] 0s 4ms
Time for merging to candidateDB_35_h: 0h 0m 0s 2ms
Time for merging to candidateDB_35: 0h 0m 0s 2ms
Database type: Aminoacid
Time for processing: 0h 0m 0s 12ms
createdb /content/mmseqs2/tmp/30/candidate_genes_53.fasta /content/mmseqs2/tmp/30/candidateDB_53 

MMseqs Version:       	df69c26e1c9aaeaa3f5d72fd6e782d02742b2b0c
Database type         	0
Shuffle input database	true
Createdb mode         	0
Write lookup file     	1
Offset of numeric ids 	0
Compressed            	0
Verbosity             	3

Converting sequences
[1415] 0s 4ms
Time for merging to ca

HBox(children=(FloatProgress(value=0.0, max=506.0), HTML(value='')))


Scoring translation initiation sites...



HBox(children=(FloatProgress(value=0.0, max=102.0), HTML(value='')))


Creating graph of contig 0...



HBox(children=(FloatProgress(value=0.0, max=101058.0), HTML(value='')))


Maximizing geneiness...

Filtering predictions with mmseqs2...
createdb /content/mmseqs2/tmp/31/candidate_genes_35.fasta /content/mmseqs2/tmp/31/candidateDB_35 

MMseqs Version:       	df69c26e1c9aaeaa3f5d72fd6e782d02742b2b0c
Database type         	0
Shuffle input database	true
Createdb mode         	0
Write lookup file     	1
Offset of numeric ids 	0
Compressed            	0
Verbosity             	3

Converting sequences
[4849] 0s 8ms
Time for merging to candidateDB_35_h: 0h 0m 0s 2ms
Time for merging to candidateDB_35: 0h 0m 0s 3ms
Database type: Aminoacid
Time for processing: 0h 0m 0s 21ms
createdb /content/mmseqs2/tmp/31/candidate_genes_53.fasta /content/mmseqs2/tmp/31/candidateDB_53 

MMseqs Version:       	df69c26e1c9aaeaa3f5d72fd6e782d02742b2b0c
Database type         	0
Shuffle input database	true
Createdb mode         	0
Write lookup file     	1
Offset of numeric ids 	0
Compressed            	0
Verbosity             	3

Converting sequences
[4849] 0s 8ms
Time for merging to ca

HBox(children=(FloatProgress(value=0.0, max=786.0), HTML(value='')))


Scoring translation initiation sites...



HBox(children=(FloatProgress(value=0.0, max=158.0), HTML(value='')))


Creating graph of contig 0...



HBox(children=(FloatProgress(value=0.0, max=157113.0), HTML(value='')))


Maximizing geneiness...

Filtering predictions with mmseqs2...
createdb /content/mmseqs2/tmp/32/candidate_genes_35.fasta /content/mmseqs2/tmp/32/candidateDB_35 

MMseqs Version:       	df69c26e1c9aaeaa3f5d72fd6e782d02742b2b0c
Database type         	0
Shuffle input database	true
Createdb mode         	0
Write lookup file     	1
Offset of numeric ids 	0
Compressed            	0
Verbosity             	3

Converting sequences
[9697] 0s 16ms
Time for merging to candidateDB_35_h: 0h 0m 0s 2ms
Time for merging to candidateDB_35: 0h 0m 0s 4ms
Database type: Aminoacid
Time for processing: 0h 0m 0s 35ms
createdb /content/mmseqs2/tmp/32/candidate_genes_53.fasta /content/mmseqs2/tmp/32/candidateDB_53 

MMseqs Version:       	df69c26e1c9aaeaa3f5d72fd6e782d02742b2b0c
Database type         	0
Shuffle input database	true
Createdb mode         	0
Write lookup file     	1
Offset of numeric ids 	0
Compressed            	0
Verbosity             	3

Converting sequences
[9697] 0s 14ms
Time for merging to 

In [28]:
# @title Download genome annotation (you may need to rerun this cell and/or allow multiple downloads in browser)
# TODO support different output formats

def write_GFF(start, end, contig, strand, contig_name_all, contig_length_all, contig_seq_all, GFF_path_out):
    # writes to same format as Prokka GFF
    bases_per_line = 60
    with open(GFF_path_out, "wt") as f:
        # header
        f.writelines(["##gff-version 3\n"])

        # contig names and lengths
        datalines = [" ".join(["##sequence-region", 
                               str(contig_name_all[i]), 
                               "1", 
                               str(contig_length_all[i]), 
                               "\n"]) for i in range(len(contig_name_all))]
        f.writelines(datalines)

        # CDS features
        datalines = ["\t".join([contig[i],
                                "Balrog",
                                "CDS",
                                str(int(start[i]) + 1),
                                end[i],
                                ".", # TODO: replace . with actual gene score
                                strand[i],
                                "0",
                                "inference=ab initio prediction:Balrog;product=hypothetical protein"
                                "\n"]) for i in range(len(start))]
        f.writelines(datalines)

        # # contig sequences
        # f.writelines(["##FASTA", "\n"])
        # for i, name in enumerate(contig_name_all):
        #     f.writelines([">", str(name), "\n"])
        #     bases_per_line = 60
        #     seq = contig_seq_all[i]
        #     datalines = [str(seq[j:j+bases_per_line])+"\n" for j in range(0, len(seq), bases_per_line)]
        #     f.writelines(datalines)

fasta_names = list(genome_dict.keys())
gff_names = [str(x) + "__.gff" for x in fasta_names] # simpler than removing all combinations of fasta and gz from the end

for genome_idx, gff in enumerate(gff_names):
    # combine all info for gene predictions
    contig_gene_start_flat = []
    contig_gene_end_flat = []
    contig_gene_strand_flat = []
    contig_gene_contig_flat = []
    try:
        for i, contig_gene_coord in enumerate(contig_gene_coord_list[genome_idx]): # TODO get rid of jenky nested lists
            for k, coord in enumerate(contig_gene_coord):
                start = str(coord[0] - 3)
                end = str(coord[1] + 3)
                strandnum = contig_gene_strand_list[genome_idx][i][k]
                if strandnum == 1:
                    strand = "+"
                else:
                    strand = "-"
                contig = str(contig_name_list[genome_idx][i])

                contig_gene_start_flat.append(start)
                contig_gene_end_flat.append(end)
                contig_gene_strand_flat.append(strand)
                contig_gene_contig_flat.append(contig)

        write_GFF(contig_gene_start_flat, contig_gene_end_flat, contig_gene_contig_flat, contig_gene_strand_flat, 
                contig_name_list[genome_idx], contig_length_list[genome_idx], contig_seq_list[genome_idx], gff)
    except:
        print("Could not generate ", gff)

for gff in gff_names:
    try:
        files.download(gff)
    except:
        print("Could not download ", gff)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

# MIT License
 Copyright (c) 2020 Markus J. Sommer & Steven L. Salzberg

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.