## This notebook shows how we generated the Schubert structure constants dataset. See: https://doc.sagemath.org/html/en/reference/combinat/sage/combinat/schubert_polynomial.html

In [7]:
import numpy as np
import random
random.seed(int(32))

In [2]:
X = SchubertPolynomialRing(ZZ)

In [3]:
def swap(perm, ind1, ind2):
    newperm = [0]*len(perm)
    for i in range(len(perm)):
        if i == ind1:
            newperm[i] = perm[ind2]
        elif i == ind2:
            newperm[i] = perm[ind1]
        else:
            newperm[i] = perm[i]
    return newperm

In [10]:
n = 4
positives = []
negatives = []
for p1 in Permutations(n):
    for p2 in Permutations(n):
        #Compute the product of the permutations
        product = X(p1)*X(p2)
        #Make a list of (perm, coeff) that appear in the product
        permutations_in_product = [p[0] for p in list(product)]
        for (perm, coeff) in list(product):
            positives.append((p1, p2, perm, coeff))
            #Construct negative examples by swapping two letters in the positive permutation
            if len(perm) > 1:
                ind1, ind2 = random.sample(range(len(perm)), 2)
                ind = np.random.choice(len(perm)-1)
                newperm = swap(perm, ind1, ind2)
                #Check that the new permutation isn't in the product
                if newperm not in permutations_in_product: 
                    negatives.append((p1, p2, newperm, 0))
                else:
                    print(f"{newperm} in {permutations_in_product}")

perm1: [1, 2, 3, 4]
perm2: [1, 2, 3, 4]
product
X[1]
[([1], 1)]
list of permutations appearing in product
[[1]]
perm2: [1, 2, 4, 3]
product
X[1, 2, 4, 3]
[([1, 2, 4, 3], 1)]
list of permutations appearing in product
[[1, 2, 4, 3]]
1 3
swapped permutation: [1, 3, 4, 2]
perm2: [1, 3, 2, 4]
product
X[1, 3, 2]
[([1, 3, 2], 1)]
list of permutations appearing in product
[[1, 3, 2]]
0 1
swapped permutation: [3, 1, 2]
perm2: [1, 3, 4, 2]
product
X[1, 3, 4, 2]
[([1, 3, 4, 2], 1)]
list of permutations appearing in product
[[1, 3, 4, 2]]
2 3
swapped permutation: [1, 3, 2, 4]
perm2: [1, 4, 2, 3]
product
X[1, 4, 2, 3]
[([1, 4, 2, 3], 1)]
list of permutations appearing in product
[[1, 4, 2, 3]]
1 2
swapped permutation: [1, 2, 4, 3]
perm2: [1, 4, 3, 2]
product
X[1, 4, 3, 2]
[([1, 4, 3, 2], 1)]
list of permutations appearing in product
[[1, 4, 3, 2]]
1 2
swapped permutation: [1, 3, 4, 2]
perm2: [2, 1, 3, 4]
product
X[2, 1]
[([2, 1], 1)]
list of permutations appearing in product
[[2, 1]]
0 1
swapped pe

In [12]:
all_examples = positives + negatives

In [15]:
random.shuffle(all_examples)
split = 0.8
ds_size = int(len(all_examples))

all_examples_train = all_examples[:math.ceil(ds_size*split)]
all_examples_test = all_examples[math.ceil(ds_size*split):]


In [16]:
arr_train = []
for row in all_examples_train:
    arr_train.append(str(row))

In [17]:
len(arr_train)

1684

In [18]:
arr_test = []
for row in all_examples_test:
    arr_test.append(str(row))

In [216]:
len(arr_test)

2101641

In [217]:
np.savetxt(f'schubert_structure_coefficients_triples_{n}_train.txt', arr_train, fmt = "%s")
np.savetxt(f'schubert_structure_coefficients_triples_{n}_test.txt', arr_test, fmt = "%s")