In [None]:
import  torch, os
import  numpy as np
import  scipy.stats
from    torch.utils.data import DataLoader
from    torch.optim import lr_scheduler
import  random, sys, pickle
import  argparse


from torchmeta.datasets import Omniglot
from torchmeta.transforms import Categorical, ClassSplitter, Rotation
from torchvision.transforms import Compose, Resize, ToTensor
from torchmeta.utils.data import BatchMetaDataLoader

from MetaFTOmni import Meta



def mean_confidence_interval(accs, confidence=0.95):
    n = accs.shape[0]
    m, se = np.mean(accs), scipy.stats.sem(accs)
    h = se * scipy.stats.t._ppf((1 + confidence) / 2, n - 1)
    return m, h


def main():

    torch.manual_seed(222)
    torch.cuda.manual_seed_all(222)
    np.random.seed(222)

    print(args)

    config = [
        ('conv2d', [64, 1, 3, 3, 2, 0]),
        ('relu', [True]),
        ('bn', [64]),
        ('conv2d', [64, 64, 3, 3, 2, 0]),
        ('relu', [True]),
        ('bn', [64]),
        ('conv2d', [64, 64, 3, 3, 2, 0]),
        ('relu', [True]),
        ('bn', [64]),
        ('conv2d', [64, 64, 2, 2, 1, 0]),
        ('relu', [True]),
        ('bn', [64]),
        ('flatten', []),
        ('linear', [args.n_way, 64])
    ]

    device = torch.device('cuda')
    
    best_cl = 0
    best_rb = 0
    filename = 'mamltrades_eps10_omniglot_5way.pt'
    maml = Meta(args, config, device).to(device)
    if os.path.isfile(filename):
        print("=> loading checkpoint '{}'".format(filename))
        checkpoint = torch.load(filename)
        best_cl = checkpoint['cl']
        best_rb = checkpoint['rb']
        maml.net.load_state_dict(checkpoint['state_dict'])
        #maml = maml.to(device)
