In [1]:
from sympy import *
from sympy.combinatorics import Permutation
import numpy as np
import random
import traceback
import inflector
import logging

In [2]:
logger = logging.getLogger("LADataGeneratorForAI")
logger.setLevel(logging.INFO)
handler_format = logging.Formatter("%(asctime)s: %(message)s")
stream_handler = logging.StreamHandler()
stream_handler.setFormatter(handler_format)
if logger.hasHandlers():
    logger.handlers.clear()
logger.addHandler(stream_handler)

MAX_SEARCH_COUNT = 10
MAX_LOOP = 100

ie = inflector.English()


def egcd(a, b):
    # extended gcd
    x, y, u, v = 0, 1, 1, 0
    while a != 0:
        q, r = b // a, b % a
        m, n = x - u * q, y - v * q
        b, a, x, y, u, v = a, r, u, v, m, n
    return b, x, y


def modinv(a, m):
    # a^-1 mod m
    g, x, y = egcd(a, m)
    if g != 1:
        return None
    else:
        return x % m


def swap(arr, m, n):
    if len(arr.shape) == 1:
        arr[[m, n]] = arr[[n, m]]
    else:
        arr[[m, n], :] = arr[[n, m], :]


def format_coeff(value, head=False):
    if head:
        return f'{"" if value>=0 else "-"}{abs(value)}'
    else:
        return f'{"+" if value>=0 else "-"} {abs(value)}'


def make_linear_problem(
    n, m, sol_range=100, coeff_range=100, unit_list=None, show=False
):
    # n: number of variables
    # m: number of linear equations
    # sol_range: max value of variables
    # coeff_range: max value of coeffs
    c = np.random.randint(1, coeff_range, size=(n, m))
    if unit_list is not None:
        unit_name = [v[0] for v in unit_list]
        idx = unit_name.index("number")
        if idx >= 0:
            c[idx, :] = np.ones(c[idx, :].shape)
    x = np.random.randint(sol_range, size=(m))
    y = c.dot(x)
    for i1 in range(n):
        head = True
        for i2 in range(m):
            if show:
                print(f"{format_coeff(c[i1][i2], head=head)}*x{i2} ", end="")
                head = False
        if show:
            print(f"== {y[i1]}")
    for i2 in range(m):
        if show:
            print(f"x{i2} = {x[i2]}")
    # cx = y
    return c, y, x


def and_join(term):
    if len(term) > 1:
        not_last = term[:-1]
        last = [term[-1]]
    else:
        not_last = term
        last = []
    return " and ".join([", ".join(not_last)] + last)


class BadEquation(Exception):
    pass


class TooDifficultProblem(Exception):
    pass


class InvalidSolution(Exception):
    pass


class FailToMakeProblem(Exception):
    pass

