In [None]:
img_root_dir = "./images"
train_csv_path = "./label/train.csv"
val_csv_path = "./label/val.csv"
test_csv_path = "./label/test.csv"
import os

import numpy as np
import pandas as pd
import collections

In [None]:
def read_csv(csv_path):
    dict = collections.defaultdict(list)
    df = pd.read_csv(csv_path)
    for index,row in df.iterrows():
        dict[row["label"]].append(row["filename"])
    return dict
train_dict = read_csv(train_csv_path)
val_dict = read_csv(val_csv_path)
test_dict = read_csv(test_csv_path)

In [None]:
from PIL import Image
import numpy as np
from torchvision import transforms
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

resize_transform = transforms.Resize(84) # 提前对图片进行缩放，以节省内存空间，将最短的边变成84
def build_data(data_dict):
    datas = []
    labels = []
    label_index = 0
    for label in data_dict.keys(): # 对图片的标签进行迭代
        for path in data_dict[label]: # 对标签对应的文件名进行迭代
            img_path = os.path.join(img_root_dir,path) 
            img = Image.open(img_path) # 读取文件
            img = resize_transform(img) # 进行缩放
            datas.append(img) 
            labels.append(label_index)

        label_index += 1
    return {"datas":datas,"labels":labels}


In [None]:
train_data = build_data(train_dict)
val_data = build_data(val_dict)
test_data = build_data(test_dict)

In [None]:
from torch.utils.data import Dataset,DataLoader
import torch
class CategoriesSampler():
    """
        目的是为了随机产生K_way*(N_support+N_query)个图片对应的index
    """
    def __init__(self, data, n_batch, K_way, N_per):
        self.n_batch = n_batch
        self.K_way = K_way
        self.N_per = N_per
        labels = np.array(data["labels"]) # [0,0,0,0,1,1,1,1,2,2,2,2……]
        self.index = [] # 记录label对应的索引位置
        for i in range(max(labels)+1):
            ind = np.argwhere(labels == i).reshape(-1)
            self.index.append(torch.from_numpy(ind))   

    def __len__(self):
        return self.n_batch
    
    def __iter__(self):
        for i_batch in range(self.n_batch):  
            batch = []
            classes = torch.randperm(len(self.index))[:self.K_way] # 随机选择K个类别构成support set和query set
            for c in classes:
                l = self.index[c] # 类别c对应的图片数组的索引，如 l = [5,6,7,8,9]
                pos = torch.randperm(len(l))[:self.N_per] # 如 pos = [4,1,0]
                batch.append(l[pos]) # 如 l[pos] = [9,6,5]

            batch = torch.stack(batch).reshape(-1)
            yield batch

In [None]:
class MiniImageNet(Dataset):

    def __init__(self, data):

        self.datas = data["datas"]
        self.labels = data["labels"]
        self.transform = transforms.Compose([
            transforms.CenterCrop(84),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])

    def __len__(self):
        return len(self.datas)

    def __getitem__(self, i):
        img, label = self.datas[i], self.labels[i]
        return self.transform(img), label

In [None]:

import torch.nn as nn
import torch
import numpy as np
import torch.nn.functional as F
import torch.optim as optim


class CNN_Net(nn.Module):
    """
        用于特征提取
    """

    def __init__(self, input_dim):
        super(CNN_Net, self).__init__()
        
        self.input_dim = input_dim
        def conv_block(in_channel,out_channel):
            return nn.Sequential(
                nn.Conv2d(in_channel, out_channel, 3,padding=1),
                nn.BatchNorm2d(out_channel),
                nn.ReLU(),
                nn.MaxPool2d(2)
            )
        self.encoder = nn.Sequential(
            conv_block(input_dim,64),
            conv_block(64,64),
            conv_block(64,64),
            conv_block(64,64),
        )
    def forward(self, x):
        x = self.encoder(x)
        x = x.view(x.size(0), -1)
        return x


