In [None]:
import pandas as pd
import mosek
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch.utils.data.dataset import Dataset
import torch.nn as nn
from copy import deepcopy
from torch import optim
from torch.nn.modules.module import Module
from torch.utils.data import DataLoader
from torch.utils.data.sampler import WeightedRandomSampler
import cvxpy as cp
from torch.utils.data.sampler import SubsetRandomSampler


seed = 2023
torch.manual_seed(seed)
np.random.seed(seed)
torch.use_deterministic_algorithms(True, warn_only=True)
erm_ = False

data_train = pd.read_csv('fashion-mnist_train.csv', header=None, skiprows=1)
data_test = pd.read_csv('fashion-mnist_test.csv', header=None, skiprows=1)

In [None]:
# get tshirt, pullover and shirt data
data_train = data_train[data_train[0].isin([0,2,6])]
data_test = data_test[data_test[0].isin([0,2,6])]

In [None]:
data_train.shape, data_test.shape

In [None]:
# subsample data frame for solver stability, comment out for erm performance
if True:
    data_train = data_train.sample(frac=0.8, random_state=42)
    data_train = data_train.reset_index(drop=True)

In [None]:
n_test = data_test.shape[0]
merge = pd.concat([data_train, data_test], axis=0)
dct = {0: 0, 2: 1, 6: 2}
y = merge[merge.columns[0]].map(lambda x: dct[x]).to_numpy()
merge = merge.drop(merge.columns[0], axis=1)

merge_df = merge
X = merge.to_numpy()
X_train, y_train = X[:-n_test, :], y[:-n_test]
X_test, y_test = X[-n_test:, :], y[-n_test:]
# convert to float32
X_train = X_train.astype('float32')
X_test = X_test.astype('float32')
print(X_train.shape, y_train.shape, X_test.shape, y_test.shape)

In [None]:
# normalize data
X_train /= 255
X_test /= 255

In [None]:
X_train.shape, (y_train == 0).sum(), (y_train == 1).sum(), (y_train == 2).sum()

In [None]:
tshirt_idx_train = np.where(y_train == 0)[0]
pull_idx_train = np.where(y_train == 1)[0]
shirt_idx_train = np.where(y_train == 2)[0]

tshirt_idx_test = np.where(y_test == 0)[0]
pull_idx_test = np.where(y_test == 1)[0]
shirt_idx_test = np.where(y_test == 2)[0]

Init Solver

In [None]:
def lpsolver(prob, verbose=False):
#   print('=== LP Solver ===')
  solvers = [cp.MOSEK, cp.ECOS_BB]
  for s in solvers:
    # print('==> Invoking {}...'.format(s))
    try:
      result = prob.solve(solver=s, verbose=verbose)
      return result
    except cp.error.SolverError as e:
      print('==> Solver Error')

#   print('==> Invoking MOSEK simplex method...')
  try:
    result = prob.solve(solver=cp.MOSEK,
                      mosek_params={'MSK_IPAR_OPTIMIZER': mosek.optimizertype.free_simplex},
                      bfs=True, verbose=verbose)
    return result
  except cp.error.SolverError as e:
    print('==> Solver Error')

  raise cp.error.SolverError('All solvers failed.')

Make Dataset

In [None]:
class MyDataset(Dataset):
  def __init__(self, X, y):
    super(MyDataset, self).__init__()
    self.X = X.astype('float32')
    self.y = y.astype('long')
    self.attr = X

  def __getitem__(self, item):
    return self.X[item], self.y[item]

  def __len__(self):
    return len(self.X)

dataset_train = MyDataset(X_train, y_train)
dataset_test_train = MyDataset(X_train, y_train)
dataset_valid = MyDataset(X_test, y_test)
dataset_test = MyDataset(X_test, y_test)

Make Model

In [None]:
model = nn.Sequential(nn.Linear(784, 3, bias=True))
loss = nn.CrossEntropyLoss(reduction='none')
testloader = DataLoader(dataset_test, batch_size=32, shuffle=False)

ERM

In [None]:
def erm(model: Module, loader: DataLoader, optimizer: optim.Optimizer, criterion, device: str, iters=0):
  """Empirical Risk Minimization (ERM)"""

  model.train()
  iteri = 0
  avg_loss = 0
  for _, (inputs, targets) in enumerate(loader):
    inputs, targets = inputs.to(device), targets.to(device)
    outputs = model(inputs)
    loss = criterion(outputs, targets).mean()
    avg_loss += loss.item()
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    iteri += 1
    if iteri == iters:
      break

