In [1]:
import numpy as np
import math
import random
import torch
from torch.autograd import grad
from torch.optim import SGD, Adam
from torch.nn import MSELoss, Parameter, Module, Linear, BCEWithLogitsLoss
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

In [16]:
class Example1Dataset:
    """
    Example 1 in Aubin et. al (regression, heteroskedastic data)
    """

    def __init__(self, dim_inv=5, dim_spu=5, n_envs=3):
        self.scramble = torch.eye(dim_inv + dim_spu)
        self.dim_inv = dim_inv
        self.dim_spu = dim_spu
        self.dim = dim_inv + dim_spu

        self.var_env = {}

        assert n_envs > 1
        self.var_env = {'E0': 0.1, 'E1': 1.5}

        if n_envs == 3:
            self.var_env['E2'] = 2.0
        if n_envs > 3:
            self.var_env = {**self.var_env, **{'E'+ str(e): ((10-0.01)*torch.rand(1) + 0.01).item() for e in range(3, n_envs)}}

        self.wxy = torch.randn(self.dim_inv, self.dim_inv) / self.dim_inv
        self.wyz = torch.randn(self.dim_inv, self.dim_spu) / self.dim_spu

    def sample(self, n=1000, env="E0", split="train"):
        var = self.var_env[env]
        x = torch.randn(n, self.dim_inv) * var
        y = x @ self.wxy + torch.randn(n, self.dim_inv) * var
        z = y @ self.wyz + torch.randn(n, self.dim_spu)

        if split == "test":
            z = z[torch.randperm(len(z))]

        inputs = torch.cat((x, z), -1) @ self.scramble
        outputs = y.sum(1, keepdim=True)

        return inputs, outputs

In [17]:
class Example2Dataset:
    """
    Example 1 in Aubin et. al (classification, cows and camels, 
    directly copied from https://github.com/facebookresearch/InvarianceUnitTests/blob/main/scripts/datasets.py)
    """

    def __init__(self, dim_inv, dim_spu, n_envs):
        self.scramble = torch.eye(dim_inv + dim_spu)
        self.dim_inv = dim_inv
        self.dim_spu = dim_spu
        self.dim = dim_inv + dim_spu

        self.task = "classification"
        self.envs = {}

        if n_envs >= 2:
            self.envs = {
                'E0': {"p": 0.95, "s": 0.3},
                'E1': {"p": 0.97, "s": 0.5}
            }
        if n_envs >= 3:
            self.envs["E2"] = {"p": 0.99, "s": 0.7}
        if n_envs > 3:
            for env in range(3, n_envs):
                self.envs["E" + str(env)] = {
                    "p": torch.zeros(1).uniform_(0.9, 1).item(),
                    "s": torch.zeros(1).uniform_(0.3, 0.7).item()
                }
        print("Environments variables:", self.envs)

        # foreground is 100x noisier than background
        self.snr_fg = 1e-2
        self.snr_bg = 1

        # foreground (fg) denotes animal (cow / camel)
        cow = torch.ones(1, self.dim_inv)
        self.avg_fg = torch.cat((cow, cow, -cow, -cow))

        # background (bg) denotes context (grass / sand)
        grass = torch.ones(1, self.dim_spu)
        self.avg_bg = torch.cat((grass, -grass, -grass, grass))

    def sample(self, n=1000, env="E0", split="train"):
        p = self.envs[env]["p"]
        s = self.envs[env]["s"]
        w = torch.Tensor([p, 1 - p] * 2) * torch.Tensor([s] * 2 + [1 - s] * 2)
        i = torch.multinomial(w, n, True)
        x = torch.cat((
            (torch.randn(n, self.dim_inv) /
                math.sqrt(10) + self.avg_fg[i]) * self.snr_fg,
            (torch.randn(n, self.dim_spu) /
                math.sqrt(10) + self.avg_bg[i]) * self.snr_bg), -1)

        if split == "test":
            x[:, self.dim_spu:] = x[torch.randperm(len(x)), self.dim_spu:]

        inputs = x @ self.scramble
        outputs = x[:, :self.dim_inv].sum(1, keepdim=True).gt(0).float()

        return inputs, outputs

