In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch.utils.data.dataset import Dataset
import torch.nn as nn
from copy import deepcopy
from torch import optim
from torch.nn.modules.module import Module
from torch.utils.data import DataLoader
from torch.utils.data.sampler import WeightedRandomSampler
import cvxpy as cp

In [None]:
seed = 2021
torch.manual_seed(seed)
np.random.seed(seed)
torch.use_deterministic_algorithms(True, warn_only=True)

Init Solver

In [None]:
def lpsolver(prob, verbose=False):
#   print('=== LP Solver ===')
  solvers = [cp.MOSEK, cp.ECOS_BB]
  for s in solvers:
    # print('==> Invoking {}...'.format(s))
    try:
      result = prob.solve(solver=s, verbose=verbose)
      return result
    except cp.error.SolverError as e:
      print('==> Solver Error')

#   print('==> Invoking MOSEK simplex method...')
  try:
    result = prob.solve(solver=cp.MOSEK,
                      mosek_params={mosek.iparam.optimizer: mosek.optimizertype.free_simplex},
                      bfs=True, verbose=verbose)
    return result
  except cp.error.SolverError as e:
    print('==> Solver Error')

  raise cp.error.SolverError('All solvers failed.')

Make Dataset

In [None]:
class MyDataset(Dataset):
  def __init__(self, X, y):
    super(MyDataset, self).__init__()
    self.X = X
    self.y = y
    self.attr = X

  def __getitem__(self, item):
    return self.X[item], self.y[item]

  def __len__(self):
    return len(self.X)

def gen_data(n):
      X = np.random.multivariate_normal([0, 0], [[1, 0], [0, 1]], size=n)
      thr = 2 **0.5
      norm = np.linalg.norm(X, axis=1)
      to_rem = (norm < 1.3 * thr) & (norm > thr / 1.3)
      # subsample
      to_rem = to_rem & (np.random.rand(n) < 0.5)
      # remove those in the middle
      X = X[~to_rem]
      Y = np.linalg.norm(X, axis=1) < thr
      pop_masks = []
      return X.astype(np.float32), Y.astype(np.long), pop_masks

X_train, y_train, pop_masks_train = gen_data(5000)
X_test, y_test, pop_masks_test = gen_data(5000)
# print(X_train.shape, y_train.shape)
dataset_train = MyDataset(X_train, y_train)
dataset_test_train = MyDataset(X_train, y_train)
dataset_valid = MyDataset(X_test, y_test)
dataset_test = MyDataset(X_test, y_test)

In [None]:
# plot data
plt.figure(figsize=(5, 5))
plt.scatter(X_train[y_train == 0, 0], X_train[y_train == 0, 1], label='class 0', s=5)
plt.scatter(X_train[y_train == 1, 0], X_train[y_train == 1, 1], label='class 1', s=5)
plt.legend()

In [None]:
# pop_masks_train[0].sum(), pop_masks_train[1].sum(), pop_masks_train[2].sum(), pop_masks_train[3].sum(), pop_masks_train[4].sum()

In [None]:
X_train.shape, y_train.shape

Make Model

In [None]:
# model = nn.Sequential(nn.Linear(2, 2, bias=True))
model = nn.Sequential(nn.Linear(2, 4, bias=True), nn.ReLU(), nn.Linear(4, 2, bias=True))
loss = nn.CrossEntropyLoss(reduction='none')

ERM

In [None]:
def pgd_attack(model, inputs, targets, criterion, eps=0.5, iters=1):
  ori_inputs = inputs
  for iter in range(iters):
    inputs.requires_grad = True
    outputs = model(inputs)
    loss = criterion(outputs, targets).mean()
    loss.backward()
    grad = inputs.grad.detach()
    inputs.requires_grad = False
    delta = eps * torch.sign(grad)
    inputs = torch.clamp(ori_inputs + delta, min=-3, max=3)
  return inputs

def erm(model: Module, loader: DataLoader, optimizer: optim.Optimizer, criterion, device: str, iters=0):
  """Empirical Risk Minimization (ERM)"""

  model.train()
  iteri = 0
  for _, (inputs, targets) in enumerate(loader):
    inputs, targets = inputs.to(device), targets.to(device)
    # add adversarial noise using pgd attack
    inputs = pgd_attack(model, inputs, targets, criterion)
    #
    outputs = model(inputs)
    loss = criterion(outputs, targets).mean()
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    iteri += 1
    if iteri == iters:
      break

