In [1]:
# import required libraries
import torch
from accelerate import Accelerator
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
)

In [None]:
#load SteloCoder weights (backbone, experts and gate)
device = 'cuda'

base_model_id = "bigcode/starcoder"
model_id = "jlpan/moe_test"

tokenizer = AutoTokenizer.from_pretrained(base_model_id)
model = AutoModelForCausalLM.from_pretrained(
        model_id,
        torch_dtype=torch.float16,
        device_map='auto',
)

In [3]:
#we create our own generate function as for some reasons model.generate does not always select the right expert during generation. 
#alternatively, you can set the expert to use prior to feeding the input to the model (this is not done manually, the model still 
#is able to tell which expert to use but it has to be done before actual inference). Both methods are displayed below.

def generate(model, prompt, max_tokens=500, return_only_new=False):
    new = tokenizer('<py>', return_tensors='pt')['input_ids']
    input_ids = tokenizer(prompt, return_tensors='pt')['input_ids'].to(device)
    for _ in range(max_tokens):
        logits = model(input_ids).logits.squeeze(0)
        token = torch.argmax(logits, dim=-1)[-1].item()
        input_ids = torch.cat((input_ids.squeeze(), torch.tensor([token]).to(device))).unsqueeze(0)
        if return_only_new:
            new = torch.cat((new.squeeze(), torch.tensor([token]))).unsqueeze(0)
        if token == 0:
            break
    return new if return_only_new else input_ids

In [None]:
#using our generate function

model.eval()

prompt_cpp = '<code> #include <iostream> bool isLeapYear(int year) { return ((year % 4 == 0 && year % 100 != 0) || (year % 400 == 0)); } <py>'
prompt_php = '<code> function isLeapYear($year) { return (($year % 4 == 0 && $year % 100 != 0) || ($year % 400 == 0)); } <py>'
prompt_js = '<code> function isLeapYear(year) { return (year % 4 === 0 && year % 100 !== 0) || (year % 400 === 0); } <py>'
prompt_csharp = '<code> static bool IsLeapYear(int year) { return (year % 4 == 0 && year % 100 != 0) || (year % 400 == 0); } <py>'
prompt_java = '<code> public static boolean isLeapYear(int year) { return (year % 4 == 0 && year % 100 != 0) || (year % 400 == 0); } <py>'

with torch.no_grad():
    for prompt in [prompt_cpp, prompt_php, prompt_js, prompt_csharp, prompt_java]:
        output = generate(model, prompt, 100)
        print(tokenizer.decode(output[0]))

In [None]:
#alternative generation, setting the expert first

model.eval()

with torch.no_grad():
    for prompt in [prompt_cpp, prompt_php, prompt_js, prompt_csharp, prompt_java]:
        input = tokenizer(prompt, return_tensors='pt')
        input['input_ids'] = input['input_ids'].to(device)
        input['attention_mask'] = input['attention_mask'].to(device)

        model.set_expert(input['input_ids'])
        output = model.generate(**input, max_new_tokens=100)
        model.reset_expert()
        print(tokenizer.decode(output[0]))