In [170]:
import numpy as np
from scipy.stats import skewnorm, skew
import matplotlib.pyplot as plt
import pandas as pd


#parameters
classes = [0.2, 0.5, 0.3] #percentages
n_vars = 10
n = 1000
max_mu = 20
max_sigma = 10
max_skew = 7
np.random.seed(12345)


def simulate_data(classes, n_vars, n, max_mu, max_sigma, max_skew):
    #The multivariate skew normal number generator
    def rng(mu, sigma, skew, n=1):
        k = len(mu)
        if not (k == len(sigma) and k ==len(skew)): 
            raise Exception("Mu, Sigma and Skew should be same length")

        data = np.zeros((int(n),k))

        for i in range(k):
            data[:,i] = skewnorm.rvs(skew[i], loc=mu[i], scale=sigma[i], size=int(n)) 

        return data
    
    if(np.sum(classes) != 1):
        raise Exception("Classes dont sum up to 1")
        
    n_classes = len(classes)
    sigma = np.random.randint(1,max_sigma,n_vars)
    skew = np.random.randint(-max_skew,max_skew,n_vars)
    mu =  np.random.randint(-max_mu, max_mu, (n_classes, n_vars))
    
    n_obs_class = np.round(np.dot(classes,n))
    
    print(n_obs_class)
    
    data = np.zeros((int(np.sum(n_obs_class)),n_vars+1))
    for i in range(n_classes):
        #calculate indexes
        start = int(np.sum(n_obs_class[0:i]))
        end = int(np.sum(n_obs_class[0:i+1]))
        
        #set the data
        data[start:end,0] = i
        data[start:end,1:] = rng(mu[i,:], sigma, skew, n_obs_class[i])
        
    X = data[:,1:]
    y = data[:,0]
    
    columns = ["x"+str(x) for x in range(n_vars + 1)]
    columns[0] = "class"
    
    df = pd.DataFrame(data,columns=columns)

    return X,y, df

X,y, df = simulate_data(classes, n_vars, n, max_mu, max_sigma, max_skew)
# data = X.join(pd.Series(y, name='class'))

display(df)


[200. 500. 300.]


Unnamed: 0,class,x1,x2,x3,x4,x5,x6,x7,x8,x9,x10
0,0.0,-14.013540,15.677506,1.829222,4.949057,13.254788,14.347968,9.176275,2.089042,15.429028,-22.377022
1,0.0,-10.042260,12.523824,1.747542,11.727868,16.894793,14.343651,7.761414,-0.503880,14.724127,-17.746363
2,0.0,-15.107836,16.174970,2.934717,10.046998,11.393264,15.722016,11.095267,1.116959,15.736638,-31.947955
3,0.0,-11.811799,18.342785,0.445583,9.465393,17.918978,18.181107,6.383164,3.183795,15.455131,-32.552531
4,0.0,-12.723650,11.638859,2.347462,10.760957,16.959377,15.480157,9.198635,-1.729150,13.881437,-18.498099
5,0.0,-12.052482,17.529102,-0.274735,11.261768,13.312943,15.117240,7.555262,-7.434681,13.374003,-27.113609
6,0.0,-16.985712,18.475640,3.400556,10.813930,12.169179,12.895663,8.319978,-0.646054,13.522222,-25.668826
7,0.0,-18.446821,13.954067,2.142966,7.725867,13.621330,16.439200,6.127532,-2.447928,14.718486,-32.268982
8,0.0,-19.547022,12.909681,0.718306,16.682876,12.799038,15.513667,8.327688,0.719324,14.408015,-20.774027
9,0.0,-10.251888,17.468064,2.448775,12.867457,17.179283,14.221798,7.428287,1.153779,14.556602,-19.470916


In [None]:
    
    

# display(np.cov(data.T))




# plt.plot(x, y, 'x')
# plt.axis('equal')
# plt.show()


# plt.hist(x)
# plt.hist(y)
# # plt.hist(z)

# from mpl_toolkits import mplot3d
# fig = plt.figure()
# ax = plt.axes(projection='3d')
# ax.scatter3D(x, y, z, c=z);

In [156]:
np.random.seed(1234)
data = np.random.randint(0,10,(3,3))
display(data)
data[:,1:]

array([[3, 6, 5],
       [4, 8, 9],
       [1, 7, 9]])

array([[6, 5],
       [8, 9],
       [7, 9]])