In [1]:
# import os
# import zipfile
# root_path = 'mini-imagenet'
# zip_ref = zipfile.ZipFile(os.path.join(root_path,'mini-imagenet.zip'), 'r')
# zip_ref.extractall(root_path)
# zip_ref.close()

In [2]:
import os
import torch
from torch.utils.data import Dataset
from torchvision.transforms import transforms
import numpy as np
import collections
from PIL import Image
import csv
import random

In [3]:
class MiniImagenet(Dataset):
    """
    put mini-imagenet files as :
    root :
        |- images/*.jpg includes all images
        |- train.csv  len(labels):64
        |- test.csv   len(labels):20
        |- val.csv    len(labels):16
    NOTICE: meta-learning is different from general supervised learning, especially the concept of batch and set.
    batch: contains several sets
    sets: conains n_way * k_shot for meta-train set, n_way * n_query for meta-test set.
    """

    def __init__(self, root, mode, batchsz, n_way, k_shot, k_query, resize, startidx=0):
        """
        :param startidx: start to index label from startidx
        """
        
        self.batchsz = batchsz  # batch of set, not batch of imgs
        self.n_way = n_way  # n-way
        self.k_shot = k_shot  # k-shot
        self.k_query = k_query  # for evaluation
        self.setsz = self.n_way * self.k_shot  # num of samples per support set
        self.querysz = self.k_query  # number of samples per set for evaluation
        self.resize = resize  # resize to
        self.startidx = startidx  # index label not from 0, but from startidx
        print('shuffle DB :%s, b:%d, %d-way, %d-shot, %d-query, resize:%d' % (mode, batchsz, n_way, k_shot, k_query, resize))

        if mode == 'train':
            self.transform = transforms.Compose([lambda x: Image.open(x).convert('RGB'),
                                                 transforms.Resize((self.resize, self.resize)),
                                                 # transforms.RandomHorizontalFlip(),
                                                 # transforms.RandomRotation(5),
                                                 transforms.ToTensor(),
                                                 transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
                                                 ])
        else:
            self.transform = transforms.Compose([lambda x: Image.open(x).convert('RGB'),
                                                 transforms.Resize((self.resize, self.resize)),
                                                 transforms.ToTensor(),
                                                 transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
                                                 ])

        self.path = os.path.join(root, 'images')  # image path
        
        # :return: dictLabels: {label1: [filename1, filename2, filename3, filename4,...], }
        dictLabels = self.loadCSV(os.path.join(root, mode + '.csv'))  # csv path
        self.data = []
        self.img2label = {}
        for i, (label, imgs) in enumerate(dictLabels.items()):
            self.data.append(imgs)  # [[img1, img2, ...], [img111, ...]]
            self.img2label[label] = i + self.startidx  # {"img_name[:9]":label}   {'n0134':0,'n0123':1,'n0123':2}
        self.cls_num = len(self.data)

        self.create_batch(self.batchsz)

    def loadCSV(self, csvf):
        """
        return a dict saving the information of csv
        :param splitFile: csv file name
        :return: {label:[file1, file2 ...]}
        """
        dictLabels = {}
        with open(csvf) as csvfile:
            csvreader = csv.reader(csvfile, delimiter=',')
            next(csvreader, None)  # skip (filename, label)
            for i, row in enumerate(csvreader):
                filename = row[0]
                label = row[1]
                # append filename to current label
                if label in dictLabels.keys():
                    dictLabels[label].append(filename)
                else:
                    dictLabels[label] = [filename]
        return dictLabels

    def create_batch(self, batchsz):
        """
        create batch for meta-learning.
        ×episode× here means batch, and it means how many sets we want to retain.
        :param episodes: batch size
        :return:
        """
        self.support_x_batch = []  # support set batch
        self.query_x_batch = []  # query set batch
        for b in range(batchsz):  # for each batch
            # 1.select n_way classes randomly
            selected_cls = np.random.choice(self.cls_num, self.n_way, False)  #False: no duplicate不重复
            np.random.shuffle(selected_cls)
            support_x = []
            query_x = []
            for cls in selected_cls:
                
                # 2. select k_shot + k_query for each class
                selected_imgs_idx = np.random.choice(len(self.data[cls]), self.k_shot + self.k_query, False)
                np.random.shuffle(selected_imgs_idx)
                indexDtrain = np.array(selected_imgs_idx[:self.k_shot])  # idx for Dtrain
                indexDtest = np.array(selected_imgs_idx[self.k_shot:])  # idx for Dtest
                support_x.append(
                    np.array(self.data[cls])[indexDtrain].tolist())        # get all images filename for current Dtrain
            query_x.append(np.array(self.data[cls])[indexDtest].tolist())

            # shuffle the correponding relation between support set and query set
            random.shuffle(support_x)
            random.shuffle(query_x)

            self.support_x_batch.append(support_x)  # append set to current sets
            self.query_x_batch.append(query_x)  # append sets to current sets

    def __getitem__(self, index):
        """
        index means index of sets, 0<= index <= batchsz-1
        :param index:
        :return:
        """
        support_x = torch.FloatTensor(self.setsz, 3, self.resize, self.resize)
        support_y = np.zeros((self.setsz), dtype=np.int32)
        
        query_x = torch.FloatTensor(self.querysz, 3, self.resize, self.resize)
        query_y = np.zeros((self.querysz), dtype=np.int32)
        
        flatten_support_x = [os.path.join(self.path, item)
                             for sublist in self.support_x_batch[index] for item in sublist]
        support_y = np.array(
            [self.img2label[item[:9]]  # filename:n0153282900000005.jpg, the first 9 characters treated as label
             for sublist in self.support_x_batch[index] for item in sublist]).astype(np.int32)
        

        flatten_query_x = [os.path.join(self.path, item)
                           for sublist in self.query_x_batch[index] for item in sublist]
        query_y = np.array([self.img2label[item[:9]]
                            for sublist in self.query_x_batch[index] for item in sublist]).astype(np.int32)
        
        # print('global:', support_y, query_y)
        # support_y: [setsz]
        # query_y: [querysz]
        # unique: [n-way], sorted