def test(model: Module, loader: DataLoader, criterion, device: str):
  """Test the avg and group acc of the model"""

  model.eval()
  total_correct = 0
  total_loss = 0
  total_num = 0
  l_rec = []
  c_rec = []

  with torch.no_grad():
    for _, (inputs, targets) in enumerate(loader):
      inputs, targets = inputs.to(device), targets.to(device)
      labels = targets
      outputs = model(inputs)
      predictions = torch.argmax(outputs, dim=1)
      c = (predictions == labels)
      c_rec.append(c.detach().cpu().numpy())
      correct = c.sum().item()
      l = criterion(outputs, labels).view(-1)
      l_rec.append(l.detach().cpu().numpy())
      loss = l.sum().item()
      total_correct += correct
      total_loss += loss
      total_num += len(inputs)
  l_vec = np.concatenate(l_rec)
  c_vec = np.concatenate(c_rec)
  return total_correct / total_num, total_loss / total_num, c_vec, l_vec

In [None]:
# model = nn.Sequential(nn.Linear(2, 4, bias=True), nn.ReLU(), nn.Linear(4, 2, bias=True))

# batch_size = 512
# iters = 1000

# optimizer = optim.AdamW(model.parameters(), lr=0.01, weight_decay=1e-5)
# trainloader = DataLoader(dataset_train, batch_size=batch_size, shuffle=True)
# for i in range(50):
#     erm(model, trainloader, optimizer, loss, 'cpu', iters=iters)
#     acc, _, c_vec, _ = test(model, trainloader, loss, 'cpu')
#     print("Train Acc: ", acc, end=', ')
# acc, _, c_vec, _ = test(model, testloader, loss, 'cpu')
# print("Test Loss: ", 1 - acc)


RAI GAME

In [None]:
def raigame(P, l_all, eta, constraints_w=['chi2'], alpha=0.7, group_idx=pop_masks_train, verbose=False):
    num_epochs, n = l_all.shape
    w = cp.Variable(n)
    
    objective = cp.Maximize(P @ (l_all @ w) + eta * cp.sum(cp.entr(w)))
    constraints = []
    constraints.append(cp.sum(w) == 1)
    constraints.append(0 <= w)
    # add constraints
    nc = len(constraints_w)
    if 'cvar' in constraints_w:
      m = alpha * n
      constraints.append(w <= 1 / m)
      nc -= 1
    if 'chi2' in constraints_w:
      m = alpha * n
      constraints.append(cp.sum_squares(w) <= 1 / m)
      nc -= 1
    if 'group_dro' in constraints_w:
      for idx in group_idx:
        # get all the indices of the group
        locs = np.where(idx)[0]
        for i in locs:
          constraints.append(w[i] == w[locs[0]])
      constraints.append(cp.entr(cp.vstack([cp.sum(w[idx]) for idx in group_idx])) >= 0.1)
      nc -= 1
    if nc > 0:
      print('Constraint {} not implemented'.format(constraints_w))
      raise NotImplementedError
    #################
    prob = cp.Problem(objective, constraints)
    result = lpsolver(prob, verbose=False)
    
    w_star = w.value
    w_star[w_star < 0] = 0
    w_star /= w_star.sum()
    return w_star, result

Train

In [None]:
T = 5
iters_per_epoch = 3000
batch_size = 32
w_all = np.zeros((0, len(dataset_train)))
l_all = np.zeros((0, len(dataset_train)))
P = np.zeros((1, 0))
n = len(dataset_train)
eta = (np.log(T) / (2 * T * np.log(n))) ** 0.5
print(eta, eta * np.log(n))

init_state_dict = deepcopy(model.state_dict())
w_all = np.concatenate([w_all, np.ones((1, n)) / n])

