In [1]:
import sys
sys.argv = ['']

In [2]:
import os
import models
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from config import cfg
from data import fetch_dataset, make_data_loader, split_dataset, SplitDataset
from utils import save, to_device, process_control, process_dataset, make_optimizer, make_scheduler, resume, collate
from logger import Logger
from metrics import Metric
import time
import datetime
import shutil

from masking_functions import SNIP

import train_classifier

In [19]:
process_control()
cfg['data_name'] = 'CIFAR10'
cfg['model_name'] = 'conv'
cfg["data_split_mode"] ='non-iid-2'
cfg["num_users"] = 100
cfg["num_epochs"] = 100

seeds = list(range(cfg['init_seed'], cfg['init_seed'] + cfg['num_experiments']))
model_tag_list = [str(seeds[0]), cfg['data_name'], cfg['subset'], cfg['model_name'], cfg['control_name']]
cfg['model_tag'] = '_'.join([x for x in model_tag_list if x])

In [4]:
print(cfg['device'])
print(cfg["data_split_mode"])
print(cfg["num_users"])

cuda
non-iid-2
100


In [5]:
dataset = fetch_dataset(cfg['data_name'], cfg['subset'])
process_dataset(dataset)

fetching data CIFAR10...
data ready


In [6]:
dataset['train']

Dataset CIFAR10
Size: 50000
Root: ./data/CIFAR10
Split: train
Subset: label
Transforms: Compose(
    RandomCrop(size=(32, 32), padding=4)
    RandomHorizontalFlip(p=0.5)
    ToTensor()
    Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.201))
)

In [7]:
model = eval('models.{}(model_rate=cfg["global_model_rate"], track=True).to(cfg["device"]).to(cfg["device"])'
                 .format(cfg['model_name']))
optimizer = make_optimizer(model, cfg['lr'])
scheduler = make_scheduler(optimizer)

In [8]:
print(model)

Conv(
  (blocks): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): Scaler()
    (2): BatchNorm2d(64, eps=1e-05, momentum=None, affine=True, track_running_stats=True)
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): Scaler()
    (7): BatchNorm2d(128, eps=1e-05, momentum=None, affine=True, track_running_stats=True)
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): Scaler()
    (12): BatchNorm2d(256, eps=1e-05, momentum=None, affine=True, track_running_stats=True)
    (13): ReLU(inplace=True)
    (14): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (15): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)

In [9]:
data_split = None

In [10]:
if data_split is None:
    data_split, label_split = split_dataset(dataset, cfg['num_users'], cfg['data_split_mode'])

In [11]:
print(len(data_split['train'][0]))
print(data_split['train'][0][0])
print(label_split[0])

500
39947
[4, 7]


In [12]:
data_loader = make_data_loader({'train': SplitDataset(dataset['train'], data_split['train'][0])})['train']
test_loader = make_data_loader({'test': SplitDataset(dataset['test'], data_split['test'][0])})['test']

In [13]:
keep_mask = SNIP(model, 0.05, data_loader, cfg['device'])

tensor(77756, device='cuda:0')


In [14]:
print(type(keep_mask[0]))

<class 'torch.Tensor'>


In [15]:
def apply_prune_mask(net, keep_masks):

    # Before I can zip() layers and pruning masks I need to make sure they match
    # one-to-one by removing all the irrelevant modules:
    prunable_layers = filter(
        lambda layer: isinstance(layer, nn.Conv2d) or isinstance(
            layer, nn.Linear), net.modules())

    for layer, keep_mask in zip(prunable_layers, keep_masks):
        assert (layer.weight.shape == keep_mask.shape)

        def hook_factory(keep_mask):
            """
            The hook function can't be defined directly here because of Python's
            late binding which would result in all hooks getting the very last
            mask! Getting it through another function forces early binding.
            """

            def hook(grads):
                return grads * keep_mask

            return hook

        # mask[i] == 0 --> Prune parameter
        # mask[i] == 1 --> Keep parameter

        # Step 1: Set the masked weights to zero (NB the biases are ignored)
        # Step 2: Make sure their gradients remain zero
        layer.weight.data[keep_mask == 0.] = 0.
        layer.weight.register_hook(hook_factory(keep_mask))

In [16]:
apply_prune_mask(model, keep_mask)

