In [21]:
import sys
sys.path.append('../')

from dataset import*
from synthetic_concept_model import *
from synthetic_coop_model import *
from baselines import *
from torch.utils.data import DataLoader
from sklearn.linear_model import LogisticRegression, LinearRegression
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import pickle
import torch
import numpy
import random

random.seed(7)
numpy.random.seed(seed=7)
torch.manual_seed(7)

<torch._C.Generator at 0x7fa6d91dd830>

In [22]:
# !nvidia-smi
torch.cuda.is_available()

True

In [23]:
print('__CUDNN VERSION:', torch.backends.cudnn.version())
print('__Number CUDA Devices:', torch.cuda.device_count())
print('__CUDA Device Name:',torch.cuda.get_device_name(0))
print('__CUDA Device Total Memory [GB]:',torch.cuda.get_device_properties(0).total_memory/1e9)

__CUDNN VERSION: 8902
__Number CUDA Devices: 1
__CUDA Device Name: Tesla V100-SXM2-32GB
__CUDA Device Total Memory [GB]: 34.079637504


## Data generation

In [24]:
feature_dim_info = dict()
label_dim_info = dict()
transform_dim = 100000 #100000

intersections = get_intersections(num_modalities=2)

feature_dim_info['12'] = 10
feature_dim_info['1'] = 6
feature_dim_info['2'] = 6

label_dim_info['12'] = 10
label_dim_info['1'] = 6
label_dim_info['2'] = 6
num_concepts = 1
transforms_2concept = None
transforms_2hd = None
num_data = 10000
noise=0.3
pos_prob=0.5
# total_data, total_labels, total_concepts, total_raw_features = generate_data_concepts(num_data, num_concepts,
#                                                                                       feature_dim_info,
#                                                                                       label_dim_info,
#                                                                                       transform_dim=transform_dim,
#                                                                                      noise=noise,
#                                                                                      pos_prob=pos_prob)

# synth_data_dict = {'total_data':total_data, 'total_labels':total_labels, 'total_concepts':total_concepts, 'total_raw_features':total_raw_features}
# synth_data_file_name = '../synth_data/'+'synth_data_exp2_'+str(noise)+'_'+str(pos_prob)+'.pkl'
# pickle.dump(synth_data_dict, open(synth_data_file_name, 'wb'))

In [5]:
# Data splitting & loading
synth_data_dict = pickle.load(open('../synth_data/'+'synth_data_exp2_'+str(noise)+'_'+str(pos_prob)+'.pkl', 'rb'))
total_data = synth_data_dict['total_data']
total_labels = synth_data_dict['total_labels']
total_concepts = synth_data_dict['total_concepts']
total_raw_features = synth_data_dict['total_raw_features']

dataset = MultiConcept(total_data, total_labels, total_concepts, 0)
batch_size = 100
pretrain_dataset, finetune_dataset = torch.utils.data.random_split(dataset,  
                                                            [int(0.8 * num_data), num_data - int(0.8 * num_data)])
train_dataset, val_dataset = torch.utils.data.random_split(pretrain_dataset,
                                                           [int(0.8 * len(pretrain_dataset)), len(pretrain_dataset) - int(0.8 * len(pretrain_dataset))])

train_loader = DataLoader(train_dataset, shuffle=True, drop_last=True,
                          batch_size=batch_size)
val_loader = DataLoader(val_dataset, shuffle=True, batch_size=batch_size, drop_last=True)

In [25]:
len(pretrain_dataset), len(finetune_dataset)

(8000, 2000)

In [6]:
# transform_dim = 101024

## Experiment 2

## Scenario 1
Pretrain on a large dataset $D_L:=\{x_i,c_i\}$ by minimizing InfoNCE_CLUB between $(Z_{\bar{c}},c|x)$ and fine-tune on small dataset $D_S:=\{x_j,c_j,y_j\}$ with supervised learning $x \rightarrow Z_{\bar{c}}, c \rightarrow y$

