In [86]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline

In [87]:
messages = [[
    {"role": "user", "content": "Can you provide ways to eat combinations of bananas and dragonfruits?"},
    {"role": "assistant", "content": "Sure! Here are some ways to eat bananas and dragonfruits together: 1. Banana and dragonfruit smoothie: Blend bananas and dragonfruits together with some milk and honey. 2. Banana and dragonfruit salad: Mix sliced bananas and dragonfruits together with some lemon juice and honey."},
    {"role": "user", "content": "What about solving an 2x + 3 = 7 equation?"},
],
[{"role": "user", "content": "Can you provide ways to eat combinations of bananas and dragonfruits?"}]]

In [88]:
import gc
gc.collect()

torch.random.manual_seed(0)
model_id = "microsoft/Phi-3-mini-4k-instruct"
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="cuda", 
    torch_dtype="auto", 
    trust_remote_code=True, 
)
tokenizer = AutoTokenizer.from_pretrained(model_id)

pipe = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
)

generation_args = {
    "max_new_tokens": 500,
    "return_full_text": False,
    "temperature": 0.0,
    "do_sample": False,
}

output = pipe(messages, **generation_args)
print(output)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

[[{'generated_text': ' To solve the equation 2x + 3 = 7, follow these steps:\n\n1. Subtract 3 from both sides of the equation: 2x + 3 - 3 = 7 - 3\n2. Simplify: 2x = 4\n3. Divide both sides by 2: 2x/2 = 4/2\n4. Simplify: x = 2\n\nThe solution to the equation 2x + 3 = 7 is x = 2.'}], [{'generated_text': ' Certainly! Bananas and dragonfruits can be combined in various delicious ways. Here are some creative ideas for incorporating these fruits into your diet:\n\n1. Smoothie: Blend together a ripe banana, a few slices of dragon fruit, a handful of spinach or kale, a tablespoon of chia seeds, and a splash of almond milk or coconut water. Add a sweetener like honey or agave syrup if desired.\n\n2. Fruit salad: Slice a banana and a dragon fruit into bite-sized pieces. Toss them together with other fruits like strawberries, blueberries, and kiwi. Drizzle with a squeeze of lime juice and a sprinkle of fresh mint leaves for added flavor.\n\n3. Tropical fruit bowl: Arrange slices of banana and dra

In [89]:
class CustomedPipeline():
    def __init__(
            self,
            config,
            model_id = "microsoft/Phi-3-mini-4k-instruct"
            device = "cuda"
        ):
        self.config = config
        self.tokenizer = AutoTokenizer.from_pretrained(model_id)
        self.model =  CustomedPhi3ForCausalLM(self.tokenizer, self.config)
        
    def preprocess(
            self,
            prompt_text,
            prefix="",
            handle_long_generation=None,
            add_special_tokens=None
            ):
        
        inputs = self.tokenizer.apply_chat_template(
                prompt_text,
                add_generation_prompt=True,
                tokenize=True,
                return_tensors="pt",
            return_dict=True,
                padding=True
            ).to('cuda')
        inputs['prompts'] = inputs['input_ids'].shape[-1]

        return inputs
    
    def forward(self, model_inputs, max_length = 500):
        input_ids = model_inputs['input_ids']
        attention_mask = model_inputs['attention_mask']
        prompt_len = model_inputs['prompts']

        generated_sequence = self.model.generate(input_ids=input_ids, attention_mask=attention_mask,max_length=max_length)
        return {"generated_sequence": generated_sequence, "prompt_len" :prompt_len}

    def postprocess(self, model_outputs, clean_up_tokenization_spaces=True):
        generated_sequence = model_outputs["generated_sequence"]
        prompt_len = model_outputs["prompt_len"]
        
        result = []
        
        for i, text in enumerate(generated_sequence):
            eos_pos = (text == self.tokenizer.eos_token_id).nonzero(as_tuple=True)[0]
  
            if len(eos_pos) > 0:
                eos_after_prompt = next((pos.item() for pos in eos_pos if pos.item() > prompt_len), None)

                if eos_after_prompt is not None:
                    text = text[prompt_len:eos_after_prompt-1]
                else:
                    text = text[prompt_len:]
            else:
                text = text[prompt_len:]
                
            #decoded_text = self.tokenizer.decode(text, skip_special_tokens=True)
            decoded_text = self.tokenizer.decode(text)
            result.append([{'generated':decoded_text}])

        return result

In [90]:
gc.collect()
pipe = CustomedPipeline()

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [91]:
inputs = pipe.preprocess(messages)
outputs = pipe.forward(inputs)
result = pipe.postprocess(outputs)
print(result)

[[{'generated': 'To solve the equation 2x + 3 = 7, follow these steps:\n\n1. Subtract 3 from both sides of the equation:\n   2x + 3 - 3 = 7 - 3\n   2x = 4\n\n2. Divide both sides of the equation by 2:\n   2x/2 = 4/2\n   x = 2\n\nSo, the solution to the equation 2x + 3 = 7 is x = 2.'}], [{'generated': 'Certainly! Bananas and dragonfruits can be combined in various delicious ways. Here are some creative ideas for incorporating these fruits into your diet:\n\n1. Smoothie: Blend together a ripe banana, a few slices of dragon fruit, a handful of spinach or kale, a tablespoon of chia seeds, and a splash of almond milk or coconut water. Add a sweetener like honey or agave syrup if desired.\n\n2. Fruit salad: Slice a banana and a dragon fruit into bite-sized pieces and mix them together with other fruits like strawberries, blueberries, and kiwi. Drizzle with a squeeze of lime juice and a sprinkle of fresh mint leaves for added flavor.\n\n3. Tropical fruit bowl: Arrange slices of banana and dra