Skip to content

Commit

Permalink
Support for simple voting ensembles
Browse files Browse the repository at this point in the history
  • Loading branch information
sdadas committed Oct 8, 2021
1 parent 78c410a commit be601b6
Show file tree
Hide file tree
Showing 5 changed files with 120 additions and 12 deletions.
2 changes: 1 addition & 1 deletion preprocess/processor.py
Expand Up @@ -12,7 +12,7 @@

class TaskProcessor(object):

def __init__(self, task: BaseTask, data_path: str, output_path: str, model_path: str, resample: str):
def __init__(self, task: BaseTask, data_path: str, output_path: str, model_path: str, resample: str=None):
self.task: BaseTask = task
self.data_path: str = data_path
self.model_path = model_path
Expand Down
97 changes: 97 additions & 0 deletions run_ensemble.py
@@ -0,0 +1,97 @@
import logging
from collections import Counter
from datetime import datetime
from random import choice
from typing import List, Tuple, Optional, Dict

import fire
import string
import fcntl

import json

from preprocess.processor import TaskProcessor
from tasks import TASKS, BaseTask
from train.evaluator import TaskEvaluatorBuilder, TaskEvaluator

logging.basicConfig(format='%(asctime)s : %(message)s', level=logging.DEBUG)
logging.root.setLevel(logging.DEBUG)


class EnsemblePrediction(object):

def __init__(self, y_true, task: BaseTask):
self.y_true = y_true
self.y_pred = []
self.aggregate = self._vote_ensemble if task.spec().task_type == "classification" else self._avg_ensemble

def add(self, val):
self.y_pred.append(val)

def _vote_ensemble(self):
counter = Counter(self.y_pred)
most_common: List[Tuple] = counter.most_common(1)
return most_common[0][0]

def _avg_ensemble(self):
return sum(self.y_pred) / len(self.y_pred)


class EnsembleRunner(object):

def __init__(self, arch: str, task_name: str, task: BaseTask, input_dir: str, output_dir: str):
self.task_name = task_name
self.task = task
self.arch = arch
self.input_dir = input_dir
self.output_dir = output_dir
self.predictions: Optional[List[EnsemblePrediction]] = None
self.evaluator = None

def evaluate_model(self, model_dir: str):
logging.info("generating predictions for model %s", model_dir)
builder = TaskEvaluatorBuilder(self.task, self.arch, model_dir, pre_trained_model=True)
self.evaluator: TaskEvaluator = builder.build()
y_true, y_pred = self.evaluator.generate_predictions()
if self.predictions is None:
self.predictions = [EnsemblePrediction(val, self.task) for val in y_true]
for idx, pred in enumerate(y_pred):
self.predictions[idx].add(pred)

def evaluate_ensemble(self, task_id: str):
y_true = [val.y_true for val in self.predictions]
y_pred = [val.aggregate() for val in self.predictions]
return self.evaluator.evaluate_predictions(y_true, y_pred, task_id)

def prepare_task(self, model_dir: str):
processor = TaskProcessor(self.task, self.input_dir, self.output_dir, model_dir)
processor.prepare()

def log_score(self, task_name: str, task_id: str, params: Dict, scores: Dict):
now = datetime.now().strftime("%d/%m/%Y,%H:%M:%S")
res = {"id": task_id, "task": task_name, "timestamp": now, "scores": scores, "params": params, "ensemble": True}
with open("runlog.txt", "a", encoding="utf-8") as output_file:
fcntl.flock(output_file, fcntl.LOCK_EX)
json.dump(res, output_file)
output_file.write("\n")
fcntl.flock(output_file, fcntl.LOCK_UN)


def run_ensemble(arch: str, task: str, model_dirs: List[str], input_dir: str="data", output_dir: str="data_processed"):
params = dict(locals())
task_name = task
task_class = TASKS.get(task)
if task_class is None: raise Exception(f"Unknown task {task}")
task = task_class()
runner = EnsembleRunner(arch, task_name, task, input_dir, output_dir)
for model_dir in model_dirs:
runner.prepare_task(model_dir)
runner.evaluate_model(model_dir)
rand = ''.join(choice(string.ascii_uppercase + string.ascii_lowercase + string.digits) for _ in range(8))
task_id = task_name.lower() + "_" + rand
scores = runner.evaluate_ensemble(task_id)
runner.log_score(task_name, task_id, params, scores)


if __name__ == '__main__':
fire.Fire(run_ensemble)
9 changes: 4 additions & 5 deletions run_tasks.py
@@ -1,12 +1,11 @@
import logging
import string
from random import choice

import fire
import fcntl
from datetime import datetime

import torch.cuda

