# Introduction

this notebook demos example of using llm in a MPS backend (apple silicon GPU) using torch 2.x

Referece:
* torch 2.x MPS Backend: https://pytorch.org/docs/stable/notes/mps.html

In [1]:
import applyllm as ap

ap.hello()

Hello from ApplyLlm


In [1]:
import os
import torch

In [2]:
# check that MPS is availabe (Metal Performance Shaders)
if not torch.backends.mps.is_available():
    print("MPS is not available")
else:
    print("MPS is available")
    mps_device = torch.device("mps")
    print(mps_device)



MPS is available
mps


## Define global variables

In [3]:
from dataclasses import dataclass
@dataclass
class DirectorySetting:
    """set the directory for the model download"""
    home_dir: str="/home/jovyan/llm-models"
    transformers_cache_home: str="core-kind/yinwang/models"
    huggingface_token_file: str="core-kind/yinwang/.cache/huggingface/token"

    def get_cache_home(self):
        """get the cache home"""
        return f"{self.home_dir}/{self.transformers_cache_home}"
    
    def get_token_file(self):
        """get the token file"""
        return f"{self.home_dir}/{self.huggingface_token_file}"
    
dir_mode_map = {
    "kf_notebook": DirectorySetting(),
    "mac_local": DirectorySetting(home_dir="/Users/yingding", transformers_cache_home="MODELS", huggingface_token_file="MODELS/.huggingface_token"),
}

model_map = {
    "llama7B-chat":     "meta-llama/Llama-2-7b-chat-hf",
    "llama13B-chat" :   "meta-llama/Llama-2-13b-chat-hf",
    "llama70B-chat" :   "meta-llama/Llama-2-70b-chat-hf",
    # "70B" : "meta-llama/Llama-2-70b-hf"
    "mistral7B-01":     "mistralai/Mistral-7B-v0.1",
    "mistral7B-inst02": "mistralai/Mistral-7B-Instruct-v0.2",
    "mixtral8x7B-01":   "mistralai/Mixtral-8x7B-v0.1",
    "mixtral8x7B-inst01":   "mistralai/Mixtral-8x7B-Instruct-v0.1", 
}

default_model_type = "mistral7B-01"
default_dir_mode = "mac_local"

dir_setting = dir_mode_map[default_dir_mode]

os.environ["WORLD_SIZE"] = "1" 
os.environ['XDG_CACHE_HOME'] = dir_setting.get_cache_home()

In [4]:
os.environ['XDG_CACHE_HOME']

'/Users/yingding/MODELS'

In [5]:
import transformers
import torch

# from transformers import AutoModelForCausalLM, AutoTokenizer
print(transformers.__version__)
print(torch.__version__)

4.37.2
2.1.2


## Choose LLM model

In [6]:
# model_type = default_model_type
model_type = "mistral7B-inst02"
# model_type = "llama7B-chat"
# model_type = "llama13B-chat"

model_name = model_map.get(model_type, default_model_type)
print(model_name)

mistralai/Mistral-7B-Instruct-v0.2


### Fast tokenizer

* https://github.com/huggingface/transformers/issues/23889#issuecomment-1584090357

In [7]:
# MAX_POSITION_EMBEDDINGS = 3072
# MAX_LENGTH = 4096

def need_token(model_type: str, model_name_prefix: str="llama"):
    """check if the model needs token"""
    return model_type.startswith(model_name_prefix)

def get_token(dir_setting: DirectorySetting):
    """get the token from the token file"""
    token_file_path = dir_setting.get_token_file()
    with open(token_file_path, "r") as file:
        # file read add a new line to the token, remove it.
        token = file.read().replace('\n', '')
    return token

if need_token(model_type):
    # kwargs = {"use_auth_token": get_token(dir_setting)}
    token_kwargs = {
        "token": get_token(dir_setting),
        # "truncation_side": "left",
        # "return_tensors": "pt",            
                    }
    print("huggingface token loaded")
else:
    token_kwargs = {}
    print("huggingface token is NOT needed")

huggingface token is NOT needed


### Load LLM Model and then Tokenizer

In [8]:
from torch import bfloat16