#         unique = np.unique(support_y)
#         random.shuffle(unique)
        # relative means the label ranges from 0 to n-way
#         support_y_relative = np.zeros(self.setsz)
#         query_y_relative = np.zeros(self.querysz)

#         for idx, l in enumerate(unique):
#             support_y_relative[support_y == l] = idx
#             query_y_relative[query_y == l] = idx

        # print('relative:', support_y_relative, query_y_relative)

        for i, path in enumerate(flatten_support_x):
            support_x[i] = self.transform(path)

        for i, path in enumerate(flatten_query_x):
            query_x[i] = self.transform(path)
        # print(support_set_y)
        # return support_x, torch.LongTensor(support_y), query_x, torch.LongTensor(query_y)

        return support_x, torch.LongTensor(support_y), query_x, torch.LongTensor(query_y)

    def __len__(self):
        # as we have built up to batchsz of sets, you can sample some small batch size of sets.
        return self.batchsz

# 模型

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [20]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(3,32,3,1,1),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.MaxPool2d(2),     #84 -> 42
            
            nn.Conv2d(32,32,3,1,1),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.MaxPool2d(2),    #42 -> 21
            
            nn.Conv2d(32,32,3,1,1),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.MaxPool2d(2),   #21 -> 10
            
            nn.Conv2d(32,32,3,1,1),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.MaxPool2d(2),  #10 -> 5
        )
        self.fc = nn.Linear(32 * 5 * 5, 500)
 
    def forward(self, x):
        x = self.conv(x)
        x = x.reshape(-1,32*5*5)
        x = self.fc(x)
        return x


class Tripletnet(nn.Module):
    def __init__(self, embeddingnet):
        super(Tripletnet, self).__init__()
        self.embeddingnet = embeddingnet
        self.anchor_embedding = nn.Embedding(100,500)  #100个类别，每个类别500维
        
    def forward(self, x_support,y_support):  #[2,3,84,84],[2]
        p = x_support[0].unsqueeze(0)
        n = x_support[1].unsqueeze(0)
        
        embedded_a = self.anchor_embedding(y_support[0])  #support_x中第一张图片为正样本，第二张为负样本
        embedded_a = embedded_a.unsqueeze(0)
        embedded_p = self.embeddingnet(p)
        embedded_n = self.embeddingnet(n)
        return embedded_a, embedded_p, embedded_n


In [34]:
device = torch.device('cuda')
model = Net().to(device)
tnet = Tripletnet(model).to(device)

