In [11]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import pdb
import numpy as np
import copy
import random


class PCGrad:
    def __init__(self, optimizer, reduction="mean"):
        self._optim, self._reduction = optimizer, reduction
        return

    @property
    def optimizer(self):
        return self._optim

    def zero_grad(self):
        """
        clear the gradient of the parameters
        """

        return self._optim.zero_grad(set_to_none=True)

    def step(self):
        """
        update the parameters with the gradient
        """

        return self._optim.step()

    def pc_backward(self, objectives):
        """
        calculate the gradient of the parameters

        input:
        - objectives: a list of objectives
        """

        grads, shapes, has_grads = self._pack_grad(objectives)
        pc_grad = self._project_conflicting(grads, has_grads)
        pc_grad = self._unflatten_grad(pc_grad, shapes[0])
        self._set_grad(pc_grad)
        return

    def _project_conflicting(self, grads, has_grads, shapes=None):
        shared = torch.stack(has_grads).prod(0).bool()
        pc_grad, num_task = copy.deepcopy(grads), len(grads)
        for g_i in pc_grad:
            random.shuffle(grads)
            for g_j in grads:
                g_i_g_j = torch.dot(g_i, g_j)
                if g_i_g_j < 0:
                    g_i -= (g_i_g_j) * g_j / (g_j.norm() ** 2)
        merged_grad = torch.zeros_like(grads[0]).to(grads[0].device)
        if self._reduction:
            merged_grad[shared] = torch.stack([g[shared] for g in pc_grad]).mean(dim=0)
        elif self._reduction == "sum":
            merged_grad[shared] = torch.stack([g[shared] for g in pc_grad]).sum(dim=0)
        else:
            exit("invalid reduction method")

        merged_grad[~shared] = torch.stack([g[~shared] for g in pc_grad]).sum(dim=0)
        return merged_grad

    def _set_grad(self, grads):
        """
        set the modified gradients to the network
        """

        idx = 0
        for group in self._optim.param_groups:
            for p in group["params"]:
                # if p.grad is None: continue
                p.grad = grads[idx]
                idx += 1
        return

    def _pack_grad(self, objectives):
        """
        pack the gradient of the parameters of the network for each objective

        output:
        - grad: a list of the gradient of the parameters
        - shape: a list of the shape of the parameters
        - has_grad: a list of mask represent whether the parameter has gradient
        """

        grads, shapes, has_grads = [], [], []
        for obj in objectives:
            self._optim.zero_grad(set_to_none=True)
            obj.backward(retain_graph=True)
            grad, shape, has_grad = self._retrieve_grad()
            grads.append(self._flatten_grad(grad, shape))
            has_grads.append(self._flatten_grad(has_grad, shape))
            shapes.append(shape)
        return grads, shapes, has_grads

    def _unflatten_grad(self, grads, shapes):
        unflatten_grad, idx = [], 0
        for shape in shapes:
            length = np.prod(shape)
            unflatten_grad.append(grads[idx : idx + length].view(shape).clone())
            idx += length
        return unflatten_grad

    def _flatten_grad(self, grads, shapes):
        flatten_grad = torch.cat([g.flatten() for g in grads])
        return flatten_grad

    def _retrieve_grad(self):
        """
        get the gradient of the parameters of the network with specific
        objective

        output:
        - grad: a list of the gradient of the parameters
        - shape: a list of the shape of the parameters
        - has_grad: a list of mask represent whether the parameter has gradient
        """

        grad, shape, has_grad = [], [], []
        for group in self._optim.param_groups:
            for p in group["params"]:
                # if p.grad is None: continue
                # tackle the multi-head scenario
                if p.grad is None:
                    shape.append(p.shape)
                    grad.append(torch.zeros_like(p).to(p.device))
                    has_grad.append(torch.zeros_like(p).to(p.device))
                    continue
                shape.append(p.grad.shape)
                grad.append(p.grad.clone())
                has_grad.append(torch.ones_like(p).to(p.device))
        return grad, shape, has_grad


class TestNet(nn.Module):
    def __init__(self):
        super().__init__()
        self._linear = nn.Linear(3, 4)

    def forward(self, x):
        return self._linear(x)