In [7]:
class ConceptCLSUP_Pretrain(nn.Module):
    def __init__(self, x_dim, hidden_dim, embed_dim, layers=2, activation='relu', lr=1e-4):
        super(ConceptCLSUP_Pretrain, self).__init__()
        self.critic_hidden_dim = 512
        self.critic_layers = 1
        self.critic_activation = 'relu'
        self.lr = lr

        # encoders
        self.backbone = mlp(x_dim, hidden_dim, embed_dim, layers, activation)
        self.linears_infonce = mlp(embed_dim, embed_dim, embed_dim, 1, activation) 

        # critics
        concept_dim = 1
        self.club_critic = CLUBInfoNCECritic(embed_dim + x_dim, concept_dim, self.critic_hidden_dim, self.critic_layers, self.critic_activation)

    def forward(self, x, c):
        # compute embedding
        z = self.linears_infonce(self.backbone(x))
        # compute critic scores
        club_infonce_score = self.club_critic(torch.cat([z, x], dim=-1), c)
        return club_infonce_score

    def get_embedding(self, x):
        return self.backbone(x)
    
    def get_backbone(self):
        return self.backbone
    

def train_concept_informed_Pretrain_model(concept_encoder, model, train_loader,val_loader, num_epochs, device, lr, log_interval,
                          save_interval, save_path):

    concept_encoder.eval()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    best_val_err = torch.tensor(1e7)
    tepoch = tqdm(range(num_epochs))
    for epoch in tepoch:
        tepoch.set_description(f"Epoch {epoch}")
        model.train()
        for batch_idx, (data, concept, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            c, z_c = concept_encoder(data)
            loss = model(data, c) #z_c
            loss.backward()
            optimizer.step()
        tepoch.set_postfix(loss=loss.item())
        if epoch % save_interval == 0:
            val_err = 0
            model.eval()
            with torch.no_grad():
                for batch_idx, (data, concept, target) in enumerate(val_loader):
                    data, target = data.to(device), target.to(device)
                    c , z_c = concept_encoder(data)
                    output = model(data, c) #z_c
                    val_err += output
                val_err = val_err / len(val_loader)
            if val_err < best_val_err:
                best_val_err = val_err

            else:
                print('Val loss did not improve')
                torch.save(model.state_dict(), os.path.join(save_path, 'concept_informed_model.pth'))
                return model
    return model


In [30]:
# models
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
hidden_dim = 512
embed_dim = 50 
concept_encoder = ConceptEncoder(transform_dim, embed_dim, 1, hidden_dim, layers=1).to(device)
model = ConceptCLSUP_Pretrain(transform_dim, hidden_dim, embed_dim).to(device)

In [31]:
# train concpet encoder 
trained_concept_encoder = train_concept_encoder(concept_encoder, train_loader,val_loader, 1000, device, 1e-5, 1e-5, 25, 3, '../trained_models')
# train concept informed model
trained_concept_informed_Pretrain_model = train_concept_informed_Pretrain_model(trained_concept_encoder, model, train_loader, val_loader, 1000, device, 1e-5, 25, 3, '../trained_models')
        

Epoch 42:   4%|▍         | 42/1000 [03:47<1:15:54,  4.75s/it, loss=1.37e-8]

Val loss did not improve


Epoch 42:   4%|▍         | 42/1000 [03:48<1:26:52,  5.44s/it, loss=1.37e-8]
Epoch 573:  57%|█████▋    | 573/1000 [3:07:07<2:17:31, 19.32s/it, loss=-6.39e+3]

Val loss did not improve


Epoch 573:  57%|█████▋    | 573/1000 [3:07:10<2:19:28, 19.60s/it, loss=-6.39e+3]


In [33]:
x

10

In [9]:
# train concpet encoder 
trained_concept_encoder = train_concept_encoder(concept_encoder, train_loader,val_loader, 1000, device, 1e-5, 1e-5, 25, 3, '../trained_models')
# train concept informed model
trained_concept_informed_Pretrain_model = train_concept_informed_Pretrain_model(trained_concept_encoder, model, train_loader, val_loader, 1000, device, 1e-5, 25, 3, '../trained_models')
        

Epoch 18:   2%|▏         | 18/1000 [01:33<1:17:12,  4.72s/it, loss=3.53e-7] 

Val loss did not improve


Epoch 18:   2%|▏         | 18/1000 [01:35<1:26:56,  5.31s/it, loss=3.53e-7]
Epoch 612:  61%|██████    | 612/1000 [3:19:51<2:04:56, 19.32s/it, loss=-8.38e+3]

Val loss did not improve


Epoch 612:  61%|██████    | 612/1000 [3:19:54<2:06:44, 19.60s/it, loss=-8.38e+3]


In [10]:
class ConceptCLSUP_Finetune_Sc1(nn.Module):
    def __init__(self, backbone, embed_dim, hidden_dim, **extra_kwargs):
        super(ConceptCLSUP_Finetune_Sc1, self).__init__()
        self.backbone = backbone
        # self.concept_encoder = mlp(1, hidden_dim, 128, 1, 'relu')
        self.fc = nn.Linear(embed_dim+1, 1)

    def forward(self, x, c):
        x = self.backbone(x)
        x_c = torch.cat((x, c), dim=1)
        out = self.fc(x_c)
        return out
def train_fine_tune_model_sc1(model, train_loader, val_loader, num_epochs, lr, weight_decay, device, log_interval,
                          save_interval):
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    best_val_err = torch.tensor(1e7)
    loss_func = nn.BCELoss() #nn.CrossEntropyLoss() #
    model.to(device)
    model.train()
    tepoch = tqdm(range(num_epochs))
    for epoch in tepoch:
        tepoch.set_description(f"Epoch {epoch}")
        for batch_idx, (data, concept, target) in enumerate(train_loader):
            data, concept, target = data.to(device), concept.to(device), target.to(device)
            optimizer.zero_grad()
            logits = model(data, concept)
            preds =  torch.sigmoid(logits) #torch.softmax(logits, dim=-1)
            loss = loss_func(preds, target.float()) #F.cross_entropy(preds, target)
            loss.backward()
            optimizer.step()
        tepoch.set_postfix(loss=loss.item())
        if epoch % save_interval == 0:
            val_err = 0
            model.eval()
            with torch.no_grad():
                for batch_idx, (data, concept, target) in enumerate(val_loader):
                    data, concept, target = data.to(device), concept.to(device), target.to(device)
                    logits = model(data, concept)
                    preds = torch.sigmoid(logits)
                    val_err += loss_func(preds, target.float())
                val_err = val_err / len(val_loader)
            # print('Val loss: {:.6f}'.format(val_err))
            if val_err < best_val_err:
                best_val_err = val_err
            else:
                print('Val loss did not improve')
                return model
    return model

In [11]:
import torch.nn as nn

#New test split
num_data = len(finetune_dataset)
batch_size = 100
new_trainval_dataset, new_test_dataset = torch.utils.data.random_split(finetune_dataset,  
                                                            [int(0.5 * num_data), num_data - int(0.5 * num_data)])
new_train_dataset, new_val_dataset = torch.utils.data.random_split(new_trainval_dataset,
                                                           [int(0.8 * len(new_trainval_dataset)), len(new_trainval_dataset) - int(0.8 * len(new_trainval_dataset))])

new_train_loader = DataLoader(new_train_dataset, shuffle=True, drop_last=True,
                          batch_size=batch_size)
new_val_loader = DataLoader(new_val_dataset, shuffle=True, batch_size=batch_size, drop_last=True)
new_test_loader = DataLoader(new_test_dataset, shuffle=False, drop_last=False)

#Train Final Model

backbone = trained_concept_informed_Pretrain_model.get_backbone()
# new_model = nn.Sequential(backbone, mlp(50, 256, 1, 1, activation= 'relu'))
new_model = ConceptCLSUP_Finetune_Sc1(backbone, embed_dim, hidden_dim)

In [12]:
from baselines import*

# final_model = mlp_train(new_model, new_train_loader, new_val_loader, 100, 1e-5, 1e-5,'cuda', 100, 100)
final_model = train_fine_tune_model_sc1(new_model, new_train_loader, new_val_loader, 1000, 1e-5, 1e-5,'cuda', 100, 100)

test_embeds = torch.stack([sample[0] for sample in  new_test_dataset])#.detach().cpu().numpy()
test_concepts = torch.tensor([sample[1].item() for sample in  new_test_dataset]).unsqueeze(1)
test_labels = np.array([sample[-1].item() for sample in  new_test_dataset])
    
out = final_model(test_embeds.to(device), test_concepts.to(device))
predictions = torch.sigmoid(out).round().detach().cpu().numpy()
        
accuracy = accuracy_score(test_labels, predictions)
precision = precision_score(test_labels, predictions)
recall = recall_score(test_labels, predictions)
f1 = f1_score(test_labels, predictions)

print("Accuracy:", accuracy)
print("Precision:", precision)
print("Recall:", recall)
print("F1-score:", f1)

Epoch 300:  30%|███       | 300/1000 [02:31<05:54,  1.98it/s, loss=0.000155]


Val loss did not improve
Accuracy: 0.861
Precision: 0.8188539741219963
Recall: 0.9152892561983471
F1-score: 0.8643902439024391


## Scenario 2
Pretrain on a large dataset $D_L:=\{x_i,c_i\}$ by minimizing InfoNCE_CLUB between $(Z_{\bar{c}},c|x)$
Proceed to fine-tune on small dataset $D_S:=\{x_j,c_j, y_j\}$ with both supervised learning and InfoNCE_CLUB score 
$\mathop{\arg \min}\limits_{\theta , \phi} \mathcal{L}\bigl( y, f_{\theta}(Z_{\bar{c}},c)\bigr) + \lambda Info_{NCE\_CLUB} \bigl( Z_{\bar{c}},c|x \bigr)$


In [13]:
class ConceptCLSUP_Pretrain_Sc2(nn.Module):
    def __init__(self, x_dim, hidden_dim, embed_dim, layers=2, activation='relu', lr=1e-4, concept_dim=1):
        super(ConceptCLSUP_Pretrain_Sc2, self).__init__()
        self.critic_hidden_dim = 512
        self.critic_layers = 1
        self.critic_activation = 'relu'
        self.lr = lr

        # encoders
        self.backbone = mlp(x_dim, hidden_dim, embed_dim, layers, activation)
        self.linears_club = mlp(embed_dim, embed_dim, embed_dim, 1, activation) 

        # critics
        self.club_critic = CLUBInfoNCECritic(embed_dim + x_dim, concept_dim, self.critic_hidden_dim, self.critic_layers, self.critic_activation)
        

    def forward(self, x, c):
        # compute embedding
        z_comp = self.linears_club(self.backbone(x))
        # compute critic scores
        club_infonce_score = self.club_critic(torch.cat([z_comp, x], dim=-1), c)
        
        return  club_infonce_score #+ infonce_score

    def get_embedding(self, x):
        #what should the embedding be?
        return self.backbone(x)
    
    def get_backbone(self):
        return self.backbone #, self.linears_club

    
class ConceptCLSUP_Finetune_Sc2(nn.Module):
    def __init__(self, backbone, x_dim, hidden_dim, embed_dim, layers=2, activation='relu', lr=1e-4, concept_dim=1):
        super(ConceptCLSUP_Finetune_Sc2, self).__init__()
        self.critic_hidden_dim = 512
        self.critic_layers = 1
        self.critic_activation = 'relu'
        self.lr = lr
        self.backbone = backbone

        # encoders
        self.lable_encoder = mlp(embed_dim, embed_dim, embed_dim, 1, activation) 
        self.linears_club = mlp(embed_dim, embed_dim, embed_dim, 1, activation) 
        self.linears_label = mlp(embed_dim + embed_dim + 1, embed_dim, 1, 1, activation)

        # critics
        self.club_critic = CLUBInfoNCECritic(embed_dim + x_dim, concept_dim, self.critic_hidden_dim, self.critic_layers, self.critic_activation)

    def forward(self, x, c):
        # compute embedding
        z_comp = self.linears_club(self.backbone(x))
        z_y = self.lable_encoder(self.backbone(x))
        logit = self.linears_label(torch.cat([z_comp, z_y, c], dim=-1))
        
        # compute critic scores
        club_infonce_score = self.club_critic(torch.cat([z_comp, x], dim=-1), c)
        
        return  club_infonce_score, logit

    def get_logits(self, x, c):
        z_comp = self.linears_club(self.backbone(x))
        z_y = self.lable_encoder(self.backbone(x))
        logit = self.linear_label(torch.cat([z_comp, z_y, c], dim=-1))
        return logit

def train_fine_tune_model_sc2(model, train_loader, val_loader, num_epochs, lr, weight_decay, device, log_interval,
                          save_interval):
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    best_val_err = torch.tensor(1e7)
    loss_func = nn.BCELoss() #nn.CrossEntropyLoss() #
    model.to(device)
    model.train()
    tepoch = tqdm(range(num_epochs))
    for epoch in tepoch:
        tepoch.set_description(f"Epoch {epoch}")
        for batch_idx, (data, concept, target) in enumerate(train_loader):
            data, concept, target = data.to(device), concept.to(device), target.to(device)
            optimizer.zero_grad()
            score, logits = model(data, concept)
            preds =  torch.sigmoid(logits) #torch.softmax(logits, dim=-1)
            label_loss = loss_func(preds, target.float()) #F.cross_entropy(preds, target)
            ###########??????????????????????????????????????????? lamb
            lamb = 0.5
            loss = lamb *score + label_loss
            loss.backward()
            optimizer.step()
            torch.cuda.empty_cache()
        tepoch.set_postfix(loss=loss.item())
        if epoch % save_interval == 0:
            val_err = 0
            model.eval()
            with torch.no_grad():
                for batch_idx, (data, concept, target) in enumerate(val_loader):
                    data, concept, target = data.to(device), concept.to(device), target.to(device)
                    score, logits = model(data, concept)
                    preds = torch.sigmoid(logits)
                    output = score + loss_func(preds, target.float())
                    val_err += output
                val_err = val_err / len(val_loader)
            # print('Val loss: {:.6f}'.format(val_err))
            if val_err < best_val_err:
                best_val_err = val_err
            else:
                print('Val loss did not improve')
                return model
    return model

In [14]:
# models
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
hidden_dim = 512
embed_dim = 50 
concept_encoder = ConceptEncoder(transform_dim, embed_dim, 1, hidden_dim, layers=1).to(device)
model = ConceptCLSUP_Pretrain_Sc2(transform_dim, hidden_dim, embed_dim).to(device)

In [15]:
# train concpet encoder
trained_concept_encoder = train_concept_encoder(concept_encoder, train_loader,val_loader, 1000, device, 1e-5, 1e-5, 25, 3, '../trained_models')
# train concept informed model
trained_concept_informed_Pretrain_model = train_concept_informed_Pretrain_model(trained_concept_encoder, model, train_loader, val_loader, 1000, device, 1e-5, 25, 3, '../trained_models')
        

Epoch 21:   2%|▏         | 21/1000 [01:46<1:15:39,  4.64s/it, loss=9.91e-8] 

Val loss did not improve


Epoch 21:   2%|▏         | 21/1000 [01:47<1:23:45,  5.13s/it, loss=9.91e-8]
Epoch 672:  67%|██████▋   | 672/1000 [3:39:21<1:45:41, 19.33s/it, loss=-9.37e+3]

Val loss did not improve


Epoch 672:  67%|██████▋   | 672/1000 [3:39:24<1:47:05, 19.59s/it, loss=-9.37e+3]


In [18]:
import torch.nn as nn

#New test split
num_data = len(finetune_dataset)
batch_size = 100
new_trainval_dataset, new_test_dataset = torch.utils.data.random_split(finetune_dataset,  
                                                            [int(0.5 * num_data), num_data - int(0.5 * num_data)])
new_train_dataset, new_val_dataset = torch.utils.data.random_split(new_trainval_dataset,
                                                           [int(0.8 * len(new_trainval_dataset)), len(new_trainval_dataset) - int(0.8 * len(new_trainval_dataset))])

new_train_loader = DataLoader(new_train_dataset, shuffle=True, drop_last=True,
                          batch_size=batch_size)
new_val_loader = DataLoader(new_val_dataset, shuffle=True, batch_size=batch_size, drop_last=True)
new_test_loader = DataLoader(new_test_dataset, shuffle=False, drop_last=False)

In [19]:
#Train Final Model

trained_concept_informed_Pretrain_model.get_backbone()
final_model = ConceptCLSUP_Finetune_Sc2(trained_concept_informed_Pretrain_model.get_backbone(), transform_dim, hidden_dim, embed_dim).to(device)
final_model = train_fine_tune_model_sc2(final_model, new_train_loader, new_val_loader, 1000, 1e-5, 1e-5,'cuda', 100, 100)

test_embeds = torch.stack([sample[0] for sample in  new_test_dataset]).to(device)
test_concepts = torch.tensor([sample[1].item() for sample in  new_test_dataset]).unsqueeze(1).to(device)
test_labels = torch.tensor([sample[-1].item() for sample in  new_test_dataset]).to(device)

predictions = []
idx = 0
while idx+batch_size < len(test_embeds)+1:
    _, out = final_model(test_embeds[idx:idx+batch_size], test_concepts[idx:idx+batch_size])
    idx = idx+batch_size
    predictions.append(torch.sigmoid(out).round().detach().cpu().numpy())
        
predictions = np.reshape(np.array(predictions), test_labels.shape)
accuracy = accuracy_score(test_labels.detach().cpu().numpy(), predictions)
precision = precision_score(test_labels.detach().cpu().numpy(), predictions)
recall = recall_score(test_labels.detach().cpu().numpy(), predictions)
f1 = f1_score(test_labels.detach().cpu().numpy(), predictions)

print("Accuracy:", accuracy)
print("Precision:", precision)
print("Recall:", recall)
print("F1-score:", f1)

Epoch 999: 100%|██████████| 1000/1000 [46:56<00:00,  2.82s/it, loss=-75.8]


Accuracy: 0.805
Precision: 0.8068669527896996
Recall: 0.7817047817047817
F1-score: 0.7940865892291448


## Baseline (Pre-Training with Concepts $x \rightarrow c_1, x \rightarrow y$)

In [20]:
# train_loader = DataLoader(train_dataset, shuffle=True, drop_last=True,
#                           batch_size=batch_size)
# val_loader = DataLoader(val_dataset, shuffle=True, batch_size=batch_size, drop_last=True)
# test_loader = DataLoader(test_dataset, shuffle=False, drop_last=False)

backbone = mlp(transform_dim, 512, 1, layers=3, activation='relu')
trained_backbone = mlp_train_c(backbone, train_loader, val_loader, 1000, 1e-5, 1e-5,'cuda', 100, 100)
FC = mlp(1, 256, 1, 1, activation= 'relu')
model = nn.Sequential(trained_backbone, FC)
trained_model = mlp_train(model, new_train_loader, new_val_loader, 1000, 1e-5, 1e-5,'cuda', 100, 100)

test_embeds = torch.stack([sample[0] for sample in  new_test_dataset]).detach().cpu().numpy()
out = trained_model(torch.tensor(test_embeds).to(device))
predictions = torch.sigmoid(out).round().detach().cpu().numpy()
        
accuracy = accuracy_score(test_labels.detach().cpu().numpy(), predictions)
precision = precision_score(test_labels.detach().cpu().numpy(), predictions)
recall = recall_score(test_labels.detach().cpu().numpy(), predictions)
f1 = f1_score(test_labels.detach().cpu().numpy(), predictions)

print("Accuracy:", accuracy)
print("Precision:", precision)
print("Recall:", recall)
print("F1-score:", f1)

Epoch 200:  20%|██        | 200/1000 [15:56<1:03:45,  4.78s/it, loss=0.000126]


Val loss did not improve


Epoch 200:  20%|██        | 200/1000 [01:47<07:09,  1.86it/s, loss=2.59e-6]


Val loss did not improve
Accuracy: 0.876
Precision: 0.8872017353579176
Recall: 0.8503118503118503
F1-score: 0.8683651804670913