loss_fn = torch.nn.TripletMarginLoss(margin = 100)
loss_ce = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(tnet.parameters(), lr=0.1)

In [None]:
mini_train = MiniImagenet('mini-imagenet/', mode='train', n_way=2, k_shot=1,
                        k_query=1,batchsz=1000, resize=84, startidx=0)
mini_test = MiniImagenet('mini-imagenet/', mode='test', n_way=2, k_shot=1,
                         k_query=1,batchsz=1000, resize=84, startidx=64)
mini_val = MiniImagenet('mini-imagenet/', mode='val', n_way=2, k_shot=1,
                         k_query=1,batchsz=1000, resize=84, startidx=84)


for epoch in range(30):
    db = torch.utils.data.DataLoader(mini_train, 4, shuffle=True, num_workers=0, pin_memory=True) 
    correct_num = 0
    
    for k,(support_x,support_y,query_x,query_y) in enumerate(db):
        support_x,support_y,query_x,query_y = support_x.to(device),support_y.to(device),query_x.to(device),query_y.to(device)
        pred_cos = list()
        q_y = list()
        pred = list()
        for i in support_y.eq(query_y).int():
            idx= torch.argmax(i)
            q_y.append(idx.item())
            
        for i in range(4):
            
            for step in range(4):
                embedded_a, embedded_p, embedded_n = tnet(support_x[i],support_y[i])
                loss = loss_fn(embedded_a, embedded_p, embedded_n)  #loss = criterion(anchor, positive, negative)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            embedded_query = model(query_x[i])
            q_p_cos = F.cosine_similarity(embedded_query,embedded_p)
            q_n_cos = F.cosine_similarity(embedded_query,embedded_n)
            pred_cos.append([q_p_cos.item(),q_n_cos.item()])
            

        if q_p_cos > q_n_cos:
            pred.append(support_y[i][0].item())
        else:
            pred.append(support_y[i][1].item())

        for step in range(4):
            pred_cos_tensor = torch.tensor(pred_cos, requires_grad=True)
            loss_query = loss_ce(pred_cos_tensor,torch.tensor(q_y))
            optimizer.zero_grad()
            loss_query.backward()
            optimizer.step()

        correct_num += torch.tensor(pred).eq(query_y.squeeze().cpu()).sum().item()
        if k % 100 == 0:
            print('train_acc:',correct_num/1000)
    
#----------------------------------------------val验证--------------------------------------------------------------------
    
        db_val = torch.utils.data.DataLoader(mini_val, 4, shuffle=True, num_workers=0, pin_memory=True) 
        val_correct_num = 0

        for support_x,support_y,query_x,query_y in db_val:
            support_x,support_y,query_x,query_y = support_x.to(device),support_y.to(device),query_x.to(device),query_y.to(device)
            pred_cos = list()
            q_y = list()
            pred = list()
            for i in support_y.eq(query_y).int():
                idx= torch.argmax(i)
                q_y.append(idx.item())

            for i in range(4):

                for step in range(4):
                    embedded_a, embedded_p, embedded_n = tnet(support_x[i],support_y[i])
                    loss = loss_fn(embedded_a, embedded_p, embedded_n)  #loss = criterion(anchor, positive, negative)
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()

                embedded_query = model(query_x[i])
                q_p_cos = F.cosine_similarity(embedded_query,embedded_p)
                q_n_cos = F.cosine_similarity(embedded_query,embedded_n)
                pred_cos.append([q_p_cos.item(),q_n_cos.item()])

                if q_p_cos > q_n_cos:
                    pred.append(support_y[i][0].item())
                else:
                    pred.append(support_y[i][1].item())

            for step in range(4):
                pred_cos_tensor = torch.tensor(pred_cos, requires_grad=True)
                loss_query = loss_ce(pred_cos_tensor,torch.tensor(q_y))
                optimizer.zero_grad()
                loss_query.backward()
                optimizer.step()

            val_correct_num += torch.tensor(pred).eq(query_y.squeeze().cpu()).sum().item()
        print('val_acc:',correct_num/1000)
    

