<font color ='blue'><font size = "8">***Fine-tuning Causal LLM with LoRA***</font>

**The code merely deploys the base-model GPT2 and uses Parameter-efficient Fine tuning technique like Low Rank Adaptation to mitigate the challenges of limited computational resource. The model is fine-tuned on a very small dataset with LoRA rank=1 which barely allows room for creativity while generating texts. However, the generated texts are sensible and coherent to a certain extent as depicted in the examples underneath.**

In [2]:
!nvidia-smi -L

GPU 0: NVIDIA GeForce 940MX (UUID: GPU-c76fe8cb-13ad-1a0a-93ba-5edd4a3a3fff)


<font size='6'>***Setting up the model***</font>

In [118]:
import torch
import torch.nn as nn
from transformers import GPT2LMHeadModel, GPT2Tokenizer

In [119]:
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2")


Downloading vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

Downloading merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

Downloading config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

Downloading generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

In [120]:
print(model)

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)


<font size='5'>***Freezing the original weights***</font>


In [121]:
for param in model.parameters():
    param.requires_grad = False     #prohibits gradient flow
    if param.ndim == 1:
        param.data = param.data.to(torch.float32)   #more precision of bias values for a stable convergence
        
model.gradient_checkpointing_enable()    #trade-off between storing and forgetting activations for efficient memory usage
model.enable_input_require_grads()  #gradient flows through adapter weights by freezing original weights

<font size="5">**Configuring LoRA**

In [124]:
from peft import LoraConfig, get_peft_model

config = LoraConfig(r=1,  #decomposed rank
                    lora_alpha = 20,    #LoRA scaling factor
                    lora_dropout = 0.05,
                    bias = "none",
                    task_type = "CAUSAL_LM")

model = get_peft_model(model, config)

In [125]:
def net_trainable_params(model):    #prints the net trainable parameters for fine-tuning 
    all_params = 0
    trainable_params = 0
    for param in model.parameters():
        all_params += param.numel()
        if param.requires_grad == True:
            trainable_params += param.numel()
    
    print(f"Trainable Parameters: {trainable_params} || Total Parameters: {all_params}, || % trainable: {100*trainable_params/all_params}")
    
net_trainable_params(model)

Trainable Parameters: 36864 || Total Parameters: 124476672, || % trainable: 0.029615187655402612


<font size="5">**Loading the custom Q&A dataset**

In [126]:
import re

'''from datasets import load_dataset
dataset = load_dataset("yahma/alpaca-cleaned")'''

prompt_path = '/Users/Ritwik/PythonforPractice/Project NLP/LLM/prompt_template.txt'
with open(prompt_path, "r") as file:
        text_data = file.read()
        
text_data = re.sub(r'\n+', '\n', text_data).strip()  # Removes redundant newline characters
print(text_data)

Question: How has social media influenced communication patterns?
Answer: Social media has transformed communication by enabling instant messaging, video calls, and sharing of multimedia content, leading to faster and more accessible interactions globally.
Question: What are the psychological effects of excessive social media use?
Answer: Excessive social media use has been linked to various psychological effects such as anxiety, depression, low self-esteem, and addictive behaviors due to constant comparison and fear of missing out (FOMO).
Question: How does social media impact relationships and social interactions?
Answer: Social media can both strengthen and strain relationships. While it facilitates staying in touch with distant friends and family, excessive use can lead to decreased face-to-face interactions and misunderstandings due to misinterpretation of digital communication cues.
Question: What role does social media play in shaping public opinion and influencing political dis

In [127]:
from transformers import TextDataset

tokenized_dataset = TextDataset(tokenizer, prompt_path, block_size=128)

print(tokenized_dataset[0])


tensor([24361,    25,  1374,   468,  1919,  2056, 12824,  6946,  7572,    30,
          198, 33706,    25,  5483,  2056,   468, 14434,  6946,   416, 15882,
         9113, 19925,    11,  2008,  3848,    11,   290,  7373,   286, 40162,
         2695,    11,  3756,   284,  5443,   290,   517,  9857, 12213, 18309,
           13,   198,   198, 24361,    25,  1867,   389,   262, 10590,  3048,
          286, 13181,  1919,  2056,   779,    30,   198, 33706,    25,  1475,
        45428,  1919,  2056,   779,   468,   587,  6692,   284,  2972, 10590,
         3048,   884,   355,  9751,    11,  8862,    11,  1877,  2116,    12,
        31869,    11,   290, 28389, 14301,  2233,   284,  6937,  7208,   290,
         3252,   286,  4814,   503,   357,    37,  2662,    46,   737,   198,
          198, 24361,    25,  1374,   857,  1919,  2056,  2928,  6958,   290,
         1919, 12213,    30,   198, 33706,    25,  5483,  2056,   460,  1111,
        12160,   290, 14022,  6958,    13,  2893,   340, 42699])

<font size="5">**Kicking off with Fine tuning**

In [128]:
from transformers import Trainer, TrainingArguments, DataCollatorForLanguageModeling

output_dir = '/Users/Ritwik/PythonforPractice/ProjectNLP/LLM'
model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)

training_args = TrainingArguments(per_device_train_batch_size=5,
                                  gradient_accumulation_steps=2,
                                  warmup_steps=10,
                                  max_steps=100,
                                  learning_rate=2e-2,
                                  num_train_epochs=1,
                                  logging_steps=1,
                                  output_dir=output_dir
                                  
)

trainer = Trainer(model,
                  training_args,
                  data_collator=DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False),
                  train_dataset= tokenized_dataset
                  )

trainer.train()
trainer.save_model(output_dir)

Step,Training Loss


<font size="5">**Inference**</font>

In [212]:
#model = GPT2LMHeadModel.from_pretrained(output_dir)
tokenizer = GPT2Tokenizer.from_pretrained(output_dir)


def generate_text(sequence, max_length, model_path=None):
    
    
    '''tokenizer = GPT2Tokenizer.from_pretrained(tokenizer_path)'''
    ids = tokenizer.encode(f'{sequence}', return_tensors='pt')
    

    final_outputs = model.generate(
        ids,
        do_sample=True,
        max_length=max_length,
        pad_token_id=model.config.eos_token_id,
        top_k=50,
        top_p=0.95,
        #no_repeat_ngram_size=3
        repetition_penalty=2.0
        )
    
    print(tokenizer.decode(final_outputs[0], skip_special_tokens=True))

In [191]:
prompt1 = "Question: What are the harmful effects of internet?"
generate_text(prompt1, 50)

Question: What are the harmful effects of internet?
Answer: There is widespread adoption of harmful internet surfing harmful to consumers by consumers and the environment, with the potential for negative effects on productivity, advertising, and consumers Rodrigues have been forced to adopt


In [218]:
prompt2 = "Question: What causes a global warming?"
generate_text(prompt2, 50)

Question: What causes a global warming?
Answer: TopCause: Extreme temperatures increase carbon emissions in electricity by 30 to 40 percent, depending on the amount of carbon released in the air, and on the amount of sunlight emitted from cars and trucks?
