Skip to content

Commit

Permalink
added retrain and train model function to update the predictor model
Browse files Browse the repository at this point in the history
  • Loading branch information
vishesh-baghel committed Jan 2, 2024
1 parent a742c58 commit 8edd8e6
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 35 deletions.
4 changes: 2 additions & 2 deletions src/main.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import {
import { FileScoreMap } from "./types/FileScoreMap";
import { isValidFilePath } from "./fetch/fetchFiles";
import { connectMindsDB } from "./db/mindsdbConnection";
import { retrainModel } from "./services/predictionService";
import { trainPredictorModel } from "./services/predictionService";
import {
errorFallbackCommentForPRClosedEvent,
errorFallbackCommentForPROpenEvent,
Expand All @@ -30,7 +30,7 @@ export async function main(app: Probot) {
handleAppInstallationCreatedEvents(app);
handlePullRequestOpenEvents(app);
handlePullRequestClosedEvents(app);
retrainModel(app);
trainPredictorModel(app);
debug(app);
}

Expand Down
59 changes: 36 additions & 23 deletions src/services/predictionService.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import MindsDB, {
} from "mindsdb-js-sdk";

const {
MINDSDB_HOST,
MONGODB_USER,
MONGODB_PASSWORD,
MONGODB_PORT,
Expand All @@ -18,22 +17,45 @@ const {
const databaseName = "mongo_datasource";
const projectName = "mindsdb";
const predictorName = "riskscore_predictor";
const targetField = "riskScore";
const aggregationQuery = `test.trainingfiles.find({})`;

export async function retrainModel(app: Probot) {
trainModel(app);
}
const regressionTrainingOptions: TrainingOptions = {
select: aggregationQuery,
integration: databaseName,
orderBy: "createdAt",
groupBy: "installationId",
window: 100, // How many rows in the past to use when making a future prediction.
horizon: 10, // How many rows in the future to forecast.
};

async function connectToMindsDB() {
await MindsDB.connect({
user: "",
password: "",
host: MINDSDB_HOST,
});
export async function retrainPredictorModel(app: Probot) {
await MindsDB.Models.retrainModel(
predictorName,
targetField,
projectName,
regressionTrainingOptions
)
.then(() => {
app.log.info(`[${predictorName}] model is retrained successfully`);
})
.catch((error: any) => {
app.log.error(
`Error occurred while retraining the model [${predictorName}]`
);
app.log.error(error);
});
}

async function trainModel(app: Probot) {
export async function trainPredictorModel(app: Probot) {
try {
const models: Model[] = await MindsDB.Models.getAllModels(projectName);
const modelNames = models.map((model: Model) => model.name);

if (modelNames.includes(predictorName)) {
app.log.info(`[${predictorName}] model is already present in mindsdb`);
return;
}
app.log.info(`Started training the model: [${predictorName}]`);
const dbList: Database[] = await MindsDB.Databases.getAllDatabases();
const dbNames: string[] = dbList.map((db: Database) => db.name);
Expand All @@ -43,18 +65,9 @@ async function trainModel(app: Probot) {
app.log.info(`Created database: ${db?.name} in mindsdb successfully`);
}

const regressionTrainingOptions: TrainingOptions = {
select: aggregationQuery,
integration: databaseName,
orderBy: "createdAt",
groupBy: "installationId",
window: 100, // How many rows in the past to use when making a future prediction.
horizon: 10, // How many rows in the future to forecast.
};

let predictionModel: Model | undefined = await MindsDB.Models.trainModel(
predictorName,
"riskScore",
targetField,
projectName,
regressionTrainingOptions
);
Expand All @@ -65,8 +78,8 @@ async function trainModel(app: Probot) {
projectName
);

if (predictionModel?.active) {
app.log.info("Prediction model is active");
if (predictionModel?.status.match("error")) {
app.log.info("Prediction model training is complete");
clearInterval(intervalId);
}
}, 2000);
Expand Down
4 changes: 2 additions & 2 deletions src/services/pullRequestService.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import { getAllCommits } from "../fetch/fetchCommits";
import { calculateRiskScore } from "./riskScoreService";
import { FileType } from "../types/FileType";
import { FileStatus } from "../constants/GithubContants";
import { retrainModel } from "./predictionService";
import { retrainPredictorModel } from "./predictionService";
import { TrainingFileType } from "../types/TrainingFileType";
import { TrainingFile } from "../db/models/TrainingFile";

Expand Down Expand Up @@ -142,7 +142,7 @@ export async function updateFilesInDb(
`Updated the files coming from pull request with ref: [${owner}/${repoName}/pulls/${pullNumber}] successfully for installation id: [${installationId}]`
);

retrainModel(app);
retrainPredictorModel(app);

return true;
} catch (error: any) {
Expand Down
20 changes: 12 additions & 8 deletions src/services/repositoryService.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import { fetchDetailsWithInstallationId } from "../fetch/fetch";
import { FileType } from "../types/FileType";
import { TrainingFile } from "../db/models/TrainingFile";
import { TrainingFileType } from "../types/TrainingFileType";
import { retrainPredictorModel } from "./predictionService";

export async function processRepositories(
app: Probot,
Expand Down Expand Up @@ -38,9 +39,9 @@ async function processRepositoryBatch(
owner: string,
repositories: any[]
): Promise<void> {
await Promise.all(
repositories.map(async (repo) => {
try {
try {
await Promise.all(
repositories.map(async (repo) => {
app.log.info(`Started processing repository: [${repo.name}]`);

const defaultBranch = await getDefaultBranch(
Expand Down Expand Up @@ -119,11 +120,14 @@ async function processRepositoryBatch(
app.log.info(
`Completed the processing of [${owner}/${repo.name}] repository successfully for installation id: [${installationId}]`
);
} catch (error: any) {
app.log.error(error);
}
})
);

retrainPredictorModel(app);
})
);
} catch (error: any) {
app.log.error(`Error while processing the repository batch`);
app.log.error(error);
}
}

async function getDefaultBranch(
Expand Down

0 comments on commit 8edd8e6

Please sign in to comment.