<a href="https://colab.research.google.com/github/yingzibu/MOL2ADMET/blob/main/reference/MultiTaskLoss.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim

class MultiTaskLoss(nn.Module):
    def __init__(self, model, loss_fn, eta) -> None:
        super(MultiTaskLoss, self).__init__()
        self.model = model
        self.loss_fn = loss_fn
        self.eta = nn.Parameter(torch.Tensor(eta))

    def forward(self, input, targets) -> (torch.Tensor, torch.Tensor):
        outputs = self.model(input)
        loss = [l(o,y) for l, o, y in zip(self.loss_fn, outputs, targets)]
        total_loss = torch.stack(loss) * torch.exp(-self.eta) + self.eta
        return loss, total_loss.sum() # omit 1/2

class LinearModel(nn.Module):
  def __init__(self) -> None:
    super(LinearModel, self).__init__()
    self.weight = nn.Parameter(torch.empty(2))
    self.reset_parameters()

  def reset_parameters(self) -> None:
    nn.init.uniform_(self.weight, a=0.0, b=0.5)
    with torch.no_grad():
      self.weight += torch.Tensor([0.0, 0.2])

  def forward(self, input: torch.Tensor) -> torch.Tensor:
    return torch.prod(self.weight) * input

class MultiTaskModel(nn.Module):
    def __init__(self) -> None:
        super(MultiTaskModel, self).__init__()
        self.f1 = LinearModel()
        self.f2 = LinearModel()

    def forward(self, input) -> torch.Tensor:
        outputs = [self.f1(input[0:1]), self.f2(input[1:2])]
        return outputs

In [3]:
output_dir = "linear_model"
os.makedirs(output_dir, exist_ok=True)

history = open(os.path.join(output_dir, 'history.csv'), 'w')
history.write('step,total_loss,loss1,loss2,eta1,eta2,weight1_1,weight1_2,weight2_1,weight2_2\n')

loss_fn1 = nn.MSELoss()
loss_fn2 = nn.MSELoss()

mtl = MultiTaskLoss(model=MultiTaskModel(),
                    loss_fn=[loss_fn1, loss_fn2],
                    eta=[1.0, 1.0])
optimizer = optim.SGD(mtl.parameters(), lr=0.1)
print(list(mtl.parameters()))

for i in range(50):
  optimizer.zero_grad()
  x = torch.Tensor([1.0, 0.5])
  (y1, y2) = [torch.Tensor([v]) for v in x.tolist()]
  loss, total_loss = mtl(x, [y1, y2])
  print(i, total_loss.item(), [v.item() for v in loss], mtl.eta.tolist(), mtl.model.f1.weight.tolist(), mtl.model.f2.weight.tolist())
  history.write(f'{i},{total_loss.item()},{loss[0].item()},{loss[1].item()},{mtl.eta[0].item()},{mtl.eta[1].item()},{mtl.model.f1.weight[0].item()},{mtl.model.f1.weight[1].item()},{mtl.model.f2.weight[0].item()},{mtl.model.f2.weight[1].item()}\n')
  total_loss.backward()
  optimizer.step()

history.close()


[Parameter containing:
tensor([1., 1.], requires_grad=True), Parameter containing:
tensor([0.4296, 0.4023], requires_grad=True), Parameter containing:
tensor([0.0145, 0.6686], requires_grad=True)]
0 2.341888427734375 [0.6841792464256287, 0.24517010152339935] [1.0, 1.0] [0.4296114146709442, 0.40233731269836426] [0.014518111944198608, 0.6686078310012817]
1 2.188544750213623 [0.6487130522727966, 0.24115124344825745] [0.9251695275306702, 0.9090192914009094] [0.4540970027446747, 0.4284827709197998] [0.02669708803296089, 0.6688722968101501]
2 2.033690929412842 [0.6079931259155273, 0.2368135303258896] [0.8508886694908142, 0.8187357187271118] [0.48146188259124756, 0.4574835002422333] [0.039931539446115494, 0.6694005131721497]
3 1.8762346506118774 [0.5615814328193665, 0.2321346551179886] [0.7768521308898926, 0.7291789054870605] [0.5119280815124512, 0.4895465672016144] [0.054296910762786865, 0.6702574491500854]
4 1.7149593830108643 [0.5092496275901794, 0.22709015011787415] [0.70267653465271, 0.6

In [None]:
import pandas as pd

df_hist = pd.read_csv(os.path.join(output_dir, "history.csv"))
df_hist.plot(x="step", y="total_loss")

In [None]:
df_hist.plot(x="step", y=["loss1","loss2"], logy=True)

In [None]:
df_hist.plot(x="step", y=["eta1","eta2"])

In [None]:
df_hist.plot(x="weight1_1", y="weight1_2", xlim=(-1,2), ylim=(-1,2))

In [None]:
df_hist.plot(x="weight2_1", y="weight2_2", xlim=(-1,2), ylim=(-1,2))