# WCC Implementation, GPT-Neo from scratch

## Imports

In [None]:
!pip install datasets tqdm matplotlib transformers accelerate deepspeed -q

In [None]:
from datasets import load_dataset
import re
import pandas as pd
import json
from tqdm import tqdm
from transformers import Trainer, Adafactor, get_scheduler
from transformers import TrainingArguments
from transformers import GPT2Tokenizer, DataCollatorForLanguageModeling
from transformers import GPTNeoConfig, GPTNeoForCausalLM
from datasets import Dataset

## Dataset Preparation

In [None]:
max_context = 2048

In [None]:
sentences = pd.read_csv("sentences.csv")
sentences['Sentence'] = sentences['Sentence'].astype(str)
print(f"Total sentences before filtering: {len(sentences)}")
sentences = sentences[sentences["Word Count"] < 25]
print(f"Total sentences after filtering: {len(sentences)}")
sentences.head()

In [None]:
def add_wcc_token(sentence, word_count):
     return f"<{word_count}>{sentence}"

In [None]:
tqdm.pandas()
sentences['Sentence_with_WCC'] = sentences.progress_apply(lambda row: add_wcc_token(row['Sentence'], row['Word Count']), axis=1)

In [None]:
paragraphs = sentences.groupby('Paragraph ID').agg({
    'Sentence_with_WCC': ' '.join,
    'Word Count': 'sum',
}).reset_index()


In [None]:
paragraphs.to_csv('paragraphs.csv', index=False)

In [None]:
paragraphs = pd.read_csv('paragraphs.csv').sample(350_000, random_state=42)

In [None]:
sentences = pd.read_csv('sentences.csv')

In [None]:
sentences['Sentence'] = sentences['Sentence'].astype(str)
print(f"Total sentences before filtering: {len(sentences)}")
sentences = sentences[sentences["Word Count"] < 25]
print(f"Total sentences after filtering: {len(sentences)}")
sentences.head()
max_len = sentences["Word Count"].max()

In [None]:
print(f"Total paragraphs before filtering: {len(paragraphs)}")
paragraphs = paragraphs[paragraphs["Word Count"] < max_context]
print(f"Total paragraphs after filtering: {len(paragraphs)}")


In [None]:
def add_wcc_token_to_examples(examples):
    examples['Sentence'] = [add_wcc_token(sentence, wc) for sentence, wc in zip(examples['Sentence'], examples['Word Count'])]
    return examples


In [None]:
dataset = Dataset.from_pandas(paragraphs.drop(columns=["Paragraph ID"]))
# Filter to keep only the relevant column
dataset = dataset.map(lambda examples: {"text": examples["Sentence_with_WCC"]}, remove_columns=["Word Count"])

next(iter(dataset))


## Model Configuration

In [None]:
# Define model configuration
config = GPTNeoConfig(
    vocab_size=50257,  # You can use a larger vocab size if necessary
    max_position_embeddings=2048,
    num_layers=12,
    num_heads=12,
    hidden_size=768,
    intermediate_size=3072,
    activation_function="gelu",
    attention_types=[[["global", "local"], 6]],  # Global and local attention for all layers
)

# Initialize model from scratch
model = GPTNeoForCausalLM(config)


In [None]:
print(model)

In [None]:
from transformers import GPT2Tokenizer, DataCollatorForLanguageModeling
from transformers import GPTNeoConfig, GPTNeoForCausalLM
from datasets import Dataset
# Load GPT-2 tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")


print("Initial tokenizer size:", len(tokenizer))
print("Max len: ", max_len)

additional_tokens = [f"<{i}>" for i in range(max_len + 1)]
tokenizer.add_tokens(additional_tokens)

print("Final tokenizer size:", len(tokenizer))
model.resize_token_embeddings(len(tokenizer))
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

# Define data collator for language modeling


from transformers import DataCollatorForLanguageModeling


def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=1024)

tokenized_dataset = dataset.map(tokenize_function, batched=True)

data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False,  # Language modeling task
)

In [None]:
training_args = TrainingArguments(
    output_dir="./gpt-neo-125m-from-scratch",
    per_device_train_batch_size=4, 
    gradient_accumulation_steps=4, 
    learning_rate=3e-4,
    weight_decay=0.01,  
    adam_beta1=0.9,
    adam_beta2=0.98,  
    adam_epsilon=1e-08,
    num_train_epochs=3,  
    logging_steps=100,
    save_steps=0,  
    deepspeed="ds_config.json",  
    fp16=True,
)


## Training

In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=tokenized_dataset,  
    tokenizer=tokenizer,
)

trainer.train()

In [None]:
trainer.save_model("./gpt-neo-125m-final-350k")

In [None]:
prompt = "<5>"

generated_text = model.generate(
    input_ids=tokenizer.encode(prompt, return_tensors="pt").to("cuda"),
    max_length=100, 
    temperature=0.75, 
    num_return_sequences=10, 
    pad_token_id=tokenizer.eos_token_id, 
    do_sample=True,
)

# Decode and print the generated text
for i, sequence in enumerate(generated_text):
    print(f"Generated sequence {i+1}: {tokenizer.decode(sequence, skip_special_tokens=True)}")
    print()

In [None]:
mse_values = []
accuracy_values = []


for length in tqdm(range(1, 25)):
    length_errors = []
    length_accuracies = []


    prompt = f"<{length}>"
    outp = model.generate(
        input_ids=tokenizer.encode(prompt, return_tensors="pt").to("cuda"),
        max_length=100,
        temperature=0.5,
        num_return_sequences=100,
        pad_token_id=tokenizer.eos_token_id,
        do_sample=True,
    )
    for generated_text in outp:

        sentences = re.split('[.!?] ', tokenizer.decode(generated_text, skip_special_tokens=True))
        first_sentence = sentences[0]
    

        num_words = len(first_sentence.split(' '))
    

        error = (num_words - length) ** 2
        accuracy = int(num_words <= length)
    

        length_errors.append(error)
        length_accuracies.append(accuracy)


    mse = np.mean(length_errors)  
    accuracy = np.mean(length_accuracies) 


    mse_values.append(mse)
    accuracy_values.append(accuracy)

print("MSE values:", mse_values)
print("Accuracy values:", accuracy_values)

In [None]:
fig, axs = plt.subplots(2)


axs[0].plot(range(1, 25), mse_values, marker='o')
axs[0].set_title('MSE values')
axs[0].set_xlabel('Sentence length')
axs[0].set_ylabel('MSE')


axs[1].plot(range(1, 25), accuracy_values, marker='o')
axs[1].set_title('Accuracy values')
axs[1].set_xlabel('Sentence length')
axs[1].set_ylabel('Accuracy')


plt.tight_layout()
plt.show()