In [3]:
class SolveLinEq:
    def __init__(self, A, y):
        # Ax = y
        self.A = Matrix(A)
        self.y = Matrix(y)
        # M = [A, -y]
        self.M = self.A.row_join(-self.y)
        self.log = []

    def print_log(self):
        for log in self.log:
            print(log)

    def get_log(self):
        return self.log

    def append_log(self, log_line):
        self.log.append(log_line)

    def append_equation_log(self, M):
        for i1 in range(M.shape[0]):
            head = True
            log_line = f"Eq{i1}: "
            for i2 in range(M.shape[1] - 1):
                if M[i1, i2] != 0:
                    log_line += f"{format_coeff(M[i1, i2], head=head)}*x{i2} "
                    head = False
            if M[i1, M.shape[1] - 1] != 0:
                log_line += f"{format_coeff(M[i1, M.shape[1]-1], head=head)}"
                head = False
            log_line += " = 0"
            self.append_log(log_line)

    def append_solution_log(self, sol):
        for k, v in sol.items():
            self.append_log(f"{k} = {v}")

    def _is_same_eqvec(self, eqv1, eqv2):
        len_eqvec1 = sqrt((eqv1 * eqv1.transpose())[0, 0])
        len_eqvec2 = sqrt((eqv2 * eqv2.transpose())[0, 0])
        inner_prod = (eqv1 * eqv2.transpose() / len_eqvec1 / len_eqvec2)[0, 0]
        if abs(inner_prod) == 1:
            return True
        return False

    def _find_same_eqvec(self, M):
        same_eqvec_list = []
        for i1 in range(M.shape[0]):
            for i2 in range(i1 + 1, M.shape[0]):
                eqv1 = M[i1, :]
                eqv2 = M[i2, :]
                if self._is_same_eqvec(eqv1, eqv2):
                    same_eqvec_list.append((i1, i2))
        return same_eqvec_list

    def _remove_same_eqvec(self, M):
        same_eqvec_list = self._find_same_eqvec(M)
        for same_eqvec in same_eqvec_list:
            self.append_log(
                f"Eq{same_eqvec[0]} and Eq{same_eqvec[1]} are same equations."
            )
        remove_cols = set([elem[1] for elem in same_eqvec_list])
        remove_list = sorted(list(remove_cols))
        if len(remove_list) > 0:
            log_line = "Remove same Eqs. "
            log_line += ",".join([f"Eq{eqv_num}" for eqv_num in remove_list]) + "."
            self.append_log(log_line)
        else:
            return False, M
        col_list = sorted(list(set(range(self.M.shape[0])) - remove_cols))
        logger.debug(f"col_list={col_list}")
        M = M[col_list, :]
        return True, M

    def _LU(self, x, M):
        A = M[:, 0:-1]
        y = M[:, -1]
        L, U, p = A.LUdecomposition()

        perm = Permutation(p)
        if perm.size > 0:
            self.append_log("Change Eq order.")
            for i in range(perm.size):
                if i != perm(i):
                    self.append_log(f"Eq{i} to New Eq{perm(i)}")
            logger.debug(f"perm={perm}")
            M1 = M.copy()
            y1 = y.copy()
            for i in range(perm.size):
                M1[i, :] = M[perm(i), :]
                y1[i] = y[perm(i)]
            M = M1
            y = y1
            self.append_equation_log(M)

        invL = L.inv()
        y1 = invL * -y

        self.append_log("Making new Eqs.")
        for i1 in range(invL.shape[0]):
            head = True
            eq_calc_str = []
            for i2 in range(invL.shape[1]):
                if invL[i1, i2] != 0:
                    eq_calc_str.append(
                        f"{format_coeff(invL[i1, i2], head=head)}*Eq{i2}"
                    )
                    head = False
            self.append_log(f'New Eq{i1} is {" ".join(eq_calc_str)}')

        eq_list = []
        for i1 in range(U.shape[0]):
            term = 0
            info_line = ""
            head = True
            for i2 in range(U.shape[1]):
                if U[i1, i2] == 0:
                    continue
                term += U[i1, i2] * x[i2]
                info_line += f"{format_coeff(U[i1, i2], head=head)}*x{i2} "
                head = False
            term += -y1[i1]
            info_line += f"== {y1[i1]}"
            logger.debug(info_line)
            eq_list.append(Eq(term, 0))

        # M' = [U, -y]
        Uy = U.row_join(-y1)
        self.append_log("New Eqs are followings.")
        self.append_equation_log(Uy)
        return Uy, eq_list

    def _adjust_equation(self, eqvec):
        def findx(v):
            # find first index of coeff > 0
            for i in range(v.shape[1]):
                if v[0, i] != 0:
                    return i
            return None

        idx = findx(eqvec)
        coeff = eqvec[0, idx]
        eqvec1 = zeros(*eqvec.shape)
        eqvec1 = eqvec / coeff
        logger.debug(eqvec1)
        denom_list = [eqvec1[0, i].denominator for i in range(eqvec1.shape[1])]
        eqlcm = ilcm(*tuple(denom_list))
        logger.debug(eqlcm)
        eqvec1 = eqvec1 * eqlcm
        coeff_list = [eqvec1[0, i] for i in range(eqvec1.shape[1])]
        eqgcd = igcd(*tuple(coeff_list))
        logger.debug(eqgcd)
        eqvec1 = eqvec1 * eqgcd
        return eqvec1[0, idx:], idx

    def _solve_int_equation(self, eqvec, mod_x_idx):
        # solve modular equation (divisor is coeff of x[mod_x_idx])
        def findx(v, m):
            for i in range(v.shape[1]):
                if gcd(v[0, i], m) == 1:
                    return i
            return None

        idx = findx(eqvec, eqvec[0, mod_x_idx])
        if idx is None:
            return None, None
        mi = modinv(
            eqvec[0, idx]
            if eqvec[0, idx] >= 0
            else eqvec[0, idx] + eqvec[0, mod_x_idx],
            eqvec[0, mod_x_idx],
        )
        if mi is None:
            return None, None
        for i in range(eqvec.shape[1]):
            if i == mod_x_idx:
                continue
            if i == idx:
                eqvec[0, i] = eqvec[0, i] * mi
            else:
                eqvec[0, i] = -eqvec[0, i] * mi
        eqvec = eqvec % eqvec[0, mod_x_idx]
        return eqvec, idx

    def _make_solution_equation(
        self, x, N, mod_eq, last_eq_first_x_idx, mod_eq_dep_x_idx, N_coeff
    ):
        xn = 0
        m = 0
        x_idx_list = []
        for i in range(1, mod_eq.shape[1]):
            if i == mod_eq_dep_x_idx:
                continue
            if i < mod_eq.shape[1] - 1:
                if mod_eq[0, i] != 0:
                    pidx = last_eq_first_x_idx + i
                    x_idx_list.append(pidx)
                    xn += mod_eq[0, i] * x[pidx]
            else:
                xn += mod_eq[0, i]
            m = m + 1
        eq = Eq(x[last_eq_first_x_idx + mod_eq_dep_x_idx] - xn + N * N_coeff, 0)
        sol = solve(eq, (x[last_eq_first_x_idx + mod_eq_dep_x_idx]))
        self.append_log(f"x{last_eq_first_x_idx+mod_eq_dep_x_idx} = {sol[0]}")
        return eq, x_idx_list

    def _search_solution_onetry(self, vl, t):
        def check(t):
            logger.debug(f"check t {t}")
            return not (sum(t < -0.0001) > 0)

        def inner(v1, v2):
            v1_len = np.sqrt(np.inner(v1, v1))
            v2_len = np.sqrt(np.inner(v2, v2))
            return np.inner(v1, v2) / v1_len / v2_len

        def find_vec(vl, t):
            t1 = (t < 0) * t + (t > 0) * t * -0.1
            dist = [inner(vl[i, :], t1) for i in range(vl.shape[0])]
            flag0 = -1 if dist[0] > 0 else 1
            dist[0] = flag0 * dist[0]
            idx = np.where((np.array(dist) < 0) == True)
            idx0 = np.random.choice(idx[0])
            flag = flag0 if idx0 == 0 else 1
            outv = vl[idx0, :] * flag
            return outv, idx0, flag

        def find_minv(v):
            idx = np.where((v < 0) == True)
            idx0 = np.random.choice(idx[0])
            return v[idx0], idx0

        search_log = []
        c = np.zeros(vl.shape[0])

        count = 0
        last_idx0 = -1
        while True:
            search_log.append((t.copy(), c.copy()))
            count += 1
            if count > MAX_SEARCH_COUNT:
                logger.debug(f"t={t}")
                break
            logger.debug(f"t={t}")
            if check(t):
                logger.debug(f"count={count}")
                return t, c, search_log
            v, idx0, flag = find_vec(vl, t)
            if last_idx0 == idx0:
                idx0 = 0
                v = vl[idx0, :]
                flag = 1
            tv, idx = find_minv(t)
            vv = v[idx]
            if vv == 0:
                a = 1 if -tv > 0 else -1
            else:
                a = np.round(-tv / vv)
                if a == 0:
                    a = 1 if -tv / vv > 0 else -1
                else:
                    da = 5
                    if abs(a) > da:
                        a = np.sign(a) * ((abs(a) - da) / 2 + da)
            t = t + v * a
            c[idx0] += a * flag
            last_idx0 = idx0
        return None, None, None

    def _search_solution(self, vl, t):
        for ii in range(10):
            rest, resc, search_log = self._search_solution_onetry(vl, t)
            if rest is not None:
                return rest, resc, search_log
        return None, None, None

    def _make_solution_param(self, sol):
        valn = self.M.shape[1] - 1
        px = list(set(self.x) - set(sol.keys()))
        px.reverse()
        solx = list(sol.keys())
        logger.debug(f"px={px}")
        logger.debug(f"solx={solx}")
        vl = np.zeros([valn - len(sol) + 1, len(sol)])
        t = np.zeros(len(sol))
        for i2 in range(vl.shape[1]):
            vl[0, i2] = sol[solx[i2]].coeff(self.N)
            for i1 in range(vl.shape[0] - 1):
                vl[i1 + 1, i2] = sol[solx[i2]].coeff(px[i1])
        for i2 in range(vl.shape[1]):
            logger.debug(sol[solx[i2]])
            t[i2] = sol[solx[i2]].as_coefficients_dict()[1]
        return vl, t, solx, px

    def _check_matrix(self, vl):
        def is_int(a):
            if abs(a - round(a)) > 0.0001:
                return False
            return True

        for i1 in range(vl.shape[0]):
            for i2 in range(vl.shape[1]):
                if not is_int(vl[i1, i2]):
                    return False

        return True

    def solve(self):
        self.append_log("Eqs are followings.")
        self.append_equation_log(self.M)

        flag, self.M = self._remove_same_eqvec(self.M)
        if flag:
            self.append_log("Eqs are followings.")
            self.append_equation_log(self.M)

        eqvn = self.M.shape[0]
        valn = self.M.shape[1] - 1
        if eqvn > valn:
            self.append_log(
                "The number of Eqs is greater than the number of variables."
            )
            self.append_log("There is no solution.")
            return None, None, None

        self.x = symbols(" ".join(f"x{i}" for i in range(valn)))
        self.N = symbols("N")

        Uy, eq_list = self._LU(self.x, self.M)
        if eqvn == valn:
            sol = solve(tuple(eq_list), tuple(self.x))
            sol_list = []
            for x1 in list(sol.keys()):
                sol_v = sol[x1]
                sol_list.append(f"{x1} = {sol_v}")
            self.append_log(f'Solution is {", ".join(sorted(sol_list))}')
            return sol, None, None

        last_eq, last_eq_first_x_idx = self._adjust_equation(Uy[-1, :])
        logger.debug(f"last_eq_first_x_idx={last_eq_first_x_idx}")
        for mod_x_idx in range(last_eq.shape[1] - 1):
            mod_eq, mod_eq_dep_x_idx = self._solve_int_equation(last_eq, mod_x_idx)
            if mod_eq is not None:
                break
        if mod_eq is None:
            raise BadEquation()

        logger.debug(f"mod_x_idx={mod_x_idx}")
        self.append_log(f"Eq{eqvn-1} leads the following relation.")
        exp, x_idx_list = self._make_solution_equation(
            self.x,
            self.N,
            mod_eq,
            last_eq_first_x_idx,
            mod_eq_dep_x_idx,
            last_eq[0, mod_x_idx],
        )
        self.append_log(f"N is integer.")
        logger.debug(f"exp={exp}")

        sol_xidx = set(range(valn)) - set(x_idx_list)
        logger.debug(f"x_idx_list={x_idx_list}")
        logger.debug(f"sol_xidx={sol_xidx}")
        sol = solve((*tuple(eq_list), exp), tuple([self.x[i] for i in sol_xidx]))
        self.append_log("Calculating Eqs for dependent variables.")
        self.append_solution_log(sol)

        vl, t, solx, px = self._make_solution_param(sol)
        if not self._check_matrix(vl):
            raise BadEquation()

        rest, resc, search_log = self._search_solution(vl, t)
        if rest is None:
            raise TooDifficultProblem()

        self.append_log(
            f'Searching a solution for {and_join([ f"{v} >= 0" for v in solx + px ])}.'
        )
        logger.debug(f"search_log = {search_log}")
        for ii, (v1, v2) in enumerate(search_log):
            message_solx = [f"{solx[i]} = {int(v)}" for i, v in enumerate(v1)]
            message_N = [f"N = {int(v2[0])}"]
            message_px = [f"{px[i]} = {int(v)}" for i, v in enumerate(v2[1:])]
            self.append_log(
                f"Step {ii+1}: {and_join(message_N + message_px)}, then {and_join(message_solx)}."
            )
            minus_v = [f"{solx[i]} = {int(v)}" for i, v in enumerate(v1) if v < 0]
            if len(minus_v) > 0:
                self.append_log(
                    f"There are minus value variables, {and_join(minus_v)}."
                )
            else:
                self.append_log(f"There is no minus value variable.")

        sol_eq = []
        sol_list = []
        for i, v in enumerate(rest):
            sol_list.append(f"{solx[i]} = {int(v)}")
            sol_eq.append(Eq(solx[i], v))
        for i, v in enumerate(resc):
            if i == 0:
                sol_eq.append(Eq(self.N, v))
            else:
                sol_list.append(f"{px[i-1]} = {int(v)}")
                sol_eq.append(Eq(px[i - 1], v))

        self.append_log(f"One solution is {and_join(sorted(sol_list))}.")
        for i, eq in enumerate(eq_list):
            eq_eval = eq.lhs
            for veq in sol_eq:
                eq_eval = eq_eval.subs(veq.lhs, int(veq.rhs))
            logger.debug(f"Eq{i}: {eq_eval}")
            if eq_eval != 0:
                raise InvalidSolution()

        return sol, rest, resc

