### code generation for forward substitution 

1. download `L` and `b` matrices to `../data`
2. rename the files to `../data/{foo}L.mtx` for `L` and correspondingly for `b`
3. change `matr = foo` in next cell and `opath` to output file 
4. run all the cells
5. add corresponding lines in header files and recompile the code ..

In [None]:
mname =  "tsopf"
matr =  f"s_{mname}"
opath = f"../src/codegen_{mname}.cpp"

In [None]:
import scipy.io as sio
from scipy import sparse
data_folder = "../data/"

In [None]:
def find_reachset(n, Gp, Gi, Gx, Bp, Bi, Bx):
    reachset = []

    top = 0
    S = []
    D = [False for x in range(n)]

    p = Bp[0]
    while p < Bp[1]:
        S.append(Bi[p])
        p += 1

    while len(S) > 0:
        v = S.pop()
        if D[v] == False:
            D[v] = True
            p = Gp[v]
            while p < Gp[v+1]:
                if D[Gi[p]] == False:
                    S.append(Gi[p])
                p += 1

    for i in range(n):
        if D[i]:
            reachset.append(i)

    return reachset 

# [2, 6,7, 9,10] -> [[2,2], [6,7], [9,10]]
def consec(lst):
    it = iter(lst)
    prev = next(it)
    tmp = [prev]
    for ele in it:
        if prev + 1 != ele:
            yield tmp
            tmp = [ele]
        else:
            tmp.append(ele)
        prev = ele
    yield tmp
    
# [2, 6,7, 9,10] -> [[2, 3], [6, 8], [9, 11]]
def consec_range(lst):
    return [[min(l), max(l)+1] for l in consec(lst)]


class ctypes(object):
    cint = "int"
    cfloat = "float"
    cdouble = "double"
    

add    = lambda x, y:     f"{x} + {y}"
var    = lambda t, x:     f"{t} {x}"
access = lambda A, i:     f"{A}[{i}]"
assign = lambda a, b:     f"{a} = {b}"
semicolon = lambda :      f";"
nextline = lambda :       f"\n"
div =    lambda a, b :    f"{a} / {b}"
divassign = lambda a, b : f"{a} /= {b}"
subassign = lambda a, b : f"{a} -= {b}"
mult=    lambda a, b :    f"{a} * {b}"
preinc = lambda x :       f"++{x}"
lt  =    lambda x, y :    f"{x} < {y}"
gt  =    lambda x, y :    f"{x} > {y}"
leq =    lambda x, y :    f"{x} <= {y}"
geq =    lambda x, y :    f"{x} >= {y}"
block =  lambda bs   :    "".join(["    "+b+semicolon()+nextline() for b in bs])
loop  =  lambda i, t, u, bs: """for ({i}; {t}; {u}) {{\n{bs}}}""".format(i=i, t=t, u=u, bs=bs)


class Lsolver(object):
    
    # Lp, bp are path to L and b
    def __init__(self, Lp, bp):

        L = sio.mmread(Lp).tocsc()
        b = sio.mmread(bp).tocsc()
        self.L = L
        self.b = b
        
        self.Lp = L.indptr
        self.Li = L.indices
        self.Lx = L.data
        self.bp = b.indptr
        self.bi = b.indices 
        self.bx = b.data
        
        
        # find reachset 

        self.reachset = find_reachset(
            b.shape[0], self.Lp, self.Li, self.Lx, self.bp, self.bi, self.bx )
        
        # Loops to be peeled 
        # type : [ colidx ... ]
        self.peeled = []
        unpeeled = [rs for rs in self.reachset]
        
        def col_nnz(j):
            return self.Lp[j+1] - self.Lp[j]
        
        for i, rc in enumerate(self.reachset):
            # peel loop with more than 2 nonzeros
            if col_nnz(rc) > 2:
                self.peeled.append(rc)
                unpeeled[i] = -1
            else:
                unpeeled[i] = i
                
        unpeeled = list(filter(lambda a: a != -1, unpeeled))
        
        
        # unpeeled loop ranges ...
        #     type : [ [init, end], ... ]
        self.unpeeled = consec_range(unpeeled)
        
        assert(type(int) for x in self.peeled)
        assert(len(l) == 2 for l in self.unpeeled)
    
        
    def codegen(self):
        
        stms = []
        
        # hardcode ...

        n  = "n"
        i  = "i"
        p  = "p"
        px = "px"
        j  = "j"
        Lp = "Lp"
        Li = "Li"
        Lx = "Lx"
        x  = "x"
        rs = "reachset"

        # L_{j,j}
        Ldiag = lambda j :   access(Lx, access(Lp, j))
        # nonzeros in L[:,j]
        Lcols = lambda p :   access(Lp, p)
        
        # computation on col `j`
        #   1. solve x_j
        #   2. update nonzeros on col `j` with value of `x_j`
        def lsolve_update_col(j):
            
            if type(j) == str:
                # unpeeled loop
                start = add(access(Lp, j), 1)
                end   = access(Lp, add(j, 1))
                inz   = access(Lp, j)
            else:
                # peeled loop
                start = self.Lp[j]+1
                end   = self.Lp[j+1]
                inz   = self.Lp[j]
            
            return [
                divassign(access(x, j), access(Lx, inz)),
                loop(
                    assign(p, start),
                    lt(p, end),
                    preinc(p), 
                    block([
                        subassign(
                            access(x, access(Li, p)),
                            mult(access(Lx, p), access(x, j))
                        ),
                    ]),
                )
            ]
            

    
        # initialize variables
        stms += [
            var(ctypes.cint, p),
            var(ctypes.cint, px),
            var(ctypes.cint, j),
        ]

        peeled   = [x for x in self.peeled]
        unpeeled = [x for x in self.unpeeled]

        while len(peeled) > 0 or len(unpeeled) > 0:
            
            if len(unpeeled) == 0 or (len(peeled) != 0 and peeled[0] < self.reachset[unpeeled[0][0]]):
                
                # peel the loop
                col = peeled.pop(0)

                stms += lsolve_update_col(col)
            else:
                
                # loop over range
                start, end = unpeeled.pop(0)
                
                stms.append(
                    loop(
                        assign(px, start),
                        lt(px, end),
                        preinc(px), 
                        block([
                            assign(j, access(rs, px))
                        ] + lsolve_update_col(j))
                    )
                )
                
        return block(stms)
    

In [None]:
Lp = "{}{}L.mtx".format(data_folder, matr)
bp = "{}{}b.mtx".format(data_folder, matr)
solver = Lsolver(Lp, bp)
if solver.L.shape[0] < 20:
    print(solver.L.toarray().astype(int))
print("peeled:   {}".format(len(solver.peeled)))
print("unpeeled: {}".format(len(solver.unpeeled)))
print("\n\n")

with open(opath, "w") as f:
    f.write("""
#include "triangular.h"
#include "minitrace.h"

void lsolve_reachset_{}(
    int n,  int* Lp, int* Li, double* Lx, double* x, 
    std::vector<int> reachset)
{{
MTR_SCOPE_FUNC();
    {}
}}""".format(mname, solver.codegen()))

