In [1]:
from collections import Counter, defaultdict, deque
import datetime
import itertools
import logging
import os
import sys

import edlib
import networkx as nx
import numpy as np

sys.path.append('../')


In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
from assembly_utils.map_utils import map_monoreads_to_monoassembly
from sequence_graph.db_graph import DeBruijnGraph
from sequence_graph.idb_graph import get_idb_monostring_set
from sequence_graph.db_graph_3col import DeBruijnGraph3Color
from sequence_graph.path_graph import PathDeBruijnGraph, IDBMappings
from sd_parser.sd_parser import SD_Report

from utils.various import filter_sublsts_n2_dict, index


In [4]:
k = 300
cen = 'X'

In [1144]:
now = datetime.datetime.now()
date = f'{now.year}{now.month:02}{now.day:02}'
date

'20200928'

In [6]:
## cen6
# monomers_fn = '/Poppy/abzikadze/centroFlye/centroFlye_repo/data/monomers/cen6/cen6_monomers_w_hybrids_chm13.fasta'
# sd_report_assembly_fn = '/Poppy/abzikadze/centroFlye/centroFlye_repo/experiments/20200808/cF_cen6_mid1/SD_assembly/final_decomposition.tsv'
# assembly_fn = '/Poppy/abzikadze/centroFlye/centroFlye_repo/data-share/assemblies/human/chm13/cen6/cFM6_v0.3.0_polish_combined.fasta'
# sd_report_assembly = SD_Report(sd_report_fn=sd_report_assembly_fn, monomers_fn=monomers_fn, sequences_fn=assembly_fn, mode='assembly')


In [7]:
## cenX
monomers_fn = '/Poppy/abzikadze/centroFlye/centroFlye_repo/data/monomers/cenX/cenX_monomers_w_hybrids_chm13_20200617.fasta'
sd_report_assembly_fn = '/Poppy/abzikadze/centroFlye/centroFlye_repo/experiments/20200831/cF_cenX_mid1/SD_assembly/final_decomposition.tsv'
assembly_fn = '/Poppy/abzikadze/centroFlye/centroFlye_repo/data/assemblies/cenX/cFX_v0.8.3_tQ_57ef8d6.fasta'
sd_report_assembly = SD_Report(sd_report_fn=sd_report_assembly_fn, monomers_fn=monomers_fn, sequences_fn=assembly_fn, mode='assembly')


In [8]:
monoassembly = sd_report_assembly.monostring_set.monostrings
raw_monoassembly = next(iter(monoassembly.values())).raw_monostring
len(raw_monoassembly)
# raw_monoassembly

18133

In [9]:
# dbs_assembly, _ = get_idb_monostring_set(string_set=sd_report_assembly.monostring_set,
#                                          mink=k, maxk=k,
#                                          outdir=f'../../experiments/{date}/idb_cen{cen}_assembly_k{k}',
#                                          mode='assembly')

In [10]:
## cen6
# dbs_assembly = {}
# dbs_assembly[k] = \
#    DeBruijnGraph.from_pickle('/Poppy/abzikadze/centroFlye/centroFlye_repo/experiments/20200831/cF_cen6_mid1/idb/db_k300.pickle')



In [11]:
## cenX
dbs_assembly = {}
dbs_assembly[k] = \
   DeBruijnGraph.from_pickle('/Poppy/abzikadze/centroFlye/centroFlye_repo/experiments/20200831/cF_cenX_mid1/idb/db_k300.pickle')



In [12]:
## cen6
# reads_fn = '/Poppy/abzikadze/centroFlye/centroFlye_repo/data/reads/cen6/centromeric_reads_6_rel5__S1C6H1L.d6z1_rel5_rds.fasta'

# sd_report_fn = '/Poppy/abzikadze/centroFlye/centroFlye_repo/experiments/20200609/SD_a967a0a_cen6_chm13_rel5_with_hybrids/final_decomposition.tsv'

# sd_report = SD_Report(sd_report_fn=sd_report_fn, monomers_fn=monomers_fn, sequences_fn=reads_fn, mode='ont')



In [13]:
## cenX
reads_fn = '/Poppy/abzikadze/centroFlye/centroFlye_repo/data/reads/cenX/centromeric_reads_X_rel5.fasta'

sd_report_fn = '/Poppy/abzikadze/centroFlye/centroFlye_repo/experiments/20200617/SD_8327dec_cenX_rel5/final_decomposition.tsv'

sd_report = SD_Report(sd_report_fn=sd_report_fn, monomers_fn=monomers_fn, sequences_fn=reads_fn, mode='ont')



In [14]:
monoreads_set = sd_report.monostring_set

In [15]:
locations = map_monoreads_to_monoassembly(monoreads_set, monoassembly)

# New mappings

In [171]:
from itertools import count

