In [None]:
import torch
import torch.nn as nn
from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPT2Config, GPT2Model, GPT2PreTrainedModel
from modeling_topK_gpt2 import CustomGPT2Attention, CustomGPT2Block, CustomGPT2Model, CustomGPT2LMHeadModel
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import time
from datasets import load_dataset

In [11]:
# function to generate text
def generate_text(model, tokenizer, input_text, max_length=100):
    # tokenize the input text
    input_ids = tokenizer.encode(input_text, return_tensors='pt').to('cuda')

    # generate text
    with torch.no_grad():
        output = model.generate(input_ids, max_length=max_length, pad_token_id=tokenizer.eos_token_id)

    # decode the generated text
    generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
    return generated_text


In [12]:
model_name = "gpt2"  
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
config = GPT2Config.from_pretrained(model_name)
custom_model = CustomGPT2LMHeadModel.from_pretrained(model_name, config=config, k_percent=[0.0, 0.0, 0.2, 0.2, 0.2, 0.15, 0.2, 0.2, 0.2, 0.2, 0.0, 0.05],
                                                     layers_to_prune=[0, 1, 2,3, 4, 5, 6,7,8,9,10,11])
custom_model.to('cuda')  
generated_text = """FineWeb Technical Report was released! Here is how the Hugging Face team created FineWeb, the best open-source dataset:

1. Collect Raw Data: Use CommonCrawl as the starting point, 96 CommonCrawl snapshots were used for FineWeb.

2. Url Filtering: Apply URL filtering using a blocklist to remove explicit and malicious content

3. Text Extraction: Use the trafilatura to extract text from raw HTML of the WARC files

4. Base Filtering (Language & Gopher): Use a fastText language classifier to keep only English text and apply quality and repetition filters from MassiveText (Gopher).

4. Deduplicating the Data: Use MinHash deduplication to deduplicate each dump individually. No cross-deduplication across dumps.

5. Additional Quality Filtering: Apply C4 filters (except Punctuation) + 3 additional new filters for Punctuation, Line duplicates, and Short Lines.

6. PII Removal: Replace Emails and IP addresses from the dataset

During the whole process the team ran hundreds of ablations to evaluate performance against other open datasets."""
start = time.time()
for i in range(100):
  
  #print(len(tokenizer(generated_text)["input_ids"]))
  generated_text = generate_text(custom_model, tokenizer, generated_text, max_length=len(tokenizer(generated_text)["input_ids"])+1)
end = time.time()
print("Generated Text:", generated_text)
print(end-start)

Generated Text: FineWeb Technical Report was released! Here is how the Hugging Face team created FineWeb, the best open-source dataset:

1. Collect Raw Data: Use CommonCrawl as the starting point, 96 CommonCrawl snapshots were used for FineWeb.

2. Url Filtering: Apply URL filtering using a blocklist to remove explicit and malicious content

3. Text Extraction: Use the trafilatura to extract text from raw HTML of the WARC files

4. Base Filtering (Language & Gopher): Use a fastText language classifier to keep only English text and apply quality and repetition filters from MassiveText (Gopher).

4. Deduplicating the Data: Use MinHash deduplication to deduplicate each dump individually. No cross-deduplication across dumps.

5. Additional Quality Filtering: Apply C4 filters (except Punctuation) + 3 additional new filters for Punctuation, Line duplicates, and Short Lines.

6. PII Removal: Replace Emails and IP addresses from the dataset

During the whole process the team ran hundreds of ab

