In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch.utils.data.dataset import Dataset
import torch.nn as nn
from copy import deepcopy
from torch import optim
from torch.nn.modules.module import Module
from torch.utils.data import DataLoader
from torch.utils.data.sampler import WeightedRandomSampler
import cvxpy as cp
import mosek

In [None]:
seed = 2021
torch.manual_seed(seed)
np.random.seed(seed)
torch.use_deterministic_algorithms(True, warn_only=True)

Init Solver

In [None]:
def lpsolver(prob, verbose=False):
#   print('=== LP Solver ===')
  solvers = [cp.MOSEK, cp.ECOS_BB]
  for s in solvers:
    # print('==> Invoking {}...'.format(s))
    try:
      result = prob.solve(solver=s, verbose=verbose)
      return result
    except cp.error.SolverError as e:
      print('==> Solver Error')

#   print('==> Invoking MOSEK simplex method...')
  try:
    result = prob.solve(solver=cp.MOSEK,
                      mosek_params={mosek.iparam.optimizer: mosek.optimizertype.free_simplex},
                      bfs=True, verbose=verbose)
    return result
  except cp.error.SolverError as e:
    print('==> Solver Error')

  raise cp.error.SolverError('All solvers failed.')

Make Dataset

In [None]:
class MyDataset(Dataset):
  def __init__(self, X, y):
    super(MyDataset, self).__init__()
    self.X = X
    self.y = y
    self.attr = X

  def __getitem__(self, item):
    return self.X[item], self.y[item]

  def __len__(self):
    return len(self.X)

def gen_data(n):
      Y = np.random.choice([0, 1], size=n, p=[0.5, 0.5])
      X = np.zeros((n, 2), dtype=np.float32)
      # Gaussian
      X[Y == 0] = np.random.multivariate_normal([0, 0], [[1, 0], [0, 1]], size=(Y == 0).sum())
      # mixture of two Gaussians
      mask = Y == 1
      # subsample mask with 50% probability
      mask_g1 = mask & (np.random.rand(n) < 0.2)
      mask_g2 = mask & ~mask_g1
      X[mask_g1] = np.random.multivariate_normal([1, 3], [[1, 0], [0, 1]], size=mask_g1.sum())
      X[mask_g2] = np.random.multivariate_normal([3, 0], [[1, 0], [0, 1]], size=mask_g2.sum())
      return X, Y

X_train, y_train = gen_data(1000)
X_test, y_test = gen_data(1000)
# print(X_train.shape, y_train.shape)
dataset_train = MyDataset(X_train, y_train)
dataset_test_train = MyDataset(X_train, y_train)
dataset_valid = MyDataset(X_test, y_test)
dataset_test = MyDataset(X_test, y_test)

In [None]:
# plot data
plt.figure(figsize=(5, 5))
plt.scatter(X_train[y_train == 0, 0], X_train[y_train == 0, 1], label='class 0', s=5)
plt.scatter(X_train[y_train == 1, 0], X_train[y_train == 1, 1], label='class 1', s=5)
plt.legend()

Make Model

In [None]:
model = nn.Sequential(nn.Linear(2, 2, bias=True))
loss = nn.CrossEntropyLoss(reduction='none')

ERM

In [None]:
def erm(model: Module, loader: DataLoader, optimizer: optim.Optimizer, criterion, device: str, iters=0):
  """Empirical Risk Minimization (ERM)"""

  model.train()
  iteri = 0
  for _, (inputs, targets) in enumerate(loader):
    inputs, targets = inputs.to(device), targets.to(device)
    outputs = model(inputs)
    loss = criterion(outputs, targets).mean()
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    iteri += 1
    if iteri == iters:
      break

def test(model: Module, loader: DataLoader, criterion, device: str):
  """Test the avg and group acc of the model"""

  model.eval()
  total_correct = 0
  total_loss = 0
  total_num = 0
  l_rec = []
  c_rec = []

  with torch.no_grad():
    for _, (inputs, targets) in enumerate(loader):
      inputs, targets = inputs.to(device), targets.to(device)
      labels = targets
      outputs = model(inputs)
      predictions = torch.argmax(outputs, dim=1)
      c = (predictions == labels)
      c_rec.append(c.detach().cpu().numpy())
      correct = c.sum().item()
      l = criterion(outputs, labels).view(-1)
      l_rec.append(l.detach().cpu().numpy())
      loss = l.sum().item()
      total_correct += correct
      total_loss += loss
      total_num += len(inputs)
  l_vec = np.concatenate(l_rec)
  c_vec = np.concatenate(c_rec)
  return total_correct / total_num, total_loss / total_num, c_vec, l_vec

RAI GAME

In [None]:
def adaboost(l_all, w_all):
  acc = 1 - l_all[-1, :]
  lss = 1 - acc
  weight = w_all[-1, :]
  eps = lss.T @ weight
  beta = eps / (1 - eps)
  weight *= np.power(beta, lss)
  weight /= weight.sum()
  return weight, 0