def test(model: Module, loader: DataLoader, criterion, device: str):
  """Test the avg and group acc of the model"""

  model.eval()
  total_correct = 0
  total_loss = 0
  total_num = 0
  l_rec = []
  c_rec = []

  with torch.no_grad():
    for _, (inputs, targets) in enumerate(loader):
      inputs, targets = inputs.to(device), targets.to(device)
      labels = targets
      outputs = model(inputs)
      predictions = torch.argmax(outputs, dim=1)
      c = (predictions == labels)
      c_rec.append(c.detach().cpu().numpy())
      correct = c.sum().item()
      l = criterion(outputs, labels).view(-1)
      l_rec.append(l.detach().cpu().numpy())
      loss = l.sum().item()
      total_correct += correct
      total_loss += loss
      total_num += len(inputs)
  l_vec = np.concatenate(l_rec)
  c_vec = np.concatenate(c_rec)
  return total_correct / total_num, total_loss / total_num, c_vec, l_vec

In [None]:
if erm_:
    batch_size = 32
    iters = 1000
    
    optimizer = optim.AdamW(model.parameters(), lr=5e-4, weight_decay=1e-5)
    trainloader = DataLoader(dataset_train, batch_size=batch_size, shuffle=True)
    for i in range(5):
        erm(model, trainloader, optimizer, loss, 'cpu', iters=iters)
        acc, _, c_vec, _ = test(model, trainloader, loss, 'cpu')
        print("Train Acc: ", acc, end=', ')
    acc, _, c_vec, _ = test(model, testloader, loss, 'cpu')
    print("Test Loss: ", 1 - acc)
    print("shirt Loss: ", 1 - c_vec[shirt_idx_test].mean())
    print("pull Loss: ", 1 - c_vec[pull_idx_test].mean())
    print("tshirt Loss: ", 1 - c_vec[tshirt_idx_test].mean())

    # assert False

RAI GAME

In [None]:
def raigame(P, l_all, eta, constraints_w=['group_doro', 'chi2'], alpha=0.55, group_idx=[pull_idx_train, tshirt_idx_train, shirt_idx_train]):
    num_epochs, n = l_all.shape
    w = cp.Variable(n)
    
    objective = cp.Maximize(P @ (l_all @ w) + eta * cp.sum(cp.entr(w)))
    constraints = []
    constraints.append(cp.sum(w) == 1)
    constraints.append(1e-10 <= w)
    # add constraints
    nc = len(constraints_w)
    if 'cvar' in constraints_w:
      m = alpha * n
      constraints.append(w <= 1 / m)
      nc -= 1
    if 'chi2' in constraints_w:
      m = alpha * n
      constraints.append(cp.sum_squares(w) <= 1 / m)
      nc -= 1
    if 'group_doro' in constraints_w:
      # calculate entropy between groups
      constraints.append(cp.entr(cp.vstack([cp.sum(w[idx]) for idx in group_idx])) >= 0.1)
      nc -= 1
    if nc > 0:
      print('Constraint {} not implemented'.format(constraints_w))
      raise NotImplementedError
    #################
    prob = cp.Problem(objective, constraints)
    result = lpsolver(prob, verbose=False)
    
    w_star = w.value
    w_star[w_star < 0] = 0
    w_star /= w_star.sum()
    return w_star, result

Train

In [None]:
init_state_dict = deepcopy(model.state_dict())
T = 4
iters_per_epoch = 2000
batch_size = 512
w_all = np.zeros((0, len(dataset_train)))
l_all = np.zeros((0, len(dataset_train)))
P = np.zeros((1, 0))
n = len(dataset_train)
eta =  10 * (np.log(T) / (2 * T * np.log(n))) ** 0.5
print(eta, eta * np.log(n))

w_all = np.concatenate([w_all, np.ones((1, n)) / n])

test_trainloader = DataLoader(dataset_train, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=False) # for testing
prev_gamevalue = np.inf
best_gamevalue, best_P = np.inf, None
models = []
for t in range(T):
    #set-up model
    model.load_state_dict(init_state_dict)
    optimizer = optim.AdamW(model.parameters(), lr=5e-4, weight_decay=1e-1)

    #set-up loader
    w_t = w_all[-1, :]
    sampler = WeightedRandomSampler(w_t, iters_per_epoch * batch_size, replacement=True)
    trainloader = DataLoader(dataset_train, batch_size=batch_size, sampler=sampler, num_workers=0, pin_memory=False)
    
    #get ht (or lt)    
    erm(model, trainloader, optimizer, loss, 'cpu', iters_per_epoch)
    models.append(deepcopy(model.state_dict()))
    _, _, c_t, ce_t = test(model, test_trainloader, loss, 'cpu')
    l_t = 1 - c_t # get 0-1 loss
    l_all = np.concatenate([l_all, l_t.reshape(1, -1)], axis=0)
    print('t={}, loss={}'.format(t, l_t.mean()))
    # do line search for a    
    base_a = 1 / (t + 1)
    best_value = np.inf
    best_a = 1 / (t + 1)
    for m_a in [0.1, 1, 1.5]:
        a = base_a * m_a
        if a >= 1:
            continue
        P_temp = np.concatenate([(1 - a) * P, a * np.ones((1, 1))], axis=1)
        P_temp = P_temp / P_temp.sum()
        _, value = raigame(P_temp, l_all, 0)
        if not value < np.inf:
            continue   
        if value < best_value:
            best_value = value
            best_a = a
    P = np.concatenate([(1 - best_a) * P, best_a * np.ones((1, 1))], axis=1)
    P = P / P.sum()
    
    # get new game value
    _, gamevalue = raigame(P, l_all, 0)
    print('t = {}, gamevalue = {}, max_w_star = {}'.format(t + 1, gamevalue, w_t.max()))
    
    # update best gamevalue
    if gamevalue < best_gamevalue:
        best_gamevalue = gamevalue
        best_P = P
        
    # update eta if gamevalue increases
    if prev_gamevalue < gamevalue:
        eta = eta * 2
    prev_gamevalue = gamevalue
    
    #get wt
    w_t, L_P = raigame(P, l_all, eta)
    w_all = np.concatenate([w_all, w_t.reshape(1, -1)], axis=0)