#         print("=> loaded checkpoint '{}' (epoch {})"
#                   .format(filename, checkpoint['epoch']))
    else:
        print("=> no checkpoint found at '{}'".format(filename))

    tmp = filter(lambda x: x.requires_grad, maml.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    print(maml)
    print('Total trainable tensors:', num)

    # batchsz here means total episode number
    
    print(best_cl)
    print(best_rb)
    

    data_train = Omniglot("data",
                   # Number of ways
                   num_classes_per_task=args.n_way,
                   meta_train=True,
                   meta_val=False,
                   meta_test=False,
                   # Resize the images to 28x28 and converts them to PyTorch tensors (from Torchvision)
                   transform=Compose([Resize(28), ToTensor()]),
                   # Transform the labels to integers (e.g. ("Glagolitic/character01", "Sanskrit/character14", ...) to (0, 1, ...))
                   target_transform=Categorical(num_classes=args.n_way),
                   # Creates new virtual classes with rotated versions of the images (from Santoro et al., 2016)
                   #class_augmentations=[Rotation([90, 180, 270])],
                   download=True)
    
    data_test = Omniglot("data",
                   # Number of ways
                   num_classes_per_task=args.n_way,
                   meta_train=False,
                   meta_val=False,
                   meta_test=True,
                   # Resize the images to 28x28 and converts them to PyTorch tensors (from Torchvision)
                   transform=Compose([Resize(28), ToTensor()]),
                   # Transform the labels to integers (e.g. ("Glagolitic/character01", "Sanskrit/character14", ...) to (0, 1, ...))
                   target_transform=Categorical(num_classes=args.n_way),
                   # Creates new virtual classes with rotated versions of the images (from Santoro et al., 2016)
                   #class_augmentations=[Rotation([90, 180, 270])],
                   download=True)
    data_train = ClassSplitter(data_train, shuffle=True, num_train_per_class=args.k_spt, num_test_per_class=args.k_qry)
    data_test = ClassSplitter(data_test, shuffle=True, num_train_per_class=args.k_spt, num_test_per_class=args.k_qry)
    
    LST = None #UnlabData()
    batchsiz = 20
    
    


    for epoch in range(args.epoch//10000):
        # fetch meta_batchsz num of episode each time
        db = BatchMetaDataLoader(data_train, batch_size=args.task_num, num_workers=0)
        #db = DataLoader(data_train, args.task_num, shuffle=True, num_workers=0, pin_memory=True)

        for step, batch_train in enumerate(db):
        #for step, (x_spt, y_spt, x_qry, y_qry) in enumerate(db):
            


            x_spt, y_spt = batch_train["train"]
            x_qry, y_qry = batch_train["test"]

            x_spt, y_spt, x_qry, y_qry = x_spt.to(device), y_spt.to(device), x_qry.to(device), y_qry.to(device)
            

            
            if LST:
                x_unlab = torch.zeros((args.task_num, args.n_way, batchsiz, 3, 32, 32))
                for i in range(args.task_num):                   
                    with torch.no_grad():
                        *_, outputs = resmodel(x_spt[i])
                        _, y_unlab = outputs.max(1)
                    index = y_unlab.cpu().numpy()
                    for j in range(args.n_way):
                        temp_train = LST.train_data[index[j]].get_next_batch(batchsiz,multiple_passes=True)
                        x_unlab[i][j]= torch.from_numpy(temp_train.astype(np.float32))
                x_unlab = x_unlab.to(device)
            else:
                x_unlab = None
            

            accs, accs_adv = maml(x_spt, y_spt, x_qry, y_qry, x_unlab)
            
            if step % 5000 == 0:
                print('step:', step, '\ttraining acc:', accs)
                print('step:', step, '\ttraining acc_adv:', accs_adv)
                

            if step % 10000== 0: # evaluation
                #db_test = DataLoader(data_test, 1, shuffle=True, num_workers=0, pin_memory=True)
                db_test = BatchMetaDataLoader(data_test, batch_size=1, num_workers=0)
                accs_all_test = []
                accsadv_all_test = []
                accsadvpr_all_test = []

                for step_t, batch_test in enumerate(db_test):
                    x_spt, y_spt = batch_test["train"]
                    x_qry, y_qry = batch_test["test"]


                    x_spt, y_spt, x_qry, y_qry = x_spt.squeeze(0).to(device), y_spt.squeeze(0).to(device), \
                                                 x_qry.squeeze(0).to(device), y_qry.squeeze(0).to(device)

                    accs, accs_adv, accs_adv_prior = maml.finetunning(x_spt, y_spt, x_qry, y_qry)
                    accs_all_test.append(accs)
                    accsadv_all_test.append(accs_adv)
                    accsadvpr_all_test.append(accs_adv_prior)
                    if step_t == 10000:
                        break

                # [b, update_step+1]
                accs = np.array(accs_all_test).mean(axis=0).astype(np.float16)
                accs_adv = np.array(accsadv_all_test).mean(axis=0).astype(np.float16)
                accs_adv_prior = np.array(accsadvpr_all_test).mean(axis=0).astype(np.float16)
                print('Test acc:', accs)
                print('Test acc_adv:', accs_adv)
                print('Test acc_adv_prior:', accs_adv_prior)

                if best_cl < accs[-1] and best_rb < accs_adv[-1]:
                    best_cl = accs[-1]
                    best_rb = accs_adv[-1]
                    state = {'cl': best_cl, 'rb': best_rb, 'state_dict': maml.net.state_dict()}
                    torch.save(state, 'mamltrades_eps10_omniglot_5way.pt')
                    print(best_cl)
                    print(best_rb)



if __name__ == '__main__':

    argparser = argparse.ArgumentParser()
    argparser.add_argument('--epoch', type=int, help='epoch number', default=30000)
    argparser.add_argument('--n_way', type=int, help='n way', default=5)
    argparser.add_argument('--k_spt', type=int, help='k shot for support set', default=1)
    argparser.add_argument('--k_qry', type=int, help='k shot for query set', default=15)
    argparser.add_argument('--imgsz', type=int, help='imgsz', default=32)
    argparser.add_argument('--imgc', type=int, help='imgc', default=3)
    argparser.add_argument('--task_num', type=int, help='meta batch size, namely task num', default=4)
    argparser.add_argument('--meta_lr', type=float, help='meta-level outer learning rate', default=1e-3)
    argparser.add_argument('--update_lr', type=float, help='task-level inner update learning rate', default=0.01)
    argparser.add_argument('--update_step', type=int, help='task-level inner update steps', default=5)
    argparser.add_argument('--update_step_test', type=int, help='update steps for finetunning', default=10)
    
    #argparser.add_argument('--fast', action="store_true", help='whether to use fgsm')

    args = argparser.parse_args(args=[])

    main()