In [1]:
#!pip install torch
#!pip install torchmetrics
#!pip install torchvision 
#!pip install torch==1.10.0+cu111 torchvision==0.11.0+cu111 torchaudio==0.10.0 -f https://download.pytorch.org/whl/torch_stable.html


In [1]:
from tqdm import tqdm
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torchvision.transforms as transforms

import medmnist
from medmnist import INFO, Evaluator

from torchsummary import summary

import random
from torch.utils.data import Subset

import subprocess

import tenseal as ts
from time import time
import pandas as pd

print(f"MedMNIST v{medmnist.__version__} @ {medmnist.HOMEPAGE}")


MedMNIST v2.1.0 @ https://github.com/MedMNIST/MedMNIST/


# We first work on a 2D dataset

In [2]:
#data_flag = 'pathmnist'
#data_flag = 'breastmnist'
data_flag = 'dermamnist'

verbose = False
#scheme = ["plain", "tfhe", "ckks"][1]
#n_parts = 2
#poly_choice = 0
#N_rounds = 10

print(f"scheme: {scheme}, n_parts: {n_parts}, poly_choice: {poly_choice}")

download = True

NUM_EPOCHS = 5
BATCH_SIZE = 64
lr = 0.001

info = INFO[data_flag]
task = info['task']
n_channels = info['n_channels']
n_classes = len(info['label'])
print(f"n_channels: {n_channels}, n_classes: {n_classes}")

DataClass = getattr(medmnist, info['python_class'])


scheme: tfhe, n_parts: 2, poly_choice: 0
n_channels: 3, n_classes: 7


## First, we read the MedMNIST data, preprocess them and encapsulate them into dataloader form.

In [3]:
# preprocessing
data_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[.5], std=[.5])
])

# load the data
train_dataset = DataClass(split='train', transform=data_transform, download=download)
test_dataset = DataClass(split='test', transform=data_transform, download=download)

pil_dataset = DataClass(split='train', download=download)

# encapsulate data into dataloader form
train_loader = data.DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)
train_loader_at_eval = data.DataLoader(dataset=train_dataset, batch_size=2*BATCH_SIZE, shuffle=False)
test_loader = data.DataLoader(dataset=test_dataset, batch_size=2*BATCH_SIZE, shuffle=False)


Using downloaded and verified file: /home/jovyan/.medmnist/dermamnist.npz
Using downloaded and verified file: /home/jovyan/.medmnist/dermamnist.npz
Using downloaded and verified file: /home/jovyan/.medmnist/dermamnist.npz


In [4]:
n_subset = int(len(train_dataset)/n_parts)+1

idxs = list(range(len(train_dataset)))
random.shuffle(idxs)

print(len(train_dataset), len(idxs), n_subset, n_subset*n_parts)

subsets = [Subset(train_dataset, idxs[i*n_subset:(i+1)*n_subset]) for i in range(n_parts)] 
loaders = [data.DataLoader(dataset=dataset, batch_size=BATCH_SIZE, shuffle=True) for dataset in subsets]


7007 7007 3504 7008


In [5]:
if verbose:
    print(train_dataset)
    print("===================")
    print(test_dataset)

In [6]:
# montage

if verbose:
    train_dataset.montage(length=20)


## Then, we define a simple model for illustration, object function and optimizer that we use to classify.

In [7]:
from torchvision.models import resnet18, mobilenet_v3_small, mnasnet1_3
from torch import nn

def init():
    #model = resnet18(num_classes=n_classes) # MNIST has 10 classes
    model = resnet18(num_classes=n_classes) # MNIST has 10 classes
    for i, param in enumerate(model.parameters()):
        pass

    model_length = i
    for i, param in enumerate(model.parameters()):
        if i < model_length - 1:
            param.requires_grad = False
            
    if task == "multi-label, binary-class":
        criterion = nn.BCEWithLogitsLoss()
    else:
        criterion = nn.CrossEntropyLoss()

    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)

    return model, criterion, optimizer


In [8]:
#state_dict = model.state_dict()

def flatten_trainable(model):
    all_trainable = []
    state_dict = model.state_dict().copy()
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad: 
            continue
        no_params = parameter.numel()
        #print(name, no_params)
        all_trainable += state_dict[name].flatten().tolist()
    return all_trainable

def reshape_trainable(model, all_trainable):
    state_dict = model.state_dict().copy()
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad: 
            continue
        n = parameter.numel()
        #print(name, n)
        this_param, all_trainable = all_trainable[:n], all_trainable[n:]
        this_shape = state_dict[name].shape
        state_dict[name] = torch.tensor(this_param).reshape(this_shape)
    return state_dict

#trainable_params = flatten_trainable(model)
#print(len(trainable_params))

#test = reshape_trainable(model, trainable_params)
#for name, parameter in model.named_parameters():
#    print((test[name] == state_dict[name]).all())


