From 8edd8e688da2436c00844cf41bdc0a16d97d5a5d Mon Sep 17 00:00:00 2001 From: vishesh-baghel Date: Tue, 2 Jan 2024 22:28:12 +0530 Subject: [PATCH] added retrain and train model function to update the predictor model --- src/main.ts | 4 +- src/services/predictionService.ts | 59 ++++++++++++++++++------------ src/services/pullRequestService.ts | 4 +- src/services/repositoryService.ts | 20 ++++++---- 4 files changed, 52 insertions(+), 35 deletions(-) diff --git a/src/main.ts b/src/main.ts index fa9351c..91c6def 100644 --- a/src/main.ts +++ b/src/main.ts @@ -16,7 +16,7 @@ import { import { FileScoreMap } from "./types/FileScoreMap"; import { isValidFilePath } from "./fetch/fetchFiles"; import { connectMindsDB } from "./db/mindsdbConnection"; -import { retrainModel } from "./services/predictionService"; +import { trainPredictorModel } from "./services/predictionService"; import { errorFallbackCommentForPRClosedEvent, errorFallbackCommentForPROpenEvent, @@ -30,7 +30,7 @@ export async function main(app: Probot) { handleAppInstallationCreatedEvents(app); handlePullRequestOpenEvents(app); handlePullRequestClosedEvents(app); - retrainModel(app); + trainPredictorModel(app); debug(app); } diff --git a/src/services/predictionService.ts b/src/services/predictionService.ts index eb772cb..17fefb1 100644 --- a/src/services/predictionService.ts +++ b/src/services/predictionService.ts @@ -7,7 +7,6 @@ import MindsDB, { } from "mindsdb-js-sdk"; const { - MINDSDB_HOST, MONGODB_USER, MONGODB_PASSWORD, MONGODB_PORT, @@ -18,22 +17,45 @@ const { const databaseName = "mongo_datasource"; const projectName = "mindsdb"; const predictorName = "riskscore_predictor"; +const targetField = "riskScore"; const aggregationQuery = `test.trainingfiles.find({})`; -export async function retrainModel(app: Probot) { - trainModel(app); -} +const regressionTrainingOptions: TrainingOptions = { + select: aggregationQuery, + integration: databaseName, + orderBy: "createdAt", + groupBy: "installationId", + window: 100, // How many rows in the past to use when making a future prediction. + horizon: 10, // How many rows in the future to forecast. +}; -async function connectToMindsDB() { - await MindsDB.connect({ - user: "", - password: "", - host: MINDSDB_HOST, - }); +export async function retrainPredictorModel(app: Probot) { + await MindsDB.Models.retrainModel( + predictorName, + targetField, + projectName, + regressionTrainingOptions + ) + .then(() => { + app.log.info(`[${predictorName}] model is retrained successfully`); + }) + .catch((error: any) => { + app.log.error( + `Error occurred while retraining the model [${predictorName}]` + ); + app.log.error(error); + }); } -async function trainModel(app: Probot) { +export async function trainPredictorModel(app: Probot) { try { + const models: Model[] = await MindsDB.Models.getAllModels(projectName); + const modelNames = models.map((model: Model) => model.name); + + if (modelNames.includes(predictorName)) { + app.log.info(`[${predictorName}] model is already present in mindsdb`); + return; + } app.log.info(`Started training the model: [${predictorName}]`); const dbList: Database[] = await MindsDB.Databases.getAllDatabases(); const dbNames: string[] = dbList.map((db: Database) => db.name); @@ -43,18 +65,9 @@ async function trainModel(app: Probot) { app.log.info(`Created database: ${db?.name} in mindsdb successfully`); } - const regressionTrainingOptions: TrainingOptions = { - select: aggregationQuery, - integration: databaseName, - orderBy: "createdAt", - groupBy: "installationId", - window: 100, // How many rows in the past to use when making a future prediction. - horizon: 10, // How many rows in the future to forecast. - }; - let predictionModel: Model | undefined = await MindsDB.Models.trainModel( predictorName, - "riskScore", + targetField, projectName, regressionTrainingOptions ); @@ -65,8 +78,8 @@ async function trainModel(app: Probot) { projectName ); - if (predictionModel?.active) { - app.log.info("Prediction model is active"); + if (predictionModel?.status.match("error")) { + app.log.info("Prediction model training is complete"); clearInterval(intervalId); } }, 2000); diff --git a/src/services/pullRequestService.ts b/src/services/pullRequestService.ts index 23bc1b9..5f6a038 100644 --- a/src/services/pullRequestService.ts +++ b/src/services/pullRequestService.ts @@ -8,7 +8,7 @@ import { getAllCommits } from "../fetch/fetchCommits"; import { calculateRiskScore } from "./riskScoreService"; import { FileType } from "../types/FileType"; import { FileStatus } from "../constants/GithubContants"; -import { retrainModel } from "./predictionService"; +import { retrainPredictorModel } from "./predictionService"; import { TrainingFileType } from "../types/TrainingFileType"; import { TrainingFile } from "../db/models/TrainingFile"; @@ -142,7 +142,7 @@ export async function updateFilesInDb( `Updated the files coming from pull request with ref: [${owner}/${repoName}/pulls/${pullNumber}] successfully for installation id: [${installationId}]` ); - retrainModel(app); + retrainPredictorModel(app); return true; } catch (error: any) { diff --git a/src/services/repositoryService.ts b/src/services/repositoryService.ts index e88c57b..93320af 100644 --- a/src/services/repositoryService.ts +++ b/src/services/repositoryService.ts @@ -10,6 +10,7 @@ import { fetchDetailsWithInstallationId } from "../fetch/fetch"; import { FileType } from "../types/FileType"; import { TrainingFile } from "../db/models/TrainingFile"; import { TrainingFileType } from "../types/TrainingFileType"; +import { retrainPredictorModel } from "./predictionService"; export async function processRepositories( app: Probot, @@ -38,9 +39,9 @@ async function processRepositoryBatch( owner: string, repositories: any[] ): Promise { - await Promise.all( - repositories.map(async (repo) => { - try { + try { + await Promise.all( + repositories.map(async (repo) => { app.log.info(`Started processing repository: [${repo.name}]`); const defaultBranch = await getDefaultBranch( @@ -119,11 +120,14 @@ async function processRepositoryBatch( app.log.info( `Completed the processing of [${owner}/${repo.name}] repository successfully for installation id: [${installationId}]` ); - } catch (error: any) { - app.log.error(error); - } - }) - ); + + retrainPredictorModel(app); + }) + ); + } catch (error: any) { + app.log.error(`Error while processing the repository batch`); + app.log.error(error); + } } async function getDefaultBranch(