From 72e841349b14994912d4d77704bbf912d3782339 Mon Sep 17 00:00:00 2001 From: Th3G33k <666th3g33k666@monmail.fr.nf> Date: Fri, 5 Apr 2024 22:23:47 -1000 Subject: [PATCH 1/9] Add custom task `register_pipeline` --- src/env.js | 2 ++ src/pipelines.js | 73 +++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 74 insertions(+), 1 deletion(-) diff --git a/src/env.js b/src/env.js index 109a065e4..9bca38d9f 100644 --- a/src/env.js +++ b/src/env.js @@ -106,6 +106,8 @@ export const env = { localModelPath: localModelPath, useFS: FS_AVAILABLE, + customTasks: {}, + /////////////////// Cache settings /////////////////// useBrowserCache: WEB_CACHE_AVAILABLE, diff --git a/src/pipelines.js b/src/pipelines.js index 2b064d522..43c20043b 100644 --- a/src/pipelines.js +++ b/src/pipelines.js @@ -69,6 +69,7 @@ import { interpolate, } from './utils/tensor.js'; import { RawImage } from './utils/image.js'; +import { env } from './env.js'; /** @@ -3031,6 +3032,7 @@ const SUPPORTED_TASKS = Object.freeze({ "type": "text", }, "image-feature-extraction": { + // no tokenizer "processor": AutoProcessor, "pipeline": ImageFeatureExtractionPipeline, "model": [AutoModelForImageFeatureExtraction, AutoModel], @@ -3116,7 +3118,7 @@ export async function pipeline( task = TASK_ALIASES[task] ?? task; // Get pipeline info - const pipelineInfo = SUPPORTED_TASKS[task.split('_', 1)[0]]; + const pipelineInfo = SUPPORTED_TASKS[task.split('_', 1)[0]] ?? env.customTasks[task.split('_', 1)[0]]; if (!pipelineInfo) { throw Error(`Unsupported pipeline: ${task}. Must be one of [${Object.keys(SUPPORTED_TASKS)}]`) } @@ -3211,4 +3213,73 @@ async function loadItems(mapping, model, pretrainedOptions) { } return result; +} + +/** + * Register a custom task pipeline. + * @param {string} task + * + * **Example:** Custom task: audio-feature-extraction. + * ```javascript + * import { + * Pipeline, + * read_audio, + * register_pipeline, + * ClapAudioModelWithProjection, + * AutoProcessor, + * pipeline + * } from '@xenova/transformers'; + * + * class AudioFeatureExtractionPipeline extends Pipeline { + * constructor(options) { + * super(options); + * } + * async _call(input, kwargs = {}) { + * input = await read_audio(input) + * input = await this.processor(input); + * let { audio_embeds } = await this.model(input); + * return audio_embeds + * } + * } + * + * register_pipeline('audio-feature-extraction', { + * pipeline: AudioFeatureExtractionPipeline, + * model: ClapAudioModelWithProjection, + * processor: AutoProcessor, + * model_name: 'Xenova/larger_clap_music_and_speech' + * }) + * + * let pipe = await pipeline('audio-feature-extraction'); + * let out = await pipe('https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/jfk.wav') + * console.log(out) + * ``` + */ +export function register_pipeline( + task, + { + tokenizer = undefined, + pipeline = undefined, + model = undefined, + processor = undefined, + model_name = '', + type = '' + } = {} +) { + if(!( + ('prototype' in pipeline) && + (pipeline.prototype instanceof Pipeline) && + ("_call" in pipeline.prototype) + )){ + throw Error('pipeline class must inherits from Pipeline, and contains _call') + } + + env.customTasks[task] = { + tokenizer, + pipeline, + model, + processor, + 'default': {model: model_name}, + type + } + } \ No newline at end of file From 286fefa883a15028de14621a67203707b793933d Mon Sep 17 00:00:00 2001 From: Th3G33k <666th3g33k666@monmail.fr.nf> Date: Sat, 6 Apr 2024 11:03:56 -1000 Subject: [PATCH 2/9] Change model_name in register_pipeline --- src/pipelines.js | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/pipelines.js b/src/pipelines.js index 43c20043b..6ca6dfad7 100644 --- a/src/pipelines.js +++ b/src/pipelines.js @@ -3246,7 +3246,7 @@ async function loadItems(mapping, model, pretrainedOptions) { * pipeline: AudioFeatureExtractionPipeline, * model: ClapAudioModelWithProjection, * processor: AutoProcessor, - * model_name: 'Xenova/larger_clap_music_and_speech' + * default_model: 'Xenova/larger_clap_music_and_speech' * }) * * let pipe = await pipeline('audio-feature-extraction'); @@ -3261,7 +3261,7 @@ export function register_pipeline( pipeline = undefined, model = undefined, processor = undefined, - model_name = '', + default_model = '', type = '' } = {} ) { @@ -3278,7 +3278,7 @@ export function register_pipeline( pipeline, model, processor, - 'default': {model: model_name}, + 'default': {model: default_model}, type } From 28423a6ae266f22f20d6103ebf5df8d7851f419c Mon Sep 17 00:00:00 2001 From: Th3G33k <4394090+Th3G33k@users.noreply.github.com> Date: Wed, 8 May 2024 08:44:46 -1000 Subject: [PATCH 3/9] add custom tasks to SUPPORTED_TASKS --- src/env.js | 2 -- src/pipelines.js | 27 ++++++++++++--------------- 2 files changed, 12 insertions(+), 17 deletions(-) diff --git a/src/env.js b/src/env.js index 9bca38d9f..109a065e4 100644 --- a/src/env.js +++ b/src/env.js @@ -106,8 +106,6 @@ export const env = { localModelPath: localModelPath, useFS: FS_AVAILABLE, - customTasks: {}, - /////////////////// Cache settings /////////////////// useBrowserCache: WEB_CACHE_AVAILABLE, diff --git a/src/pipelines.js b/src/pipelines.js index 6ca6dfad7..499dfe7e9 100644 --- a/src/pipelines.js +++ b/src/pipelines.js @@ -69,7 +69,6 @@ import { interpolate, } from './utils/tensor.js'; import { RawImage } from './utils/image.js'; -import { env } from './env.js'; /** @@ -2758,7 +2757,7 @@ export class DepthEstimationPipeline extends (/** @type {new (options: ImagePipe } } -const SUPPORTED_TASKS = Object.freeze({ +const SUPPORTED_TASKS = { "text-classification": { "tokenizer": AutoTokenizer, "pipeline": TextClassificationPipeline, @@ -3032,7 +3031,6 @@ const SUPPORTED_TASKS = Object.freeze({ "type": "text", }, "image-feature-extraction": { - // no tokenizer "processor": AutoProcessor, "pipeline": ImageFeatureExtractionPipeline, "model": [AutoModelForImageFeatureExtraction, AutoModel], @@ -3043,7 +3041,7 @@ const SUPPORTED_TASKS = Object.freeze({ }, "type": "image", }, -}) +} // TODO: Add types for TASK_ALIASES @@ -3118,7 +3116,7 @@ export async function pipeline( task = TASK_ALIASES[task] ?? task; // Get pipeline info - const pipelineInfo = SUPPORTED_TASKS[task.split('_', 1)[0]] ?? env.customTasks[task.split('_', 1)[0]]; + const pipelineInfo = SUPPORTED_TASKS[task.split('_', 1)[0]]; if (!pipelineInfo) { throw Error(`Unsupported pipeline: ${task}. Must be one of [${Object.keys(SUPPORTED_TASKS)}]`) } @@ -3232,12 +3230,11 @@ async function loadItems(mapping, model, pretrainedOptions) { * * class AudioFeatureExtractionPipeline extends Pipeline { * constructor(options) { - * super(options); + * super(options) * } * async _call(input, kwargs = {}) { - * input = await read_audio(input) - * input = await this.processor(input); - * let { audio_embeds } = await this.model(input); + * input = await read_audio(input).then(input=>this.processor(input)) + * let { audio_embeds } = await this.model(input) * return audio_embeds * } * } @@ -3257,10 +3254,10 @@ async function loadItems(mapping, model, pretrainedOptions) { export function register_pipeline( task, { - tokenizer = undefined, - pipeline = undefined, - model = undefined, - processor = undefined, + tokenizer, + pipeline, + model, + processor, default_model = '', type = '' } = {} @@ -3270,10 +3267,10 @@ export function register_pipeline( (pipeline.prototype instanceof Pipeline) && ("_call" in pipeline.prototype) )){ - throw Error('pipeline class must inherits from Pipeline, and contains _call') + throw Error('pipeline class must inherit from Pipeline, and contains _call') } - env.customTasks[task] = { + SUPPORTED_TASKS[task] = { tokenizer, pipeline, model, From 091b321ccba39c2ac2ada8e55f66cc19dd2de97b Mon Sep 17 00:00:00 2001 From: Th3G33k <4394090+Th3G33k@users.noreply.github.com> Date: Wed, 8 May 2024 20:39:20 -1000 Subject: [PATCH 4/9] models.js getModelClassFromName + refactor model mapping --- src/models.js | 503 ++++++++++++++++++++++++----------------------- src/pipelines.js | 23 ++- 2 files changed, 271 insertions(+), 255 deletions(-) diff --git a/src/models.js b/src/models.js index 0cc865abd..a62ec904b 100644 --- a/src/models.js +++ b/src/models.js @@ -110,6 +110,7 @@ import { medianFilter } from './utils/maths.js'; import { EosTokenCriteria, MaxLengthCriteria, StoppingCriteriaList } from './generation/stopping_criteria.js'; import { LogitsSampler } from './generation/logits_sampler.js'; import { apis } from './env.js'; +import * as modelsExport from './models.js'; ////////////////////////////////////////////////// // Model types: used internally @@ -134,6 +135,16 @@ const MODEL_TYPE_MAPPING = new Map(); const MODEL_NAME_TO_CLASS_MAPPING = new Map(); const MODEL_CLASS_TO_NAME_MAPPING = new Map(); +/** + * Get model class from name + * @param {string} name + * @returns + */ +export function getModelClassFromName(name){ + let cls = MODEL_NAME_TO_CLASS_MAPPING.get(name); + if(!cls) console.warn(name + ' undefined'); + return cls; +} /** * Constructs an InferenceSession using a model file located at the specified path. @@ -5899,316 +5910,319 @@ export class PretrainedMixin { } const MODEL_MAPPING_NAMES_ENCODER_ONLY = new Map([ - ['bert', ['BertModel', BertModel]], - ['nomic_bert', ['NomicBertModel', NomicBertModel]], - ['roformer', ['RoFormerModel', RoFormerModel]], - ['electra', ['ElectraModel', ElectraModel]], - ['esm', ['EsmModel', EsmModel]], - ['convbert', ['ConvBertModel', ConvBertModel]], - ['camembert', ['CamembertModel', CamembertModel]], - ['deberta', ['DebertaModel', DebertaModel]], - ['deberta-v2', ['DebertaV2Model', DebertaV2Model]], - ['mpnet', ['MPNetModel', MPNetModel]], - ['albert', ['AlbertModel', AlbertModel]], - ['distilbert', ['DistilBertModel', DistilBertModel]], - ['roberta', ['RobertaModel', RobertaModel]], - ['xlm', ['XLMModel', XLMModel]], - ['xlm-roberta', ['XLMRobertaModel', XLMRobertaModel]], - ['clap', ['ClapModel', ClapModel]], - ['clip', ['CLIPModel', CLIPModel]], - ['clipseg', ['CLIPSegModel', CLIPSegModel]], - ['chinese_clip', ['ChineseCLIPModel', ChineseCLIPModel]], - ['siglip', ['SiglipModel', SiglipModel]], - ['mobilebert', ['MobileBertModel', MobileBertModel]], - ['squeezebert', ['SqueezeBertModel', SqueezeBertModel]], - ['wav2vec2', ['Wav2Vec2Model', Wav2Vec2Model]], - ['wav2vec2-bert', ['Wav2Vec2BertModel', Wav2Vec2BertModel]], - ['unispeech', ['UniSpeechModel', UniSpeechModel]], - ['unispeech-sat', ['UniSpeechSatModel', UniSpeechSatModel]], - ['hubert', ['HubertModel', HubertModel]], - ['wavlm', ['WavLMModel', WavLMModel]], - ['audio-spectrogram-transformer', ['ASTModel', ASTModel]], - ['vits', ['VitsModel', VitsModel]], - - ['detr', ['DetrModel', DetrModel]], - ['table-transformer', ['TableTransformerModel', TableTransformerModel]], - ['vit', ['ViTModel', ViTModel]], - ['mobilevit', ['MobileViTModel', MobileViTModel]], - ['owlvit', ['OwlViTModel', OwlViTModel]], - ['owlv2', ['Owlv2Model', Owlv2Model]], - ['beit', ['BeitModel', BeitModel]], - ['deit', ['DeiTModel', DeiTModel]], - ['convnext', ['ConvNextModel', ConvNextModel]], - ['convnextv2', ['ConvNextV2Model', ConvNextV2Model]], - ['dinov2', ['Dinov2Model', Dinov2Model]], - ['resnet', ['ResNetModel', ResNetModel]], - ['swin', ['SwinModel', SwinModel]], - ['swin2sr', ['Swin2SRModel', Swin2SRModel]], - ['donut-swin', ['DonutSwinModel', DonutSwinModel]], - ['yolos', ['YolosModel', YolosModel]], - ['dpt', ['DPTModel', DPTModel]], - ['glpn', ['GLPNModel', GLPNModel]], - - ['hifigan', ['SpeechT5HifiGan', SpeechT5HifiGan]], - ['efficientnet', ['EfficientNetModel', EfficientNetModel]], + ['bert', BertModel], + ['nomic_bert', NomicBertModel], + ['roformer', RoFormerModel], + ['electra', ElectraModel], + ['esm', EsmModel], + ['convbert', ConvBertModel], + ['camembert', CamembertModel], + ['deberta', DebertaModel], + ['deberta-v2', DebertaV2Model], + ['mpnet', MPNetModel], + ['albert', AlbertModel], + ['distilbert', DistilBertModel], + ['roberta', RobertaModel], + ['xlm', XLMModel], + ['xlm-roberta', XLMRobertaModel], + ['clap', ClapModel], + ['clip', CLIPModel], + ['clipseg', CLIPSegModel], + ['chinese_clip', ChineseCLIPModel], + ['siglip', SiglipModel], + ['mobilebert', MobileBertModel], + ['squeezebert', SqueezeBertModel], + ['wav2vec2', Wav2Vec2Model], + ['wav2vec2-bert', Wav2Vec2BertModel], + ['unispeech', UniSpeechModel], + ['unispeech-sat', UniSpeechSatModel], + ['hubert', HubertModel], + ['wavlm', WavLMModel], + ['audio-spectrogram-transformer', ASTModel], + ['vits', VitsModel], + + ['detr', DetrModel], + ['table-transformer', TableTransformerModel], + ['vit', ViTModel], + ['mobilevit', MobileViTModel], + ['owlvit', OwlViTModel], + ['owlv2', Owlv2Model], + ['beit', BeitModel], + ['deit', DeiTModel], + ['convnext', ConvNextModel], + ['convnextv2', ConvNextV2Model], + ['dinov2', Dinov2Model], + ['resnet', ResNetModel], + ['swin', SwinModel], + ['swin2sr', Swin2SRModel], + ['donut-swin', DonutSwinModel], + ['yolos', YolosModel], + ['dpt', DPTModel], + ['glpn', GLPNModel], + + ['hifigan', SpeechT5HifiGan], + ['efficientnet', EfficientNetModel], ]); const MODEL_MAPPING_NAMES_ENCODER_DECODER = new Map([ - ['t5', ['T5Model', T5Model]], - ['longt5', ['LongT5Model', LongT5Model]], - ['mt5', ['MT5Model', MT5Model]], - ['bart', ['BartModel', BartModel]], - ['mbart', ['MBartModel', MBartModel]], - ['marian', ['MarianModel', MarianModel]], - ['whisper', ['WhisperModel', WhisperModel]], - ['m2m_100', ['M2M100Model', M2M100Model]], - ['blenderbot', ['BlenderbotModel', BlenderbotModel]], - ['blenderbot-small', ['BlenderbotSmallModel', BlenderbotSmallModel]], + ['t5', T5Model], + ['longt5', LongT5Model], + ['mt5', MT5Model], + ['bart', BartModel], + ['mbart', MBartModel], + ['marian', MarianModel], + ['whisper', WhisperModel], + ['m2m_100', M2M100Model], + ['blenderbot', BlenderbotModel], + ['blenderbot-small', BlenderbotSmallModel], ]); const MODEL_MAPPING_NAMES_DECODER_ONLY = new Map([ - ['bloom', ['BloomModel', BloomModel]], - ['gpt2', ['GPT2Model', GPT2Model]], - ['gptj', ['GPTJModel', GPTJModel]], - ['gpt_bigcode', ['GPTBigCodeModel', GPTBigCodeModel]], - ['gpt_neo', ['GPTNeoModel', GPTNeoModel]], - ['gpt_neox', ['GPTNeoXModel', GPTNeoXModel]], - ['codegen', ['CodeGenModel', CodeGenModel]], - ['llama', ['LlamaModel', LlamaModel]], - ['openelm', ['OpenELMModel', OpenELMModel]], - ['qwen2', ['Qwen2Model', Qwen2Model]], - ['phi', ['PhiModel', PhiModel]], - ['phi3', ['Phi3Model', Phi3Model]], - ['mpt', ['MptModel', MptModel]], - ['opt', ['OPTModel', OPTModel]], - ['mistral', ['MistralModel', MistralModel]], - ['starcoder2', ['Starcoder2Model', Starcoder2Model]], - ['falcon', ['FalconModel', FalconModel]], - ['stablelm', ['StableLmModel', StableLmModel]], + ['bloom', BloomModel], + ['gpt2', GPT2Model], + ['gptj', GPTJModel], + ['gpt_bigcode', GPTBigCodeModel], + ['gpt_neo', GPTNeoModel], + ['gpt_neox', GPTNeoXModel], + ['codegen', CodeGenModel], + ['llama', LlamaModel], + ['openelm', OpenELMModel], + ['qwen2', Qwen2Model], + ['phi', PhiModel], + ['phi3', Phi3Model], + ['mpt', MptModel], + ['opt', OPTModel], + ['mistral', MistralModel], + ['starcoder2', Starcoder2Model], + ['falcon', FalconModel], + ['stablelm', StableLmModel], ]); const MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = new Map([ - ['speecht5', ['SpeechT5ForSpeechToText', SpeechT5ForSpeechToText]], - ['whisper', ['WhisperForConditionalGeneration', WhisperForConditionalGeneration]], + ['speecht5', SpeechT5ForSpeechToText], + ['whisper', WhisperForConditionalGeneration], ]); const MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING_NAMES = new Map([ - ['speecht5', ['SpeechT5ForTextToSpeech', SpeechT5ForTextToSpeech]], + ['speecht5', SpeechT5ForTextToSpeech], ]); +// @ts-ignore const MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING_NAMES = new Map([ - ['vits', ['VitsModel', VitsModel]], - ['musicgen', ['MusicgenForConditionalGeneration', MusicgenForConditionalGeneration]], + ['vits', VitsModel], + ['musicgen', MusicgenForConditionalGeneration], ]); +// @ts-ignore const MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = new Map([ - ['bert', ['BertForSequenceClassification', BertForSequenceClassification]], - ['roformer', ['RoFormerForSequenceClassification', RoFormerForSequenceClassification]], - ['electra', ['ElectraForSequenceClassification', ElectraForSequenceClassification]], - ['esm', ['EsmForSequenceClassification', EsmForSequenceClassification]], - ['convbert', ['ConvBertForSequenceClassification', ConvBertForSequenceClassification]], - ['camembert', ['CamembertForSequenceClassification', CamembertForSequenceClassification]], - ['deberta', ['DebertaForSequenceClassification', DebertaForSequenceClassification]], - ['deberta-v2', ['DebertaV2ForSequenceClassification', DebertaV2ForSequenceClassification]], - ['mpnet', ['MPNetForSequenceClassification', MPNetForSequenceClassification]], - ['albert', ['AlbertForSequenceClassification', AlbertForSequenceClassification]], - ['distilbert', ['DistilBertForSequenceClassification', DistilBertForSequenceClassification]], - ['roberta', ['RobertaForSequenceClassification', RobertaForSequenceClassification]], - ['xlm', ['XLMForSequenceClassification', XLMForSequenceClassification]], - ['xlm-roberta', ['XLMRobertaForSequenceClassification', XLMRobertaForSequenceClassification]], - ['bart', ['BartForSequenceClassification', BartForSequenceClassification]], - ['mbart', ['MBartForSequenceClassification', MBartForSequenceClassification]], - ['mobilebert', ['MobileBertForSequenceClassification', MobileBertForSequenceClassification]], - ['squeezebert', ['SqueezeBertForSequenceClassification', SqueezeBertForSequenceClassification]], + ['bert', BertForSequenceClassification], + ['roformer', RoFormerForSequenceClassification], + ['electra', ElectraForSequenceClassification], + ['esm', EsmForSequenceClassification], + ['convbert', ConvBertForSequenceClassification], + ['camembert', CamembertForSequenceClassification], + ['deberta', DebertaForSequenceClassification], + ['deberta-v2', DebertaV2ForSequenceClassification], + ['mpnet', MPNetForSequenceClassification], + ['albert', AlbertForSequenceClassification], + ['distilbert', DistilBertForSequenceClassification], + ['roberta', RobertaForSequenceClassification], + ['xlm', XLMForSequenceClassification], + ['xlm-roberta', XLMRobertaForSequenceClassification], + ['bart', BartForSequenceClassification], + ['mbart', MBartForSequenceClassification], + ['mobilebert', MobileBertForSequenceClassification], + ['squeezebert', SqueezeBertForSequenceClassification], ]); const MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = new Map([ - ['bert', ['BertForTokenClassification', BertForTokenClassification]], - ['roformer', ['RoFormerForTokenClassification', RoFormerForTokenClassification]], - ['electra', ['ElectraForTokenClassification', ElectraForTokenClassification]], - ['esm', ['EsmForTokenClassification', EsmForTokenClassification]], - ['convbert', ['ConvBertForTokenClassification', ConvBertForTokenClassification]], - ['camembert', ['CamembertForTokenClassification', CamembertForTokenClassification]], - ['deberta', ['DebertaForTokenClassification', DebertaForTokenClassification]], - ['deberta-v2', ['DebertaV2ForTokenClassification', DebertaV2ForTokenClassification]], - ['mpnet', ['MPNetForTokenClassification', MPNetForTokenClassification]], - ['distilbert', ['DistilBertForTokenClassification', DistilBertForTokenClassification]], - ['roberta', ['RobertaForTokenClassification', RobertaForTokenClassification]], - ['xlm', ['XLMForTokenClassification', XLMForTokenClassification]], - ['xlm-roberta', ['XLMRobertaForTokenClassification', XLMRobertaForTokenClassification]], + ['bert', BertForTokenClassification], + ['roformer', RoFormerForTokenClassification], + ['electra', ElectraForTokenClassification], + ['esm', EsmForTokenClassification], + ['convbert', ConvBertForTokenClassification], + ['camembert', CamembertForTokenClassification], + ['deberta', DebertaForTokenClassification], + ['deberta-v2', DebertaV2ForTokenClassification], + ['mpnet', MPNetForTokenClassification], + ['distilbert', DistilBertForTokenClassification], + ['roberta', RobertaForTokenClassification], + ['xlm', XLMForTokenClassification], + ['xlm-roberta', XLMRobertaForTokenClassification], ]); const MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = new Map([ - ['t5', ['T5ForConditionalGeneration', T5ForConditionalGeneration]], - ['longt5', ['LongT5ForConditionalGeneration', LongT5ForConditionalGeneration]], - ['mt5', ['MT5ForConditionalGeneration', MT5ForConditionalGeneration]], - ['bart', ['BartForConditionalGeneration', BartForConditionalGeneration]], - ['mbart', ['MBartForConditionalGeneration', MBartForConditionalGeneration]], - ['marian', ['MarianMTModel', MarianMTModel]], - ['m2m_100', ['M2M100ForConditionalGeneration', M2M100ForConditionalGeneration]], - ['blenderbot', ['BlenderbotForConditionalGeneration', BlenderbotForConditionalGeneration]], - ['blenderbot-small', ['BlenderbotSmallForConditionalGeneration', BlenderbotSmallForConditionalGeneration]], + ['t5', T5ForConditionalGeneration], + ['longt5', LongT5ForConditionalGeneration], + ['mt5', MT5ForConditionalGeneration], + ['bart', BartForConditionalGeneration], + ['mbart', MBartForConditionalGeneration], + ['marian', MarianMTModel], + ['m2m_100', M2M100ForConditionalGeneration], + ['blenderbot', BlenderbotForConditionalGeneration], + ['blenderbot-small', BlenderbotSmallForConditionalGeneration], ]); const MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = new Map([ - ['bloom', ['BloomForCausalLM', BloomForCausalLM]], - ['gpt2', ['GPT2LMHeadModel', GPT2LMHeadModel]], - ['gptj', ['GPTJForCausalLM', GPTJForCausalLM]], - ['gpt_bigcode', ['GPTBigCodeForCausalLM', GPTBigCodeForCausalLM]], - ['gpt_neo', ['GPTNeoForCausalLM', GPTNeoForCausalLM]], - ['gpt_neox', ['GPTNeoXForCausalLM', GPTNeoXForCausalLM]], - ['codegen', ['CodeGenForCausalLM', CodeGenForCausalLM]], - ['llama', ['LlamaForCausalLM', LlamaForCausalLM]], - ['openelm', ['OpenELMForCausalLM', OpenELMForCausalLM]], - ['qwen2', ['Qwen2ForCausalLM', Qwen2ForCausalLM]], - ['phi', ['PhiForCausalLM', PhiForCausalLM]], - ['phi3', ['Phi3ForCausalLM', Phi3ForCausalLM]], - ['mpt', ['MptForCausalLM', MptForCausalLM]], - ['opt', ['OPTForCausalLM', OPTForCausalLM]], - ['mbart', ['MBartForCausalLM', MBartForCausalLM]], - ['mistral', ['MistralForCausalLM', MistralForCausalLM]], - ['starcoder2', ['Starcoder2ForCausalLM', Starcoder2ForCausalLM]], - ['falcon', ['FalconForCausalLM', FalconForCausalLM]], - ['trocr', ['TrOCRForCausalLM', TrOCRForCausalLM]], - ['stablelm', ['StableLmForCausalLM', StableLmForCausalLM]], + ['bloom', BloomForCausalLM], + ['gpt2', GPT2LMHeadModel], + ['gptj', GPTJForCausalLM], + ['gpt_bigcode', GPTBigCodeForCausalLM], + ['gpt_neo', GPTNeoForCausalLM], + ['gpt_neox', GPTNeoXForCausalLM], + ['codegen', CodeGenForCausalLM], + ['llama', LlamaForCausalLM], + ['openelm', OpenELMForCausalLM], + ['qwen2', Qwen2ForCausalLM], + ['phi', PhiForCausalLM], + ['phi3', Phi3ForCausalLM], + ['mpt', MptForCausalLM], + ['opt', OPTForCausalLM], + ['mbart', MBartForCausalLM], + ['mistral', MistralForCausalLM], + ['starcoder2', Starcoder2ForCausalLM], + ['falcon', FalconForCausalLM], + ['trocr', TrOCRForCausalLM], + ['stablelm', StableLmForCausalLM], ]); const MODEL_FOR_MASKED_LM_MAPPING_NAMES = new Map([ - ['bert', ['BertForMaskedLM', BertForMaskedLM]], - ['roformer', ['RoFormerForMaskedLM', RoFormerForMaskedLM]], - ['electra', ['ElectraForMaskedLM', ElectraForMaskedLM]], - ['esm', ['EsmForMaskedLM', EsmForMaskedLM]], - ['convbert', ['ConvBertForMaskedLM', ConvBertForMaskedLM]], - ['camembert', ['CamembertForMaskedLM', CamembertForMaskedLM]], - ['deberta', ['DebertaForMaskedLM', DebertaForMaskedLM]], - ['deberta-v2', ['DebertaV2ForMaskedLM', DebertaV2ForMaskedLM]], - ['mpnet', ['MPNetForMaskedLM', MPNetForMaskedLM]], - ['albert', ['AlbertForMaskedLM', AlbertForMaskedLM]], - ['distilbert', ['DistilBertForMaskedLM', DistilBertForMaskedLM]], - ['roberta', ['RobertaForMaskedLM', RobertaForMaskedLM]], - ['xlm', ['XLMWithLMHeadModel', XLMWithLMHeadModel]], - ['xlm-roberta', ['XLMRobertaForMaskedLM', XLMRobertaForMaskedLM]], - ['mobilebert', ['MobileBertForMaskedLM', MobileBertForMaskedLM]], - ['squeezebert', ['SqueezeBertForMaskedLM', SqueezeBertForMaskedLM]], + ['bert', BertForMaskedLM], + ['roformer', RoFormerForMaskedLM], + ['electra', ElectraForMaskedLM], + ['esm', EsmForMaskedLM], + ['convbert', ConvBertForMaskedLM], + ['camembert', CamembertForMaskedLM], + ['deberta', DebertaForMaskedLM], + ['deberta-v2', DebertaV2ForMaskedLM], + ['mpnet', MPNetForMaskedLM], + ['albert', AlbertForMaskedLM], + ['distilbert', DistilBertForMaskedLM], + ['roberta', RobertaForMaskedLM], + ['xlm', XLMWithLMHeadModel], + ['xlm-roberta', XLMRobertaForMaskedLM], + ['mobilebert', MobileBertForMaskedLM], + ['squeezebert', SqueezeBertForMaskedLM], ]); const MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = new Map([ - ['bert', ['BertForQuestionAnswering', BertForQuestionAnswering]], - ['roformer', ['RoFormerForQuestionAnswering', RoFormerForQuestionAnswering]], - ['electra', ['ElectraForQuestionAnswering', ElectraForQuestionAnswering]], - ['convbert', ['ConvBertForQuestionAnswering', ConvBertForQuestionAnswering]], - ['camembert', ['CamembertForQuestionAnswering', CamembertForQuestionAnswering]], - ['deberta', ['DebertaForQuestionAnswering', DebertaForQuestionAnswering]], - ['deberta-v2', ['DebertaV2ForQuestionAnswering', DebertaV2ForQuestionAnswering]], - ['mpnet', ['MPNetForQuestionAnswering', MPNetForQuestionAnswering]], - ['albert', ['AlbertForQuestionAnswering', AlbertForQuestionAnswering]], - ['distilbert', ['DistilBertForQuestionAnswering', DistilBertForQuestionAnswering]], - ['roberta', ['RobertaForQuestionAnswering', RobertaForQuestionAnswering]], - ['xlm', ['XLMForQuestionAnswering', XLMForQuestionAnswering]], - ['xlm-roberta', ['XLMRobertaForQuestionAnswering', XLMRobertaForQuestionAnswering]], - ['mobilebert', ['MobileBertForQuestionAnswering', MobileBertForQuestionAnswering]], - ['squeezebert', ['SqueezeBertForQuestionAnswering', SqueezeBertForQuestionAnswering]], + ['bert', BertForQuestionAnswering], + ['roformer', RoFormerForQuestionAnswering], + ['electra', ElectraForQuestionAnswering], + ['convbert', ConvBertForQuestionAnswering], + ['camembert', CamembertForQuestionAnswering], + ['deberta', DebertaForQuestionAnswering], + ['deberta-v2', DebertaV2ForQuestionAnswering], + ['mpnet', MPNetForQuestionAnswering], + ['albert', AlbertForQuestionAnswering], + ['distilbert', DistilBertForQuestionAnswering], + ['roberta', RobertaForQuestionAnswering], + ['xlm', XLMForQuestionAnswering], + ['xlm-roberta', XLMRobertaForQuestionAnswering], + ['mobilebert', MobileBertForQuestionAnswering], + ['squeezebert', SqueezeBertForQuestionAnswering], ]); const MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = new Map([ - ['vision-encoder-decoder', ['VisionEncoderDecoderModel', VisionEncoderDecoderModel]], + ['vision-encoder-decoder', VisionEncoderDecoderModel], ]); const MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES = new Map([ - ['llava', ['LlavaForConditionalGeneration', LlavaForConditionalGeneration]], + ['llava', LlavaForConditionalGeneration], ]); const MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES = new Map([ - ['vision-encoder-decoder', ['VisionEncoderDecoderModel', VisionEncoderDecoderModel]], + ['vision-encoder-decoder', VisionEncoderDecoderModel], ]); const MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = new Map([ - ['vit', ['ViTForImageClassification', ViTForImageClassification]], - ['mobilevit', ['MobileViTForImageClassification', MobileViTForImageClassification]], - ['beit', ['BeitForImageClassification', BeitForImageClassification]], - ['deit', ['DeiTForImageClassification', DeiTForImageClassification]], - ['convnext', ['ConvNextForImageClassification', ConvNextForImageClassification]], - ['convnextv2', ['ConvNextV2ForImageClassification', ConvNextV2ForImageClassification]], - ['dinov2', ['Dinov2ForImageClassification', Dinov2ForImageClassification]], - ['resnet', ['ResNetForImageClassification', ResNetForImageClassification]], - ['swin', ['SwinForImageClassification', SwinForImageClassification]], - ['segformer', ['SegformerForImageClassification', SegformerForImageClassification]], - ['efficientnet', ['EfficientNetForImageClassification', EfficientNetForImageClassification]], + ['vit', ViTForImageClassification], + ['mobilevit', MobileViTForImageClassification], + ['beit', BeitForImageClassification], + ['deit', DeiTForImageClassification], + ['convnext', ConvNextForImageClassification], + ['convnextv2', ConvNextV2ForImageClassification], + ['dinov2', Dinov2ForImageClassification], + ['resnet', ResNetForImageClassification], + ['swin', SwinForImageClassification], + ['segformer', SegformerForImageClassification], + ['efficientnet', EfficientNetForImageClassification], ]); const MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES = new Map([ - ['detr', ['DetrForObjectDetection', DetrForObjectDetection]], - ['table-transformer', ['TableTransformerForObjectDetection', TableTransformerForObjectDetection]], - ['yolos', ['YolosForObjectDetection', YolosForObjectDetection]], + ['detr', DetrForObjectDetection], + ['table-transformer', TableTransformerForObjectDetection], + ['yolos', YolosForObjectDetection], ]); const MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES = new Map([ - ['owlvit', ['OwlViTForObjectDetection', OwlViTForObjectDetection]], - ['owlv2', ['Owlv2ForObjectDetection', Owlv2ForObjectDetection]], + ['owlvit', OwlViTForObjectDetection], + ['owlv2', Owlv2ForObjectDetection], ]); const MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES = new Map([ - ['detr', ['DetrForSegmentation', DetrForSegmentation]], - ['clipseg', ['CLIPSegForImageSegmentation', CLIPSegForImageSegmentation]], + ['detr', DetrForSegmentation], + ['clipseg', CLIPSegForImageSegmentation], ]); const MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES = new Map([ - ['segformer', ['SegformerForSemanticSegmentation', SegformerForSemanticSegmentation]], + ['segformer', SegformerForSemanticSegmentation], ]); const MODEL_FOR_MASK_GENERATION_MAPPING_NAMES = new Map([ - ['sam', ['SamModel', SamModel]], + ['sam', SamModel], ]); +// @ts-ignore const MODEL_FOR_CTC_MAPPING_NAMES = new Map([ - ['wav2vec2', ['Wav2Vec2ForCTC', Wav2Vec2ForCTC]], - ['wav2vec2-bert', ['Wav2Vec2BertForCTC', Wav2Vec2BertForCTC]], - ['unispeech', ['UniSpeechForCTC', UniSpeechForCTC]], - ['unispeech-sat', ['UniSpeechSatForCTC', UniSpeechSatForCTC]], - ['wavlm', ['WavLMForCTC', WavLMForCTC]], - ['hubert', ['HubertForCTC', HubertForCTC]], + ['wav2vec2', Wav2Vec2ForCTC], + ['wav2vec2-bert', Wav2Vec2BertForCTC], + ['unispeech', UniSpeechForCTC], + ['unispeech-sat', UniSpeechSatForCTC], + ['wavlm', WavLMForCTC], + ['hubert', HubertForCTC], ]); const MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = new Map([ - ['wav2vec2', ['Wav2Vec2ForSequenceClassification', Wav2Vec2ForSequenceClassification]], - ['wav2vec2-bert', ['Wav2Vec2BertForSequenceClassification', Wav2Vec2BertForSequenceClassification]], - ['unispeech', ['UniSpeechForSequenceClassification', UniSpeechForSequenceClassification]], - ['unispeech-sat', ['UniSpeechSatForSequenceClassification', UniSpeechSatForSequenceClassification]], - ['wavlm', ['WavLMForSequenceClassification', WavLMForSequenceClassification]], - ['hubert', ['HubertForSequenceClassification', HubertForSequenceClassification]], - ['audio-spectrogram-transformer', ['ASTForAudioClassification', ASTForAudioClassification]], + ['wav2vec2', Wav2Vec2ForSequenceClassification], + ['wav2vec2-bert', Wav2Vec2BertForSequenceClassification], + ['unispeech', UniSpeechForSequenceClassification], + ['unispeech-sat', UniSpeechSatForSequenceClassification], + ['wavlm', WavLMForSequenceClassification], + ['hubert', HubertForSequenceClassification], + ['audio-spectrogram-transformer', ASTForAudioClassification], ]); const MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES = new Map([ - ['wavlm', ['WavLMForXVector', WavLMForXVector]], + ['wavlm', WavLMForXVector], ]); const MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES = new Map([ - ['unispeech-sat', ['UniSpeechSatForAudioFrameClassification', UniSpeechSatForAudioFrameClassification]], - ['wavlm', ['WavLMForAudioFrameClassification', WavLMForAudioFrameClassification]], - ['wav2vec2', ['Wav2Vec2ForAudioFrameClassification', Wav2Vec2ForAudioFrameClassification]], + ['unispeech-sat', UniSpeechSatForAudioFrameClassification], + ['wavlm', WavLMForAudioFrameClassification], + ['wav2vec2', Wav2Vec2ForAudioFrameClassification], ]); const MODEL_FOR_IMAGE_MATTING_MAPPING_NAMES = new Map([ - ['vitmatte', ['VitMatteForImageMatting', VitMatteForImageMatting]], + ['vitmatte', VitMatteForImageMatting], ]); const MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES = new Map([ - ['swin2sr', ['Swin2SRForImageSuperResolution', Swin2SRForImageSuperResolution]], + ['swin2sr', Swin2SRForImageSuperResolution], ]) const MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES = new Map([ - ['dpt', ['DPTForDepthEstimation', DPTForDepthEstimation]], - ['depth_anything', ['DepthAnythingForDepthEstimation', DepthAnythingForDepthEstimation]], - ['glpn', ['GLPNForDepthEstimation', GLPNForDepthEstimation]], + ['dpt', DPTForDepthEstimation], + ['depth_anything', DepthAnythingForDepthEstimation], + ['glpn', GLPNForDepthEstimation], ]) // NOTE: This is custom to Transformers.js, and is necessary because certain models // (e.g., CLIP) are split into vision and text components const MODEL_FOR_IMAGE_FEATURE_EXTRACTION_MAPPING_NAMES = new Map([ - ['clip', ['CLIPVisionModelWithProjection', CLIPVisionModelWithProjection]], - ['siglip', ['SiglipVisionModel', SiglipVisionModel]], + ['clip', CLIPVisionModelWithProjection], + ['siglip', SiglipVisionModel], ]) const MODEL_CLASS_TYPE_MAPPING = [ @@ -6244,31 +6258,16 @@ const MODEL_CLASS_TYPE_MAPPING = [ [MODEL_FOR_IMAGE_FEATURE_EXTRACTION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly], ]; -for (const [mappings, type] of MODEL_CLASS_TYPE_MAPPING) { - // @ts-ignore - for (const [name, model] of mappings.values()) { - MODEL_TYPE_MAPPING.set(name, type); - MODEL_CLASS_TO_NAME_MAPPING.set(model, name); - MODEL_NAME_TO_CLASS_MAPPING.set(name, model); - } -} - const CUSTOM_MAPPING = [ // OVERRIDE: // TODO: Refactor to allow class to specify model - ['MusicgenForConditionalGeneration', MusicgenForConditionalGeneration, MODEL_TYPES.Musicgen], + [MusicgenForConditionalGeneration, MODEL_TYPES.Musicgen], - ['CLIPTextModelWithProjection', CLIPTextModelWithProjection, MODEL_TYPES.EncoderOnly], - ['SiglipTextModel', SiglipTextModel, MODEL_TYPES.EncoderOnly], - ['ClapTextModelWithProjection', ClapTextModelWithProjection, MODEL_TYPES.EncoderOnly], - ['ClapAudioModelWithProjection', ClapAudioModelWithProjection, MODEL_TYPES.EncoderOnly], + [CLIPTextModelWithProjection, MODEL_TYPES.EncoderOnly], + [SiglipTextModel, MODEL_TYPES.EncoderOnly], + [ClapTextModelWithProjection, MODEL_TYPES.EncoderOnly], + [ClapAudioModelWithProjection, MODEL_TYPES.EncoderOnly], ] -for (const [name, model, type] of CUSTOM_MAPPING) { - MODEL_TYPE_MAPPING.set(name, type); - MODEL_CLASS_TO_NAME_MAPPING.set(model, name); - MODEL_NAME_TO_CLASS_MAPPING.set(name, model); -} - /** * Helper class which is used to instantiate pretrained models with the `from_pretrained` function. @@ -6490,6 +6489,24 @@ export class AutoModelForImageFeatureExtraction extends PretrainedMixin { static MODEL_CLASS_MAPPINGS = [MODEL_FOR_IMAGE_FEATURE_EXTRACTION_MAPPING_NAMES]; } +for(let name in modelsExport){ + MODEL_CLASS_TO_NAME_MAPPING.set(modelsExport[name], name); + MODEL_NAME_TO_CLASS_MAPPING.set(name, modelsExport[name]); +} + +for (const [mappings, type] of MODEL_CLASS_TYPE_MAPPING) { + // @ts-ignore + for (const model of mappings.values()) { + const name = MODEL_CLASS_TO_NAME_MAPPING.get(model); + MODEL_TYPE_MAPPING.set(name, type); + } +} + +for (const [model, type] of CUSTOM_MAPPING) { + const name = MODEL_CLASS_TO_NAME_MAPPING.get(model); + MODEL_TYPE_MAPPING.set(name, type); +} + ////////////////////////////////////////////////// ////////////////////////////////////////////////// diff --git a/src/pipelines.js b/src/pipelines.js index c46c44046..df8e08ff7 100644 --- a/src/pipelines.js +++ b/src/pipelines.js @@ -41,6 +41,7 @@ import { AutoModelForDepthEstimation, AutoModelForImageFeatureExtraction, PreTrainedModel, + getModelClassFromName, } from './models.js'; import { AutoProcessor, @@ -3317,8 +3318,6 @@ async function loadItems(mapping, model, pretrainedOptions) { * Pipeline, * read_audio, * register_pipeline, - * ClapAudioModelWithProjection, - * AutoProcessor, * pipeline * } from '@xenova/transformers'; * @@ -3335,8 +3334,8 @@ async function loadItems(mapping, model, pretrainedOptions) { * * register_pipeline('audio-feature-extraction', { * pipeline: AudioFeatureExtractionPipeline, - * model: ClapAudioModelWithProjection, - * processor: AutoProcessor, + * model: 'ClapAudioModelWithProjection', + * processor: 'AutoProcessor', * default_model: 'Xenova/larger_clap_music_and_speech' * }) * @@ -3349,7 +3348,7 @@ export function register_pipeline( task, { tokenizer, - pipeline, + pipeline: pipelineClass, model, processor, default_model = '', @@ -3357,18 +3356,18 @@ export function register_pipeline( } = {} ) { if(!( - ('prototype' in pipeline) && - (pipeline.prototype instanceof Pipeline) && - ("_call" in pipeline.prototype) + ('prototype' in pipelineClass) && + (pipelineClass.prototype instanceof Pipeline) && + ("_call" in pipelineClass.prototype) )){ throw Error('pipeline class must inherit from Pipeline, and contains _call') } SUPPORTED_TASKS[task] = { - tokenizer, - pipeline, - model, - processor, + tokenizer: tokenizer == 'AutoTokenizer' ? AutoTokenizer : tokenizer, + pipeline: pipelineClass, + model: typeof model == 'string' ? getModelClassFromName(model) : model, + processor: processor == 'AutoProcessor' ? AutoProcessor : processor, 'default': {model: default_model}, type } From a86bfe30fdca76bd9041a7e882a67e5c8c6c6e98 Mon Sep 17 00:00:00 2001 From: Th3G33k <4394090+Th3G33k@users.noreply.github.com> Date: Wed, 8 May 2024 21:20:28 -1000 Subject: [PATCH 5/9] models.js fix model mapping --- src/models.js | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/models.js b/src/models.js index a62ec904b..7ccd9d829 100644 --- a/src/models.js +++ b/src/models.js @@ -1027,7 +1027,8 @@ export class PreTrainedModel extends Callable { for (const model_mapping of generate_compatible_mappings) { const supported_models = model_mapping.get(modelType); if (supported_models) { - generate_compatible_classes.add(supported_models[0]); + const name = MODEL_CLASS_TO_NAME_MAPPING.get(supported_models); + generate_compatible_classes.add(name); } } @@ -5897,7 +5898,7 @@ export class PretrainedMixin { if (!modelInfo) { continue; // Item not found in this mapping } - return await modelInfo[1].from_pretrained(pretrained_model_name_or_path, options); + return await modelInfo.from_pretrained(pretrained_model_name_or_path, options); } if (this.BASE_IF_FAIL) { From 5ed44f2428efd7e8c831ed11a06328bbf96ffd2b Mon Sep 17 00:00:00 2001 From: Th3G33k <4394090+Th3G33k@users.noreply.github.com> Date: Fri, 10 May 2024 03:05:09 -1000 Subject: [PATCH 6/9] Beautify code --- src/models.js | 8 ++++---- src/pipelines.js | 17 +++++++++-------- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/src/models.js b/src/models.js index 7ccd9d829..f91ef6535 100644 --- a/src/models.js +++ b/src/models.js @@ -142,7 +142,7 @@ const MODEL_CLASS_TO_NAME_MAPPING = new Map(); */ export function getModelClassFromName(name){ let cls = MODEL_NAME_TO_CLASS_MAPPING.get(name); - if(!cls) console.warn(name + ' undefined'); + if (!cls) console.warn(name + ' undefined'); return cls; } @@ -1027,7 +1027,7 @@ export class PreTrainedModel extends Callable { for (const model_mapping of generate_compatible_mappings) { const supported_models = model_mapping.get(modelType); if (supported_models) { - const name = MODEL_CLASS_TO_NAME_MAPPING.get(supported_models); + const name = supported_models[0] || MODEL_CLASS_TO_NAME_MAPPING.get(supported_models); generate_compatible_classes.add(name); } } @@ -5898,7 +5898,7 @@ export class PretrainedMixin { if (!modelInfo) { continue; // Item not found in this mapping } - return await modelInfo.from_pretrained(pretrained_model_name_or_path, options); + return await (modelInfo[1] || modelInfo).from_pretrained(pretrained_model_name_or_path, options); } if (this.BASE_IF_FAIL) { @@ -6490,7 +6490,7 @@ export class AutoModelForImageFeatureExtraction extends PretrainedMixin { static MODEL_CLASS_MAPPINGS = [MODEL_FOR_IMAGE_FEATURE_EXTRACTION_MAPPING_NAMES]; } -for(let name in modelsExport){ +for (let name in modelsExport){ MODEL_CLASS_TO_NAME_MAPPING.set(modelsExport[name], name); MODEL_NAME_TO_CLASS_MAPPING.set(name, modelsExport[name]); } diff --git a/src/pipelines.js b/src/pipelines.js index df8e08ff7..110215470 100644 --- a/src/pipelines.js +++ b/src/pipelines.js @@ -3345,8 +3345,7 @@ async function loadItems(mapping, model, pretrainedOptions) { * ``` */ export function register_pipeline( - task, - { + task, { tokenizer, pipeline: pipelineClass, model, @@ -3355,11 +3354,11 @@ export function register_pipeline( type = '' } = {} ) { - if(!( - ('prototype' in pipelineClass) && - (pipelineClass.prototype instanceof Pipeline) && - ("_call" in pipelineClass.prototype) - )){ + if (!( + ('prototype' in pipelineClass) && + (pipelineClass.prototype instanceof Pipeline) && + ("_call" in pipelineClass.prototype) + )) { throw Error('pipeline class must inherit from Pipeline, and contains _call') } @@ -3368,7 +3367,9 @@ export function register_pipeline( pipeline: pipelineClass, model: typeof model == 'string' ? getModelClassFromName(model) : model, processor: processor == 'AutoProcessor' ? AutoProcessor : processor, - 'default': {model: default_model}, + 'default': { + model: default_model + }, type } From 306bbb3111f8f16cb51bafa1909592f4b2ff1027 Mon Sep 17 00:00:00 2001 From: Th3G33k <4394090+Th3G33k@users.noreply.github.com> Date: Fri, 10 May 2024 10:38:08 -1000 Subject: [PATCH 7/9] Allow updating existing supported_tasks --- src/pipelines.js | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/src/pipelines.js b/src/pipelines.js index 110215470..d6bce692e 100644 --- a/src/pipelines.js +++ b/src/pipelines.js @@ -3362,15 +3362,22 @@ export function register_pipeline( throw Error('pipeline class must inherit from Pipeline, and contains _call') } - SUPPORTED_TASKS[task] = { + const custom = { tokenizer: tokenizer == 'AutoTokenizer' ? AutoTokenizer : tokenizer, pipeline: pipelineClass, model: typeof model == 'string' ? getModelClassFromName(model) : model, processor: processor == 'AutoProcessor' ? AutoProcessor : processor, - 'default': { + 'default': (!default_model ? {} : { model: default_model - }, + }), type + }; + + if (task in SUPPORTED_TASKS) { + for (let key in custom) { + if (custom[key]) SUPPORTED_TASK[task][key] = custom[key]; + } } + else SUPPORTED_TASK[task] = custom; -} \ No newline at end of file +} From 8aac78e49bfa02527e5fd6f4bbad6183dfd46997 Mon Sep 17 00:00:00 2001 From: Th3G33k <4394090+Th3G33k@users.noreply.github.com> Date: Sat, 11 May 2024 08:11:26 -1000 Subject: [PATCH 8/9] Update pipelines.js --- src/pipelines.js | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/pipelines.js b/src/pipelines.js index d6bce692e..ef9a0db96 100644 --- a/src/pipelines.js +++ b/src/pipelines.js @@ -3375,9 +3375,9 @@ export function register_pipeline( if (task in SUPPORTED_TASKS) { for (let key in custom) { - if (custom[key]) SUPPORTED_TASK[task][key] = custom[key]; + if (custom[key]) SUPPORTED_TASKS[task][key] = custom[key]; } } - else SUPPORTED_TASK[task] = custom; + else SUPPORTED_TASKS[task] = custom; } From 00ff4f915c7b2f5fcc84e8b9195990602b421bf1 Mon Sep 17 00:00:00 2001 From: Th3G33k <4394090+Th3G33k@users.noreply.github.com> Date: Sat, 11 May 2024 23:36:40 -1000 Subject: [PATCH 9/9] Update pipelines.js --- src/pipelines.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pipelines.js b/src/pipelines.js index ef9a0db96..3d3538b2b 100644 --- a/src/pipelines.js +++ b/src/pipelines.js @@ -3367,7 +3367,7 @@ export function register_pipeline( pipeline: pipelineClass, model: typeof model == 'string' ? getModelClassFromName(model) : model, processor: processor == 'AutoProcessor' ? AutoProcessor : processor, - 'default': (!default_model ? {} : { + 'default': (!default_model ? '' : { model: default_model }), type