In [1]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

In [2]:
import sys
sys.path.append('../')

from dataset import*
from baselines import*
from synthetic_concept_model import *
from synthetic_coop_model import *
from torch.utils.data import DataLoader
from sklearn.linear_model import LogisticRegression, LinearRegression
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import pickle
import torch
import numpy
import random
import os

random.seed(7)
numpy.random.seed(seed=7)
torch.manual_seed(7)

<torch._C.Generator at 0x7f1909f81770>

In [3]:
# !nvidia-smi
torch.cuda.is_available()

True

## Data generation

In [4]:
feature_dim_info = dict()
label_dim_info = dict()
transform_dim = 10000

intersections = get_intersections(num_modalities=2)

feature_dim_info['12'] = 10
feature_dim_info['1'] = 6
feature_dim_info['2'] = 6

label_dim_info['12'] = 10
label_dim_info['1'] = 6
label_dim_info['2'] = 6
num_concepts = 1
transforms_2concept = None
transforms_2hd = None
num_data = 1000
noise=0.3
pos_prob=0.5
total_data, total_labels, total_concepts, total_raw_features = generate_data_concepts(num_data, num_concepts,
                                                                                      feature_dim_info,
                                                                                      label_dim_info,
                                                                                      transform_dim=transform_dim,
                                                                                     noise=noise,
                                                                                     pos_prob=pos_prob)


Current generated data : 0
Current generated data : 100
Current generated data : 200
Current generated data : 300
Current generated data : 400
Current generated data : 500
Current generated data : 600
Current generated data : 700
Current generated data : 800
Current generated data : 900


In [5]:
# total_raw_features

In [6]:
# Data splitting & loading
# dataset = MultiConcept(total_data, total_labels, total_concepts, 0)
dataset = MultiConcept_w_Features(total_data, total_labels, total_concepts, 0, total_raw_features)

batch_size = 100
trainval_dataset, test_dataset = torch.utils.data.random_split(dataset,  
                                                            [int(0.7 * num_data), num_data - int(0.7 * num_data)])
train_dataset, val_dataset = torch.utils.data.random_split(trainval_dataset,
                                                           [int(0.8 * len(trainval_dataset)), len(trainval_dataset) - int(0.8 * len(trainval_dataset))])

## Experiment 1


## Scenario 1 
One known concept $c_1$ derived from information components $W_{U_1}, W_s$. Label $Y$ is composed of information components $y=f(W_{U_1}, W_s, W_{U_2})$. We try to recover $W_{U_2}$ by $\arg \max_{Z_x} I(Z_x;Y|Z_{c_1})$, assuming that $Z_{c_1}$ represents $\{W_{U_1}, W_s\}$


In [21]:
"""

!!! CHANGE THE train_concept_encoder AND  train_concept_informed_model FUNCTIONS IF NEEDED!!!

"""

def factorCBM_exp1(train_dataset, test_dataset,  num_eval=10, save_path='./results'): #concept_encoder, model, device,

    # models
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    hidden_dim = 512
    embed_dim = 50
  
    acc_list, pre_list, recall_list, f1_list = [], [], [], []
    # teval = tqdm(range(num_eval))
    for idx in range(num_eval):
        concept_encoder = ConceptEncoder(transform_dim, embed_dim, 1, hidden_dim).to(device)
        model = ConceptCLSUP_full_concept(transform_dim, embed_dim, 2, hidden_dim, embed_dim).to(device) #ConceptCLSUP_full_concept
        train_loader = DataLoader(train_dataset, shuffle=True, drop_last=True,
                          batch_size=batch_size)
        val_loader = DataLoader(val_dataset, shuffle=True, batch_size=batch_size, drop_last=True)
        test_loader = DataLoader(test_dataset, shuffle=False, drop_last=False)
        # train concpet encoder
        trained_concept_encoder = train_concept_encoder(concept_encoder, train_loader,val_loader, 1000, device, 1e-5, 1e-5, 25, 3, '../trained_models')
        # train concept informed model
        trained_concept_informed_model = train_concept_informed_model(trained_concept_encoder, model, train_loader, val_loader, 1000, device, 1e-5, 25, 3, '../trained_models') #train_concept_informed_model
        
        train_embeds_1 = trained_concept_encoder.get_embedding(torch.stack([sample[0] for sample in  train_dataset]).to(device)).detach().cpu().numpy() #torch.stack([sample[1] for sample in  train_dataset]).to(device).detach().cpu().numpy()#
        train_embeds_2 = trained_concept_informed_model.get_embedding(torch.stack([sample[0] for sample in  train_dataset]).to(device)).detach().cpu().numpy()
        train_embeds = np.concatenate((train_embeds_1, train_embeds_2), axis=1) #train_embeds_2 #
        train_labels = np.array([sample[2].item() for sample in  train_dataset])

        test_embeds_1 = trained_concept_encoder.get_embedding(torch.stack([sample[0] for sample in  test_dataset]).to(device)).detach().cpu().numpy() #torch.stack([sample[1] for sample in  test_dataset]).to(device).detach().cpu().numpy()#
        test_embeds_2 = trained_concept_informed_model.get_embedding(torch.stack([sample[0] for sample in  test_dataset]).to(device)).detach().cpu().numpy()
        test_embeds = np.concatenate((test_embeds_1, test_embeds_2), axis=1) #test_embeds_2 #
        test_labels = np.array([sample[2].item() for sample in  test_dataset])

        clf = LogisticRegression(max_iter=1000).fit(train_embeds, train_labels)
        predictions = clf.predict(test_embeds)
        
        accuracy = accuracy_score(test_labels, predictions)
        precision = precision_score(test_labels, predictions)
        recall = recall_score(test_labels, predictions)
        f1 = f1_score(test_labels, predictions)
        
        acc_list.append(accuracy)
        pre_list.append(precision)
        recall_list.append(recall)
        f1_list.append(f1)
    
    dict_results = {'accuracy':acc_list,
                    'precision':pre_list,
                    'recall':recall_list,
                    'f1_score':f1_list}
    df_results = pd.DataFrame(data=dict_results)
    
    print(f'Accuracy:{df_results.accuracy.mean():0.3f} \u00B1 {2*df_results.accuracy.std():0.3f}')
    print(f'Precision:{df_results.precision.mean():0.3f} \u00B1 {2*df_results.precision.std():0.3f}')
    print(f'Recall:{df_results.recall.mean():0.3f} \u00B1 {2*df_results.recall.std():0.3f}')
    print(f'F1-score:{df_results.f1_score.mean():0.3f} \u00B1 {2*df_results.f1_score.std():0.3f}')
    
    directory = save_path + '/' + time.strftime("%Y%m%d")
    
    if not os.path.exists(directory):
        os.makedirs(directory)
    
    timestr = time.strftime("%H%M%S")
    file_path = directory +  '/factorCBM_exp1_'+ str(num_eval) + '_' + timestr +'.csv'
    df_results.to_csv(file_path)
    
    
    return df_results

