## Example with 3d shape 

In [1]:
from tqdm import tqdm
import time
import argparse
import os
import sys
from abc import abstractmethod
from torch.nn import functional as F

In [2]:
import pandas as pd
import numpy as np
import einops

In [3]:
import torch
from torch import nn
import torch.nn.init as init
from torchvision import datasets
from torchvision import transforms as T
from torchvision.datasets import ImageFolder
from torch.utils.data import Dataset, DataLoader

In [4]:
sys.path.insert(0,"/nfshomes/xliu1231/Causal_Disentangle/3dshape_dataset")

In [5]:
from dataset import *

### Prepare dataset

In [6]:
# Set distributions of factors' values. If not specified, we will sample factors randomly
dist={'shape': [1, 2, 1, 4],
      'object_hue': [1, 1, 1, 2, 2, 3, 1, 3, 3, 1]
}

In [None]:
# sample a dataset and save
sample_3dshape_dataset(dist, data_size=5000, 
                       data_path='/nfshomes/xliu1231/Causal_Disentangle/data',
                       label_path='/nfshomes/xliu1231/Causal_Disentangle/data/sample.csv')

In [None]:
transform = T.ToTensor()

In [None]:
# _FACTORS_IN_ORDER = ['floor_hue', 'wall_hue', 'object_hue', 'scale', 'shape', 'orientation']
# concept_interest is the indices of the interested factors that will be labels

train_dataset = ShapesDataset(data_path='/nfshomes/xliu1231/Causal_Disentangle/data', 
                              attr_path='/nfshomes/xliu1231/Causal_Disentangle/data/sample.csv',
                              attr=attr=[0 ,1 ,2 ,3],
                              transform=transform
                             )

In [None]:
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True, pin_memory=True, num_workers=1)

In [None]:
class data_prefetcher():
    def __init__(self, loader):
        self.loader = iter(loader)
        self.stream = torch.cuda.Stream()
        self.preload()

    def preload(self):
        try:
            self.next_input, self.next_target = next(self.loader)
        except StopIteration:
            self.next_input = None
            self.next_target = None
            return

        with torch.cuda.stream(self.stream):
            self.next_input = self.next_input.cuda(non_blocking=True)
            self.next_target = self.next_target.cuda(non_blocking=True)

    def next(self):
        torch.cuda.current_stream().wait_stream(self.stream)
        input = self.next_input
        target = self.next_target
        if input is not None:
            input.record_stream(torch.cuda.current_stream())
        if target is not None:
            target.record_stream(torch.cuda.current_stream())
        self.preload()
        return input, target

### Now learn a causal map from labels/data

In [None]:
class CausalDAG(nn.Module):
    """
    creates a causal diagram A
    
    
    """
    def __init__(self, num_concepts, dim_per_concept, inference = False, bias=False, g_dim=32):
        
        super(CausalDAG, self).__init__()
        self.num_concepts = num_concepts
        self.dim_per_concept = dim_per_concept
        
        self.A = nn.Parameter(torch.zeros(num_concepts, num_concepts))
        self.I = nn.Parameter(torch.eye(num_concepts))
        self.I.requires_grad=False
        if bias:
            self.bias = Parameter(torch.Tensor(num_concepts))
        else:
            self.register_parameter('bias', None)
            
        nets_z = []
        nets_label = []
        
        
        for _ in range(num_concepts):
            nets_z.append(
                nn.Sequential(
                    nn.Linear(dim_per_concept, g_dim),
                    nn.ELU(),
                    nn.Linear(g_dim, dim_per_concept)
                )
            )
                
            nets_label.append(
                nn.Sequential(
                    nn.Linear(1, g_dim),
                    nn.ELU(),
                    nn.Linear(g_dim, 1)
                )
            )
        self.nets_z = nn.ModuleList(nets_z)
        self.nets_label = nn.ModuleList(nets_label)
        
    def calculate_z(self, epsilon):
        """
        convert epsilon to z using the SCM assumption and causal diagram A
        
        """
        
        C = torch.inverse(self.I - self.A.t())
            
        if epsilon.dim() > 2: # one concept is represented by multiple dimensions     
            z = F.linear(epsilon.permute(0,2,1), C, self.bias)
            z = z.permute(0,2,1).contiguous() 
            
        else:
            z = F.linear(epsilon, C, self.bias)
        return z
    
    def calculate_epsilon(self, z):
        """
        convert epsilon to z using the SCM assumption and causal diagram A
         
        """
        
        C_inv = self.I - self.A.t()
        
        if z.dim() > 2: # one concept is represented by multiple dimensions     
            epsilon = F.linear(z.permute(0,2,1), C_inv, self.bias)
            epsilon = epsilon.permute(0,2,1).contiguous() 
            
        else:
            epsilon = F.linear(z, C, self.bias)
        return epsilon
    
    def mask(self, x):
        if x.dim() == 2:
            x = x.unsqueeze(dim=-1)
        res = torch.matmul(self.A.t(), x)
        return res
    
    def g_z(self, x):
        """
        apply nonlinearity for more stable approximation
        
        """
        x_flatterned = x.view(-1, self.num_concepts*self.dim_per_concept)
        concepts = torch.split(x_flatterned, self.dim_per_concept, dim = 1)
        res = []
        for i, concept in enumerate(concepts):
            res.append(self.nets_z[i](concept))
        x = torch.concat(res, dim=1).reshape([-1, self.num_concepts, self.dim_per_concept])
        return x
    
    def g_label(self, x):
        """
        apply nonlinearity for more stable approximation
        
        """
        x_flatterned = x.view(-1, self.num_concepts)
        concepts = torch.split(x_flatterned, 1, dim = 1)
        res = []
        for i, concept in enumerate(concepts):
            res.append(self.nets_label[i](concept))
        x = torch.concat(res, dim=1).reshape([-1, self.num_concepts])
        return x
            
    def forward(self, x, islabel=False):
        if islabel:
            return self.g_label(self.mask(x))
        return self.g_z(self.mask(x))


In [None]:
def _h_A(A, m):
    A_square = A * A
    x = torch.eye(m).cuda() + torch.div(A_square, m)
    expm_A = torch.matrix_power(x, m)
    h_A = torch.trace(expm_A) - m
    return h_A

In [None]:
graph = CausalDAG(4, 1, g_dim=1).cuda()
graph.train()

In [None]:
optimizer = torch.optim.SGD(graph.parameters(), lr=0.01*float(128)/256., momentum=0.9, weight_decay=5e-4)

In [None]:
prefetcher = data_prefetcher(train_loader)

In [None]:
input, target = prefetcher.next()
h_a = _h_A(graph.A, 4)

In [None]:
c = 1
alpha = 1
h_a_prev = 0

In [None]:
while input is not None:
    
    target = target.cuda()
    output = graph(target, islabel=True)
    
    
    A = graph.A
    m = graph.A.size()[0]
    A_square = A * A
    x = torch.eye(m).cuda() + torch.div(A_square, m)
    expm_A = torch.matrix_power(x, m)
    h_a = torch.trace(expm_A) - m
    

    
    loss = F.mse_loss(target, output) + c * h_a
    
    #loss += alpha * h_a + 0.5 * c * h_a * h_a
    
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    
    
    print("loss:", loss)
    
    input, target = prefetcher.next()
    
    
    alpha = alpha + c * h_a
    
    if h_a > 0.25 * h_a_prev:
        c = 10 * c
    
    h_a_prev = h_a
    
    
    
    
    
    
    