In [1074]:
class IDBMapping:
    def __init__(self, mappings):
        self.mappings = mappings
    
    def add(self, st, en, new):
        for r_id, mapping in self.mappings.items():
            for i, a, b in zip(count(), mapping, mapping[1:]):
                if a == st and b == en:
                    self.mappings[r_id].insert(i+1, new)
    
    def remove(self, edge):
        for r_id, mapping in self.mappings.items():
            self.mappings[r_id] = list(filter(lambda e: e != edge, mapping))
    
    def merge(self, st, en):
        # merge en into st
        for r_id, mapping in self.mappings.items():
            if len(mapping) and mapping[0] == en:
                self.mappings[r_id][0] = st
        self.remove(en)
        
    def get_active_connections(self):
        ac = set()
        for mapping in self.mappings.values():
            for a, b in zip(mapping, mapping[1:]):
                ac.add((a, b))
        return ac

In [238]:
idb_mapping_new = IDBMapping({0: [0, 1, 2, 3], 1: [1, 3, 5, 6, 1, 2]})

In [239]:
idb_mapping_new.mappings

{0: [0, 1, 2, 3], 1: [1, 3, 5, 6, 1, 2]}

In [240]:
idb_mapping_new.add(1, 2, -1)

In [241]:
idb_mapping_new.mappings

{0: [0, 1, -1, 2, 3], 1: [1, 3, 5, 6, 1, -1, 2]}

In [242]:
idb_mapping_new.remove(-1)

In [243]:
idb_mapping_new.mappings

{0: [0, 1, 2, 3], 1: [1, 3, 5, 6, 1, 2]}

In [244]:
idb_mapping_new.merge(0, 1)

In [245]:
idb_mapping_new.mappings

{0: [0, 2, 3], 1: [0, 3, 5, 6, 2]}

# Light DB

In [None]:
        def process_complex_fast():
            # complex vertex
            all_ac = self.idb_mappings.get_active_connections()

            in_indexes = [self.edge2index[e_in] for e_in in inedges]
            out_indexes = [self.edge2index[e_out] for e_out in outedges]

            ac = defaultdict(list)

            for e_in in in_indexes:
                for e_out in out_indexes:
                    if (e_in, e_out) in all_ac:
                        ac[e_in].append(e_out)
                        ac[e_out].append(e_in)

            # unglue and merge isolates
            isolates = {(e_in, e_out) for e_in in in_indexes for e_out in out_indexes
                        if ac[e_in] == [e_out] and ac[e_out] == [e_in]}
            print('isolates', isolates)
            for e_index, e_outdex in isolates:
                e_in = self.index2edge[e_index]
                e_out = self.index2edge[e_outdex]
                in_seq = self.edge2seq[e_index]
                out_seq = self.edge2seq[e_outdex]
                assert in_seq[-self.k+1:] == out_seq[:self.k-1]
                seq = in_seq + out_seq[self.k-1:]
                self.nx_graph.remove_edge(*e_in)
                self.nx_graph.remove_edge(*e_out)
                key = self.nx_graph.add_edge(e_in[0], e_out[1])
                del self.edge2index[e_in]
                new_edge = (e_in[0], e_out[1], key)
                self.edge2index[new_edge] = e_index
                self.index2edge[e_index] = new_edge
                self.idb_mappings.merge(e_index, e_outdex)
                self.edge2seq[e_index] = seq
                del self.edge2seq[e_outdex]

            # create new nodes and redirect existing edges
            for index in in_indexes:
                if len(ac[index]) >= 2:
                    w = self.max_node_index
                    edge = self.index2edge[index]
                    print(index, edge)
                    self.nx_graph.remove_edge(*edge)
                    del self.edge2index[edge]
                    new_edge = (edge[0], w, 0)
                    self.edge2index[new_edge] = index
                    self.index2edge[index] = new_edge
                    self.nx_graph.add_edge(*new_edge)                        
                    self.max_node_index += 1
                    
            for index in out_indexes:
                if len(ac[index]) >= 2:
                    w = self.max_node_index
                    edge = self.index2edge[index]
                    print(index, edge)
                    self.nx_graph.remove_edge(*edge)
                    del self.edge2index[edge]
                    new_edge = (w, edge[1], 0)
                    self.edge2index[new_edge] = index
                    self.index2edge[index] = new_edge
                    self.nx_graph.add_edge(*new_edge)                        
                    self.max_node_index += 1

            # create new edges b/w new nodes
            for e_index in in_indexes:
                if len(ac[e_index]) >= 2:
                    for e_outdex in ac[e_index]:
                        if len(ac[e_outdex]) >= 2:
                            in_seq = self.edge2seq[e_index]
                            out_seq = self.edge2seq[e_outdex]
                            assert in_seq[-self.k+1:] == out_seq[:self.k-1]
                            seq = in_seq[-self.k:] + [out_seq[self.k-1]]
                            assert len(seq) == self.k + 1
                            in_edge = self.index2edge[e_index]
                            out_edge = self.index2edge[e_outdex]
                            new_edge = (in_edge[1], out_edge[0], 0)
                            self.nx_graph.add_edge(*new_edge)
                            self.edge2index[new_edge] = self.max_edge_index
                            self.index2edge[self.max_edge_index] = new_edge
                            self.edge2seq[self.max_edge_index] = seq
                            self.max_edge_index += 1

            # move and extend the rest of the edges
            for e_index in in_indexes:
                if len(ac[e_index]) == 1:
                    e_outdex = ac[e_index][0]
                    if len(ac[e_outdex]) == 1:
                        continue
                    in_edge = self.index2edge[e_index]
                    out_edge = self.index2edge[e_outdex]
                    in_seq = self.edge2seq[e_index]
                    out_seq = self.edge2seq[e_outdex]
                    assert in_seq[-self.k+1:] == out_seq[:self.k-1]
                    seq = in_seq + [out_seq[self.k-1]]
                    self.edge2seq[e_index] = seq
                    self.nx_graph.remove_edge(*in_edge)
                    new_edge = (in_edge[0], out_edge[0], 0)
                    self.nx_graph.add_edge(*new_edge)
                    del self.edge2index[in_edge]
                    self.edge2index[new_edge] = e_index
                    self.index2edge[e_index] = new_edge

            for e_outdex in out_indexes:
                if len(ac[e_outdex]) == 1:
                    e_index = ac[e_outdex][0]
                    if len(ac[e_index]) == 1:
                        continue
                    in_edge = self.index2edge[e_index]
                    out_edge = self.index2edge[e_outdex]
                    in_seq = self.edge2seq[e_index]
                    out_seq = self.edge2seq[e_outdex]
                    assert in_seq[-self.k+1:] == out_seq[:self.k-1]
                    seq = [in_seq[-self.k]] + out_seq
                    self.edge2seq[e_outdex] = seq
                    self.nx_graph.remove_edge(*out_edge)
                    new_edge = (in_edge[1], out_edge[1], 0)
                    self.nx_graph.add_edge(*new_edge)
                    del self.edge2index[out_edge]
                    self.edge2index[new_edge] = e_outdex
                    self.index2edge[e_outdex] = new_edge

            if self.nx_graph.in_degree(u) == self.nx_graph.out_degree(u) == 0:
                self.nx_graph.remove_node(u)