In [22]:
results = factorCBM_exp1(train_dataset, test_dataset, num_eval=10, save_path='./results') #concept_encoder, model, device, 

Epoch 6:   1%|          | 6/1000 [00:00<01:04, 15.35it/s, loss=2.31]


Val loss did not improve


Epoch 39:   4%|▍         | 39/1000 [00:08<03:17,  4.87it/s, loss=-.0142]  


Val loss did not improve


Epoch 6:   1%|          | 6/1000 [00:00<01:01, 16.24it/s, loss=2.16]


Val loss did not improve


Epoch 39:   4%|▍         | 39/1000 [00:07<03:05,  5.17it/s, loss=-.0153]  


Val loss did not improve


Epoch 6:   1%|          | 6/1000 [00:00<01:21, 12.21it/s, loss=2.5] 


Val loss did not improve


Epoch 6:   1%|          | 6/1000 [00:01<03:48,  4.34it/s, loss=-6.73e-5]


Val loss did not improve


Epoch 9:   1%|          | 9/1000 [00:00<00:59, 16.76it/s, loss=2.46]


Val loss did not improve


Epoch 51:   5%|▌         | 51/1000 [00:09<03:04,  5.15it/s, loss=-.0516]  


Val loss did not improve


Epoch 9:   1%|          | 9/1000 [00:00<00:55, 17.97it/s, loss=2.66]


Val loss did not improve


Epoch 12:   1%|          | 12/1000 [00:02<03:38,  4.52it/s, loss=-.000131]


Val loss did not improve


Epoch 6:   1%|          | 6/1000 [00:00<00:55, 17.76it/s, loss=2.58]


Val loss did not improve


Epoch 51:   5%|▌         | 51/1000 [00:09<03:02,  5.20it/s, loss=-.0455]  


Val loss did not improve


Epoch 9:   1%|          | 9/1000 [00:00<01:07, 14.65it/s, loss=2.37]


Val loss did not improve


Epoch 54:   5%|▌         | 54/1000 [00:10<03:04,  5.14it/s, loss=-.0633]  


Val loss did not improve


Epoch 9:   1%|          | 9/1000 [00:00<00:51, 19.22it/s, loss=2.28]


Val loss did not improve


Epoch 57:   6%|▌         | 57/1000 [00:10<03:00,  5.22it/s, loss=-.0722]  


Val loss did not improve


Epoch 6:   1%|          | 6/1000 [00:00<00:59, 16.65it/s, loss=2.63]


Val loss did not improve


Epoch 60:   6%|▌         | 60/1000 [00:11<03:03,  5.13it/s, loss=-.111]   


Val loss did not improve


Epoch 6:   1%|          | 6/1000 [00:00<01:01, 16.24it/s, loss=1.94]


Val loss did not improve


Epoch 51:   5%|▌         | 51/1000 [00:09<03:02,  5.20it/s, loss=-.054]   


Val loss did not improve
Accuracy:0.736 ± 0.129
Precision:0.717 ± 0.133
Recall:0.793 ± 0.110
F1-score:0.750 ± 0.083


