## Step 0: Mounting Google Drive and Importing Libraries

In [None]:
from google.colab import drive
drive.mount('/content/drive')
%cd /content/drive/MyDrive/multimodal-xray-agent
!ls

In [48]:
import torch
import json
from huggingface_hub import login
from datasets import load_dataset, DatasetDict, load_from_disk, Dataset
from transformers import AutoTokenizer

In [None]:
login()

## Step 1: Verifying GPU and Environment

In [4]:
if torch.cuda.is_available():
    device_name = torch.cuda.get_device_name(0)
    device = torch.device("cuda")
    print(f"GPU detected: {device_name}")
else:
    device = torch.device("cpu")
    print("GPU not detected. Falling back to CPU.")

print(f"Running on device: {device}")

GPU detected: Tesla T4
Running on device: cuda


## Step 2: Loading QA Dataset

In [5]:
# Copy file from GDrive to Colab local runtime
!cp /content/drive/MyDrive/multimodal-xray-agent/data/qapairs/top_700_qa_pairs.jsonl /content/top_700_qa_pairs.jsonl

In [10]:
# Load the data manually
with open("/content/top_700_qa_pairs.jsonl", "r") as f:
    data = [json.loads(line) for line in f]

In [13]:
# Convert to Hugging Face Dataset
dataset = Dataset.from_list(data)

In [18]:
len(dataset)

700

## Step 3: Formatting Dataset for Supervised Fine-Tuning

In this step, we format each QA pair into an instruction-following format expected by our model during training. The `format_example()` function wraps the question and answer into a structured prompt using markdown-style headings `(### Question: / ### Answer:)`, which helps the model learn the instruction-response format more reliably.

In essence, this function takes a QA pair and transforms it into a single string where the question and answer are clearly labeled with headings and separated by newlines. This structured format helps the language model learn the relationship between questions and answers more effectively during the fine-tuning process.

We then use the `.map()` function to apply this transformation to the entire dataset, removing the original "question" and "answer" fields and keeping only the unified "text" field for training.

This is the final format that will be tokenized and fed into the model for fine-tuning.

In [19]:
# Format each QA pair into an instruction-following prompt
def format_example(example):
    return {
        "text": f"### Question:\n{example['question']}\n\n### Answer:\n{example['answer']}"
    }

In [None]:
# Apply formatting to all samples in the dataset
formatted_dataset = dataset.map(format_example, remove_columns=["uuid", "question", "answer"])

In [None]:
formatted_dataset[0]

In [None]:
formatted_dataset.features

## Step 4:  Loading the Tokenizer

In [None]:
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B-Instruct", padding_side="left")

In [27]:
tokenizer.pad_token = tokenizer.eos_token

## Step 5: Tokenizing the Dataset for Causal Language Modeling

We now tokenize the dataset using the LLaMA tokenizer. Each example is formatted in the style:

```
### Question:
{question text}

### Answer:
{answer text}
```

During tokenization, each example is converted into three key fields:

- **input_ids**: Token IDs representing the full prompt (`question + answer`) to be fed into the model.
- **attention_mask**: Binary vector indicating which tokens are real (1) vs. padding (0).
- **labels**: Target tokens that the model should try to predict during training.

---

#### Why Label Masking?

In causal language modeling (CLM), the model learns by **predicting the next token**, one step at a time. To train the model to *only* learn to generate the **answer** (not the question or prompt), we **mask the prompt portion** of the labels using `-100`. This tells the loss function to **ignore these tokens** during gradient computation.

The logic is:

```python
labels = [-100] * len(prompt_ids) + result["input_ids"][len(prompt_ids):]
```

- `-100` is the special ignored index in PyTorch loss functions.
- The answer portion (after the prompt) remains unmasked and is used for learning.
- We truncate or pad the label sequence to `max_length = 512` for stability.

---

#### Tokenization Config

```python
tokenizer(
    example["text"],
    truncation=True,         # Cut off long sequences safely
    padding="max_length",    # Pad all to uniform length
    max_length=512           # Max allowed length (safe for 3B models)
)
```

We chose `max_length = 512` to ensure future compatibility with longer inference prompts and outputs (e.g., definitions, expanded context). This also keeps GPU memory usage manageable and prevents truncating informative answers.


In [28]:
def tokenize(example):
    # Compute prompt length so we know what to mask
    prompt_split = example["text"].split("### Answer:\n")
    prompt_ids = tokenizer(prompt_split[0] + "### Answer:\n")["input_ids"]

    result = tokenizer(
        example["text"],
        truncation=True,
        padding="max_length",
        max_length=512,
        return_tensors=None,
    )

    labels = [-100] * len(prompt_ids) + result["input_ids"][len(prompt_ids):]
    labels = labels[:512] + [-100] * max(0, 512 - len(labels))

    result["labels"] = labels
    return result

In [None]:
tokenized_dataset = formatted_dataset.map(tokenize, batched=False, remove_columns=["text"])

In [None]:
tokenized_dataset[0]

## Step 6: Splitting the Tokenized Dataset into Train and Validation Sets

In [32]:
split_dataset = tokenized_dataset.train_test_split(test_size=0.1, seed=42)

In [33]:
dataset_dict = DatasetDict({
    "train": split_dataset["train"],
    "validation": split_dataset["test"]
})

In [34]:
print("Training examples:", len(dataset_dict["train"]))
print("Validation examples:", len(dataset_dict["validation"]))

Training examples: 630
Validation examples: 70


In [36]:
# Sanity check
print(tokenizer.decode(tokenized_dataset[0]["input_ids"]))

<|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|>

In [37]:
print(tokenizer.decode([t if t != -100 else tokenizer.pad_token_id for t in tokenized_dataset[0]["labels"]]))

<|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|>

## Step 7: Saving the Tokenized Dataset to Disk

In [None]:
save_path = "./data/tokenized_dataset"

dataset_dict.save_to_disk(save_path)

## Step 8: Verifying the Saved Dataset

In [54]:
# Path to the saved tokenized dataset
load_path = "file://./data/tokenized_dataset"

In [55]:
# Load the dataset from disk
loaded_dataset = load_from_disk(load_path)

In [56]:
# Sanity check: view one example
print(loaded_dataset["train"][0])

{'input_ids': [128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 1

In [60]:
print(len(loaded_dataset["train"]))
print(len(loaded_dataset["validation"]))

630
70
