In [None]:
import torch



def ComputeGradient(gradients, curr_losses, mem_losses, gradnorm_mom):
    gs = []
    for i in range(len(gradients)):  # 对每个任务
        g_task_flat = torch.cat([grad.reshape(-1) for grad in gradients[i]], 0)
        gs.append(g_task_flat)
    tols = ComputeTol(curr_losses, mem_losses, gradnorm_mom)
    sol = min_norm_solvers.find_min_norm_element_with_tol(gs, tols)

    # if len(gs) > 2:
    #     print(3)

    d = []
    for k in range(len(gradients[0])):
        g = 0
        for i in range(len(gradients)):  # 对每个任务
            g += sol[i] * gradients[i][k] #/ len(gradients)
        d.append(g)
    return d

def ComputeTol(curr_losses, mem_losses, gradnorm_mom):
    losses = [torch.from_numpy(mem_losses)] + [torch.from_numpy(loss) for loss in curr_losses] if len(mem_losses) > 0 else [torch.from_numpy(loss) for loss in curr_losses]
    tols = []
    for k in range(len(losses)):
        assert len(losses[k]) > 0
        tols.append(gradnorm_mom[k])

    tols = torch.tensor(tols, dtype=torch.float64)
    tols = softmax(tols/5, 0) # Softmax Temperature 5
    return tols

def softmax(x, axis=None):
    x = x - x.max(dim=axis, keepdim=True).values
    y = torch.exp(x)
    return y / y.sum(dim=axis, keepdim=True)

def _projection2simplex_with_tol(y, tols):
    """
    Given y, it solves argmin_z |y-z|_2 st \sum z = 1 , 1 >= z_i >= 0 for all i
    """
    sorted_idx = torch.flip(torch.argsort(y), dims=[0])
    tmpsum = 0.0
    tmpsum_tol = 0.0
    tmax_f = (torch.sum(torch.mul(y, tols)) - 1.0) / torch.sum(torch.mul(tols, tols))

    for i in sorted_idx[:-1]:
        tmpsum += y[i] * tols[i]  # plus from large to small
        tmpsum_tol += tols[i] * tols[i]  # plus from large to small
        tmax = (tmpsum - 1.) / (tmpsum_tol)  #
        if tols[i] * tmax > y[i]:  # 基本无法满足条件
            tmax_f = tmax
            break

    output = torch.max(y - tmax_f * tols, torch.zeros_like(y))
    return output


def _min_norm_2d_with_tol(vecs, dps, tols):
    """
    Find the minimum norm solution as combination of two points
    This is correct only in 2D
    ie. min_c |\sum c_i x_i|_2^2 st. \sum c_i = 1 , 1 >= c_1 >= 0 for all i, c_i + c_j = 1.0 for some i, j
    """
    dmin = None
    for i in range(len(vecs)):
        for j in range(i + 1, len(vecs)):
            if (i, j) not in dps:
                dps[(i, j)] = torch.sum(torch.mul(vecs[i].view(-1), vecs[j].view(-1))).item()
                dps[(j, i)] = dps[(i, j)]
            if (i, i) not in dps:
                dps[(i, i)] = torch.sum(torch.mul(vecs[i].view(-1), vecs[i].view(-1))).item()
            if (j, j) not in dps:
                dps[(j, j)] = torch.sum(torch.mul(vecs[j].view(-1), vecs[j].view(-1))).item()

            c, d = _min_norm_element_from2_with_tol_v2(dps[(i, i)], dps[(i, j)], dps[(j, j)], tols[i], tols[j])

            if dmin == None:
                dmin = d
                sol = [(i, j), c, d]
            else:
                if d < dmin:
                    dmin = d
                    sol = [(i, j), c, d]
    return sol, dps


def _next_point_with_tol_v2(cur_val, grad, n, tols, lr):
    # proj_grad = grad - ( np.sum(grad) / n ) # 一定下降的方向

    next_point = grad * lr + cur_val
    # print(cur_val)
    # print(next_point)
    # print(proj_grad)
    # print(t)
    # print(t*proj_grad)
    # exit()
    # _n = next_point
    # _n = _projection2simplex_with_tol(next_point, tols)
    return next_point


def _projection2simplex_with_tol(y, tols):
    """
    Given y, it solves argmin_z |y-z|_2 st \sum z = 1 , 1 >= z_i >= 0 for all i
    """
    sorted_idx = torch.flip(torch.argsort(y), dim=0)
    tmpsum = torch.tensor(0.0)
    tmpsum_tol = torch.tensor(0.0)
    tmax_f = (torch.sum(torch.mul(y, tols)) - 1.0) / torch.sum(torch.mul(tols, tols))

    for i in sorted_idx[:-1]:
        tmpsum += y[i] * tols[i]  # plus from large to small
        tmpsum_tol += tols[i] * tols[i]  # plus from large to small
        tmax = (tmpsum - 1.) / (tmpsum_tol)  #
        if tols[i] * tmax > y[i]:  # 基本无法满足条件
            tmax_f = tmax
            break

    output = torch.max(y - tmax_f * tols, torch.zeros_like(y))
    return output



def find_min_norm_element_with_tol(vecs, tols):
    dps = {}
    init_sol, dps = _min_norm_2d_with_tol(vecs, dps, tols)

    n = len(vecs)
    iter_count = 0

    grad_mat = torch.zeros((n,n))
    for i in range(n):
        for j in range(n):
            grad_mat[i,j] = dps[(i, j)]
    sol_vec = torch.ones([n], dtype=torch.float64) / n
    P = grad_mat
    A = tols
    q = torch.zeros([n], dtype=torch.float64)
    b = torch.tensor([1.], dtype=torch.float64)
    lb = (1/(n+1))*torch.ones([n], dtype=torch.float64)
    sol_vec = torch.optim.solve_qp(P=P, q=q, A=A, b=b, lb=lb, initvals=sol_vec)
    lr = 0.001

    while iter_count < MAX_ITER:
        grad_dir = -1.0*torch.matmul(grad_mat, sol_vec) - 100*(torch.dot(sol_vec, tols) * tols - 1)
        print(iter_count, "- 1 gamma*norm", torch.matmul(torch.matmul(sol_vec, grad_mat), sol_vec))
        new_point = _next_point_with_tol_v2(sol_vec, grad_dir, n, tols, lr)
        print(new_point, tols, torch.dot(new_point, tols))
        print(iter_count, "- 2 gamma*norm", torch.matmul(torch.matmul(new_point, grad_mat), new_point))
        new_sol_vec = new_point
        change = new_sol_vec - sol_vec
        if torch.sum(torch.abs(change)) < STOP_CRIT:
            sol_vec =  _projection2simplex_with_tol(sol_vec, tols)
            print(sol_vec, tols, torch.dot(sol_vec, tols))
            print(iter_count, "Mapping: gamma*norm", torch.matmul(torch.matmul(sol_vec, grad_mat), sol_vec))
            print('-------------')
            return sol_vec
        sol_vec = new_sol_vec
        iter_count += 1 # delete this line for unlimited optimization
    sol_vec =  _projection2simplex_with_tol(sol_vec, tols)
    print(iter_count, "Mapping: gamma*norm", torch.matmul(torch.matmul(sol_vec, grad_mat), sol_vec))
    print('-------------')
    return sol_vec

In [None]:
def get_mem(m,gradients):
    q = 0.9
    p = 0.1
    a = []
    for i in range(len(gradients)): 
        grad_array = np.array(gradients[i])
        grad_norm = np.linalg.norm(grad_array)
        a.append(q*m[i] +p*grad_norm)
    return a