## Scenario 2

Use both the supervised loss and the Info_club constraint to minimize the mutual information between the learned representation and c, then concatenate c to the learned representation and train everything end-to-end.

$\mathop{\arg \min}\limits_{\theta , \phi} \mathcal{L}\bigl( y, f_{\theta}(g_{\phi}(x),c)\bigr) + \lambda Info_{NCE\_CLUB} \bigl( g_{\phi}(x);c \bigr)$


In [9]:
"""

!!! CHANGE THE train_concept_encoder AND  train_concept_informed_model FUNCTIONS IF NEEDED!!!

"""
    
class ConceptCLSUP_Sc2(nn.Module):
    def __init__(self, x_dim, hidden_dim, embed_dim, layers=2, activation='relu', lr=1e-4, concept_dim = 50):
        super(ConceptCLSUP_Sc2, self).__init__()
        self.critic_hidden_dim = 512
        self.critic_layers = 1
        self.critic_activation = 'relu'

        # encoders
        self.backbone = mlp(x_dim, hidden_dim, embed_dim, layers, activation)
        self.linears_infonce = mlp(embed_dim, embed_dim, embed_dim, 1, activation) 
        self.y_projection = mlp(embed_dim + concept_dim, embed_dim, 1, 1, activation)

        # critics
        self.club_critic = CLUBInfoNCECritic(embed_dim , concept_dim, self.critic_hidden_dim, self.critic_layers, self.critic_activation) #MINECritic(
        

    def forward(self, x, c):
        # compute embedding
        z = self.linears_infonce(self.backbone(x))
        # compute critic scores
        club_infonce_score = self.club_critic(z, c) #self.club_critic(torch.cat([z, x], dim=-1), c)self.club_critic(torch.cat([z, x], dim=-1), c) #
        y_encoding = self.y_projection(torch.cat([z,c], dim=-1))
        
        return club_infonce_score, y_encoding 

    
    def get_embedding(self, x):
        return self.backbone(x)
    
    def get_logits(self, x, c):
        z = self.linears_infonce(self.backbone(x))
        return self.y_projection(torch.cat([z,c], dim=-1))
    
    def get_backbone(self):
        return self.backbone
    
    
def train_concept_informed_model_sc2(concept_encoder, model, train_loader,val_loader, num_epochs, device, lr, lamb, log_interval,
                          save_interval, save_path):

    # concept_encoder.eval()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    label_loss_func = torch.nn.BCELoss()
    best_val_err = torch.tensor(1e7)
    tepoch = tqdm(range(num_epochs))
    # print(f'lamda {lamb}')
    for epoch in tepoch:
        tepoch.set_description(f"Epoch {epoch}")
        model.train()
        for batch_idx, (data, concept, target,_) in enumerate(train_loader):
            data, concept, target = data.to(device), concept.to(device) ,target.to(device)
            optimizer.zero_grad()
            c, z_c = concept_encoder(data)
            MI_loss, y_logits = model(data, z_c) #concept
            label_loss = label_loss_func(torch.sigmoid(y_logits), target.float())
            # print(f'label loss {label_loss}, MI loss {MI_loss}')
            loss = label_loss + lamb * MI_loss #
            loss.backward()
            optimizer.step()

            # torch.cuda.empty_cache()
            # gc.collect()
            
        tepoch.set_postfix(loss=loss.item())
        if epoch % save_interval == 0:
            val_err = 0
            model.eval()
            with torch.no_grad():
                for batch_idx, (data, concept, target,_) in enumerate(val_loader):
                    data, concept, target = data.to(device), concept.to(device), target.to(device)
                    c , z_c = concept_encoder(data)
                    # output = model(data, c, target) #z_c
                    MI_loss, y_logits = model(data, z_c)#concept
                    label_loss = label_loss_func(torch.sigmoid(y_logits), target.float())
                    loss = label_loss + lamb * MI_loss 
                    val_err += loss
                val_err = val_err / len(val_loader)
            if val_err < best_val_err:
                best_val_err = val_err

            else:
                print('Val loss did not improve')
                torch.save(model.state_dict(), os.path.join(save_path, 'concept_informed_model.pth'))
                return model
    return model