pipeline_kwargs = {
    "torch_dtype": torch.float16, #bfloat16 is not supported on MPS backend, float16 only on GPU accelerator
    # torch_dtype=torch.float32,
    # max_length=MAX_LENGTH,
    "max_length" : None, # remove the total length of the generated response
    "max_new_tokens" : 80,
}    



In [9]:
from transformers import AutoModelForCausalLM, AutoTokenizer
# from optimum.onnxruntime import ORTModelForCausalLM

# bnb_config = transformers.BitsAndBytesConfig(
#     load_in_4bit=True,
#     bnb_4bit_quant_type='nf4',
#     bnb_4bit_use_double_quant=True,
#     bnb_4bit_compute_dtype=bfloat16
# )

# pretrained_model_name_or_path
# TODO: use model config to set the max_length
# model = ORTModelForCausalLM.from_pretrained(
model = AutoModelForCausalLM.from_pretrained(    
  model_name,
  device_map='auto',
  # max_length= None, # remove the total length of the generated response
  # max_new_tokens=80,
  # quantization_config=bnb_config,
  # max_memory=f'{int(torch.cuda.mem_get_info()[0]/1024**3)-2}GB',
  # torch_dtype=torch.float16
  **token_kwargs,  
)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [10]:
tokenizer = AutoTokenizer.from_pretrained(
    model_name,
    # device='mps',
    #max_position_embeddings=MAX_LENGTH,
    #max_length=MAX_LENGTH,
    # device_map="auto", # put to GPU
    device="cpu", # put to CPU
    # use_auth_token=token, #transformer==4.31.0
    **token_kwargs
)

In [11]:
tokenizer

LlamaTokenizerFast(name_or_path='mistralai/Mistral-7B-Instruct-v0.2', vocab_size=32000, model_max_length=1000000000000000019884624838656, is_fast=True, padding_side='left', truncation_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>'}, clean_up_tokenization_spaces=False),  added_tokens_decoder={
	0: AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	1: AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	2: AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}

In [12]:
print(type(tokenizer))

<class 'transformers.models.llama.tokenization_llama_fast.LlamaTokenizerFast'>


### Testing token
* https://huggingface.co/docs/tokenizers/pipeline

In [13]:
inputs=["""
Q: Roger has 3 tennis balls. He buys 2 more cans of tennis balls. Each can has 4 tennis balls. How many tennis balls does he have now?
A: Roger started with 3 balls. 2 cans of 4 tennis balls each is 8 tennis balls. 3 + 8 = 11. The answer is 11.
Q: The cafeteria had 23 apples. If they used 20 to make lunch and bought 6 more, how many apples do they have?
"""]

In [14]:
input_test_encoded = tokenizer.encode(inputs[0])
print(f"{len(input_test_encoded)}")
print(input_test_encoded)

122
[1, 28705, 13, 28824, 28747, 14115, 659, 28705, 28770, 19552, 16852, 28723, 650, 957, 846, 28705, 28750, 680, 277, 509, 302, 19552, 16852, 28723, 7066, 541, 659, 28705, 28781, 19552, 16852, 28723, 1602, 1287, 19552, 16852, 1235, 400, 506, 1055, 28804, 13, 28741, 28747, 14115, 2774, 395, 28705, 28770, 16852, 28723, 28705, 28750, 277, 509, 302, 28705, 28781, 19552, 16852, 1430, 349, 28705, 28783, 19552, 16852, 28723, 28705, 28770, 648, 28705, 28783, 327, 28705, 28740, 28740, 28723, 415, 4372, 349, 28705, 28740, 28740, 28723, 13, 28824, 28747, 415, 18302, 1623, 515, 553, 28705, 28750, 28770, 979, 2815, 28723, 1047, 590, 1307, 28705, 28750, 28734, 298, 1038, 9957, 304, 7620, 28705, 28784, 680, 28725, 910, 1287, 979, 2815, 511, 590, 506, 28804, 13]


In [15]:
response_test_decoded = tokenizer.decode(input_test_encoded)
print(response_test_decoded)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