In [68]:
def test_model(model, dummy_w, test_data, task):
    
    x_test, y_test = test_data
    if task == "regression":
        criterion = MSELoss()
    elif task == "classification":
        criterion = BCEWithLogitsLoss()
    else:
      raise ValueError("Choose b/w 'regression' and 'classification' only.")

    with torch.no_grad():
        y_pred = x_test @ model.linear.weight.T * dummy_w
        test_loss = criterion(y_pred, y_test)

    if task=="regression":
        return test_loss.item()  
    else:
        return min(test_loss.item(), 1-test_loss.item())

In [24]:
class LinearModel(Module):
    def __init__(self, input_dim, output_dim):
        super(LinearModel, self).__init__()
        self.linear = Linear(input_dim, output_dim, bias=False)

    def forward(self, x):
        return self.linear(x)


dim_inv, dim_spu = 5, 5
n_envs = 3
learning_rate = 1e-3

# IRM

The loss function for this method is:
\begin{equation}
    \min_{f:\mathcal{X}\rightarrow \mathcal{Y}}\sum_{e \in [m]} R^e(1\cdot f) + \lambda \cdot \mathbb{D}(w, f, e)\Big|_{w=1.0},
\end{equation}

where $R^e(\phi)$ is the risk of function $\phi$ under environment $E_e$, $f$ is the invariant classifier, $\mathbb{D}(w,f,e)$ is a metric for how close $w$ is to minimizing $R^e(w\cdot f)$ (e.g. $||\nabla_wR^e(w\cdot f)||^2$) when $w\overset{(set)}{=}1.0$, and $\lambda$ is a hyperparameter balancing the predictive power (the second term), and invariance ($1\cdot f$).

This minimal implementation is inspired by Appendix D of the original IRM paper: https://arxiv.org/pdf/1907.02893.pdf

In [25]:
def compute_penalty(losses, dummy_w):
    g1 = grad(losses[0::2].mean(), dummy_w, create_graph=True)[0]
    g2 = grad(losses[1::2].mean(), dummy_w, create_graph=True)[0]
    return (g1 * g2).sum()


def train_irm(model, dummy_w, environments, optimizer, num_epochs=50000, print_interval=10000, task="regression"):
    if task == "regression":
        loss = MSELoss(reduction="none")
    elif task == "classification":
        loss = BCEWithLogitsLoss(reduction="none")
    else:
        raise ValueError("Choose b/w 'regression' and 'classification' only.")
    
    for epoch in range(num_epochs):
        total_error = 0
        irm_penalty = 0

        for inputs, targets in environments:
            error = loss(inputs @ model.linear.weight.T * dummy_w, targets)
            irm_penalty += compute_penalty(error, dummy_w)
            total_error += error.mean()

        optimizer.zero_grad()
        (1e-5 * total_error + irm_penalty).backward()
        optimizer.step()

        if epoch % print_interval == 0:
            print(f"Epoch {epoch}: {model.linear.weight}")

#### Regression (Example 1) - IRM

Observe that IRM performs well, and the obtained test errors are close to those in Table 1 of Aubin et al: https://arxiv.org/pdf/2102.10867.pdf.

In [26]:
#IRM
model = LinearModel(dim_inv + dim_spu, 1)
dummy_w = Parameter(torch.Tensor([1.0]))
optimizer = Adam(model.parameters(), lr=learning_rate)

task = "regression"

# Training
example1 = Example1Dataset(dim_inv, dim_spu, n_envs)
environments = [example1.sample(n=1000, env=f"E{i}") for i in range(n_envs)]
train_irm(model, dummy_w, environments, optimizer, task = task)