In [1125]:
class LightMappedDBGraph:
    def __init__(self, nx_graph,
                 edge2seq, edge2index, index2edge,
                 max_edge_index, max_node_index,
                 k,
                 idb_mappings):
        self.nx_graph = nx_graph
        self.edge2seq = edge2seq
        self.edge2index = edge2index
        self.index2edge = index2edge
        self.max_edge_index = max_edge_index
        self.max_node_index = max_node_index
        self.k = k
        self.idb_mappings = idb_mappings

    @classmethod
    def fromDB(cls, db, string_set, neutral_symbs=None, raw_mappings=None):
        if raw_mappings is None:
            raw_mappings = db.map_strings(string_set,
                only_unique_paths=True,
                neutral_symbs=neutral_symbs)

        nx_graph = nx.MultiDiGraph()
        edge2seq = {}
        edge_index = 0
        edge2index = {}
        index2edge = {}
        for u, v, key, data in db.nx_graph.edges(data=True, keys=True):
            nx_graph.add_edge(u, v, key)
            edge2index[(u, v, key)] = edge_index
            index2edge[edge_index] = (u, v, key)
            edge2seq[edge_index] = list(data['string'])
            edge_index += 1
            
        mappings = {}
        for r_id, (raw_mapping, _, _) in raw_mappings.items():
            mappings[r_id] = [edge2index[edge] for edge in raw_mapping]
        
        idb_mappings = IDBMapping(mappings)
        max_node_index = 1 + max(nx_graph.nodes)

        return cls(nx_graph=nx_graph,
                   edge2seq=edge2seq,
                   edge2index=edge2index,
                   index2edge=index2edge,
                   max_edge_index=edge_index,
                   max_node_index=max_node_index,
                   k=db.k,
                   idb_mappings=idb_mappings)
        
    def move_edge(self, e1_st, e1_en, e1_key,
                  e2_st, e2_en, e2_key=None):
        old_edge = (e1_st, e1_en, e1_key)
        i = self.edge2index[old_edge]
        self.nx_graph.remove_edge(*old_edge)
        e2_key = self.nx_graph.add_edge(e2_st, e2_en, key=e2_key)
        new_edge = (e2_st, e2_en, e2_key)
        self.edge2index[new_edge] = i
        del self.edge2index[old_edge]
        self.index2edge[i] = new_edge
    
    def remove_edge(self, edge):
        self.idb_mappings.remove(self.edge2index[edge])
        self.nx_graph.remove_edge(*edge)
        index = self.edge2index[edge]
        del self.edge2index[edge]
        del self.index2edge[index]
        del self.edge2seq[index]
        
        for e in list(self.nx_graph.in_edges(edge[1], keys=True)):
            self.move_edge(*e, e[0], edge[0])
        
        for e in list(self.nx_graph.out_edges(edge[1], keys=True)):
            self.move_edge(*e, edge[0], e[1])
    
    def get_new_vertex_index(self):
        self.max_node_index += 1
        return self.max_node_index - 1
    
    def merge_edges(self, e1, e2):
        # merge edge e2 into e1
        i = self.edge2index[e1]
        j = self.edge2index[e2]
        self.idb_mappings.merge(e1, e2)
        in_seq = self.edge2seq[i]
        out_seq = self.edge2seq[j]
        assert in_seq[-self.k+1:] == out_seq[:self.k-1]
        seq = in_seq + out_seq[self.k-1:]
        self.edge2seq[i] = seq
        self.remove_edge(e2)
        self.move_edge(*e1, e1[0], e2[1])
        
    def add_edge(self, i, j, seq):
        in_edge = self.index2edge[i]
        out_edge = self.index2edge[j]
        new_edge = (in_edge[1], out_edge[0])
        key = self.nx_graph.add_edge(*new_edge)
        new_edge = (*new_edge, key)
        self.edge2index[new_edge] = self.max_edge_index
        self.index2edge[self.max_edge_index] = new_edge
        self.edge2seq[self.max_edge_index] = seq
        self.max_edge_index += 1
        
    def __process_vertex(self, u):
        def process_simple():
            if indegree == 1 and outdegree == 1:
                # node on nonbranching path - should not be happening
                assert False

            if indegree == 0 and outdegree == 0:
                # isolate - should be removed
                self.nx_graph.remove_node(u)

            elif indegree == 0 and outdegree > 0:
                # starting vertex
                for j in out_indexes[1:]:
                    old_edge = self.index2edge[j]
                    new_edge = (self.get_new_vertex_index(), old_edge[1], 0)
                    self.move_edge(*old_edge, *new_edge)

            elif indegree > 0 and outdegree == 0:
                # ending vertex
                for i in in_indexes[1:]:
                    old_edge = self.index2edge[i]
                    new_edge = (old_edge[0], self.get_new_vertex_index(), 0)
                    self.move_edge(*old_edge, *new_edge)

            elif indegree == 1 and outdegree > 1:
                # simple 1-in vertex
                assert len(in_indexes) == 1
                in_index = in_indexes[0]
                in_seq = self.edge2seq[in_index]
                c = in_seq[-self.k]
                for j in out_indexes:
                    assert self.edge2seq[j][:self.k-1] == in_seq[-self.k+1:]
                    self.edge2seq[j].insert(0, c)

            elif indegree > 1 and outdegree == 1:
                # simple 1-out vertex
                assert len(out_indexes) == 1
                out_index = out_indexes[0]
                out_seq = self.edge2seq[out_index]
                c = out_seq[self.k-1]
                for i in in_indexes:
                    assert self.edge2seq[i][-self.k+1:] == out_seq[:self.k-1]
                    self.edge2seq[i].append(c)

        def process_complex():
            # complex vertex
            for i in in_indexes:
                old_edge = self.index2edge[i]
                new_edge = (old_edge[0], self.get_new_vertex_index(), 0)
                self.move_edge(*old_edge, *new_edge)

            for j in out_indexes:
                old_edge = self.index2edge[j]
                new_edge = (self.get_new_vertex_index(), old_edge[1], 0)
                self.move_edge(*old_edge, *new_edge)
            
            all_ac = self.idb_mappings.get_active_connections()
            ac_s2e = defaultdict(set)
            ac_e2s = defaultdict(set)
            for e_in in in_indexes:
                for e_out in out_indexes:
                    if (e_in, e_out) in all_ac:
                        ac_s2e[e_in].add(e_out)
                        ac_e2s[e_out].add(e_in)

            merged = {}
            for i in ac_s2e:
                for j in ac_s2e[i]:
                    if i in merged:
                        i = merged[i]
                    if j in merged:
                        j = merged[j]
                    e_i = self.index2edge[i]
                    e_j = self.index2edge[j]
                    in_seq = self.edge2seq[i]
                    out_seq = self.edge2seq[j]
                    assert in_seq[-self.k+1:] == out_seq[:self.k-1]
                    if len(ac_s2e[i]) == len(ac_e2s[j]) == 1:
                        self.merge_edges(e_i, e_j)
                        merged[j] = i
                    elif len(ac_s2e[i]) >= 2 and len(ac_e2s[j]) >= 2:
                        seq = in_seq[-self.k:] + [out_seq[self.k-1]]
                        assert len(seq) == self.k + 1
                        self.add_edge(i, j, seq)
                    elif len(ac_s2e[i]) == 1 and len(ac_e2s[j]) >= 2:
                        # extend left edge to the right
                        self.move_edge(*e_i, e_i[0], e_j[0])
                        seq = in_seq + [out_seq[self.k-1]]
                        self.edge2seq[i] = seq
                    elif len(ac_e2s[j]) == 1 and len(ac_s2e[i]) >= 2:
                        # extend right edge to the left
                        self.move_edge(*e_j, e_i[1], e_j[1])
                        seq = [in_seq[-self.k]] + out_seq
                        self.edge2seq[j] = seq
                    else:
                        print(len(ac[i]), len(ac[j]))
                        assert False
            
            self.nx_graph.remove_node(u)
        
        in_indexes = [self.edge2index[e_in]
                      for e_in in self.nx_graph.in_edges(u, keys=True)]
        out_indexes = [self.edge2index[e_out]
                       for e_out in self.nx_graph.out_edges(u, keys=True)]

        indegree = self.nx_graph.in_degree(u)
        outdegree = self.nx_graph.out_degree(u)

        if indegree >= 2 and outdegree >= 2:
            process_complex()
        else:
            process_simple()

    def assert_validity(self):
        self.max_edge_index == 1 + max(self.index2edge)
        self.max_node_index == 1 + max(self.nx_graph.nodes)
        edges = set(self.nx_graph.edges(keys=True))
        assert edges == set(self.edge2index.keys())
        assert edges == set(self.index2edge.values())
        assert all(edge == self.index2edge[self.edge2index[edge]]
                   for edge in self.edge2index)
        for node in self.nx_graph.nodes:
            for in_edge in self.nx_graph.in_edges(node, keys=True):
                e_index = self.edge2index[in_edge]
                in_seq = self.edge2seq[e_index]
                for out_edge in self.nx_graph.out_edges(node, keys=True):
                    e_outdex = self.edge2index[out_edge]
                    out_seq = self.edge2seq[e_outdex]
                    assert in_seq[-self.k+1:] == out_seq[:self.k-1]
        
    def increase_k_by_one(self):           
        for u in list(self.nx_graph.nodes):
            self.__process_vertex(u)

        self.k += 1
        collapsed_edges = [edge for edge in self.nx_graph.edges
                           if len(self.edge2seq[self.edge2index[edge]]) == self.k - 1]
        [self.remove_edge(edge) for edge in collapsed_edges]  # remove collapsed edges
        
        self.assert_validity()
        
    def toDB(self):
        nx_graph = nx.MultiDiGraph()
        nodeindex2label = {}
        nodelabel2index = {}
        for i, (u, v, key) in self.index2edge.items():
            print(u, v, key)
            seq = tuple(self.edge2seq[i])
            u_label = seq[:self.k-1]
            v_label = seq[-self.k+1:]
            nodelabel2index[u_label] = u
            nodelabel2index[v_label] = v
            nodeindex2label[u] = u_label
            nodeindex2label[v] = v_label
            edge_len = len(seq) - self.k + 1
            cov = [1] * edge_len
            mean_cov = np.mean(cov)
            label = f'index={i}\nlen={edge_len}\ncov={mean_cov:0.2f}'
            nx_graph.add_edge(u, v, key=key,
                              coverage=cov,
                              edge_index=i,
                              edge_len=edge_len,
                              label=label,
                              string=seq,
                              color='black')
        print(nodelabel2index)
        print(nx_graph.nodes)
        print(len(nodelabel2index), len(nodelabel2index), len(nx_graph.nodes))
        db = DeBruijnGraph(k=self.k,
                           nx_graph=nx_graph,
                           nodeindex2label=nodeindex2label,
                           nodelabel2index=nodelabel2index)
        return db