In [None]:
best_P = np.concatenate([best_P, np.zeros((1, T - best_P.shape[1]))], axis=1)
P = best_P

In [None]:
P

Analysis

In [None]:
P_erm = np.zeros((1, T))
P_erm[:, 0] = 1

P_unif = np.ones((1, T)) / T

_, game_erm = raigame(P_erm, l_all, 0)
_, game_opt = raigame(P, l_all, 0)
_, game_unif = raigame(P_unif, l_all, 0)

print('game_erm = {}, game_opt = {}, game_unif = {}'.format(game_erm, game_opt, game_unif))

In [None]:
# Analyze CVaR Loss
def get_cvar_loss(P, l, alpha):
    exp_loss = (P @ l).reshape(-1)
    exp_loss = np.sort(exp_loss)[::-1]
    n = exp_loss.shape[0]
    return format(exp_loss[:int(alpha * n)].mean(), '.3f')

for alpha in [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]:
    print('alpha = {}, opt_cvar = {}, erm_cvar = {}, unif_cvar = {}'.format(alpha, get_cvar_loss(P, l_all, alpha), get_cvar_loss(P_erm, l_all, alpha), get_cvar_loss(P_unif, l_all, alpha)))

In [None]:
# Analyze worst-class loss
def get_worst_class_loss(P, l, dataset):
    assert len(dataset) == l.shape[1]
    exp_loss = (P @ l).reshape(-1)
    num_classes = len(np.unique(dataset.y))
    dct = {}
    for i in range(num_classes):
        idx = np.where(dataset.y == i)[0]
        print(i, idx.shape)
        dct[i] = exp_loss[idx].mean()
    print("Class Loss: ", dct)
    print("Worst Class Loss: ", max(dct.values()))
    return 

get_worst_class_loss(P, l_all, dataset_train)
get_worst_class_loss(P_erm, l_all, dataset_train)

In [None]:
P[:, 5:] = 0
P = P / P.sum()
P

In [None]:
def get_acc(P, models, loader):
    acc = []
    for mod in models:
        model.load_state_dict(mod)
        acc.append(test(model, loader, loss, 'cpu')[0])
    return np.array(acc), P @ acc

from torch.utils.data.sampler import SubsetRandomSampler
tshirt_testloader = DataLoader(dataset_test, batch_size=batch_size, sampler=SubsetRandomSampler(np.where(dataset_test.y == 0)[0]), num_workers=0, pin_memory=False)
pull_testloader = DataLoader(dataset_test, batch_size=batch_size, sampler=SubsetRandomSampler(np.where(dataset_test.y == 1)[0]), num_workers=0, pin_memory=False)
shirt_testloader = DataLoader(dataset_test, batch_size=batch_size, sampler=SubsetRandomSampler(np.where(dataset_test.y == 2)[0]), num_workers=0, pin_memory=False)

acc_tshirt, acc_tshirt_avg = get_acc(P, models, tshirt_testloader)
acc_pull, acc_pull_avg = get_acc(P, models, pull_testloader)
acc_shirt, acc_shirt_avg = get_acc(P, models, shirt_testloader)
overall_acc, overall_acc_avg = get_acc(P, models, testloader)

1 - acc_tshirt_avg, 1 - acc_pull_avg, 1 - acc_shirt_avg, 1 - overall_acc_avg

In [None]:
acc_tshirt, acc_tshirt_avg = get_acc(P_erm, models, tshirt_testloader)
acc_pull, acc_pull_avg = get_acc(P_erm, models, pull_testloader)
acc_shirt, acc_shirt_avg = get_acc(P_erm, models, shirt_testloader)
overall_acc, overall_acc_avg = get_acc(P_erm, models, testloader)

1 - acc_tshirt_avg, 1 - acc_pull_avg, 1 - acc_shirt_avg, 1 - overall_acc_avg