In [4]:
def make_problem_sentence(c, y, item_name, name_list, unit_list):
    prob_log = []
    prob_log.append(
        f"In a market, {len(name_list)} kinds of {item_name}, {and_join([ ie.pluralize(name) for name in name_list ])} are for sale."
    )
    for i1, vlist in enumerate(c):
        str_list = []
        if unit_list[i1][0] == "number":
            pass
        else:
            for i2, v in enumerate(vlist):
                if v == 0:
                    continue
                str_list.append(
                    f"The {unit_list[i1][0]} of {name_list[i2]} is {v} {unit_list[i1][1]}."
                )
            prob_log.append(f'{" ".join(str_list)}')
    prob_log.append(f"A shopper buys {item_name} at this market.")
    for i1, vlist in enumerate(c):
        str_list = []
        if unit_list[i1][0] == "number":
            prob_log.append(f"The number of {ie.pluralize(item_name)} is {y[i1]}.")
        else:
            prob_log.append(
                f"The sum of {unit_list[i1][0]} is {y[i1]} {unit_list[i1][1]}."
            )
    prob_log.append(
        f"How many {and_join([ ie.pluralize(name) for name in name_list ])} does the shopper buy?"
    )
    prob_log.append(f"Show one solution.")
    return prob_log


def replace_name(name_list, sol_log):
    replace_dict = {}
    replaced_log = []
    for i, name in enumerate(name_list):
        replace_dict[f"x{i}"] = name
    for log in sol_log:
        for k, v in replace_dict.items():
            log = log.replace(k, v)
        replaced_log.append(log)
    return replaced_log