In [13]:
model_name = "gpt2"  
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model = CustomGPT2LMHeadModel.from_pretrained(model_name, config=config, k_percent=[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
                                                     layers_to_prune=[0, 1, 2,3, 4, 5, 6,7,8,9,10,11])
model.to('cuda')

generated_text = """FineWeb Technical Report was released! Here is how the Hugging Face team created FineWeb, the best open-source dataset:

1. Collect Raw Data: Use CommonCrawl as the starting point, 96 CommonCrawl snapshots were used for FineWeb.

2. Url Filtering: Apply URL filtering using a blocklist to remove explicit and malicious content

3. Text Extraction: Use the trafilatura to extract text from raw HTML of the WARC files

4. Base Filtering (Language & Gopher): Use a fastText language classifier to keep only English text and apply quality and repetition filters from MassiveText (Gopher).

4. Deduplicating the Data: Use MinHash deduplication to deduplicate each dump individually. No cross-deduplication across dumps.

5. Additional Quality Filtering: Apply C4 filters (except Punctuation) + 3 additional new filters for Punctuation, Line duplicates, and Short Lines.

6. PII Removal: Replace Emails and IP addresses from the dataset

During the whole process the team ran hundreds of ablations to evaluate performance against other open datasets."""
start = time.time()
for i in range(100):

  #print(len(tokenizer(generated_text)["input_ids"]))
  generated_text = generate_text(model, tokenizer, generated_text, max_length=len(tokenizer(generated_text)["input_ids"])+1)
end = time.time()
print("Generated Text:", generated_text)
print(end-start)

Generated Text: FineWeb Technical Report was released! Here is how the Hugging Face team created FineWeb, the best open-source dataset:

1. Collect Raw Data: Use CommonCrawl as the starting point, 96 CommonCrawl snapshots were used for FineWeb.

2. Url Filtering: Apply URL filtering using a blocklist to remove explicit and malicious content

3. Text Extraction: Use the trafilatura to extract text from raw HTML of the WARC files

4. Base Filtering (Language & Gopher): Use a fastText language classifier to keep only English text and apply quality and repetition filters from MassiveText (Gopher).

4. Deduplicating the Data: Use MinHash deduplication to deduplicate each dump individually. No cross-deduplication across dumps.

5. Additional Quality Filtering: Apply C4 filters (except Punctuation) + 3 additional new filters for Punctuation, Line duplicates, and Short Lines.

6. PII Removal: Replace Emails and IP addresses from the dataset

During the whole process the team ran hundreds of ab

In [14]:
model_name = "gpt2"  
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model = GPT2LMHeadModel.from_pretrained(model_name, config=config)
model.to('cuda')

generated_text = """FineWeb Technical Report was released! Here is how the Hugging Face team created FineWeb, the best open-source dataset:

1. Collect Raw Data: Use CommonCrawl as the starting point, 96 CommonCrawl snapshots were used for FineWeb.

2. Url Filtering: Apply URL filtering using a blocklist to remove explicit and malicious content

3. Text Extraction: Use the trafilatura to extract text from raw HTML of the WARC files

4. Base Filtering (Language & Gopher): Use a fastText language classifier to keep only English text and apply quality and repetition filters from MassiveText (Gopher).

4. Deduplicating the Data: Use MinHash deduplication to deduplicate each dump individually. No cross-deduplication across dumps.

5. Additional Quality Filtering: Apply C4 filters (except Punctuation) + 3 additional new filters for Punctuation, Line duplicates, and Short Lines.

6. PII Removal: Replace Emails and IP addresses from the dataset

During the whole process the team ran hundreds of ablations to evaluate performance against other open datasets."""
start = time.time()
for i in range(100):

  #print(len(tokenizer(generated_text)["input_ids"]))
  generated_text = generate_text(model, tokenizer, generated_text, max_length=len(tokenizer(generated_text)["input_ids"])+1)
end = time.time()
print("Generated Text:", generated_text)
print(end-start)

Generated Text: FineWeb Technical Report was released! Here is how the Hugging Face team created FineWeb, the best open-source dataset:

1. Collect Raw Data: Use CommonCrawl as the starting point, 96 CommonCrawl snapshots were used for FineWeb.

2. Url Filtering: Apply URL filtering using a blocklist to remove explicit and malicious content

3. Text Extraction: Use the trafilatura to extract text from raw HTML of the WARC files

4. Base Filtering (Language & Gopher): Use a fastText language classifier to keep only English text and apply quality and repetition filters from MassiveText (Gopher).

4. Deduplicating the Data: Use MinHash deduplication to deduplicate each dump individually. No cross-deduplication across dumps.

5. Additional Quality Filtering: Apply C4 filters (except Punctuation) + 3 additional new filters for Punctuation, Line duplicates, and Short Lines.

6. PII Removal: Replace Emails and IP addresses from the dataset

During the whole process the team ran hundreds of ab

In [None]:
dataset = load_dataset("ag_news")
dataset

In [9]:
dataset["train"][0]["text"]

"Wall St. Bears Claw Back Into the Black (Reuters) Reuters - Short-sellers, Wall Street's dwindling\\band of ultra-cynics, are seeing green again."

In [27]:
generated_text = dataset["train"][0]["text"]
start = time.time()
for i in range(50):

  #print(len(tokenizer(generated_text)["input_ids"]))
  generated_text = generate_text(model, tokenizer, generated_text, max_length=len(tokenizer(generated_text)["input_ids"])+1)
end = time.time()
print("Generated Text:", generated_text)
print(end-start)

Generated Text: Wall St. Bears Claw Back Into the Black (Reuters) Reuters - Short-sellers, Wall Street's dwindling\band of ultra-cynics, are seeing green again.


The Dow Jones industrial average closed down 0.9 percent at $1,919.50, while the S&P 500 closed down 0.9 percent at $1,919.50.


The Dow Jones industrial
6.666771173477173


In [28]:
generated_text = dataset["train"][0]["text"]
start = time.time()
for i in range(50):

  #print(len(tokenizer(generated_text)["input_ids"]))
  generated_text = generate_text(custom_model, tokenizer, generated_text, max_length=len(tokenizer(generated_text)["input_ids"])+1)
end = time.time()
print("Generated Text:", generated_text)
print(end-start)

Generated Text: Wall St. Bears Claw Back Into the Black (Reuters) Reuters - Short-sellers, Wall Street's dwindling\band of ultra-cynics, are seeing green again.


The company's stock has fallen by more more than than than 10 percent in the past past year, and its stock has fallen by more than 20 percent in the past year.


The stock has been trading at $1.
5.529221534729004
