Skip to content

Commit

Permalink
Merge pull request #44 from team-inu/development
Browse files Browse the repository at this point in the history
Development
  • Loading branch information
Porping committed May 6, 2024
2 parents 25911f5 + 8845fbf commit 93254be
Show file tree
Hide file tree
Showing 11 changed files with 155 additions and 223 deletions.
32 changes: 10 additions & 22 deletions entity/prediction.go
Original file line number Diff line number Diff line change
@@ -1,31 +1,19 @@
package entity

type PredictionStatus string

const (
PredictionStatusPending PredictionStatus = "PENDING"
PredictionStatusFailed PredictionStatus = "FAILED"
PredictionStatusDone PredictionStatus = "DONE"
)

type Prediction struct {
Id string `json:"id" gorm:"primaryKey;type:char(255)"`
Status PredictionStatus `json:"status"`
Result string `json:"result"`
PredictedGPAX float64 `json:"predictedGPAX"`
}

type PredictionRepository interface {
GetById(id string) (*Prediction, error)
GetAll() ([]Prediction, error)
GetLatest() (*Prediction, error)
CreatePrediction(prediction *Prediction) error
Update(id string, prediction *Prediction) error
type PredictionRequirements struct {
ProgrammeName string
OldGPAX *float64
MathGPA *float64
EngGPA *float64
SciGPA *float64
School string
Admission string
}

type PredictionUseCase interface {
GetById(id string) (*Prediction, error)
GetAll() ([]Prediction, error)
GetLatest() (*Prediction, error)
CreatePrediction() (*string, error)
Update(id string, status PredictionStatus, result string) error
CreatePrediction(requirements PredictionRequirements) (*Prediction, error)
}
6 changes: 6 additions & 0 deletions entity/student.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ type StudentRepository interface {
Update(id string, student *Student) error
Delete(id string) error
FilterExisted(studentIds []string) ([]string, error)

GetAllSchools() ([]string, error)
GetAllAdmissions() ([]string, error)
}

type StudentUseCase interface {
Expand All @@ -41,4 +44,7 @@ type StudentUseCase interface {
Delete(id string) error
FilterExisted(studentIds []string) ([]string, error)
FilterNonExisted(studentIds []string) ([]string, error)

GetAllSchools() ([]string, error)
GetAllAdmissions() ([]string, error)
}
20 changes: 17 additions & 3 deletions infrastructure/fiber/controller/prediction.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package controller
import (
"github.com/gofiber/fiber/v2"
"github.com/team-inu/inu-backyard/entity"
"github.com/team-inu/inu-backyard/infrastructure/fiber/request"
"github.com/team-inu/inu-backyard/infrastructure/fiber/response"
"github.com/team-inu/inu-backyard/internal/validator"
)
Expand All @@ -19,11 +20,24 @@ func NewPredictionController(validator validator.PayloadValidator, predictionUse
}
}

func (c predictionController) Train(ctx *fiber.Ctx) error {
id, err := c.predictionUseCase.CreatePrediction()
func (c predictionController) Predict(ctx *fiber.Ctx) error {
var payload request.PredictPayload
if ok, err := c.Validator.Validate(&payload, ctx); !ok {
return err
}

prediction, err := c.predictionUseCase.CreatePrediction(entity.PredictionRequirements{
ProgrammeName: payload.ProgrammeName,
OldGPAX: payload.GPAX,
MathGPA: payload.MathGPA,
EngGPA: payload.EngGPA,
SciGPA: payload.SciGPA,
School: payload.School,
Admission: payload.Admission,
})
if err != nil {
return err
}

return response.NewSuccessResponse(ctx, fiber.StatusOK, id)
return response.NewSuccessResponse(ctx, fiber.StatusOK, prediction)
}
21 changes: 21 additions & 0 deletions infrastructure/fiber/controller/student.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,3 +179,24 @@ func (c studentController) Delete(ctx *fiber.Ctx) error {

return response.NewSuccessResponse(ctx, fiber.StatusOK, nil)
}

func (c studentController) GetAllSchools(ctx *fiber.Ctx) error {
schools, err := c.studentUseCase.GetAllSchools()
if err != nil {
return err
}

return response.NewSuccessResponse(ctx, fiber.StatusOK, map[string]interface{}{
"schools": schools,
})
}
func (c studentController) GetAllAdmissions(ctx *fiber.Ctx) error {
admissions, err := c.studentUseCase.GetAllAdmissions()
if err != nil {
return err
}

return response.NewSuccessResponse(ctx, fiber.StatusOK, map[string]interface{}{
"admissions": admissions,
})
}
11 changes: 11 additions & 0 deletions infrastructure/fiber/request/prediction.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package request

type PredictPayload struct {
ProgrammeName string `json:"programmeName" validate:"required"`
GPAX *float64 `json:"gpax" validate:"required"`
MathGPA *float64 `json:"mathGPA" validate:"required"`
EngGPA *float64 `json:"engGPA" validate:"required"`
SciGPA *float64 `json:"sciGPA" validate:"required"`
School string `json:"school" validate:"required"`
Admission string `json:"admission" validate:"required"`
}
9 changes: 5 additions & 4 deletions infrastructure/fiber/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ type fiberServer struct {
gradeRepository entity.GradeRepository
sessionRepository entity.SessionRepository
coursePortfolioRepository entity.CoursePortfolioRepository
predictionRepository entity.PredictionRepository
courseStreamRepository entity.CourseStreamRepository

studentUseCase entity.StudentUseCase
Expand Down Expand Up @@ -103,7 +102,6 @@ func (f *fiberServer) initRepository() (err error) {
f.gradeRepository = repository.NewGradeRepositoryGorm(f.gorm)
f.sessionRepository = repository.NewSessionRepository(f.gorm)
f.coursePortfolioRepository = repository.NewCoursePortfolioRepositoryGorm(f.gorm)
f.predictionRepository = repository.NewPredictionRepositoryGorm(f.gorm)
f.courseStreamRepository = repository.NewCourseStreamRepository(f.gorm)

return nil
Expand All @@ -128,7 +126,7 @@ func (f *fiberServer) initUseCase() {
scoreUseCase := usecase.NewScoreUseCase(f.scoreRepository, enrollmentUseCase, assignmentUseCase, courseUseCase, userUseCase, studentUseCase)
courseStreamUseCase := usecase.NewCourseStreamUseCase(f.courseStreamRepository, courseUseCase)
coursePortfolioUseCase := usecase.NewCoursePortfolioUseCase(f.coursePortfolioRepository, courseUseCase, userUseCase, enrollmentUseCase, assignmentUseCase, scoreUseCase, studentUseCase, courseLearningOutcomeUseCase, courseStreamUseCase)
predictionUseCase := usecase.NewPredictionUseCase(f.predictionRepository, f.config)
predictionUseCase := usecase.NewPredictionUseCase(f.config)

f.assignmentUseCase = assignmentUseCase
f.authUseCase = authUseCase
Expand Down Expand Up @@ -189,6 +187,9 @@ func (f *fiberServer) initController() error {

api := app.Group("/")

api.Get("/schools", studentController.GetAllSchools)
api.Get("/admissions", studentController.GetAllAdmissions)

// student route
student := api.Group("/students", authMiddleware)

Expand Down Expand Up @@ -367,7 +368,7 @@ func (f *fiberServer) initController() error {
// prediction
prediction := api.Group("/prediction")

prediction.Get("/Train", predictionController.Train)
prediction.Post("/predict", predictionController.Predict)

// authentication route
auth := app.Group("/auth")
Expand Down
56 changes: 10 additions & 46 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,13 @@
import matplotlib.pyplot as plt
from sklearn.preprocessing import OneHotEncoder
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsRegressor
from sklearn.neighbors import KNeighborsClassifier
from sklearn.multioutput import MultiOutputRegressor
from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import StandardScaler
from sklearn.impute import SimpleImputer
from sklearn.metrics import r2_score
from sklearn.neural_network import MLPRegressor
from sklearn.linear_model import LinearRegression
from sklearn.svm import SVR
from sklearn.linear_model import Ridge
from sklearn.ensemble import RandomForestRegressor
from sklearn.decomposition import PCA
from sklearn.feature_selection import r_regression
# python3 predict.py <user> <password> <host> <port> <database> <prediction_id>
# python3 predict.py root root mysql 3306 inu_backyard 01HW8F6EAJ1JM46DH98QTMKN15
# python3 predict.py <user> <password> <host> <port> <database> <programme_name> <old_gpax> <math_gpa> <eng_gpa> <sci_gpa> <school> <admission>
# python3 predict.py root root mysql 3306 inu_backyard

# TODO: add visibility

Expand Down Expand Up @@ -64,17 +55,11 @@
pca = PCA()
y_gpax = np.array(data[:,[8]], dtype=float)
Xt_gpax = ctx.fit_transform(X)
# print(r_regression(admission, y_gpax.reshape(-1, 1)))

X_gpax_train, X_gpax_test, y_gpax_train, y_gpax_test = train_test_split(Xt_gpax, y_gpax, test_size=0.25, random_state=0)
X_gpax_pca = pca.fit_transform(X_gpax_train)
# print(y_gpax[0][0])
# print(type(y_gpax[0][0]))
# print(y_gpax[:,0])
# print(X_gpax_train[0])
# print(y_gpax_train[0])
# print (pca.explained_variance_ratio_.cumsum())
yscaler = StandardScaler().fit(y_gpax_train[:,-1].reshape(-1, 1))

yscaler = StandardScaler().fit(y_gpax_train[:,-1].reshape(-1, 1))
y_gpax_train = yscaler.transform(y_gpax_train[:,-1].reshape(-1, 1))

# plt.figure(figsize=(4,4))
Expand All @@ -83,35 +68,22 @@
# plt.ylabel("gpax")
# plt.show()

# modelregr = MLPRegressor(hidden_layer_sizes=(5,15,5), max_iter=5000)
# modelregr = RandomForestRegressor(criterion="squared_error", max_depth=5, n_estimators=1000)
# modelregr = SVR()
modelregr = LinearRegression()
# modelregr = KNeighborsRegressor(n_neighbors=30)

modelregr.fit(X_gpax_pca[:,:42], y_gpax_train)
y_gpax_predict = modelregr.predict(pca.transform(X_gpax_test)[:,:42]).reshape(-1, 1)
y_gpax_predict_iscaled = yscaler.inverse_transform(y_gpax_predict.reshape(-1, 1))

target = pd.DataFrame([[sys.argv[6], sys.argv[7], sys.argv[8], sys.argv[9], sys.argv[10], sys.argv[11], sys.argv[12]]]).to_numpy()
target = ctx.transform(target)
target = pca.transform(target)
prediction = modelregr.predict(target[:,:42])
print(round(yscaler.inverse_transform(prediction.reshape(-1, 1))[0,0], 2))

# test = pd.DataFrame([['regular', 3.99, 4,3.96, 4, 'โรงเรียนจำลอง', 'เรียนดี']]).to_numpy()
# test = pd.DataFrame([['regular', 1, 1,1, 1, 'หมีน้อย', 'เรียนดี']]).to_numpy()
test = pd.DataFrame([['regular', 1, 1,1, 1, 'เตรียมอุดมศึกษาน้อมเกล้า', 'เรียนดี']]).to_numpy()
test = pd.DataFrame([['regular', 4, 4, 4, 4, 'เตรียมอุดมศึกษาน้อมเกล้า', 'เรียนดี']]).to_numpy()
# test = pd.DataFrame([['regular', 3.99, 4,3.96, 4, 'อิสลามวิทยาลัยแห่งประเทศไทย', 'เรียนดี']]).to_numpy()
# test = pd.DataFrame([['regular', 1, 1,1, 1, 'อิสลามวิทยาลัยแห่งประเทศไทย', 'เรียนดี']]).to_numpy()
# test = pd.DataFrame([['regular', 4, 4, 4, 4, 'อิสลามวิทยาลัยแห่งประเทศไทย', 'เรียนดี']]).to_numpy()
# test = pd.DataFrame([['regular', 3.8, 3.45,4, 3.7, 'มหิดลวิทยานุสรณ์', 'เรียนดี']]).to_numpy()
# test = pd.DataFrame([['regular', 1, 1,1, 1, 'มหิดลวิทยานุสรณ์', 'เรียนดี']]).to_numpy()
# test = pd.DataFrame([['regular', 4, 4, 4, 4, 'มหิดลวิทยานุสรณ์', 'เรียนดี']]).to_numpy()
test = ctx.transform(test)
test = pca.transform(test)
prediction = modelregr.predict(test[:,:42])
print(yscaler.inverse_transform(prediction.reshape(-1, 1)))
sys.exit(0)

# print(r2_score(yscaler.transform(y_gpax_test[:,-1].reshape(-1, 1)), y_gpax_predict))


## Predict remark from admission and current GPAX

# y_remark = data[:,[6]]
Expand Down Expand Up @@ -150,11 +122,3 @@
# print(yt[0])
# print(cty.get_feature_names_out())
# print(y_train[:, -1])







sys.exit(0)
69 changes: 0 additions & 69 deletions repository/prediction.go

This file was deleted.

36 changes: 36 additions & 0 deletions repository/student.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package repository

import (
"database/sql"
"fmt"

"github.com/team-inu/inu-backyard/entity"
Expand Down Expand Up @@ -118,3 +119,38 @@ func (r studentRepositoryGorm) FilterExisted(studentIds []string) ([]string, err

return existedIds, nil
}

func (r studentRepositoryGorm) GetAllSchools() ([]string, error) {
var schools []sql.NullString

err := r.gorm.Raw("SELECT DISTINCT school FROM student").Scan(&schools).Error
if err != nil {
return nil, fmt.Errorf("cannot query student: %w", err)
}

nonNullSchool := make([]string, 0)
for _, school := range schools {
if school.Valid {
nonNullSchool = append(nonNullSchool, school.String)
}
}

return nonNullSchool, nil
}
func (r studentRepositoryGorm) GetAllAdmissions() ([]string, error) {
var admissions []sql.NullString

err := r.gorm.Raw("SELECT DISTINCT admission FROM student").Scan(&admissions).Error
if err != nil {
return nil, fmt.Errorf("cannot query student: %w", err)
}

nonNullAdmission := make([]string, 0)
for _, admission := range admissions {
if admission.Valid {
nonNullAdmission = append(nonNullAdmission, admission.String)
}
}

return nonNullAdmission, nil
}
Loading

0 comments on commit 93254be

Please sign in to comment.