In [2]:

import itertools
import numpy as np
from sklearn.metrics import accuracy_score,roc_auc_score,recall_score,f1_score


import scipy.sparse as sp
import sys
# insert at 1, 0 is the script path (or '' in REPL)
sys.path.insert(1, '../')

In [3]:
import time
import argparse
import numpy as np
import dgl
from utils import feature_norm
from sklearn.metrics import accuracy_score,roc_auc_score,recall_score,f1_score
import pickle as pk
import pdb

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from utils import load_data, accuracy,load_pokec
from models.FairGNN import FairGNN, FairGNN1
from models.VarFairGNN import VarFairGNN

# Training settings
parser = argparse.ArgumentParser()
parser.add_argument('--no-cuda', action='store_true', default=False,
                    help='Disables CUDA training.')
parser.add_argument('--fastmode', action='store_true', default=False,
                    help='Validate during training pass.')
parser.add_argument('--seed', type=int, default=42, help='Random seed.')
parser.add_argument('--epochs', type=int, default=2000,
                    help='Number of epochs to train.')
parser.add_argument('--lr', type=float, default=0.001,
                    help='Initial learning rate.')
parser.add_argument('--weight_decay', type=float, default=1e-5,
                    help='Weight decay (L2 loss on parameters).')
parser.add_argument('--hidden', type=int, default=128,
                    help='Number of hidden units of the sensitive attribute estimator')
parser.add_argument('--dropout', type=float, default=.5,
                    help='Dropout rate (1 - keep probability).')
parser.add_argument('--alpha', type=float, default=4,
                    help='The hyperparameter of alpha')
parser.add_argument('--beta', type=float, default=0.01,
                    help='The hyperparameter of beta')
parser.add_argument('--model', type=str, default="GAT",
                    help='the type of model GCN/GAT')
parser.add_argument('--dataset', type=str, default='nba',
                    choices=['pokec_z','pokec_n','nba'])
parser.add_argument('--num-hidden', type=int, default=64,
                    help='Number of hidden units of classifier.')
parser.add_argument("--num-heads", type=int, default=1,
                        help="number of hidden attention heads")
parser.add_argument("--num-out-heads", type=int, default=1,
                    help="number of output attention heads")
parser.add_argument("--num-layers", type=int, default=1,
                    help="number of hidden layers")
parser.add_argument("--residual", action="store_true", default=False,
                    help="use residual connection")
parser.add_argument("--attn-drop", type=float, default=.0,
                    help="attention dropout")
parser.add_argument('--negative-slope', type=float, default=0.2,
                    help="the negative slope of leaky relu")
parser.add_argument('--acc', type=float, default=0.688,
                    help='the selected FairGNN accuracy on val would be at least this high')
parser.add_argument('--roc', type=float, default=0.745,
                    help='the selected FairGNN ROC score on val would be at least this high')
parser.add_argument('--sens_number', type=int, default=200,
                    help="the number of sensitive attributes")
parser.add_argument('--label_number', type=int, default=500,
                    help="the number of labels")


def fair_metric(output,idx):
    val_y = labels[idx].cpu().numpy()
    idx_s0 = sens.cpu().numpy()[idx.cpu().numpy()]==0
    idx_s1 = sens.cpu().numpy()[idx.cpu().numpy()]==1

    idx_s0_y1 = np.bitwise_and(idx_s0,val_y==1)
    idx_s1_y1 = np.bitwise_and(idx_s1,val_y==1)

#     pdb.set_trace()
    pred_y = (output[idx].squeeze()>0).type_as(labels).cpu().numpy()
    parity = abs(sum(pred_y[idx_s0])/sum(idx_s0)-sum(pred_y[idx_s1])/sum(idx_s1))
    equality = abs(sum(pred_y[idx_s0_y1])/sum(idx_s0_y1)-sum(pred_y[idx_s1_y1])/sum(idx_s1_y1))
    # counterfair = sum(pred_y[])

    return parity,equality



args = parser.parse_known_args()[0]


args.cuda = not args.no_cuda and torch.cuda.is_available()
print(args)
#%%
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if args.cuda:
    torch.cuda.manual_seed(args.seed)