# Testing
test_data = [example1.sample(n=200, env=f"E{i}", split="test") for i in range(n_envs)]
for test_set in test_data:
    test_loss = test_model(model, dummy_w, test_set, task)
    print("Test MSE Error:", test_loss)
del model

Epoch 0: Parameter containing:
tensor([[ 0.3121, -0.2900,  0.0049, -0.0209,  0.0370,  0.1799,  0.0769,  0.1310,
          0.1099, -0.0612]], requires_grad=True)
Epoch 10000: Parameter containing:
tensor([[-0.1612,  0.2394,  0.2942, -0.0666, -0.7940,  0.0777, -0.3930,  0.0648,
          0.1549,  0.1043]], requires_grad=True)
Epoch 20000: Parameter containing:
tensor([[-0.1612,  0.2394,  0.2942, -0.0666, -0.7940,  0.0777, -0.3930,  0.0648,
          0.1549,  0.1043]], requires_grad=True)
Epoch 30000: Parameter containing:
tensor([[-0.1612,  0.2394,  0.2942, -0.0666, -0.7940,  0.0777, -0.3930,  0.0648,
          0.1549,  0.1043]], requires_grad=True)
Epoch 40000: Parameter containing:
tensor([[-0.1612,  0.2394,  0.2942, -0.0667, -0.7940,  0.0778, -0.3930,  0.0648,
          0.1549,  0.1043]], requires_grad=True)
Test MSE Error: 0.25268781185150146
Test MSE Error: 12.233926773071289
Test MSE Error: 23.905122756958008


#### Classification (Example 2) - IRM

Observe that the obtained test errors are very close to those in Table 1 of Aubin et al: https://arxiv.org/pdf/2102.10867.pdf

In [27]:
#IRM
model = LinearModel(dim_inv + dim_spu, 1)
dummy_w = Parameter(torch.Tensor([1.0]))
optimizer = Adam(model.parameters(), lr=learning_rate)

task = "classification"

# Training
example2 = Example2Dataset(dim_inv, dim_spu, n_envs)
environments = [example2.sample(n=1000, env=f"E{i}") for i in range(n_envs)]
train_irm(model, dummy_w, environments, optimizer, task = task)

# Testing
test_data = [example2.sample(n=200, env=f"E{i}", split="test") for i in range(n_envs)]
for test_set in test_data:
    test_loss = test_model(model, dummy_w, test_set, task)
    print("Classification Error:", test_loss)
del model

Environments variables: {'E0': {'p': 0.95, 's': 0.3}, 'E1': {'p': 0.97, 's': 0.5}, 'E2': {'p': 0.99, 's': 0.7}}
Epoch 0: Parameter containing:
tensor([[-0.2269, -0.0317, -0.2734,  0.0034, -0.1580, -0.3070,  0.1531, -0.0372,
          0.1949,  0.2485]], requires_grad=True)
Epoch 10000: Parameter containing:
tensor([[ 2.5762,  3.7463,  2.5383,  3.3090,  3.8226, -2.1501,  1.1289, -1.0750,
          0.5808,  1.9628]], requires_grad=True)
Epoch 20000: Parameter containing:
tensor([[11.1213, 12.6520, 11.2289, 11.6377, 13.2351, -2.7434,  1.3525, -1.4648,
          0.6177,  2.6195]], requires_grad=True)
Epoch 30000: Parameter containing:
tensor([[ 7.2465, 21.6740, 18.8479,  7.5583, 22.8085, -2.9074,  1.4526, -1.6931,
          0.8070,  2.6522]], requires_grad=True)
Epoch 40000: Parameter containing:
tensor([[-2.9451, 31.3306, 25.7069, -1.8198, 32.6376, -2.9126,  1.4640, -1.8235,
          0.9804,  2.5488]], requires_grad=True)
Classification Error: 0.3855908513069153
Classification Error: 0.41

