# VFLの実装

In [1]:
import os
import random
import pandas as pd
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

from sklearn.metrics import roc_auc_score
from sklearn.preprocessing import OneHotEncoder, MinMaxScaler

from tqdm.notebook import tqdm
import time

In [2]:
def load_data():
    tr = pd.read_csv('adult.data', header=None)
    te = pd.read_csv('adult.test', header=None, skiprows=1) # 1行目は不要

    h = [
        'age', 'workclass', 'fnlwgt', 'education', 'education_num', 'marital_status', 'occupation', 'relationship', 
        'race', 'sex', 'capital_gain', 'capital_loss', 'hours_per_week', 'native_country', 'over_50k'
    ]

    tr.columns = h
    te.columns = h
    te.over_50k = te.over_50k.str.split('.', expand=True)[0] # 行末尾に不要なドットが含まれているので除去

    num_cols = ['age', 'fnlwgt', 'education_num', 'capital_gain', 'capital_loss', 'hours_per_week']
    cat_cols = ['workclass', 'education', 'marital_status', 'occupation', 'relationship', 'race', 'sex', 'native_country']
    lab_col  = 'over_50k'
    
    return tr, te, num_cols, cat_cols, lab_col

def create_encoders(dat, cat_cols, num_cols):
    encoders = dict()

    # カテゴリ列はOneHotEncoding
    for c in tqdm(cat_cols):
        enc = OneHotEncoder(handle_unknown='ignore')
        enc.fit(dat[c].astype(str).values.reshape(-1, 1))
        encoders[c] = enc

    # 数値列はMinMaxScaler
    for c in tqdm(num_cols):
        scaler = MinMaxScaler()
        scaler.fit(dat[c].values.reshape(-1, 1))
        encoders[c] = scaler

    return encoders

def encode(dat, cat_cols, num_cols, encoders):

    for c in tqdm(cat_cols):
        out = encoders[c].transform(dat[c].astype(str).values.reshape(-1, 1))
        if not type(out) == np.ndarray:
            out = out.todense()
        keys = [f'{c}_{i}' for i in range(out.shape[1])]
        dat[keys] = out
        dat = dat.drop(c, axis=1)

    for c in tqdm(num_cols):
        out = encoders[c].transform(dat[c].values.reshape(-1, 1)).flatten()
        dat[c] = out

    return dat

class AdultDataset(Dataset):
    def __init__(self, x, y):
        # 本来はxはクライアント，yはサーバ側で保有します
        self.x = torch.Tensor(x.values)
        self.y = torch.Tensor(y.values)
    
    def __len__(self):
        return len(self.x)
    
    def __getitem__(self, idx):
        return self.x[idx,:], self.y[idx]

def preprocess(tr, te, cat_cols, num_cols, lab_col):
    encoders = create_encoders(tr, cat_cols, num_cols)
    tr = encode(tr, cat_cols, num_cols, encoders)
    te = encode(te, cat_cols, num_cols, encoders)
    
    tr[lab_col] = tr[lab_col].replace({' <=50K': 0, ' >50K': 1}) 
    te[lab_col] = te[lab_col].replace({' <=50K': 0, ' >50K': 1}) 
    
    tr_x = tr.drop(lab_col, axis=1)
    tr_y = tr[lab_col]
    te_x = te.drop(lab_col, axis=1)
    te_y = te[lab_col]
    
    tr_ds = AdultDataset(tr_x, tr_y)
    te_ds = AdultDataset(te_x, te_y)
    
    tr_dl = DataLoader(tr_ds, batch_size = 1024, shuffle=True)
    te_dl = DataLoader(te_ds, batch_size = 1024, shuffle=False)
    
    # positive weight
    pos_weight = (tr_y.shape[0] - tr_y.sum()) / tr_y.sum()
    
    return tr_dl, te_dl, torch.FloatTensor([pos_weight])

# MLP