def make_linear_problem_and_solution(problem_param):
    count_bad_eq, count_difficult, count_invalid = 0, 0, 0
    sol_range, coeff_range = problem_param["sol_range"], problem_param["coeff_range"]
    val_n = len(problem_param["name_list"])
    eq_n = len(problem_param["equations_setting"])
    item_name = problem_param["item_name"]
    name_list = problem_param["name_list"]
    unit_list = problem_param["equations_setting"]

    for i in range(MAX_LOOP):
        c, y, sol_x = make_linear_problem(
            eq_n,
            val_n,
            unit_list=unit_list,
            sol_range=sol_range,
            coeff_range=coeff_range,
        )
        # c1 = c[0:n, :]
        # y1 = y[0:n]
        n, c1, y1 = eq_n, c, y
        sle = SolveLinEq(c1[0:n, :], y1[0:n])
        try:
            sol, rest, resc = sle.solve()
        except BadEquation:
            count_bad_eq += 1
            continue
        except TooDifficultProblem:
            count_difficult += 1
            continue
        except InvalidSolution:
            count_invalid += 1
            continue
        except Exception as e:
            logger.debug(traceback.format_exc())
            continue
        prob_log = make_problem_sentence(c1, y1, item_name, name_list, unit_list)
        sol_log = replace_name(name_list, sle.get_log())
        logger.debug("----- result -----")
        logger.debug(f"sol_x={sol_x}")
        logger.debug(f"count_bad_eq={count_bad_eq}")
        logger.debug(f"count_difficult={count_difficult}")
        logger.debug(f"count_invalid={count_invalid}")
        return prob_log, sol_log
    raise FailToMakeProblem()