# V-REx

The loss function for this method is:

\begin{equation}
    \min_{f:\mathcal{X}\rightarrow \mathcal{Y}}\sum_{e \in [m]} R^e(f) + \gamma\text{Var}(\{R^1(f), \ldots, R^m(f)\}),
\end{equation}

where the first term is the ERM loss, the second term encourages the risks across environments towards equality, and $\gamma$ is the hyperparameter that balances these two objectives.

In [48]:
def train_vrex(model, dummy_w, environments, optimizer, num_epochs=50000, print_interval=10000, task="regression"):
    if task == "regression":
        loss = MSELoss(reduction="none")
    elif task == "classification":
        loss = BCEWithLogitsLoss(reduction="none")
    else:
        raise ValueError("Choose b/w 'regression' and 'classification' only.")
    
    for epoch in range(num_epochs):
        total_error = 0
        risks_per_env = []
        for inputs, targets in environments:
            error = loss(inputs @ model.linear.weight.T * dummy_w, targets)
            risks_per_env.append(error)
            total_error += error.mean()
        
        vrex_penalty = torch.var(torch.cat(risks_per_env))

        optimizer.zero_grad()
        (1e-5 * total_error + vrex_penalty).backward()
        optimizer.step()

        if epoch % print_interval == 0:
            print(f"Epoch {epoch}: {model.linear.weight}")

#### Regression (Example 1) - V-REx

Observe that while VREx is technically successful at this task, it does not perform as well as IRM on regression. On the other hand...

In [49]:
# V-REx
model = LinearModel(dim_inv + dim_spu, 1)
dummy_w = Parameter(torch.Tensor([1.0]))
optimizer = Adam(model.parameters(), lr=learning_rate)

task = "regression"

# Training
example1 = Example1Dataset(dim_inv, dim_spu, n_envs)
environments = [example1.sample(n=1000, env=f"E{i}") for i in range(n_envs)]
train_vrex(model, dummy_w, environments, optimizer, task = task)

# Testing
test_data = [example1.sample(n=200, env=f"E{i}", split="test") for i in range(n_envs)]
for test_set in test_data:
    test_loss = test_model(model, dummy_w, test_set, task)
    print("Test MSE Error:", test_loss)
del model

Epoch 0: Parameter containing:
tensor([[-0.2612, -0.3114,  0.1031, -0.2789, -0.2893, -0.0623,  0.1871, -0.3028,
          0.0793, -0.2814]], requires_grad=True)
Epoch 10000: Parameter containing:
tensor([[-0.0852,  0.8133, -0.3139, -0.0813,  0.2019, -0.0316, -0.4034,  0.1662,
         -1.6057,  0.1780]], requires_grad=True)
Epoch 20000: Parameter containing:
tensor([[-0.0852,  0.8133, -0.3139, -0.0813,  0.2019, -0.0316, -0.4034,  0.1662,
         -1.6057,  0.1780]], requires_grad=True)
Epoch 30000: Parameter containing:
tensor([[-0.0852,  0.8133, -0.3139, -0.0813,  0.2018, -0.0316, -0.4034,  0.1662,
         -1.6057,  0.1780]], requires_grad=True)
Epoch 40000: Parameter containing:
tensor([[-0.0852,  0.8133, -0.3139, -0.0813,  0.2019, -0.0316, -0.4034,  0.1662,
         -1.6057,  0.1780]], requires_grad=True)
Test MSE Error: 2.438600540161133
Test MSE Error: 26.248167037963867
Test MSE Error: 33.872825622558594


#### Classification (Example 2) - VREx

... VREx performs better (and more consistently) on the classification task across environments.

In [53]:
model = LinearModel(dim_inv + dim_spu, 1)
dummy_w = Parameter(torch.Tensor([1.0]))
optimizer = Adam(model.parameters(), lr=learning_rate)

task = "classification"

