In [None]:
import  torch, os
import  numpy as np
from    MiniImagenet import MiniImagenet
import  scipy.stats
from    torch.utils.data import DataLoader
from    torch.optim import lr_scheduler
import  random, sys, pickle
import  argparse
import skimage.transform

from LoadUnlableData import UnlabData

#from MAMLROBUST import Meta
# import dataset_input
# import utilities
import time
from MetaFT import Meta



def mean_confidence_interval(accs, confidence=0.95):
    n = accs.shape[0]
    m, se = np.mean(accs), scipy.stats.sem(accs)
    h = se * scipy.stats.t._ppf((1 + confidence) / 2, n - 1)
    return m, h


def main():

    torch.manual_seed(222)
    torch.cuda.manual_seed_all(222)
    np.random.seed(222)

    print(args)

    config = [
        ('conv2d', [32, 3, 3, 3, 1, 0]),
        ('relu', [True]),
        ('bn', [32]),
        ('max_pool2d', [2, 2, 0]),
        ('conv2d', [32, 32, 3, 3, 1, 0]),
        ('relu', [True]),
        ('bn', [32]),
        ('max_pool2d', [2, 2, 0]),
        ('conv2d', [32, 32, 3, 3, 1, 0]),
        ('relu', [True]),
        ('bn', [32]),
        ('max_pool2d', [2, 2, 0]),
        ('conv2d', [32, 32, 3, 3, 1, 0]),
        ('relu', [True]),
        ('bn', [32]),
        ('max_pool2d', [2, 1, 0]),
        ('flatten', []),
        ('linear', [args.n_way, 32 * 5 * 5])
    ]

    device = torch.device('cuda:0')
    maml = Meta(args, config, device).to(device)
