Skip to content

Commit

Permalink
Remove unused code
Browse files Browse the repository at this point in the history
  • Loading branch information
shubhamugare committed May 17, 2024
1 parent 78609fb commit 693f04f
Showing 1 changed file with 3 additions and 17 deletions.
20 changes: 3 additions & 17 deletions syncode/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,36 +32,22 @@ def get_vocab_from_tokenizer(tokenizer):
return vocab

def load_model(model_name, device, quantize):
llama_models = ["Llama-7b", "Llama-13b", "CodeLlama-7b", "CodeLlama-7b-Python"]
if model_name == 'test':
model = AutoModelForCausalLM.from_pretrained('bigcode/tiny_starcoder_py').to(device)
elif model_name == 'test-instruct':
if model_name == 'test-instruct':
model = AutoModelForCausalLM.from_pretrained("rahuldshetty/tiny-starcoder-instruct")
elif model_name not in llama_models:
else:
if (quantize):
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, cache_dir=HF_CACHE, token=HF_ACCESS_TOKEN, trust_remote_code=True).eval().to(device)
else:
model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir=HF_CACHE, token=HF_ACCESS_TOKEN, trust_remote_code=True).eval().to(device)
elif model_name in llama_models:
model_location = "/data/share/models/hugging_face/" + model_name
if (quantize):
model = LlamaForCausalLM.from_pretrained(model_location, torch_dtype=torch.bfloat16).eval().to(device)
else:
model = LlamaForCausalLM.from_pretrained(model_location).eval().to(device)
return model

def load_tokenizer(model_name):
llama_models = ["Llama-7b", "Llama-13b", "CodeLlama-7b", "CodeLlama-7b-Python"]
if model_name == 'test':
tokenizer = AutoTokenizer.from_pretrained('bigcode/tiny_starcoder_py')
elif model_name == 'test-instruct':
tokenizer = AutoTokenizer.from_pretrained("rahuldshetty/tiny-starcoder-instruct")
elif model_name not in llama_models:
else:
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=HF_CACHE, token=HF_ACCESS_TOKEN, trust_remote_code=True)
elif model_name in llama_models:
# TODO: remove this hardcoding
model_location = "/data/share/models/hugging_face/" + model_name
tokenizer = LlamaTokenizer.from_pretrained(model_location)
return tokenizer

class Logger:
Expand Down

0 comments on commit 693f04f

Please sign in to comment.