# Google – AI Assistants for Data Tasks with Gemma Using [State-of-the-art Representation Fine-Tuning (ReFT) methods: A Powerful, Parameter-Efficient, and Interpretable fine-tuning method](https://github.com/stanfordnlp/pyreft/tree/main) and [Pytorch](https://pytorch.org/)

<div align="center">
    <img src="https://i.ibb.co/8xZNc32/Gemma.png">
</div>


In this competition, the goal is to create a notebook that demonstrates how to use the Gemma LLM to accomplish one of the following data science oriented tasks:

    Explain or teach basic data science concepts.
    Answer common questions about the Python programming language.
    Summarize Kaggle solution write-ups.
    Explain or teach concepts from Kaggle competition solution write-ups.
    Answer common questions about the Kaggle platform.

Specifically, in this notebook, we are fine-tuning the [Gemma 2B-it](https://huggingface.co/google/gemma-2b-it/tree/main) model with SOTA ReFT to answer common questions about the Python programming language.

**TL;DR on ReFT**: ReFT methods operate on a frozen base model and learn task-specific interventions on hidden representations. In this notebook, a strong instance of the ReFT family, Low-rank Linear Subspace ReFT (LoReFT) is employed. LoReFT is a drop-in replacement for existing PEFTs and learns interventions that are 10x-50x more parameter-efficient than prior state-of-the-art PEFTs. Explore further details in the [paper](https://arxiv.org/abs/2404.03592).

# Install and Import Libraries 

In [1]:
!pip install git+https://github.com/stanfordnlp/pyreft.git

import os
import ast
import torch
import numpy as np
import torchvision
import pandas as pd
import transformers
import torch.nn as nn
from tqdm import tqdm
import torch.optim as optim
from torch.nn import Softplus
import torch.nn.functional as F
from IPython.display import display, Markdown
from sklearn.model_selection import train_test_split
from transformers import AutoModel, AutoModelForSeq2SeqLM, AutoModelForCausalLM
from pyreft import (get_reft_model, ReftConfig, ConsreftIntervention, TaskType, ReftTrainerForCausalLM, LoreftIntervention, ReftDataCollator, ReftSupervisedDataset)

Collecting git+https://github.com/stanfordnlp/pyreft.git
  Cloning https://github.com/stanfordnlp/pyreft.git to /tmp/pip-req-build-38oo0lzp
  Running command git clone --filter=blob:none --quiet https://github.com/stanfordnlp/pyreft.git /tmp/pip-req-build-38oo0lzp
  Resolved https://github.com/stanfordnlp/pyreft.git to commit 528d504b7a0e9d6e6931bde2a6e136f5ee90008d
  Preparing metadata (setup.py) ... [?25ldone
Collecting flash-attn>=2.5.6 (from pyreft==0.0.4)
  Downloading flash_attn-2.5.7.tar.gz (2.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.5/2.5 MB[0m [31m33.9 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l- \ | done
[?25hCollecting pyvene>=0.1.1 (from pyreft==0.0.4)
  Downloading pyvene-0.1.1-py3-none-any.whl.metadata (17 kB)
Collecting ipywidgets>=8.1.1 (from pyreft==0.0.4)
  Downloading ipywidgets-8.1.2-py3-none-any.whl.metadata (2.4 kB)
Collecting accelerate>=0.29.1 (from pyreft==0.0.4)
  Downloa

2024-04-14 17:42:34.363066: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-04-14 17:42:34.363201: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-04-14 17:42:34.502608: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


# Tokenizer

In [2]:
MODEL_PATH = "google/gemma-2b-it"
tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_PATH, token="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")
stop = tokenizer.eos_token
tokenizer.add_special_tokens({"pad_token": "<pad>"})

tokenizer_config.json:   0%|          | 0.00/2.16k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/888 [00:00<?, ?B/s]

0

# Configurations and Definitions

In [3]:
batch_size = 1
seed_value = 42
MAX_LENGTH = 512
device = torch.device("cuda")
torch.backends.cuda.enable_flash_sdp(False)
torch.backends.cuda.enable_mem_efficient_sdp(False)

def checkpoint(model, filename):
    torch.save(model.state_dict(), filename)

def resume(model, filename):
    model.load_state_dict(torch.load(filename))

def prepare_input(text, tokenizer=tokenizer):
    input_ids = []
    attention_mask = []
    inputs = tokenizer(text, max_length=MAX_LENGTH, return_attention_mask=False, return_tensors="pt", padding='max_length', truncation=True)
    #print(inputs["input_ids"].shape)
    return inputs["input_ids"]

def prepare_labels(text, tokenizer=tokenizer):
    input_ids = []
    attention_mask = []
    inputs = tokenizer(text, max_length=MAX_LENGTH, return_attention_mask=False, return_tensors="pt", padding='max_length', truncation=True)
    #print(inputs["input_ids"].shape)
    return inputs["input_ids"]

class AllDataset():
    def __init__(self, train_x, train_y):
        self.inputs = train_x.tolist()
        self.labels = train_y.tolist()

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, item):
        inputs_ = prepare_input(str(self.inputs[item]))
        label = prepare_labels(str(self.labels[item]))
        #print(inputs)
        return inputs_, label
    

# Reproducibility 
Sets value for random seed to produce similar result in each run.

In [4]:
np.random.seed(seed_value)
torch.manual_seed(seed_value)
torch.cuda.manual_seed_all(seed_value)

# Data

No training data is provided in this competition; in other words, we can use any openly available datasets for this competition. In this notebook, we will use two external datasets.

- [Dataset_Python_Question_Answer](http://https://www.kaggle.com/datasets/chinmayadatt/dataset-python-question-answer): Question and Answers are generated using Gemma. There are more than four hundred questions and their corresponding answers about Python programming. Questions are ranging from concepts like data-types, variables and keywords to regular-expression and threading.
- [Python FAQ ChatGPT+Gemini](http://https://www.kaggle.com/datasets/williamalabi/python-faq-chatgpt-gemini): Question and Answers from the FAQ section on the [Python website](https://docs.python.org/3/faq/) coupled with generated pairs from [ChatGPT](http://chat.openai.com/) and [Gemini](http://gemini.google.com/)

**Data Format:**

These datasets includes:
- `Questions`: Input questions.
- `Answers`: Output answers. This is also our **target** for this competition.

In [5]:
df1 = pd.read_csv("/kaggle/input/dataset-python-question-answer/Dataset_Python_Question_Answer.csv")
df2 = pd.read_csv("/kaggle/input/python-faq-chatgpt-gemini/Python FAQ Dataset.csv", encoding="Latin-1")
ans_lst = []
#preprocess answer strings in dataset-python-question-answer because they are presented with square brackets
for ans in df1['Answer']:
    ans_lst.append(' '.join(ast.literal_eval(ans)))
df1['Answer'] = ans_lst
print(df1["Answer"][0])
#Rename the columns to be consistent for pd.concat
df1 = df1.rename(columns={'Question': 'Questions'})
df1 = df1.rename(columns={'Answer': 'Answers'})
df = pd.concat([df1, df2], axis=0, ignore_index=True)
print(df)
print("done")
#Splitting data into Train and Test for Modeling
X_train, X_test, Y_train, Y_test = train_test_split(df["Questions"].values, df["Answers"].values, test_size=0.3, random_state=42)
#Creating train and test datasets with tokenization under the hood
train_dataset = AllDataset(X_train, Y_train)
test_dataset = AllDataset(X_test, Y_test)
#Creating the dataloaders for train and test
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

Sure, here's the difference between a variable and an object: **Variable:** * A variable is a named memory location that stores a single value. * It is a placeholder for a specific amount of data. * Variables can hold different values throughout the program. * They are declared using the `=` operator and assigned a value. **Object:** * An object is a complex data structure that contains data and methods. * It is an instance of a class. * Objects can have multiple variables and methods associated with them. * They are created using the `new` keyword and contain a copy of all the variables and methods of the class. In summary, a variable is a single piece of memory that stores a single value, while an object is a complex data structure that contains data and methods.
                                              Questions  \
0      What is the difference between a variable and...   
1      What is the difference between a built-in fun...   
2      What is the difference between the `prin

Let's examine a sample prompt. As the answers in our dataset are curated with **markdown** format, we will render the sample using `Markdown()` to properly visualize the formatting.

## Sample

In [6]:
def colorize_text(text):
    for word, color in zip(["Question", "Answer"],
                           ["red", "green"]):
        text = text.replace(f"{word}:", f"\n\n**<font color='{color}'>{word}:</font>**")
    return text

In [7]:
# Take a random sample
sample = df.iloc[10]
sample = "Question: " + str(sample['Questions']) + "Answer: " + str(sample['Answers'])
# Give colors to Instruction, Response and Category
sample = colorize_text(sample)

# Show sample in markdown
display(Markdown(str(sample)))



**<font color='red'>Question:</font>**  Explain the difference between a function and a method.

**<font color='green'>Answer:</font>** Sure, here is the difference between a function and a method. **Function:** * A function is a block of code that contains a set of instructions that perform a specific task. * A function is independent. It can be called independently from other parts of the code. * A function can return a value, but it can also perform multiple tasks without modifying its surrounding environment. * A function can be called by other parts of the code. **Method:** * A method is a block of code that is contained within a class. * A method is associated with a specific class and can only be called from within that class. * A method can be called by the class itself or by other objects that inherit from the class. * A method usually has one or more parameters, which provide data to the function. * A method typically performs a specific task, but it can also return a value or perform multiple tasks without modifying the surrounding environment. In simple terms: * **Function:** An operation or task performed on something. * **Method:** A set of instructions within a class performing a specific purpose.

# Modeling

<div align="center"><img src="https://i.ibb.co/Bqg9w3g/Gemma-Logo-no-background.png" width="300"></div>

**Gemma** is a collection of advanced open LLMs developed by **Google DeepMind** and other **Google teams**, derived from the same research and technology behind the **Gemini** models. They can be integrated into applications and run on various platforms including mobile devices and hosted services. Developers can customize Gemma models using tuning techniques to enhance their performance for specific tasks, offering more targeted and efficient generative AI solutions beyond text generation.

Gemma models are available in several sizes so we can build generative AI solutions based on your available computing resources, the capabilities you need, and where you want to run them.

| Parameters size | Tuned versions    | Intended platforms                 | Preset                 |
|-----------------|-------------------|------------------------------------|------------------------|
| 2B              | Pretrained        | Mobile devices and laptops         | `gemma_2b_en`          |
| 2B              | Instruction tuned | Mobile devices and laptops         | `gemma_instruct_2b_en` |
| 7B              | Pretrained        | Desktop computers and small servers| `gemma_7b_en`          |
| 7B              | Instruction tuned | Desktop computers and small servers| `gemma_instruct_7b_en` |

In this notebook, we will utilize the `Gemma 2b-it` model from Google's pretrained models on Hugging Face to answer common questions about the Python programming language. We will fine-tune our model using question-answer pairs with State-Of-The-Art Low-rank Linear Subspace Representation Fine-Tuning (LoReFT) method thus model will likely yield better results.

## Gemma Causal LM

The code below will build an end-to-end Gemma model for causal language modeling (hence the name `GemmaCausalLM`). A causal language model (LM) predicts the next token based on previous tokens. This task setup can be used to train the model unsupervised on plain text input or to autoregressively generate plain text similar to the data used for training. This task can be used for pre-training or fine-tuning a Gemma model to answer common questions about the Python programming language.

In [8]:
model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, token="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", torch_dtype=torch.bfloat16, low_cpu_mem_usage=True)
model.to(device)

config.json:   0%|          | 0.00/627 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/13.5k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.95G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/67.1M [00:00<?, ?B/s]

Gemma's activation function should be approximate GeLU and not exact GeLU.
Changing the activation function to `gelu_pytorch_tanh`.if you want to use the legacy `gelu`, edit the `model.config` to set `hidden_activation=gelu`   instead of `hidden_act`. See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/137 [00:00<?, ?B/s]

GemmaForCausalLM(
  (model): GemmaModel(
    (embed_tokens): Embedding(256000, 2048, padding_idx=0)
    (layers): ModuleList(
      (0-17): 18 x GemmaDecoderLayer(
        (self_attn): GemmaSdpaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=256, bias=False)
          (v_proj): Linear(in_features=2048, out_features=256, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (rotary_emb): GemmaRotaryEmbedding()
        )
        (mlp): GemmaMLP(
          (gate_proj): Linear(in_features=2048, out_features=16384, bias=False)
          (up_proj): Linear(in_features=2048, out_features=16384, bias=False)
          (down_proj): Linear(in_features=16384, out_features=2048, bias=False)
          (act_fn): PytorchGELUTanh()
        )
        (input_layernorm): GemmaRMSNorm()
        (post_attention_layernorm): GemmaRMSNorm()
      )
    )
    (norm): GemmaR

# Embarking on LoReFT's Fine-Tuning Journey

We're embarking on an exciting journey to fine-tune our foundational model with the innovative Low-rank Linear Subspace ReFT (LoReFT) technique.

**Unveiling LoReFT**

LoReFT is a transformative approach that tailors language models for specific tasks through precise, controlled interventions. Drawing from the linear representation hypothesis, it utilizes distributed interchange interventions (DIIs) for a streamlined, task-specific model refinement.

Central to LoReFT's methodology is the equation:
$$ΦLoReFT(h) = h + Rᵀ(Wh + b - Rh)$$
This formula adjusts the hidden representation 'h' within a low-rank subspace 'R', harmonizing it with the task-specific projections 'Wh + b'. LoReFT's brilliance lies in its capacity to steer the model towards accurate task predictions while preserving the original model parameters, thus offering a parameter-efficient and interpretable adaptation process.

**LoReFT's Operational Elegance**

LoReFT's operation is a dance of learning interventions that nudge the model into precise task label predictions. By harnessing the linear subspace carved by 'R', LoReFT deftly edits the representation to resonate with the task-specific features encapsulated in 'Wh + b'.

The learning parameters of LoReFT, namely 'R', 'W', and 'b', are meticulously chosen. 'R' is particularly constrained to be a low-rank matrix with orthonormal rows, ensuring a compact yet effective representation within the subspace. Remarkably, the language model's parameters remain untouched during this adaptation dance.

**Choosing LoReFT Over PEFTs**

Opting for LoReFT over traditional fine-tuning methods like PEFTs brings forth a symphony of advantages, including memory conservation and enhanced interpretability. LoReFT's interventions are a distilled causal abstraction of the training task, leaving the model weights pristine. Moreover, the intervention site's search space is vast and versatile, offering unparalleled flexibility in model refinement.

Prepare a foundational model for fine-tuning with a ReFT method by wrapping the base model and ReFT configuration with `get_reft_model`. In this notebook, we are using [ConsreftIntervention](https://github.com/stanfordnlp/pyreft/blob/main/pyreft/interventions.py#L85) (Constant LoReFT Intervention) which is simpler than the original LoReFT described in the [paper](https://arxiv.org/abs/2404.03592)

In [9]:
# wrap the model with rank-1 constant reft
reft_config = ReftConfig(representations={
    "component": "model.layers[15].mlp.output", # string access to the model component
    "intervention": LoreftIntervention(
    embed_dim=model.config.hidden_size, low_rank_dimension=1)})

#wrap the base model with get_reft_model
model = get_reft_model(model, reft_config)
model.print_trainable_parameters()
optimizer = optim.Adam(model.parameters(), lr=2e-3)
n_epochs = 5

early_stop_thresh = 2
best_loss = np.inf
best_epoch = -1

trainable intervention params: 4,097 || trainable model params: 0
model params: 2,506,172,416 || trainable%: 0.00016347638230489566


**Notice** that, the number of trainable parameters is reduced from ~$2.5$ billions to ~$4.1$ thousands after wrapping with LoReFT.

# Inference before Fine-Tuning

Before we fine-tune with LoReFT, let's try to answer a couple of questions from Copilot and see how it responds.

> As this model is not yet fine-tuned with LoReft, you will notice that the model's responses are good but not extraordinary with the first sample generation being incomplete.

## Sample 1

In [10]:
prompt = tokenizer("Considering Python’s extensive use in scientific computing and data analysis, how might its design principles affect the way scientists approach problem-solving in their research?", return_tensors="pt").to("cuda")
base_unit_location = prompt["input_ids"].shape[-1] - 1  # last position
_, reft_response = model.generate(
    prompt, unit_locations={"sources->base": (None, [[[base_unit_location]]])},
    intervene_on_prompt=True, max_new_tokens=512, do_sample=False, 
    eos_token_id=tokenizer.eos_token_id, early_stopping=True
)
gen_text = tokenizer.decode(reft_response[0], skip_special_tokens=True)
parts = gen_text.split("\n", 1)

# First part is the question, second part is the answer
question = parts[0].strip()
answer = parts[1].strip() if len(parts) > 1 else ""

# Take a random sample
sample = df.iloc[10]
sample = "Question: " + str(question) + "Answer: " + str(answer)
# Give colors to Instruction, Response and Category
sample = colorize_text(sample)

# Show sample in markdown
display(Markdown(str(sample)))





**<font color='red'>Question:</font>** Considering Python’s extensive use in scientific computing and data analysis, how might its design principles affect the way scientists approach problem-solving in their research?

**<font color='green'>Answer:</font>** Sure, here's how the design principles of Python can affect the way scientists approach problem-solving in their research:

**1. Flexibility and Extensibility:**

- Python's dynamic typing and support for modules and packages allow scientists to choose the tools and libraries they need for their specific research tasks, promoting flexibility and customization.
- This flexibility can help scientists explore different approaches and find efficient solutions that might not be readily available with other languages.

**2. Conciseness and Readability:**

- Python's clear and concise syntax makes it easier for scientists to express complex scientific ideas and algorithms, improving code readability and maintainability.
- This can lead to faster debugging and collaboration among researchers.

**3. Data Structures and Algorithms:**

- Python provides built-in data structures like lists, dictionaries, and tuples, which are efficient for various data manipulation and analysis tasks.
- Additionally, the extensive collection of scientific libraries, such as NumPy, SciPy, and Pandas, offers optimized algorithms for numerical computing, data analysis, and machine learning.

**4. Scientific Computing and Data Analysis:**

- Python's extensive use in scientific computing and data analysis has led to the development of specific libraries and modules, such as NumPy, SciPy, and Pandas, specifically designed for scientific data handling and analysis.
- These libraries provide optimized functions for numerical computations, data manipulation, statistical analysis, and visualization, streamlining the research process.

**5. Collaboration and Sharing:**

- Python's open-source nature and extensive libraries facilitate collaboration among researchers, allowing them to share code, libraries, and data easily.
- This promotes transparency and accelerates scientific progress by enabling researchers to build upon each other's work.

**6. Scalability and Performance:**

- Python's ability to handle large datasets efficiently makes it suitable for tackling complex research problems that require extensive data analysis.
- Its multithreading capabilities allow scientists to perform multiple tasks simultaneously, improving the overall research pace.

**7. Integration with Other Tools:**

- Python integrates seamlessly with other tools and technologies, such as MATLAB, R, and SQL, enabling scientists to leverage existing expertise and workflows.
- This facilitates data exchange, analysis, and visualization, streamlining the research process.

Overall, the design principles of Python have significantly influenced how scientists approach problem-solving in their research. Its flexibility, conciseness, data structures, algorithms, scientific computing libraries, collaboration features, scalability, and integration capabilities empower researchers to

## Sample 2

In [11]:
prompt = tokenizer("In what ways has Python’s ‘Zen’ philosophy shaped modern software development practices beyond language design?", return_tensors="pt").to("cuda")
base_unit_location = prompt["input_ids"].shape[-1] - 1  # last position
_, reft_response = model.generate(
    prompt, unit_locations={"sources->base": (None, [[[base_unit_location]]])},
    intervene_on_prompt=True, max_new_tokens=512, do_sample=False, 
    eos_token_id=tokenizer.eos_token_id, early_stopping=True
)
gen_text = tokenizer.decode(reft_response[0], skip_special_tokens=True)
parts = gen_text.split("\n", 1)

# First part is the question, second part is the answer
question = parts[0].strip()
answer = parts[1].strip() if len(parts) > 1 else ""

# Take a random sample
sample = df.iloc[10]
sample = "Question: " + str(question) + "Answer: " + str(answer)
# Give colors to Instruction, Response and Category
sample = colorize_text(sample)

# Show sample in markdown
display(Markdown(str(sample)))



**<font color='red'>Question:</font>** In what ways has Python’s ‘Zen’ philosophy shaped modern software development practices beyond language design?

**<font color='green'>Answer:</font>** Zen philosophy emphasizes simplicity, clarity, and self-reflection. These principles have influenced the following ways in modern software development practices:

**1. Design patterns:** Zen encourages designers to focus on the core problem and avoid unnecessary complexity. This has led to the widespread adoption of design patterns in software development, such as the Singleton design pattern and the Factory design pattern.

**2. Testing:** Zen emphasizes the importance of clear and concise test cases. This has influenced the development of unit testing frameworks and the use of test-driven development (TDD) principles.

**3. Documentation:** Zen encourages clear and concise documentation that is easy to understand. This has influenced the use of natural language processing (NLP) tools for documentation generation and the adoption of documentation tools like Swagger and OpenAPI.

**4. Code quality:** Zen encourages writing clean and efficient code that is free of errors. This has influenced the use of static code analysis tools and the adoption of coding best practices.

**5. Agile development:** Zen encourages iterative and incremental development, which has influenced the agile software development (ASD) methodology.

**6. Design thinking:** Zen encourages empathy and understanding of the user. This has influenced the adoption of design thinking principles in software development, such as user research and prototyping.

**7. Continuous integration and continuous delivery (CI/CD):** Zen encourages automation and continuous flow of software development. This has influenced the use of CI/CD tools and the adoption of continuous integration and continuous delivery (CI/CD) practices.

Overall, Python's Zen philosophy has had a significant impact on modern software development practices by promoting clear, concise, and efficient code, leading to more maintainable and scalable software.

## Training

The training loop is structured to iteratively adjust the model’s weights based on the loss incurred during each epoch. It alternates between training and validation phases, employing early stopping and checkpointing to optimize performance and prevent overfitting.

In [12]:
for epoch in tqdm(range(n_epochs), leave=True):
    # Break for debugging purposes (can be removed later)
    #break
    # Set the LoReFT wrapped model to training mode
    model.train()
    # Initialize total training loss for the epoch
    total_loss = 0.0

    # Iterate over training batches
    for inputs, targets in train_loader:
        # Reduce dimension if unnecessary and move data to the appropriate device (GPU)
        inputs = inputs.squeeze(1)
        targets = targets.squeeze(1)
        inputs = inputs.to(device)
        targets = targets.to(device)
        # Forward pass: get predictions and calculate loss
        y_pred = model({"input_ids": inputs}, labels=targets)
        loss = y_pred[1].loss
        # Clear gradients before backpropagation
        optimizer.zero_grad()
        # Backward pass: propagate gradients
        loss.backward()
        # Update model weights
        optimizer.step()
        # Accumulate loss for the epoch
        total_loss += loss.item()
        # Move data back to CPU (optional, depends on memory constraints)
        inputs = inputs.to('cpu')
        targets = targets.to('cpu')
        # Free memory (optional, might be handled automatically by PyTorch)
        del inputs, targets

    # Calculate average training loss for the epoch
    average_train_loss = total_loss / len(train_loader)
    print(average_train_loss)

    # Validation loop
    # Set the model to evaluation mode
    model.eval()
    acc = 0
    count = 0
    # Clear GPU memory cache (if using GPU)
    torch.cuda.empty_cache()

    # Iterate over validation batches
    total_val_loss = 0.0
    for inputs, targets in test_loader:
        inputs = inputs.squeeze(1)
        targets = targets.squeeze(1)
        inputs = inputs.to(device)
        targets = targets.to(device)
        y_pred = model({"input_ids": inputs}, labels=targets.to(device))
        loss = y_pred[1].loss
        total_val_loss += loss.item()
        inputs = inputs.to('cpu')
        targets = targets.to('cpu')
        del inputs, targets
    average_val_loss = total_val_loss / len(test_loader)
    print("Epoch %d: model loss %.4f" % (epoch, average_val_loss))
    print("Best Loss So Far: %.4f" % best_loss)
    # Early stopping and checkpointing logic
    if average_val_loss < best_loss:
        best_loss = average_val_loss
        best_epoch = epoch
        checkpoint(model, "best_model.pth")
    elif epoch - best_epoch > early_stop_thresh:
        print("Early stopped training at epoch %d, with best loss of %.4f" % (epoch, best_loss))
        break  # terminate the training loop

# Resume training from the best checkpoint
resume(model, "best_model.pth")

  0%|          | 0/5 [00:00<?, ?it/s]

7.506237087913604
Epoch 0: model loss 7.0718
Best Loss So Far: inf


 20%|██        | 1/5 [18:41<1:14:45, 1121.35s/it]

7.115029996846042
Epoch 1: model loss 6.8118
Best Loss So Far: 7.0718


 40%|████      | 2/5 [37:40<56:35, 1131.72s/it]  

7.144760753181471


 60%|██████    | 3/5 [56:11<37:24, 1122.21s/it]

Epoch 2: model loss 6.9178
Best Loss So Far: 6.8118
7.050453360092862


 80%|████████  | 4/5 [1:14:41<18:37, 1117.70s/it]

Epoch 3: model loss 6.9608
Best Loss So Far: 6.8118
7.090340520408644


 80%|████████  | 4/5 [1:33:12<23:18, 1398.24s/it]

Epoch 4: model loss 7.0844
Best Loss So Far: 6.8118
Early stopped training at epoch 4, with best loss of 6.8118





# Inference after fine-tuning

Let's see how our fine-tuned model responds to the same questions we asked before fine-tuning the model.

## Sample 1

In [13]:
prompt = tokenizer("Considering Python’s extensive use in scientific computing and data analysis, how might its design principles affect the way scientists approach problem-solving in their research?", return_tensors="pt").to("cuda")
base_unit_location = prompt["input_ids"].shape[-1] - 1  # last position
_, reft_response = model.generate(
    prompt, unit_locations={"sources->base": (None, [[[base_unit_location]]])},
    intervene_on_prompt=True, max_new_tokens=512, do_sample=False, 
    eos_token_id=tokenizer.eos_token_id, early_stopping=True
)
gen_text = tokenizer.decode(reft_response[0], skip_special_tokens=True)
parts = gen_text.split("\n", 1)

# First part is the question, second part is the answer
question = parts[0].strip()
answer = parts[1].strip() if len(parts) > 1 else ""

# Take a random sample
sample = df.iloc[10]
sample = "Question: " + str(question) + "Answer: " + str(answer)
# Give colors to Instruction, Response and Category
sample = colorize_text(sample)

# Show sample in markdown
display(Markdown(str(sample)))



**<font color='red'>Question:</font>** Considering Python’s extensive use in scientific computing and data analysis, how might its design principles affect the way scientists approach problem-solving in their research?

**<font color='green'>Answer:</font>** **Design principles of Python that might affect problem-solving:**

* **Immutability:** Immutable data structures prevent changes to existing data, making it easier to reason about and debug code. This can be particularly helpful for complex scientific projects involving large datasets.
* **Conciseness:** Python's syntax is often more concise than other languages, making it easier to express complex ideas and write efficient code.
* **Flexibility:** Python's diverse ecosystem of libraries and modules allows users to choose the tools that best suit their specific needs, making it a versatile tool for scientific computing.
* **Concurrency:** Python's built-in support for concurrency allows multiple tasks to run simultaneously, making it efficient for handling large datasets and complex simulations.

**How these principles might affect problem-solving:**

* **Immutability:** Immutable data structures make it easier to reason about the data and ensure that changes are handled correctly. This can lead to more robust and reliable code, especially when working with large datasets.
* **Conciseness:** Python's syntax can make it easier to express complex ideas and write efficient code, which can save time and effort.
* **Flexibility:** The vast ecosystem of libraries and modules in Python allows users to choose the tools that best suit their specific needs, making it a versatile tool for scientific computing.
* **Concurrency:** Python's built-in support for concurrency allows multiple tasks to run simultaneously, making it more efficient for handling large datasets and complex simulations.

**Overall, the design principles of Python can significantly impact how scientists approach problem-solving in their research. By providing tools that facilitate reasoning, conciseness, flexibility, and concurrency, Python can empower scientists to tackle complex scientific challenges more effectively.**

## Sample 2

In [14]:
prompt = tokenizer("In what ways has Python’s ‘Zen’ philosophy shaped modern software development practices beyond language design?", return_tensors="pt").to("cuda")
base_unit_location = prompt["input_ids"].shape[-1] - 1  # last position
_, reft_response = model.generate(
    prompt, unit_locations={"sources->base": (None, [[[base_unit_location]]])},
    intervene_on_prompt=True, max_new_tokens=512, do_sample=False, 
    eos_token_id=tokenizer.eos_token_id, early_stopping=True
)
gen_text = tokenizer.decode(reft_response[0], skip_special_tokens=True)
parts = gen_text.split("\n", 1)

# First part is the question, second part is the answer
question = parts[0].strip()
answer = parts[1].strip() if len(parts) > 1 else ""

# Take a random sample
sample = df.iloc[10]
sample = "Question: " + str(question) + "Answer: " + str(answer)
# Give colors to Instruction, Response and Category
sample = colorize_text(sample)

# Show sample in markdown
display(Markdown(str(sample)))



**<font color='red'>Question:</font>** In what ways has Python’s ‘Zen’ philosophy shaped modern software development practices beyond language design?

**<font color='green'>Answer:</font>** Zen philosophy emphasizes simplicity, clarity, and self-reflection. These principles have influenced the following ways in modern software development practices:

**1. Design patterns:** Zen encourages designers to focus on the core problem and avoid unnecessary complexity. This has led to the widespread adoption of design patterns in software development, such as the Singleton design pattern and the Factory design pattern.

**2. Testing:** Zen emphasizes the importance of clear and concise test cases. This has influenced the development of unit testing frameworks and the use of test-driven development (TDD) principles.

**3. Documentation:** Zen encourages clear and concise documentation that is easy to understand. This has influenced the use of natural language processing (NLP) tools for documentation generation and the adoption of documentation tools like Swagger and OpenAPI.

**4. Code quality:** Zen encourages writing clean and efficient code that is free of errors. This has influenced the use of static code analysis tools and the adoption of coding best practices.

**5. Agile development:** Zen encourages iterative and incremental development, which has influenced the agile software development (ASD) methodology.

**6. Design thinking:** Zen encourages empathy and understanding of the user. This has influenced the adoption of design thinking principles in software development, such as user research and prototyping.

**7. Continuous integration and continuous delivery (CI/CD):** Zen encourages automation and continuous flow of software development. This has influenced the use of CI/CD tools and the adoption of continuous integration and continuous delivery (CI/CD) practices.

Overall, Python's Zen philosophy has had a significant impact on modern software development practices by promoting clear, concise, and efficient code, leading to more maintainable, scalable, and user-friendly software.

# Conclusion

The answers are now excellent. Still there is ample room for improvement. Here are some tips to improve performance:

- Experiment with the larger version of **Gemma** (7B).
- Increase `MAX_LENGTH`.
- Utilize a learning rate scheduler.

# Reference
* [ReFT: Representation Fine-tuning for Language Models](https://arxiv.org/pdf/2404.03592.pdf)
* [PyReFT library](https://github.com/stanfordnlp/pyreft?tab=readme-ov-file)
* [Google - Gemma Release on Hugging Face](https://huggingface.co/collections/google/gemma-release-65d5efbccdbb8c4202ec078b)