test_trainloader = DataLoader(dataset_train, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=False) # for testing
prev_gamevalue = np.inf
best_gamevalue, best_P = np.inf, None
models = []
for t in range(T):
    #set-up model
    model.load_state_dict(init_state_dict)
    # optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=5e-4)
    optimizer = torch.optim.AdamW(model.parameters(), lr=0.01, weight_decay=1e-2)
    
    #set-up loader
    w_t = w_all[-1, :]
    sampler = WeightedRandomSampler(w_t, iters_per_epoch * batch_size, replacement=True)
    trainloader = DataLoader(dataset_train, batch_size=batch_size, sampler=sampler, num_workers=0, pin_memory=False)
    
    #get ht (or lt)    
    erm(model, trainloader, optimizer, loss, 'cpu', iters_per_epoch)
    models.append(deepcopy(model.state_dict()))
    _, _, c_t, ce_t = test(model, test_trainloader, loss, 'cpu')
    l_t = 1 - c_t # get 0-1 loss
    print('t={}, loss={}'.format(t, l_t.mean()))
    l_all = np.concatenate([l_all, l_t.reshape(1, -1)], axis=0)

    # do line search for a    
    base_a = 1 / (t + 1)
    best_value = np.inf
    best_a =  1 / (t + 1)
    for m_a in np.linspace(0.5, 1.5, 20):
        a = base_a * m_a
        if a >= 1:
            continue
        P_temp = np.concatenate([(1 - a) * P, a * np.ones((1, 1))], axis=1)
        P_temp = P_temp / P_temp.sum()
        _, value = raigame(P_temp, l_all, 0)
        if not value < np.inf:
            continue
        if value < best_value:
            best_value = value
            best_a = a
    P = np.concatenate([(1 - best_a) * P, best_a * np.ones((1, 1))], axis=1)
    P = P / P.sum()
    
    # get new game value
    _, gamevalue = raigame(P, l_all, 0)
    print('t = {}, gamevalue = {}, max_w_star = {}'.format(t + 1, gamevalue, w_t.max()))
    
    # update best gamevalue
    if gamevalue < best_gamevalue:
        best_gamevalue = gamevalue
        best_P = P
        
    # update eta if gamevalue increases
    if prev_gamevalue < gamevalue:
        eta = eta * 2
    prev_gamevalue = gamevalue
    
    #get wt
    w_t, L_P = raigame(P, l_all, eta)
    # print(w_t[pop_masks_train[0]].sum(), w_t[pop_masks_train[1]].sum(), w_t[pop_masks_train[2]].sum(), w_t[pop_masks_train[3]].sum(), w_t[pop_masks_train[4]].sum())
    w_all = np.concatenate([w_all, w_t.reshape(1, -1)], axis=0)

In [None]:
# best_P = np.concatenate([best_P, np.zeros((1, T - best_P.shape[1]))], axis=1)
# P = best_P

In [None]:
P

In [None]:
# models[1]

In [None]:
l_all.mean(axis=1)

In [None]:
# plot data and all models with weights as width of line
# plt.figure(figsize=(5, 5))
# plt.scatter(X_train[y_train == 0, 0], X_train[y_train == 0, 1], label='class 0', s=5)
# plt.scatter(X_train[y_train == 1, 0], X_train[y_train == 1, 1], label='class 1', s=5)

# for t in range(T):
#     params = models[t]
#     weight, bias = params['0.weight'].cpu().numpy(), params['0.bias'].cpu().numpy()
#     w_t = P[0, t]
#     x = np.linspace(-6, 6, 100)
#     y = (-bias[0] - weight[0, 0] * x) / weight[0, 1]
#     model.load_state_dict(params)
#     print(model(torch.tensor([[-6., 6.]])).argmax())
#     plt.plot(x, y, label='model {}'.format(t), linewidth=w_t * 10)
# plt.legend()
# plt.ylim(-6, 6)

In [None]:
P

In [None]:
# plot decision boundary of derandomized model
def predict(x, P, models):
    P = np.array(P).reshape(1, -1)
    probs = []
    for t in range(len(models)):
        model.load_state_dict(models[t])
        probs.append(model(x).detach().cpu().argmax(dim=1, keepdim=True))
    probs = torch.cat(probs, dim=1)
    preds = (probs @ P.T)[:, 0]
    print(preds)
    return preds > 0.5