class MultiHeadTestNet(nn.Module):
    def __init__(self):
        super().__init__()
        self._linear = nn.Linear(3, 2)
        self._head1 = nn.Linear(2, 4)
        self._head2 = nn.Linear(2, 4)

    def forward(self, x):
        feat = self._linear(x)
        return self._head1(feat), self._head2(feat)

# fully shared network test
torch.manual_seed(4)
x, y = torch.randn(2, 3), torch.randn(2, 4)
net = TestNet()
y_pred = net(x)
pc_adam = PCGrad(optim.Adam(net.parameters()))
pc_adam.zero_grad()
loss1_fn, loss2_fn = nn.L1Loss(), nn.MSELoss()
loss1, loss2 = loss1_fn(y_pred, y), loss2_fn(y_pred, y)
pc_adam.pc_backward([loss1, loss2])
# (loss1 + loss2).backward()
# tensor([[-1.1910, -0.2429,  1.4168],
#         [-0.0428, -0.6260, -0.3132],
#         [ 1.1799,  0.2386, -1.4048],
#         [ 0.8918,  0.1739, -1.0656]])
# tensor([ 0.2624, -0.7032, -0.2623, -0.2057])
for p in net.parameters():
    print(p.grad)
print("-" * 80)

# seperated shared network test
torch.manual_seed(4)
x, y = torch.randn(2, 3), torch.randn(2, 4)
net = MultiHeadTestNet()
y_pred_1, y_pred_2 = net(x)
pc_adam = PCGrad(optim.Adam(net.parameters()))
pc_adam.zero_grad()
loss1_fn, loss2_fn = nn.MSELoss(), nn.MSELoss()
loss1, loss2 = loss1_fn(y_pred_1, y), loss2_fn(y_pred_2, y)
# (loss1 + loss2).backward()
# tensor([[ 0.3086,  0.3411, -0.2030],
#         [-0.9521, -0.6981,  0.8353]])
# tensor([ 0.2531, -0.3721])
# tensor([[-0.0333,  0.0969],
#         [-0.1064,  0.0918],
#         [-0.0381, -0.0076],
#         [-0.0717,  0.0192]])
# tensor([-0.1726, -0.2270, -0.0211, -0.0894])
# tensor([[ 0.1493,  0.0015],
#         [-0.3189,  0.1969],
#         [ 0.0228, -0.0440],
#         [-0.1457,  0.0503]])
# tensor([ 0.1246, -0.5640,  0.0848, -0.1987])
pc_adam.pc_backward([loss1, loss2])
for p in net.parameters():
    print(p.grad)

tensor([[-0.5955, -0.1214,  0.7084],
        [-0.0214, -0.3130, -0.1566],
        [ 0.5900,  0.1193, -0.7024],
        [ 0.4459,  0.0869, -0.5328]])
tensor([ 0.1312, -0.3516, -0.1312, -0.1028])
--------------------------------------------------------------------------------
tensor([[ 0.1543,  0.1705, -0.1015],
        [-0.4761, -0.3491,  0.4177]])
tensor([ 0.1266, -0.1860])
tensor([[-0.0333,  0.0969],
        [-0.1064,  0.0918],
        [-0.0381, -0.0076],
        [-0.0717,  0.0192]])
tensor([-0.1726, -0.2270, -0.0211, -0.0894])
tensor([[ 0.1493,  0.0015],
        [-0.3189,  0.1969],
        [ 0.0228, -0.0440],
        [-0.1457,  0.0503]])
tensor([ 0.1246, -0.5640,  0.0848, -0.1987])


In [34]:
x = torch.randint(0, 2, (2, 3))
y = torch.randint(0, 2, (2, 3))
x, (x + x).sum()

a = [[1, 2]]
a.append([3, 4])
a


[[1, 2], [3, 4]]

In [5]:
import torch
import torch.nn.functional as F

logits = torch.randn(3, 5, 4)
labels = torch.randint(0, 4, (3, 5))
loss = F.cross_entropy(logits.view(-1, 4), labels.view(-1))
print(logits.view(-1, 4).shape, labels.view(-1).shape)
loss

torch.Size([15, 4]) torch.Size([15])


tensor(1.7895)