# Load datprint(args.dataset)
print(args.dataset)

if args.dataset != 'nba':
    if args.dataset == 'pokec_z':
        dataset = 'region_job'
    else:
        dataset = 'region_job_2'
    sens_attr = "region"
    predict_attr = "I_am_working_in_field"
    label_number = args.label_number
    sens_number = args.sens_number
    seed = 20
    path="../../dataset/pokec/"
    test_idx=False
else:
    dataset = 'nba'
    sens_attr = "country"
    predict_attr = "SALARY"
    label_number = 100
    sens_number = 50
    seed = 20
    path = "../../dataset/NBA"
    test_idx = True
print(dataset)

adj, features, labels, idx_train, idx_val, idx_test, sens, idx_sens_train = load_pokec(dataset,
                                                                                    sens_attr,
                                                                                    predict_attr,
                                                                                    path=path,
                                                                                    label_number=label_number,
                                                                                    sens_number=sens_number,
                                                                                    seed=seed,test_idx=test_idx)
print(len(idx_test))

G = dgl.from_scipy(adj)
if dataset == 'nba':
    features = feature_norm(features)

labels[labels>1]=1
if sens_attr:
    sens[sens>0]=1
# Model and optimizer

args.data = features

Namespace(acc=0.688, alpha=4, attn_drop=0.0, beta=0.01, cuda=True, dataset='nba', dropout=0.5, epochs=2000, fastmode=False, hidden=128, label_number=500, lr=0.001, model='GAT', negative_slope=0.2, no_cuda=False, num_heads=1, num_hidden=64, num_layers=1, num_out_heads=1, residual=False, roc=0.745, seed=42, sens_number=200, weight_decay=1e-05)
nba
nba
Loading nba dataset from ../../dataset/NBA
213


In [4]:
def compute_loss(pos_score, neg_score):
    scores = torch.cat([pos_score, neg_score])
    labels = torch.cat([torch.ones(pos_score.shape[0]), torch.zeros(neg_score.shape[0])]).cuda()
    return F.binary_cross_entropy_with_logits(scores, labels)

def compute_auc(pos_score, neg_score):
    scores = torch.cat([pos_score, neg_score]).cpu().numpy()
    labels = torch.cat(
        [torch.ones(pos_score.shape[0]), torch.zeros(neg_score.shape[0])]).cpu().numpy()
    return roc_auc_score(labels, scores)

class MLPPredictor(nn.Module):
    def __init__(self, h_feats):
        super(MLPPredictor, self).__init__()
        self.W1 = nn.Linear(h_feats * 2, h_feats)
        self.W2 = nn.Linear(h_feats, 1)

    def forward(self, edges1,edges2): # N,h (N,h)
        h = torch.cat([edges1, edges2], 1)
        out=self.W2(F.relu(self.W1(h))).squeeze(1)
        return out
        

