In [30]:
EPOCHS = 5
ROUNDS = 10

SERVER_ADRESS = "tcp://localhost:5555"
SUBSCRIPTION_ADRESS = "tcp://localhost:5557"
ID = 123

In [31]:
import pickle
import zmq
import torch
from torch import nn, optim
import torchvision
import torchvision.transforms as transforms
import pickle
import base64
import time

# Helpers
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

def write_data(file_name, data):
    if type(data) == bytes:
        #bytes to base64
        data = base64.b64encode(data)
         
    with open(file_name, 'wb') as f: 
        f.write(data)
 
def read_data(file_name):
    with open(file_name, "rb") as f:
        data = f.read()
    #base64 to bytes
    return base64.b64decode(data)

In [32]:
context = zmq.Context()
# socket is the server socket
socket = context.socket(zmq.DEALER)
socket.connect(SERVER_ADRESS)
identity = str(ID)
socket.identity = identity.encode("ascii")
# sub_socker is subscribed to the server
sub_socket = context.socket(zmq.SUB)
sub_socket.connect(SUBSCRIPTION_ADRESS)
sub_socket.setsockopt_string(zmq.SUBSCRIBE, '')

print("Context created, connecting to server...")

Context created, connecting to server...


In [33]:
from model_def import SimpleNN

# Takes a pickled model, trains it for n epochs and returns the pickled model
def train(model_bytes):
    model = pickle.loads(model_bytes)
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.01)

    # Training setup
    trainset = torchvision.datasets.MNIST('MNIST_data/', download=True, train=True, transform=transform)
    dataset_range = range(int(len(trainset) / 2))
    trainset_range = torch.utils.data.Subset(trainset, list(dataset_range))
    trainloader = torch.utils.data.DataLoader(trainset_range, batch_size=64, shuffle=True)

    start_total = time.time()
    for epoch in range(EPOCHS):
        print(f'Epoch {epoch}/{EPOCHS}')
        start = time.time()
        running_loss = 0.0

        for i, (images, labels) in enumerate(trainloader):
            optimizer.zero_grad()
            output = model(images)
            loss = criterion(output, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            if i % 100 == 99:
                print(f'Epoch {epoch}/{EPOCHS}, batch {i+1}: {running_loss/100}')
                running_loss = 0.0
        
        end = time.time()
        print(f'Time elapsed for epoch {epoch}: {end-start}')

    end_total = time.time()
    print(f'Total time elapsed: {end_total-start_total}')
    print('Finished Training')

    return pickle.dumps(model)

In [34]:
# Run the client
def run():
    model_bytes = None
    for request in range(ROUNDS):
        print(f'Round {request}')
        if request == 0:
            # First round, request initial model
            print("Requesting model...")
            message = b"New"
            socket.send(message)

            model_bytes = socket.recv()
            print("Model received")
        else:
            print("Requesting model update...")
            model_bytes = sub_socket.recv()
            print("Model update received")
        
        print("Training...")
        trained_model_bytes = train(model_bytes)
        print("Training done")
        print("Sending model update...")
        socket.send(trained_model_bytes)

run()

Round 0
Requesting model...
Model received
Training...
Epoch 0/5
Epoch 0/5, batch 100: 1.8930507016181946
Epoch 0/5, batch 200: 1.1756234347820282
Epoch 0/5, batch 300: 0.8002745914459228
Epoch 0/5, batch 400: 0.6495162111520767
Time elapsed for epoch 0: 1.5560247898101807
Epoch 1/5
Epoch 1/5, batch 100: 0.5291568952798843
Epoch 1/5, batch 200: 0.4782855647802353
Epoch 1/5, batch 300: 0.43944292068481444
Epoch 1/5, batch 400: 0.4132793764770031
Time elapsed for epoch 1: 1.5663681030273438
Epoch 2/5
Epoch 2/5, batch 100: 0.38382028594613077
Epoch 2/5, batch 200: 0.3852822887897491
Epoch 2/5, batch 300: 0.37155962705612183
Epoch 2/5, batch 400: 0.3788030073046684
Time elapsed for epoch 2: 1.6420321464538574
Epoch 3/5
Epoch 3/5, batch 100: 0.3627372823655605
Epoch 3/5, batch 200: 0.3496367146074772
Epoch 3/5, batch 300: 0.3305418634414673
Epoch 3/5, batch 400: 0.3500203981250525
Time elapsed for epoch 3: 2.448975086212158
Epoch 4/5
Epoch 4/5, batch 100: 0.316870768815279
Epoch 4/5, batch 