<s> 
Q: Roger has 3 tennis balls. He buys 2 more cans of tennis balls. Each can has 4 tennis balls. How many tennis balls does he have now?
A: Roger started with 3 balls. 2 cans of 4 tennis balls each is 8 tennis balls. 3 + 8 = 11. The answer is 11.
Q: The cafeteria had 23 apples. If they used 20 to make lunch and bought 6 more, how many apples do they have?



### load llm

In [16]:
# quantization_enabled = True
# bitsandbytes quantization does not work with MPS 

# quantization_enabled = False

# if quantization_enabled:
#     compression_kwargs = {
#         "load_in_8bit": True,
#         # "load_in_4bit": True,
#     }
# else:
#     compression_kwargs = {}

generator = transformers.pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer, # optional
    torch_dtype=torch.float16, #bfloat16 is not supported on MPS backend
    # torch_dtype=torch.float32,
    device_map="auto",
    # max_length=MAX_LENGTH,
    max_length=None, # remove the total length of the generated response
    max_new_tokens=80, # set the size of new generated token # 200, are the token size different as the text size?
    temperature=0.01,
    **token_kwargs,
    # **compression_kwargs,
)

##### Install autopep8 or black extension in VSCode
`shift + opt + F` to auto format python code

In [17]:
from util.accelerator_utils import AcceleratorStatus

gpu_status = AcceleratorStatus.create_accelerator_status()
gpu_status.gpu_usage()

--------------------
Allocated memory : 28.008636 GB
--------------------


In [18]:
import pydantic, time
pydantic.__version__

'1.10.13'

In [19]:
def chat_gen(
    generator: transformers.pipelines.text_generation.TextGenerationPipeline, 
    tokenizer: transformers.models.llama.tokenization_llama_fast.LlamaTokenizerFast,
    gpu_status: AcceleratorStatus
):    
    def local(input_prompts: list=[], temperature: float=0.1, max_new_tokens: int=200, verbose: bool=True) -> list:
        """
        do_sample, top_k, num_return_sequences, eos_token_id are the settings 
        the TextGenerationPipeline
        
        Reference:
        https://huggingface.co/docs/transformers/generation_strategies#customize-text-generation
        """
        start = time.time()
        sequences = generator(
            input_prompts,
            do_sample=True,
            top_k=5,
            top_p=0.95,
            num_return_sequences=1,
            # pad_token_id=tokenizer.eos_token_id, # for mistral
            eos_token_id=tokenizer.eos_token_id,
            # max_length=200,
            max_new_tokens= max_new_tokens, # 200 # max number of tokens to generate in the output
            temperature=temperature,
            repetition_penalty=1.1  # without this output begins repeating
        )
        # for seq in sequences:
        #     print(f"Result: \n{seq['generated_text']}")
        
        batch_result = []
        for prompt_result in sequences: # passed a list of prompt
            result = []
            for seq in prompt_result: # 
                result.append(f"Result: \n{seq['generated_text']}")
            batch_result.append(result)
            
        end = time.time()
        duration = end - start
        
        if verbose == True:
            for prompt_result in batch_result:
                for result in prompt_result:
                    print("promt-response")
                    print(result)
            print("-"*20)
            print(f"walltime: {duration} in secs.")
            gpu_status.gpu_usage()
            
        return batch_result   
    return local
    
chat = chat_gen(generator, tokenizer, gpu_status)

In [20]:
class PromptHelper():
    """
    mistral instruction example:
    https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2
    llama2 instruction examples:
    https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGUF
    """
    meta = "meta-llama"
    mistral = "mistralai"
    INST_MSG_MAP = {
        mistral: """<s>[INST] You are a helpful, respectful and honest assistant.
Always answer as helpfully as possible using the context text provided.
Your answers should only answer the question once and not have any text after the answer is done.\n
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct.
If you don't know the answer to a question, please don't share false information. Just return \"</s>\"
""",
        meta: """[INST]<<SYS>>You are a helpful, respectful and honest assistant.
Always answer as helpfully as possible using the context text provided.
Your answers should only answer the question once and not have any text after the answer is done.\n
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct.
If you don't know the answer to a question, please don't share false information.<</SYS>>
"""
    }

    def __init__(self, model_name):
        self.model_name = model_name

    
    def gen_prompt(self, query: str) -> str:      
        if model_name.startswith(self.meta):
            inst_msg = self.INST_MSG_MAP[self.meta]
        elif model_name.startswith(self.mistral):
            inst_msg = self.INST_MSG_MAP[self.mistral]
        else:
            inst_msg = ""

        prompt = f"""{inst_msg}\n{query}\n[/INST]""" if query is not None or len(query) > 0 else f"""{inst_msg}\n[/INST]"""
        return prompt
    
    def get_inst_msg(self) -> str:
        return self.gen_prompt("")