from preprocess.processor import TaskProcessor
from train.evaluator import TaskEvaluatorBuilder
from tasks import *
Expand Down Expand Up @@ -41,8 +40,8 @@ def train_task(self, train_epochs: int, fp16: bool, lr: str, max_sentences: int,
trainer.train(train_epochs=train_epochs, max_sentences=max_sentences, update_freq=update_freq)

def evaluate_task(self, verbose: bool=False, sharded_model: bool=False):
builder = TaskEvaluatorBuilder(self.task, self.arch, self.model_dir, self.input_dir, self.output_dir,
verbose=verbose, sharded_model=sharded_model)
builder = TaskEvaluatorBuilder(self.task, self.arch, self.model_dir, self.input_dir,
output_dir=self.output_dir, verbose=verbose, sharded_model=sharded_model)
evaluator = builder.build()
return evaluator.evaluate(self.task_id)

Expand Down Expand Up @@ -83,7 +82,7 @@ def run_tasks(arch: str, model_dir: str, input_dir: str="data", output_dir: str=
task = task_class()
task_id = task_name.lower() + "_" + rand
cross_validation = cv_folds > 1
task_runs = CrossValidatedTask.cv_folds(task, cv_folds) if cross_validation else [task]
task_runs = CrossValidatedTask.cv_folds(task, cv_folds, seed) if cross_validation else [task]
for idx, task_run in enumerate(task_runs):
task_run_id = task_id
if cross_validation: task_run_id += f"-fold{idx}"
Expand Down
9 changes: 5 additions & 4 deletions tasks.py
@@ -1,5 +1,4 @@
import json
import logging
import random
import os
from itertools import chain
Expand Down Expand Up @@ -120,18 +119,20 @@ def read_simple(self, data_path: str, split: str, separator: str=" ", label_firs

class CrossValidatedTask(BaseTask):

def __init__(self, wrapped_task: BaseTask, num_folds: int=4):
def __init__(self, wrapped_task: BaseTask, num_folds: int=4, seed: int=None):
self.wrapped_task: BaseTask = wrapped_task
self.num_folds = num_folds
self.folds = None
self._spec = wrapped_task.spec()
self.set_fold(0)
self.seed = seed

def set_fold(self, fold: int):
self.fold = fold
self._spec.output_dir = f"{self._spec.task_dir}-fold{self.fold}"

def _read_folds(self, data_path: str):
if self.seed is not None: random.seed(self.seed)
data: List[DataExample] = []
for record in self.wrapped_task.read(data_path, "train"):
data.append(record)
Expand All @@ -154,8 +155,8 @@ def read(self, data_path: str, split: str) -> Iterable[DataExample]:
return [rec for rec in self.folds[self.fold]]

@staticmethod
def cv_folds(wrapped_task: BaseTask, num_folds: int=4) -> Iterable[BaseTask]:
task = CrossValidatedTask(wrapped_task, num_folds)
def cv_folds(wrapped_task: BaseTask, num_folds: int=4, seed: int=None) -> Iterable[BaseTask]:
task = CrossValidatedTask(wrapped_task, num_folds, seed)
for fold in range(num_folds):
task.set_fold(fold)
yield task
Expand Down
15 changes: 13 additions & 2 deletions train/evaluator.py
Expand Up @@ -45,13 +45,16 @@ def _init_prediction_settings(self):
def _get_label(self, label):
return self.model.task.label_dictionary.string([label + self.model.task.label_dictionary.nspecial])

def evaluate(self, task_id: str="sample_task"):
def generate_predictions(self):
y_true = []
y_pred = []
logging.info("generating predictions for task %s", self.task.spec().output_dir)
for record in self.task.read(self.data_path, "test"):
y_true.append(self.get_true(record) if record.label is not None else None)
y_pred.append(self.predict(record))
return y_true, y_pred

def evaluate_predictions(self, y_true, y_pred, task_id: str="sample_task"):
if y_true[0] is None:
logging.info("No test labels available, skipping evaluation for task %s", self.task.spec().output_dir)
scores = {}
Expand All @@ -61,6 +64,10 @@ def evaluate(self, task_id: str="sample_task"):
self.save_results(y_pred, task_id)
return scores

def evaluate(self, task_id: str="sample_task"):
y_true, y_pred = self.generate_predictions()
return self.evaluate_predictions(y_true, y_pred, task_id)

def predict(self, record: DataExample, logits: bool=False):
tokens = self.model.encode(*record.inputs)
if tokens.size()[0] > self.maxlen:
Expand Down Expand Up @@ -90,7 +97,7 @@ def save_results(self, y_pred: List[any], task_id: str):

class TaskEvaluatorBuilder(object):

def __init__(self, task: BaseTask, arch: str, model_dir: str, input_dir: str="data",
def __init__(self, task: BaseTask, arch: str, model_dir: str, input_dir: str="data", pre_trained_model=False,
output_dir: str="data_processed", verbose=False, sharded_model=False):
self.task = task
self.arch = arch
Expand All @@ -99,6 +106,7 @@ def __init__(self, task: BaseTask, arch: str, model_dir: str, input_dir: str="da
self.output_dir = output_dir
self.verbose = verbose
self.model_name = os.path.basename(model_dir)
self.pre_trained_model = pre_trained_model
self.task_output_dir: str = os.path.join(self.output_dir, f"{task.spec().output_path()}-bin")
self.sharded_model = sharded_model

Expand All @@ -110,6 +118,9 @@ def build(self) -> TaskEvaluator:
if arch_type.startswith("xlmr"): arch_type = "roberta"
model_class = model_classes[arch_type][0]
spm_path = os.path.join(self.model_dir, "sentencepiece.bpe.model")
if self.pre_trained_model:
checkpoints_output_dir = self.model_dir
checkpoint_file = "model.pt"
loaded = self.from_pretrained(
model_name_or_path=checkpoints_output_dir,
checkpoint_file=checkpoint_file,
Expand Down

0 comments on commit be601b6

Please sign in to comment.