Skip to content

Commit

Permalink
Share code from Processor.from_pretrained
Browse files Browse the repository at this point in the history
  • Loading branch information
hans00 committed Jun 12, 2024
1 parent 790b082 commit e72ad8f
Showing 1 changed file with 14 additions and 37 deletions.
51 changes: 14 additions & 37 deletions src/processors.js
Original file line number Diff line number Diff line change
Expand Up @@ -2210,7 +2210,7 @@ export class AutoImageProcessor {
/**
* @typedef {import('./utils/hub.js').PretrainedOptions} PretrainedOptions
* @typedef {import('./tokenizers.js').TokenizerModel} TokenizerModel
* @typedef {Object.<string, FeatureExtractor | TokenizerModel>} ProcessorArgs
* @typedef {Object.<string, FeatureExtractor | TokenizerModel> & { config: Object }} ProcessorArgs
* @typedef {{ from_pretrained: (name_or_path: string, options: PretrainedOptions) => Promise<FeatureExtractor |TokenizerModel>}} AttributeLoader
* @typedef {Object.<string, AttributeLoader>} ProcessorAttributes
*/
Expand All @@ -2228,7 +2228,7 @@ export class Processor extends Callable {

/**
* Creates a new Processor with the given feature extractor.
* @param {ProcessorArgs & { config: Object }} args The config or function used to extract features from the input.
* @param {ProcessorArgs} args The config or function used to extract features from the input.
*/
constructor(args) {
super();
Expand Down Expand Up @@ -2458,18 +2458,9 @@ export class AutoProcessor {
revision,
})

let {
processor_class,
feature_extractor_type,
image_processor_type,
} = preprocessorConfig;
let { processor_class } = preprocessorConfig;

/**
* @type {ProcessorArgs}
*/
let args = {
config: preprocessorConfig,
};
let cls = this.PROCESSOR_CLASS_MAPPING[processor_class];

let options = {
progress_callback,
Expand All @@ -2478,32 +2469,18 @@ export class AutoProcessor {
local_files_only,
revision,
}
if (feature_extractor_type) {
args.feature_extractor = await AutoFeatureExtractor.from_pretrained(pretrained_model_name_or_path, options);
} else if (image_processor_type) {
args.feature_extractor = await AutoImageProcessor.from_pretrained(pretrained_model_name_or_path, options);
} else if (!processor_class) {
throw new Error(`Missing required 'feature_extractor_type' or 'image_processor_type' or 'processor_class' in config.`);
}

let cls = this.PROCESSOR_CLASS_MAPPING[processor_class] ?? Processor;
if (cls.ATTRIBUTES) {
let promises = Object.entries(cls.ATTRIBUTES)
.filter(([key]) => !(key in args))
.map(([key, attr_cls]) =>
[
key,
attr_cls.from_pretrained(pretrained_model_name_or_path, {
progress_callback,
cache_dir,
local_files_only,
revision,
})
]
);
Object.assign(args, Object.fromEntries(await Promise.all(promises)));
// Check if the processor class is a feature extractor only
if (cls?.ATTRIBUTES && Object.keys(cls.ATTRIBUTES).length === 1 && cls.ATTRIBUTES.feature_extractor) {
return new cls({
config: preprocessorConfig,
feature_extractor: cls.ATTRIBUTES.feature_extractor.from_pretrained(pretrained_model_name_or_path, options),
});
} else if (!cls) {
throw new Error(`Unknown Processor class: ${processor_class}`);
} else {
return await cls.from_pretrained(pretrained_model_name_or_path, options);
}
return new cls(args);
}
}
//////////////////////////////////////////////////

0 comments on commit e72ad8f

Please sign in to comment.