In [1129]:
# graph with a loop

class DBLoop3:
    k=3
    nx_graph = nx.MultiDiGraph()
    nx_graph.add_edge(0, 1, string='CCAA')
    nx_graph.add_edge(1, 1, string='AACAA')
    nx_graph.add_edge(1, 1, string='AAGAA')
    nx_graph.add_edge(1, 2, string='AATT')
    nx_graph.add_edge(3, 1, string='GGAA')
    nx_graph.add_edge(1, 4, string='AARR')

lightdb = LightMappedDBGraph.fromDB(DBLoop3(), string_set={},
                                    raw_mappings={0: ([(0, 1, 0), (1, 1, 0), (1, 2, 0)], 0, 0),
                                                  1: ([(3, 1, 0), (1, 1, 1), (1, 1, 1), (1, 4, 0)], 0, 0),
                                                  2: ([(3, 1, 0), (1, 2, 0)], 0, 0)})
lightdb.increase_k_by_one()
assert [(edge, lightdb.edge2index[edge]) for edge in lightdb.nx_graph.edges] == \
    [((0, 11, 0), 0), ((3, 8, 0), 5), ((7, 10, 0), 6), ((7, 4, 0), 4), ((8, 10, 0), 7), ((8, 11, 0), 8),
     ((10, 7, 0), 2), ((11, 2, 0), 3)]