In [21]:
from functools import partial

prompt_helper = PromptHelper(model_name)

def get_inputs_by_model(idx, inputs, prompt_helper):
    print(prompt_helper.model_name)
    # generate a model dependent prompt with appropriate sys instruction message
    return prompt_helper.gen_prompt(inputs[idx])

get_inputs = partial(get_inputs_by_model, inputs=inputs, prompt_helper=prompt_helper)
print(get_inputs(0))

mistralai/Mistral-7B-Instruct-v0.2
<s>[INST] You are a helpful, respectful and honest assistant.
Always answer as helpfully as possible using the context text provided.
Your answers should only answer the question once and not have any text after the answer is done.

If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct.
If you don't know the answer to a question, please don't share false information. Just return "</s>"


Q: Roger has 3 tennis balls. He buys 2 more cans of tennis balls. Each can has 4 tennis balls. How many tennis balls does he have now?
A: Roger started with 3 balls. 2 cans of 4 tennis balls each is 8 tennis balls. 3 + 8 = 11. The answer is 11.
Q: The cafeteria had 23 apples. If they used 20 to make lunch and bought 6 more, how many apples do they have?

[/INST]


In [22]:
verbose = True
batch_answers = chat(inputs, temperature=0.001, max_new_tokens = 80, verbose=verbose)

if not verbose:
    prompt_0_results = batch_answers[0]
    print(prompt_0_results[0])

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


promt-response
Result: 

Q: Roger has 3 tennis balls. He buys 2 more cans of tennis balls. Each can has 4 tennis balls. How many tennis balls does he have now?
A: Roger started with 3 balls. 2 cans of 4 tennis balls each is 8 tennis balls. 3 + 8 = 11. The answer is 11.
Q: The cafeteria had 23 apples. If they used 20 to make lunch and bought 6 more, how many apples do they have?
A: The cafeteria originally had 23 apples. They used 20 for lunch, so that leaves 3 apples remaining. Then they bought 6 more, making the total number of apples 3 + 6 = 9. However, we made an error in our calculation. We should have subtracted the 20 apples used from the original
--------------------
walltime: 13.999051094055176 in secs.
--------------------
Allocated memory : 32.060165 GB
--------------------


### mlflow autologging langchain
* https://mlflow.org/docs/latest/llms/langchain/guide/index.html+
* https://github.com/mlflow/mlflow/issues/9237#issuecomment-1667549626

#### Issue
* HuggingFacePipeline is not callable from mlflow run: https://github.com/langchain-ai/langchain/issues/8858

#### LangChain Callback Handler
* https://python.langchain.com/docs/integrations/providers/aim_tracking
* https://python.langchain.com/docs/integrations/providers/mlflow_tracking
* https://python.langchain.com/docs/integrations/providers/mlflow_ai_gateway
* https://python.langchain.com/docs/integrations/providers/mlflow
* https://api.python.langchain.com/en/latest/_modules/langchain_community/callbacks/mlflow_callback.html

In [23]:
import os
# os.environ["MLFLOW_TRACKING_URI"] = "./mlruns"

In [24]:
import mlflow
import logging
import time
from pprint import pprint

logging.getLogger("mlflow").setLevel(logging.DEBUG)

from langchain import PromptTemplate, LLMChain, HuggingFaceHub
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
# from langchain.callbacks import MlflowCallbackHandler

# Set the run name to time string
run_name = time.strftime("%Y-%m-%d_%H-%M-%S")
experiment_name = "langchain"
search_pattern = f"name = '{experiment_name}'"
experiments = mlflow.search_experiments(filter_string=search_pattern)

