Skip to content

Commit

Permalink
update model
Browse files Browse the repository at this point in the history
  • Loading branch information
Nghi Bui committed Jul 24, 2023
1 parent 22ae1da commit d5db4bb
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 6 deletions.
7 changes: 7 additions & 0 deletions codetf/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,13 @@ def load_model_pipeline(model_name, model_type="base", task="sum",

return model

def load_model_from_path(checkpoint_path, tokenizer_path, model_name, is_eval=True, load_in_8bit=False, load_in_4bit=False):
model_cls = registry.get_model_class(model_name)
model = model_cls.from_custom(checkpoint_path=checkpoint_path, tokenizer_path=tokenizer_path, load_in_8bit=load_in_8bit, load_in_4bit=load_in_4bit)
if is_eval:
model.eval()

return model

class ModelZoo:
def __init__(self, config_files):
Expand Down
12 changes: 11 additions & 1 deletion codetf/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,17 @@ def from_pretrained(model_class, model_card, load_in_8bit=False, load_in_4bit=Fa
Build a pretrained model from default configuration file, specified by model_type.
"""
model_config = OmegaConf.load(get_abs_path(model_class.MODEL_DICT))[model_card]
model_cls = model_class.load_model_from_config(model_config=model_config, load_in_8bit=load_in_8bit, load_in_4bit=load_in_4bit, weight_sharding=weight_sharding)
model_cls = model_class.load_huggingface_model_from_config(model_config=model_config, load_in_8bit=load_in_8bit, load_in_4bit=load_in_4bit, weight_sharding=weight_sharding)

return model_cls


@classmethod
def from_custom(model_class, checkpoint_path, tokenizer_path, load_in_8bit=False, load_in_4bit=False):
"""
Build a pretrained model from default configuration file, specified by model_type.
"""
model_cls = model_class.load_custom_model(checkpoint_path, tokenizer_path, load_in_8bit=load_in_8bit, load_in_4bit=load_in_4bit)

return model_cls

Expand Down
31 changes: 30 additions & 1 deletion codetf/models/causal_lm_models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def init_tokenizer(cls, model):
return tokenizer

@classmethod
def load_model_from_config(model_class, model_config, load_in_8bit=False, load_in_4bit=False, weight_sharding=False):
def load_huggingface_model_from_config(model_class, model_config, load_in_8bit=False, load_in_4bit=False, weight_sharding=False):
checkpoint = model_config["huggingface_url"]

if load_in_8bit and load_in_4bit:
Expand Down Expand Up @@ -79,6 +79,35 @@ def load_model_from_config(model_class, model_config, load_in_8bit=False, load_i
model_config=model_config,
tokenizer=tokenizer
)

@classmethod
def load_custom_model(model_class, checkpoint_path, tokenizer_path, load_in_8bit=False, load_in_4bit=False):

if load_in_8bit and load_in_4bit:
raise ValueError("Only one of load_in_8bit or load_in_4bit can be True. Please choose one.")

if load_in_8bit:
model = AutoModelForCausalLM.from_pretrained(checkpoint_path,
load_in_8bit=load_in_8bit,
low_cpu_mem_usage=True,
device_map="auto")
elif load_in_4bit:
model = AutoModelForCausalLM.from_pretrained(checkpoint_path,
load_in_4bit=load_in_4bit,
low_cpu_mem_usage=True,
device_map="auto")
else:
model = AutoModelForCausalLM.from_pretrained(checkpoint_path,
low_cpu_mem_usage=True,
device_map="auto")

tokenizer = model_class.init_tokenizer(tokenizer_path)

return model_class(
model=model,
model_config=model_config,
tokenizer=tokenizer
)

def forward(self, sources, max_length=512):
encoding = self.tokenizer(sources, return_tensors='pt').to(self.device)
Expand Down
35 changes: 31 additions & 4 deletions codetf/models/seq2seq_models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,15 @@ def init_tokenizer(cls, model):


@classmethod
def load_model_from_config(model_class, model_config, load_in_8bit=False, load_in_4bit=False, weight_sharding=False):
def load_huggingface_model_from_config(model_class, model_config, load_in_8bit=False, load_in_4bit=False, weight_sharding=False):

checkpoint = model_config["huggingface_url"]

if load_in_8bit and load_in_4bit:
raise ValueError("Only one of load_in_8bit or load_in_4bit can be True. Please choose one.")

# This "device" is for the case of CodeT5plus, will be removed in the future
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if weight_sharding:
try:
# Try to download and load the json index file
Expand Down Expand Up @@ -85,12 +85,10 @@ def load_model_from_config(model_class, model_config, load_in_8bit=False, load_i
else:
if model_config["device_map"]:
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint,
load_in_4bit=load_in_4bit,
low_cpu_mem_usage=True,
device_map=model_config["device_map"], trust_remote_code=model_config["trust_remote_code"])
else:
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint,
load_in_4bit=load_in_4bit,
low_cpu_mem_usage=True,
trust_remote_code=model_config["trust_remote_code"]).to(device)

Expand All @@ -103,6 +101,35 @@ def load_model_from_config(model_class, model_config, load_in_8bit=False, load_i
tokenizer=tokenizer
)

@classmethod
def load_custom_model(model_class, checkpoint_path, tokenizer_path, load_in_8bit=False, load_in_4bit=False):

if load_in_8bit and load_in_4bit:
raise ValueError("Only one of load_in_8bit or load_in_4bit can be True. Please choose one.")

if load_in_8bit:
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint_path,
load_in_8bit=load_in_8bit,
low_cpu_mem_usage=True,
device_map="auto")
elif load_in_4bit:
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint_path,
load_in_4bit=load_in_4bit,
low_cpu_mem_usage=True,
device_map="auto")
else:
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint_path,
low_cpu_mem_usage=True,
device_map="auto")

tokenizer = model_class.init_tokenizer(tokenizer_path)

return model_class(
model=model,
model_config=model_config,
tokenizer=tokenizer
)


def forward(self, sources, max_length=512, beam_size=5):
encoding = self.tokenizer(sources, return_tensors='pt').to(self.model.device)
Expand Down

0 comments on commit d5db4bb

Please sign in to comment.