In [9]:
def addition_tree_plain(parts):
    #print("addition_tree_plain")
    n = len(parts)
    times = [-1,-1,-1,-1]
    if n==1:
        return times, parts[0]
    else:
        return times, addition_tree_plain(parts[:n//2])[1] + addition_tree_plain(parts[n//2:])[1] 

def addition_tree_tfhe(parts, key):
    #print("addition_tree_tfhe")
    n_vectors = len(parts)
    m_entries = len(parts[0])
    v_concat = np.concatenate(parts)
    if np.max(v_concat) > upper:
        print(f"warning {np.max(v_concat)}>{upper}")
    if np.min(v_concat) < lower:
        print(f"warning {np.min(v_concat)}<{lower}")        
    all_txt = " ".join(["%f"%a for a in v_concat])
    command = f"target/release/add_vectors {key} {prec} {padd} {lower} {upper} {n_vectors} {m_entries} {all_txt}"
    #print(command)
    out = subprocess.getoutput(command).split("\n")[-1]
    #print(out)
    times, result = out.split(" [")
    #print(times)
    times = [float(a) for a in times.split(" ")]
    result = [float(a) for a in result[:-1].split(",")]
    result = np.array(result)
    return times, result

def addition_tree_ckks(vecs, context):
    #print("addition_tree_ckks")
    n_vectors = len(vecs)
    t0 = time()
    vecs_enc = [ts.ckks_vector(context, v) for v in vecs]
    enc_time = (time()-t0)/n_vectors
    t0 = time()
    _, aggr_enc = addition_tree_plain(vecs_enc)
    add_time = (time()-t0)
    t0 = time()
    aggr_vec = aggr_enc.decrypt()
    dec_time = (time()-t0)
    return (-1, enc_time, add_time, dec_time), aggr_vec

testlist = [1,2,3,4]
#addition_tree_plain(testlist)
#addition_tree_ckks([torch.tensor([a]) for a in testlist], context)
#addition_tree_ckks([np.array([a]) for a in testlist], context)
#addition_tree_ckks([[a] for a in testlist], context)


In [10]:
# cryptographic parameters

if scheme=="tfhe":
    context = [256, 512, 1024][poly_choice]
    key = {None:"N/A",  256:"keys/def_80_256_1", 512:"keys/def_80_512_1", 1024:"keys/def_80_2014_1"}[context]
    prec = 6
    padd = int(np.log2(n_parts))
    lower = -1.0
    upper = 1.0
    addition_tree = lambda parts: addition_tree_tfhe(parts, key)
    max_vec_length = 500
    model_name = f"tfhe_{context}"
    
elif scheme=="ckks":
    param = [4096,8192][poly_choice]
    if param==4096:
        poly_mod_degree = 4096
        coeff_mod_bit_sizes = [40, 20, 40]
        context = ts.context(ts.SCHEME_TYPE.CKKS, poly_mod_degree, -1, coeff_mod_bit_sizes)
        context.global_scale = 2 ** 20
    else:
        poly_mod_degree = 8192
        #coeff_mod_bit_sizes = [40, 21, 21, 21, 21, 21, 21, 40]
        coeff_mod_bit_sizes = [60,40,40,60]
        context = ts.context(ts.SCHEME_TYPE.CKKS, poly_mod_degree, -1, coeff_mod_bit_sizes)
        #context.global_scale = 2 ** 21
        context.global_scale = 2 ** 40
    # this key is needed for doing dot-product operations
    context.generate_galois_keys()
    addition_tree = lambda parts: addition_tree_ckks(parts, context)
    max_vec_length = param//2
    model_name = f"ckks_{param}"

else:
    addition_tree = addition_tree_plain
    max_vec_length = 10000000000000000
    model_name = f"plain_text"
    
print(model_name)
    

tfhe_256


## Next, we can start to train and evaluate!

In [11]:
def train_distr(verbose = False):

    model, criterion, optimizer = init()
    global_state = model.state_dict().copy()

    for epoch in range(NUM_EPOCHS):
        if verbose: print(f"epoch: {epoch+1}")
        params = []

        for i in range(n_parts):
            if verbose: print(f"part: {i+1}")
            model.load_state_dict(global_state.copy())
            model.train()
             
            for inputs, targets in loaders[i]:
                # forward + backward + optimize
                optimizer.zero_grad()
                outputs = model(inputs)
                if task == 'multi-label, binary-class':
                    targets = targets.to(torch.float32)
                    loss = criterion(outputs, targets)
                else:
                    targets = targets.squeeze().long()
                    loss = criterion(outputs, targets)
                loss.backward()
                optimizer.step()

            trainable_params = flatten_trainable(model)        
            n_chunks = int(np.ceil((len(trainable_params)/max_vec_length)))
            params.append(np.array_split(trainable_params, n_chunks))

        aggr_params = []
        for j in range(n_chunks):
            times, aggr_chunk = addition_tree([params[i][j] for i in range(n_parts)])
            aggr_params.append(aggr_chunk)
        aggr_params = np.concatenate(aggr_params)/n_parts

        global_state = reshape_trainable(model, list(aggr_params))
        
    model.load_state_dict(global_state.copy())
    
    return [t*n_chunks for t in times], model
    
if verbose:
    times, model = train_distr(verbose = False)
    times 


In [12]:
# evaluation

def test(split, model):
    model.eval()
    y_true = torch.tensor([])
    y_score = torch.tensor([])
    
    data_loader = train_loader_at_eval if split == 'train' else test_loader

    with torch.no_grad():
        for inputs, targets in data_loader:
            outputs = model(inputs)

            if task == 'multi-label, binary-class':
                targets = targets.to(torch.float32)
                outputs = outputs.softmax(dim=-1)
            else:
                targets = targets.squeeze().long()
                outputs = outputs.softmax(dim=-1)
                targets = targets.float().resize_(len(targets), 1)

            y_true = torch.cat((y_true, targets), 0)
            y_score = torch.cat((y_score, outputs), 0)

        y_true = y_true.numpy()
        y_score = y_score.detach().numpy()
        
        evaluator = Evaluator(data_flag, split)
        metrics = evaluator.evaluate(y_score)
    return metrics
        
if verbose:
    print('==> Evaluating ...')

    metrics = test('train', model)
    print('train  auc: %.3f  acc:%.3f' % metrics)

    metrics = test('test', model)
    print('test  auc: %.3f  acc:%.3f' % metrics)


## Collect stats for training and evaluation

In [15]:
tst = []

for i in range(N_rounds):
    row = model_name.split("_")
    row.append(n_parts)
            
    t0 = time()
    times, model = train_distr(verbose=False)
    row.append(time()-t0)

    accuracy = test('test', model)[1]
    row.append(accuracy)
    
    #row.append(times[0])
    row.append(times[1])
    row.append(times[2])
    row.append(times[3])
    
    tst.append(row)
    if N_rounds <= 10 or i%(N_rounds//10)==0 or i == N_rounds-1:
        print(f"Round {i+1} completed in {time()-t0}")

res = pd.DataFrame(tst, columns=["model", "param", "n_parts", "crypt_time","crypt_acc","enc_time","add_time","dec_time"])
res

Round 1 completed in 277.39720129966736
Round 2 completed in 257.21616435050964
Round 3 completed in 273.9937992095947
Round 4 completed in 268.4945878982544
Round 5 completed in 272.4022145271301
Round 6 completed in 266.8021488189697
Round 7 completed in 271.58902168273926
Round 8 completed in 282.976717710495
Round 9 completed in 265.8248641490936
Round 10 completed in 269.79818201065063


Unnamed: 0,model,param,n_parts,crypt_time,crypt_acc,enc_time,add_time,dec_time
0,plain,none,8,266.208445,0.693766,-1,-1,-1
1,plain,none,8,248.608489,0.682294,-1,-1,-1
2,plain,none,8,260.884558,0.676808,-1,-1,-1
3,plain,none,8,257.798172,0.687781,-1,-1,-1
4,plain,none,8,261.999813,0.688279,-1,-1,-1
5,plain,none,8,258.201778,0.682294,-1,-1,-1
6,plain,none,8,259.196676,0.680798,-1,-1,-1
7,plain,none,8,252.307324,0.684289,-1,-1,-1
8,plain,none,8,254.130729,0.686284,-1,-1,-1
9,plain,none,8,261.705513,0.689776,-1,-1,-1


In [16]:
tmp = res.describe()
tmp.to_csv(f"figs/stats_{data_flag}_{model_name}_{n_parts}_{N_rounds}.csv")
tmp


Unnamed: 0,n_parts,crypt_time,crypt_acc,enc_time,add_time,dec_time
count,10.0,10.0,10.0,10.0,10.0,10.0
mean,8.0,258.10415,0.685237,-1.0,-1.0,-1.0
std,0.0,5.18986,0.004934,0.0,0.0,0.0
min,8.0,248.608489,0.676808,-1.0,-1.0,-1.0
25%,8.0,255.04759,0.682294,-1.0,-1.0,-1.0
50%,8.0,258.699227,0.685287,-1.0,-1.0,-1.0
75%,8.0,261.500274,0.688155,-1.0,-1.0,-1.0
max,8.0,266.208445,0.693766,-1.0,-1.0,-1.0


In [17]:
res.to_csv(f"figs/results_{data_flag}_{model_name}_{n_parts}_{N_rounds}.csv")



In [19]:
#model = resnet18(num_classes=n_classes)
#summary(model, (3, 28, 28)) # 


----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 14, 14]           9,408
       BatchNorm2d-2           [-1, 64, 14, 14]             128
              ReLU-3           [-1, 64, 14, 14]               0
         MaxPool2d-4             [-1, 64, 7, 7]               0
            Conv2d-5             [-1, 64, 7, 7]          36,864
       BatchNorm2d-6             [-1, 64, 7, 7]             128
              ReLU-7             [-1, 64, 7, 7]               0
            Conv2d-8             [-1, 64, 7, 7]          36,864
       BatchNorm2d-9             [-1, 64, 7, 7]             128
             ReLU-10             [-1, 64, 7, 7]               0
       BasicBlock-11             [-1, 64, 7, 7]               0
           Conv2d-12             [-1, 64, 7, 7]          36,864
      BatchNorm2d-13             [-1, 64, 7, 7]             128
             ReLU-14             [-1, 6