From 1078c89ee76dde5492f401bab16c972423537ca0 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Wed, 25 Sep 2024 12:09:18 -0700 Subject: [PATCH] solve llama3.2 import issue --- torchchat/model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchchat/model.py b/torchchat/model.py index edb0ce3d5..844aaf977 100644 --- a/torchchat/model.py +++ b/torchchat/model.py @@ -31,7 +31,7 @@ ) from torch.nn import functional as F -from torchtune.models.flamingo import flamingo_decoder, flamingo_vision_encoder +from torchtune.models.llama3_2_vision import llama3_2_vision_decoder, llama3_2_vision_encoder from torchtune.models.llama3_1._component_builders import llama3_1 as llama3_1_builder from torchtune.modules.model_fusion import DeepFusionModel from torchtune.models.clip import clip_vision_encoder @@ -213,7 +213,7 @@ def _llama3_1(cls): def _flamingo(cls): return cls( model_type=ModelType.Flamingo, - modules={"encoder": flamingo_vision_encoder, "decoder": flamingo_decoder}, + modules={"encoder": llama3_2_vision_encoder, "decoder": llama3_2_vision_decoder}, fusion_class=DeepFusionModel, )