# Training
example2 = Example2Dataset(dim_inv, dim_spu, n_envs)
environments = [example2.sample(n=1000, env=f"E{i}") for i in range(n_envs)]
train_vrex(model, dummy_w, environments, optimizer, task = task)

# Testing
test_data = [example2.sample(n=200, env=f"E{i}", split="test") for i in range(n_envs)]
for test_set in test_data:
    test_loss = test_model(model, dummy_w, test_set, task)
    print("Classification Error:", test_loss)
del model

Environments variables: {'E0': {'p': 0.95, 's': 0.3}, 'E1': {'p': 0.97, 's': 0.5}, 'E2': {'p': 0.99, 's': 0.7}}
Epoch 0: Parameter containing:
tensor([[ 0.0520, -0.1443, -0.0036,  0.2110,  0.0670, -0.1072,  0.0702,  0.0170,
          0.1575,  0.0132]], requires_grad=True)
Epoch 10000: Parameter containing:
tensor([[2.9376e-02, 2.8515e-02, 3.0621e-02, 2.8830e-02, 2.7850e-02, 5.0312e-05,
         3.9900e-05, 5.5910e-05, 3.4785e-05, 6.0376e-05]], requires_grad=True)
Epoch 20000: Parameter containing:
tensor([[2.9434e-02, 2.8563e-02, 3.0667e-02, 2.8878e-02, 2.7901e-02, 6.1922e-05,
         5.5021e-05, 6.7942e-05, 5.2970e-05, 7.2490e-05]], requires_grad=True)
Epoch 30000: Parameter containing:
tensor([[2.9415e-02, 2.8554e-02, 3.0652e-02, 2.8867e-02, 2.7886e-02, 5.3133e-05,
         4.6232e-05, 5.9157e-05, 4.4189e-05, 6.3690e-05]], requires_grad=True)