#     maml = Meta(args, config)

    start_epoch = 0
    start_step = 0
    filename = 'mamltradesrseps2self1.pt'
    #maml = Meta(args, config).to(device)
    if os.path.isfile(filename):
        print("=> loading checkpoint '{}'".format(filename))
        checkpoint = torch.load(filename)
        start_epoch = checkpoint['epoch']
        start_step = checkpoint['step']
        maml.net.load_state_dict(checkpoint['state_dict'])
        #maml = maml.to(device)
        print("=> loaded checkpoint '{}' (epoch {})"
                  .format(filename, checkpoint['epoch']))
    else:
        print("=> no checkpoint found at '{}'".format(filename))
        #maml = Meta(args, config).to(device)



    tmp = filter(lambda x: x.requires_grad, maml.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    print(maml)
    print('Total trainable tensors:', num)

    # batchsz here means total episode number
    mini = MiniImagenet('../', mode='train', n_way=args.n_way, k_shot=args.k_spt,
                        k_query=args.k_qry,
                        batchsz=10000, resize=args.imgsz)
    mini_test = MiniImagenet('../', mode='test', n_way=args.n_way, k_shot=args.k_spt,
                             k_query=args.k_qry,
                             batchsz=100, resize=args.imgsz)
    
    tinyimg = UnlabData()
    batchsiz = 20
    
    

    for epoch in range(start_epoch, args.epoch//10000):
        # fetch meta_batchsz num of episode each time
        db = DataLoader(mini, args.task_num, shuffle=True, num_workers=0, pin_memory=True)
        
            

        for step, (x_spt, y_spt, x_qry, y_qry) in enumerate(db):
#             if step == 1:
#                 t = time.perf_counter()
#             if step == 499:
#                 ExecTime = time.perf_counter() - t
#                 print(ExecTime)
            x_unlab = torch.zeros((args.task_num, args.n_way, batchsiz, 3, 84, 84))

            x_spt, y_spt, x_qry, y_qry = x_spt.to(device), y_spt.to(device), x_qry.to(device), y_qry.to(device)
            
            for i in range(args.task_num):
                for j in range(args.n_way):
                    index = y_spt[i][j].cpu().numpy()
                    temp_train = tinyimg.train_data[index].get_next_batch(batchsiz,multiple_passes=True)
                    
                    x_unlab[i][j]= torch.from_numpy(temp_train.astype(np.float32))
            x_unlab = x_unlab.to(device)
            

            accs, accs_adv = maml(x_spt, y_spt, x_qry, y_qry, x_unlab)

            if step % 30 == 0:
                print('step:', step, '\ttraining acc:', accs)
                print('step:', step, '\ttraining acc_adv:', accs_adv)
                state = {'epoch': epoch, 'step': step, 'state_dict': maml.net.state_dict()}
                #torch.save(state, 'mamltradesrseps2self.pt')

            if step % 500 == 0 or step % 2500 == 0:  # evaluation
                db_test = DataLoader(mini_test, 1, shuffle=True, num_workers=0, pin_memory=True)
                accs_all_test = []
                accsadv_all_test = []
                accsadvpr_all_test = []

                for x_spt, y_spt, x_qry, y_qry in db_test:
                    x_spt, y_spt, x_qry, y_qry = x_spt.squeeze(0).to(device), y_spt.squeeze(0).to(device), \
                                                 x_qry.squeeze(0).to(device), y_qry.squeeze(0).to(device)
#                     x_spt, y_spt, x_qry, y_qry = x_spt.squeeze(0), y_spt.squeeze(0), \
#                                                  x_qry.squeeze(0), y_qry.squeeze(0)

                    accs, accs_adv, accs_adv_prior = maml.finetunning(x_spt, y_spt, x_qry, y_qry)
                    accs_all_test.append(accs)
                    accsadv_all_test.append(accs_adv)
                    accsadvpr_all_test.append(accs_adv_prior)

                # [b, update_step+1]
                accs = np.array(accs_all_test).mean(axis=0).astype(np.float16)
                accs_adv = np.array(accsadv_all_test).mean(axis=0).astype(np.float16)
                accs_adv_prior = np.array(accsadvpr_all_test).mean(axis=0).astype(np.float16)
                print('Test acc:', accs)
                print('Test acc_adv:', accs_adv)
                print('Test acc_adv_prior:', accs_adv_prior)


if __name__ == '__main__':

    argparser = argparse.ArgumentParser()
    argparser.add_argument('--epoch', type=int, help='epoch number', default=30000)
    argparser.add_argument('--n_way', type=int, help='n way', default=5)
    argparser.add_argument('--k_spt', type=int, help='k shot for support set', default=1)
    argparser.add_argument('--k_qry', type=int, help='k shot for query set', default=15)
    argparser.add_argument('--imgsz', type=int, help='imgsz', default=84)
    argparser.add_argument('--imgc', type=int, help='imgc', default=3)
    argparser.add_argument('--task_num', type=int, help='meta batch size, namely task num', default=4)
    argparser.add_argument('--meta_lr', type=float, help='meta-level outer learning rate', default=1e-3)
    argparser.add_argument('--update_lr', type=float, help='task-level inner update learning rate', default=0.01)
    argparser.add_argument('--update_step', type=int, help='task-level inner update steps', default=5)
    argparser.add_argument('--update_step_test', type=int, help='update steps for finetunning', default=10)
    argparser.add_argument('--batch-size', default=2048, type=int, help='batch size')
    argparser.add_argument('--dataset-path', default = './cifar10', type=str, help='dataset folder')
    
    #argparser.add_argument('--fast', action="store_true", help='whether to use fgsm')

    args = argparser.parse_args(args=[])

    main()

Namespace(batch_size=2048, dataset_path='./cifar10', epoch=30000, imgc=3, imgsz=84, k_qry=15, k_spt=1, meta_lr=0.001, n_way=5, task_num=4, update_lr=0.01, update_step=5, update_step_test=10)
=> no checkpoint found at 'mamltradesrseps2self1.pt'
Meta(
  (net): Learner(
    conv2d:(ch_in:3, ch_out:32, k:3x3, stride:1, padding:0)
    relu:(True,)
    bn:(32,)
    max_pool2d:(k:2, stride:2, padding:0)
    conv2d:(ch_in:32, ch_out:32, k:3x3, stride:1, padding:0)
    relu:(True,)
    bn:(32,)
    max_pool2d:(k:2, stride:2, padding:0)
    conv2d:(ch_in:32, ch_out:32, k:3x3, stride:1, padding:0)
    relu:(True,)
    bn:(32,)
    max_pool2d:(k:2, stride:2, padding:0)
    conv2d:(ch_in:32, ch_out:32, k:3x3, stride:1, padding:0)
    relu:(True,)
    bn:(32,)
    max_pool2d:(k:2, stride:1, padding:0)
    flatten:()
    linear:(in:800, out:5)
    
    (vars): ParameterList(
        (0): Parameter containing: [torch.cuda.FloatTensor of size 32x3x3x3 (GPU 0)]
        (1): Parameter containing: [torch.



step: 0 	training acc: [0.17333333 0.22333333 0.23333333 0.24333333 0.25       0.24333333]
step: 0 	training acc_adv: [0. 0. 0. 0. 0. 0.]


	nonzero()
Consider using one of the following signatures instead:
	nonzero(*, bool as_tuple) (Triggered internally at  /pytorch/torch/csrc/utils/python_arg_parser.cpp:766.)
  corr_ind = (torch.eq(pred_q, y_qry) == True).nonzero()


Test acc: [0.1965 0.2351 0.2408 0.2429 0.2417 0.2428 0.2432 0.2439 0.2444 0.2451
 0.2454]
Test acc_adv: [0.0008   0.002533 0.0028   0.0028   0.002533 0.002934 0.0028   0.003067
 0.0028   0.0028   0.003067]
Test acc_adv_prior: [0.003874 0.01078  0.01132  0.01107  0.01028  0.011955 0.01127  0.01234
 0.01117  0.011246 0.01224 ]
step: 30 	training acc: [0.18333333 0.36333333 0.35666667 0.37       0.35666667 0.36333333]
step: 30 	training acc_adv: [0. 0. 0. 0. 0. 0.]
step: 60 	training acc: [0.19       0.28666667 0.32       0.31       0.31333333 0.32333333]
step: 60 	training acc_adv: [0. 0. 0. 0. 0. 0.]
step: 90 	training acc: [0.19333333 0.32666667 0.34666667 0.36       0.38333333 0.38666667]
step: 90 	training acc_adv: [0. 0. 0. 0. 0. 0.]
step: 120 	training acc: [0.23       0.24333333 0.29333333 0.30666667 0.32       0.31      ]
step: 120 	training acc_adv: [0. 0. 0. 0. 0. 0.]
step: 150 	training acc: [0.19333333 0.3        0.31       0.30666667 0.29666667 0.32      ]
step: 150 	trainin