## This notebook shows how we generated the Schubert polynomial dataset. 
## See: https://doc.sagemath.org/html/en/reference/combinat/sage/combinat/schubert_polynomial.html

In [None]:
import numpy as np
import itertools
import random
from sage.all import SchubertPolynomialRing, Permutations, ZZ
import math
random.seed(int(32))

In [None]:
X = SchubertPolynomialRing(ZZ)

In [None]:
#When n = 3, the permutations in the product can all be embedded in S_5
#When n = 4, the permutations in the product can all be embedded in S_7
#When n = 5, the permutations in the product can all be embedded in S_9
#When n = 6, the permutations in the product can all be embedded in S_11

n = 3
nn = 5

In [None]:
def swap(perm, ind1, ind2):
    newperm = [0]*len(perm)
    for i in range(len(perm)):
        if i == ind1:
            newperm[i] = perm[ind2]
        elif i == ind2:
            newperm[i] = perm[ind1]
        else:
            newperm[i] = perm[i]
    return newperm

def construct_zero_coeff_example(perm, n):
    #The number of transpositions we multiply perm by is sampled from a geometric distribution
    #The number of transpositions can't be greater than the total number of transpositions
    number_of_transpositions = min( np.random.geometric(0.20), int((nn)*(nn-1)/2))
    combinations = list(itertools.combinations(range( nn ), 2))
    transpositions = random.sample(combinations, number_of_transpositions)
    for (i, j) in transpositions:
        perm = swap(perm, i, j)
    return perm

positive_coeff_triples = []
zero_coeff_triples = []
P = Permutations(n)

for p1 in P:
    for p2 in P:
        #Compute the product of the permutations
        product = X(p1)*X(p2)
        #Make a list of (perm, coeff) that appear in the product
        permutations_in_product = [p[0] for p in list(product)]
        
        #embed permutations in S_{nn}
        embedded_permutations_in_product = [p[0]+ list(range(len(p[0])+1, nn+1 )) for p in list(product)]


        for (perm, coeff) in list(product):
            embedded_perm = perm + list(range(len(perm)+1, nn+1))
            positive_coeff_triples.append((p1, p2, embedded_perm, coeff))

            #Construct an example with a zero coefficient by multiplying the 
            #coefficient in the product by a random number of transpositions
            if len(embedded_perm) > 1:
                newperm = construct_zero_coeff_example(embedded_perm, n)
                
                #Check that the new permutation isn't in the product
                if newperm not in embedded_permutations_in_product: 
                    zero_coeff_triples.append((p1, p2, newperm, 0))
                else:
                    print(f"{newperm} in {embedded_permutations_in_product}, not adding to zero coeff triples")

In [None]:
len(positive_coeff_triples)

In [None]:
len(zero_coeff_triples)

In [None]:
all_examples = positive_coeff_triples + zero_coeff_triples

In [None]:
random.shuffle(all_examples)
split = 0.8
ds_size = int(len(all_examples))

all_examples_train = all_examples[:math.ceil(ds_size*split)]
all_examples_test = all_examples[math.ceil(ds_size*split):]


In [None]:
arr_train = []
for row in all_examples_train:
    arr_train.append(str(row))
arr_test = []
for row in all_examples_test:
    arr_test.append(str(row))
np.savetxt(f'schubert_{n}_train.txt', arr_train, fmt = "%s")
np.savetxt(f'schubert_{n}_test.txt', arr_test, fmt = "%s")