shuffle DB :train, b:1000, 2-way, 1-shot, 1-query, resize:84
shuffle DB :test, b:1000, 2-way, 1-shot, 1-query, resize:84
shuffle DB :val, b:1000, 2-way, 1-shot, 1-query, resize:84
train_acc: 0.001
val_acc: 0.001
val_acc: 0.002
val_acc: 0.003
val_acc: 0.004
val_acc: 0.005
val_acc: 0.005
val_acc: 0.006
val_acc: 0.007
val_acc: 0.008
val_acc: 0.008
val_acc: 0.009
val_acc: 0.01
val_acc: 0.012
val_acc: 0.012
val_acc: 0.012
val_acc: 0.013
val_acc: 0.014
val_acc: 0.014
val_acc: 0.015
val_acc: 0.016
val_acc: 0.016
val_acc: 0.016
val_acc: 0.017


In [None]:
from copy import deepcopy
from torch import nn

class Meta(nn.Module):
    def __init__(self, config):
        super(Meta, self).__init__()   
        self.update_lr = 0.1 ## learner中的学习率，即\alpha
        self.meta_lr = 1e-3 ## meta-learner的学习率，即\beta
        self.n_way = 2 ## 5种类型
        self.k_shot = 1 ## 一个样本
        self.k_query = 1 ## 15个查询样本
        self.update_step = 5   ## task-level inner update steps
        self.update_step_test = 5 ## 用在finetunning这个函数中
        
        self.net = Learner(config)     ## base-learner
        self.meta_optim = torch.optim.Adam(self.net.parameters(), lr = self.meta_lr)
        
    def forward(self, x_support, y_support, x_query, y_query):  # [2,5,3,84,84]
        task_num, n, c, h, w = x_support.size()   # [2,5,3,84,84]
        querysz = x_query.size(1)      ## 1
        losses_q = [0 for _ in range(self.update_step +1)] ## losses_q[i] is the loss on step i
        corrects = [0 for _ in range(self.update_step +1)]
        
            
        pred_cos = list()
        q_y = list()
        pred = list()
        for i in support_y.eq(query_y).int():
            idx= torch.argmax(i)
            q_y.append(idx.item())

        for i in range(task_num):
            for step in range(4):
                embedded_a, embedded_p, embedded_n = tnet(support_x[i],support_y[i])
                loss = loss_fn(embedded_a, embedded_p, embedded_n)  #loss = criterion(anchor, positive, negative)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            embedded_query = model(query_x[i])
            q_p_cos = F.cosine_similarity(embedded_query,embedded_p)
            q_n_cos = F.cosine_similarity(embedded_query,embedded_n)
            pred_cos.append([q_p_cos.item(),q_n_cos.item()])

            if q_p_cos > q_n_cos:
                pred.append(support_y[i][0].item())
            else:
                pred.append(support_y[i][1].item())

        for step in range(4):
            pred_cos_tensor = torch.tensor(pred_cos, requires_grad=True)
            loss_query = loss_ce(pred_cos_tensor,torch.tensor(q_y))
            optimizer.zero_grad()
            loss_query.backward()
            optimizer.step()

        correct_num += torch.tensor(pred).eq(query_y.squeeze().cpu()).sum().item()
        
        
        
        
            
            ## 第0步更新
            logits = self.net(x_support[i], vars=None, bn_training = True)   ## return 一个经过各层计算后的y
            ## logits : 5*5的tensor
            loss = F.cross_entropy(logits, y_support[i])  ## 计算Loss值
            grad = torch.autograd.grad(loss, self.net.parameters())      ##计算梯度。如果输入x，输出是y，则求y关于x的导数（梯度）
            tuples = zip(grad, self.net.parameters() ) ##将梯度grad和参数\theta一一对应起来
            ## fast_weights这一步相当于求了一个\theta - \alpha*\nabla(L)
            fast_weights = list( map(lambda p: p[1] - self.update_lr * p[0], tuples) ) ##更新子任务的theta__

            
            
            ### 在query集上进行测试，计算准确率
            ## 这一步使用的是更新前的参数
            with torch.no_grad():
                logits_q = self.net(x_query[i], self.net.parameters(), bn_training = True) ## logits_q :torch.Size([75, 5])
                loss_q = F.cross_entropy(logits_q, y_query[i]) ## y_query : torch.Size([75])
                losses_q[0] += loss_q ##将loss存在数组的第一个位置
                pred_q = F.softmax(logits_q, dim = 1).argmax(dim=1) ## size = (75)
                correct = torch.eq(pred_q, y_query[i]).sum().item()## item()取出tensor中的数字
                corrects[0] += correct
            
            ### 在query集上进行测试，计算准确率
            ## 这一步使用的是更新后的参数
            with torch.no_grad():
                logits_q = self.net(x_query[i], fast_weights, bn_training = True)   #fast_weights
                loss_q = F.cross_entropy(logits_q, y_query[i])
                losses_q[1] += loss_q
                pred_q = F.softmax(logits_q, dim = 1).argmax(dim=1)
                correct = torch.eq(pred_q, y_query[i]).sum().item()
                corrects[1] += correct
             
            
            for k in range(1, self.update_step):   #k: [1,2,3,4]
                logits = self.net(x_support[i], fast_weights, bn_training =True)
                loss = F.cross_entropy(logits, y_support[i])
                grad = torch.autograd.grad(loss, fast_weights)
                tuples = zip(grad,fast_weights)
                fast_weights = list(map(lambda p:p[1] - self.update_lr * p[0], tuples))  #更新参数
                
                if k < self.update_step - 1:  #k:[1,2,3]
                    with torch.no_grad():   
                        logits_q = self.net(x_query[i], fast_weights, bn_training = True)
                        loss_q = F.cross_entropy(logits_q, y_query[i])
                        losses_q[k+1] += loss_q
                        
                else:   #k: [4]
                    logits_q = self.net(x_query[i], fast_weights, bn_training = True)
                    loss_q = F.cross_entropy(logits_q, y_query[i])
                    losses_q[k+1] += loss_q
                
                with torch.no_grad():
                    pred_q = F.softmax(logits_q, dim=1).argmax(dim = 1)
                    correct = torch.eq(pred_q, y_query[i]).sum().item()
                    corrects[k+1] += correct
        
        
        ## 在一组task_num个任务结束后，求一个平均的loss,  残差可以吗？？？
        loss_q = losses_q[-1] / task_num     
        
        self.meta_optim.zero_grad()   ## 梯度清零
        loss_q.backward()        ## 计算梯度
        #self.meta_optim = torch.optim.Adam(self.net.parameters(), lr = self.meta_lr)
        self.meta_optim.step()   ## 用设置好的优化方法来迭代模型参数，这一步是meta步迭代, 更新多了 会不会过拟合???????????
        
        accs = np.array(corrects) / (querysz * task_num)   #corrects/75*task
        
        return accs
        
    
    def finetunning(self, x_support, y_support, x_query, y_query):
        assert len(x_support.shape) == 4
        
        querysz = x_query.size(0)
        
        corrects = [0 for _ in range(self.update_step_test + 1)]
        
        # in order to not ruin the state of running_mean/variance and bn_weight/bias
        # we finetunning on the copied model instead of self.net
        net = deepcopy(self.net)
        
        logits = net(x_support)
        loss = F.cross_entropy(logits, y_support)
        grad = torch.autograd.grad(loss, net.parameters())
        fast_weights = list(map(lambda p: p[1] - self.update_lr * p[0], zip(grad, net.parameters())))
        
        
        ## 开始训练前的准确率
        with torch.no_grad():
            logits_q = net(x_query, net.parameters(), bn_training = True)
            pred_q = F.softmax(logits_q, dim =1).argmax(dim=1)
            correct = torch.eq(pred_q, y_query).sum().item()
            corrects[0] += correct
         
        ## 训练后的准确率
        with torch.no_grad():
            logits_q = net(x_query, fast_weights, bn_training = True)
            pred_q = F.softmax(logits_q, dim = 1).argmax(dim=1)
            correct = torch.eq(pred_q, y_query).sum().item()
            corrects[1] += correct
            
        for k in range(1, self.update_step_test):
            logits = net(x_support, fast_weights, bn_training=True)
            loss = F.cross_entropy(logits, y_support)
            grad = torch.autograd.grad(loss, fast_weights)
            fast_weights = list(map(lambda p: p[1] - self.update_lr * p[0],   zip(grad, fast_weights)))
            
            logits_q = net(x_query, fast_weights, bn_training=True)
            loss_q = F.cross_entropy(logits_q, y_query)
            
            with torch.no_grad():
                pred_q = F.softmax(logits_q, dim =1).argmax(dim=1)
                correct = torch.eq(pred_q, y_query).sum().item()
                corrects[k+1] += correct
                
        del net
        
        accs = np.array(corrects) / querysz
        
        return accs