-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
2774a7f
commit 7c20ed3
Showing
9 changed files
with
196 additions
and
56 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
from .susy import Susy | ||
from .wine import Wine | ||
from .higgs import Higgs | ||
from .us_census import Census | ||
from .covertype import Covertype | ||
from .osteoarthritis import Osteoarthritis |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
import pandas as pd | ||
import torch | ||
|
||
from .experiment import Experiment | ||
|
||
|
||
class Covertype(Experiment): | ||
|
||
def __init__(self): | ||
url = "https://archive.ics.uci.edu/ml/machine-learning-databases/covtype/covtype.data.gz" | ||
self.load_dataset_if_not_exists('data', url=url) | ||
# See covtype.info for description of the attributes | ||
dataset = pd.read_csv("data/covtype.data.gz", compression='gzip', delimiter=',', header=None) | ||
X = torch.transpose(torch.tensor(dataset.values).float(), 0, 1) | ||
|
||
super().__init__(X[4:8, :], name="Covertype") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
import pandas as pd | ||
import requests | ||
|
||
from cieg.eigenvectors import * | ||
from cieg.experiments.methods import * | ||
from cieg.experiments.utils import * | ||
from cieg.utils.covariance import cov | ||
from cieg.utils.draw import * | ||
|
||
|
||
class Experiment: | ||
def __init__(self, X, name): | ||
self.X = X | ||
self.name = name | ||
|
||
@staticmethod | ||
def load_dataset_if_not_exists(path, url): | ||
if not os.path.exists(path): | ||
os.mkdir(path) | ||
|
||
filename = os.path.join(path, os.path.basename(url)) | ||
if not os.path.exists(filename): | ||
print("Loading dataset") | ||
data = requests.get(url).content | ||
with open(filename, "wb") as file: | ||
file.write(data) | ||
|
||
def run(self, path): | ||
X = preprocess(self.X) | ||
X_pd = pd.DataFrame(torch.transpose(X, 0, 1).data.numpy(), dtype='float64') | ||
|
||
sigma = cov(X) | ||
emp_prec = torch.inverse(sigma) | ||
eig = get_eig(sigma) | ||
|
||
print("----------- RESAMPLING - bootstrap pmatrix -----------") | ||
pmatrix_lower, pmatrix_upper = resample_pmatrix(X_pd, n_iterations=100, rng=check_random_state(0)) | ||
print_pmatrix_bounds(pmatrix_lower, pmatrix_upper, emp_prec) | ||
check_pmatrix_bounds(pmatrix_lower, pmatrix_upper, emp_prec) | ||
plot_and_save_bounds(pmatrix_lower, | ||
pmatrix_upper, | ||
emp_prec, | ||
f"Bounds on the precision matrix using resampling on {self.name}", | ||
path) | ||
|
||
print("----------- OUR METHOD -----------") | ||
# Bounds on eigendecomposition | ||
eigvals_lower, eigvals_upper, eigvects_lower, eigvects_upper = cieg(X, sigma, eig) | ||
print_eig_bounds(eigvals_lower, eigvals_upper, eigvects_lower, eigvects_upper, eig) | ||
check_eig_bounds(eigvals_lower, eigvals_upper, eigvects_lower, eigvects_upper, eig) | ||
|
||
# Bounds on precision matrix | ||
pmatrix_lower, pmatrix_upper, _ = pmatrix_bounds(eigvals_lower, | ||
eigvals_upper, | ||
eigvects_lower, | ||
eigvects_upper, | ||
sigma, eig) | ||
print_pmatrix_bounds(pmatrix_lower, pmatrix_upper, emp_prec) | ||
check_pmatrix_bounds(pmatrix_lower, pmatrix_upper, emp_prec) | ||
plot_and_save_bounds(pmatrix_lower, | ||
pmatrix_upper, | ||
emp_prec, | ||
f"Bounds on the precision matrix using our method on {self.name}", | ||
path) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
import pandas as pd | ||
import torch | ||
|
||
from .experiment import Experiment | ||
|
||
|
||
class Higgs(Experiment): | ||
|
||
def __init__(self, **kwargs): | ||
url = "https://archive.ics.uci.edu/ml/machine-learning-databases/00280/HIGGS.csv.gz" | ||
columns = ['class', 'lepton pT', 'lepton eta', 'lepton phi', 'missing energy magnitude', 'missing energy phi', | ||
'jet 1 pt', 'jet 1 eta', 'jet 1 phi', 'jet 1 b-tag', 'jet 2 pt', 'jet 2 eta', 'jet 2 phi', | ||
'jet 2 b-tag', 'jet 3 pt', 'jet 3 eta', 'jet 3 phi', 'jet 3 b-tag', 'jet 4 pt', 'jet 4 eta', | ||
'jet 4 phi', 'jet 4 b-tag', 'm_jj', 'm_jjj', 'm_lv', 'm_jlv', 'm_bb', 'm_wbb', 'm_wwbb'] | ||
|
||
self.load_dataset_if_not_exists('data', url=url) | ||
dataset = pd.read_csv("data/HIGGS.csv.gz", | ||
compression='gzip', | ||
delimiter=',', | ||
header=None, | ||
names=columns, | ||
usecols=kwargs.pop('usecols')) | ||
|
||
X = torch.transpose(torch.tensor(dataset.values).float(), 0, 1) | ||
super().__init__(X, kwargs.pop("name")) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,64 +1,20 @@ | ||
import warnings | ||
|
||
import pandas as pd | ||
|
||
from cieg.eigenvectors import * | ||
from cieg.experiments.methods import * | ||
from cieg.experiments.utils import * | ||
from cieg.utils import * | ||
from cieg.utils.covariance import cov | ||
|
||
warnings.filterwarnings( | ||
"ignore", category=UserWarning | ||
) | ||
|
||
path = create_folders() | ||
|
||
columns_to_read = ['Side', 'WOMAC', 'OSTM'] # OSTM OSFM | ||
print(f"Features used: {columns_to_read}") | ||
|
||
# Read data | ||
X_pd = pd.read_csv('../data/oai_most_bl_aleksei_sep20_w_dataset_col.csv', usecols=columns_to_read) | ||
if 'Side' in columns_to_read: | ||
X_pd['Side'] = X_pd['Side'].map({"R": 0, "L": 1}) | ||
X_pd = X_pd.dropna() | ||
X = torch.transpose(torch.tensor(X_pd.values).float(), 0, 1) | ||
print(f"X shape: {X.shape}") | ||
|
||
X = preprocess(X) | ||
X_pd = pd.DataFrame(torch.transpose(X, 0, 1).data.numpy(), dtype='float64') | ||
from .experiment import Experiment | ||
|
||
sigma = cov(X) | ||
emp_prec = torch.inverse(sigma) | ||
eig = get_eig(sigma) | ||
|
||
class Osteoarthritis(Experiment): | ||
|
||
print("----------- RESAMPLING - bootstrap pmatrix -----------") | ||
pmatrix_lower, pmatrix_upper = resample_pmatrix(X_pd, n_iterations=100, rng=check_random_state(0)) | ||
print_pmatrix_bounds(pmatrix_lower, pmatrix_upper, emp_prec) | ||
check_pmatrix_bounds(pmatrix_lower, pmatrix_upper, emp_prec) | ||
plot_and_save_bounds(pmatrix_lower, | ||
pmatrix_upper, | ||
emp_prec, | ||
"Osteoarthritis bounds on the precision matrix using resampling", | ||
path) | ||
def __init__(self, **kwargs): | ||
columns_to_read = kwargs.pop("usecols") | ||
print(f"Features used: {columns_to_read}") | ||
|
||
print("----------- OUR METHOD -----------") | ||
# Bounds on eigendecomposition | ||
eigvals_lower, eigvals_upper, eigvects_lower, eigvects_upper = cieg(X, sigma, eig) | ||
print_eig_bounds(eigvals_lower, eigvals_upper, eigvects_lower, eigvects_upper, eig) | ||
check_eig_bounds(eigvals_lower, eigvals_upper, eigvects_lower, eigvects_upper, eig) | ||
# Read data | ||
X_pd = pd.read_csv('data/oai_most_bl_aleksei_sep20_w_dataset_col.csv', usecols=columns_to_read) | ||
if 'Side' in columns_to_read: | ||
X_pd['Side'] = X_pd['Side'].map({"R": 0, "L": 1}) | ||
X_pd = X_pd.dropna() | ||
X = torch.transpose(torch.tensor(X_pd.values).float(), 0, 1) | ||
|
||
# Bounds on precision matrix | ||
pmatrix_lower, pmatrix_upper, _ = pmatrix_bounds(eigvals_lower, | ||
eigvals_upper, | ||
eigvects_lower, | ||
eigvects_upper, | ||
sigma, eig) | ||
print_pmatrix_bounds(pmatrix_lower, pmatrix_upper, emp_prec) | ||
check_pmatrix_bounds(pmatrix_lower, pmatrix_upper, emp_prec) | ||
plot_and_save_bounds(pmatrix_lower, | ||
pmatrix_upper, | ||
emp_prec, | ||
"Osteoarthritis bounds on the precision matrix using our method", | ||
path) | ||
super().__init__(X, "Osteoarthritis") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
import warnings | ||
from cieg.experiments import * | ||
from cieg.utils import create_folders | ||
|
||
warnings.filterwarnings( | ||
"ignore", category=UserWarning | ||
) | ||
|
||
census_columns = ["iClass", "dIncome1", "iEnglish", "dHours"] | ||
higgs_first = ['m_jj', 'm_jjj', 'm_lv', 'm_jlv'] | ||
higgs_second = ['m_bb', 'm_wbb', 'm_wwbb'] | ||
osteoarthritis = ['Side', 'WOMAC', 'OSTM'] # OSTM OSFM | ||
|
||
|
||
EXPERIMENTS = { | ||
'higgs_first': Higgs(usecols=higgs_first, name='Higgs first'), | ||
'higgs_second': Higgs(usecols=higgs_second, name='Higgs second'), | ||
'osteoarthritis': Osteoarthritis(usecols=osteoarthritis) | ||
} | ||
|
||
|
||
def run(): | ||
for exp_name in EXPERIMENTS.keys(): | ||
print(f"Running {exp_name}") | ||
path = create_folders() | ||
experiment = EXPERIMENTS[exp_name] | ||
experiment.run(path) | ||
|
||
|
||
if __name__ == '__main__': | ||
run() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
import pandas as pd | ||
import torch | ||
|
||
from .experiment import Experiment | ||
|
||
|
||
class Susy(Experiment): | ||
|
||
def __init__(self): | ||
url = "https://archive.ics.uci.edu/ml/machine-learning-databases/00279/SUSY.csv.gz" | ||
self.load_dataset_if_not_exists('data', url=url) | ||
dataset = pd.read_csv("data/SUSY.csv.gz", compression='gzip', delimiter=',', header=None) | ||
|
||
X = torch.transpose(torch.tensor(dataset.values).float(), 0, 1) | ||
super().__init__(X[:4, :], name="Susy") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
import pandas as pd | ||
import torch | ||
|
||
from .experiment import Experiment | ||
|
||
|
||
class Census(Experiment): | ||
|
||
def __init__(self, **kwargs): | ||
url = "https://archive.ics.uci.edu/ml/machine-learning-databases/census1990-mld/USCensus1990.data.txt" | ||
self.load_dataset_if_not_exists('data', url=url) | ||
X_input = pd.read_csv('data/USCensus1990.data.txt', usecols=kwargs.pop('usecols')) | ||
|
||
X = torch.transpose(torch.tensor(X_input.values).float(), 0, 1) | ||
super().__init__(X, name="Census") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
from sklearn.datasets import load_wine | ||
|
||
from cieg.experiments.utils import * | ||
from .experiment import Experiment | ||
|
||
|
||
class Wine(Experiment): | ||
def __init__(self): | ||
dataset = load_wine() | ||
|
||
X = torch.transpose(torch.tensor(dataset.data).float(), 0, 1) | ||
super().__init__(X[4:7, :], name="Wine") |