In [1]:
import  torch, os
import  numpy as np
from    MiniImagenet import MiniImagenet
import  scipy.stats
from    torch.utils.data import DataLoader
from    torch.optim import lr_scheduler
import  random, sys, pickle
import  argparse
import time

#from ANIP import Meta
#from metafgsmanil import Meta
#from metafgsm import Meta
#from MAMLMeta import Meta
#from meta import Meta
#from Adv_Quer import Meta
#from metafgsmnewnew import Meta
from metafgsm import Meta


def mean_confidence_interval(accs, confidence=0.95):
    n = accs.shape[0]
    m, se = np.mean(accs), scipy.stats.sem(accs)
    h = se * scipy.stats.t._ppf((1 + confidence) / 2, n - 1)
    return m, h


def main():

    torch.manual_seed(222)
    torch.cuda.manual_seed_all(222)
    np.random.seed(222)

    print(args)

    config = [
        ('conv2d', [32, 3, 3, 3, 1, 0]),
        ('relu', [True]),
        ('bn', [32]),
        ('max_pool2d', [2, 2, 0]),
        ('conv2d', [32, 32, 3, 3, 1, 0]),
        ('relu', [True]),
        ('bn', [32]),
        ('max_pool2d', [2, 2, 0]),
        ('conv2d', [32, 32, 3, 3, 1, 0]),
        ('relu', [True]),
        ('bn', [32]),
        ('max_pool2d', [2, 2, 0]),
        ('conv2d', [32, 32, 3, 3, 1, 0]),
        ('relu', [True]),
        ('bn', [32]),
        ('max_pool2d', [2, 1, 0]),
        ('flatten', []),
        ('linear', [args.n_way, 32 * 5 * 5])
    ]

    device = torch.device('cuda:2')
    maml = Meta(args, config, device).to(device)
    
    
    start_epoch = 0
    start_step = 0
    filename = 'mamlfgsmeps2_4.pt'
    #maml = Meta(args, config).to(device)
    if os.path.isfile(filename):
        print("=> loading checkpoint '{}'".format(filename))
        checkpoint = torch.load(filename)
        start_epoch = checkpoint['epoch']
        start_step = checkpoint['step']
        maml.net.load_state_dict(checkpoint['state_dict'])
        #maml = maml.to(device)
        print("=> loaded checkpoint '{}' (epoch {})"
                  .format(filename, checkpoint['epoch']))
    else:
        print("=> no checkpoint found at '{}'".format(filename))
    

    tmp = filter(lambda x: x.requires_grad, maml.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    print(maml)
    print('Total trainable tensors:', num)

    # batchsz here means total episode number
    mini = MiniImagenet('../../../dataset/', mode='train', n_way=args.n_way, k_shot=args.k_spt,
                        k_query=args.k_qry,
                        batchsz=10000, resize=args.imgsz)
    mini_test = MiniImagenet('../../../dataset/', mode='test', n_way=args.n_way, k_shot=args.k_spt,
                             k_query=args.k_qry,
                             batchsz=100, resize=args.imgsz)

    for epoch in range(args.epoch//10000):
        # fetch meta_batchsz num of episode each time
        db = DataLoader(mini, args.task_num, shuffle=True, num_workers=0, pin_memory=True)

        for step, (x_spt, y_spt, x_qry, y_qry) in enumerate(db):
            if step == 1:
                t = time.perf_counter()
            if step == 499:
                ExecTime = time.perf_counter() - t
                print(ExecTime)
            if step == 501:
                t = time.perf_counter()
            if step == 999:
                ExecTime = time.perf_counter() - t
                print(ExecTime)

            x_spt, y_spt, x_qry, y_qry = x_spt.to(device), y_spt.to(device), x_qry.to(device), y_qry.to(device)

            accs, accs_adv = maml(x_spt, y_spt, x_qry, y_qry)

            if step % 30 == 0:
                print('step:', step, '\ttraining acc:', accs)
                print('step:', step, '\ttraining acc_adv:', accs_adv)
                state = {'epoch': epoch, 'step': step, 'state_dict': maml.net.state_dict()}
                torch.save(state, 'mamlfgsmeps2_4.pt')

            if step % 1000 == 0:  # evaluation
                db_test = DataLoader(mini_test, 1, shuffle=True, num_workers=0, pin_memory=True)
                accs_all_test = []
                accsadv_all_test = []
                accsadvpr_all_test = []

                for x_spt, y_spt, x_qry, y_qry in db_test:
                    x_spt, y_spt, x_qry, y_qry = x_spt.squeeze(0).to(device), y_spt.squeeze(0).to(device), \
                                                 x_qry.squeeze(0).to(device), y_qry.squeeze(0).to(device)

                    accs, accs_adv, accs_adv_prior = maml.finetunning(x_spt, y_spt, x_qry, y_qry)
                    accs_all_test.append(accs)
                    accsadv_all_test.append(accs_adv)
                    accsadvpr_all_test.append(accs_adv_prior)

                # [b, update_step+1]
                accs = np.array(accs_all_test).mean(axis=0).astype(np.float16)
                accs_adv = np.array(accsadv_all_test).mean(axis=0).astype(np.float16)
                accs_adv_prior = np.array(accsadvpr_all_test).mean(axis=0).astype(np.float16)
                print('Test acc:', accs)
                print('Test acc_adv:', accs_adv)
                print('Test acc_adv_prior:', accs_adv_prior)


if __name__ == '__main__':

    argparser = argparse.ArgumentParser()
    argparser.add_argument('--epoch', type=int, help='epoch number', default=60000)
    argparser.add_argument('--n_way', type=int, help='n way', default=5)
    argparser.add_argument('--k_spt', type=int, help='k shot for support set', default=1)
    argparser.add_argument('--k_qry', type=int, help='k shot for query set', default=15)
    argparser.add_argument('--imgsz', type=int, help='imgsz', default=84)
    argparser.add_argument('--imgc', type=int, help='imgc', default=3)
    argparser.add_argument('--task_num', type=int, help='meta batch size, namely task num', default=4)
    argparser.add_argument('--meta_lr', type=float, help='meta-level outer learning rate', default=1e-3)
    argparser.add_argument('--update_lr', type=float, help='task-level inner update learning rate', default=0.01)
    argparser.add_argument('--update_step', type=int, help='task-level inner update steps', default=5)
    argparser.add_argument('--update_step_test', type=int, help='update steps for finetunning', default=10)
    
    #argparser.add_argument('--fast', action="store_true", help='whether to use fgsm')

    args = argparser.parse_args(args=[])

    main()

Namespace(epoch=60000, imgc=3, imgsz=84, k_qry=15, k_spt=1, meta_lr=0.001, n_way=5, task_num=4, update_lr=0.01, update_step=5, update_step_test=10)
=> loading checkpoint 'mamlfgsmeps2_4.pt'
=> loaded checkpoint 'mamlfgsmeps2_4.pt' (epoch 1)
Meta(
  (net): Learner(
    conv2d:(ch_in:3, ch_out:32, k:3x3, stride:1, padding:0)
    relu:(True,)
    bn:(32,)
    max_pool2d:(k:2, stride:2, padding:0)
    conv2d:(ch_in:32, ch_out:32, k:3x3, stride:1, padding:0)
    relu:(True,)
    bn:(32,)
    max_pool2d:(k:2, stride:2, padding:0)
    conv2d:(ch_in:32, ch_out:32, k:3x3, stride:1, padding:0)
    relu:(True,)
    bn:(32,)
    max_pool2d:(k:2, stride:2, padding:0)
    conv2d:(ch_in:32, ch_out:32, k:3x3, stride:1, padding:0)
    relu:(True,)
    bn:(32,)
    max_pool2d:(k:2, stride:1, padding:0)
    flatten:()
    linear:(in:800, out:5)
    
    (vars): ParameterList(
        (0): Parameter containing: [torch.cuda.FloatTensor of size 32x3x3x3 (GPU 2)]
        (1): Parameter containing: [torch.cud

step: 810 	training acc: [0.21333333 0.36333333 0.41666667 0.42666667 0.43333333 0.42666667]
step: 810 	training acc_adv: [0.   0.   0.   0.   0.   0.33]
step: 840 	training acc: [0.25       0.38       0.41       0.45666667 0.45333333 0.45333333]
step: 840 	training acc_adv: [0.   0.   0.   0.   0.   0.34]
step: 870 	training acc: [0.15666667 0.28       0.34666667 0.42       0.43333333 0.44666667]
step: 870 	training acc_adv: [0.  0.  0.  0.  0.  0.3]
step: 900 	training acc: [0.2        0.35333333 0.39666667 0.42       0.43       0.43333333]
step: 900 	training acc_adv: [0.         0.         0.         0.         0.         0.31666667]
step: 930 	training acc: [0.25333333 0.42333333 0.48       0.50333333 0.49333333 0.50666667]
step: 930 	training acc_adv: [0.  0.  0.  0.  0.  0.4]
step: 960 	training acc: [0.25333333 0.36666667 0.41666667 0.42666667 0.43333333 0.42666667]
step: 960 	training acc_adv: [0.         0.         0.         0.         0.         0.30666667]
step: 990 	train

step: 2100 	training acc: [0.19666667 0.40666667 0.48333333 0.52       0.52       0.53666667]
step: 2100 	training acc_adv: [0.   0.   0.   0.   0.   0.41]
step: 2130 	training acc: [0.17333333 0.3        0.34       0.35333333 0.35666667 0.36      ]
step: 2130 	training acc_adv: [0.         0.         0.         0.         0.         0.23333333]
step: 2160 	training acc: [0.17       0.37666667 0.43666667 0.45       0.44666667 0.44333333]
step: 2160 	training acc_adv: [0.         0.         0.         0.         0.         0.33333333]
step: 2190 	training acc: [0.2        0.36333333 0.46333333 0.47333333 0.48333333 0.49      ]
step: 2190 	training acc_adv: [0.         0.         0.         0.         0.         0.33666667]
step: 2220 	training acc: [0.26333333 0.40666667 0.42       0.43666667 0.44333333 0.45333333]
step: 2220 	training acc_adv: [0.   0.   0.   0.   0.   0.33]
step: 2250 	training acc: [0.17       0.28666667 0.36333333 0.38666667 0.40333333 0.40666667]
step: 2250 	traini

step: 900 	training acc: [0.17666667 0.32666667 0.41333333 0.44666667 0.45       0.45      ]
step: 900 	training acc_adv: [0.         0.         0.         0.         0.         0.30666667]
step: 930 	training acc: [0.27666667 0.44       0.48666667 0.50333333 0.50333333 0.50333333]
step: 930 	training acc_adv: [0.   0.   0.   0.   0.   0.36]
step: 960 	training acc: [0.18666667 0.43333333 0.46666667 0.49666667 0.50666667 0.51333333]
step: 960 	training acc_adv: [0.         0.         0.         0.         0.         0.37333333]
step: 990 	training acc: [0.16333333 0.32333333 0.4        0.41666667 0.43       0.44333333]
step: 990 	training acc_adv: [0.         0.         0.         0.         0.         0.29333333]
1156.2234692079946
Test acc: [0.2001 0.3403 0.3728 0.3813 0.386  0.3894 0.3909 0.3904 0.3914 0.3909
 0.3909]
Test acc_adv: [0.02066 0.07733 0.0977  0.10626 0.1088  0.1095  0.11    0.10986 0.1093
 0.1096  0.10986]
Test acc_adv_prior: [0.1006 0.2223 0.256  0.2722 0.276  0.2761 

KeyboardInterrupt: 