# Creating, saving and loading datasets with CSBMs

In [1]:
import multiprocessing
import os
import torch
import numpy as np

from csbm import MultiClassCSBM, FeatureCSBM, StructureCSBM, ClassCSBM, HomophilyCSBM

In [2]:
np.set_printoptions(precision=2, suppress=True)
torch.set_printoptions(precision=3)
csbm = MultiClassCSBM(n=100, classes=2, sigma_square=0.1)
csbm_f = FeatureCSBM(n=100, classes=2, sigma_square=0.1)

In [3]:
for _ in range(10):
    csbm.evolve()
    print(f'{csbm.get_per_class_feature_shift_mmd_with_rbf_kernel():.4f}')
    csbm_f.evolve()
    print(f'{csbm_f.get_per_class_feature_shift_mmd_with_rbf_kernel():.4f}')
    print()

0.0412
0.0406

0.0402
0.0415

0.0407
0.0418

0.0403
0.0409

0.0407
0.0426

0.0402
0.0446

0.0407
0.0494

0.0402
0.0585

0.0403
0.0635

0.0402
0.0663



In [4]:
from measures import mmd_max_rbf, mmd_rbf

In [5]:
for i in range(10):
    mmd = 0
    mmd_f = 0
    for c in range(csbm.classes):
        x = csbm.x[(csbm.t == 0) & (csbm.y == c)]
        z = csbm.x[(csbm.t == i) & (csbm.y == c)]
        
        x_f = csbm_f.x[(csbm_f.t == 0) & (csbm_f.y == c)]
        z_f = csbm_f.x[(csbm_f.t == i) & (csbm_f.y == c)]
        
        mmd += mmd_max_rbf(x, z)
        mmd_f += mmd_max_rbf(x_f, z_f)
    mmd /= csbm.classes
    mmd_f /= csbm_f.classes
    print(f'Control:\t{mmd:.4f}')
    print(f'Feature:\t{mmd_f:.4f}\n')

Control:	0.0000
Feature:	0.0000

Control:	0.0412
Feature:	0.0406

Control:	0.0402
Feature:	0.0415

Control:	0.0407
Feature:	0.0418

Control:	0.0403
Feature:	0.0409

Control:	0.0407
Feature:	0.0426

Control:	0.0402
Feature:	0.0446

Control:	0.0407
Feature:	0.0494

Control:	0.0402
Feature:	0.0585

Control:	0.0403
Feature:	0.0635

