diff --git a/src/models.js b/src/models.js index 0cc865abd..f91ef6535 100644 --- a/src/models.js +++ b/src/models.js @@ -110,6 +110,7 @@ import { medianFilter } from './utils/maths.js'; import { EosTokenCriteria, MaxLengthCriteria, StoppingCriteriaList } from './generation/stopping_criteria.js'; import { LogitsSampler } from './generation/logits_sampler.js'; import { apis } from './env.js'; +import * as modelsExport from './models.js'; ////////////////////////////////////////////////// // Model types: used internally @@ -134,6 +135,16 @@ const MODEL_TYPE_MAPPING = new Map(); const MODEL_NAME_TO_CLASS_MAPPING = new Map(); const MODEL_CLASS_TO_NAME_MAPPING = new Map(); +/** + * Get model class from name + * @param {string} name + * @returns + */ +export function getModelClassFromName(name){ + let cls = MODEL_NAME_TO_CLASS_MAPPING.get(name); + if (!cls) console.warn(name + ' undefined'); + return cls; +} /** * Constructs an InferenceSession using a model file located at the specified path. @@ -1016,7 +1027,8 @@ export class PreTrainedModel extends Callable { for (const model_mapping of generate_compatible_mappings) { const supported_models = model_mapping.get(modelType); if (supported_models) { - generate_compatible_classes.add(supported_models[0]); + const name = supported_models[0] || MODEL_CLASS_TO_NAME_MAPPING.get(supported_models); + generate_compatible_classes.add(name); } } @@ -5886,7 +5898,7 @@ export class PretrainedMixin { if (!modelInfo) { continue; // Item not found in this mapping } - return await modelInfo[1].from_pretrained(pretrained_model_name_or_path, options); + return await (modelInfo[1] || modelInfo).from_pretrained(pretrained_model_name_or_path, options); } if (this.BASE_IF_FAIL) { @@ -5899,316 +5911,319 @@ export class PretrainedMixin { } const MODEL_MAPPING_NAMES_ENCODER_ONLY = new Map([ - ['bert', ['BertModel', BertModel]], - ['nomic_bert', ['NomicBertModel', NomicBertModel]], - ['roformer', ['RoFormerModel', RoFormerModel]], - ['electra', ['ElectraModel', ElectraModel]], - ['esm', ['EsmModel', EsmModel]], - ['convbert', ['ConvBertModel', ConvBertModel]], - ['camembert', ['CamembertModel', CamembertModel]], - ['deberta', ['DebertaModel', DebertaModel]], - ['deberta-v2', ['DebertaV2Model', DebertaV2Model]], - ['mpnet', ['MPNetModel', MPNetModel]], - ['albert', ['AlbertModel', AlbertModel]], - ['distilbert', ['DistilBertModel', DistilBertModel]], - ['roberta', ['RobertaModel', RobertaModel]], - ['xlm', ['XLMModel', XLMModel]], - ['xlm-roberta', ['XLMRobertaModel', XLMRobertaModel]], - ['clap', ['ClapModel', ClapModel]], - ['clip', ['CLIPModel', CLIPModel]], - ['clipseg', ['CLIPSegModel', CLIPSegModel]], - ['chinese_clip', ['ChineseCLIPModel', ChineseCLIPModel]], - ['siglip', ['SiglipModel', SiglipModel]], - ['mobilebert', ['MobileBertModel', MobileBertModel]], - ['squeezebert', ['SqueezeBertModel', SqueezeBertModel]], - ['wav2vec2', ['Wav2Vec2Model', Wav2Vec2Model]], - ['wav2vec2-bert', ['Wav2Vec2BertModel', Wav2Vec2BertModel]], - ['unispeech', ['UniSpeechModel', UniSpeechModel]], - ['unispeech-sat', ['UniSpeechSatModel', UniSpeechSatModel]], - ['hubert', ['HubertModel', HubertModel]], - ['wavlm', ['WavLMModel', WavLMModel]], - ['audio-spectrogram-transformer', ['ASTModel', ASTModel]], - ['vits', ['VitsModel', VitsModel]], - - ['detr', ['DetrModel', DetrModel]], - ['table-transformer', ['TableTransformerModel', TableTransformerModel]], - ['vit', ['ViTModel', ViTModel]], - ['mobilevit', ['MobileViTModel', MobileViTModel]], - ['owlvit', ['OwlViTModel', OwlViTModel]], - ['owlv2', ['Owlv2Model', Owlv2Model]], - ['beit', ['BeitModel', BeitModel]], - ['deit', ['DeiTModel', DeiTModel]], - ['convnext', ['ConvNextModel', ConvNextModel]], - ['convnextv2', ['ConvNextV2Model', ConvNextV2Model]], - ['dinov2', ['Dinov2Model', Dinov2Model]], - ['resnet', ['ResNetModel', ResNetModel]], - ['swin', ['SwinModel', SwinModel]], - ['swin2sr', ['Swin2SRModel', Swin2SRModel]], - ['donut-swin', ['DonutSwinModel', DonutSwinModel]], - ['yolos', ['YolosModel', YolosModel]], - ['dpt', ['DPTModel', DPTModel]], - ['glpn', ['GLPNModel', GLPNModel]], - - ['hifigan', ['SpeechT5HifiGan', SpeechT5HifiGan]], - ['efficientnet', ['EfficientNetModel', EfficientNetModel]], + ['bert', BertModel], + ['nomic_bert', NomicBertModel], + ['roformer', RoFormerModel], + ['electra', ElectraModel], + ['esm', EsmModel], + ['convbert', ConvBertModel], + ['camembert', CamembertModel], + ['deberta', DebertaModel], + ['deberta-v2', DebertaV2Model], + ['mpnet', MPNetModel], + ['albert', AlbertModel], + ['distilbert', DistilBertModel], + ['roberta', RobertaModel], + ['xlm', XLMModel], + ['xlm-roberta', XLMRobertaModel], + ['clap', ClapModel], + ['clip', CLIPModel], + ['clipseg', CLIPSegModel], + ['chinese_clip', ChineseCLIPModel], + ['siglip', SiglipModel], + ['mobilebert', MobileBertModel], + ['squeezebert', SqueezeBertModel], + ['wav2vec2', Wav2Vec2Model], + ['wav2vec2-bert', Wav2Vec2BertModel], + ['unispeech', UniSpeechModel], + ['unispeech-sat', UniSpeechSatModel], + ['hubert', HubertModel], + ['wavlm', WavLMModel], + ['audio-spectrogram-transformer', ASTModel], + ['vits', VitsModel], + + ['detr', DetrModel], + ['table-transformer', TableTransformerModel], + ['vit', ViTModel], + ['mobilevit', MobileViTModel], + ['owlvit', OwlViTModel], + ['owlv2', Owlv2Model], + ['beit', BeitModel], + ['deit', DeiTModel], + ['convnext', ConvNextModel], + ['convnextv2', ConvNextV2Model], + ['dinov2', Dinov2Model], + ['resnet', ResNetModel], + ['swin', SwinModel], + ['swin2sr', Swin2SRModel], + ['donut-swin', DonutSwinModel], + ['yolos', YolosModel], + ['dpt', DPTModel], + ['glpn', GLPNModel], + + ['hifigan', SpeechT5HifiGan], + ['efficientnet', EfficientNetModel], ]); const MODEL_MAPPING_NAMES_ENCODER_DECODER = new Map([ - ['t5', ['T5Model', T5Model]], - ['longt5', ['LongT5Model', LongT5Model]], - ['mt5', ['MT5Model', MT5Model]], - ['bart', ['BartModel', BartModel]], - ['mbart', ['MBartModel', MBartModel]], - ['marian', ['MarianModel', MarianModel]], - ['whisper', ['WhisperModel', WhisperModel]], - ['m2m_100', ['M2M100Model', M2M100Model]], - ['blenderbot', ['BlenderbotModel', BlenderbotModel]], - ['blenderbot-small', ['BlenderbotSmallModel', BlenderbotSmallModel]], + ['t5', T5Model], + ['longt5', LongT5Model], + ['mt5', MT5Model], + ['bart', BartModel], + ['mbart', MBartModel], + ['marian', MarianModel], + ['whisper', WhisperModel], + ['m2m_100', M2M100Model], + ['blenderbot', BlenderbotModel], + ['blenderbot-small', BlenderbotSmallModel], ]); const MODEL_MAPPING_NAMES_DECODER_ONLY = new Map([ - ['bloom', ['BloomModel', BloomModel]], - ['gpt2', ['GPT2Model', GPT2Model]], - ['gptj', ['GPTJModel', GPTJModel]], - ['gpt_bigcode', ['GPTBigCodeModel', GPTBigCodeModel]], - ['gpt_neo', ['GPTNeoModel', GPTNeoModel]], - ['gpt_neox', ['GPTNeoXModel', GPTNeoXModel]], - ['codegen', ['CodeGenModel', CodeGenModel]], - ['llama', ['LlamaModel', LlamaModel]], - ['openelm', ['OpenELMModel', OpenELMModel]], - ['qwen2', ['Qwen2Model', Qwen2Model]], - ['phi', ['PhiModel', PhiModel]], - ['phi3', ['Phi3Model', Phi3Model]], - ['mpt', ['MptModel', MptModel]], - ['opt', ['OPTModel', OPTModel]], - ['mistral', ['MistralModel', MistralModel]], - ['starcoder2', ['Starcoder2Model', Starcoder2Model]], - ['falcon', ['FalconModel', FalconModel]], - ['stablelm', ['StableLmModel', StableLmModel]], + ['bloom', BloomModel], + ['gpt2', GPT2Model], + ['gptj', GPTJModel], + ['gpt_bigcode', GPTBigCodeModel], + ['gpt_neo', GPTNeoModel], + ['gpt_neox', GPTNeoXModel], + ['codegen', CodeGenModel], + ['llama', LlamaModel], + ['openelm', OpenELMModel], + ['qwen2', Qwen2Model], + ['phi', PhiModel], + ['phi3', Phi3Model], + ['mpt', MptModel], + ['opt', OPTModel], + ['mistral', MistralModel], + ['starcoder2', Starcoder2Model], + ['falcon', FalconModel], + ['stablelm', StableLmModel], ]); const MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = new Map([ - ['speecht5', ['SpeechT5ForSpeechToText', SpeechT5ForSpeechToText]], - ['whisper', ['WhisperForConditionalGeneration', WhisperForConditionalGeneration]], + ['speecht5', SpeechT5ForSpeechToText], + ['whisper', WhisperForConditionalGeneration], ]); const MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING_NAMES = new Map([ - ['speecht5', ['SpeechT5ForTextToSpeech', SpeechT5ForTextToSpeech]], + ['speecht5', SpeechT5ForTextToSpeech], ]); +// @ts-ignore const MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING_NAMES = new Map([ - ['vits', ['VitsModel', VitsModel]], - ['musicgen', ['MusicgenForConditionalGeneration', MusicgenForConditionalGeneration]], + ['vits', VitsModel], + ['musicgen', MusicgenForConditionalGeneration], ]); +// @ts-ignore const MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = new Map([ - ['bert', ['BertForSequenceClassification', BertForSequenceClassification]], - ['roformer', ['RoFormerForSequenceClassification', RoFormerForSequenceClassification]], - ['electra', ['ElectraForSequenceClassification', ElectraForSequenceClassification]], - ['esm', ['EsmForSequenceClassification', EsmForSequenceClassification]], - ['convbert', ['ConvBertForSequenceClassification', ConvBertForSequenceClassification]], - ['camembert', ['CamembertForSequenceClassification', CamembertForSequenceClassification]], - ['deberta', ['DebertaForSequenceClassification', DebertaForSequenceClassification]], - ['deberta-v2', ['DebertaV2ForSequenceClassification', DebertaV2ForSequenceClassification]], - ['mpnet', ['MPNetForSequenceClassification', MPNetForSequenceClassification]], - ['albert', ['AlbertForSequenceClassification', AlbertForSequenceClassification]], - ['distilbert', ['DistilBertForSequenceClassification', DistilBertForSequenceClassification]], - ['roberta', ['RobertaForSequenceClassification', RobertaForSequenceClassification]], - ['xlm', ['XLMForSequenceClassification', XLMForSequenceClassification]], - ['xlm-roberta', ['XLMRobertaForSequenceClassification', XLMRobertaForSequenceClassification]], - ['bart', ['BartForSequenceClassification', BartForSequenceClassification]], - ['mbart', ['MBartForSequenceClassification', MBartForSequenceClassification]], - ['mobilebert', ['MobileBertForSequenceClassification', MobileBertForSequenceClassification]], - ['squeezebert', ['SqueezeBertForSequenceClassification', SqueezeBertForSequenceClassification]], + ['bert', BertForSequenceClassification], + ['roformer', RoFormerForSequenceClassification], + ['electra', ElectraForSequenceClassification], + ['esm', EsmForSequenceClassification], + ['convbert', ConvBertForSequenceClassification], + ['camembert', CamembertForSequenceClassification], + ['deberta', DebertaForSequenceClassification], + ['deberta-v2', DebertaV2ForSequenceClassification], + ['mpnet', MPNetForSequenceClassification], + ['albert', AlbertForSequenceClassification], + ['distilbert', DistilBertForSequenceClassification], + ['roberta', RobertaForSequenceClassification], + ['xlm', XLMForSequenceClassification], + ['xlm-roberta', XLMRobertaForSequenceClassification], + ['bart', BartForSequenceClassification], + ['mbart', MBartForSequenceClassification], + ['mobilebert', MobileBertForSequenceClassification], + ['squeezebert', SqueezeBertForSequenceClassification], ]); const MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = new Map([ - ['bert', ['BertForTokenClassification', BertForTokenClassification]], - ['roformer', ['RoFormerForTokenClassification', RoFormerForTokenClassification]], - ['electra', ['ElectraForTokenClassification', ElectraForTokenClassification]], - ['esm', ['EsmForTokenClassification', EsmForTokenClassification]], - ['convbert', ['ConvBertForTokenClassification', ConvBertForTokenClassification]], - ['camembert', ['CamembertForTokenClassification', CamembertForTokenClassification]], - ['deberta', ['DebertaForTokenClassification', DebertaForTokenClassification]], - ['deberta-v2', ['DebertaV2ForTokenClassification', DebertaV2ForTokenClassification]], - ['mpnet', ['MPNetForTokenClassification', MPNetForTokenClassification]], - ['distilbert', ['DistilBertForTokenClassification', DistilBertForTokenClassification]], - ['roberta', ['RobertaForTokenClassification', RobertaForTokenClassification]], - ['xlm', ['XLMForTokenClassification', XLMForTokenClassification]], - ['xlm-roberta', ['XLMRobertaForTokenClassification', XLMRobertaForTokenClassification]], + ['bert', BertForTokenClassification], + ['roformer', RoFormerForTokenClassification], + ['electra', ElectraForTokenClassification], + ['esm', EsmForTokenClassification], + ['convbert', ConvBertForTokenClassification], + ['camembert', CamembertForTokenClassification], + ['deberta', DebertaForTokenClassification], + ['deberta-v2', DebertaV2ForTokenClassification], + ['mpnet', MPNetForTokenClassification], + ['distilbert', DistilBertForTokenClassification], + ['roberta', RobertaForTokenClassification], + ['xlm', XLMForTokenClassification], + ['xlm-roberta', XLMRobertaForTokenClassification], ]); const MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = new Map([ - ['t5', ['T5ForConditionalGeneration', T5ForConditionalGeneration]], - ['longt5', ['LongT5ForConditionalGeneration', LongT5ForConditionalGeneration]], - ['mt5', ['MT5ForConditionalGeneration', MT5ForConditionalGeneration]], - ['bart', ['BartForConditionalGeneration', BartForConditionalGeneration]], - ['mbart', ['MBartForConditionalGeneration', MBartForConditionalGeneration]], - ['marian', ['MarianMTModel', MarianMTModel]], - ['m2m_100', ['M2M100ForConditionalGeneration', M2M100ForConditionalGeneration]], - ['blenderbot', ['BlenderbotForConditionalGeneration', BlenderbotForConditionalGeneration]], - ['blenderbot-small', ['BlenderbotSmallForConditionalGeneration', BlenderbotSmallForConditionalGeneration]], + ['t5', T5ForConditionalGeneration], + ['longt5', LongT5ForConditionalGeneration], + ['mt5', MT5ForConditionalGeneration], + ['bart', BartForConditionalGeneration], + ['mbart', MBartForConditionalGeneration], + ['marian', MarianMTModel], + ['m2m_100', M2M100ForConditionalGeneration], + ['blenderbot', BlenderbotForConditionalGeneration], + ['blenderbot-small', BlenderbotSmallForConditionalGeneration], ]); const MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = new Map([ - ['bloom', ['BloomForCausalLM', BloomForCausalLM]], - ['gpt2', ['GPT2LMHeadModel', GPT2LMHeadModel]], - ['gptj', ['GPTJForCausalLM', GPTJForCausalLM]], - ['gpt_bigcode', ['GPTBigCodeForCausalLM', GPTBigCodeForCausalLM]], - ['gpt_neo', ['GPTNeoForCausalLM', GPTNeoForCausalLM]], - ['gpt_neox', ['GPTNeoXForCausalLM', GPTNeoXForCausalLM]], - ['codegen', ['CodeGenForCausalLM', CodeGenForCausalLM]], - ['llama', ['LlamaForCausalLM', LlamaForCausalLM]], - ['openelm', ['OpenELMForCausalLM', OpenELMForCausalLM]], - ['qwen2', ['Qwen2ForCausalLM', Qwen2ForCausalLM]], - ['phi', ['PhiForCausalLM', PhiForCausalLM]], - ['phi3', ['Phi3ForCausalLM', Phi3ForCausalLM]], - ['mpt', ['MptForCausalLM', MptForCausalLM]], - ['opt', ['OPTForCausalLM', OPTForCausalLM]], - ['mbart', ['MBartForCausalLM', MBartForCausalLM]], - ['mistral', ['MistralForCausalLM', MistralForCausalLM]], - ['starcoder2', ['Starcoder2ForCausalLM', Starcoder2ForCausalLM]], - ['falcon', ['FalconForCausalLM', FalconForCausalLM]], - ['trocr', ['TrOCRForCausalLM', TrOCRForCausalLM]], - ['stablelm', ['StableLmForCausalLM', StableLmForCausalLM]], + ['bloom', BloomForCausalLM], + ['gpt2', GPT2LMHeadModel], + ['gptj', GPTJForCausalLM], + ['gpt_bigcode', GPTBigCodeForCausalLM], + ['gpt_neo', GPTNeoForCausalLM], + ['gpt_neox', GPTNeoXForCausalLM], + ['codegen', CodeGenForCausalLM], + ['llama', LlamaForCausalLM], + ['openelm', OpenELMForCausalLM], + ['qwen2', Qwen2ForCausalLM], + ['phi', PhiForCausalLM], + ['phi3', Phi3ForCausalLM], + ['mpt', MptForCausalLM], + ['opt', OPTForCausalLM], + ['mbart', MBartForCausalLM], + ['mistral', MistralForCausalLM], + ['starcoder2', Starcoder2ForCausalLM], + ['falcon', FalconForCausalLM], + ['trocr', TrOCRForCausalLM], + ['stablelm', StableLmForCausalLM], ]); const MODEL_FOR_MASKED_LM_MAPPING_NAMES = new Map([ - ['bert', ['BertForMaskedLM', BertForMaskedLM]], - ['roformer', ['RoFormerForMaskedLM', RoFormerForMaskedLM]], - ['electra', ['ElectraForMaskedLM', ElectraForMaskedLM]], - ['esm', ['EsmForMaskedLM', EsmForMaskedLM]], - ['convbert', ['ConvBertForMaskedLM', ConvBertForMaskedLM]], - ['camembert', ['CamembertForMaskedLM', CamembertForMaskedLM]], - ['deberta', ['DebertaForMaskedLM', DebertaForMaskedLM]], - ['deberta-v2', ['DebertaV2ForMaskedLM', DebertaV2ForMaskedLM]], - ['mpnet', ['MPNetForMaskedLM', MPNetForMaskedLM]], - ['albert', ['AlbertForMaskedLM', AlbertForMaskedLM]], - ['distilbert', ['DistilBertForMaskedLM', DistilBertForMaskedLM]], - ['roberta', ['RobertaForMaskedLM', RobertaForMaskedLM]], - ['xlm', ['XLMWithLMHeadModel', XLMWithLMHeadModel]], - ['xlm-roberta', ['XLMRobertaForMaskedLM', XLMRobertaForMaskedLM]], - ['mobilebert', ['MobileBertForMaskedLM', MobileBertForMaskedLM]], - ['squeezebert', ['SqueezeBertForMaskedLM', SqueezeBertForMaskedLM]], + ['bert', BertForMaskedLM], + ['roformer', RoFormerForMaskedLM], + ['electra', ElectraForMaskedLM], + ['esm', EsmForMaskedLM], + ['convbert', ConvBertForMaskedLM], + ['camembert', CamembertForMaskedLM], + ['deberta', DebertaForMaskedLM], + ['deberta-v2', DebertaV2ForMaskedLM], + ['mpnet', MPNetForMaskedLM], + ['albert', AlbertForMaskedLM], + ['distilbert', DistilBertForMaskedLM], + ['roberta', RobertaForMaskedLM], + ['xlm', XLMWithLMHeadModel], + ['xlm-roberta', XLMRobertaForMaskedLM], + ['mobilebert', MobileBertForMaskedLM], + ['squeezebert', SqueezeBertForMaskedLM], ]); const MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = new Map([ - ['bert', ['BertForQuestionAnswering', BertForQuestionAnswering]], - ['roformer', ['RoFormerForQuestionAnswering', RoFormerForQuestionAnswering]], - ['electra', ['ElectraForQuestionAnswering', ElectraForQuestionAnswering]], - ['convbert', ['ConvBertForQuestionAnswering', ConvBertForQuestionAnswering]], - ['camembert', ['CamembertForQuestionAnswering', CamembertForQuestionAnswering]], - ['deberta', ['DebertaForQuestionAnswering', DebertaForQuestionAnswering]], - ['deberta-v2', ['DebertaV2ForQuestionAnswering', DebertaV2ForQuestionAnswering]], - ['mpnet', ['MPNetForQuestionAnswering', MPNetForQuestionAnswering]], - ['albert', ['AlbertForQuestionAnswering', AlbertForQuestionAnswering]], - ['distilbert', ['DistilBertForQuestionAnswering', DistilBertForQuestionAnswering]], - ['roberta', ['RobertaForQuestionAnswering', RobertaForQuestionAnswering]], - ['xlm', ['XLMForQuestionAnswering', XLMForQuestionAnswering]], - ['xlm-roberta', ['XLMRobertaForQuestionAnswering', XLMRobertaForQuestionAnswering]], - ['mobilebert', ['MobileBertForQuestionAnswering', MobileBertForQuestionAnswering]], - ['squeezebert', ['SqueezeBertForQuestionAnswering', SqueezeBertForQuestionAnswering]], + ['bert', BertForQuestionAnswering], + ['roformer', RoFormerForQuestionAnswering], + ['electra', ElectraForQuestionAnswering], + ['convbert', ConvBertForQuestionAnswering], + ['camembert', CamembertForQuestionAnswering], + ['deberta', DebertaForQuestionAnswering], + ['deberta-v2', DebertaV2ForQuestionAnswering], + ['mpnet', MPNetForQuestionAnswering], + ['albert', AlbertForQuestionAnswering], + ['distilbert', DistilBertForQuestionAnswering], + ['roberta', RobertaForQuestionAnswering], + ['xlm', XLMForQuestionAnswering], + ['xlm-roberta', XLMRobertaForQuestionAnswering], + ['mobilebert', MobileBertForQuestionAnswering], + ['squeezebert', SqueezeBertForQuestionAnswering], ]); const MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = new Map([ - ['vision-encoder-decoder', ['VisionEncoderDecoderModel', VisionEncoderDecoderModel]], + ['vision-encoder-decoder', VisionEncoderDecoderModel], ]); const MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES = new Map([ - ['llava', ['LlavaForConditionalGeneration', LlavaForConditionalGeneration]], + ['llava', LlavaForConditionalGeneration], ]); const MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES = new Map([ - ['vision-encoder-decoder', ['VisionEncoderDecoderModel', VisionEncoderDecoderModel]], + ['vision-encoder-decoder', VisionEncoderDecoderModel], ]); const MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = new Map([ - ['vit', ['ViTForImageClassification', ViTForImageClassification]], - ['mobilevit', ['MobileViTForImageClassification', MobileViTForImageClassification]], - ['beit', ['BeitForImageClassification', BeitForImageClassification]], - ['deit', ['DeiTForImageClassification', DeiTForImageClassification]], - ['convnext', ['ConvNextForImageClassification', ConvNextForImageClassification]], - ['convnextv2', ['ConvNextV2ForImageClassification', ConvNextV2ForImageClassification]], - ['dinov2', ['Dinov2ForImageClassification', Dinov2ForImageClassification]], - ['resnet', ['ResNetForImageClassification', ResNetForImageClassification]], - ['swin', ['SwinForImageClassification', SwinForImageClassification]], - ['segformer', ['SegformerForImageClassification', SegformerForImageClassification]], - ['efficientnet', ['EfficientNetForImageClassification', EfficientNetForImageClassification]], + ['vit', ViTForImageClassification], + ['mobilevit', MobileViTForImageClassification], + ['beit', BeitForImageClassification], + ['deit', DeiTForImageClassification], + ['convnext', ConvNextForImageClassification], + ['convnextv2', ConvNextV2ForImageClassification], + ['dinov2', Dinov2ForImageClassification], + ['resnet', ResNetForImageClassification], + ['swin', SwinForImageClassification], + ['segformer', SegformerForImageClassification], + ['efficientnet', EfficientNetForImageClassification], ]); const MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES = new Map([ - ['detr', ['DetrForObjectDetection', DetrForObjectDetection]], - ['table-transformer', ['TableTransformerForObjectDetection', TableTransformerForObjectDetection]], - ['yolos', ['YolosForObjectDetection', YolosForObjectDetection]], + ['detr', DetrForObjectDetection], + ['table-transformer', TableTransformerForObjectDetection], + ['yolos', YolosForObjectDetection], ]); const MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES = new Map([ - ['owlvit', ['OwlViTForObjectDetection', OwlViTForObjectDetection]], - ['owlv2', ['Owlv2ForObjectDetection', Owlv2ForObjectDetection]], + ['owlvit', OwlViTForObjectDetection], + ['owlv2', Owlv2ForObjectDetection], ]); const MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES = new Map([ - ['detr', ['DetrForSegmentation', DetrForSegmentation]], - ['clipseg', ['CLIPSegForImageSegmentation', CLIPSegForImageSegmentation]], + ['detr', DetrForSegmentation], + ['clipseg', CLIPSegForImageSegmentation], ]); const MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES = new Map([ - ['segformer', ['SegformerForSemanticSegmentation', SegformerForSemanticSegmentation]], + ['segformer', SegformerForSemanticSegmentation], ]); const MODEL_FOR_MASK_GENERATION_MAPPING_NAMES = new Map([ - ['sam', ['SamModel', SamModel]], + ['sam', SamModel], ]); +// @ts-ignore const MODEL_FOR_CTC_MAPPING_NAMES = new Map([ - ['wav2vec2', ['Wav2Vec2ForCTC', Wav2Vec2ForCTC]], - ['wav2vec2-bert', ['Wav2Vec2BertForCTC', Wav2Vec2BertForCTC]], - ['unispeech', ['UniSpeechForCTC', UniSpeechForCTC]], - ['unispeech-sat', ['UniSpeechSatForCTC', UniSpeechSatForCTC]], - ['wavlm', ['WavLMForCTC', WavLMForCTC]], - ['hubert', ['HubertForCTC', HubertForCTC]], + ['wav2vec2', Wav2Vec2ForCTC], + ['wav2vec2-bert', Wav2Vec2BertForCTC], + ['unispeech', UniSpeechForCTC], + ['unispeech-sat', UniSpeechSatForCTC], + ['wavlm', WavLMForCTC], + ['hubert', HubertForCTC], ]); const MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = new Map([ - ['wav2vec2', ['Wav2Vec2ForSequenceClassification', Wav2Vec2ForSequenceClassification]], - ['wav2vec2-bert', ['Wav2Vec2BertForSequenceClassification', Wav2Vec2BertForSequenceClassification]], - ['unispeech', ['UniSpeechForSequenceClassification', UniSpeechForSequenceClassification]], - ['unispeech-sat', ['UniSpeechSatForSequenceClassification', UniSpeechSatForSequenceClassification]], - ['wavlm', ['WavLMForSequenceClassification', WavLMForSequenceClassification]], - ['hubert', ['HubertForSequenceClassification', HubertForSequenceClassification]], - ['audio-spectrogram-transformer', ['ASTForAudioClassification', ASTForAudioClassification]], + ['wav2vec2', Wav2Vec2ForSequenceClassification], + ['wav2vec2-bert', Wav2Vec2BertForSequenceClassification], + ['unispeech', UniSpeechForSequenceClassification], + ['unispeech-sat', UniSpeechSatForSequenceClassification], + ['wavlm', WavLMForSequenceClassification], + ['hubert', HubertForSequenceClassification], + ['audio-spectrogram-transformer', ASTForAudioClassification], ]); const MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES = new Map([ - ['wavlm', ['WavLMForXVector', WavLMForXVector]], + ['wavlm', WavLMForXVector], ]); const MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES = new Map([ - ['unispeech-sat', ['UniSpeechSatForAudioFrameClassification', UniSpeechSatForAudioFrameClassification]], - ['wavlm', ['WavLMForAudioFrameClassification', WavLMForAudioFrameClassification]], - ['wav2vec2', ['Wav2Vec2ForAudioFrameClassification', Wav2Vec2ForAudioFrameClassification]], + ['unispeech-sat', UniSpeechSatForAudioFrameClassification], + ['wavlm', WavLMForAudioFrameClassification], + ['wav2vec2', Wav2Vec2ForAudioFrameClassification], ]); const MODEL_FOR_IMAGE_MATTING_MAPPING_NAMES = new Map([ - ['vitmatte', ['VitMatteForImageMatting', VitMatteForImageMatting]], + ['vitmatte', VitMatteForImageMatting], ]); const MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES = new Map([ - ['swin2sr', ['Swin2SRForImageSuperResolution', Swin2SRForImageSuperResolution]], + ['swin2sr', Swin2SRForImageSuperResolution], ]) const MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES = new Map([ - ['dpt', ['DPTForDepthEstimation', DPTForDepthEstimation]], - ['depth_anything', ['DepthAnythingForDepthEstimation', DepthAnythingForDepthEstimation]], - ['glpn', ['GLPNForDepthEstimation', GLPNForDepthEstimation]], + ['dpt', DPTForDepthEstimation], + ['depth_anything', DepthAnythingForDepthEstimation], + ['glpn', GLPNForDepthEstimation], ]) // NOTE: This is custom to Transformers.js, and is necessary because certain models // (e.g., CLIP) are split into vision and text components const MODEL_FOR_IMAGE_FEATURE_EXTRACTION_MAPPING_NAMES = new Map([ - ['clip', ['CLIPVisionModelWithProjection', CLIPVisionModelWithProjection]], - ['siglip', ['SiglipVisionModel', SiglipVisionModel]], + ['clip', CLIPVisionModelWithProjection], + ['siglip', SiglipVisionModel], ]) const MODEL_CLASS_TYPE_MAPPING = [ @@ -6244,31 +6259,16 @@ const MODEL_CLASS_TYPE_MAPPING = [ [MODEL_FOR_IMAGE_FEATURE_EXTRACTION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly], ]; -for (const [mappings, type] of MODEL_CLASS_TYPE_MAPPING) { - // @ts-ignore - for (const [name, model] of mappings.values()) { - MODEL_TYPE_MAPPING.set(name, type); - MODEL_CLASS_TO_NAME_MAPPING.set(model, name); - MODEL_NAME_TO_CLASS_MAPPING.set(name, model); - } -} - const CUSTOM_MAPPING = [ // OVERRIDE: // TODO: Refactor to allow class to specify model - ['MusicgenForConditionalGeneration', MusicgenForConditionalGeneration, MODEL_TYPES.Musicgen], + [MusicgenForConditionalGeneration, MODEL_TYPES.Musicgen], - ['CLIPTextModelWithProjection', CLIPTextModelWithProjection, MODEL_TYPES.EncoderOnly], - ['SiglipTextModel', SiglipTextModel, MODEL_TYPES.EncoderOnly], - ['ClapTextModelWithProjection', ClapTextModelWithProjection, MODEL_TYPES.EncoderOnly], - ['ClapAudioModelWithProjection', ClapAudioModelWithProjection, MODEL_TYPES.EncoderOnly], + [CLIPTextModelWithProjection, MODEL_TYPES.EncoderOnly], + [SiglipTextModel, MODEL_TYPES.EncoderOnly], + [ClapTextModelWithProjection, MODEL_TYPES.EncoderOnly], + [ClapAudioModelWithProjection, MODEL_TYPES.EncoderOnly], ] -for (const [name, model, type] of CUSTOM_MAPPING) { - MODEL_TYPE_MAPPING.set(name, type); - MODEL_CLASS_TO_NAME_MAPPING.set(model, name); - MODEL_NAME_TO_CLASS_MAPPING.set(name, model); -} - /** * Helper class which is used to instantiate pretrained models with the `from_pretrained` function. @@ -6490,6 +6490,24 @@ export class AutoModelForImageFeatureExtraction extends PretrainedMixin { static MODEL_CLASS_MAPPINGS = [MODEL_FOR_IMAGE_FEATURE_EXTRACTION_MAPPING_NAMES]; } +for (let name in modelsExport){ + MODEL_CLASS_TO_NAME_MAPPING.set(modelsExport[name], name); + MODEL_NAME_TO_CLASS_MAPPING.set(name, modelsExport[name]); +} + +for (const [mappings, type] of MODEL_CLASS_TYPE_MAPPING) { + // @ts-ignore + for (const model of mappings.values()) { + const name = MODEL_CLASS_TO_NAME_MAPPING.get(model); + MODEL_TYPE_MAPPING.set(name, type); + } +} + +for (const [model, type] of CUSTOM_MAPPING) { + const name = MODEL_CLASS_TO_NAME_MAPPING.get(model); + MODEL_TYPE_MAPPING.set(name, type); +} + ////////////////////////////////////////////////// ////////////////////////////////////////////////// diff --git a/src/pipelines.js b/src/pipelines.js index 99662a97a..3d3538b2b 100644 --- a/src/pipelines.js +++ b/src/pipelines.js @@ -41,6 +41,7 @@ import { AutoModelForDepthEstimation, AutoModelForImageFeatureExtraction, PreTrainedModel, + getModelClassFromName, } from './models.js'; import { AutoProcessor, @@ -2837,7 +2838,7 @@ export class DepthEstimationPipeline extends (/** @type {new (options: ImagePipe } } -const SUPPORTED_TASKS = Object.freeze({ +const SUPPORTED_TASKS = { "text-classification": { "tokenizer": AutoTokenizer, "pipeline": TextClassificationPipeline, @@ -3121,7 +3122,7 @@ const SUPPORTED_TASKS = Object.freeze({ }, "type": "image", }, -}) +} // TODO: Add types for TASK_ALIASES @@ -3305,4 +3306,78 @@ async function loadItems(mapping, model, pretrainedOptions) { } return result; -} \ No newline at end of file +} + +/** + * Register a custom task pipeline. + * @param {string} task + * + * **Example:** Custom task: audio-feature-extraction. + * ```javascript + * import { + * Pipeline, + * read_audio, + * register_pipeline, + * pipeline + * } from '@xenova/transformers'; + * + * class AudioFeatureExtractionPipeline extends Pipeline { + * constructor(options) { + * super(options) + * } + * async _call(input, kwargs = {}) { + * input = await read_audio(input).then(input=>this.processor(input)) + * let { audio_embeds } = await this.model(input) + * return audio_embeds + * } + * } + * + * register_pipeline('audio-feature-extraction', { + * pipeline: AudioFeatureExtractionPipeline, + * model: 'ClapAudioModelWithProjection', + * processor: 'AutoProcessor', + * default_model: 'Xenova/larger_clap_music_and_speech' + * }) + * + * let pipe = await pipeline('audio-feature-extraction'); + * let out = await pipe('https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/jfk.wav') + * console.log(out) + * ``` + */ +export function register_pipeline( + task, { + tokenizer, + pipeline: pipelineClass, + model, + processor, + default_model = '', + type = '' + } = {} +) { + if (!( + ('prototype' in pipelineClass) && + (pipelineClass.prototype instanceof Pipeline) && + ("_call" in pipelineClass.prototype) + )) { + throw Error('pipeline class must inherit from Pipeline, and contains _call') + } + + const custom = { + tokenizer: tokenizer == 'AutoTokenizer' ? AutoTokenizer : tokenizer, + pipeline: pipelineClass, + model: typeof model == 'string' ? getModelClassFromName(model) : model, + processor: processor == 'AutoProcessor' ? AutoProcessor : processor, + 'default': (!default_model ? '' : { + model: default_model + }), + type + }; + + if (task in SUPPORTED_TASKS) { + for (let key in custom) { + if (custom[key]) SUPPORTED_TASKS[task][key] = custom[key]; + } + } + else SUPPORTED_TASKS[task] = custom; + +}