In [17]:
def train(data_loader, model, optimizer, logger, epoch):
    metric = Metric()
    model.train(True)
    start_time = time.time()
    for i, input in enumerate(data_loader):
        input = collate(input)
        input_size = input['img'].size(0)
        input = to_device(input, cfg['device'])
        optimizer.zero_grad()
        output = model(input)
        output['loss'] = output['loss'].mean() if cfg['world_size'] > 1 else output['loss']
        output['loss'].backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
        optimizer.step()
        evaluation = metric.evaluate(cfg['metric_name']['train'], input, output)
        logger.append(evaluation, 'train', n=input_size)
        if i % int((len(data_loader) * cfg['log_interval']) + 1) == 0:
            batch_time = (time.time() - start_time) / (i + 1)
            lr = optimizer.param_groups[0]['lr']
            epoch_finished_time = datetime.timedelta(seconds=round(batch_time * (len(data_loader) - i - 1)))
            exp_finished_time = epoch_finished_time + datetime.timedelta(
                seconds=round((cfg['num_epochs'] - epoch) * batch_time * len(data_loader)))
            info = {'info': ['Model: {}'.format(cfg['model_tag']),
                             'Train Epoch: {}({:.0f}%)'.format(epoch, 100. * i / len(data_loader)),
                             'Learning rate: {}'.format(lr), 'Epoch Finished Time: {}'.format(epoch_finished_time),
                             'Experiment Finished Time: {}'.format(exp_finished_time)]}
            logger.append(info, 'train', mean=False)
            logger.write('train', cfg['metric_name']['train'])
    return
def stats(data_loader, model):
    with torch.no_grad():
        test_model = eval('models.{}(model_rate=cfg["global_model_rate"], track=True).to(cfg["device"])'
                          .format(cfg['model_name']))
        test_model.load_state_dict(model.state_dict(), strict=False)
        test_model.train(True)
        for i, input in enumerate(data_loader):
            input = collate(input)
            input = to_device(input, cfg['device'])
            test_model(input)
    return test_model

def test(data_loader, model, logger, epoch):
    with torch.no_grad():
        metric = Metric()
        model.train(False)
        for i, input in enumerate(data_loader):
            input = collate(input)
            input_size = input['img'].size(0)
            input = to_device(input, cfg['device'])
            output = model(input)
            output['loss'] = output['loss'].mean() if cfg['world_size'] > 1 else output['loss']
            evaluation = metric.evaluate(cfg['metric_name']['test'], input, output)
            logger.append(evaluation, 'test', input_size)
        info = {'info': ['Model: {}'.format(cfg['model_tag']), 'Test Epoch: {}({:.0f}%)'.format(epoch, 100.)]}
        logger.append(info, 'test', mean=False)
        logger.write('test', cfg['metric_name']['test'])
    return

In [20]:
last_epoch = 1
logger_path = os.path.join('output', 'runs', 'train_{}'.format(cfg['model_tag']))
logger = Logger(logger_path)
for epoch in range(last_epoch, cfg['num_epochs'] + 1):
        logger.safe(True)
        train(data_loader, model, optimizer, logger, epoch)
        test_model = stats(data_loader, model)
        test(test_loader, test_model, logger, epoch)

Model: 0_CIFAR10_label_conv_1_100_0.1_iid_fix_a1_bn_1_1  Train Epoch: 1(0%)  Loss: 0.4828  Accuracy: 90.0000  Learning rate: 0.1  Epoch Finished Time: 0:00:01  Experiment Finished Time: 0:02:09
Model: 0_CIFAR10_label_conv_1_100_0.1_iid_fix_a1_bn_1_1  Train Epoch: 1(26%)  Loss: 0.4773  Accuracy: 76.4286  Learning rate: 0.1  Epoch Finished Time: 0:00:01  Experiment Finished Time: 0:01:15
Model: 0_CIFAR10_label_conv_1_100_0.1_iid_fix_a1_bn_1_1  Train Epoch: 1(52%)  Loss: 0.4982  Accuracy: 75.5556  Learning rate: 0.1  Epoch Finished Time: 0:00:00  Experiment Finished Time: 0:01:13
Model: 0_CIFAR10_label_conv_1_100_0.1_iid_fix_a1_bn_1_1  Train Epoch: 1(78%)  Loss: 0.5347  Accuracy: 74.2500  Learning rate: 0.1  Epoch Finished Time: 0:00:00  Experiment Finished Time: 0:01:09
Model: 0_CIFAR10_label_conv_1_100_0.1_iid_fix_a1_bn_1_1  Test Epoch: 1(100%)  Loss: 1.1498  Accuracy: 59.0000
Model: 0_CIFAR10_label_conv_1_100_0.1_iid_fix_a1_bn_1_1  Train Epoch: 2(0%)  Loss: 0.5292  Accuracy: 74.7059  L

