In [1]:
import os
import random
import uuid
from collections import defaultdict
from timeit import default_timer as timer

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from loguru import logger
from torch.distributions.normal import Normal
# from torch.utils.tensorboard import SummaryWriter

import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
import os, pickle

  from .autonotebook import tqdm as notebook_tqdm


In [9]:
import torch
from torchmeta.toy import Sinusoid
from torchmeta.transforms import ClassSplitter
from torchmeta.utils.data import BatchMetaDataLoader


class ToTensor1D(object):
    """Convert a `numpy.ndarray` to tensor. Unlike `ToTensor` from torchvision,
    this converts numpy arrays regardless of the number of dimensions.

    Converts automatically the array to `float32`.
    """

    def __call__(self, array):
        return torch.tensor(array.astype("float32"))

    def __repr__(self):
        return self.__class__.__name__ + "()"


def get_sine_loader(batch_size, num_steps, shots=10, test_shots=15):
    dataset_transform = ClassSplitter(
        shuffle=True, num_train_per_class=shots, num_test_per_class=test_shots
    )
    transform = ToTensor1D()
    dataset = Sinusoid(
        shots + test_shots,
        num_tasks=batch_size * num_steps,
        transform=transform,
        target_transform=transform,
        dataset_transform=dataset_transform,
    )
    loader = BatchMetaDataLoader(
        dataset, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True,
    )
    return loader

In [10]:
def get_loader(task, batch_size, num_steps):
    if task == "sine":
        loader = get_sine_loader(batch_size=batch_size, num_steps=num_steps)
    else:
        raise ValueError(f"task={task} is not implemented")
    return loader


def get_task(saved, task, batch_size, num_steps):
    if not saved:
        return get_loader(task, batch_size, num_steps)

    os.makedirs("data/saved", exist_ok=True)
    filename = f"data/saved/{task}_{batch_size}_{num_steps}.pkl"

    if os.path.exists(filename):
        with open(filename, "rb") as handle:
            tasks = pickle.load(handle)
    else:
        test_task_gen = get_loader(
            task=task, batch_size=batch_size, num_steps=num_steps
        )
        tasks = [t for t in test_task_gen]
        with open(filename, "wb") as handle:
            pickle.dump(tasks, handle, protocol=pickle.HIGHEST_PROTOCOL)
    return tasks


In [11]:
test_tasks = get_task(
    saved=True,
    task= "sine",
    batch_size= 64,
    num_steps= 250,
)