In [1]:
import  torch, os
import  numpy as np
from    MiniImagenet import MiniImagenet
import  scipy.stats
from    torch.utils.data import DataLoader
from    torch.optim import lr_scheduler
import  random, sys, pickle
import  argparse
import time
from metapgd import Meta

from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter('../run-final/pgd2/PGD-0.001-0.0003', comment='PGD-0.001-0.0003')

def mean_confidence_interval(accs, confidence=0.95):
    n = accs.shape[0] 
    m, se = np.mean(accs), scipy.stats.sem(accs)
    h = se * scipy.stats.t._ppf((1 + confidence) / 2, n - 1)
    return m, h

def main():
    
    torch.manual_seed(222)
    torch.cuda.manual_seed_all(222)
    np.random.seed(222)

    print(args)
    '''
    config = [
        ('conv2d', [32, 3, 3, 3, 1, 0]),
        ('relu', [True]),
        ('bn', [32]),
        ('max_pool2d', [2, 2, 0]),
        ('conv2d', [32, 32, 3, 3, 1, 0]),
        ('relu', [True]),
        ('bn', [32]),
        ('max_pool2d', [2, 2, 0]),
        ('conv2d', [32, 32, 3, 3, 1, 0]),
        ('relu', [True]),
        ('bn', [32]),
        ('max_pool2d', [2, 2, 0]),
        ('conv2d', [32, 32, 3, 3, 1, 0]),
        ('relu', [True]),
        ('bn', [32]),
        ('max_pool2d', [2, 1, 0]),
        ('flatten', []),
        ('linear', [args.n_way, 32 * 5 * 5])
    ]
    '''
    config = [
        ('conv2d', [32, 3, 3, 3, 1, 0]),
        ('relu', [True]),
        ('bn', [32]),
        ('max_pool2d', [2, 2, 0]),
        ('conv2d', [32, 32, 3, 3, 1, 0]),
        ('relu', [True]),
        ('bn', [32]),
        ('max_pool2d', [2, 2, 0]),
        ('conv2d', [32, 32, 3, 3, 1, 0]),
        ('relu', [True]),
        ('bn', [32]),
        ('max_pool2d', [2, 1, 0]),
        ('flatten', []),
        ('linear', [args.n_way, 32 * 2 * 2])
    ]
    
    device = torch.device('cuda:0')
    maml = Meta(args, config, device).to(device)
    
    
    start_epoch = 0
    start_step = 0
    #filename = 'mamlfgsmeps4_2.pt'
    filename = 'mamlfgsmeps2_8.pt'
    #maml = Meta(args, config).to(device)
    if os.path.isfile(filename):
        print("=> loading checkpoint '{}'".format(filename))
        checkpoint = torch.load(filename)
        start_epoch = checkpoint['epoch']
        start_step = checkpoint['step']
        maml.net.load_state_dict(checkpoint['state_dict'])
        #maml = maml.to(device)
        print("=> loaded checkpoint '{}' (epoch {})"
                  .format(filename, checkpoint['epoch']))
    else:
        print("=> no checkpoint found at '{}'".format(filename))
    

    tmp = filter(lambda x: x.requires_grad, maml.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    print(maml)
    print('Total trainable tensors:', num)

    # batchsz here means total episode number
    mini = MiniImagenet('../../dataset/', mode='train', n_way=args.n_way, k_shot=args.k_spt,
                        k_query=args.k_qry,
                        batchsz=4000, resize=args.imgsz)
    mini_test = MiniImagenet('../../dataset/', mode='test', n_way=args.n_way, k_shot=args.k_spt,
                             k_query=args.k_qry,
                             batchsz=100, resize=args.imgsz)
    tot_step = -args.task_num
    
    ExecTime = 0 # training 시간
    SampleTime = 0 # training 중 adv sample 생성 시간
    for epoch in range(30):
        # fetch meta_batchsz num of episode each time
        t = time.perf_counter()
        db = DataLoader(mini, args.task_num, shuffle=True, num_workers=0, pin_memory=True)
        for step, (x_spt, y_spt, x_qry, y_qry) in enumerate(db): # 0~124 -> batch 32일 때
            tot_step = tot_step + args.task_num
            '''
            if step == 1:
                t = time.perf_counter()
            if step == 499:
                ExecTime = time.perf_counter() - t
                print(ExecTime)
            if step == 501:
                t = time.perf_counter()
            if step == 999:
                ExecTime = time.perf_counter() - t
                print(ExecTime)
            '''
            x_spt, y_spt, x_qry, y_qry = x_spt.to(device), y_spt.to(device), x_qry.to(device), y_qry.to(device)

            accs, accs_adv, loss_q, loss_q_adv, make_time = maml(x_spt, y_spt, x_qry, y_qry)
            SampleTime += make_time
            if step % 10 == 0:
                print('step:', step, '\ttraining acc:', accs)
                print('step:', step, '\ttraining acc_adv:', accs_adv)
                writer.add_scalar("acc/train", accs[-1],tot_step)
                writer.add_scalar("acc_adv/train", accs_adv[-1],tot_step)
                writer.add_scalar("loss/train", loss_q,tot_step)
                writer.add_scalar("loss_adv/train", loss_q_adv,tot_step)
                state = {'epoch': epoch, 'step': step, 'state_dict': maml.net.state_dict()}
                torch.save(state, 'mamlfgsmeps4_2.pt')
            
            if step == 124:  # evaluation -> 학습에는 전혀 영향을 주지 않음, copy network를 사용하므로
                ExecTime += time.perf_counter() - t
                db_test = DataLoader(mini_test, 1, shuffle=True, num_workers=0, pin_memory=True)
                accs_all_test = []
                accsadv_all_test = []
                accsadvpr_all_test = []
                loss_all_test = []
                loss_adv_all_test = []
                
                for x_spt, y_spt, x_qry, y_qry in db_test:
                    x_spt, y_spt, x_qry, y_qry = x_spt.squeeze(0).to(device), y_spt.squeeze(0).to(device), \
                                                 x_qry.squeeze(0).to(device), y_qry.squeeze(0).to(device)

                    accs, accs_adv, accs_adv_prior, loss_q, loss_q_adv = maml.finetunning(x_spt, y_spt, x_qry, y_qry)
                    accs_all_test.append(accs)
                    accsadv_all_test.append(accs_adv)
                    accsadvpr_all_test.append(accs_adv_prior)
                    loss_all_test.append(loss_q.item())
                    loss_adv_all_test.append(loss_q_adv.item())
                    
                # [b, update_step+1]
                accs = np.array(accs_all_test).mean(axis=0).astype(np.float16)
                accs_adv = np.array(accsadv_all_test).mean(axis=0).astype(np.float16)
                accs_adv_prior = np.array(accsadvpr_all_test).mean(axis=0).astype(np.float16)
                loss_q = np.array(loss_all_test).mean()
                loss_q_adv = np.array(loss_adv_all_test).mean()
                print('Test acc:', accs)
                print('Test acc_adv:', accs_adv)
                print('Test acc_adv_prior:', accs_adv_prior)
                #writer.add_scalar("acc/test", accs[-1],tot_step)
                #writer.add_scalar("acc_adv/test", accs_adv[-1],tot_step)
                #writer.add_scalar("acc_adv_prior/test", accs_adv_prior[-1],tot_step)
                #writer.add_scalar("loss/test", loss_q,tot_step)
                #writer.add_scalar("loss_adv/test", loss_q_adv,tot_step)
                
                writer.add_scalar("acc/test_epoch", accs[-1],epoch)
                writer.add_scalar("acc_adv/test_epoch", accs_adv[-1],epoch)
                writer.add_scalar("loss/epoch", loss_q, epoch)
                writer.add_scalar("loss_adv/epoch", loss_q_adv, epoch)
                
                writer.add_scalar("train_time/epoch", ExecTime, epoch)
                writer.add_scalar("make_time/epoch", SampleTime, epoch)
            

if __name__ == '__main__':

    argparser = argparse.ArgumentParser()
    argparser.add_argument('--epoch', type=int, help='epoch number', default=60000)
    argparser.add_argument('--n_way', type=int, help='n way', default=5)
    argparser.add_argument('--k_spt', type=int, help='k shot for support set', default=1)
    argparser.add_argument('--k_qry', type=int, help='k shot for query set', default=15)
    argparser.add_argument('--imgsz', type=int, help='imgsz', default=28)
    argparser.add_argument('--imgc', type=int, help='imgc', default=3)
    argparser.add_argument('--task_num', type=int, help='meta batch size, namely task num', default=32)
    argparser.add_argument('--meta_lr', type=float, help='meta-level outer learning rate', default=0.001) #0.001 - 0.0002 기존
    argparser.add_argument('--adv_lr', type=float, help='adv-level learning rate', default=0.0003)
    argparser.add_argument('--rho', type=float, help='aRUB-rho', default=0.03)
    argparser.add_argument('--update_lr', type=float, help='task-level inner update learning rate', default=0.01)
    argparser.add_argument('--update_step', type=int, help='task-level inner update steps', default=5)
    argparser.add_argument('--update_step_test', type=int, help='update steps for finetunning', default=10)
    
    #argparser.add_argument('--fast', action="store_true", help='whether to use fgsm')

    args = argparser.parse_args(args=[])

    main()

Namespace(epoch=60000, imgc=3, imgsz=28, k_qry=15, k_spt=1, meta_lr=0.001, n_way=5, task_num=32, update_lr=0.01, update_step=5, update_step_test=10)
init
=> no checkpoint found at 'mamlfgsmeps2_8.pt'
Meta(
  (net): Learner(
    conv2d:(ch_in:3, ch_out:32, k:3x3, stride:1, padding:0)
    relu:(True,)
    bn:(32,)
    max_pool2d:(k:2, stride:2, padding:0)
    conv2d:(ch_in:32, ch_out:32, k:3x3, stride:1, padding:0)
    relu:(True,)
    bn:(32,)
    max_pool2d:(k:2, stride:2, padding:0)
    conv2d:(ch_in:32, ch_out:32, k:3x3, stride:1, padding:0)
    relu:(True,)
    bn:(32,)
    max_pool2d:(k:2, stride:1, padding:0)
    flatten:()
    linear:(in:128, out:5)
    
    (vars): ParameterList(
        (0): Parameter containing: [torch.float32 of size 32x3x3x3 (GPU 0)]
        (1): Parameter containing: [torch.float32 of size 32 (GPU 0)]
        (2): Parameter containing: [torch.float32 of size 32 (GPU 0)]
        (3): Parameter containing: [torch.float32 of size 32 (GPU 0)]
        (4): Param

KeyboardInterrupt: 

In [None]:
import  torch, os
import  numpy as np
from    MiniImagenet import MiniImagenet
import  scipy.stats
from    torch.utils.data import DataLoader
from    torch.optim import lr_scheduler
import  random, sys, pickle
import  argparse
import time
from metaaRUB import Meta

from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter('../run-final/pgd2/aRUB-0.001-0.001-0.03', comment='aRUB-0.001-0.001-0.03')

def mean_confidence_interval(accs, confidence=0.95):
    n = accs.shape[0] 
    m, se = np.mean(accs), scipy.stats.sem(accs)
    h = se * scipy.stats.t._ppf((1 + confidence) / 2, n - 1)
    return m, h

def main():
    
    torch.manual_seed(222)
    torch.cuda.manual_seed_all(222)
    np.random.seed(222)

    print(args)
    '''
    config = [
        ('conv2d', [32, 3, 3, 3, 1, 0]),
        ('relu', [True]),
        ('bn', [32]),
        ('max_pool2d', [2, 2, 0]),
        ('conv2d', [32, 32, 3, 3, 1, 0]),
        ('relu', [True]),
        ('bn', [32]),
        ('max_pool2d', [2, 2, 0]),
        ('conv2d', [32, 32, 3, 3, 1, 0]),
        ('relu', [True]),
        ('bn', [32]),
        ('max_pool2d', [2, 2, 0]),
        ('conv2d', [32, 32, 3, 3, 1, 0]),
        ('relu', [True]),
        ('bn', [32]),
        ('max_pool2d', [2, 1, 0]),
        ('flatten', []),
        ('linear', [args.n_way, 32 * 5 * 5])
    ]
    '''
    config = [
        ('conv2d', [32, 3, 3, 3, 1, 0]),
        ('relu', [True]),
        ('bn', [32]),
        ('max_pool2d', [2, 2, 0]),
        ('conv2d', [32, 32, 3, 3, 1, 0]),
        ('relu', [True]),
        ('bn', [32]),
        ('max_pool2d', [2, 2, 0]),
        ('conv2d', [32, 32, 3, 3, 1, 0]),
        ('relu', [True]),
        ('bn', [32]),
        ('max_pool2d', [2, 1, 0]),
        ('flatten', []),
        ('linear', [args.n_way, 32 * 2 * 2])
    ]
    
    device = torch.device('cuda:0')
    maml = Meta(args, config, device).to(device)
    
    
    start_epoch = 0
    start_step = 0
    #filename = 'mamlfgsmeps4_2.pt'
    filename = 'mamlfgsmeps2_8.pt'
    #maml = Meta(args, config).to(device)
    if os.path.isfile(filename):
        print("=> loading checkpoint '{}'".format(filename))
        checkpoint = torch.load(filename)
        start_epoch = checkpoint['epoch']
        start_step = checkpoint['step']
        maml.net.load_state_dict(checkpoint['state_dict'])
        #maml = maml.to(device)
        print("=> loaded checkpoint '{}' (epoch {})"
                  .format(filename, checkpoint['epoch']))
    else:
        print("=> no checkpoint found at '{}'".format(filename))
    

    tmp = filter(lambda x: x.requires_grad, maml.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    print(maml)
    print('Total trainable tensors:', num)

    # batchsz here means total episode number
    mini = MiniImagenet('../../dataset/', mode='train', n_way=args.n_way, k_shot=args.k_spt,
                        k_query=args.k_qry,
                        batchsz=4000, resize=args.imgsz)
    mini_test = MiniImagenet('../../dataset/', mode='test', n_way=args.n_way, k_shot=args.k_spt,
                             k_query=args.k_qry,
                             batchsz=100, resize=args.imgsz)
    tot_step = -args.task_num
    
    ExecTime = 0 # training 시간
    SampleTime = 0 # training 중 adv sample 생성 시간
    for epoch in range(30):
        # fetch meta_batchsz num of episode each time
        t = time.perf_counter()
        db = DataLoader(mini, args.task_num, shuffle=True, num_workers=0, pin_memory=True)
        for step, (x_spt, y_spt, x_qry, y_qry) in enumerate(db): # 0~124 -> batch 32일 때
            tot_step = tot_step + args.task_num
            '''
            if step == 1:
                t = time.perf_counter()
            if step == 499:
                ExecTime = time.perf_counter() - t
                print(ExecTime)
            if step == 501:
                t = time.perf_counter()
            if step == 999:
                ExecTime = time.perf_counter() - t
                print(ExecTime)
            '''
            x_spt, y_spt, x_qry, y_qry = x_spt.to(device), y_spt.to(device), x_qry.to(device), y_qry.to(device)

            accs, accs_adv, loss_q, loss_q_adv, make_time = maml(x_spt, y_spt, x_qry, y_qry)
            SampleTime += make_time
            if step % 10 == 0:
                print('step:', step, '\ttraining acc:', accs)
                print('step:', step, '\ttraining acc_adv:', accs_adv)
                writer.add_scalar("acc/train", accs[-1],tot_step)
                writer.add_scalar("acc_adv/train", accs_adv[-1],tot_step)
                writer.add_scalar("loss/train", loss_q,tot_step)
                writer.add_scalar("loss_adv/train", loss_q_adv,tot_step)
                state = {'epoch': epoch, 'step': step, 'state_dict': maml.net.state_dict()}
                torch.save(state, 'mamlfgsmeps4_2.pt')
            
            if step == 124:  # evaluation -> 학습에는 전혀 영향을 주지 않음, copy network를 사용하므로
                ExecTime += time.perf_counter() - t
                db_test = DataLoader(mini_test, 1, shuffle=True, num_workers=0, pin_memory=True)
                accs_all_test = []
                accsadv_all_test = []
                accsadvpr_all_test = []
                loss_all_test = []
                loss_adv_all_test = []
                
                for x_spt, y_spt, x_qry, y_qry in db_test:
                    x_spt, y_spt, x_qry, y_qry = x_spt.squeeze(0).to(device), y_spt.squeeze(0).to(device), \
                                                 x_qry.squeeze(0).to(device), y_qry.squeeze(0).to(device)

                    accs, accs_adv, accs_adv_prior, loss_q, loss_q_adv = maml.finetunning(x_spt, y_spt, x_qry, y_qry)
                    accs_all_test.append(accs)
                    accsadv_all_test.append(accs_adv)
                    accsadvpr_all_test.append(accs_adv_prior)
                    loss_all_test.append(loss_q.item())
                    loss_adv_all_test.append(loss_q_adv.item())
                    
                # [b, update_step+1]
                accs = np.array(accs_all_test).mean(axis=0).astype(np.float16)
                accs_adv = np.array(accsadv_all_test).mean(axis=0).astype(np.float16)
                accs_adv_prior = np.array(accsadvpr_all_test).mean(axis=0).astype(np.float16)
                loss_q = np.array(loss_all_test).mean()
                loss_q_adv = np.array(loss_adv_all_test).mean()
                print('Test acc:', accs)
                print('Test acc_adv:', accs_adv)
                print('Test acc_adv_prior:', accs_adv_prior)
                #writer.add_scalar("acc/test", accs[-1],tot_step)
                #writer.add_scalar("acc_adv/test", accs_adv[-1],tot_step)
                #writer.add_scalar("acc_adv_prior/test", accs_adv_prior[-1],tot_step)
                #writer.add_scalar("loss/test", loss_q,tot_step)
                #writer.add_scalar("loss_adv/test", loss_q_adv,tot_step)
                
                writer.add_scalar("acc/test_epoch", accs[-1],epoch)
                writer.add_scalar("acc_adv/test_epoch", accs_adv[-1],epoch)
                writer.add_scalar("loss/epoch", loss_q, epoch)
                writer.add_scalar("loss_adv/epoch", loss_q_adv, epoch)
                
                writer.add_scalar("train_time/epoch", ExecTime, epoch)
                writer.add_scalar("make_time/epoch", SampleTime, epoch)
            

if __name__ == '__main__':

    argparser = argparse.ArgumentParser()
    argparser.add_argument('--epoch', type=int, help='epoch number', default=60000)
    argparser.add_argument('--n_way', type=int, help='n way', default=5)
    argparser.add_argument('--k_spt', type=int, help='k shot for support set', default=1)
    argparser.add_argument('--k_qry', type=int, help='k shot for query set', default=15)
    argparser.add_argument('--imgsz', type=int, help='imgsz', default=28)
    argparser.add_argument('--imgc', type=int, help='imgc', default=3)
    argparser.add_argument('--task_num', type=int, help='meta batch size, namely task num', default=32)
    argparser.add_argument('--meta_lr', type=float, help='meta-level outer learning rate', default=0.001) #0.001 - 0.0002 기존
    argparser.add_argument('--adv_lr', type=float, help='adv-level learning rate', default=0.001)
    argparser.add_argument('--rho', type=float, help='aRUB-rho', default=0.03)
    argparser.add_argument('--update_lr', type=float, help='task-level inner update learning rate', default=0.01)
    argparser.add_argument('--update_step', type=int, help='task-level inner update steps', default=5)
    argparser.add_argument('--update_step_test', type=int, help='update steps for finetunning', default=10)
    
    #argparser.add_argument('--fast', action="store_true", help='whether to use fgsm')

    args = argparser.parse_args(args=[])

    main()