def factorCBM_exp1_Sc2(train_dataset, test_dataset, num_eval=10, save_path='./results'): #, concept_encoder, model, device,

    acc_list, pre_list, recall_list, f1_list = [], [], [], []
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    hidden_dim = 512
    embed_dim = 256
    
    # teval = tqdm(range(num_eval))
    for idx in range(num_eval):
        concept_encoder = ConceptEncoder(transform_dim, embed_dim, 1, hidden_dim).to(device)
        model = ConceptCLSUP_Sc2(transform_dim, hidden_dim, embed_dim, 2, 'relu').to(device)
        train_loader = DataLoader(train_dataset, shuffle=True, drop_last=True,
                          batch_size=batch_size)
        val_loader = DataLoader(val_dataset, shuffle=True, batch_size=batch_size, drop_last=True)
        test_loader = DataLoader(test_dataset, shuffle=False, drop_last=False)
        # train concpet encoder
        trained_concept_encoder = train_concept_encoder(concept_encoder, train_loader,val_loader, 1000, device, 1e-5, 1e-5, 25, 3, '../trained_models')
        # train concept informed model concept_encoder
        trained_concept_informed_model = train_concept_informed_model_sc2(trained_concept_encoder, model, train_loader, val_loader, num_epochs=1000, device=device,
                                                                          lr=1e-5, lamb=0.5, log_interval = 25, save_interval = 3, save_path ='../trained_models')
        
        test_embeds =  torch.stack([sample[0] for sample in  test_dataset]).to(device)
        test_concepts = torch.tensor([sample[1] for sample in  test_dataset]).unsqueeze(1).to(device)
        test_labels = np.array([sample[2].item() for sample in  test_dataset])

        
        out = trained_concept_informed_model.get_logits(test_embeds, trained_concept_encoder.get_embedding(test_embeds)) #test_concepts
        
        predictions = torch.sigmoid(out).round().detach().cpu().numpy()
        
        accuracy = accuracy_score(test_labels, predictions)
        precision = precision_score(test_labels, predictions)
        recall = recall_score(test_labels, predictions)
        f1 = f1_score(test_labels, predictions)
        
        acc_list.append(accuracy)
        pre_list.append(precision)
        recall_list.append(recall)
        f1_list.append(f1)
    
    dict_results = {'accuracy':acc_list,
                    'precision':pre_list,
                    'recall':recall_list,
                    'f1_score':f1_list}
    df_results = pd.DataFrame(data=dict_results)
    
    print(f'Accuracy:{df_results.accuracy.mean():0.3f} \u00B1 {2*df_results.accuracy.std():0.3f}')
    print(f'Precision:{df_results.precision.mean():0.3f} \u00B1 {2*df_results.precision.std():0.3f}')
    print(f'Recall:{df_results.recall.mean():0.3f} \u00B1 {2*df_results.recall.std():0.3f}')
    print(f'F1-score:{df_results.f1_score.mean():0.3f} \u00B1 {2*df_results.f1_score.std():0.3f}')
    
    directory = save_path + '/' + time.strftime("%Y%m%d")
    
    if not os.path.exists(directory):
        os.makedirs(directory)
    
    timestr = time.strftime("%H%M%S")
    file_path = directory +  '/factorCBM_exp1_Sc2_'+ str(num_eval) + '_' + timestr +'.csv'
    df_results.to_csv(file_path)
    
    
    return df_results

In [10]:
results = factorCBM_exp1_Sc2(train_dataset, test_dataset, num_eval=10, save_path='./results') #concept_encoder, model, device,

Epoch 6:   1%|          | 6/1000 [00:00<00:58, 16.99it/s, loss=2.81]


Val loss did not improve


Epoch 3:   0%|          | 3/1000 [00:00<02:13,  7.44it/s, loss=0.702]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Val loss did not improve


Epoch 6:   1%|          | 6/1000 [00:00<01:01, 16.15it/s, loss=2.37]


Val loss did not improve


Epoch 3:   0%|          | 3/1000 [00:00<02:16,  7.31it/s, loss=0.692]


Val loss did not improve


Epoch 6:   1%|          | 6/1000 [00:00<01:01, 16.16it/s, loss=2.74]


Val loss did not improve


Epoch 3:   0%|          | 3/1000 [00:00<02:39,  6.27it/s, loss=0.691]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Val loss did not improve


Epoch 9:   1%|          | 9/1000 [00:00<00:57, 17.23it/s, loss=2.4] 


Val loss did not improve


Epoch 12:   1%|          | 12/1000 [00:01<01:39,  9.96it/s, loss=0.692]


Val loss did not improve


Epoch 6:   1%|          | 6/1000 [00:00<00:58, 16.91it/s, loss=2.2] 


Val loss did not improve


Epoch 3:   0%|          | 3/1000 [00:00<02:21,  7.03it/s, loss=0.692]


Val loss did not improve


Epoch 9:   1%|          | 9/1000 [00:00<00:56, 17.68it/s, loss=2.8] 


Val loss did not improve


Epoch 3:   0%|          | 3/1000 [00:00<02:20,  7.11it/s, loss=0.694]


Val loss did not improve


Epoch 3:   0%|          | 3/1000 [00:00<01:15, 13.14it/s, loss=2.55]


Val loss did not improve


Epoch 6:   1%|          | 6/1000 [00:00<02:00,  8.24it/s, loss=0.696]


Val loss did not improve


Epoch 9:   1%|          | 9/1000 [00:00<00:53, 18.42it/s, loss=2]   


Val loss did not improve


Epoch 9:   1%|          | 9/1000 [00:01<01:55,  8.57it/s, loss=0.701]


Val loss did not improve


Epoch 9:   1%|          | 9/1000 [00:00<00:55, 17.73it/s, loss=2.55]


Val loss did not improve


Epoch 6:   1%|          | 6/1000 [00:00<02:10,  7.63it/s, loss=0.691]


Val loss did not improve