x_min, x_max = -6, 6
y_min, y_max = -6, 6
xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.05),
                        np.arange(y_min, y_max, 0.05))
pred = predict(torch.tensor(np.c_[xx.ravel(), yy.ravel()]).float(), [1] + [0] * (T - 1), models)
# pred = predict(torch.tensor(np.c_[xx.ravel(), yy.ravel()]).float(), P, models)
print(pred.shape, pred.sum())

import seaborn as sns
sns.set_style('whitegrid')
# Put the result into a color plot
# plt.figure(figsize=(5, 5))
plt.scatter(X_train[y_train == 0, 0], X_train[y_train == 0, 1], label='class 0', s=5)
plt.scatter(X_train[y_train == 1, 0], X_train[y_train == 1, 1], label='class 1', s=5)
# plot the decision boundary but do not fill and add l
sns.kdeplot(x=X_train[y_train == 0, 0], y=X_train[y_train == 0, 1], cmap="Blues", shade=True, shade_lowest=False, alpha=0.7)
sns.kdeplot(x=X_train[y_train == 1, 0], y=X_train[y_train == 1, 1], cmap="Oranges", shade=True, alpha=0.3)
plt.contour(xx, yy, pred.reshape(xx.shape), colors='tab:red', linewidths=0.3, alpha=0.7)
# plt.contourf(xx, yy, pred.reshape(xx.shape), cmap=plt.cm.coolwarm, alpha=0.5)

pred = predict(torch.tensor(np.c_[xx.ravel(), yy.ravel()]).float(), P, models)
print(pred.shape, pred.sum())
# only plot the decision boundary
plt.contour(xx, yy, pred.reshape(xx.shape), colors='tab:green', linewidths=0.3, alpha=0.7)
plt.legend()
# in bold
plt.annotate('RAI',(-4.5,-3), weight='bold', color='tab:green')
plt.annotate('ERM',(-5,-1.3), weight='bold', color='tab:red')
plt.xlim(-6, 6)
plt.ylim(-4, 4)
plt.xlabel('x1')
plt.ylabel('x2')
# plt.savefig('group_dro_rai.pdf', bbox_inches='tight')


In [None]:
pred = predict(torch.tensor(np.c_[xx.ravel(), yy.ravel()]).float(), P, models)
print(pred.shape, pred.sum())

# Put the result into a color plot
plt.figure(figsize=(5, 5))
plt.scatter(X_train[y_train == 0, 0], X_train[y_train == 0, 1], label='class 0', s=5)
plt.scatter(X_train[y_train == 1, 0], X_train[y_train == 1, 1], label='class 1', s=5)
plt.contourf(xx, yy, pred.reshape(xx.shape), cmap=plt.cm.coolwarm, alpha=0.5)
plt.legend()

Analysis

In [None]:
P_erm = np.zeros((1, T))
P_erm[:, 0] = 1

_, game_erm = raigame(P_erm, l_all, 0)
_, game_opt = raigame(P, l_all, 0)

print('game_erm = {}, game_opt = {}'.format(game_erm, game_opt))

In [None]:
# Analyze CVaR Loss
def get_cvar_loss(P, l, alpha):
    exp_loss = (P @ l).reshape(-1)
    exp_loss = np.sort(exp_loss)[::-1]
    n = exp_loss.shape[0]
    return format(exp_loss[:int(alpha * n)].mean(), '.3f')

for alpha in [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]:
    print('alpha = {}, opt_cvar = {}, erm_cvar = {}'.format(alpha, get_cvar_loss(P, l_all, alpha), get_cvar_loss(P_erm, l_all, alpha)))

In [None]:
# Analyze worst-class loss
def get_worst_class_loss(P, l, dataset):
    assert len(dataset) == l.shape[1]
    exp_loss = (P @ l).reshape(-1)
    num_classes = len(np.unique(dataset.y))
    dct = {}
    for i in range(num_classes):
        idx = np.where(dataset.y == i)[0]
        dct[i] = exp_loss[idx].mean()
    print("Class Loss: ", dct)
    print("Worst Class Loss: ", max(dct.values()))
    return 

get_worst_class_loss(P, l_all, dataset_train)
get_worst_class_loss(P_erm, l_all, dataset_train)