if len(experiments) < 1:
    experiment_id = mlflow.create_experiment(name=experiment_name)
    print(f"experiment with string id {experiment_id} is created.")
else:
    experiment_id = experiments[0].experiment_id
    # experiment_id = experiments.experiment_id[0]
    print(f"experiment with string id {experiment_id} is reused.")

# mlflow_callback = MlflowCallbackHandler(experiment=experiment_name, name=run_name)

mlflow.end_run()
mlflow.set_experiment(experiment_id=experiment_id)
mlflow.start_run(run_name=run_name)


llm = HuggingFacePipeline(
    pipeline=generator 
)

template = prompt_helper.gen_prompt("{input}")
prompt = PromptTemplate(template=template, input_variables=["input"])

mlflow.log_param("system_prompt", template)

llm_chain = LLMChain(prompt=prompt, llm=llm)
# llm_chain = LLMChain(prompt=prompt, llm=llm, callbacks=[mlflow_callback])

# mlflow_callback.flush_tracker(llm_chain)

# print(llm_chain.invoke({"input": inputs[0]}))
# format the output of print with multiple lines of 60 max line length
response = llm_chain.run(inputs[0])
mlflow.log_param("response", response)

# Evaluate the model on some example questions
import pandas as pd
eval_data = pd.DataFrame(
    {
        "input": [
            "What is MLflow?",
            "What is Spark?",
        ],
        "ground_truth": [
            "MLflow is an open-source platform for managing the end-to-end machine learning (ML) " +
            "lifecycle. It was developed by Databricks, a company that specializes in big data and " +
            "machine learning solutions. MLflow is designed to address the challenges that data " +
            "scientists and machine learning engineers face when developing, training, and deploying " +
            "machine learning models.",
            "Apache Spark is an open-source, distributed computing system designed for big data " +
            "processing and analytics. It was developed in response to limitations of the Hadoop " +
            "MapReduce computing model, offering improvements in speed and ease of use. Spark " +
            "provides libraries for various tasks such as data ingestion, processing, and analysis " +
            "through its components like Spark SQL for structured data, Spark Streaming for " +
            "real-time data processing, and MLlib for machine learning tasks",
        ],
    }
)

print(eval_data)

class LocalHfpModel():
    """local huggingface pipeline model"""
    def __init__(self, llm_chain):
        self.llm_chain = llm_chain
    

    def __call__(self, data):
        # single call returns string
        # response = self.llm_chain.run(data["input"].tolist())
        # self.results.append(response)
        # GPU batch
        response = self.llm_chain.batch(data["input"].tolist())
        # print(type(response))
        # print(response)
        return [ _dict["text"] for _dict in response]

# load the LocalHfpModel() to mlflow.evaluate
results = mlflow.evaluate(
    model=LocalHfpModel(llm_chain),
    model_type="question-answering",
    targets="ground_truth",
    data=eval_data,
)
print(f"See aggregated evaluation results below: \n{results.metrics}")

# Evaluation result for each data record is available in `results.tables`.
eval_table = results.tables["eval_results_table"]
print(f"See evaluation table below: \n{eval_table}")

mlflow.end_run()

pprint(response, indent=0, width=100)