lightdb.edge2seq
assert lightdb.edge2seq == \
    {0: ['C', 'C', 'A', 'A', 'C', 'A', 'A', 'T'],
     2: ['A', 'A', 'G', 'A', 'A'],
     3: ['A', 'A', 'T', 'T'],
     4: ['G', 'A', 'A', 'R', 'R'],
     5: ['G', 'G', 'A', 'A'],
     6: ['G', 'A', 'A', 'G'],
     7: ['G', 'A', 'A', 'G'],
     8: ['G', 'A', 'A', 'T']}

In [1063]:
# graph with a loop

class DBLoop2:
    k=3
    nx_graph = nx.MultiDiGraph()
    nx_graph.add_edge(0, 1, string='CCAA')
    nx_graph.add_edge(1, 1, string='AACAA')
    nx_graph.add_edge(1, 1, string='AAGAA')
    nx_graph.add_edge(1, 2, string='AATT')

lightdb = LightMappedDBGraph.fromDB(DBLoop2(), string_set={},
                                    raw_mappings={0: ([(0, 1, 0), (1, 1, 0), (1, 2, 0)], 0, 0),
                                                  1: ([(0, 1, 0), (1, 1, 1), (1, 2, 0)], 0, 0)})
lightdb.increase_k_by_one()
assert [(edge, lightdb.edge2index[edge]) for edge in lightdb.nx_graph.edges] == \
    [((0, 3, 0), 0), ((3, 8, 0), 1), ((3, 8, 1), 2), ((8, 2, 0), 3)]