Model: 0_CIFAR10_label_conv_1_100_0.1_iid_fix_a1_bn_1_1  Train Epoch: 10(52%)  Loss: 0.4880  Accuracy: 77.8826  Learning rate: 0.1  Epoch Finished Time: 0:00:00  Experiment Finished Time: 0:00:54
Model: 0_CIFAR10_label_conv_1_100_0.1_iid_fix_a1_bn_1_1  Train Epoch: 10(78%)  Loss: 0.4846  Accuracy: 78.0204  Learning rate: 0.1  Epoch Finished Time: 0:00:00  Experiment Finished Time: 0:00:57
Model: 0_CIFAR10_label_conv_1_100_0.1_iid_fix_a1_bn_1_1  Test Epoch: 10(100%)  Loss: 0.6822  Accuracy: 68.2000
Model: 0_CIFAR10_label_conv_1_100_0.1_iid_fix_a1_bn_1_1  Train Epoch: 11(0%)  Loss: 0.4835  Accuracy: 78.0838  Learning rate: 0.1  Epoch Finished Time: 0:00:01  Experiment Finished Time: 0:00:47
Model: 0_CIFAR10_label_conv_1_100_0.1_iid_fix_a1_bn_1_1  Train Epoch: 11(26%)  Loss: 0.4817  Accuracy: 78.1712  Learning rate: 0.1  Epoch Finished Time: 0:00:00  Experiment Finished Time: 0:00:58
Model: 0_CIFAR10_label_conv_1_100_0.1_iid_fix_a1_bn_1_1  Train Epoch: 11(52%)  Loss: 0.4831  Accuracy: 78.

Model: 0_CIFAR10_label_conv_1_100_0.1_iid_fix_a1_bn_1_1  Test Epoch: 19(100%)  Loss: 0.7396  Accuracy: 68.2105
Model: 0_CIFAR10_label_conv_1_100_0.1_iid_fix_a1_bn_1_1  Train Epoch: 20(0%)  Loss: 0.4500  Accuracy: 79.8318  Learning rate: 0.1  Epoch Finished Time: 0:00:01  Experiment Finished Time: 0:00:44
Model: 0_CIFAR10_label_conv_1_100_0.1_iid_fix_a1_bn_1_1  Train Epoch: 20(26%)  Loss: 0.4490  Accuracy: 79.8963  Learning rate: 0.1  Epoch Finished Time: 0:00:00  Experiment Finished Time: 0:00:47
Model: 0_CIFAR10_label_conv_1_100_0.1_iid_fix_a1_bn_1_1  Train Epoch: 20(52%)  Loss: 0.4474  Accuracy: 79.9591  Learning rate: 0.1  Epoch Finished Time: 0:00:00  Experiment Finished Time: 0:00:52
Model: 0_CIFAR10_label_conv_1_100_0.1_iid_fix_a1_bn_1_1  Train Epoch: 20(78%)  Loss: 0.4463  Accuracy: 80.0101  Learning rate: 0.1  Epoch Finished Time: 0:00:00  Experiment Finished Time: 0:00:53
Model: 0_CIFAR10_label_conv_1_100_0.1_iid_fix_a1_bn_1_1  Test Epoch: 20(100%)  Loss: 0.7475  Accuracy: 68.

