In [1]:
from typing import Callable, Dict, Tuple
from dotenv import load_dotenv
import json
import time
import numpy as np
import web3
from logging import INFO

from utils import *


  from .autonotebook import tqdm as notebook_tqdm


In [11]:
# load_dotenv()
# abi = json.load(open('contract/build/contracts/FugaController.json', 'r'))['abi']
# S3_ACCESS_KEY = os.environ.get('S3_ACCESS_KEY')
# S3_SCRETE_KEY = os.environ.get('S3_SCRETE_KEY')
# BUCKET_NAME = 'fugaeth'
# HTTP_PROVIDER = os.environ.get('HTTP_PROVIDER')

In [12]:

def listen_for_event(contract, event_name):
    # create a filter to listen for the specified event
    event_filter = contract.events[event_name].createFilter(fromBlock='latest')

    while True:
        # check if any new events have been emitted
        for event in event_filter.get_new_entries():
            # if the specified event has been emitted, return its message
            if event.event == event_name:
                yield event.args
        # wait for new events
        time.sleep(5)

def web3_connection(s3, w3, contract) -> Tuple[Callable[[], Dict], Callable[[Dict], None]]:
    """
    Creates a connection to the blockchain and returns a function to receive messages and a function to send messages.
    """
    web3_message_iterator = listen_for_event(contract, "ServerMessage")

    def receive():
        try:
            return next(web3_message_iterator)
        except StopIteration:
            return None

    def send(msg):
        return handle_send(msg, s3, w3, contract)

    return (receive, send)

# def listen_for_event(contract, event_name):
#     # create a filter to listen for the specified event
#     event_filter = contract.events[event_name].createFilter(fromBlock='latest')

#     while True:
#         # check if any new events have been emitted
#         for event in event_filter.get_new_entries():
#             # if the specified event has been emitted, return its message
#             if event.event == event_name:
#                 yield event.args
#                 return
#         # wait for new events
# #         time.sleep(5)

# def web3_connection(s3, w3, contract) -> Tuple[Callable[[], Dict], Callable[[Dict], None]]:
#     """
#     Creates a connection to the blockchain and returns a function to receive messages and a function to send messages.    
#     """
#     # create a filter to listen for the specified event
#     web3_message_iterator = listen_for_event(contract, "ServerMessage")

#     # create a function that returns the next message
#     def receive():
#         try:
#             return next(web3_message_iterator)
#         except StopIteration:
#             return None
#     def send(msg):
#         return handle_send(msg, s3, w3, contract)
#     # receive: Callable[[], Dict] = lambda: next(web3_message_iterator)
#     # send: Callable[[Dict], None] = lambda msg: handle_send(msg, s3, w3, contract)

#     return (receive, send)


