In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from datasets import load_dataset
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM

In [None]:
from unsloth import FastLanguageModel
import nltk
from nltk.corpus import wordnet
import ssl

In [None]:
'''
def LoadModelHUB():
    tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1")
    tokenizer.padding_side = 'right' # to avoid the future warning
    model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1")
    return model, tokenizer

def formatting_prompts_func(sentences):
    out_sentences = []
    for sentence in sentences['sentences']:
        for text in sentence['sentence']:
            out_sentences.append(text)

    return out_sentences
'''

In [None]:
def LoadModelUnsloth():
    model, tokenizer = FastLanguageModel.from_pretrained(
            "mistralai/Mistral-7B-v0.1", device_map="auto", 
            max_seq_length = max_seq_length, dtype=None, load_in_4bit=True)

    model = FastLanguageModel.get_peft_model(
                    model, 
                    r = 16, 
                    target_modules = ["q_proj", "k_proj", "v_proj", 
                        "o_proj", "gate_proj", "up_proj", "down_proj",],
                    lora_alpha = 16,
                    lora_dropout = 0, # Supports any, but = 0 is optimized
                    bias = "none",    # Supports any, but = "none" is optimized
                    use_gradient_checkpointing = True,
                    random_state = 3407,
                    use_rslora = False,  # We support rank stabilized LoRA
                    loftq_config = None, # And LoftQ
                    )

    return model, tokenizer

In [None]:
def LoadDataset():
    dataset = load_dataset("McGill-NLP/stereoset", "intersentence", split='validation')
    #dataset = dataset.map(formatting_prompts_func, remove_columns=[f for f in dataset.features if not f == 'sentences'],batched=True)
    return dataset

In [None]:
def FineTuneModel(model, tokenizer, dataset):
    trainer = SFTTrainer(model, train_dataset=dataset, 
            formatting_func=formatting_prompts_func, max_seq_length=512)

    trainer.train()

In [None]:
def InitNLTK():
    try:
        _create_unverified_https_context = ssl._create_unverified_context
    except AttributeError:
        pass
    else:
        ssl._create_default_https_context = _create_unverified_https_context

    nltk.download('wordnet')
    nltk.download('omw-1.4')

In [None]:
def GetSynonyms(word):
    synonyms = set()
    for syn in wordnet.synsets(word):
        for lemma in syn.lemmas():
            synonyms.add(lemma.name())
    return list(synonyms)

In [None]:
def InitBiasWords():
    """
    Initialize the bias words in a list from a file
    """
    bias_words = ['bad', 'black', 'slave']

    return bias_words

In [None]:
def MitigateBias(text, bias_words):
    words = text.split()
    new_words = []
    for word in words:
        synonyms = GetSynonyms(word)
        unbiased_synonyms = [syn for syn in synonyms if syn not in bias_words]
        if unbiased_synonyms:
            new_word = unbiased_synonyms[0]
            new_words.append(new_word)
        else:
            new_words.append(word)

    return ' '.join(new_words)

In [None]:
def HandleQuestion(model, tokenizer, dataset, bias_words):
    pipe = pipeline("text-generation", model=model, tokenizer = tokenizer)

    question = input("Your question: ")
    print("finding a good answer for your question, please wait,...\n")
    answers = pipe(question, do_sample=True, max_new_tokens=100, 
            temperature=0.7, top_k=50, top_p=0.95, num_return_sequences=1)

    if not bias_words:
        output = answers[0]['generated_text']
    else:
        output = MitigateBias(answers[0]['generated_text'])

    print(output)

In [None]:
def main():
    InitNLTK() # to remove the bias words
    bias_words = InitBiasWords() # list of bias words present in our dataset
    model, tokenizer = LoadModelUnsloth() # Load the mistrel 7b ai model
    dataset = LoadDataset() # load our dataset
    FineTuneModel(model, tokenizer, dataset)

    HandleQuestion(model, tokenizer, dataset, bias_words)

In [None]:
if __name__ == "__main__":
    main()