Model: 0_CIFAR10_label_conv_1_100_0.1_iid_fix_a1_bn_1_1  Train Epoch: 29(52%)  Loss: 0.4248  Accuracy: 80.9881  Learning rate: 0.1  Epoch Finished Time: 0:00:00  Experiment Finished Time: 0:00:46
Model: 0_CIFAR10_label_conv_1_100_0.1_iid_fix_a1_bn_1_1  Train Epoch: 29(78%)  Loss: 0.4245  Accuracy: 80.9861  Learning rate: 0.1  Epoch Finished Time: 0:00:00  Experiment Finished Time: 0:00:46
Model: 0_CIFAR10_label_conv_1_100_0.1_iid_fix_a1_bn_1_1  Test Epoch: 29(100%)  Loss: 1.0596  Accuracy: 64.4828
Model: 0_CIFAR10_label_conv_1_100_0.1_iid_fix_a1_bn_1_1  Train Epoch: 30(0%)  Loss: 0.4260  Accuracy: 80.9373  Learning rate: 0.1  Epoch Finished Time: 0:00:01  Experiment Finished Time: 0:01:09
Model: 0_CIFAR10_label_conv_1_100_0.1_iid_fix_a1_bn_1_1  Train Epoch: 30(26%)  Loss: 0.4252  Accuracy: 80.9836  Learning rate: 0.1  Epoch Finished Time: 0:00:01  Experiment Finished Time: 0:00:54
Model: 0_CIFAR10_label_conv_1_100_0.1_iid_fix_a1_bn_1_1  Train Epoch: 30(52%)  Loss: 0.4252  Accuracy: 80.

Model: 0_CIFAR10_label_conv_1_100_0.1_iid_fix_a1_bn_1_1  Test Epoch: 38(100%)  Loss: 1.2377  Accuracy: 63.4211
Model: 0_CIFAR10_label_conv_1_100_0.1_iid_fix_a1_bn_1_1  Train Epoch: 39(0%)  Loss: 0.4103  Accuracy: 81.7780  Learning rate: 0.1  Epoch Finished Time: 0:00:01  Experiment Finished Time: 0:00:34
Model: 0_CIFAR10_label_conv_1_100_0.1_iid_fix_a1_bn_1_1  Train Epoch: 39(26%)  Loss: 0.4095  Accuracy: 81.8130  Learning rate: 0.1  Epoch Finished Time: 0:00:00  Experiment Finished Time: 0:00:37
Model: 0_CIFAR10_label_conv_1_100_0.1_iid_fix_a1_bn_1_1  Train Epoch: 39(52%)  Loss: 0.4086  Accuracy: 81.8422  Learning rate: 0.1  Epoch Finished Time: 0:00:00  Experiment Finished Time: 0:00:38
Model: 0_CIFAR10_label_conv_1_100_0.1_iid_fix_a1_bn_1_1  Train Epoch: 39(78%)  Loss: 0.4082  Accuracy: 81.8763  Learning rate: 0.1  Epoch Finished Time: 0:00:00  Experiment Finished Time: 0:00:37
Model: 0_CIFAR10_label_conv_1_100_0.1_iid_fix_a1_bn_1_1  Test Epoch: 39(100%)  Loss: 1.2556  Accuracy: 63.

Model: 0_CIFAR10_label_conv_1_100_0.1_iid_fix_a1_bn_1_1  Train Epoch: 48(52%)  Loss: 0.3938  Accuracy: 82.6967  Learning rate: 0.1  Epoch Finished Time: 0:00:00  Experiment Finished Time: 0:00:34
Model: 0_CIFAR10_label_conv_1_100_0.1_iid_fix_a1_bn_1_1  Train Epoch: 48(78%)  Loss: 0.3933  Accuracy: 82.7364  Learning rate: 0.1  Epoch Finished Time: 0:00:00  Experiment Finished Time: 0:00:33
Model: 0_CIFAR10_label_conv_1_100_0.1_iid_fix_a1_bn_1_1  Test Epoch: 48(100%)  Loss: 1.2860  Accuracy: 63.5208
Model: 0_CIFAR10_label_conv_1_100_0.1_iid_fix_a1_bn_1_1  Train Epoch: 49(0%)  Loss: 0.3928  Accuracy: 82.7489  Learning rate: 0.1  Epoch Finished Time: 0:00:00  Experiment Finished Time: 0:00:26
Model: 0_CIFAR10_label_conv_1_100_0.1_iid_fix_a1_bn_1_1  Train Epoch: 49(26%)  Loss: 0.3927  Accuracy: 82.7465  Learning rate: 0.1  Epoch Finished Time: 0:00:00  Experiment Finished Time: 0:00:30
Model: 0_CIFAR10_label_conv_1_100_0.1_iid_fix_a1_bn_1_1  Train Epoch: 49(52%)  Loss: 0.3921  Accuracy: 82.

