In [15]:
# Minimal implementation of IRM in PyTorch
# https://arxiv.org/pdf/1907.02893v1.pdf

#!pip install torch torchvision torchaudio

In [16]:
import torch
from torch.autograd import grad

In [17]:
def compute_penalty ( losses , dummy_w ):
    g1 = grad ( losses [0::2].mean () , dummy_w , create_graph = True )[0]
    g2 = grad ( losses [1::2].mean () , dummy_w , create_graph = True )[0]
    return ( g1 * g2 ).sum()

def example_1 (n=10000 , d=2 , env=1):
    x = torch.randn(n , d) * env
    y = x + torch.randn(n , d ) * env
    z = y + torch.randn(n , d )
    return torch.cat((x ,z), 1) , y.sum(1, keepdim = True)

In [18]:
phi = torch.nn.Parameter( torch.ones(4,1) )
dummy_w = torch.nn.Parameter( torch.Tensor([1.0]) )

opt = torch.optim.SGD([ phi ], lr=1e-3 )
mse = torch.nn.MSELoss( reduction = "none" )

environments = [ example_1(env = 0.1),
                 example_1(env = 1.0) ]

In [None]:
for iteration in range(50000):
    error = 0
    penalty = 0
    for x_e , y_e in environments :
        p = torch.randperm( len( x_e ) )
        error_e = mse( x_e[p] @ phi * dummy_w , y_e[p] )
        penalty += compute_penalty( error_e , dummy_w )
        error += error_e.mean ()
    
    opt. zero_grad ()
    (1e-5 * error + penalty ).backward ()
    opt.step ()
    
    if iteration % 1000 == 0:
        print ( phi )

Parameter containing:
tensor([[0.8543],
        [0.8573],
        [0.6790],
        [0.6883]], requires_grad=True)
Parameter containing:
tensor([[0.9389],
        [0.9409],
        [0.1639],
        [0.1647]], requires_grad=True)
Parameter containing:
tensor([[0.9757],
        [0.9769],
        [0.1118],
        [0.1120]], requires_grad=True)
Parameter containing:
tensor([[0.9886],
        [0.9896],
        [0.0881],
        [0.0882]], requires_grad=True)
Parameter containing:
tensor([[0.9949],
        [0.9960],
        [0.0742],
        [0.0744]], requires_grad=True)
Parameter containing:
tensor([[0.9986],
        [0.9999],
        [0.0645],
        [0.0654]], requires_grad=True)
Parameter containing:
tensor([[1.0011],
        [1.0024],
        [0.0579],
        [0.0587]], requires_grad=True)
Parameter containing:
tensor([[1.0028],
        [1.0042],
        [0.0528],
        [0.0533]], requires_grad=True)
Parameter containing:
tensor([[1.0042],
        [1.0056],
        [0.0485],
    