In [11]:
import numpy as np
import torch
from torch.distributions.multivariate_normal import MultivariateNormal, Uniform

In [None]:
class LowDimLinearRep():
    def __init__(self, input_d, feature_d, output_d, n_source_tasks, n1, n2):
        self.input_d = input_d
        self.feature_d = feature_d
        self.output_d = output_d
        self.n_source_tasks = n_source_tasks
        self.n1 = n1
        self.n2 = n2

    def low_dim_check(self):
        assert self.feature_d < self.input_d, "feature dimension should be less than input dimension"

    def representation(self, x):
        """
        maps input x to a low-dimensional representation of dim self.feature_d
        """
        pass

    def prediction(self, phi):
        """
        linear prediction of the representation of x
        """
        pass

In [None]:
class Simulation():
    def __init__(self, input_d, feature_d):
        pass

    def generate_data(self, n, rho):
        """
        generate n samples of data
        rho^2-subgaussian
        following assumption 4.1
        """
        d = self.input_d
        vec = torch.normal(0, rho, size=(d,1))
        # generate d independent subgaussian random variables to form a subgaussian vector
        pass

    def representation(self, x):
        """
        ground truth representation function
        """
        pass

    def prediction(self, phi):
        """
        ground truth prediction function
        
        """
        pass
    
    def specialization_t(self, A3=True, A4=True, target_task=False):
        """
        generate a specialization function for task t
        for assumption 4.3
        iid sample from N(0, Sigma) where (max eigenvalue of Sigma)/(min eigenvalue of Sigma) = O(1)

        returns w_t, maximun and minimum eigenvalues of Sigma
        """
        k = self.feature_d
        if A3:
            ev = np.linspace(1.0, 10.0, k)
        else:
            ev = np.linspace(1.0, 10.0*k, k)
        
        if A4 and target_task:
            ev /= k
        
        mean = torch.zeros(k)
        cov = torch.diag(torch.tensor(ev.astype(np.float32)))

        p_t = MultivariateNormal(mean, covariance_matrix=cov)
        w_t = p_t.sample()

        ev_max = np.max(ev)
        ev_min = np.min(ev)
        return w_t, ev_max, ev_min

    def noise(self, sigma, n):
        """
        simulate gaussian noise vector of length n
        """
        return np.random.normal(0, sigma, n)