In [81]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import numpy as np
import torchvision
from collections import defaultdict
from tqdm import tqdm
from torch.autograd import Variable
from torch.autograd import grad, backward
from torchvision import datasets , transforms
from torchvision.transforms import Compose
from torch.utils.data import Subset
from sklearn.model_selection import train_test_split
from tqdm import tqdm

In [82]:
class SimpleNetwork(nn.Module):
    def __init__(self):
        super(SimpleNetwork, self).__init__()

        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1)  # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


In [83]:
data_path = '/workspace/nvflare/FedProx/Data'
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
transforms = Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])
BATCH_SIZE = 4
EPOCH = 10
LR = 1e-2
model_pth = '/workspace/nvflare/FedProx/model/'
model_name= 'model_second_half_tune_param.pth'

In [84]:
ds = datasets.CIFAR10(root=data_path, transform=transforms, download=True, train=True)

Files already downloaded and verified


In [85]:
first_half = list(range(0, len(ds)//2))
second_half = list(range(len(ds)//2, len(ds)))

In [86]:
ds = Subset(ds, second_half)

In [87]:
indices = np.arange(len(ds))
y = []
for i in range(len(ds)):
    y.append(ds[i][1])
# print(y[0])
train_indices, val_indices = train_test_split(indices, test_size=0.15, stratify=y)
train_ds = Subset(dataset, train_indices)
val_ds = Subset(dataset, val_indices)
train_data_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_data_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=True)
n_iterations = len(train_data_loader)

In [88]:
model = SimpleNetwork().to(DEVICE)
loss_fn = nn.CrossEntropyLoss().to(DEVICE)
optimizer = torch.optim.SGD(model.parameters(), lr=LR)


In [89]:
def train_model(model, 
          data_loader, 
          loss_fn,
          opimizer,
          n_example):
   
    model = model.train()
    losses = []
    correct = 0 
    for step, (data, target) in enumerate(data_loader):
        img = data.to(DEVICE)
        targets = target.to(DEVICE)
        out1 = model( img)
        _,pred = torch.max(out1, dim = 1)
        correct += torch.sum(pred == targets)
        loss1 =  loss_fn(out1, targets)
        losses.append(loss1.item() )

        optimizer.zero_grad()
        loss1.backward()
        optimizer.step()
    return correct / n_example , np.mean(losses)

In [90]:
 def val_model(model, 
          data_loader, 
          loss_fn,
          n_example):
   
    model = model.train()
    losses = []
    correct = 0 
    with torch.no_grad():
        for step, (data, target) in enumerate(data_loader):
            img = data.to(DEVICE)
            targets = target.to(DEVICE)
            out1 = model( img)
            _,pred = torch.max(out1, dim = 1)
            correct += torch.sum(pred == targets)
            loss1 =  loss_fn(out1, targets)
            losses.append(loss1.item() )

            
    return correct / n_example , np.mean(losses)


In [91]:
bestacc = -1
n = 0 
for epoch in tqdm(range(EPOCH)):
    print(f'Epoch {epoch + 1} / {EPOCH}')
    print("=" *15)
    train_acc,train_loss = train_model(model, train_data_loader ,loss_fn, optimizer  , len(train_ds)   )
    print(f'Train Accuracy : {train_acc} Train Loss : {train_loss}')
    val_acc,val_loss = val_model(model, val_data_loader ,loss_fn  , len( val_ds)   )
    print(f'Val Accuracy : {val_acc} Val Loss : {val_loss}')
    if bestacc < val_acc:
        bestacc = val_acc
        torch.save(model,model_pth + model_name)
    # break

  0%|                                                                                            | 0/10 [00:00<?, ?it/s]

Epoch 1 / 10
Train Accuracy : 0.2691764831542969 Train Loss : 1.9668471538681735


 10%|████████▍                                                                           | 1/10 [00:18<02:50, 18.99s/it]

Val Accuracy : 0.2837333381175995 Val Loss : 2.081955168960191
Epoch 2 / 10
Train Accuracy : 0.4217882454395294 Train Loss : 1.59112768908151


 20%|████████████████▊                                                                   | 2/10 [00:38<02:32, 19.04s/it]

Val Accuracy : 0.4477333426475525 Val Loss : 1.555730192995529
Epoch 3 / 10
Train Accuracy : 0.474823534488678 Train Loss : 1.4549611197557015


 30%|█████████████████████████▏                                                          | 3/10 [00:57<02:13, 19.05s/it]

Val Accuracy : 0.48399999737739563 Val Loss : 1.4383855926424964
Epoch 4 / 10
Train Accuracy : 0.5142588019371033 Train Loss : 1.3612784305445425


 40%|█████████████████████████████████▌                                                  | 4/10 [01:16<01:53, 18.99s/it]

Val Accuracy : 0.5058666467666626 Val Loss : 1.3963114606387326
Epoch 5 / 10
Train Accuracy : 0.5389176607131958 Train Loss : 1.2898450201819358


 50%|██████████████████████████████████████████                                          | 5/10 [01:34<01:34, 18.95s/it]

Val Accuracy : 0.525866687297821 Val Loss : 1.3519710938511753
Epoch 6 / 10
Train Accuracy : 0.5607059001922607 Train Loss : 1.22614872843201


 60%|██████████████████████████████████████████████████▍                                 | 6/10 [01:53<01:15, 18.97s/it]

Val Accuracy : 0.4898666739463806 Val Loss : 1.484388658359869
Epoch 7 / 10
Train Accuracy : 0.580752968788147 Train Loss : 1.17815859408328


 70%|██████████████████████████████████████████████████████████▊                         | 7/10 [02:12<00:56, 18.95s/it]

Val Accuracy : 0.5112000107765198 Val Loss : 1.3999307463799457
Epoch 8 / 10
Train Accuracy : 0.6053176522254944 Train Loss : 1.1137969847547677


 80%|███████████████████████████████████████████████████████████████████▏                | 8/10 [02:31<00:37, 18.95s/it]

Val Accuracy : 0.4994666576385498 Val Loss : 1.5617254828053242
Epoch 9 / 10
Train Accuracy : 0.619105875492096 Train Loss : 1.0695342580356095


 90%|███████████████████████████████████████████████████████████████████████████▌        | 9/10 [02:50<00:18, 18.97s/it]

Val Accuracy : 0.5095999836921692 Val Loss : 1.4733933130545276
Epoch 10 / 10
Train Accuracy : 0.6332706212997437 Train Loss : 1.0288435523747927


100%|███████████████████████████████████████████████████████████████████████████████████| 10/10 [03:09<00:00, 18.97s/it]

Val Accuracy : 0.5226666927337646 Val Loss : 1.4532176851670244





In [92]:
# model_pth = '/workspace/nvflare/FedProx/model/'
# model_name= 'model_second_half_tune_param.pth'
# torch.save(model,model_pth + model_name)