Epoch 40000: Parameter containing:
tensor([[2.9416e-02, 2.8552e-02, 3.0654e-02, 2.8865e-02, 2.7885e-02, 3.9400e-05,
         3.2472e-05, 4.539

# IB-IRM

The loss function for this method is:

\begin{equation}
    \min_{f:\mathcal{X}\rightarrow \mathcal{Y}}\sum_{e \in [m]} R^e(1\cdot f) + \lambda \cdot \mathbb{D}(w, f, e)\Big|_{w=1.0} + \beta\text{Var}(f),
\end{equation}

where the first two terms are the IRM loss terms, and the third term includes a hyperparamter $\beta$ and variance of the predictor (i.e. average variance of predictions across environments). This is a surrogate for the unconditional (of environment) entropy $h(f)$, since for all conitnous random variables, the Gaussian has the highest entropy, and the Gaussian entropy increases with its variance.

In [69]:
def compute_penalty(losses, dummy_w):
    g1 = grad(losses[0::2].mean(), dummy_w, create_graph=True)[0]
    g2 = grad(losses[1::2].mean(), dummy_w, create_graph=True)[0]
    return (g1 * g2).sum()


def train_ibirm(model, dummy_w, environments, optimizer, num_epochs=50000, print_interval=10000, task="regression"):
    if task == "regression":
        loss = MSELoss(reduction="none")
    elif task == "classification":
        loss = BCEWithLogitsLoss(reduction="none")
    else:
        raise ValueError("Choose b/w 'regression' and 'classification' only.")
    
    for epoch in range(num_epochs):
        total_error = 0
        irm_penalty = 0
        preds_per_env = []
        for inputs, targets in environments:
            error = loss(inputs @ model.linear.weight.T * dummy_w, targets)
            irm_penalty += compute_penalty(error, dummy_w)
            total_error += error.mean()
            preds_per_env.append(model(inputs))
        
        ib_penalty = torch.stack(preds_per_env).var(1).mean()

        optimizer.zero_grad()
        (1e-5 * total_error + irm_penalty + 1e1 * ib_penalty).backward()
        optimizer.step()

        if epoch % print_interval == 0:
            print(f"Epoch {epoch}: {model.linear.weight}")

#### Regression (Example 1) - IB-IRM

Observe how IB-IRM performs as well as IRM on regression. According to the original paper: https://arxiv.org/pdf/2106.06607.pdf, this is a result of the IRM penalty 'taking over'. 

In [7]:
# IB-IRM
model = LinearModel(dim_inv + dim_spu, 1)
dummy_w = Parameter(torch.Tensor([1.0]))
optimizer = Adam(model.parameters(), lr=learning_rate)

task = "regression"

# Training
example1 = Example1Dataset(dim_inv, dim_spu, n_envs)
environments = [example1.sample(n=1000, env=f"E{i}") for i in range(n_envs)]
train_ibirm(model, dummy_w, environments, optimizer, task = task)

# Testing
test_data = [example1.sample(n=200, env=f"E{i}", split="test") for i in range(n_envs)]
for test_set in test_data:
    test_loss = test_model(model, dummy_w, test_set, task)
    print("Test MSE Error:", test_loss)
del model

Epoch 0: Parameter containing:
tensor([[-0.2418,  0.0225,  0.0243,  0.0367, -0.2227, -0.2023, -0.0584, -0.0800,
          0.0613, -0.3167]], requires_grad=True)
Epoch 10000: Parameter containing:
tensor([[ 3.7363e-08,  1.9732e-08, -4.8332e-08, -9.8523e-08,  2.3469e-07,
          1.8868e-07,  2.3610e-07,  8.1600e-08, -2.7507e-07, -1.8108e-07]],
       requires_grad=True)
Epoch 20000: Parameter containing:
tensor([[-3.3192e-06, -3.3084e-06,  3.2554e-06,  3.2045e-06, -3.0792e-06,
         -3.1461e-06, -3.1019e-06, -3.2707e-06,  3.0580e-06,  3.1414e-06]],
       requires_grad=True)
Epoch 30000: Parameter containing:
tensor([[-6.3467e-08, -8.2543e-08,  5.3417e-08,  3.3788e-09,  1.3358e-07,
          8.5639e-08,  1.3429e-07, -1.8336e-08, -1.7356e-07, -7.9981e-08]],
       requires_grad=True)
Epoch 40000: Parameter containing:
tensor([[ 3.7631e-08,  1.7183e-08, -4.6454e-08, -9.5986e-08,  2.3369e-07,
          1.8624e-07,  2.3529e-07,  7.9253e-08, -2.7390e-07, -1.8095e-07]],
       requires_gr

#### Classification (Example 2) - IB-IRM

Observe how the classification performance is much better, which the authors attribute to the IB penalty.

In [70]:
# IB-IRM
model = LinearModel(dim_inv + dim_spu, 1)
dummy_w = Parameter(torch.Tensor([1.0]))
optimizer = Adam(model.parameters(), lr=learning_rate)

task = "classification"

# Training
example2 = Example2Dataset(dim_inv, dim_spu, n_envs)
environments = [example2.sample(n=1000, env=f"E{i}") for i in range(n_envs)]
train_ibirm(model, dummy_w, environments, optimizer, task = task)

# Testing
test_data = [example2.sample(n=200, env=f"E{i}", split="test") for i in range(n_envs)]
for test_set in test_data:
    test_loss = test_model(model, dummy_w, test_set, task)
    print("Classification Error:", test_loss)
del model

Environments variables: {'E0': {'p': 0.95, 's': 0.3}, 'E1': {'p': 0.97, 's': 0.5}, 'E2': {'p': 0.99, 's': 0.7}}
Epoch 0: Parameter containing:
tensor([[-0.0835,  0.2956,  0.1429,  0.0585,  0.1501, -0.3108,  0.2187, -0.0169,
          0.0023, -0.1063]], requires_grad=True)
Epoch 10000: Parameter containing:
tensor([[ 1.6629e-05,  1.5276e-05,  1.2986e-05,  1.2288e-05,  1.4347e-05,
          2.1638e-08,  3.0420e-08,  8.2188e-09, -2.1430e-08,  2.7348e-08]],
       requires_grad=True)
Epoch 20000: Parameter containing:
tensor([[ 1.7462e-05,  1.4792e-05,  1.2483e-05,  1.2227e-05,  1.4545e-05,
          2.1551e-08,  3.0551e-08,  8.3012e-09, -2.1676e-08,  2.7806e-08]],
       requires_grad=True)
Epoch 30000: Parameter containing:
tensor([[ 1.7489e-05,  1.4813e-05,  1.2492e-05,  1.2235e-05,  1.4554e-05,
          3.6503e-08,  4.5502e-08,  2.3257e-08, -6.7207e-09,  4.2766e-08]],
       requires_grad=True)
Epoch 40000: Parameter containing:
tensor([[ 1.5770e-05,  1.3088e-05,  1.0752e-05,  1.0496e

# IGA

The loss function for this method is:

\begin{equation}
    \min_{\theta} \sum_{e \in [m]} R^e(f_\theta) + \alpha\sum_{e \in [m]}\Bigg|\Bigg| \nabla_{\theta}R^e(f_\theta) - \frac{1}{m}\sum_{e\in [m]}\nabla_\theta R^e(f_\theta)\Bigg|\Bigg|^2,
\end{equation}

where the first term is the average training risk and the second term is the inter-environmental gradient discrepancy. This formulation is claimed to agree with the general formulation of the IRM objective (in contrast to the practical IRM objective (\ref{eq:2}) for categorical target). $\alpha$ is the hyperparameter that balances the expressiveness and invariance of the predictor.

In [60]:
def train_iga(model, environments, optimizer, num_epochs=50000, print_interval=10000, task="regression"):
    if task == "regression":
        loss = MSELoss(reduction="none")
    elif task == "classification":
        loss = BCEWithLogitsLoss(reduction="none")
    else:
        raise ValueError("Choose b/w 'regression' and 'classification' only.")

    for epoch in range(num_epochs):
        total_error = []
        gradients = []
        for inputs, targets in environments:
            error = loss(inputs @ model.linear.weight.T, targets)
            total_error.append(error)
            gradient = grad(error.mean(), model.linear.weight, create_graph=True)
            gradients.append(gradient)

        avg_loss = torch.stack(total_error).mean()
        avg_gradient = grad(avg_loss, model.linear.weight, create_graph=True)

        penalty_value = 0
        for gradient in gradients:
            for gradient_i, avg_grad_i in zip(gradient, avg_gradient):
                penalty_value += (gradient_i - avg_grad_i).pow(2).sum()

        optimizer.zero_grad()
        (1e-3 * avg_loss + penalty_value).backward()
        optimizer.step()

        if epoch % print_interval == 0:
            print(f"Epoch {epoch}: {model.linear.weight}")

#### Regression (Example 1) - IGA

Observe that this method performs much worse than the prior methods on task 1. (and much worse than the results in table 1 of Aubin et al. : https://arxiv.org/pdf/2102.10867.pdf). Still, note that this is a minimal implementation and since this is explicitly a gradient-based method, its optimization is definitely more finicky (especially with loss balancing).

In [64]:
# IGA
model = LinearModel(dim_inv + dim_spu, 1)
dummy_w = Parameter(torch.Tensor([1.0]))
optimizer = Adam(model.parameters(), lr=learning_rate)

task = "regression"

# Training
example1 = Example1Dataset(dim_inv, dim_spu, n_envs)
environments = [example1.sample(n=1000, env=f"E{i}") for i in range(n_envs)]
train_iga(model, environments, optimizer, task = task)

# Testing
test_data = [example1.sample(n=200, env=f"E{i}", split="test") for i in range(n_envs)]
for test_set in test_data:
    test_loss = test_model(model, dummy_w, test_set, task)
    print("Test MSE Error:", test_loss)
del model

Epoch 0: Parameter containing:
tensor([[ 0.1296, -0.0946, -0.2574, -0.3072,  0.0223, -0.2589, -0.1158,  0.2420,
          0.1853, -0.1739]], requires_grad=True)
Epoch 10000: Parameter containing:
tensor([[ 0.0248, -0.8683, -0.1423, -0.4550,  0.7380, -3.6888,  0.0421,  2.0978,
         -2.1279,  1.5756]], requires_grad=True)
Epoch 20000: Parameter containing:
tensor([[ 0.0250, -0.8683, -0.1424, -0.4551,  0.7380, -3.6923,  0.0394,  2.1005,
         -2.1300,  1.5748]], requires_grad=True)
Epoch 30000: Parameter containing:
tensor([[ 0.0249, -0.8683, -0.1424, -0.4551,  0.7380, -3.6923,  0.0395,  2.1005,
         -2.1300,  1.5748]], requires_grad=True)
Epoch 40000: Parameter containing:
tensor([[ 0.0250, -0.8683, -0.1424, -0.4551,  0.7380, -3.6923,  0.0395,  2.1005,
         -2.1300,  1.5748]], requires_grad=True)
Test MSE Error: 26.61590003967285
Test MSE Error: 39.50508117675781
Test MSE Error: 66.31356811523438


#### Classification (Example 2) - IGA

Interestingly, this minimal implementation of IGA performs amazingly well on task 2, which is in stark contrast to the results provided in table 1 of Aubin et al. : https://arxiv.org/pdf/2102.10867.pdf. 

I am really not sure what to make of this discrepancy... 

In [66]:
# IGA
model = LinearModel(dim_inv + dim_spu, 1)
dummy_w = Parameter(torch.Tensor([1.0]))
optimizer = Adam(model.parameters(), lr=learning_rate)

task = "classification"

# Training
example2 = Example2Dataset(dim_inv, dim_spu, n_envs)
environments = [example2.sample(n=1000, env=f"E{i}") for i in range(n_envs)]
train_iga(model, environments, optimizer, task = task)

# Testing
test_data = [example2.sample(n=200, env=f"E{i}", split="test") for i in range(n_envs)]
for test_set in test_data:
    test_loss = test_model(model, dummy_w, test_set, task)
    print("Classification Error:", test_loss)
del model

Environments variables: {'E0': {'p': 0.95, 's': 0.3}, 'E1': {'p': 0.97, 's': 0.5}, 'E2': {'p': 0.99, 's': 0.7}}
Epoch 0: Parameter containing:
tensor([[-0.2706,  0.2919,  0.1161, -0.1315,  0.1581, -0.0415,  0.0427,  0.2549,
          0.3170,  0.0948]], requires_grad=True)
Epoch 10000: Parameter containing:
tensor([[ 8.1607,  8.5869,  8.5010,  8.0091,  8.2896,  5.9700,  1.1304, -4.7443,
          3.3920, -3.2336]], requires_grad=True)
Epoch 20000: Parameter containing:
tensor([[17.1763, 17.8349, 17.6177, 17.0925, 17.1829,  5.6891,  5.3499, -7.2129,
          3.8819, -4.1236]], requires_grad=True)
Epoch 30000: Parameter containing:
tensor([[27.7137, 28.1844, 28.0984, 27.7895, 27.9050,  3.0330, -0.5247, -2.1420,
          1.0975, -0.6182]], requires_grad=True)
Epoch 40000: Parameter containing:
tensor([[37.2125, 37.8030, 37.6419, 37.3777, 37.4289,  2.5764, -0.6894, -1.3896,
          0.7365, -0.5499]], requires_grad=True)
Classification Error: 0.16034336388111115
Classification Error: 0.1