Model: 0_CIFAR10_label_conv_1_100_0.1_iid_fix_a1_bn_1_1  Test Epoch: 57(100%)  Loss: 1.3766  Accuracy: 62.6667
Model: 0_CIFAR10_label_conv_1_100_0.1_iid_fix_a1_bn_1_1  Train Epoch: 58(0%)  Loss: 0.3781  Accuracy: 83.6268  Learning rate: 0.1  Epoch Finished Time: 0:00:00  Experiment Finished Time: 0:00:21
Model: 0_CIFAR10_label_conv_1_100_0.1_iid_fix_a1_bn_1_1  Train Epoch: 58(26%)  Loss: 0.3778  Accuracy: 83.6418  Learning rate: 0.1  Epoch Finished Time: 0:00:00  Experiment Finished Time: 0:00:29
Model: 0_CIFAR10_label_conv_1_100_0.1_iid_fix_a1_bn_1_1  Train Epoch: 58(52%)  Loss: 0.3778  Accuracy: 83.6462  Learning rate: 0.1  Epoch Finished Time: 0:00:00  Experiment Finished Time: 0:00:27
Model: 0_CIFAR10_label_conv_1_100_0.1_iid_fix_a1_bn_1_1  Train Epoch: 58(78%)  Loss: 0.3775  Accuracy: 83.6540  Learning rate: 0.1  Epoch Finished Time: 0:00:00  Experiment Finished Time: 0:00:27
Model: 0_CIFAR10_label_conv_1_100_0.1_iid_fix_a1_bn_1_1  Test Epoch: 58(100%)  Loss: 1.3747  Accuracy: 62.

Model: 0_CIFAR10_label_conv_1_100_0.1_iid_fix_a1_bn_1_1  Train Epoch: 67(52%)  Loss: 0.3683  Accuracy: 84.1268  Learning rate: 0.1  Epoch Finished Time: 0:00:00  Experiment Finished Time: 0:00:20
Model: 0_CIFAR10_label_conv_1_100_0.1_iid_fix_a1_bn_1_1  Train Epoch: 67(78%)  Loss: 0.3681  Accuracy: 84.1377  Learning rate: 0.1  Epoch Finished Time: 0:00:00  Experiment Finished Time: 0:00:20
Model: 0_CIFAR10_label_conv_1_100_0.1_iid_fix_a1_bn_1_1  Test Epoch: 67(100%)  Loss: 1.3769  Accuracy: 63.7463
Model: 0_CIFAR10_label_conv_1_100_0.1_iid_fix_a1_bn_1_1  Train Epoch: 68(0%)  Loss: 0.3678  Accuracy: 84.1510  Learning rate: 0.1  Epoch Finished Time: 0:00:01  Experiment Finished Time: 0:00:18
Model: 0_CIFAR10_label_conv_1_100_0.1_iid_fix_a1_bn_1_1  Train Epoch: 68(26%)  Loss: 0.3672  Accuracy: 84.1736  Learning rate: 0.1  Epoch Finished Time: 0:00:00  Experiment Finished Time: 0:00:19
Model: 0_CIFAR10_label_conv_1_100_0.1_iid_fix_a1_bn_1_1  Train Epoch: 68(52%)  Loss: 0.3666  Accuracy: 84.

Model: 0_CIFAR10_label_conv_1_100_0.1_iid_fix_a1_bn_1_1  Test Epoch: 76(100%)  Loss: 1.4533  Accuracy: 63.8816
Model: 0_CIFAR10_label_conv_1_100_0.1_iid_fix_a1_bn_1_1  Train Epoch: 77(0%)  Loss: 0.3556  Accuracy: 84.7461  Learning rate: 0.1  Epoch Finished Time: 0:00:01  Experiment Finished Time: 0:00:14
Model: 0_CIFAR10_label_conv_1_100_0.1_iid_fix_a1_bn_1_1  Train Epoch: 77(26%)  Loss: 0.3559  Accuracy: 84.7431  Learning rate: 0.1  Epoch Finished Time: 0:00:00  Experiment Finished Time: 0:00:14
Model: 0_CIFAR10_label_conv_1_100_0.1_iid_fix_a1_bn_1_1  Train Epoch: 77(52%)  Loss: 0.3558  Accuracy: 84.7531  Learning rate: 0.1  Epoch Finished Time: 0:00:00  Experiment Finished Time: 0:00:15
Model: 0_CIFAR10_label_conv_1_100_0.1_iid_fix_a1_bn_1_1  Train Epoch: 77(78%)  Loss: 0.3558  Accuracy: 84.7604  Learning rate: 0.1  Epoch Finished Time: 0:00:00  Experiment Finished Time: 0:00:15
Model: 0_CIFAR10_label_conv_1_100_0.1_iid_fix_a1_bn_1_1  Test Epoch: 77(100%)  Loss: 1.4473  Accuracy: 64.