def handle_receive(client, msg, s3, w3, contract):
    """
    Handles a received message.
    Returns a tuple containing the response message, the number of samples and a boolean indicating whether the client should shut down.
    """
    # check the field of the message
    field = msg['field']

    if field == "Finished":
        print("finished")
        return None, 0, False

    # if the message is a request for Ready, join the round
    if field == "Ready":
        print("join round")
        function = getattr(contract.functions, 'joinRound')()

    # if the message is a request for JoinRound, after few minutes, start the round
    if field == "JoinRound":
        time.sleep(1*20)
        print("call start round")
        function = getattr(contract.functions, 'startRound')()

    if field == "Ready" or field == "JoinRound":
        send_transaction(w3, contract, function)
        return None, 0, True

    # if the message is a request for the ConfigIns, set the Config data
    if field == "ConfigIns":
        print("config start")
        # receive the Config data
        # function = getattr(contract.functions, 'getConfig')()
        # response = contract.functions.getConfig().call()
        response = read_transaction(w3, contract, 'getConfig')
        
        print(response)
        
        self_centered = response['self_centered']
        batch_size = response['batch_size']
        learning_rate = float(response['learning_rate'])
        local_epochs = response['local_epochs']
        val_steps = response['val_steps']

        config = {'self_centered': self_centered, 'batch_size': batch_size, 'lr' : learning_rate ,'local_epochs': local_epochs, 'val_steps': val_steps}

        client.set_config(config)

        return {'field':'ConfigRes'}, 0, True
        
    # if the message is a request for the FitIns, aggregate the model and return the result after fit
    if field == "FitIns":
        print("fit start")
        # receive the Client data
        # response = contract.functions.getClient().call()
        response = read_transaction(w3, contract, 'getClient')
        print(response)

        self_model_hash = response['model_hash']
        self_num_sample = response['num_sample']
        self_score = response['score']

        model_hashes = []
        num_samples = []
        scores = []

        # if(make_hash(client.get_parameters()) != self_model_hash):
        #     print("model hash is not matched, Retreive model from blockchain")
        #     object_key = f'models/{self_model_hash}.bin'
        #     file_path = object_key
        #     s3.download_file(BUCKET_NAME, object_key, file_path)
        #     # check hash value
        #     if(not check_model(self_model_hash)):
        #         client.init_model()
        # client.fit(client.get_parameters())

        # receive the FitIns data
        # response = contract.functions.FitIns().call()
        response = read_transaction(w3, contract, 'FitIns')
        print(response)
        
        other_model_hashes = response['model_hashes']
        other_num_samples =response['num_samples']
        other_scores = response['scores']

        for model_hash in other_model_hashes:
            object_key = f'models/{model_hash}.bin'
            file_path = object_key
            s3.download_file(BUCKET_NAME, object_key, file_path)

            # check hash value
            if(check_model(model_hash)):
                model_hashes.append(model_hash)
                num_samples.append(other_num_samples)
                scores.append(other_scores)

        if(check_model(self_model_hash)):
            model_hashes.append(self_model_hash)
            num_samples.append(self_num_sample)
            scores.append(sum(scores) if client.self_centered else self_score)
            if(len(scores)==1, scores[0]==0):
                scores[0] = 1

        # aggregate the model
        fitres = aggregate_fit(client, model_hashes, num_samples, scores)

        # return the result of fit
        return {'field':'FitRes', 'data': fitres},0, True
    
    # if the message is a request for the EvaluateIns, evaluate the model and return the result
    elif field == "EvaluateIns":
        print("eval start")
        # evaluate the model on client
        # response = contract.functions.EvaluateIns().call()
        response = read_transaction(w3, contract, 'EvaluateIns')
        print(response)

        model_hashes = response['model_hashes']
        evalres = {}

        for model_hash in model_hashes:
            object_key = f'models/{model_hash}.bin'
            file_path = object_key

            s3.download_file(BUCKET_NAME, object_key, file_path)

            # check hash value
            check_model(model_hash)
            param = read_model(model_hash)

            # evaluate the model
            loss, _, _ = client.evaluate(param)
            evalres[model_hash] = loss

        return {'field':'EvaluateRes', 'data':evalres}, 0 , True


def handle_send(msg, s3, w3, contract):
    """
    Handles a message to be sent.
    Returns the transaction receipt.
    """
    # check the field of the message
    field = msg['field']

    # If the message is a ConfigRes, upload the result to the blockchain
    if field == 'ConfigRes':
        print("config complete")
        function = getattr(contract.functions, 'ConfigRes')()

    # If the message is a FitRes, save model and upload model hash to the blockchain
    if field == 'FitRes':
        print("fit complete")
        # get the message params
        parameters_prime, num_examples_train, results = msg['data']

        # hash the message params which is dictionary
        model_hash = make_hash(parameters_prime)

        # save and upload the model
        upload_model(s3, parameters_prime)

        # prepare the arguments for the FitRes function
        args = [model_hash, num_examples_train]

        # get the function object from the contract ABI
        function = getattr(contract.functions, 'FitRes')(*args)

    # If the message is a EvaluateRes, upload the result to the blockchain
    elif field == 'EvaluateRes':
        print("eval complete")
        # get the message params
        evalres = msg['data']
        model_hashes = list(evalres.keys())
        values = list(evalres.values())
        args = [model_hashes, values]

        # get the function object from the contract ABI
        function = getattr(contract.functions, 'EvaluateRes')(*args)

    send_transaction(w3, contract, function)



