In [3]:
!pip install -q matplotlib

In [None]:
from multiprocessing.connection import Listener
from typing import Dict
import matplotlib.pyplot as plt
from collections import defaultdict

address = ('localhost', 6000)     # family is deduced to be 'AF_INET'

class EpochData:
    def __init__(self, cid:int, epoch: int, loss: float, accuracy: float):
        self.cid = cid
        self.epoch = epoch
        self.loss: float = loss
        self.accuracy:float = accuracy

    def __str__(self):
        return f'cid: {self.cid}, epoch: {self.epoch} loss: {self.loss}, accuracy: {self.accuracy}'

training_matrices = [];

with Listener(address, authkey=b'ri3uhisd') as listener:
    while True:
        try:
            with listener.accept() as conn:
                try:
                    data = conn.recv_bytes().decode('utf-8');
        
                    if data == 'exit':
                        break;
                    
                    data = data.split(':')
                    if len(data) == 4:
                        cid = int(data[0])
                        epoch = int(data[1])
                        loss = float(data[2])
                        accuracy = float(data[3])
                        epoch_data = EpochData(cid, epoch, loss, accuracy)
                        training_matrices.append(epoch_data)
                        print(epoch_data)
                    elif len(data) == 2:
                        cid = int(data[0])
                        time = float(data[1])
                        print(f'time for cid {cid}: {time}')
                except EOFError as err:
                    print('client disconnected')
                except KeyboardInterrupt as err:
                    print('server stopped by interrupt')
                    break
                except Exception as e:
                    print('unable to log data')
                    print(e)
        except KeyboardInterrupt as err:
            print('server stopped by interrupt')
            break
        except Exception as e:
            print('server crashed with error')
            print(e)

data_by_cid = defaultdict(list)
for item in data:
    data_by_cid[item.cid].append(item)

plt.figure(figsize=(10, 5))
for cid, items in data_by_cid.items():
    epochs = [item.epoch for item in items]
    losses = [item.loss for item in items]
    plt.plot(epochs, losses, label=f'CID {cid}')

plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Loss vs Epoch')
plt.legend()
plt.show()

plt.figure(figsize=(10, 5))
for cid, items in data_by_cid.items():
    epochs = [item.epoch for item in items]
    accuracies = [item.accuracy for item in items]
    plt.plot(epochs, accuracies, label=f'CID {cid}')

plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Accuracy vs Epoch')
plt.legend()
plt.show()