In [19]:
import torch
import numpy as np
import learn2learn as l2l
from learn2learn.data import *
import import_ipynb
import utils

In [49]:
class KShotLoader():
    def __init__(self,myds,num_tasks=1000,shots=2,ways=2,classes=None):
        self.shots = shots
        self.ways = ways
        self.myMds = l2l.data.MetaDataset(myds)
        if classes == None:
            n_classes = len(set(myds.labels))
            classes = [i for i in range(n_classes)]
        self.my_tasks = l2l.data.TaskDataset(self.myMds, task_transforms=[
                                l2l.data.transforms.FilterLabels(self.myMds,classes),
                                l2l.data.transforms.NWays(self.myMds,ways),
                                l2l.data.transforms.KShots(self.myMds,2*shots),
                                l2l.data.transforms.LoadData(self.myMds),
                                l2l.data.transforms.RemapLabels(self.myMds),
                                l2l.data.transforms.ConsecutiveLabels(self.myMds)
                                ],num_tasks=num_tasks)
    def get_task(self):
        data,labels = self.my_tasks.sample()
        adaptation_indices = np.zeros(data.size(0), dtype=bool)
        adaptation_indices[np.arange(self.shots*self.ways) * 2] = True
        evaluation_indices = torch.from_numpy(~adaptation_indices)
        adaptation_indices = torch.from_numpy(adaptation_indices)
        adaptation_data, adaptation_labels = data[adaptation_indices], labels[adaptation_indices]
        evaluation_data, evaluation_labels = data[evaluation_indices], labels[evaluation_indices]
        d_train = (adaptation_data,adaptation_labels)
        d_test = (evaluation_data,evaluation_labels)
        return d_train, d_test

In [50]:
meta_train_ds, meta_test_ds, _ = utils.euclideanDataset(n_samples=100,n_features=20,n_classes=10,batch_size=32)

In [51]:
classes_train = [i for i in range(5)]
classes_test = [i+5 for i in range(5)]
classes_train, classes_test

([0, 1, 2, 3, 4], [5, 6, 7, 8, 9])

In [56]:
kloader=KShotLoader(meta_train_ds,ways=10)

In [57]:
d_train,d_test=kloader.get_task()

In [58]:
d_train

(tensor([[-2.2269e+00, -3.2695e+00,  5.4442e+00, -4.8867e+00,  4.2291e+00,
           3.5576e+00,  2.5956e+00, -2.0264e+00,  4.8915e-01,  1.2715e+00,
           5.1616e+00,  3.5798e+00, -9.4333e-01,  1.2588e+00,  1.4621e+00,
          -1.2572e+00, -4.5453e+00, -3.3953e+00,  8.0548e+00,  3.9462e+00],
         [-5.9133e+00, -4.5788e+00,  3.7464e+00,  8.7112e+00,  1.8168e+00,
           8.5292e-02,  3.3876e+00,  1.7877e+00,  2.8035e+00,  5.6987e+00,
          -4.9392e+00, -3.6449e-01, -1.0126e+00, -3.3236e+00, -1.1476e+00,
           1.9044e+00, -2.6974e+00,  1.4532e+00, -1.0257e+00,  5.9243e+00],
         [ 7.3079e-01,  3.6682e+00, -4.1600e+00,  1.5629e+00, -4.9982e+00,
          -4.8190e+00,  1.0289e+00,  2.0922e+00,  4.0064e+00, -5.1631e+00,
          -6.5307e+00,  2.1510e+00,  4.1438e-01, -4.7272e+00,  1.4923e-01,
          -4.8651e+00,  1.4452e-01, -6.5634e-01,  8.1033e-01, -1.9831e+00],
         [ 4.3463e+00,  4.3864e+00, -5.6252e+00,  4.3318e+00, -7.6635e-01,
          -2.6735e+00,

In [59]:
d_train,d_test=kloader.get_task()

In [17]:
d_train

(tensor([[ 3.6960,  3.6126,  1.0981, -0.9689,  0.7627,  6.3903,  0.9497, -5.2207,
          -1.1006, -1.5891,  2.8220, -0.1910, -1.8938, -1.2033, -4.1035,  3.5602,
           2.1527, -0.2454, -3.7900,  2.2128],
         [ 2.4989,  1.0827,  3.5661,  1.1853,  1.2567,  1.1357,  7.4111, -0.7032,
           0.9706, -3.9028, -1.4625,  4.8641,  7.0013, -2.6186, -0.7804, -0.9068,
          -0.7404, -6.4392, -2.4066,  1.6961],
         [-2.5684,  0.5440, -2.7641,  1.0810, -1.3895,  2.2740, -2.4515,  1.1156,
           1.1550, -0.2270, -1.5329,  7.2667, -1.6682, -0.8797, -4.6239, -1.6831,
          -1.4690,  1.2760, -1.6313,  0.0796],
         [-3.3656,  2.2880, -2.2261,  3.9417,  2.2327,  5.2604, -1.2766,  1.7004,
           4.9334,  4.1778,  3.3986,  2.3361,  1.3132, -0.3988, -3.8092,  3.8010,
          -5.0442,  0.8580, -3.2479,  3.5913]]),
 tensor([1, 1, 0, 0]))