In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, Dataset
from typing import List, Tuple, Iterable


In [None]:
class SingleTaskDataset(Dataset[tuple[torch.Tensor, torch.Tensor]]):
    # create a dataset of sequences of length `n_control_bits` + `n_task_bits`
    # the sequences are bit strings. The first `n_control_bits` bits describe the task: 
    # they are zero everywhere except for the `task_num`-th bit, which is 1.
    # The next `n_task_bits` bits are random.
    # The target is the parity of the relevant variables, which are at the indices of the task bits
    # given in `relevant_vars`.
        def __init__(self,
                     n_task_bits: int,
                     n_control_bits: int,
                     task_num: int,
                     relevant_vars: torch.Tensor,
                     dataset_length: int):

            assert len(relevant_vars.shape) == 1
            assert relevant_vars.dtype == torch.int64
            assert relevant_vars.shape[0] <= n_task_bits
            assert all([0 <= i < n_task_bits for i in relevant_vars])
            self.data = torch.zeros(dataset_length, n_control_bits + n_task_bits)
            self.task_bits = torch.randint(0, 2, (dataset_length, n_task_bits), dtype=torch.float32)
            self.data[:, n_control_bits + 1:] = self.task_bits
            self.data[:, task_num] = 1.
            self.dataset_length = dataset_length
            
            self.relevant_vars = relevant_vars

        def __len__(self):
            return self.dataset_length

        def __getitem__(self, idx: int | slice):
            x = self.data[idx]
            y = self.task_bits[idx, self.relevant_vars].sum() % 2
            return x, y

class MultiTaskDataset(Dataset[tuple[torch.Tensor, torch.Tensor]]):
    def __init__(self,
                 n_task_bits: int,
                 n_control_bits: int,
                 relevant_vars: List[torch.Tensor],
                 dataset_length_per_task: int | list[int]):
        if isinstance(dataset_length_per_task, list):
            assert len(dataset_length_per_task) == n_control_bits
            data_set_lengths = dataset_length_per_task
        else:
            data_set_lengths = [dataset_length_per_task] * n_control_bits
        
        self.datasets = [SingleTaskDataset(n_task_bits, n_control_bits, i, relevant_vars[i], data_set_lengths[i]) for i in range(len(relevant_vars))]
        
        self.datasets = [SingleTaskDataset(n_task_bits, n_control_bits, i, relevant_vars[i], dataset_length) for i in range(len(relevant_vars))]
        self.dataset_length = dataset_length