<a href="https://colab.research.google.com/github/schumbar/CMPE297/blob/main/assignment_06/ShawnChumbar_Assignment06_PartB.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Assignment 06 - Part B
## Part B: Implement the "Textbooks Are All You Need" Case Study

### Assignment Description

The second part of the assignment requires recreating the "Textbooks Are All You Need" case study using a smaller dataset. The implementation should be done in Google Colab Pro with an A100 GPU for optimal performance. You will use your own dataset and process a smaller amount of data to simplify the experiment while maintaining the integrity of the methodology. Resources available for this task include a [YouTube video](https://www.youtube.com/watch?v=gmFi6W8DPdM), the [GitHub repository](https://github.com/jina-ai/textbook), the [original Colab notebook](https://colab.research.google.com/drive/1T4IfGfDJ8uxgU8XBPpMZivw_JThzdQim?usp=sharing), and the [research paper PDF](https://arxiv.org/pdf/2306.11644.pdf). The final submission should include a Colab notebook containing the complete implementation, input and output files, and additional artifacts demonstrating your results. Assistance from GPT-4 is recommended to handle complex parts of the implementation and ensure clarity in execution.

### References:
1. [YouTube video](https://www.youtube.com/watch?v=gmFi6W8DPdM)
2. [GitHub repository](https://github.com/jina-ai/textbook)
3. [Original Colab notebook](https://colab.research.google.com/drive/1T4IfGfDJ8uxgU8XBPpMZivw_JThzdQim?usp=sharing)
4. [research paper PDF](https://arxiv.org/pdf/2306.11644.pdf)



## Install Required Packages

In [None]:
# Install required packages
!pip install transformers datasets torch numpy wandb huggingface_hub

Collecting datasets
  Downloading datasets-3.1.0-py3-none-any.whl.metadata (20 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.9.0,>=2023.1.0 (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.9.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.1.0-py3-none-any.whl (480 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m480.6/480.6 kB[0m [31m20.8 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m10.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2024.9.0-py3-none-any.whl 

## Import Packages

In [None]:
import torch
import numpy as np
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    TrainingArguments,
    Trainer,
    DataCollatorForLanguageModeling,
    TrainerCallback
)
from datasets import Dataset, load_dataset
import wandb

## Set Device as Cuda

In [None]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


## Define Helper Functions

In [None]:
import re
def preprocess_dataset(dataset):
    """Preprocess the dataset for better quality"""

    def clean_text(text):
        # Remove excessive newlines
        text = re.sub(r'\n{3,}', '\n\n', text)
        # Ensure proper indentation
        lines = text.split('\n')
        cleaned_lines = []
        for line in lines:
            if line.strip():
                cleaned_lines.append(line)
        return '\n'.join(cleaned_lines)

    return dataset.map(lambda x: {"text": clean_text(x["text"])})

In [None]:
def create_basic_concepts():
    """Generate basic Python programming concepts."""
    return [
        # Variables and Types
        """
        # Chapter 1: Python Fundamentals - Variables and Types

        Python is a dynamically typed language, which means you don't need to declare
        variable types explicitly. Let's explore different data types and their operations.

        ## Numeric Types

        def demonstrate_numeric_operations(x: int, y: float) -> None:
            '''
            Demonstrates various numeric operations in Python.

            Args:
                x (int): An integer value
                y (float): A floating-point value
            '''
            # Integer operations
            print(f"Integer division: {x // 2}")
            print(f"Modulus: {x % 2}")
            print(f"Power: {x ** 2}")

            # Float operations
            print(f"Float division: {y / 2}")
            print(f"Rounded float: {round(y, 2)}")

            # Type conversion
            print(f"Float to int: {int(y)}")
            print(f"Int to float: {float(x)}")

        # Example usage:
        demonstrate_numeric_operations(10, 15.7)
        """,

        # String Operations
        """
        # Chapter 2: String Manipulation and Formatting

        Strings are one of the most commonly used data types in Python. Understanding
        string manipulation is crucial for text processing and data formatting.

        def string_toolkit(text: str) -> dict:
            '''
            Demonstrates various string operations.

            Args:
                text (str): Input string to process

            Returns:
                dict: Dictionary containing various string operations results
            '''
            results = {
                'length': len(text),
                'uppercase': text.upper(),
                'lowercase': text.lower(),
                'capitalized': text.capitalize(),
                'words': text.split(),
                'word_count': len(text.split()),
                'reversed': text[::-1],
                'stripped': text.strip(),
                'is_alpha': text.isalpha(),
                'is_numeric': text.isnumeric()
            }

            # Find all occurrences of a character
            char_count = {}
            for char in text:
                char_count[char] = text.count(char)
            results['char_frequency'] = char_count

            return results

        # Example usage:
        text_analysis = string_toolkit("Hello Python Programming!")
        for operation, result in text_analysis.items():
            print(f"{operation}: {result}")
        """
    ]

def create_data_structures():
    """Generate content about Python data structures."""
    return [
        # Advanced List Operations
        """
        # Chapter 3: Advanced List Operations and List Comprehension

        Lists are versatile data structures that support various operations and
        comprehension techniques for elegant data manipulation.

        class ListManipulator:
            '''
            A class demonstrating advanced list operations and comprehensions.
            '''

            @staticmethod
            def filter_and_transform(numbers: list) -> dict:
                '''
                Demonstrates various list operations and comprehensions.

                Args:
                    numbers (list): List of numbers to process

                Returns:
                    dict: Results of various list operations
                '''
                results = {
                    'squares': [x**2 for x in numbers],
                    'even_numbers': [x for x in numbers if x % 2 == 0],
                    'odd_numbers': [x for x in numbers if x % 2 != 0],
                    'positive_numbers': [x for x in numbers if x > 0],
                    'negative_numbers': [x for x in numbers if x < 0],
                    'divisible_by_3': [x for x in numbers if x % 3 == 0]
                }

                # Nested list comprehension
                matrix = [[i*j for j in range(5)] for i in range(5)]
                results['multiplication_table'] = matrix

                # Complex transformations
                results['complex_transform'] = [(x, x**2, x**3) for x in numbers if x > 0]

                return results

            @staticmethod
            def list_operations(lst1: list, lst2: list) -> dict:
                '''
                Demonstrates list operations between two lists.

                Args:
                    lst1 (list): First list
                    lst2 (list): Second list

                Returns:
                    dict: Results of various list operations
                '''
                return {
                    'concatenated': lst1 + lst2,
                    'unique_elements': list(set(lst1 + lst2)),
                    'common_elements': list(set(lst1) & set(lst2)),
                    'different_elements': list(set(lst1) ^ set(lst2)),
                    'lst1_unique': list(set(lst1) - set(lst2)),
                    'lst2_unique': list(set(lst2) - set(lst1))
                }

        # Example usage:
        numbers = [-5, -2, 0, 1, 3, 4, 6, 8, 9, 10]
        manipulator = ListManipulator()
        results = manipulator.filter_and_transform(numbers)
        for operation, result in results.items():
            print(f"{operation}: {result}")
        """
    ]

def create_algorithms():
    """Generate content about algorithms."""
    return [
        # Sorting Algorithms
        """
        # Chapter 4: Advanced Sorting Algorithms

        Understanding different sorting algorithms and their implementations helps in
        choosing the right tool for specific scenarios.

        class SortingAlgorithms:
            '''
            Implementation of various sorting algorithms with complexity analysis.
            '''

            @staticmethod
            def merge_sort(arr: list) -> list:
                '''
                Implements merge sort algorithm.
                Time Complexity: O(n log n)
                Space Complexity: O(n)

                Args:
                    arr (list): List to sort

                Returns:
                    list: Sorted list
                '''
                if len(arr) <= 1:
                    return arr

                mid = len(arr) // 2
                left = SortingAlgorithms.merge_sort(arr[:mid])
                right = SortingAlgorithms.merge_sort(arr[mid:])

                return SortingAlgorithms._merge(left, right)

            @staticmethod
            def _merge(left: list, right: list) -> list:
                '''Helper method for merge sort.'''
                result = []
                i = j = 0

                while i < len(left) and j < len(right):
                    if left[i] <= right[j]:
                        result.append(left[i])
                        i += 1
                    else:
                        result.append(right[j])
                        j += 1

                result.extend(left[i:])
                result.extend(right[j:])
                return result

            @staticmethod
            def quick_sort(arr: list) -> list:
                '''
                Implements quick sort algorithm.
                Time Complexity: O(n log n) average, O(n²) worst
                Space Complexity: O(log n)

                Args:
                    arr (list): List to sort

                Returns:
                    list: Sorted list
                '''
                if len(arr) <= 1:
                    return arr

                pivot = arr[len(arr) // 2]
                left = [x for x in arr if x < pivot]
                middle = [x for x in arr if x == pivot]
                right = [x for x in arr if x > pivot]

                return SortingAlgorithms.quick_sort(left) + middle + SortingAlgorithms.quick_sort(right)

            @staticmethod
            def insertion_sort(arr: list) -> list:
                '''
                Implements insertion sort algorithm.
                Time Complexity: O(n²)
                Space Complexity: O(1)

                Args:
                    arr (list): List to sort

                Returns:
                    list: Sorted list
                '''
                for i in range(1, len(arr)):
                    key = arr[i]
                    j = i - 1
                    while j >= 0 and arr[j] > key:
                        arr[j + 1] = arr[j]
                        j -= 1
                    arr[j + 1] = key
                return arr

        # Example usage and comparison
        numbers = [64, 34, 25, 12, 22, 11, 90]
        sorter = SortingAlgorithms()

        # Compare different sorting methods
        merge_sorted = sorter.merge_sort(numbers.copy())
        quick_sorted = sorter.quick_sort(numbers.copy())
        insertion_sorted = sorter.insertion_sort(numbers.copy())

        print(f"Original: {numbers}")
        print(f"Merge sort: {merge_sorted}")
        print(f"Quick sort: {quick_sorted}")
        print(f"Insertion sort: {insertion_sorted}")
        """
    ]

def create_advanced_topics():
    """Generate content about advanced Python topics."""
    return [
        # Design Patterns
        """
        # Chapter 5: Design Patterns in Python

        Design patterns are reusable solutions to common problems in software design.
        Let's explore some common patterns implemented in Python.

        ## Singleton Pattern

        class Singleton:
            '''
            Implements the Singleton pattern ensuring only one instance exists.
            '''
            _instance = None

            def __new__(cls):
                if cls._instance is None:
                    cls._instance = super().__new__(cls)
                return cls._instance

            def __init__(self):
                self.value = None

        ## Factory Pattern

        class Animal:
            def speak(self):
                pass

        class Dog(Animal):
            def speak(self):
                return "Woof!"

        class Cat(Animal):
            def speak(self):
                return "Meow!"

        class AnimalFactory:
            @staticmethod
            def create_animal(animal_type: str) -> Animal:
                '''
                Creates an animal instance based on type.

                Args:
                    animal_type (str): Type of animal to create

                Returns:
                    Animal: Instance of specified animal
                '''
                if animal_type.lower() == "dog":
                    return Dog()
                elif animal_type.lower() == "cat":
                    return Cat()
                raise ValueError(f"Unknown animal type: {animal_type}")

        # Example usage:
        factory = AnimalFactory()
        dog = factory.create_animal("dog")
        cat = factory.create_animal("cat")
        print(f"Dog says: {dog.speak()}")
        print(f"Cat says: {cat.speak()}")
        """
    ]

def create_synthetic_textbook_data():
    """
    Create a comprehensive synthetic dataset of Python textbook-style content.
    """
    textbook_data = []

    # Combine all sections
    textbook_data.extend(create_basic_concepts())
    textbook_data.extend(create_data_structures())
    textbook_data.extend(create_algorithms())
    textbook_data.extend(create_advanced_topics())

    return Dataset.from_dict({"text": textbook_data})

In [None]:
# 2. Create a data filtering function
def filter_code_quality(text):
    """
    Simple quality filter for code content
    """
    # Check for docstring presence
    has_docstring = '"""' in text or "'''" in text

    # Check for comments
    has_comments = '#' in text

    # Check for proper indentation
    lines = text.split('\n')
    proper_indentation = all(line.startswith(' ' * (line.count('    ')))
                           for line in lines if line.strip())

    return has_docstring and has_comments and proper_indentation


In [None]:

# 3. Setup model and tokenizer
def setup_model():
    """
    Initialize a small transformer model for code generation
    """
    tokenizer = AutoTokenizer.from_pretrained("gpt2")
    tokenizer.pad_token = tokenizer.eos_token

    # Load a small pretrained model and move to correct device
    model = AutoModelForCausalLM.from_pretrained("gpt2").to(device)

    return tokenizer, model


In [None]:
# 4. Tokenize dataset
def tokenize_dataset(dataset, tokenizer):
    """
    Tokenize the dataset for training
    """
    def tokenize_function(examples):
        return tokenizer(
            examples["text"],
            padding="max_length",
            truncation=True,
            max_length=512,
            return_special_tokens_mask=True
        )

    tokenized_dataset = dataset.map(
        tokenize_function,
        batched=True,
        remove_columns=["text"]
    )

    # Set the format to PyTorch tensors
    tokenized_dataset.set_format(type='torch', device=device)

    return tokenized_dataset

In [None]:
def train_model(model, tokenizer, dataset):
    """
    Train the model on our textbook dataset with proper progress logging.
    """
    # Tokenize the dataset first
    tokenized_dataset = tokenize_dataset(dataset, tokenizer)

    training_args = TrainingArguments(
        output_dir="./phi-1-mini",
        num_train_epochs=5,
        per_device_train_batch_size=2,
        gradient_accumulation_steps=4,  # Accumulate gradients
        learning_rate=2e-5,  # Lower learning rate
        warmup_ratio=0.1,    # Add warmup
        weight_decay=0.01,   # Add weight decay
        logging_steps=1,
        save_steps=50,
        no_cuda=False if torch.cuda.is_available() else True,
        report_to="wandb",
        logging_dir="./logs",
        logging_first_step=True,
        logging_nan_inf_filter=True,
        warmup_steps=100,
    )

    # Custom callback for better progress tracking
    class LoggingCallback(TrainerCallback):
        def on_log(self, args, state, control, logs=None, **kwargs):
            if state.is_local_process_zero:
                if logs is not None:
                    wandb.log(logs)

    data_collator = DataCollatorForLanguageModeling(
        tokenizer=tokenizer,
        mlm=False,
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_dataset,
        data_collator=data_collator,
        callbacks=[LoggingCallback()]
    )

    print("Starting training...")
    trainer.train()
    print("Training completed!")

    return trainer

In [None]:
def evaluate_model(model, tokenizer, prompt):
    """
    Generate code from a prompt with better parameters
    """
    inputs = tokenizer(prompt, return_tensors="pt").to(device)

    outputs = model.generate(
        inputs.input_ids,
        max_length=256,  # Shorter max length to avoid repetition
        num_return_sequences=1,
        do_sample=True,  # Enable sampling
        temperature=0.7,  # Higher temperature for more creativity
        top_p=0.95,      # Nucleus sampling
        top_k=50,        # Top-k sampling
        repetition_penalty=1.2,  # Penalize repetition
        pad_token_id=tokenizer.eos_token_id,
        eos_token_id=tokenizer.eos_token_id,
        no_repeat_ngram_size=2  # Prevent repetition of n-grams
    )

    return tokenizer.decode(outputs[0], skip_special_tokens=True)

## Main Execution

### Initialize Wandb for Tracking

In [None]:
# Initialize wandb for tracking
wandb.init(project="textbooks-study")

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


### Create Dataset and Preprocess Dataset

In [None]:
# Create dataset and preprocess dataset
print("Creating dataset...")
dataset = create_synthetic_textbook_data()
dataset = preprocess_dataset(dataset)

Creating dataset...


Map:   0%|          | 0/5 [00:00<?, ? examples/s]

### Setup Model and Tokenizer

In [None]:
# Setup model and tokenizer
print("Setting up model...")
tokenizer, model = setup_model()

Setting up model...


tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

### Train Model

In [None]:
# Train model
print("Training model...")
trainer = train_model(model, tokenizer, dataset)

Training model...


Map:   0%|          | 0/5 [00:00<?, ? examples/s]

Starting training...




Step,Training Loss
1,1.6247


Training completed!


### Save Model

In [None]:
# Save model
print("Saving model...")
model.save_pretrained("./phi-1-mini")
tokenizer.save_pretrained("./phi-1-mini")

Saving model...


('./phi-1-mini/tokenizer_config.json',
 './phi-1-mini/special_tokens_map.json',
 './phi-1-mini/vocab.json',
 './phi-1-mini/merges.txt',
 './phi-1-mini/added_tokens.json',
 './phi-1-mini/tokenizer.json')

### Test Generation

In [None]:
# Test generation
test_prompt = """
def calculate_average(numbers):
    '''
"""

print("Testing generation...")
generated_code = evaluate_model(model, tokenizer, test_prompt)
print(f"Generated code:\n{generated_code}")

The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


Testing generation...
Generated code:

def calculate_average(numbers):
    '''
for (a = 0) in [ numbers, len: 1] else : a * 10.0/10 is equivalent to the average of all numbers with an odd number between them and any two are equal or not at least one digit within range which means that it can't be done without rounding up either . If you don´t want your digits smaller than 2 then this may well work too but do consider using round ups as follows for more complex calculations if there will also need some extra data from here i know what I mean by "is" so its probably better just say how many times each decimal was divided because thats possible on most people think about every fraction they have been counted first thing when counting points should count second time based off sort out these fractions like $30 x3 y4*100000$ while evens always last 5 seconds instead does something similar where we could make sure our whole year counts 50 millionths per day? It turns into my math now let's ta