def start_web3_client(client, contract_address, abi):

    while True:
        s3 = s3_connection()
        w3 = web3.Web3(web3.HTTPProvider(HTTP_PROVIDER))
        contract = w3.eth.contract(address=contract_address, abi=abi)
        receive,send = web3_connection(s3, w3, contract)
        # receive,send = next(conn)

        while True:
            server_message = receive()
            print("server message : ",server_message)
            if(server_message is not None):
                client_message, sleep_duration, keep_going = handle_receive(
                    client, server_message, s3, w3, contract
                )
                if(client_message is not None):
                    send(client_message)
        
                if not keep_going:
                    break

        # Check if we should disconnect and shut down
        if sleep_duration == 0:
            print("Disconnect and shut down")
            break
        # Sleep and reconnect afterwards
        print("Sleeping for {} seconds".format(sleep_duration))
        time.sleep(sleep_duration)



In [13]:
import random
from collections import OrderedDict

import flwr as fl
import torch
from datasets import load_dataset, load_metric
from torch.utils.data import DataLoader
from transformers import (
    AdamW,
    AutoModelForSequenceClassification,
    AutoTokenizer,
    DataCollatorWithPadding,
)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def load_data():
    """Load IMDB data (training and eval)"""
    raw_datasets = load_dataset("imdb")
    raw_datasets = raw_datasets.shuffle(seed=42)

    # remove unnecessary data split
    del raw_datasets["unsupervised"]

    tokenizer = AutoTokenizer.from_pretrained("albert-base-v2")

    # random 10 samples
    population = random.sample(range(len(raw_datasets["train"])), 10)

    tokenized_datasets = raw_datasets.map(
        lambda examples: tokenizer(examples["text"], truncation=True), batched=True
    )
    tokenized_datasets["train"] = tokenized_datasets["train"].select(population)
    tokenized_datasets["test"] = tokenized_datasets["test"].select(population)

    tokenized_datasets = tokenized_datasets.remove_columns("text")
    tokenized_datasets = tokenized_datasets.rename_column("label", "labels")

    data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
    trainloader = DataLoader(
        tokenized_datasets["train"],
        shuffle=True,
        batch_size=32,
        collate_fn=data_collator,
    )

    testloader = DataLoader(
        tokenized_datasets["test"], batch_size=32, collate_fn=data_collator
    )

    return trainloader, testloader


In [14]:

def train(net, trainloader, epochs ,learning_rate):
    optimizer = AdamW(net.parameters(), lr=learning_rate)
    net.train()
    for _ in range(epochs):
        for batch in trainloader:
            batch = {k: v.to(DEVICE) for k, v in batch.items()}
            outputs = net(**batch)
            loss = outputs.loss
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()


def test(net, testloader):
    metric = load_metric("accuracy")
    loss = 0
    net.eval()
    for batch in testloader:
        batch = {k: v.to(DEVICE) for k, v in batch.items()}
        with torch.no_grad():
            outputs = net(**batch)
        logits = outputs.logits
        loss += outputs.loss.item()
        predictions = torch.argmax(logits, dim=-1)
        metric.add_batch(predictions=predictions, references=batch["labels"])
    loss /= len(testloader.dataset)
    accuracy = metric.compute()["accuracy"]
    return loss, accuracy




In [15]:
class IMDBClient(fl.client.NumPyClient):
    def __init__(self) -> None:
        self.net = AutoModelForSequenceClassification.from_pretrained(
            "albert-base-v2", num_labels=2
        ).to(DEVICE)
        self.trainloader, self.testloader = load_data()
        self.config = None
    
    def init_model(self):
        self.net = AutoModelForSequenceClassification.from_pretrained(
            "albert-base-v2", num_labels=2
        ).to(DEVICE)

    def set_config(self, config) :
        self.config = config

    def get_parameters(self):
        return [val.cpu().numpy() for _, val in self.net.state_dict().items()]

    def set_parameters(self, parameters):
        params_dict = zip(self.net.state_dict().keys(), parameters)
        state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
        self.net.load_state_dict(state_dict, strict=True)

    def fit(self, parameters):
        self.set_parameters(parameters)
        print("Training Started...")
        train(self.net, self.trainloader, epochs=self.config["local_epochs"], learning_rate=self.config["learning_rate"])
        print("Training Finished.")
        print("Test Results : ", test(self.net, self.testloader))
        return self.get_parameters(), len(self.trainloader), {}

    def evaluate(self, parameters):
        self.set_parameters(parameters)
        loss, accuracy = test(self.net, self.testloader)
        return float(loss), len(self.testloader), {"accuracy": float(accuracy)}

