# Monomial matrices and their transformations

## Libraries

In [None]:
import sys; sys.path.append("../modules")
import random

import numpy as np
import sympy; sympy.init_printing()

import Permutations as pm
from Grid import *

from tqdm.notebook import tqdm
# ----- Debugger -----
# from IPython.core.debugger import Pdb; Pdb().set_trace()

## Monomial Matrix Class

In [None]:
class MonomialMatrix(dict):
    '''Monomial matrices in a style of a pair of a diagonal matrix and a permutation matrix.'''
    def __init__(self, diag, perm):
        if not ((type(diag) is list) and (type(perm) is pm.Permutation)): raise(TypeError)
        if not len(diag) == perm.size: raise(ValueError("len(diag) != perm.size"))
        
        self['diagonal'] = diag
        self['permutation'] = perm
        self.size = len(diag)
        self.matrix = sympy.diag(*self['diagonal'])*sympy.Matrix(self['permutation'].matrix())
        
    def associate_PM(self):
        d, p = self['diagonal'], self['permutation']
        size, pinv = sum([abs(v) for v in d]), p.inverse()
    
        I, A = sympy.eye(size), sympy.zeros(0,size)
        for k in range(len(d)):
            x = abs(d[k])
            sign = 1 if d[k]==x else -1
            prev = [abs(d[pinv.act(j)]) for j in range(p.act(k))]
            abov = [abs(d[i]) for i in range(k)]
            sp, sa = sum(prev), sum(abov)
        
            for i in range(x):
                start = sp if sign == 1 else sp+x-1
                A = A.row_insert(sa+i, I.row(start+sign*i))
        #---
        img = [np.where(row > 0)[0][0] for row in np.array(A)]
        return pm.Permutation(img)
    
    def show_in_grid(self):
        hline, vline = "+---" * self.size + "+", "|   " * self.size + "|"
        for row in np.array(self.matrix):
            idx_list = np.where(row != 0)[0]
            if idx_list.size > 0:
                idx = idx_list[0]
                val = row[idx]
                val_str = "{}" if val > 0 else "\033[31m{}\033[0m"
                print(hline + "\n"
                    + vline[: 4 * idx + 2]
                    + val_str.format(abs(val))
                    + vline[4 * (idx + 1) - 1 :])
            else: print(hline + "\n" + vline)
        else: print(hline)
#---    
MM = MonomialMatrix

### scratch

In [None]:
mm = MM([-4,2,3], pm.Permutation([1,0,2])) #; print(type(mm), mm['diagonal'], mm['permutation'].act(0))
aprm = mm.associate_PM()
pmtx = sympy.Matrix(aprm.matrix())

display(mm.matrix, pmtx)
mm.show_in_grid()

## Correctness checking function

In [None]:
def correctness_checking(code, restriction=''):
#-- Starting MM --
    a,b,c,r = code[0], code[1], code[2], code[3]
    starting_mm = MM([-a,b,c,-a,-(a-r),c,b,-(a-r),-2*r], pm.Permutation([8,2,1,5,3,7,6,0,4]))
    sdata = {'mm': starting_mm, '#orbits': len(pm.cycle_decomp(starting_mm.associate_PM())), 'memo': ''}