Epoch 6:   1%|          | 6/1000 [00:00<01:30, 10.93it/s, loss=2.27]


Val loss did not improve


Epoch 6:   1%|          | 6/1000 [00:00<02:11,  7.53it/s, loss=0.688]


Val loss did not improve
Accuracy:0.496 ± 0.011
Precision:0.395 ± 0.416
Recall:0.800 ± 0.843
F1-score:0.529 ± 0.557


## Baseline 1 (logistic regression on $x$)

In [11]:
results_1 = baseline_1(train_dataset, test_dataset, num_eval=10)

Evaluation 9: 100%|██████████| 10/10 [00:11<00:00,  1.13s/it]

Accuracy:0.757 ± 0.000
Precision:0.739 ± 0.000
Recall:0.784 ± 0.000
F1-score:0.761 ± 0.000





## Baseline 2 (Supervised Representation Learning on $x$)

In [12]:
results_2 = baseline_2(train_dataset, val_dataset, test_dataset, transform_dim=transform_dim, batch_size=batch_size, num_eval=1)

Epoch 200:  20%|██        | 200/1000 [00:09<00:37, 21.18it/s, loss=0.000127]


Val loss did not improve
Accuracy:0.777 ± nan
Precision:0.745 ± nan
Recall:0.831 ± nan
F1-score:0.786 ± nan


## Baseline 3 

### (Logistic Regression on $x,c_1$)

In [13]:
resuts_3_A = baseline_3_A(train_dataset, test_dataset, num_eval=10, save_path='./results')

Evaluation 9: 100%|██████████| 10/10 [00:12<00:00,  1.20s/it]

Accuracy:0.860 ± 0.000
Precision:0.849 ± 0.000
Recall:0.872 ± 0.000
F1-score:0.860 ± 0.000





### (Supervised Representation Learning on $x,c_1$)

In [14]:
results_3_b = baseline_3_B(train_dataset, val_dataset, test_dataset, transform_dim=transform_dim, batch_size=batch_size, num_eval=10)

Epoch 200:  20%|██        | 200/1000 [00:09<00:38, 21.03it/s, loss=0.000136]


Val loss did not improve


Epoch 300:  30%|███       | 300/1000 [00:13<00:32, 21.64it/s, loss=4.51e-5] 


Val loss did not improve


Epoch 200:  20%|██        | 200/1000 [00:09<00:36, 21.75it/s, loss=0.000125]


Val loss did not improve


Epoch 200:  20%|██        | 200/1000 [00:09<00:37, 21.53it/s, loss=0.00011] 


Val loss did not improve


Epoch 200:  20%|██        | 200/1000 [00:09<00:37, 21.47it/s, loss=0.000137]


Val loss did not improve


Epoch 200:  20%|██        | 200/1000 [00:09<00:37, 21.44it/s, loss=0.000141]


Val loss did not improve


Epoch 100:  10%|█         | 100/1000 [00:04<00:42, 21.11it/s, loss=0.000795]


Val loss did not improve


Epoch 200:  20%|██        | 200/1000 [00:09<00:37, 21.52it/s, loss=0.000131]


Val loss did not improve


Epoch 300:  30%|███       | 300/1000 [00:13<00:32, 21.76it/s, loss=4.17e-5] 


Val loss did not improve


Epoch 400:  40%|████      | 400/1000 [00:18<00:27, 21.50it/s, loss=2.12e-5] 

Val loss did not improve
Accuracy:0.783 ± 0.015
Precision:0.766 ± 0.016
Recall:0.808 ± 0.028
F1-score:0.786 ± 0.016





## Baseline 4 (Multi-Task Learning with Concepts $x \rightarrow y, c_1$)

In [15]:
results_4 = baseline_4(train_dataset, val_dataset, test_dataset, transform_dim=transform_dim, batch_size=batch_size, num_eval=10)

Epoch 200:  20%|██        | 200/1000 [00:08<00:35, 22.34it/s, loss=0.0022] 


Val loss did not improve


Epoch 200:  20%|██        | 200/1000 [00:08<00:35, 22.25it/s, loss=0.00223]


Val loss did not improve


Epoch 300:  30%|███       | 300/1000 [00:13<00:31, 22.30it/s, loss=0.000586]


Val loss did not improve


Epoch 200:  20%|██        | 200/1000 [00:08<00:35, 22.33it/s, loss=0.00214]


Val loss did not improve


Epoch 300:  30%|███       | 300/1000 [00:13<00:31, 22.42it/s, loss=0.000607]


Val loss did not improve


Epoch 200:  20%|██        | 200/1000 [00:08<00:35, 22.29it/s, loss=0.00203]


Val loss did not improve


Epoch 300:  30%|███       | 300/1000 [00:13<00:31, 22.31it/s, loss=0.00065] 


Val loss did not improve


Epoch 300:  30%|███       | 300/1000 [00:13<00:31, 22.20it/s, loss=0.000593]


Val loss did not improve


Epoch 200:  20%|██        | 200/1000 [00:08<00:35, 22.38it/s, loss=0.00235]