In [16]:
client = IMDBClient()

Some weights of the model checkpoint at albert-base-v2 were not used when initializing AlbertForSequenceClassification: ['predictions.decoder.bias', 'predictions.bias', 'predictions.LayerNorm.bias', 'predictions.dense.bias', 'predictions.LayerNorm.weight', 'predictions.dense.weight', 'predictions.decoder.weight']
- This IS expected if you are initializing AlbertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing AlbertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of AlbertForSequenceClassification were not initialized from the model checkpoint at albert-base-v2 and are newly initialized: ['classifier.weight', 'classifier.bias']
You sho

In [17]:
start_web3_client(client, '0xbB8977D0604ABe3FA3d368e15D96426Ed91a06C0', abi)

s3 bucket connected!
server message :  AttributeDict({'field': 'JoinRound'})
call start round
server message :  AttributeDict({'field': 'ConfigIns'})
config start


  f"The log with transaction hash: {log['transactionHash']!r} and "


AttributeDict({'self_centered': True, 'batch_size': 1, 'learning_rate': '0.00005', 'local_epochs': 1, 'val_steps': 5})
config complete


KeyboardInterrupt: 

In [None]:
INFURA_API_KEY = os.environ.get('INFURA_API_KEY')
MNEMONIC = os.environ.get('MNEMONIC')
abi = json.load(open('contract/build/contracts/Test.json', 'r'))['abi']
w3 = web3.Web3()
w3.eth.account.enable_unaudited_hdwallet_features()
account = w3.eth.account.from_mnemonic(MNEMONIC, account_path="m/44'/60'/0'/0/0")
w3 = web3.Web3(web3.HTTPProvider(f'https://sepolia.infura.io/v3/{INFURA_API_KEY}'))
contract = w3.eth.contract(address='0x8309F6a4BBb56caEC4c197Efd00287e748017162', abi=abi)

In [None]:
args = [1]
function = getattr(contract.functions, 'setScore')(*args)

# Create a transaction for calling the function
transaction = function.buildTransaction()

# Estimate the gas required for the transaction
gas_estimate = w3.eth.estimateGas(transaction)

# Set up the transaction parameters
transaction.update({
    'from': account.address,
    'to': contract.address,
    'gas': gas_estimate,
    'nonce': w3.eth.getTransactionCount(account.address),
})

# Sign the transaction
signed_transaction = w3.eth.account.signTransaction(transaction, account.key)

# Send the transaction
transaction_hash = w3.eth.sendRawTransaction(signed_transaction.rawTransaction)

# Wait for the transaction to be mined
transaction_receipt = w3.eth.waitForTransactionReceipt(transaction_hash)

transaction_receipt

AttributeDict({'blockHash': HexBytes('0xc7fab322a8e6b67850bebec4fa7a406bc1df1827582458463f539ff3e0f5f7e5'),
 'blockNumber': 3295748,
 'contractAddress': None,
 'cumulativeGasUsed': 1024850,
 'effectiveGasPrice': 1000000008,
 'from': '0x08cA4DCa530A7c02d7755455AeB24B63A41C8608',
 'gasUsed': 23802,
 'logs': [],
 'logsBloom': HexBytes('0x00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000'),
 'status': 1,
 'to': '0x8309F6a4BBb56caEC4c197Efd00287e748017162',
 'transactionHash': HexBytes('0x99c487885d648391a653ced4ecf2d21157c40349967e2094

In [None]:
contract.functions.setScore(2).call()

[]