def random_print_linear_problem_and_solution(problem_param, count):
    i = 0
    name_list = problem_param["name_list"].copy()
    while True:
        try:
            k = random.randint(len(problem_param["equations_setting"]), len(name_list))
            problem_param["name_list"] = random.sample(name_list, k=k)
            prob_log, sol_log = make_linear_problem_and_solution(problem_param)
        except FailToMakeProblem:
            continue
        i += 1
        print(f"----- Problem {i} -----")
        for log in prob_log:
            print(log)
        print(f"----- Solution {i} -----")
        for log in sol_log:
            print(log)
        if i >= count:
            break

In [6]:
problem_param = {
    "equations_setting": [
        ("number", ""),
        ("price", "cents"),
        ("weight", "g"),
        # ("volume", "cm3"),
    ],
    "item_name": "fruits",
    "name_list": [
        "apple",
        "blueberry",
        "kiwifruit",
        "orange",
        "melon",
        "grape",
        "strawberry",
    ],
    "sol_range": 10,
    "coeff_range": 100,
}

random_print_linear_problem_and_solution(problem_param, 2)

----- Problem 1 -----
In a market, 4 kinds of fruits, grapes, oranges, apples and melons are for sale.
The price of grape is 48 cents. The price of orange is 46 cents. The price of apple is 56 cents. The price of melon is 8 cents.
The weight of grape is 81 g. The weight of orange is 4 g. The weight of apple is 69 g. The weight of melon is 66 g.
A shopper buys fruits at this market.
The number of fruits is 20.
The sum of price is 904 cents.
The sum of weight is 1479 g.
How many grapes, oranges, apples and melons does the shopper buy?
Show one solution.
----- Solution 1 -----
Eqs are followings.
Eq0: 1*grape + 1*orange + 1*apple + 1*melon - 20 = 0
Eq1: 48*grape + 46*orange + 56*apple + 8*melon - 904 = 0
Eq2: 81*grape + 4*orange + 69*apple + 66*melon - 1479 = 0
Making new Eqs.
New Eq0 is 1*Eq0
New Eq1 is -48*Eq0 + 1*Eq1
New Eq2 is 1767*Eq0 - 77/2*Eq1 + 1*Eq2
New Eqs are followings.
Eq0: 1*grape + 1*orange + 1*apple + 1*melon - 20 = 0
Eq1: -2*orange + 8*apple - 40*melon + 56 = 0
Eq2: -320*