In [5]:
def generate_similarity_graph(args,G,features,labels,idx_train,idx_val,idx_test, sens,idx_sens_train, noisy_estimate_factor=1.0,test_split=0.1):

    
    model = FairGNN1(nfeat = features.shape[1], args = args)
    # model = VarFairGNN(nfeat = features.shape[1], args = args)
    # model.estimator.load_state_dict(torch.load("./checkpoint/GCN_sens_{}_ns_{}".format(dataset,sens_number)))
    if args.cuda:
        G = G.to(torch.device('cuda:0'))
        model.cuda()
        features = features.cuda()
        labels = labels.cuda()
        idx_train = idx_train.cuda()
        idx_val = idx_val.cuda()
        idx_test = idx_test.cuda()
        sens = sens.cuda()
        idx_sens_train = idx_sens_train.cuda()

        
  
    # Train model
    t_total = time.time()
    best_result = {}
    best_fair = 100

    acc_test_list = []
    roc_test_list = []
    parity_list = []
    equality_list = []
    
    from sklearn.metrics import accuracy_score,roc_auc_score,recall_score,f1_score

    for epoch in range(args.epochs):
        t = time.time()
        model.train()
        model.optimize(G,features,labels,idx_train,sens,idx_sens_train)
        cov = model.cov
        cls_loss = model.cls_loss
        adv_loss = model.adv_loss
        model.eval()
        output,s,z = model(G, features)
        acc_val = accuracy(output[idx_val], labels[idx_val])
        roc_val = roc_auc_score(labels[idx_val].cpu().numpy(),output[idx_val].detach().cpu().numpy())


        acc_sens = accuracy(s[idx_test], sens[idx_test])

        parity_val, equality_val = fair_metric(output,idx_val)

        acc_test = accuracy(output[idx_test], labels[idx_test])
        roc_test = roc_auc_score(labels[idx_test].cpu().numpy(),output[idx_test].detach().cpu().numpy())
        parity,equality = fair_metric(output,idx_test)
        if acc_val > args.acc and roc_val > args.roc:
            acc_test_list.append(acc_test.item())
            roc_test_list.append(roc_test)
            parity_list.append(parity)
            equality_list.append(equality)    
            if best_fair > parity_val + equality_val :
                best_fair = parity_val + equality_val

                best_result['acc'] = acc_test.item()
                best_result['roc'] = roc_test
                best_result['parity'] = parity
                best_result['equality'] = equality

            print("=================================")

            print('Epoch: {:04d}'.format(epoch+1),
                'cov: {:.4f}'.format(cov.item()),
                'cls: {:.4f}'.format(cls_loss.item()),
                'adv: {:.4f}'.format(adv_loss.item()),
                'acc_val: {:.4f}'.format(acc_val.item()),
                "roc_val: {:.4f}".format(roc_val),
                "parity_val: {:.4f}".format(parity_val),
                "equality: {:.4f}".format(equality_val))
            print("Test:",
                    "accuracy: {:.4f}".format(acc_test.item()),
                    "roc: {:.4f}".format(roc_test),
                    "acc_sens: {:.4f}".format(acc_sens),
                    "parity: {:.4f}".format(parity),
                    "equality: {:.4f}".format(equality))



    print("Optimization Finished!")
    print("Total time elapsed: {:.4f}s".format(time.time() - t_total))

    print('============performace on test set=============')
    if len(best_result) > 0:
        print("Test:",
                "accuracy: {:.4f}".format(best_result['acc']),
                "roc: {:.4f}".format(best_result['roc']),
                "acc_sens: {:.4f}".format(acc_sens),
                "parity: {:.4f}".format(best_result['parity']),
                "equality: {:.4f}".format(best_result['equality']))
        print("Test:",
                "accuracy: {:.4f}".format(np.mean(acc_test_list)),
                "roc: {:.4f}".format(np.mean(roc_test_list)),
                "parity: {:.4f}".format(np.mean(parity_list)),
                "equality: {:.4f}".format(np.mean(equality_list)))
    else:
        print("Please set smaller acc/roc thresholds")
     
    g=G
    u, v = g.cpu().edges()

    eids = np.arange(g.number_of_edges())
    eids = np.random.permutation(eids)
    nedges=int(len(eids)*noisy_estimate_factor)
    test_size = int(nedges * test_split)
    train_size = nedges - test_size
    test_pos_u, test_pos_v = u[eids[:test_size]], v[eids[:test_size]]
    train_pos_u, train_pos_v = u[eids[test_size:nedges]], v[eids[test_size:nedges]]

    # Find all negative edges and split them for training and testing
    adj1 = sp.coo_matrix((np.ones(len(u)), (u.numpy(), v.numpy())))
    adj1_neg = 1 - adj1.todense() - np.eye(g.number_of_nodes())
    neg_u, neg_v = np.where(adj1_neg != 0)

    neg_eids = np.random.choice(len(neg_u), nedges// 2)
    test_neg_u, test_neg_v = neg_u[neg_eids[:test_size]], neg_v[neg_eids[:test_size]]
    train_neg_u, train_neg_v = neg_u[neg_eids[test_size:nedges ]], neg_v[neg_eids[test_size:nedges ]]
    
    train_pos_edges1=z[train_pos_u].cuda().detach()
    train_pos_edges2=z[train_pos_v].cuda().detach()
    test_pos_edges1=z[test_pos_u].cuda().detach()
    test_pos_edges2=z[test_pos_v].cuda().detach()
    train_neg_edges1=z[train_neg_u].cuda().detach()
    train_neg_edges2=z[train_neg_v].cuda().detach()
    test_neg_edges1=z[test_neg_u].cuda().detach()
    test_neg_edges2=z[test_neg_v].cuda().detach()
    
    pred1 = MLPPredictor(args.num_hidden)
    pred1.cuda()

        # ----------- 3. set up loss and optimizer -------------- #
    # in this case, loss will in training loop
    optimizer = torch.optim.Adam(pred1.parameters(), lr=0.002)

    # ----------- 4. training -------------------------------- #
    all_logits = []
    for e in range(100):
        # forward
        pred1.train()
        pos_score = pred1(train_pos_edges1, train_pos_edges2)
        neg_score = pred1(train_neg_edges1, train_neg_edges1)
        loss = compute_loss(pos_score, neg_score)
    #     print('In epoch {}, loss: {}'.format(e, loss.data))
        # backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if e % 5 == 0:
            print('In epoch {}, loss: {}'.format(e, loss.data))

    #     ----------- 5. check results ------------------------ #
        from sklearn.metrics import roc_auc_score
        with torch.no_grad():
            pos_score = pred1(test_pos_edges1, test_pos_edges2)
            neg_score = pred1(test_neg_edges1, test_neg_edges1)
            print('AUC', compute_auc(pos_score, neg_score))


    pred1.eval()
    # ----------- 5. check results ------------------------ #
    from sklearn.metrics import roc_auc_score
    with torch.no_grad():
        pos_score = pred1(test_pos_edges1, test_pos_edges2)
        neg_score = pred1(test_neg_edges1, test_neg_edges1)
        print('AUC', compute_auc(pos_score, neg_score))
        
    #for evaluation just use
    #torch.special.expit(pos_score)
    return z,pred1,pos_score,neg_score

In [6]:
z,pred,pos_score,neg_score=generate_similarity_graph(args,G,features,labels,idx_train,idx_val,idx_test, sens,idx_sens_train, noisy_estimate_factor=1.0,test_split=0.1)

Epoch: 0822 cov: 0.0005 cls: 0.4683 adv: 0.6830 acc_val: 0.6948 roc_val: 0.7564 parity_val: 0.0554 equality: 0.1515
Test: accuracy: 0.6948 roc: 0.7564 acc_sens: 0.7230 parity: 0.0554 equality: 0.1515
Epoch: 0878 cov: 0.0002 cls: 0.4558 adv: 0.6832 acc_val: 0.6901 roc_val: 0.7464 parity_val: 0.0006 equality: 0.0676
Test: accuracy: 0.6901 roc: 0.7464 acc_sens: 0.7230 parity: 0.0006 equality: 0.0676
Epoch: 0892 cov: 0.0001 cls: 0.4168 adv: 0.6850 acc_val: 0.6901 roc_val: 0.7499 parity_val: 0.0251 equality: 0.1046
Test: accuracy: 0.6901 roc: 0.7499 acc_sens: 0.7230 parity: 0.0251 equality: 0.1046
Epoch: 0895 cov: 0.0024 cls: 0.4495 adv: 0.6804 acc_val: 0.6901 roc_val: 0.7566 parity_val: 0.0215 equality: 0.0917
Test: accuracy: 0.6901 roc: 0.7566 acc_sens: 0.7230 parity: 0.0215 equality: 0.0917
Epoch: 0896 cov: 0.0035 cls: 0.4783 adv: 0.6820 acc_val: 0.6995 roc_val: 0.7554 parity_val: 0.0006 equality: 0.0796
Test: accuracy: 0.6995 roc: 0.7554 acc_sens: 0.7230 parity: 0.0006 equality: 0.0796


Epoch: 1383 cov: 0.0010 cls: 0.5091 adv: 0.6822 acc_val: 0.6901 roc_val: 0.7470 parity_val: 0.0045 equality: 0.0202
Test: accuracy: 0.6901 roc: 0.7470 acc_sens: 0.7230 parity: 0.0045 equality: 0.0202
Optimization Finished!
Total time elapsed: 56.8426s
Test: accuracy: 0.6901 roc: 0.7470 acc_sens: 0.7230 parity: 0.0045 equality: 0.0202
Test: accuracy: 0.6943 roc: 0.7517 parity: 0.0179 equality: 0.0533
In epoch 0, loss: 0.7478482127189636
AUC 0.5676658204666514
AUC 0.6232696604835981
AUC 0.6076970319221269
AUC 0.6148915115774511
AUC 0.6322087870411813
In epoch 5, loss: 0.5891706347465515
AUC 0.6520317769175314
AUC 0.666793155346606
AUC 0.6769902641442389
AUC 0.6857401915395943
AUC 0.693596868262716
In epoch 10, loss: 0.5613959431648254
AUC 0.7023912126171498
AUC 0.7122648036599574
AUC 0.7232381415260984
AUC 0.7339753007540633
AUC 0.7455303299496723
In epoch 15, loss: 0.5288335084915161
AUC 0.7571526252131159
AUC 0.7694317789675449
AUC 0.7814993115371343
AUC 0.7926125201157573
AUC 0.805294

In [7]:
torch.special.expit(pos_score)

tensor([0.9998, 1.0000, 0.9995,  ..., 1.0000, 0.7308, 0.4055], device='cuda:0')

In [95]:
def generate_auxiliary_pred(feats,pred):

    def compute_similarity(pred, a, b, eps=1e-8):
        """
        eps for numerical stability
        """
        n=a.shape[0]
        sim_mat= torch.zeros([n,n], dtype=torch.float32).cuda()
        for i in range(n):
            x=a[i].repeat(n, 1)
#             print(x)
            sim_mat[i]= torch.special.expit(pred(x,b))
#             print(sim_mat[i])
        sim_mat=(sim_mat+torch.transpose(sim_mat, 0, 1))/2
        return sim_mat

    sim_mat = compute_similarity(pred,feats,feats)
#     print(sim_mat)
    thresh = torch.mean(sim_mat)  + 0.6824*torch.std(sim_mat)
    sim_mat = torch.where(sim_mat>thresh, sim_mat, torch.zeros_like(sim_mat))
    print(thresh)
#     print(sim_mat)
    adj = sp.coo_matrix(sim_mat.cpu().detach().numpy(), dtype=np.float32)
    G_aux = dgl.from_scipy(adj, eweight_name='sim')

    return G_aux,adj
sim_graph=None
sim_graph,adj=generate_auxiliary_pred(z,pred)
sim_graph.edges()[0].shape,sim_graph.edata['sim']

tensor(0.9999, device='cuda:0', grad_fn=<AddBackward0>)


(torch.Size([24650]),
 tensor([1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]))

In [96]:
404*404

163216

In [101]:
def generate_auxiliary1(feats):

    def compute_similarity(a, b, eps=1e-8):
        """
        eps for numerical stability
        """
        a_n, b_n = a.norm(dim=1)[:, None], b.norm(dim=1)[:, None]
        a_norm = a / torch.max(a_n, eps * torch.ones_like(a_n))
        b_norm = b / torch.max(b_n, eps * torch.ones_like(b_n))
        sim_mt = torch.mm(a_norm, b_norm.transpose(0, 1))
        return sim_mt

    sim_mat = compute_similarity(feats,feats)
    thresh = torch.mean(sim_mat)  + 0.75*torch.std(sim_mat)
    sim_mat = torch.where(sim_mat>thresh, sim_mat, torch.zeros_like(sim_mat))
    adj = sp.coo_matrix(sim_mat.cpu().detach().numpy(), dtype=np.float32)
    G_aux = dgl.from_scipy(adj, eweight_name='sim')

    return G_aux

In [102]:
sim_graph=generate_auxiliary1(z)
sim_graph.edges()[0].shape,sim_graph.edata['sim']

(torch.Size([47093]),
 tensor([1.0000, 0.9773, 0.9534,  ..., 0.9907, 0.9735, 1.0000]))

In [105]:

args.model="GCN"
args.num_hidden=128
args.acc=0.69
args.roc=0.76
args.alpha=100
args.beta=1

In [106]:
# model = FairGNN(nfeat = features.shape[1], args = args)
model = VarFairGNN(nfeat = features.shape[1], args = args,aux_graph=sim_graph)
# model.estimator.load_state_dict(torch.load("./checkpoint/GCN_sens_{}_ns_{}".format(dataset,sens_number)))
if args.cuda:
    G = G.to(torch.device('cuda:0'))
    model.cuda()
    features = features.cuda()
    labels = labels.cuda()
    idx_train = idx_train.cuda()
    idx_val = idx_val.cuda()
    idx_test = idx_test.cuda()
    sens = sens.cuda()
    idx_sens_train = idx_sens_train.cuda()


# Train model
t_total = time.time()
best_result = {}
best_fair = 100

acc_test_list = []
roc_test_list = []
parity_list = []
equality_list = []
loss_list = {'cls': [], 'lip':[]}

for epoch in range(args.epochs):
    t = time.time()
    model.train()
    model.optimize(G,features,labels,idx_train,sens,idx_sens_train)
    cls_loss = model.cls_loss
    lip_loss = model.lip_loss
    loss_list['cls'].append(cls_loss.item())
    loss_list['lip'].append(lip_loss.item())
    model.eval()
    output = model(G, features)
    acc_val = accuracy(output[idx_val], labels[idx_val])
    roc_val = roc_auc_score(labels[idx_val].cpu().numpy(),output[idx_val].detach().cpu().numpy())


    parity_val, equality_val = fair_metric(output,idx_val)

    acc_test = accuracy(output[idx_test], labels[idx_test])
    roc_test = roc_auc_score(labels[idx_test].cpu().numpy(),output[idx_test].detach().cpu().numpy())
    parity,equality = fair_metric(output,idx_test)
    if acc_val > args.acc and roc_val > args.roc:
        acc_test_list.append(acc_test.item())
        roc_test_list.append(roc_test)
        parity_list.append(parity)
        equality_list.append(equality)
        if best_fair > parity_val + equality_val :
            best_fair = parity_val + equality_val

            best_result['acc'] = acc_test.item()
            best_result['roc'] = roc_test
            best_result['parity'] = parity
            best_result['equality'] = equality

        print("=================================")

        print('Epoch: {:04d}'.format(epoch+1),
            'cls: {:.4f}'.format(cls_loss.item()),
            'lip: {:.4f}'.format(lip_loss.item()),
            'acc_val: {:.4f}'.format(acc_val.item()),
            "roc_val: {:.4f}".format(roc_val),
            "parity_val: {:.4f}".format(parity_val),
            "equality: {:.4f}".format(equality_val))
        print("Test:",
                "accuracy: {:.4f}".format(acc_test.item()),
                "roc: {:.4f}".format(roc_test),
                "parity: {:.4f}".format(parity),
                "equality: {:.4f}".format(equality))

print("Optimization Finished!")
print("Total time elapsed: {:.4f}s".format(time.time() - t_total))

print('============performace on test set=============')
if len(best_result) > 0:
    print("Test:",
            "accuracy: {:.4f}".format(best_result['acc']),
            "roc: {:.4f}".format(best_result['roc']),
            "parity: {:.4f}".format(best_result['parity']),
            "equality: {:.4f}".format(best_result['equality']))
    print("Test:",
            "accuracy: {:.4f}".format(np.mean(acc_test_list)),
            "roc: {:.4f}".format(np.mean(roc_test_list)),             
            "parity: {:.4f}".format(np.mean(parity_list)),
            "equality: {:.4f}".format(np.mean(equality_list)))
else:
    print("Please set smaller acc/roc thresholds")

with open('loss.pk', 'wb') as handle:
    pk.dump(loss_list, handle, protocol=pk.HIGHEST_PROTOCOL)


Epoch: 0456 cls: 0.6716 lip: 64.3662 acc_val: 0.6948 roc_val: 0.7774 parity_val: 0.0611 equality: 0.0783
Test: accuracy: 0.6948 roc: 0.7774 parity: 0.0611 equality: 0.0783
Epoch: 0585 cls: 0.6706 lip: 71.4568 acc_val: 0.6948 roc_val: 0.7806 parity_val: 0.0351 equality: 0.0542
Test: accuracy: 0.6948 roc: 0.7806 parity: 0.0351 equality: 0.0542
Epoch: 0699 cls: 0.6672 lip: 74.3355 acc_val: 0.6948 roc_val: 0.7815 parity_val: 0.0715 equality: 0.1140
Test: accuracy: 0.6948 roc: 0.7815 parity: 0.0715 equality: 0.1140
Epoch: 0734 cls: 0.6731 lip: 44.9744 acc_val: 0.6948 roc_val: 0.7694 parity_val: 0.0975 equality: 0.1381
Test: accuracy: 0.6948 roc: 0.7694 parity: 0.0975 equality: 0.1381
Epoch: 0868 cls: 0.6698 lip: 59.3132 acc_val: 0.7089 roc_val: 0.7831 parity_val: 0.0597 equality: 0.0908
Test: accuracy: 0.7089 roc: 0.7831 parity: 0.0597 equality: 0.0908
Epoch: 0945 cls: 0.6679 lip: 68.2695 acc_val: 0.7042 roc_val: 0.7778 parity_val: 0.0506 equality: 0.1024
Test: accuracy: 0.7042 roc: 0.7778 

In [107]:
sim_graph=None

In [108]:
# model = FairGNN(nfeat = features.shape[1], args = args)
model = VarFairGNN(nfeat = features.shape[1], args = args,aux_graph=sim_graph)
# model.estimator.load_state_dict(torch.load("./checkpoint/GCN_sens_{}_ns_{}".format(dataset,sens_number)))
if args.cuda:
    G = G.to(torch.device('cuda:0'))
    model.cuda()
    features = features.cuda()
    labels = labels.cuda()
    idx_train = idx_train.cuda()
    idx_val = idx_val.cuda()
    idx_test = idx_test.cuda()
    sens = sens.cuda()
    idx_sens_train = idx_sens_train.cuda()


# Train model
t_total = time.time()
best_result = {}
best_fair = 100

acc_test_list = []
roc_test_list = []
parity_list = []
equality_list = []
loss_list = {'cls': [], 'lip':[]}

for epoch in range(args.epochs):
    t = time.time()
    model.train()
    model.optimize(G,features,labels,idx_train,sens,idx_sens_train)
    cls_loss = model.cls_loss
    lip_loss = model.lip_loss
    loss_list['cls'].append(cls_loss.item())
    loss_list['lip'].append(lip_loss.item())
    model.eval()
    output = model(G, features)
    acc_val = accuracy(output[idx_val], labels[idx_val])
    roc_val = roc_auc_score(labels[idx_val].cpu().numpy(),output[idx_val].detach().cpu().numpy())


    parity_val, equality_val = fair_metric(output,idx_val)

    acc_test = accuracy(output[idx_test], labels[idx_test])
    roc_test = roc_auc_score(labels[idx_test].cpu().numpy(),output[idx_test].detach().cpu().numpy())
    parity,equality = fair_metric(output,idx_test)
    if acc_val > args.acc and roc_val > args.roc:
        acc_test_list.append(acc_test.item())
        roc_test_list.append(roc_test)
        parity_list.append(parity)
        equality_list.append(equality)
        if best_fair > parity_val + equality_val :
            best_fair = parity_val + equality_val

            best_result['acc'] = acc_test.item()
            best_result['roc'] = roc_test
            best_result['parity'] = parity
            best_result['equality'] = equality

        print("=================================")

        print('Epoch: {:04d}'.format(epoch+1),
            'cls: {:.4f}'.format(cls_loss.item()),
            'lip: {:.4f}'.format(lip_loss.item()),
            'acc_val: {:.4f}'.format(acc_val.item()),
            "roc_val: {:.4f}".format(roc_val),
            "parity_val: {:.4f}".format(parity_val),
            "equality: {:.4f}".format(equality_val))
        print("Test:",
                "accuracy: {:.4f}".format(acc_test.item()),
                "roc: {:.4f}".format(roc_test),
                "parity: {:.4f}".format(parity),
                "equality: {:.4f}".format(equality))

print("Optimization Finished!")
print("Total time elapsed: {:.4f}s".format(time.time() - t_total))

print('============performace on test set=============')
if len(best_result) > 0:
    print("Test:",
            "accuracy: {:.4f}".format(best_result['acc']),
            "roc: {:.4f}".format(best_result['roc']),
            "parity: {:.4f}".format(best_result['parity']),
            "equality: {:.4f}".format(best_result['equality']))
    print("Test:",
            "accuracy: {:.4f}".format(np.mean(acc_test_list)),
            "roc: {:.4f}".format(np.mean(roc_test_list)),             
            "parity: {:.4f}".format(np.mean(parity_list)),
            "equality: {:.4f}".format(np.mean(equality_list)))
else:
    print("Please set smaller acc/roc thresholds")

with open('loss.pk', 'wb') as handle:
    pk.dump(loss_list, handle, protocol=pk.HIGHEST_PROTOCOL)


Epoch: 0534 cls: 0.6699 lip: 106.5875 acc_val: 0.7136 roc_val: 0.7639 parity_val: 0.0088 equality: 0.0671
Test: accuracy: 0.7136 roc: 0.7639 parity: 0.0088 equality: 0.0671
Epoch: 0551 cls: 0.6682 lip: 107.8012 acc_val: 0.6948 roc_val: 0.7676 parity_val: 0.0168 equality: 0.0417
Test: accuracy: 0.6948 roc: 0.7676 parity: 0.0168 equality: 0.0417
Epoch: 0677 cls: 0.6680 lip: 114.9420 acc_val: 0.7089 roc_val: 0.7673 parity_val: 0.0052 equality: 0.0293
Test: accuracy: 0.7089 roc: 0.7673 parity: 0.0052 equality: 0.0293
Epoch: 0747 cls: 0.6666 lip: 116.7011 acc_val: 0.7230 roc_val: 0.7662 parity_val: 0.0690 equality: 0.0783
Test: accuracy: 0.7230 roc: 0.7662 parity: 0.0690 equality: 0.0783
Epoch: 0781 cls: 0.6697 lip: 94.8912 acc_val: 0.6948 roc_val: 0.7641 parity_val: 0.0326 equality: 0.0301
Test: accuracy: 0.6948 roc: 0.7641 parity: 0.0326 equality: 0.0301
Epoch: 0815 cls: 0.6638 lip: 129.2169 acc_val: 0.6901 roc_val: 0.7647 parity_val: 0.0261 equality: 0.0181
Test: accuracy: 0.6901 roc: 0.

Epoch: 1812 cls: 0.6632 lip: 124.6200 acc_val: 0.6995 roc_val: 0.7604 parity_val: 0.0806 equality: 0.1024
Test: accuracy: 0.6995 roc: 0.7604 parity: 0.0806 equality: 0.1024
Epoch: 1904 cls: 0.6608 lip: 146.0544 acc_val: 0.7183 roc_val: 0.7681 parity_val: 0.0258 equality: 0.0671
Test: accuracy: 0.7183 roc: 0.7681 parity: 0.0258 equality: 0.0671
Epoch: 1908 cls: 0.6631 lip: 128.1092 acc_val: 0.7230 roc_val: 0.7734 parity_val: 0.0193 equality: 0.0194
Test: accuracy: 0.7230 roc: 0.7734 parity: 0.0193 equality: 0.0194
Epoch: 1975 cls: 0.6607 lip: 133.2095 acc_val: 0.7089 roc_val: 0.7635 parity_val: 0.0023 equality: 0.0551
Test: accuracy: 0.7089 roc: 0.7635 parity: 0.0023 equality: 0.0551
Epoch: 1983 cls: 0.6582 lip: 162.2287 acc_val: 0.7089 roc_val: 0.7688 parity_val: 0.0128 equality: 0.0430
Test: accuracy: 0.7089 roc: 0.7688 parity: 0.0128 equality: 0.0430
Optimization Finished!
Total time elapsed: 38.3942s
Test: accuracy: 0.7042 roc: 0.7678 parity: 0.0092 equality: 0.0052
Test: accuracy: 