lightdb.edge2seq
assert lightdb.edge2seq == \
    {0: ['C', 'C', 'A', 'A'],
     1: ['C', 'A', 'A', 'C', 'A', 'A', 'T'],
     2: ['C', 'A', 'A', 'G', 'A', 'A', 'T'],
     3: ['A', 'A', 'T', 'T']}

In [1064]:
# graph with a loop

class DBLoop:
    k=3
    nx_graph = nx.MultiDiGraph()
    nx_graph.add_edge(0, 1, string='CCAA')
    nx_graph.add_edge(1, 1, string='AACAA')
    nx_graph.add_edge(1, 2, string='AATT')

lightdb = LightMappedDBGraph.fromDB(DBLoop(), string_set={},
                                    raw_mappings={0: ([(0, 1, 0), (1, 1, 0)], 0, 0),
                                                  1: ([(1, 1, 0), (1, 2, 0)], 0, 0)})
lightdb.increase_k_by_one()
assert [(edge, lightdb.edge2index[edge]) for edge in lightdb.nx_graph.edges] == \
    [((0, 2, 0), 0)]
lightdb.edge2seq
assert lightdb.edge2seq == \
    {0: ['C', 'C', 'A', 'A', 'C', 'A', 'A', 'T', 'T']}

In [1065]:
# graph with a complex vertex

class DBComplexVertex1:
    k=3
    nx_graph = nx.MultiDiGraph()
    nx_graph.add_edge(0, 2, string='ACAAA')
    nx_graph.add_edge(1, 2, string='GGAAA')
    nx_graph.add_edge(2, 3, string='AATGC')
    nx_graph.add_edge(2, 4, string='AATT')

lightdb = LightMappedDBGraph.fromDB(DBComplexVertex1(), string_set={},
                                    raw_mappings={0: ([(0, 2, 0), (2, 3, 0)], 0, 0),
                                                  1: ([(0, 2, 0), (2, 4, 0)], 0, 0),
                                                  2: ([(1, 2, 0), (2, 4, 0)], 0, 0)})
assert [(edge, lightdb.edge2index[edge]) for edge in lightdb.nx_graph.edges] == \
    [((0, 2, 0), 0), ((2, 3, 0), 1), ((2, 4, 0), 2), ((1, 2, 0), 3)]
lightdb.increase_k_by_one()
print([(edge, lightdb.edge2index[edge]) for edge in lightdb.nx_graph.edges])
print(lightdb.edge2seq)
assert [(edge, lightdb.edge2index[edge]) for edge in lightdb.nx_graph.edges] == \
    [((0, 5, 0), 0), ((1, 8, 0), 3), ((5, 3, 0), 1), ((5, 8, 0), 4), ((8, 4, 0), 2)]
assert lightdb.edge2seq == \
    {0: ['A', 'C', 'A', 'A', 'A'],
     1: ['A', 'A', 'A', 'T', 'G', 'C'],
     2: ['A', 'A', 'T', 'T'],
     3: ['G', 'G', 'A', 'A', 'A', 'T'],
     4: ['A', 'A', 'A', 'T']}

[((0, 5, 0), 0), ((1, 8, 0), 3), ((5, 3, 0), 1), ((5, 8, 0), 4), ((8, 4, 0), 2)]
{0: ['A', 'C', 'A', 'A', 'A'], 1: ['A', 'A', 'A', 'T', 'G', 'C'], 2: ['A', 'A', 'T', 'T'], 3: ['G', 'G', 'A', 'A', 'A', 'T'], 4: ['A', 'A', 'A', 'T']}


In [1066]:
# graph with a complex vertex

