-
Notifications
You must be signed in to change notification settings - Fork 0
/
base_model.py
222 lines (186 loc) · 7.95 KB
/
base_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
from abc import ABC, abstractmethod
import os
import json
import pandas as pd
class BaseModel(ABC):
"""
Abstract base class for a base model.
Attributes:
data: The data generated by the model.
model: The model used for simulation.
parameters: The parameters of the model.
Methods:
check_params: Checks if parameters are set and uses default values if not.
simulate: Simulates data based on the specified simulation type and parameters.
read_parameters: Reads parameters from a JSON file.
save_parameters: Saves parameters to a JSON file.
print_parameters: Prints parameters to the console.
save_data: Saves data to a CSV file.
load_data: Loads data from a CSV file.
"""
def __init__(self):
"""
Initialize a new instance of the BaseModel class.
"""
self.data = None
self.model = None
self.parameters = None
# check if params are set, else print a warning and use the default values for each simulation type
@abstractmethod
def set_parameters(self):
"""
Set parameters for the model.
This method is implemented in the derived classes.
"""
pass
def read_parameters(self, filepath):
"""
Read parameters from a JSON file.
Parameters:
filepath (str): The path to the JSON file.
Raises:
ValueError: If `filepath` does not point to a .json file.
FileNotFoundError: If `filepath` does not exist.
"""
# Check if the filepath ends with .json
if not filepath.lower().endswith('.json'):
raise ValueError("Filepath must point to a .json file.")
if not os.path.exists(filepath):
raise FileNotFoundError(f"No file found at {filepath}")
try:
with open(filepath, 'r') as file:
self.parameters = json.load(file)
except Exception as e:
print(f"Error reading parameters from {filepath}: {e}")
return False
def check_params(self, params, sim_type):
"""
Check and update simulation parameters.
Args:
params (dict): The parameters provided for the simulation.
sim_type (str): The type of simulation.
Returns:
dict: The updated simulation parameters.
Raises:
ValueError: If sim_type is not 'VAR' or 'gMLV'.
Examples:
>>> params = {"n_obs": 200, "coefficients": [[0.5, -0.5], [0.2, 0.8]], "output": "save"}
>>> sim_type = "VAR"
>>> check_params(params, sim_type)
Using the following parameters for VAR simulation: {'n_obs': 200, 'coefficients': [[0.5, -0.5], [0.2, 0.8]], 'initial_values': [[1], [2]], 'noise_stddev': 1, 'output': 'save'}
"""
# NOTE: This method is here instead of being implemented by each derived class to avoid code repetition. It is used to check if parameters are set and use default values if not.
# sourcery skip: use-named-expression
# Define default parameters for each simulation type
default_params_VAR = {"n_obs": 100, "coefficients": [[0.8, -0.2], [0.3, 0.5]],
"initial_values": [[1], [2]], "noise_stddev": 1, "output": "show"}
default_params_gMLV = {"n": 100, "p": 2, "k": 2, "sigma": 1}
# Determine default parameters based on simulation type
if sim_type == "VAR":
default_params = default_params_VAR
elif sim_type == "gMLV":
default_params = default_params_gMLV
else:
raise ValueError("sim_type must be 'VAR' or 'gMLV'.")
# Check if no parameters were provided and warn the user
if params is None:
print(
f"Warning: No parameters provided for {sim_type} simulation. Using default values.")
else:
# Identify missing or None parameters
missing_params = [
key for key in default_params if key not in params or params[key] is None]
if missing_params:
print(
f"Warning: Missing or None parameters for {sim_type} simulation. Using default values for: {missing_params}")
# Update the default parameters with the provided ones
for key, value in params.items():
if value is not None:
default_params[key] = value
print(
f"Using the following parameters for {sim_type} simulation: {default_params}")
return default_params
def print_parameters(self):
"""
Print parameters to the console.
If the instance's parameters are None, prints "No parameters to print."
"""
print(f"Model: {self.model}")
if self.parameters is not None:
print(json.dumps(self.parameters, indent=4))
else:
print("No parameters to print.")
def save_parameters(self, filepath, parameters=None):
"""
Save parameters to a JSON file.
Parameters:
filepath (str): The path to the JSON file.
parameters (dict, optional): The parameters to save. If None, the instance's parameters are used.
Raises:
ValueError: If `filepath` does not point to a .json file.
FileNotFoundError: If the directory to save the file does not exist.
"""
if not filepath.endswith('.json'):
raise ValueError("Filepath must point to a .json file.")
if not os.path.exists(os.path.dirname(filepath)):
raise FileNotFoundError(
f"No directory found at {os.path.dirname(filepath)}")
parameters = parameters if parameters is not None else self.parameters
if parameters is None:
print("No parameters to save.")
return
try:
with open(filepath, 'w') as file:
json.dump(parameters, file)
except Exception as e:
print(f"Error saving parameters to {filepath}: {e}")
return False
def save_data(self, filename, data=None):
"""
Save data to a CSV file.
Parameters:
filename (str): The name of the CSV file.
data (numpy array, optional): The data to save. If None, the instance's data is used.
Raises:
ValueError: If `filename` does not end with .csv.
FileNotFoundError: If the directory to save the file does not exist.
"""
if not filename.endswith('.csv'):
raise ValueError("Filename must end with .csv.")
if not os.path.exists(os.path.dirname(filename)):
raise FileNotFoundError(
f"No directory found at {os.path.dirname(filename)}")
data = data if data is not None else self.data
if data is None:
print("No data to save.")
return
try:
pd.DataFrame(data).to_csv(filename, index=False, header=False)
except Exception as e:
print(f"Error saving data to {filename}: {e}")
return False
def load_data(self, filename):
"""
Load data from a CSV file.
Parameters:
filename (str): The name of the CSV file.
Raises:
ValueError: If `filename` does not point to a .csv file.
FileNotFoundError: If `filename` does not exist.
"""
if not filename.endswith('.csv'):
raise ValueError("Filename must point to a .csv file.")
if not os.path.exists(filename):
raise FileNotFoundError(f"No file found at {filename}")
try:
self.data = pd.read_csv(filename, header=None).values.tolist()
except Exception as e:
print(f"Error reading data from {filename}: {e}")
return False
@abstractmethod
def simulate(self):
"""
Simulate data based on the specified simulation type and parameters.
This method is implemented in the derived classes.
"""
pass