Val loss did not improve


Epoch 200:  20%|██        | 200/1000 [00:08<00:35, 22.43it/s, loss=0.00226]


Val loss did not improve
Accuracy:0.772 ± 0.018
Precision:0.755 ± 0.023
Recall:0.796 ± 0.019
F1-score:0.775 ± 0.016


## Baseline 5 (Pre-Training with Concepts $x \rightarrow c_1, x \rightarrow y$)

In [16]:
results_5 = baseline_5(train_dataset, val_dataset, test_dataset, transform_dim=transform_dim, batch_size=batch_size, num_eval=10)

Epoch 200:  20%|██        | 200/1000 [00:09<00:36, 22.00it/s, loss=1.39e-10]


Val loss did not improve


Epoch 200:  20%|██        | 200/1000 [00:09<00:37, 21.43it/s, loss=0.000189]


Val loss did not improve


Epoch 200:  20%|██        | 200/1000 [00:09<00:36, 21.71it/s, loss=2.49e-10]


Val loss did not improve


Epoch 200:  20%|██        | 200/1000 [00:09<00:36, 21.63it/s, loss=0.000215]


Val loss did not improve


Epoch 200:  20%|██        | 200/1000 [00:09<00:36, 21.85it/s, loss=4.07e-11]


Val loss did not improve


Epoch 200:  20%|██        | 200/1000 [00:09<00:37, 21.49it/s, loss=0.000185]


Val loss did not improve


Epoch 300:  30%|███       | 300/1000 [00:13<00:31, 21.97it/s, loss=1.72e-10]


Val loss did not improve


Epoch 200:  20%|██        | 200/1000 [00:09<00:37, 21.62it/s, loss=0.000169]


Val loss did not improve


Epoch 200:  20%|██        | 200/1000 [00:09<00:36, 21.93it/s, loss=6.83e-11]


Val loss did not improve


Epoch 200:  20%|██        | 200/1000 [00:09<00:37, 21.31it/s, loss=0.00022] 


Val loss did not improve


Epoch 200:  20%|██        | 200/1000 [00:09<00:36, 21.87it/s, loss=6.21e-11]


Val loss did not improve


Epoch 200:  20%|██        | 200/1000 [00:09<00:36, 21.72it/s, loss=0.000158]


Val loss did not improve


Epoch 300:  30%|███       | 300/1000 [00:13<00:32, 21.82it/s, loss=1.11e-10]


Val loss did not improve


Epoch 200:  20%|██        | 200/1000 [00:09<00:37, 21.48it/s, loss=0.000257]


Val loss did not improve


Epoch 200:  20%|██        | 200/1000 [00:09<00:36, 21.92it/s, loss=1.68e-10]


Val loss did not improve


Epoch 100:  10%|█         | 100/1000 [00:04<00:42, 21.21it/s, loss=0.00121]


Val loss did not improve


Epoch 200:  20%|██        | 200/1000 [00:09<00:36, 21.84it/s, loss=6.76e-11]


Val loss did not improve


Epoch 200:  20%|██        | 200/1000 [00:09<00:36, 21.78it/s, loss=0.000246]


Val loss did not improve


Epoch 400:  40%|████      | 400/1000 [00:18<00:27, 21.85it/s, loss=1.24e-10]


Val loss did not improve


Epoch 200:  20%|██        | 200/1000 [00:09<00:37, 21.47it/s, loss=0.000203]


Val loss did not improve
Accuracy:0.772 ± 0.011
Precision:0.755 ± 0.013
Recall:0.795 ± 0.018
F1-score:0.775 ± 0.011


## Hard conditioning

In [65]:
!pip install kmeans-pytorch

Collecting kmeans-pytorch
  Downloading kmeans_pytorch-0.3-py3-none-any.whl (4.4 kB)
Installing collected packages: kmeans-pytorch
Successfully installed kmeans-pytorch-0.3


In [96]:
from kmeans_pytorch import kmeans, kmeans_predict

In [116]:
# class ConceptCLSUP_Hard_Cond(nn.Module):
#     def __init__(self, x_dim, hidden_dim, embed_dim, bin_dim, layers=2, activation='relu', lr=1e-4):
#         super(ConceptCLSUP_Hard_Cond, self).__init__()
#         self.critic_hidden_dim = 512
#         self.critic_layers = 1
#         self.critic_activation = 'relu'
#         self.lr = lr

#         # encoders
#         self.backbone = mlp(x_dim, hidden_dim, embed_dim, layers, activation)
#         self.linears_infonce = mlp(embed_dim, embed_dim, embed_dim, 1, activation) 

#         # critics
#         concept_dim = 1
#         self.club_critic = CLUBInfoNCECritic(embed_dim + x_dim, concept_dim, self.critic_hidden_dim, self.critic_layers, self.critic_activation)

