Skip to content

Commit 3a4b72e

Browse files
committed
no function changes
1 parent 033c01e commit 3a4b72e

File tree

8 files changed

+116
-115
lines changed

8 files changed

+116
-115
lines changed

pgml-admin/app/serializers.py

Lines changed: 28 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,31 @@
22

33
from app.models import Project, Snapshot, Model, Deployment
44

5+
class SnapshotSerializer(serializers.ModelSerializer):
6+
y_column_name = serializers.ListSerializer(child=serializers.CharField())
7+
8+
class Meta:
9+
model = Snapshot
10+
fields = [
11+
"id",
12+
"y_column_name",
13+
"test_size",
14+
"test_sampling",
15+
"status",
16+
"columns",
17+
"analysis",
18+
"sample",
19+
"samples",
20+
"table_size",
21+
"feature_size",
22+
"created_at",
23+
"updated_at",
24+
]
25+
526

627
class ModelSerializer(serializers.ModelSerializer):
28+
snapshot = SnapshotSerializer()
29+
730
class Meta:
831
model = Model
932
fields = [
@@ -27,34 +50,15 @@ class Meta:
2750
fields = [
2851
"id",
2952
"name",
53+
"key_metric_name",
54+
"key_metric_display_name",
3055
"objective",
3156
"created_at",
3257
"updated_at",
3358
"models",
3459
]
3560

3661

37-
class SnapshotSerializer(serializers.ModelSerializer):
38-
y_column_name = serializers.ListSerializer(child=serializers.CharField())
39-
40-
class Meta:
41-
model = Snapshot
42-
fields = [
43-
"id",
44-
"y_column_name",
45-
"test_size",
46-
"test_sampling",
47-
"status",
48-
"columns",
49-
"analysis",
50-
"sample",
51-
"samples",
52-
"table_size",
53-
"feature_size",
54-
"created_at",
55-
"updated_at",
56-
]
57-
5862

5963
class DeploymentSerializer(serializers.ModelSerializer):
6064
class Meta:
@@ -65,8 +69,9 @@ class Meta:
6569

6670
class NewProjectSerializer(serializers.Serializer):
6771
project_name = serializers.CharField()
68-
objective = serializers.CharField()
69-
snapshot_id = serializers.IntegerField()
72+
objective = serializers.CharField(required=False)
73+
relation_name = serializers.CharField(required=False)
74+
y_column_name = serializers.ListSerializer(child=serializers.CharField(), required=False)
7075
algorithms = serializers.ListSerializer(child=serializers.CharField())
7176

7277

pgml-admin/app/static/js/controllers/new-project.js

Lines changed: 26 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -164,16 +164,19 @@ export default class extends Controller {
164164
}
165165

166166
renderAnalysisResult() {
167-
fetch(`/html/snapshots/analysis/?snapshot_id=${this.snapshotData.id}`)
167+
const snapshotData = this.projectData.models[0].snapshot
168+
169+
console.log("Fetching analysis")
170+
fetch(`/html/snapshots/analysis/?snapshot_id=${snapshotData.id}`)
168171
.then(res => res.text())
169172
.then(html => this.analysisResultTarget.innerHTML = html)
170173
.then(() => {
171174
// Render charts
172-
for (name in this.snapshotData.columns) {
175+
for (name in snapshotData.columns) {
173176
const sample = JSON.parse(document.getElementById(name).textContent)
174-
renderDistribution(name, sample, this.snapshotData.analysis[`${name}_dip`])
177+
renderDistribution(name, sample, snapshotData.analysis[`${name}_dip`])
175178

176-
for (target of this.snapshotData.y_column_name) {
179+
for (target of snapshotData.y_column_name) {
177180
if (target === name)
178181
continue
179182

@@ -182,9 +185,9 @@ export default class extends Controller {
182185
}
183186
}
184187

185-
for (target of this.snapshotData.y_column_name) {
188+
for (target of snapshotData.y_column_name) {
186189
const targetSample = JSON.parse(document.getElementById(target).textContent)
187-
renderOutliers(target, targetSample, this.snapshotData.analysis[`${target}_stddev`])
190+
renderOutliers(target, targetSample, snapshotData.analysis[`${target}_stddev`])
188191
}
189192

190193
this.progressBarProgress = 100
@@ -215,52 +218,34 @@ export default class extends Controller {
215218
createSnapshot(event) {
216219
event.preventDefault()
217220

218-
const request = {
219-
"relation_name": this.tableName,
220-
"y_column_name": Array.from(this.targetNames),
221-
}
221+
// Train a linear algorithm by default
222+
this.algorithmNames.add("linear")
222223

223224
this.nextStep()
224225

225226
// Start the progress bar :)
226227
this.progressBarProgress = 2
227-
this.progressBarInterval = setInterval(this.renderProgressBar.bind(this), 750)
228-
console.log("interval set to ", this.progressBarInterval)
228+
this.progressBarInterval = setInterval(this.renderProgressBar.bind(this), 850)
229229

230-
fetch(`/api/snapshots/snapshot/`, {
231-
method: "POST",
232-
cache: "no-cache",
233-
headers: {
234-
"Content-Type": "application/json",
235-
},
236-
redirect: "follow",
237-
body: JSON.stringify(request),
238-
})
239-
.then(res => {
240-
if (res.ok) {
241-
return res.json()
242-
} else {
243-
alert("Failed to create snapshot")
244-
throw Error("Failed to create snapshot")
245-
}
246-
})
247-
.then(json => {
248-
this.snapshotData = json
230+
this.createProject(event, false, () => {
249231
this.renderAnalysisResult()
232+
this.algorithmNames.delete("linear")
250233
})
251234
}
252235

253-
createProject(event) {
236+
createProject(event, redirect = true, callback = null) {
254237
event.preventDefault()
255238

256239
const request = {
257240
"project_name": this.projectName,
258241
"objective": this.objectiveName,
259242
"algorithms": Array.from(this.algorithmNames),
260-
"snapshot_id": this.snapshotData.id,
243+
"relation_name": this.tableName,
244+
"y_column_name": Array.from(this.targetNames),
261245
}
262246

263-
this.createLoader()
247+
if (redirect)
248+
this.createLoader()
264249

265250
fetch(`/api/projects/train/`, {
266251
method: "POST",
@@ -280,7 +265,13 @@ export default class extends Controller {
280265
}
281266
})
282267
.then(json => {
283-
window.location.assign(`/projects/${json.id}`);
268+
this.projectData = json
269+
270+
if (redirect)
271+
window.location.assign(`/projects/${json.id}`);
272+
273+
if (callback)
274+
callback()
284275
})
285276
}
286277

pgml-admin/app/templates/projects/new.html

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,23 @@ <h2>Project name</h2>
1919
</div>
2020
</section>
2121

22+
<section data-new-project-target="step" class="hidden" data-step-name="algorithm-type">
23+
<h2>Objective</h2>
24+
<ol class="objective_list">
25+
<li>
26+
<a href="#" data-action="click->new-project#selectObjective" data-new-project-target="objective" data-objective="regression">Regression</a>
27+
</li>
28+
<li>
29+
<a href="#" data-action="click->new-project#selectObjective" data-new-project-target="objective" data-objective="classification">Classification</a>
30+
</li>
31+
</ol>
32+
33+
<div class="button-container">
34+
<button class="next" data-action="click->new-project#nextStep" data-new-project-target="objectiveNameNext" disabled>Next</button>
35+
<button class="next" data-action="click->new-project#previousStep">Back</button>
36+
</div>
37+
</section>
38+
2239
<section data-new-project-target="step" class="hidden" data-step-name="data-source">
2340
<h2>Data source</h2>
2441

@@ -46,14 +63,14 @@ <h2>Target</h2>
4663
<ol data-new-project-target="trainingLabel" class="object_list snapshot_list"></ol>
4764

4865
<div class="button-container">
49-
<button class="next" data-action="click->new-project#createSnapshot" data-new-project-target="analysisNext" disabled>Snapshot data</button>
66+
<button class="next" data-action="click->new-project#createSnapshot" data-new-project-target="analysisNext" disabled>Run analysis</button>
5067
<button class="next" data-action="click->new-project#previousStep">Back</button>
5168
</div>
5269
</section>
5370

5471
<section data-new-project-target="step" class="hidden" data-step-name="target">
5572
<div style="display: flex; justify-content: center; align-items: center;">
56-
<p>Creating snapshot, this may take a moment...</p>
73+
<p>Performing data analysis, this may take a moment...</p>
5774
</div>
5875

5976
<div class="progress-bar">
@@ -73,28 +90,10 @@ <h2>Analysis</h2>
7390
</div>
7491
</section>
7592

76-
<section data-new-project-target="step" class="hidden" data-step-name="algorithm-type">
77-
<h2>Objective</h2>
78-
<ol class="objective_list">
79-
<li>
80-
<a href="#" data-action="click->new-project#selectObjective" data-new-project-target="objective" data-objective="regression">Regression</a>
81-
</li>
82-
<li>
83-
<a href="#" data-action="click->new-project#selectObjective" data-new-project-target="objective" data-objective="classification">Classification</a>
84-
</li>
85-
</ol>
86-
87-
<div class="button-container">
88-
<button class="next" data-action="click->new-project#nextStep" data-new-project-target="objectiveNameNext" disabled>Next</button>
89-
<button class="next" data-action="click->new-project#previousStep">Back</button>
90-
</div>
91-
</section>
92-
9393
<section data-new-project-target="step" class="hidden" data-step-name="algorithm">
9494
<h2>Algorithms</h2>
9595
<div style="margin-top: 25px" data-new-project-target="algorithmListRegression">
9696
<ol class="algorithm_list">
97-
<li><a href="#" data-action="click->new-project#selectAlgorithm" data-algorithm="linear"><span>LinearRegression</span></a></li>
9897
<li><a href="#" data-action="click->new-project#selectAlgorithm" data-algorithm="ridge"><span>Ridge</span></a></li>
9998
<li><a href="#" data-action="click->new-project#selectAlgorithm" data-algorithm="lasso"><span>Lasso</span></a></li>
10099
<li><a href="#" data-action="click->new-project#selectAlgorithm" data-algorithm="elastic_net"><span>ElasticNet</span></a></li>
@@ -127,7 +126,6 @@ <h2>Algorithms</h2>
127126

128127
<div style="margin-top: 25px" data-new-project-target="algorithmListClassification">
129128
<ol class="algorithm_list">
130-
<li><a href="#" data-action="click->new-project#selectAlgorithm" data-algorithm="linear"><span>LogisticRegression</span></a></li>
131129
<li><a href="#" data-action="click->new-project#selectAlgorithm" data-algorithm="ridge"><span>RidgeClassifier</span></a></li>
132130
<li><a href="#" data-action="click->new-project#selectAlgorithm" data-algorithm="stochastic_gradient_descent"><span>SGDClassifier</span></a></li>
133131
<li><a href="#" data-action="click->new-project#selectAlgorithm" data-algorithm="perceptron"><span>Perceptron</span></a></li>

pgml-admin/app/templates/snapshots/analysis.html

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,16 @@
1+
<div style="margin: 30px 0">
2+
<h3><span class="material-symbols-outlined">model_training</span>Linear model</h3>
3+
<dl class="model_metrics">
4+
{% for key, value in model.metrics.items %}
5+
{% if key == "search_results" %}
6+
{% else %}
7+
<dt>{{ key }}</dt>
8+
<dd>{{ value|floatformat:"5"}}</dd>
9+
{% endif %}
10+
{% endfor %}
11+
</dl>
12+
</div>
13+
114
<div class="feature_figures">
215
<h3><span class="material-symbols-outlined">label_important</span>Labels</h3>
316
{% for label in labels %}

pgml-admin/app/views/projects.py

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -75,25 +75,34 @@ def train(self, request):
7575
"""Train a new project."""
7676
serializer = NewProjectSerializer(data=request.data)
7777
if serializer.is_valid():
78+
exists = len(Project.objects.filter(name=serializer.validated_data["project_name"])) > 0
79+
7880
with connection.cursor() as cursor:
79-
cursor.execute(
80-
"""
81-
SELECT * FROM pgml.train_joint(
82-
project_name => %s,
83-
objective => %s,
84-
snapshot_id => %s,
85-
algorithm => %s
81+
if not exists:
82+
cursor.execute(
83+
"""
84+
SELECT * FROM pgml.train_joint(
85+
project_name => %s,
86+
objective => %s,
87+
relation_name => %s,
88+
y_column_name => %s,
89+
algorithm => %s
90+
)
91+
""",
92+
[
93+
serializer.validated_data["project_name"],
94+
serializer.validated_data["objective"],
95+
serializer.validated_data["relation_name"],
96+
serializer.validated_data["y_column_name"],
97+
serializer.validated_data["algorithms"][0],
98+
],
8699
)
87-
""",
88-
[
89-
serializer.validated_data["project_name"],
90-
serializer.validated_data["objective"],
91-
serializer.validated_data["snapshot_id"],
92-
serializer.validated_data["algorithms"][0],
93-
],
94-
)
95-
96-
for algorithm in serializer.validated_data["algorithms"][1:]:
100+
if exists:
101+
algorithms = serializer.validated_data["algorithms"]
102+
else:
103+
algorithms = serializer.validated_data["algorithms"][1:]
104+
105+
for algorithm in algorithms:
97106
cursor.execute(
98107
"""
99108
SELECT * FROM pgml.train(

pgml-admin/app/views/snapshots.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from rest_framework import status
1212

1313
import json
14-
from app.models import Snapshot, Project
14+
from app.models import Snapshot, Project, Model
1515
from app.serializers import SnapshotSerializer, NewSnapshotSerializer
1616

1717
from collections import namedtuple
@@ -108,7 +108,8 @@ def list(self, request):
108108
"type": snapshot.columns[column],
109109
"samples": list(map(lambda x: x[column], snapshot.sample())),
110110
} for column in snapshot.columns.keys() - snapshot.y_column_name
111-
]
111+
],
112+
"model": Model.objects.filter(snapshot=snapshot, algorithm_name="linear").first(),
112113
}
113114

114115
return render(request, "snapshots/analysis.html", context)

pgml-extension/pgml_extension/model.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -770,7 +770,6 @@ def train(
770770
objective: str = None,
771771
relation_name: str = None,
772772
y_column_name: str = None,
773-
snapshot_id: int = None,
774773
algorithm_name: str = "linear",
775774
hyperparams: dict = {},
776775
search: str = None,
@@ -805,18 +804,7 @@ def train(
805804
raise PgMLException(f"Unknown objective `{objective}`, available options are: regression, classification.")
806805

807806
# Create or use an existing snapshot.
808-
#
809-
# If a snapshot_id is given, use that specific snapshot.
810-
# If a relation name is given, snapshot it.
811-
# If none of the above, use the last snapshot created for the project, if any.
812-
#
813-
if snapshot_id is not None:
814-
snapshot = Snapshot.find(snapshot_id)
815-
if snapshot is None:
816-
raise PgMLException(
817-
f"Snapshot with ID {snapshot_id} does not exist."
818-
)
819-
elif relation_name is None:
807+
if relation_name is None:
820808
snapshot = project.last_snapshot
821809
if snapshot is None:
822810
raise PgMLException(

0 commit comments

Comments
 (0)