Model: 0_CIFAR10_label_conv_1_100_0.1_iid_fix_a1_bn_1_1  Train Epoch: 86(52%)  Loss: 0.3469  Accuracy: 85.1742  Learning rate: 0.1  Epoch Finished Time: 0:00:00  Experiment Finished Time: 0:00:08
Model: 0_CIFAR10_label_conv_1_100_0.1_iid_fix_a1_bn_1_1  Train Epoch: 86(78%)  Loss: 0.3468  Accuracy: 85.1702  Learning rate: 0.1  Epoch Finished Time: 0:00:00  Experiment Finished Time: 0:00:09
Model: 0_CIFAR10_label_conv_1_100_0.1_iid_fix_a1_bn_1_1  Test Epoch: 86(100%)  Loss: 1.4786  Accuracy: 63.8605
Model: 0_CIFAR10_label_conv_1_100_0.1_iid_fix_a1_bn_1_1  Train Epoch: 87(0%)  Loss: 0.3465  Accuracy: 85.1779  Learning rate: 0.1  Epoch Finished Time: 0:00:01  Experiment Finished Time: 0:00:08
Model: 0_CIFAR10_label_conv_1_100_0.1_iid_fix_a1_bn_1_1  Train Epoch: 87(26%)  Loss: 0.3461  Accuracy: 85.1994  Learning rate: 0.1  Epoch Finished Time: 0:00:00  Experiment Finished Time: 0:00:08
Model: 0_CIFAR10_label_conv_1_100_0.1_iid_fix_a1_bn_1_1  Train Epoch: 87(52%)  Loss: 0.3457  Accuracy: 85.

Model: 0_CIFAR10_label_conv_1_100_0.1_iid_fix_a1_bn_1_1  Test Epoch: 95(100%)  Loss: 1.5053  Accuracy: 64.1474
Model: 0_CIFAR10_label_conv_1_100_0.1_iid_fix_a1_bn_1_1  Train Epoch: 96(0%)  Loss: 0.3394  Accuracy: 85.5420  Learning rate: 0.1  Epoch Finished Time: 0:00:01  Experiment Finished Time: 0:00:03
Model: 0_CIFAR10_label_conv_1_100_0.1_iid_fix_a1_bn_1_1  Train Epoch: 96(26%)  Loss: 0.3392  Accuracy: 85.5500  Learning rate: 0.1  Epoch Finished Time: 0:00:00  Experiment Finished Time: 0:00:02
Model: 0_CIFAR10_label_conv_1_100_0.1_iid_fix_a1_bn_1_1  Train Epoch: 96(52%)  Loss: 0.3392  Accuracy: 85.5453  Learning rate: 0.1  Epoch Finished Time: 0:00:00  Experiment Finished Time: 0:00:02
Model: 0_CIFAR10_label_conv_1_100_0.1_iid_fix_a1_bn_1_1  Train Epoch: 96(78%)  Loss: 0.3388  Accuracy: 85.5658  Learning rate: 0.1  Epoch Finished Time: 0:00:00  Experiment Finished Time: 0:00:02
Model: 0_CIFAR10_label_conv_1_100_0.1_iid_fix_a1_bn_1_1  Test Epoch: 96(100%)  Loss: 1.5054  Accuracy: 64.

tensor([[[ 0.1525, -0.0393, -0.1730],
         [-0.1450,  0.0000, -0.1476],
         [-0.1478,  0.0538,  0.0000]],

        [[-0.1300,  0.0000,  0.1578],
         [ 0.0000,  0.0000,  0.0000],
         [ 0.1125, -0.1531,  0.1092]],

        [[ 0.0000,  0.1051,  0.0000],
         [ 0.0854, -0.1499, -0.1850],
         [ 0.0000, -0.1001, -0.0954]]], device='cuda:0',
       grad_fn=<SelectBackward>)
