## Problem

The two-bit-env problem with confounded-descendant:

\begin{align*}
    & Y \sim \text{Rad}(c_e)\\
    & X_1 = Y \text{Rad}(a)\\
    & X_2 = Y \text{Rad}(b_e)
\end{align*}

Use $(a, b_e, c_e)$ to denote an environment. I use $a = 0.25$ for all environments. We have 4 training domains, with $b_e = 0.8, 0.9, 0.85, 0.95$ and $c_e = 0.9, 0.1, 0.7, 0.3$ respectively. For test domain, we flip the relationship between $X_2$ and $Y$: $b_e = 0.1$. We use $c_e = 0.5$ 


Since $X_1, X_2, Y$ takes values in $\{1, -1\}$, we use $x_{01}$ to denote the prediction $f(1, -1)$ etc, where first coordinate is for $x_1$ and the second component is for $x_2$. I use $p_{011}$ to denote $P(X_1 = 1, X_2 = -1, Y = -1)$, etc. 

## obj and predictor

We use $\text{obj} := E[\text{loss}(f(X), Y)] + \lambda E[g(f(X), Y)]$ as the objective. Here $g$ is the penalty for IRMv1 and gIRMv1 respectively. We computed obj and find the minimizer over $(x_{00}, x_{10}, x_{01}, x_{11})$. We know the optimal invariant solution should be of the form $(v, -v, v, -v)$.


* Note: I implement gIRM by first performing importance sampling on both the loss and the penalty. (Another way is to do importance sampling only on the penalty. But the two approaches give the same result here.)

In [1]:
from sympy import *
import numpy as np
from scipy.optimize import minimize as scipy_min
np.random.seed(1)

In [2]:
def sqls(zh, z):
    return 0.5 * Pow(zh - z, 2)

def logls(zh, z):
    return log(1 + exp(-z * zh))

def zols(zh, z):
    if Eq(sign(zh), z):
        return 0
    return 1

def IRMparts(e, ls, x00, x10, x01, x11, generalized = False):
    a, b, c = e
   
    
    w = symbols('w')
    
    p000 = (1 - c)*(1 - a)*(1 - b)
    p001 = c*a*b;
    p100 = (1 - c)*a*(1 - b)
    p101 = c*(1 - a)*b
    p010 = (1 - c)*(1 - a)*b
    p011 = c*a*(1 - b)
    p110 = (1 - c)*a*b
    p111 = c * (1 - a) * (1 - b);

    R = p000 * ls(x00, 1) + p001 * ls(x00, -1) + \
    p100 * ls(x10, 1) + p101 * ls(x10, -1) +\
    p010 * ls(x01, 1) + p011 * ls(x01, -1)+\
    p110 * ls(x11, 1) + p111 * ls(x11, -1)
    
    if not generalized:
        dR = R.subs([(x00, w*x00), (x01, w*x01), (x10, w*x10), (x11, w*x11)]).diff(w).subs(w, 1)
        return [R, dR]
     
    c = 0.5 ## after reweighting it's equivalent to the case without prior shift
    p000 = (1 - c)*(1 - a)*(1 - b)
    p001 = c*a*b;
    p100 = (1 - c)*a*(1 - b)
    p101 = c*(1 - a)*b
    p010 = (1 - c)*(1 - a)*b
    p011 = c*a*(1 - b)
    p110 = (1 - c)*a*b
    p111 = c * (1 - a) * (1 - b);

    R_w = p000 * ls(x00, 1) + p001 * ls(x00, -1) + \
    p100 * ls(x10, 1) + p101 * ls(x10, -1) +\
    p010 * ls(x01, 1) + p011 * ls(x01, -1)+\
    p110 * ls(x11, 1) + p111 * ls(x11, -1)

    dR = R_w.subs([(x00, w*x00), (x01, w*x01), (x10, w*x10), (x11, w*x11)]).diff(w).subs(w, 1)

    return [R, dR]

def IRMv1e(e, lam, ls, x00, x10, x01, x11, generalized = False):
    R, dR = IRMparts(e, ls, x00, x10, x01, x11, generalized)
    return R + lam * dR**2

def IRMv1(envs, lam, ls, x00, x10, x01, x11, generalized = False):
    results = 0
    for e in envs:
        results += IRMv1e(e, lam, ls, x00, x10, x01, x11, generalized)
        
    return results

In [3]:
def prior_shift(alpha, lam, betalist, clist, ls, inits, generalized = False):
    envs = []
    for i in range(2):
        envs += [[alpha, betalist[i], clist[i]]]

    x00, x10, x01, x11 = symbols('x00 x10 x01 x11')
    obj = IRMv1(envs, lam ,ls, x00, x10, x01, x11, generalized)

    func = lambdify([x00, x10, x01, x11], obj)
    def func_np(params):
        x00, x10, x01, x11 = params
        return func(x00, x10, x01, x11)

    n_init = inits.shape[0]
    sols = np.zeros((n_init, 4))
    objs = []
    for i in range(n_init):
        res = scipy_min(func_np, inits[i, :], method='Nelder-Mead', tol=1e-10)
        sols[i, :] = res.x
        objs.append(res.fun)

    obj_min = min(objs)
    sol_min = sols[objs.index(obj_min), :]
    print("obj,   sol")
    print((round(obj_min, 2), sol_min.round(2)))

## Without prior shift

In [4]:
n_init = 500
inits = np.random.uniform(low = -2, high = 2, size = (n_init, 4))

lam = 1e+20
alpha = 0.25
# betalist = [0.1, 0.2]
# clist = [0.5, 0.5]
betalist = [0.2, 0.1, 0.15, 0.05]
clist = [0.5, 0.5, 0.5, 0.5]


print("square loss")
prior_shift(alpha, lam, betalist, clist, sqls, inits)

print("logistic loss")
prior_shift(alpha, lam, betalist, clist, logls, inits)

square loss
obj,   sol
(0.75, array([ 0.52, -0.5 ,  0.5 , -0.48]))
logistic loss
obj,   sol
(1.12, array([ 1.16, -1.11,  1.08, -1.03]))


## With prior shift

In [5]:
betalist = [0.2, 0.1, 0.15, 0.05]
clist = [0.9, 0.1, 0.7, 0.3]

print("square loss")
prior_shift(alpha, lam, betalist, clist, sqls, inits)

print("logistic loss")
prior_shift(alpha, lam, betalist, clist, logls, inits)

square loss
obj,   sol
(0.44, array([ 0.93,  0.15, -0.4 , -0.85]))
logistic loss
obj,   sol
(0.72, array([ 2.53, -0.08, -0.93, -3.19]))


## g-IRM fixes the issue with prior-shift

In [6]:
betalist = [0.2, 0.1, 0.15, 0.05]
clist = [0.9, 0.1, 0.7, 0.3]

print("square loss")
prior_shift(alpha, lam, betalist, clist, sqls, inits, generalized = True)

print("logistic loss")
prior_shift(alpha, lam, betalist, clist, logls, inits, generalized = True)

square loss
obj,   sol
(0.75, array([ 0.52, -0.5 ,  0.5 , -0.48]))
logistic loss
obj,   sol
(1.12, array([ 1.16, -1.11,  1.08, -1.03]))