#     def forward(self, x, c, sliced_x):
#         # compute embedding
#         z = self.linears_infonce(self.backbone(x))
#         # compute critic scores
#         # print(f'size of now {x[:,curr_idx:curr_idx+bin_dim].shape}, actual size {embed_dim + bin_dim}')
#         # print(f'size of z {z.shape}')
#         club_infonce_score = self.club_critic(torch.cat([z, sliced_x], dim=-1), c) #
#         return club_infonce_score

#     def get_embedding(self, x):
#         return self.backbone(x)
    
#     def get_backbone(self):
#         return self.backbone
    

def train_concept_informed_HC(concept_encoder, model, train_loader,val_loader, num_epochs, cluster_centers, device, lr, log_interval,
                          save_interval, save_path):

    concept_encoder.eval()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    best_val_err = torch.tensor(1e7)
    tepoch = tqdm(range(num_epochs))
    for epoch in tepoch:
        tepoch.set_description(f"Epoch {epoch}")
        model.train()

        for batch_idx, (data, concept, target,_) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            cluster_ids = kmeans_predict(data, cluster_centers, distance='euclidean', device=device)
            current_data = data[cluster_ids == batch_idx]
            c, z_c = concept_encoder(current_data)
            loss = model(current_data, c, target[cluster_ids == batch_idx])
            
            # bounds = torch.linspace(torch.min(data).item(), torch.max(data).item(), 6).to(device)
            # ind_list = torch.bucketize(data, bounds) 
            # sliced_x = data.clone()
            # sliced_x [ind_list != batch_idx] = 0
            # c, z_c = concept_encoder(data)
            # loss = model(sliced_x, c, target) #z_c
             
            loss.backward()
            optimizer.step()

        tepoch.set_postfix(loss=loss.item())
        if epoch % save_interval == 0:
            val_err = 0
            model.eval()
            with torch.no_grad():
                for batch_idx, (data, concept, target,_) in enumerate(val_loader):
                    data, target = data.to(device), target.to(device)
                    cluster_ids = kmeans_predict(data, cluster_centers, distance='euclidean', device=device)
                    current_data = data[cluster_ids == batch_idx]
                    c , z_c = concept_encoder(current_data)
                    output = model(current_data, c, target[cluster_ids == batch_idx])
                    
                    # bounds = torch.linspace(torch.min(data).item(), torch.max(data).item(), 6).to(device)
                    # ind_list = torch.bucketize(data, bounds) 
                    # sliced_x = data.clone()
                    # sliced_x [ind_list != batch_idx] = 0
                    # c , z_c = concept_encoder(data)
                    # output = model(sliced_x, c, target) #z_c
                    val_err += output
                val_err = val_err / len(val_loader)
            if val_err < best_val_err:
                best_val_err = val_err

            else:
                print('Val loss did not improve')
                torch.save(model.state_dict(), os.path.join(save_path, 'concept_informed_model.pth'))
                return model
    return model

In [121]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
hidden_dim = 512
embed_dim = 256
concept_encoder = ConceptEncoder(transform_dim, embed_dim, 1, hidden_dim, layers=1).to(device)
model = ConceptCLSUP_full_concept(transform_dim, embed_dim, 2, hidden_dim, embed_dim).to(device)

In [122]:
train_loader = DataLoader(train_dataset, shuffle=True, drop_last=True, batch_size=batch_size)
val_loader = DataLoader(val_dataset, shuffle=True, batch_size=batch_size, drop_last=True)
test_loader = DataLoader(test_dataset, shuffle=False, drop_last=False)
cluster_ids_x, cluster_centers = kmeans(
    X=torch.stack([sample[0] for sample in  train_dataset]), num_clusters=6, distance='euclidean', device=torch.device('cuda:0')
)
# train concpet encoder 
trained_concept_encoder = train_concept_encoder(concept_encoder, train_loader,val_loader, 1000, device, 1e-5, 1e-5, 25, 3, '../trained_models')
# train concept informed model
trained_concept_informed_HC_model = train_concept_informed_HC(trained_concept_encoder, model, train_loader, val_loader, 1000, cluster_centers, device, 1e-5, 25, 3, '../trained_models')

running k-means on cuda:0..


[running kmeans]: 3it [00:00,  7.88it/s, center_shift=0.000000, iteration=3, tol=0.000100]
Epoch 12:   1%|          | 12/1000 [00:00<00:52, 18.89it/s, loss=0.00106]


Val loss did not improve


Epoch 0:   0%|          | 0/1000 [00:00<?, ?it/s]

predicting on cuda..
predicting on cuda..
predicting on cuda..
predicting on cuda..


Epoch 0:   0%|          | 0/1000 [00:00<?, ?it/s, loss=-7.38e-6]

predicting on cuda..


Epoch 1:   0%|          | 1/1000 [00:00<03:42,  4.48it/s, loss=-7.38e-6]

predicting on cuda..
predicting on cuda..
predicting on cuda..
predicting on cuda..
predicting on cuda..
predicting on cuda..


Epoch 2:   0%|          | 2/1000 [00:00<03:35,  4.63it/s, loss=-5.16e-5]

predicting on cuda..
predicting on cuda..
predicting on cuda..
predicting on cuda..
predicting on cuda..