In [3]:
seed = 42
def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    
class MLP(torch.nn.Module):
    def __init__(self, in_size, hidden_size, out_size):
        super(MLP, self).__init__()
        self.i2h = torch.nn.Linear(in_size, hidden_size)
        self.h2h = torch.nn.Linear(hidden_size, hidden_size//2) 
        self.h2o = torch.nn.Linear(hidden_size//2, out_size)   

        # 初期値を固定
        set_seed(seed)
        torch.nn.init.xavier_uniform_(self.i2h.weight.data)            
        torch.nn.init.ones_(self.i2h.bias.data)
        set_seed(seed)
        torch.nn.init.xavier_uniform_(self.h2h.weight.data)
        torch.nn.init.ones_(self.h2h.bias.data)
        set_seed(seed)
        torch.nn.init.xavier_uniform_(self.h2o.weight.data)
        torch.nn.init.ones_(self.h2o.bias.data)        
        
    def forward(self, x):
        h = self.i2h(x)
        h = F.relu(h)
        h = self.h2h(h)
        h = F.relu(h)
        o = self.h2o(h)
        return o
        
def train_MLP(tr_dl, te_dl, pos_weight):
    mlp = MLP(108, 32, 1)
    optimizer = torch.optim.Adam(mlp.parameters(), lr=0.01)
    criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)    
    
    for epoch in range(30):
        tr_loss = 0
        mlp.train()
        for i, (batch_x, batch_y) in enumerate(tqdm(tr_dl)):
            optimizer.zero_grad()
            pred_y = mlp(batch_x)
            loss = criterion(pred_y.flatten(), batch_y)
            loss.backward()
            optimizer.step()
            tr_loss += loss.item()
        print(f'Epoch: {epoch}, Training loss: {tr_loss:.4f}')

    te_loss = 0
    pred_y_list = []
    true_y_list = []
    mlp.eval()
    for i, (batch_x, batch_y) in enumerate(tqdm(te_dl)):
        pred_y = mlp(batch_x)
        loss = criterion(pred_y.flatten(), batch_y)
        te_loss += loss.item()
        pred_y_list.extend(torch.sigmoid(pred_y.flatten()).detach().tolist())
        true_y_list.extend(batch_y.detach().tolist())
    
    score = roc_auc_score(true_y_list, pred_y_list)
    print(f'Test loss: {te_loss:.4f}')
    print(f'Test ROC-AUC: {score:.4f}')

# VFL

In [4]:
class ClientModel(torch.nn.Module):
    def __init__(self, in_size, hidden_size):
        super(ClientModel, self).__init__()
        self.i2h = torch.nn.Linear(in_size, hidden_size)

        # 初期値を固定
        set_seed(seed)
        torch.nn.init.xavier_uniform_(self.i2h.weight.data)            
        torch.nn.init.ones_(self.i2h.bias.data)
                
    def forward(self, x):
        h = self.i2h(x)
        h = F.relu(h)
        return h
    
class ServerModel(torch.nn.Module):
    def __init__(self, hidden_size, out_size):
        super(ServerModel, self).__init__()
        self.h2h = torch.nn.Linear(hidden_size, hidden_size//2) 
        self.h2o = torch.nn.Linear(hidden_size//2, out_size)   
        
        # 初期値を固定
        set_seed(seed)
        torch.nn.init.xavier_uniform_(self.h2h.weight.data)
        torch.nn.init.ones_(self.h2h.bias.data)
        set_seed(seed)
        torch.nn.init.xavier_uniform_(self.h2o.weight.data)
        torch.nn.init.ones_(self.h2o.bias.data)
        
    def forward(self, h):
        h = self.h2h(h)
        h = F.relu(h)
        o = self.h2o(h)
        return o

# Forward-propagation
# Backward-propagation

In [8]:
def train_VFL(tr_dl, te_dl, pos_weight):
    client_model = ClientModel(108, 32)
    server_model = ServerModel(32, 1)
    clinet_optimizer = torch.optim.Adam(client_model.parameters(), lr=0.01)
    server_optimizer = torch.optim.Adam(server_model.parameters(), lr=0.01)
    criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)    
    
    for epoch in range(30):
        tr_loss = 0
        client_model.train()
        server_model.train()
        for i, (batch_x, batch_y) in enumerate(tqdm(tr_dl)):
            # クライアントがembeddingをサーバへ送信
            clinet_optimizer.zero_grad()
            h = client_model(batch_x)
            torch.save(h.detach(), 'emb.pt')
            
            # サーバがembeddingを受け取ってサーバモデルを更新
            server_optimizer.zero_grad()
            emb = torch.load('emb.pt')
            emb.requires_grad_(True) # あとでembeddingでの微分を取得できるように事前設定
            pred_y = server_model(emb)
            # (下のコードへ続く)
            # (上のコードの続き。forループ内部で，サーバ処理の続き)
            loss = criterion(pred_y.flatten(), batch_y)
            loss.backward()
            server_optimizer.step()
            tr_loss += loss.item()
            
            # サーバがgradientをクライアントへ返却
            grad = emb.grad # gradient の取得
            torch.save(grad.detach(), 'grad.pt')
            
            # クライアントがgradientを受け取ってクライアントモデルを更新
            dldh = torch.load('grad.pt')
            h.backward(dldh)
            clinet_optimizer.step()
        
        # End of for all batches

        print(f'Epoch: {epoch}, Training loss: {tr_loss:.4f}')

    # End of for all epochs

    te_loss = 0
    pred_y_list = []
    true_y_list = []
    client_model.eval()
    server_model.eval()
    for i, (batch_x, batch_y) in enumerate(tqdm(te_dl)):
        # クライアントがembeddingをサーバへ送信
        h = client_model(batch_x)
        torch.save(h.detach(), 'emb.pt')
        
        # サーバがembeddingを受け取って出力
        emb = torch.load('emb.pt')
        pred_y = server_model(emb)
        loss = criterion(pred_y.flatten(), batch_y)
        te_loss += loss.item()
        pred_y_list.extend(torch.sigmoid(pred_y.flatten()).detach().tolist())
        true_y_list.extend(batch_y.detach().tolist())

        # Testingデータについてはback-propagationが不要
    
    score = roc_auc_score(true_y_list, pred_y_list)
    print(f'Test loss: {te_loss:.4f}')
    print(f'Test ROC-AUC: {score:.4f}')


# 実験

In [10]:
tr, te, num_cols, cat_cols, lab_col = load_data()
tr_dl, te_dl, pos_weight = preprocess(tr, te, cat_cols, num_cols, lab_col)

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  dat[keys] = out
  dat[keys] = out
  dat[keys] = out
  dat[keys] = out


  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  dat[keys] = out
  dat[keys] = out
  dat[keys] = out
  dat[keys] = out
  dat[keys] = out


  0%|          | 0/6 [00:00<?, ?it/s]

# MLP

In [11]:
st = time.time()
train_MLP(tr_dl, te_dl, pos_weight)
print(f'Time: {time.time()-st:.4f}')

  0%|          | 0/32 [00:00<?, ?it/s]

Epoch: 0, Training loss: 28.0662


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch: 1, Training loss: 20.5576


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch: 2, Training loss: 20.0262


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch: 3, Training loss: 19.6655


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch: 4, Training loss: 19.3939


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch: 5, Training loss: 19.1457


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch: 6, Training loss: 19.0693


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch: 7, Training loss: 18.8062


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch: 8, Training loss: 18.5536


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch: 9, Training loss: 18.5164


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch: 10, Training loss: 18.5321


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch: 11, Training loss: 18.5289


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch: 12, Training loss: 18.2194


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch: 13, Training loss: 18.0579


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch: 14, Training loss: 18.1807


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch: 15, Training loss: 17.9560


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch: 16, Training loss: 17.9120


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch: 17, Training loss: 17.8606


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch: 18, Training loss: 17.7908


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch: 19, Training loss: 17.9168


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch: 20, Training loss: 17.7262


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch: 21, Training loss: 17.5858


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch: 22, Training loss: 17.5655


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch: 23, Training loss: 17.5677


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch: 24, Training loss: 17.6287


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch: 25, Training loss: 17.4766


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch: 26, Training loss: 17.3864


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch: 27, Training loss: 17.3278


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch: 28, Training loss: 17.3704


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch: 29, Training loss: 17.3743


  0%|          | 0/16 [00:00<?, ?it/s]

Test loss: 9.2621
Test ROC-AUC: 0.9060
Time: 15.5836


# VFL

In [12]:
st = time.time()
train_VFL(tr_dl, te_dl, pos_weight)
print(f'Time: {time.time()-st:.4f}')

  0%|          | 0/32 [00:00<?, ?it/s]

Epoch: 0, Training loss: 28.0662


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch: 1, Training loss: 20.5576


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch: 2, Training loss: 20.0262


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch: 3, Training loss: 19.6655


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch: 4, Training loss: 19.3939


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch: 5, Training loss: 19.1457


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch: 6, Training loss: 19.0693


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch: 7, Training loss: 18.8062


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch: 8, Training loss: 18.5536


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch: 9, Training loss: 18.5164


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch: 10, Training loss: 18.5321


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch: 11, Training loss: 18.5289


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch: 12, Training loss: 18.2194


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch: 13, Training loss: 18.0579


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch: 14, Training loss: 18.1807


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch: 15, Training loss: 17.9560


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch: 16, Training loss: 17.9120


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch: 17, Training loss: 17.8606


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch: 18, Training loss: 17.7908


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch: 19, Training loss: 17.9168


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch: 20, Training loss: 17.7262


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch: 21, Training loss: 17.5858


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch: 22, Training loss: 17.5655


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch: 23, Training loss: 17.5677


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch: 24, Training loss: 17.6287


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch: 25, Training loss: 17.4766


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch: 26, Training loss: 17.3864


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch: 27, Training loss: 17.3278


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch: 28, Training loss: 17.3704


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch: 29, Training loss: 17.3743


  0%|          | 0/16 [00:00<?, ?it/s]

Test loss: 9.2621
Test ROC-AUC: 0.9060
Time: 16.7219