class DBComplexVertex3:
    k=3
    nx_graph = nx.MultiDiGraph()
    nx_graph.add_edge(0, 2, string='CCAC')
    nx_graph.add_edge(1, 2, string='TTAC')
    nx_graph.add_edge(2, 5, string='ACG')
    nx_graph.add_edge(3, 5, string='AACG')
    nx_graph.add_edge(5, 4, string='CGTA')
    nx_graph.add_edge(5, 6, string='CGA')
    nx_graph.add_edge(6, 7, string='GACC')
    nx_graph.add_edge(6, 8, string='GATT')
    

lightdb = LightMappedDBGraph.fromDB(DBComplexVertex3(), string_set={},
                                    raw_mappings={0: ([(0, 2, 0), (2, 5, 0), (5, 6, 0), (6, 7, 0)], 0, 0),
                                                  1: ([(1, 2, 0), (2, 5, 0), (5, 4, 0)], 0, 0),
                                                  2: ([(3, 5, 0), (5, 6, 0), (6, 8, 0)], 0, 0)})
assert [(edge, lightdb.edge2index[edge]) for edge in lightdb.nx_graph.edges] == \
    [((0, 2, 0), 0), ((2, 5, 0), 1), ((5, 4, 0), 3), ((5, 6, 0), 4), ((1, 2, 0), 2),
     ((6, 7, 0), 6), ((6, 8, 0), 7), ((3, 5, 0), 5)]
lightdb.increase_k_by_one()
assert [(edge, lightdb.edge2index[edge]) for edge in lightdb.nx_graph.edges] == \
    [((0, 2, 0), 0), ((2, 4, 0), 3), ((2, 12, 0), 8), ((1, 2, 0), 2), ((3, 12, 0), 5),
     ((12, 7, 0), 6), ((12, 8, 0), 7)]
assert lightdb.edge2seq == \
    {0: ['C', 'C', 'A', 'C', 'G'],
     2: ['T', 'T', 'A', 'C', 'G'],
     3: ['A', 'C', 'G', 'T', 'A'],
     5: ['A', 'A', 'C', 'G', 'A'],
     6: ['C', 'G', 'A', 'C', 'C'],
     7: ['C', 'G', 'A', 'T', 'T'],
     8: ['A', 'C', 'G', 'A']}

In [1067]:
# graph with a complex vertex

class DBComplexVertex2:
    k=3
    nx_graph = nx.MultiDiGraph()
    nx_graph.add_edge(0, 2, string='ACAAA')
    nx_graph.add_edge(1, 2, string='GGAAA')
    nx_graph.add_edge(2, 3, string='AATGC')
    nx_graph.add_edge(2, 4, string='AATT')

lightdb = LightMappedDBGraph.fromDB(DBComplexVertex2(), string_set={},
                                    raw_mappings={0: ([(0, 2, 0), (2, 3, 0)], 0, 0),
                                                  1: ([(0, 2, 0), (2, 4, 0)], 0, 0),
                                                  2: ([(1, 2, 0), (2, 4, 0)], 0, 0),
                                                  3: ([(1, 2, 0), (2, 3, 0)], 0, 0)})
assert [(edge, lightdb.edge2index[edge]) for edge in lightdb.nx_graph.edges] == \
    [((0, 2, 0), 0), ((2, 3, 0), 1), ((2, 4, 0), 2), ((1, 2, 0), 3)]
lightdb.increase_k_by_one()
assert [(edge, lightdb.edge2index[edge]) for edge in lightdb.nx_graph.edges] == \
    [((0, 5, 0), 0), ((1, 6, 0), 3), ((5, 7, 0), 4), ((5, 8, 0), 5), ((6, 7, 0), 6),
     ((6, 8, 0), 7), ((7, 3, 0), 1), ((8, 4, 0), 2)]
assert lightdb.edge2seq == \
    {0: ['A', 'C', 'A', 'A', 'A'],
     1: list('AATGC'),
     2: list('AATT'),
     3: ['G', 'G', 'A', 'A', 'A'],
     4: list('AAAT'),
     5: list('AAAT'),
     6: list('AAAT'),
     7: list('AAAT')}

In [1068]:
# graph starting vertex

class DBStVertex:
    k=3
    nx_graph = nx.MultiDiGraph()
    nx_graph.add_edge(0, 1, string='AAAAA')
    nx_graph.add_edge(0, 2, string='AAACA')
    nx_graph.add_edge(0, 3, string='AAA')    

    
lightdb = LightMappedDBGraph.fromDB(DBStVertex(), string_set={}, raw_mappings={0: ([(0, 1, 0)], 0, 0),
                                                                               1: ([(0, 2, 0)], 0, 0),
                                                                               2: ([(0, 3, 0)], 0, 0)})
assert lightdb.idb_mappings.mappings == {0: [0], 1: [1], 2: [2]}
lightdb.increase_k_by_one()
assert [(edge, lightdb.edge2index[edge]) for edge in lightdb.nx_graph.edges] == \
        [((0, 1, 0), 0), ((4, 2, 0), 1)]
assert lightdb.idb_mappings.mappings == {0: [0], 1: [1], 2: []}