Epoch 3:   0%|          | 3/1000 [00:00<03:21,  4.95it/s, loss=-5.32e-8]

predicting on cuda..


Epoch 3:   0%|          | 3/1000 [00:00<03:21,  4.95it/s, loss=7.04e-7] 

predicting on cuda..
predicting on cuda..
predicting on cuda..
predicting on cuda..
predicting on cuda..
Val loss did not improve


Epoch 3:   0%|          | 3/1000 [00:00<05:03,  3.29it/s, loss=7.04e-7]


In [123]:
train_embeds_1 = trained_concept_encoder.get_embedding(torch.stack([sample[0] for sample in  train_dataset]).to(device)).detach().cpu().numpy()
train_embeds_2 = trained_concept_informed_HC_model.get_embedding(torch.stack([sample[0] for sample in  train_dataset]).to(device)).detach().cpu().numpy()
train_embeds = np.concatenate((train_embeds_1, train_embeds_2), axis=1) #train_embeds_2 #
train_labels = np.array([sample[2].item() for sample in  train_dataset])

test_embeds_1 = trained_concept_encoder.get_embedding(torch.stack([sample[0] for sample in  test_dataset]).to(device)).detach().cpu().numpy()
test_embeds_2 = trained_concept_informed_HC_model.get_embedding(torch.stack([sample[0] for sample in  test_dataset]).to(device)).detach().cpu().numpy()
test_embeds = np.concatenate((test_embeds_1, test_embeds_2), axis=1) #test_embeds_2 #
test_labels = np.array([sample[2].item() for sample in  test_dataset])

clf = LogisticRegression(max_iter=1000).fit(train_embeds, train_labels)
predictions = clf.predict(test_embeds)

accuracy = accuracy_score(test_labels, predictions)
precision = precision_score(test_labels, predictions)
recall = recall_score(test_labels, predictions)
f1 = f1_score(test_labels, predictions)

print("Accuracy:", accuracy)
print("Precision:", precision)
print("Recall:", recall)
print("F1-score:", f1)

Accuracy: 0.6066666666666667
Precision: 0.5666666666666667
Recall: 0.9066666666666666
F1-score: 0.6974358974358975


In [124]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
hidden_dim = 512
embed_dim = 256
concept_encoder = ConceptEncoder(transform_dim, embed_dim, 1, hidden_dim, layers=1).to(device)
model = ConceptCLSUP_full_concept(transform_dim, embed_dim, 2, hidden_dim, embed_dim).to(device)

train_loader = DataLoader(train_dataset, shuffle=True, drop_last=True, batch_size=batch_size)
val_loader = DataLoader(val_dataset, shuffle=True, batch_size=batch_size, drop_last=True)
test_loader = DataLoader(test_dataset, shuffle=False, drop_last=False)
# train concpet encoder 
trained_concept_encoder = train_concept_encoder(concept_encoder, train_loader,val_loader, 1000, device, 1e-5, 1e-5, 25, 3, '../trained_models')
# train concept informed model
trained_concept_informed_NC_model = train_concept_informed_model(trained_concept_encoder, model, train_loader, val_loader, 1000, device, 1e-5, 25, 3, '../trained_models')

train_embeds_1 = trained_concept_encoder.get_embedding(torch.stack([sample[0] for sample in  train_dataset]).to(device)).detach().cpu().numpy()
train_embeds_2 = trained_concept_informed_NC_model.get_embedding(torch.stack([sample[0] for sample in  train_dataset]).to(device)).detach().cpu().numpy()
train_embeds = np.concatenate((train_embeds_1, train_embeds_2), axis=1) #train_embeds_2 #
train_labels = np.array([sample[2].item() for sample in  train_dataset])

test_embeds_1 = trained_concept_encoder.get_embedding(torch.stack([sample[0] for sample in  test_dataset]).to(device)).detach().cpu().numpy()
test_embeds_2 = trained_concept_informed_NC_model.get_embedding(torch.stack([sample[0] for sample in  test_dataset]).to(device)).detach().cpu().numpy()
test_embeds = np.concatenate((test_embeds_1, test_embeds_2), axis=1) #test_embeds_2 #
test_labels = np.array([sample[2].item() for sample in  test_dataset])

clf = LogisticRegression(max_iter=1000).fit(train_embeds, train_labels)
predictions = clf.predict(test_embeds)

accuracy = accuracy_score(test_labels, predictions)
precision = precision_score(test_labels, predictions)
recall = recall_score(test_labels, predictions)
f1 = f1_score(test_labels, predictions)

print("Accuracy:", accuracy)
print("Precision:", precision)
print("Recall:", recall)
print("F1-score:", f1)

Epoch 21:   2%|▏         | 21/1000 [00:01<00:50, 19.31it/s, loss=0.000469]


Val loss did not improve


Epoch 6:   1%|          | 6/1000 [00:01<05:28,  3.03it/s, loss=-.000234]


Val loss did not improve
Accuracy: 0.6666666666666666
Precision: 0.6106194690265486
Recall: 0.92
F1-score: 0.7340425531914894
