In [1]:
from GaussianGenerator import GaussianGenerator
import numpy as np
import threading
import communicator
import matplotlib.pyplot as plt
%matplotlib inline
import pickle
import socket
import time
import pandas as pd

In [2]:
DEBUG = True
BACKLOG = 5
HEADERSIZE = 10


class Server:

	def __init__(self, port):
		"""
		Server object that can accept request at port=port
		Create a SOCK_STREAM socket and bind at localhost at port
		"""
		self.s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
		self.s.bind((socket.gethostname(), port))
		self.s.listen(BACKLOG)  # backlog = 5 by default
		self.clientsock = None
		self.clientaddress = None

	def recv(self):
		"""
		Blocking recv. Block to accept and receive.
		:return: msg to be pickled
		"""
		self.clientsock, self.clientaddress = self.s.accept()

		#if DEBUG:
			#print(f"Received message from {self.clientaddress}")

		full_msg = b''
		new_msg = True
		msglen = 0
		while True:  # receive full length
			msg = self.clientsock.recv(16)
			if new_msg:
				#print("new msg len:", msg[:HEADERSIZE])
				msglen += int(msg[:HEADERSIZE])
				new_msg = False
			full_msg += msg
			if len(full_msg) - HEADERSIZE == msglen:
				#print("full msg recvd")
				return pickle.loads(full_msg[HEADERSIZE:])

	def reply(self, msg, port):
		"""
		Reply to msg that we just received from
		:return: num_byte_sent
		"""
		
		loc = (self.clientaddress[0],port)
		print(loc)
		return self.send(loc, msg)

	def get_socket(self):
		return self.s

	def send(self, location: (str, str), msgtosend):
		"""
		Send msg to location
		:param location: specify (address, port) to be sent to
		:param msgtosend: the msg will be pickled
		:return:
		"""
		s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
		s.connect(location)
		msg = pickle.dumps(msgtosend)
		msg = bytes(f"{len(msg):<{HEADERSIZE}}", 'utf-8') + msg
		return s.send(msg)


In [3]:
def ClassAccuracy(X,Y,w):
	s = 0
	for i in range(X.shape[0]):
		if ((np.dot(w,X[i])>=0 and Y[i]==1) or (np.dot(w,X[i])<0 and Y[i]==-1)):
			s += 1
	s = s/X.shape[0]
	return s

In [4]:
PORT = 1232
PORT2 = 1233
HOSTADDR = "container1"
serv = Server(PORT)
loc_model = (HOSTADDR, PORT)
loc_server = (HOSTADDR, PORT2)

In [5]:
data = pd.read_csv("data_banknote_authentication.txt", header = None)
d_full = data.to_numpy()
np.random.seed(4)
np.random.shuffle(d_full)

#data set
X_full = d_full[:,:4]
Y_full = d_full[:,4]
for i in range(Y_full.shape[0]):
    if Y_full[i]==0:
        Y_full[i] = -1
print(X_full.shape[0])

#test set
X_test = X_full[:274,:]
Y_test = Y_full[:274]
print(X_test.shape)

#validation set
n = 55
X_val = X_full[274:329,:]
Y_val = Y_full[274:329]
print(X_val.shape)

#worker set
worker = 1
if worker <=3:
    X = X_full[329+(worker-1)*209:329+worker*209,:]
    Y = Y_full[329+(worker-1)*209:329+worker*209]
else:
    X = X_full[956+(worker-4)*208:956+(worker-3)*208,:]
    Y = Y_full[956+(worker-4)*208:956+(worker-3)*208]

1372
(274, 4)
(55, 4)


In [None]:
n = 10
kw = 5.0
maxit = 40
step = 0
np.random.seed(int(time.time()))
dif = 1
Acc=[]
while step<maxit:
    delay = np.random.uniform(high=kw)
    time.sleep(delay)

    #send request to server for model w
    msg = "Req"
    serv.send(loc_model, msg)
    #print(f"Sent {str(msg)}")

    #wait for model 
    w = serv.recv()
    #print(f"Received weights {str(w)}")


    #get sample
    r = np.random.choice(range(X.shape[0]), n)
    x = X[r]
    y = Y[r]

    #compute gradient
    g = np.array([0,0,0,0])
    for i in range(n):
        if ((np.dot(w,x[i])>=0 and y[i]==-1) or (np.dot(w,x[i])<0 and y[i]==1)):
            g = g -np.multiply(y[i],x[i])

    g = g/n
    #print(f"Grad computed is {str(g)}")

    #Byzantine behavior
    mal = 'y'
    #mal = input("Is the update malicious? [y/n]")
    if mal== 'y':
        g = -np.multiply(10,g)
        #g = 100*np.random.randn(3)+3
        #print(f"Malicious grad is {str(g)}")

    #send g to server
    serv.send(loc_server, g)
    s = ClassAccuracy(X, Y, w)
    Acc.append(s)
    if len(Acc)>1:
        dif = Acc[len(Acc)-2]-Acc[len(Acc)-1]
    else:
        dif = 1
    step += 1