In [1069]:
# graph ending vertex

class DBEnVertex:
    k=3
    nx_graph = nx.MultiDiGraph()
    nx_graph.add_edge(0, 3, string='AAAA')
    nx_graph.add_edge(1, 3, string='AACA')
    nx_graph.add_edge(2, 3, string='AAA')
    


lightdb = LightMappedDBGraph.fromDB(DBEnVertex(), string_set={}, raw_mappings={0: ([(0, 3, 0)], 0, 0),
                                                                               1: ([(1, 3, 0)], 0, 0),
                                                                               2: ([(2, 3, 0)], 0, 0)})
assert lightdb.idb_mappings.mappings == {0: [0], 1: [1], 2: [2]}
lightdb.increase_k_by_one()
assert [(edge, lightdb.edge2index[edge]) for edge in lightdb.nx_graph.edges] == \
        [((0, 3, 0), 0), ((1, 4, 0), 1)]
assert lightdb.idb_mappings.mappings == {0: [0], 1: [1], 2: []}

In [1070]:
# graph 1-in >1-out

class DB1inVertex:
    k=3
    nx_graph = nx.MultiDiGraph()
    nx_graph.add_edge(0, 1, string='AACAG')
    nx_graph.add_edge(1, 2, string='AGACC')
    nx_graph.add_edge(1, 3, string='AGATT')
    nx_graph.add_edge(1, 4, string='AGAGG')
    

lightdb = LightMappedDBGraph.fromDB(DB1inVertex(), string_set={}, raw_mappings={})
assert [lightdb.edge2seq[lightdb.edge2index[edge]] for edge in lightdb.nx_graph.edges] == \
    [list('AACAG'), ['A', 'G', 'A', 'C', 'C'], ['A', 'G', 'A', 'T', 'T'], ['A', 'G', 'A', 'G', 'G']]
lightdb.increase_k_by_one()
assert [lightdb.edge2seq[lightdb.edge2index[edge]] for edge in lightdb.nx_graph.edges] == \
    [list('AACAG'), ['C', 'A', 'G', 'A', 'C', 'C'], ['C', 'A', 'G', 'A', 'T', 'T'], ['C', 'A', 'G', 'A', 'G', 'G']]

In [1071]:
# graph >1-in 1-out

class DB1outVertex:
    k=3
    nx_graph = nx.MultiDiGraph()
    nx_graph.add_edge(0, 3, string='CCAGA')
    nx_graph.add_edge(1, 3, string='TTAGA')
    nx_graph.add_edge(2, 3, string='GGAGA')
    nx_graph.add_edge(3, 4, string='GAAAA')
    

lightdb = LightMappedDBGraph.fromDB(DB1outVertex(), string_set={}, raw_mappings={})
assert sorted([lightdb.edge2seq[lightdb.edge2index[edge]] for edge in lightdb.nx_graph.edges]) == \
    sorted([list('CCAGA'), list('TTAGA'), list('GGAGA'), list('GAAAA')])
lightdb.increase_k_by_one()
assert sorted([lightdb.edge2seq[lightdb.edge2index[edge]] for edge in lightdb.nx_graph.edges]) == \
    sorted([list('CCAGAA'), list('TTAGAA'), list('GGAGAA'), list('GAAAA')])

In [1164]:
# real graph

lightdb = LightMappedDBGraph.fromDB(dbs_assembly[k], string_set=monoreads_set, neutral_symbs=set('?'),
                                    raw_mappings=mappings)

In [1165]:
K = 302

In [1166]:
for i in range(K-k):
    print(i, lightdb.increase_k_by_one())

0 None
1 None


In [1167]:
transformed_db = lightdb.toDB()

3791 3804 0
3790 3791 1
3804 3803 0
3804 753 0
3811 106 0
3795 3791 0
3795 486 0
3803 3790 0
106 3805 0
106 3787 0
106 2728 0
486 106 0
753 955 0
955 1039 0
955 1039 1
1039 1073 0
1073 753 0
1073 3811 0
3787 3795 0
3787 486 0
3790 3791 0
3806 3795 0
106 3803 0
3795 3811 0
{(2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 

In [1168]:
# transformed_db.write_dot(f'../../experiments/{date}/transformed_test', compact=True)

# Testing correctness

In [1169]:
dbs_assembly_reference, _ = \
    get_idb_monostring_set(string_set=sd_report_assembly.monostring_set,
                           mink=K, maxk=K,
                           outdir=None,
                           mode='assembly')

In [1170]:
color3graph = DeBruijnGraph3Color.from_db_graphs(gr_assembly=dbs_assembly_reference[K], gr_reads=transformed_db)

In [1171]:
date

'20200928'

In [1172]:
color3graph.write_dot(f'../../experiments/{date}/c3g_test', compact=True)

In [1173]:
!realpath ../../experiments/{date}/c3g_test

/Poppy/abzikadze/centroFlye/centroFlye_repo/experiments/20200928/c3g_test