#-- Resulting MM --
    flag = True
    resulting_mm, rnum, memo = None, None, ''
    if a-r>=0:
        case = "a>=r"
        if r-c == 0:
            case += ", r=c"
            if restriction in case:
                resulting_mm = MM([-r,-a,-(a-r),r,a-r,-2*r,2*b], pm.Permutation([4,2,0,3,5,1,6]))
                memo = ", b={}".format(b)
            else: case = "(x)"
        elif r-c > 0:
            case += ", r>c"
            X, Y = b%(r-c), (a-r)%(2*r)
            Xi, Yi = (r-c)-X, (2*r)-Y
            if restriction in case:
                resulting_mm = MM([Xi,X,c,Xi,X,c,2*Y], pm.Permutation([3,2,1,6,5,4,0]))
            else: case = "(x)"
        else: # r-c < 0
            case += ", r<c"
            X, Y = b%(c-r), (a-r)%(2*r)
            Xi, Yi = (c-r)-X, (2*r)-Y
            if r-Xi == 0:
                case += ", r=Xi"
                if restriction in case:
                    resulting_mm = MM([-Y,-Yi,-2*r,Y,2*X], pm.Permutation([2,3,0,1,4]))
                    memo = ", (Y,r)+X=({},{})+{}={}+{}".format(Y, r, X, np.gcd(Y, r),X)
                else: case = "(x)"
            # elif r-Xi > 0:
            #     case += ", r>Xi"
            #     Z = X%abs(r-Xi)
            #     Zi = abs(r-Xi)-Z
            #     if restriction in case:
            #         resulting_mm = MM([Zi,Z,Xi,Zi,Z,Xi,2*Y], pm.Permutation([3,2,1,6,5,4,0]))
            #     else: case = "(x)"
            else: # r-Xi != 0
                case = "r!=Xi"
                while r < Xi:
                    X = X%(Xi-r)
                    Xi = (Xi-r)-X
                if r-Xi == 0:
                    case += ", r%=Xi"
                    if restriction in case:
                        resulting_mm = MM([-Y,-Yi,-2*r,Y,2*X], pm.Permutation([2,3,0,1,4]))
                        memo = ", (Y,r)+X=({},{})+{}={}+{}".format(Y, r, X, np.gcd(Y, r),X)
                    else: case = "(x)"
                else:
                    Z = X%abs(r-Xi)
                    Zi = abs(r-Xi)-Z
                    if restriction in case:
                        resulting_mm = MM([Zi,Z,Xi,Zi,Z,Xi,2*Y], pm.Permutation([3,2,1,6,5,4,0]))
                    else: case = "(x)"
    else: case = "(x)"
    if resulting_mm is not None: rnum = len(pm.cycle_decomp(resulting_mm.associate_PM()))
    rdata = {'mm': resulting_mm, '#orbits': rnum, 'memo': memo}
#----    
    return [sdata, rdata], case

### scratch

In [None]:
N, case, skip_count = 20, "(x)", 0
restriction = "" #"r=c" #"r!=Xi" #"" #"r=Xi"

while case == "(x)" and skip_count < 1000: #
    code = [random.randint(0,N) for i in range(4)]
    # code = [7,1,5,2] #[7,12,13,3] #[12,4,3,3] #[10,9,4,8] #[8,6,2,7] #[19,4,4,1] #
    d = 2*code[0]+code[1]+code[2]
    code[3] = code[3]%d
    if not code[3] == 0:
        data = correctness_checking(code, restriction)
        case = data[1]
        skip_count += 1
        
#-- Display --
print("skip_count={}".format(skip_count))
print("code={}, case=[ {} ]".format(code,case))
for d in data[0]:
    d['mm'].show_in_grid() #; display(mm.matrix)
    num = int(d['#orbits']/2)
    print("# of components = {}".format(num)+d['memo'])
    
nums = [data[0][i]['#orbits'] for i in [0,1]]
if not nums[0] == nums[1]:
    print("nums are different!! nums = {}".format(nums) )

### Experiments

In [None]:
R, skipped, ones = 100, 0, 0
rstn = "" #"r!=Xi"
for count in tqdm(range(R)):
    N, case = 10, "(x)"

    while case == "(x)": #
        code = [random.randint(0,N) for i in range(4)]
        try:
            code[3] = code[3]%(2*code[0]+code[1]+code[2])
        except:
            print(code)
        if not (code[0]==0 or code[3] == 0):
            data = correctness_checking(code, rstn)
            case = data[1]
        skipped += 1
            
    nums = [data[0][i]['#orbits'] for i in [0,1]]
    if data[0][0]['#orbits']/2 == 1: ones += 1
    if not nums[0] == nums[1]:
        print("nums are different!! nums = {}".format(nums) )
        print("code={}, case=[ {} ]".format(code,case))
        for d in data[0]:
            d['mm'].show_in_grid() #; display(mm.matrix)
            print("# of components = {}".format(int(d['#orbits']/2))+d['memo'])
    
print("ones = {},  skipped = {}".format(ones, skipped))

### Nando

In [None]:
def correctness_checking(code, restriction=''):
#-- Starting MM --
    a,b,c,r = code[0], code[1], code[2], code[3]
    starting_mm = MM([-a,b,c,-a,-(a-r),c,b,-(a-r),-2*r], pm.Permutation([8,2,1,5,3,7,6,0,4]))
    snum  = len(pm.cycle_decomp(starting_mm.associate_PM()))
    sdata = {'mm': starting_mm, '#orbits': snum, 'memo': ''}