class Prototypicl_Net():
    def __init__(self, input_dim):

        self.device = torch.device(
            "cuda:0" if torch.cuda.is_available() else "cpu")
        self.z_dim = 1600
        self.cnn_net = CNN_Net(input_dim).to(self.device)
        
        self.train_dataset = MiniImageNet(train_data)
        self.val_dataset = MiniImageNet(val_data)
        self.test_dataset = MiniImageNet(test_data)

        # 优化器
        self.optimizer = torch.optim.Adam(self.cnn_net.parameters(), lr=0.001)
        self.scheduler = optim.lr_scheduler.StepLR(
            self.optimizer, 1, gamma=0.5, last_epoch=-1)
        print("使用：", self.device)

    def cal_euc_distance(self, query_z, center,K_way, N_query):
        """
            计算query_z与center的距离
            query_z : (K_way*N_query,z_dim)
            center : (K_way,z_dim)
        """
        center = center.unsqueeze(0).expand(
            K_way*N_query, K_way, self.z_dim)  # (K_way*N_query,K_way,z_dim)
        query_z = query_z.unsqueeze(1).expand(
            K_way*N_query, K_way, self.z_dim)  # (K_way*N_query,K_way,z_dim)

        return torch.pow(query_z-center, 2).sum(2)  # (K_way*N_query,K_way)

    def loss_acc(self, query_z, center, K_way, N_query):
        """
            计算loss和acc
            query_z : (K_way*N_query,z_dim)
            center : (K_way,z_dim)
        """
        target_inds = torch.arange(0, K_way).view(K_way, 1).expand(
            K_way, N_query).long().to(self.device) # shape=(K_way, N_query)
        
        distance = self.cal_euc_distance(query_z, center,K_way, N_query)    # (K_way*N_query,K_way) 
        predict_label = torch.argmin(distance, dim=1)  # (K_way*N_query) 预测出来的label

        acc = torch.eq(target_inds.contiguous().view(-1),
                        predict_label).float().mean() # 准确率

        loss = F.log_softmax(-distance, dim=1).view(K_way,
                                                    N_query, K_way)  # (K_way,N_query,K_way)
        loss = - \
            loss.gather(dim=2, index=target_inds.unsqueeze(2)).view(-1).mean()
        return loss, acc

    def set_forward_loss(self, K_way, N_shot, N_query,sample_datas):
        """
            sample_datas： shape(K_way*(N_shot+N_query),3,84,84)
        """

        z = self.cnn_net(sample_datas) # shape=(K_way*(N_shot+N_query),z_dim) ，将support set和query set都进行向量化表示
        z = z.view(K_way,N_shot+N_query,-1) # shape = (K_way,N_shot+N_query,1600)
        
        support_z = z[:,:N_shot] # support set的向量化表示 shape=(K_way,N_shot,1600)
        query_z = z[:,N_shot:].contiguous().view(K_way*N_query,-1) # Query set的向量化表示 shape=(K_way*N_query,1600)
        
        center = torch.mean(support_z, dim=1) # 计算support set的向量均值，shape=(K_way,1600)
        return self.loss_acc(query_z, center,K_way,N_query)

    def train(self, epochs, epoch_size):
        """
            进行一个episode的训练，随机采样N个类，每个类使用K个数据集。
        """
        K_way = 20
        N_shot = 5
        N_query = 15

        self.cnn_net.train()
        train_sampler = CategoriesSampler(train_data, epoch_size,K_way,N_shot+N_query)

        train_loader = DataLoader(dataset=self.train_dataset, batch_sampler=train_sampler,
                                    num_workers=16, pin_memory=True)

        for epoch in range(epochs):
            for i,batch_data in enumerate(train_loader):
                imgs,_ = batch_data[0].to(self.device),batch_data[1]
                loss, acc = self.set_forward_loss(K_way,N_shot,N_query,imgs)
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
            self.scheduler.step()

            val_acc1 = self.eval_model(val_data,self.val_dataset,5,1,15,600) 
            val_acc2 = self.eval_model(val_data,self.val_dataset,5,5,15,600) 
            
            test_acc1 = self.eval_model(test_data,self.test_dataset,5,1,15,600)
            
            test_acc2 = self.eval_model(test_data,self.test_dataset,5,5,15,600) 

            print("验证集：1-shot：{:.4},5-shot：{:.4}".format(val_acc1,val_acc2))
            
            print("测试集：1-shot：{:.4},5-shot：{:.4}".format(test_acc1,test_acc2))
            

    def eval_model(self,datas,data_set,K_way,N_shot,N_query, eval_step):
        self.cnn_net.eval()
        batch_sampler = CategoriesSampler(datas,eval_step,K_way,N_shot+N_query)
        data_loader = DataLoader(dataset=data_set, batch_sampler=batch_sampler,
                                            num_workers=16, pin_memory=True)
        accs = []
        losses = []
        for i,batch_data in enumerate(data_loader):
            imgs,_ = batch_data[0].to(self.device),batch_data[1]
            loss, acc = self.set_forward_loss(K_way,N_shot,N_query,imgs)
            accs.append(acc.item())
            losses.append(loss.item())
        self.cnn_net.train()
        return sum(accs)/eval_step

In [None]:
net = Prototypicl_Net(3) # input_dim,K_way, N_shot, N_query
net.train(50, 2000)