experiment with string id 724336309488726134 is reused.


  warn_deprecated(
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
  string_columns = trimmed_df.columns[(df.applymap(type) == str).all(0)]
  data = data.applymap(_hash_array_like_element_as_bytes)
2024/01/31 02:19:17 INFO mlflow.models.evaluation.base: Evaluating the model with the default evaluator.
2024/01/31 02:19:17 INFO mlflow.models.evaluation.default_evaluator: Computing model predictions.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


             input                                       ground_truth
0  What is MLflow?  MLflow is an open-source platform for managing...
1   What is Spark?  Apache Spark is an open-source, distributed co...


2024/01/31 02:19:30 INFO mlflow.models.evaluation.default_evaluator: Testing metrics on first row...
Using default facebook/roberta-hate-speech-dynabench-r4-target checkpoint
2024/01/31 02:19:34 INFO mlflow.models.evaluation.default_evaluator: Evaluating builtin metrics: token_count
2024/01/31 02:19:34 INFO mlflow.models.evaluation.default_evaluator: Evaluating builtin metrics: toxicity
2024/01/31 02:19:34 INFO mlflow.models.evaluation.default_evaluator: Evaluating builtin metrics: flesch_kincaid_grade_level
2024/01/31 02:19:34 INFO mlflow.models.evaluation.default_evaluator: Evaluating builtin metrics: ari_grade_level
2024/01/31 02:19:34 INFO mlflow.models.evaluation.default_evaluator: Evaluating builtin metrics: exact_match


See aggregated evaluation results below: 
{'toxicity/v1/mean': 0.00014677690342068672, 'toxicity/v1/variance': 2.3893283430098113e-11, 'toxicity/v1/p90': 0.00015068736393004657, 'toxicity/v1/ratio': 0.0, 'flesch_kincaid_grade_level/v1/mean': 13.399999999999999, 'flesch_kincaid_grade_level/v1/variance': 7.839999999999999, 'flesch_kincaid_grade_level/v1/p90': 15.639999999999999, 'ari_grade_level/v1/mean': 17.2, 'ari_grade_level/v1/variance': 12.25, 'ari_grade_level/v1/p90': 20.0, 'exact_match/v1': 0.0}


Downloading artifacts:   0%|          | 0/1 [00:00<?, ?it/s]

See evaluation table below: 
             input                                       ground_truth  \
0  What is MLflow?  MLflow is an open-source platform for managing...   
1   What is Spark?  Apache Spark is an open-source, distributed co...   

                                             outputs  token_count  \
0   MLflow is an open-source platform for the com...           27   
1   Apologies for any potential confusion, but yo...           73   

   toxicity/v1/score  flesch_kincaid_grade_level/v1/score  \
0           0.000142                                 16.2   
1           0.000152                                 10.6   

   ari_grade_level/v1/score  
0                      20.7  
1                      13.7  
(' A: The cafeteria started with 23 apples. They used 20, so they had 3 apples left. Then they '
 'bought 6 more, so they have 9 apples in total. The answer is 9.')


In [25]:
# # Set the run name to time string
# run_name = time.strftime("%Y-%m-%d_%H-%M-%S")
# experiment_name = "local_llm_test"
# search_pattern = f"name = '{experiment_name}'"
# experiments = mlflow.search_experiments(filter_string=search_pattern)

# if len(experiments) < 1:
#     experiment_id = mlflow.create_experiment(name=experiment_name)
#     print(f"experiment with string id {experiment_id} is created.")
# else:
#     experiment_id = experiments[0].experiment_id
#     # experiment_id = experiments.experiment_id[0]
#     print(f"experiment with string id {experiment_id} is reused.")

    
# try:
#     with mlflow.start_run(experiment_id=experiment_id, run_name=run_name) as run:
#         logged_model = mlflow.langchain.log_model(
#             lc_model=llm_chain,
#             artifact_path="models")
        
#     # Load the logged model using MLflow's Python function flavor
#     loaded_model = mlflow.pyfunc.load_model(logged_model.model_uri)

#     # Predict using the loaded model, with defined input schema from prompt template
#     print(loaded_model.predict([{"input": inputs[0]}]))
# except Exception as e:
#     print(e)
#     mlflow.end_run()


In [26]:
# We automatically log the model and trace related artifacts
# A model with name `lc_model` is registered, we can load it back as a PyFunc model
# model_name = "lc_model"
# model_version = 1
# loaded_model = mlflow.pyfunc.load_model(f"models:/{model_name}/{model_version}")
# print(loaded_model.predict(inputs))

In [27]:
import gc
def clear_mps_memory(tokenizer, generator):
    """clear the MPS memory"""
    if tokenizer is not None:
        del tokenizer
    if generator is not None:
        # need to move the model to cpu before delete.
        generator.model.cpu()
        del generator
    gc.collect()
    torch.mps.empty_cache()
    # report the GPU usage
    gpu_status.gpu_usage()


In [28]:
CLEAR_MEMORY = False
# CLEAR_MEMORY = True

if CLEAR_MEMORY:
    clear_mps_memory(tokenizer=tokenizer, generator=generator)

In [29]:
gpu_status.gpu_usage()

--------------------
Allocated memory : 33.490677 GB
--------------------