def raigame(P, l_all, eta, constraints_w=['chi2'], alpha=0.99):
    num_epochs, n = l_all.shape
    w = cp.Variable(n)
    
    objective = cp.Maximize(P @ (l_all @ w) + eta * cp.sum(cp.entr(w)))
    constraints = []
    constraints.append(cp.sum(w) == 1)
    constraints.append(0 <= w)
    # add constraints
    nc = len(constraints_w)
    if 'cvar' in constraints_w:
      m = alpha * n
      constraints.append(w <= 1 / m)
      nc -= 1
    if 'chi2' in constraints_w:
      m = alpha * n
      constraints.append(cp.sum_squares(w) <= 1 / m)
      nc -= 1
    if nc > 0:
      print('Constraint {} not implemented'.format(constraints_w))
      raise NotImplementedError
    #################
    prob = cp.Problem(objective, constraints)
    result = lpsolver(prob, verbose=False)
    
    w_star = w.value
    w_star[w_star < 0] = 0
    w_star /= w_star.sum()
    return w_star, result

Train

In [None]:
T = 10
iters_per_epoch = 1000
batch_size = 32
w_all = np.zeros((0, len(dataset_train)))
l_all = np.zeros((0, len(dataset_train)))
P = np.zeros((1, 0))
n = len(dataset_train)
eta = (np.log(T) / (2 * T * np.log(n))) ** 0.5
lamda = -1/2
print(eta, eta * np.log(n))

init_state_dict = deepcopy(model.state_dict())
w_all = np.concatenate([w_all, np.ones((1, n)) / n])

test_trainloader = DataLoader(dataset_train, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=False) # for testing
prev_gamevalue = np.inf
best_gamevalue, best_P = np.inf, None
models = []
for t in range(T):
    #set-up model
    model.load_state_dict(init_state_dict)
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
    
    #set-up loader
    w_t = w_all[-1, :]
    sampler = WeightedRandomSampler(w_t, iters_per_epoch * batch_size, replacement=True)
    trainloader = DataLoader(dataset_train, batch_size=batch_size, sampler=sampler, num_workers=0, pin_memory=False)
    
    #get ht (or lt)    
    erm(model, trainloader, optimizer, loss, 'cpu', iters_per_epoch)
    models.append(deepcopy(model.state_dict()))
    _, _, c_t, ce_t = test(model, test_trainloader, loss, 'cpu')
    l_t = 1 - c_t # get 0-1 loss
    l_all = np.concatenate([l_all, l_t.reshape(1, -1)], axis=0)
    print('t={}, l_t={:.4f}, ce_t={:.4f}'.format(t, l_t.mean(), ce_t.mean()))
    
    # do line search for a    
    # base_a = 1 / (t + 1)
    base_a = 1
    best_value = np.inf
    best_a = None
    for m_a in np.linspace(0.01, 1.01, 20):
        a = base_a * m_a
        # if a >= 1:
        #     continue
        P_temp = np.concatenate([P, a * np.ones((1, 1))], axis=1)
        P_temp = P_temp / P_temp.sum()
        _, value = raigame(P_temp, l_all + lamda, eta)
        assert value < np.inf
        if value < best_value:
            best_value = value
            best_a = a
    P = np.concatenate([P, best_a * np.ones((1, 1))], axis=1)
    print('best_a = {}, best_value = {}'.format(best_a, best_value))
    # P = P / P.sum()
    
    # get new game value
    _, gamevalue = raigame(P / P.sum(), l_all, 0)
    print('t = {}, gamevalue = {}, max_w_star = {}'.format(t + 1, gamevalue, w_t.max()))
    
    # update best gamevalue
    if gamevalue < best_gamevalue:
        best_gamevalue = gamevalue
        best_P = P
        
    # update eta if gamevalue increases
    if prev_gamevalue < gamevalue:
        eta = eta * 2
    prev_gamevalue = gamevalue
    
    #get wt
    w_t, L_P = raigame(P, l_all, eta)
    # w_t, _ = adaboost(l_all, w_all)   
    w_all = np.concatenate([w_all, w_t.reshape(1, -1)], axis=0)

In [None]:
l = l_all.mean(axis=1)
np.log((1-l) / l) / 2

In [None]:
best_P = np.concatenate([best_P, np.zeros((1, T - best_P.shape[1]))], axis=1)
P = best_P

In [None]:
P = P / P.sum()
P

Analysis

In [None]:
P_erm = np.zeros((1, T))
P_erm[:, 0] = 1

_, game_erm = raigame(P_erm, l_all, 0)
_, game_opt = raigame(P, l_all, 0)

print('game_erm = {}, game_opt = {}'.format(game_erm, game_opt))

In [None]:
# Analyze CVaR Loss
def get_cvar_loss(P, l, alpha):
    exp_loss = (P @ l).reshape(-1)
    exp_loss = np.sort(exp_loss)[::-1]
    n = exp_loss.shape[0]
    return format(exp_loss[:int(alpha * n)].mean(), '.3f')

for alpha in [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]:
    print('alpha = {}, opt_cvar = {}, erm_cvar = {}'.format(alpha, get_cvar_loss(P, l_all, alpha), get_cvar_loss(P_erm, l_all, alpha)))

In [None]:
# Analyze worst-class loss
def get_worst_class_loss(P, l, dataset):
    assert len(dataset) == l.shape[1]
    exp_loss = (P @ l).reshape(-1)
    num_classes = len(np.unique(dataset.y))
    dct = {}
    for i in range(num_classes):
        idx = np.where(dataset.y == i)[0]
        dct[i] = exp_loss[idx].mean()
    print("Class Loss: ", dct)
    print("Worst Class Loss: ", max(dct.values()))
    return 

get_worst_class_loss(P, l_all, dataset_train)
get_worst_class_loss(P_erm, l_all, dataset_train)