# Renset Classification on CIFAR 10

In this example, we show you how to use torchvision model to make a federated classification task

## Dataset: CIFAR 10

You can download the CIFAR-10 dataset through this link: 
- https://webank-ai-1251170195.cos.ap-guangzhou.myqcloud.com/fate/examples/data/cifar-10.zip
  
The origin CIFAR-10 is from: 
- https://www.cs.toronto.edu/~kriz/cifar.html
  
For the convinence of demonstrate, our clients will use same dataset

In [7]:
pwd

'/data/projects/fate/persistence/fate'

In [2]:
import sys
sys.path.append('/data/projects/fate/persistence/fate/python/fate_client')
sys.path.append('/data/projects/fate/persistence/fate/python')

import os
os.chdir("/data/projects/fate/persistence/fate")
os.environ['FATE_PROJECT_BASE'] = '/data/projects/fate/persistence/fate'

## Local Test

Firstly we locally test our model and dataset. If it works, we can submit a federated task.

In [3]:
from pipeline.component.nn import save_to_fate

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
%%save_to_fate model resnet.py

# model
import torch as t
from torch import nn
from torchvision.models import resnet18, ResNet18_Weights

class Resnet(nn.Module):

    def __init__(self, ):
        super(Resnet, self).__init__()
        self.resnet = resnet18()
        self.classifier = t.nn.Linear(1000, 10)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        if self.training:
            return self.classifier(self.resnet(x))
        else:
            return self.softmax(self.classifier(self.resnet(x)))
    

In [5]:
model = Resnet()
print(model)

Resnet(
  (resnet): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_runni

In [8]:
# read dataset
from federatedml.nn.dataset.image import ImageDataset

ds = ImageDataset()
ds.load('/data/projects/fate/persistence/fate/examples/data/cifar10/train/')

In [9]:
ds[0][0].shape

torch.Size([3, 32, 32])

In [13]:
# local test
import torch as t
from federatedml.nn.homo.trainer.fedavg_trainer import FedAVGTrainer

trainer = FedAVGTrainer(epochs=1, batch_size=512, data_loader_worker=2)
trainer.set_model(model)

optimizer = t.optim.Adam(model.parameters(), lr=0.001)
loss = t.nn.CrossEntropyLoss()

trainer.local_mode() # set local mode
trainer.train(ds, None, optimizer, loss)

ipcl_python failed to import
epoch is 0
100%|██████████| 98/98 [00:35<00:00,  2.76it/s]
epoch loss is 1.4889685620117188


## Submit a federated task

In [28]:
import torch as t
from torch import nn
from pipeline import fate_torch_hook
from pipeline.component import HomoNN
from pipeline.backend.pipeline import PipeLine
from pipeline.component import Reader, Evaluation, DataTransform
from pipeline.interface import Data, Model

fate_torch_hook(t)

import os
fate_project_path = os.path.abspath('/data/projects/fate')
guest = 9999
host = 10000
pipeline = PipeLine().set_initiator(role='guest', party_id=guest).set_roles(guest=guest, host=host,
                                                                            arbiter=host)
data_0 = {"name": "cifar10", "namespace": "experiment"}
data_path = fate_project_path + '/examples/data/cifar10/train'
pipeline.bind_table(name=data_0['name'], namespace=data_0['namespace'], path=data_path)
pipeline.bind_table(name=data_0['name'], namespace=data_0['namespace'], path=data_path)

{'namespace': 'experiment', 'table_name': 'cifar10'}

In [29]:
reader_0 = Reader(name="reader_0")
reader_0.get_party_instance(role='guest', party_id=guest).component_param(table=data_0)
reader_0.get_party_instance(role='host', party_id=host).component_param(table=data_0)

reader_1 = Reader(name="reader_1")
reader_1.get_party_instance(role='guest', party_id=guest).component_param(table=data_0)
reader_1.get_party_instance(role='host', party_id=host).component_param(table=data_0)

In [30]:
from pipeline.component.homo_nn import DatasetParam, TrainerParam

model = t.nn.Sequential(
    t.nn.CustModel(module_name='resnet', class_name='Resnet')
)

nn_component = HomoNN(name='nn_0',
                      model=model, 
                      loss=t.nn.CrossEntropyLoss(),
                      optimizer = t.optim.Adam(lr=0.001, weight_decay=0.001),
                      dataset=DatasetParam(dataset_name='image'),  # 使用自定义的dataset
                      trainer=TrainerParam(trainer_name='fedavg_trainer', epochs=1, batch_size=512, data_loader_worker=4),
                      torch_seed=100
                      )

In [31]:
pipeline.add_component(reader_0)
pipeline.add_component(reader_1)
pipeline.add_component(nn_component, data=Data(train_data=reader_0.output.data, validate_data=reader_1.output.data))
pipeline.add_component(Evaluation(name='eval_0', eval_type='multi'), data=Data(data=nn_component.output.data))

<pipeline.backend.pipeline.PipeLine at 0x7fadff1889d0>

In [32]:
pipeline.compile()

<pipeline.backend.pipeline.PipeLine at 0x7fadff1889d0>

In [33]:
pipeline.fit() # submit pipeline here

[32m2023-02-20 07:43:36.618[0m | [1mINFO    [0m | [36mpipeline.utils.invoker.job_submitter[0m:[36mmonitor_job_status[0m:[36m83[0m - [1mJob id is 202302200743362003410
[0m
[32m2023-02-20 07:43:36.627[0m | [1mINFO    [0m | [36mpipeline.utils.invoker.job_submitter[0m:[36mmonitor_job_status[0m:[36m98[0m - [1m[80D[1A[KJob is still waiting, time elapse: 0:00:00[0m
[0mm2023-02-20 07:43:37.645[0m | [1mINFO    [0m | [36mpipeline.utils.invoker.job_submitter[0m:[36mmonitor_job_status[0m:[36m125[0m - [1m
[32m2023-02-20 07:43:37.646[0m | [1mINFO    [0m | [36mpipeline.utils.invoker.job_submitter[0m:[36mmonitor_job_status[0m:[36m127[0m - [1m[80D[1A[KRunning component reader_0, time elapse: 0:00:01[0m
[32m2023-02-20 07:43:38.667[0m | [1mINFO    [0m | [36mpipeline.utils.invoker.job_submitter[0m:[36mmonitor_job_status[0m:[36m127[0m - [1m[80D[1A[KRunning component reader_0, time elapse: 0:00:02[0m
[32m2023-02-20 07:43:39.684[0m | [1mI

ValueError: Job is failed, please check out job 202302200743362003410 by fate board or fate_flow cli