#-- Resulting MM --
    flag = True
    resulting_mm, rnum, memo = None, None, ''
    if a-r>=0:
        case = "a>=r"
        if r-c == 0:
            case += ", r=c"
            if restriction in case:
                resulting_mm = MM([-r,-a,-(a-r),r,a-r,-2*r,2*b], pm.Permutation([4,2,0,3,5,1,6]))
                rnum = len(pm.cycle_decomp(resulting_mm.associate_PM())) + 2*b
                memo = ", b={}".format(b)
            else: case = "(x)"
        elif r-c > 0:
            case += ", r>c"
            X, Y = b%(r-c), (a-r)%(2*r)
            Xi, Yi = (r-c)-X, (2*r)-Y
            if restriction in case:
                resulting_mm = MM([Xi,X,c,Xi,X,c,2*Y], pm.Permutation([3,2,1,6,5,4,0]))
                rnum = len(pm.cycle_decomp(resulting_mm.associate_PM()))
            else: case = "(x)"
            # if Xi == Y:
            #     case += ", Xi=Y"
            #     resulting_mm = MM([X,c,Xi,X,c,Xi], pm.Permutation([2,1,0,5,4,3]))
            #     rnum = 2*np.gcd(X+c, r-X) #len(pm.cycle_decomp(resulting_mm.associate_PM()))
            #     memo = ", (X+c, r-X)=({},{})={}".format(X+c, r-X, np.gcd(X+c, r-X))
            # elif Xi > Y:
            #     case += ", Xi>Y"
            #     Z = Xi-Y
            #     # resulting_mm = MM([-Z,-X,-c,-Z,Y,-X,-c,-Y,-Yi], pm.Permutation([5,7,8,2,0,3,4,6,1]))
            #     # resulting_mm = MM([-Z,X,c,Z,Y,-X,-c,-Y,-r], pm.Permutation([7,2,1,4,0,5,6,8,3]))
            #     # resulting_mm = MM([Z,X,c,Z,Y,X,c,Y], pm.Permutation([4,2,1,7,0,6,5,3]))
            #     resulting_mm = MM([Xi,X,c,Xi,X,c,2*Y], pm.Permutation([3,2,1,6,5,4,0]))
            #     rnum = len(pm.cycle_decomp(resulting_mm.associate_PM()))
            # else:
            #     case += ", Xi<Y"
            #     # resulting_mm = MM([-Xi,-X,-c,-Xi,-X,-c,Y,-Y,-Yi], pm.Permutation([6,7,8,3,4,5,0,1,2]))
            #     resulting_mm = MM([Xi,X,c,Xi,X,c,2*Y], pm.Permutation([3,2,1,6,5,4,0]))
            #     rnum = len(pm.cycle_decomp(resulting_mm.associate_PM()))
        else: # r-c < 0
            case += ", r<c"
            X, Y = b%(c-r), (a-r)%(2*r)
            Xi, Yi = (c-r)-X, (2*r)-Y
            # if r-Yi == 0:
            #     case += ", r=Yi"
            #     if restriction in case:
            #         resulting_mm = MM([Xi,X,r,r,X,Xi], pm.Permutation([2,1,0,5,4,3]))
            #         rnum = len(pm.cycle_decomp(resulting_mm.associate_PM()))
            #     else: case = "(x)"
            # else:
                # case += ", r!=Yi"
            if r-Xi == 0:
                case += ", r=Xi"
                if restriction in case:
                    resulting_mm = MM([-Y,-Yi,-2*r,Y,2*X], pm.Permutation([2,3,0,1,4]))
                    rnum = len(pm.cycle_decomp(resulting_mm.associate_PM()))
                    memo = ", (Y,r)+X=({},{})+{}={}+{}".format(Y, r, X, np.gcd(Y, r),X)
                else: case = "(x)"
            if r-Xi > 0:
                case += ", r>Xi"
                Z = X%abs(r-Xi)
                Zi = abs(r-Xi)-Z
                if restriction in case:
                    # resulting_mm = MM([2*Y,Xi,X,r,r,X,Xi], pm.Permutation([6,2,1,0,5,4,3]))
                    # resulting_mm = MM([Xi,X,-Y,-Yi,-r,X,Xi,-r,Y], pm.Permutation([2,1,7,8,3,5,4,0,6]))
                    # resulting_mm = MM([Xi,X,-Y,-(Yi-r),r,X,Xi,-r,Y], pm.Permutation([2,1,6,8,7,4,3,0,5]))
                    # resulting_mm = MM([Xi,X,-r,r,X,Xi,-r,2*Y], pm.Permutation([2,1,7,6,4,3,0,5]))
                    resulting_mm = MM([Zi,Z,Xi,Zi,Z,Xi,2*Y], pm.Permutation([3,2,1,6,5,4,0]))
                    rnum = len(pm.cycle_decomp(resulting_mm.associate_PM()))
                else: case = "(x)"
            else: case = "(x)"
    else: case = "(x)"
    rdata = {'mm': resulting_mm, '#orbits': rnum, 'memo': memo}
#----    
    return [sdata, rdata], case