# load pretrained model

1. Misra, Rishabh. "News Category Dataset." arXiv preprint arXiv:2209.11429 (2022).
2. Misra, Rishabh and Jigyasa Grover. "Sculpting Data for ML: The first act of Machine Learning." ISBN 9798585463570 (2021).


In [31]:
import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [32]:
from transformers import LlamaForCausalLM
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("daily_tokenizer_0612")
model = LlamaForCausalLM.from_pretrained('daily_llama_0612')

model.to(device)
0

0

In [33]:
model_fake = LlamaForCausalLM.from_pretrained('fake_detect_llama')

model_fake.to(device)
model_fake.eval()
0

0

In [34]:
prompt = """Return True if the given article is fake. article: Boeing CEO says he assured Trump about Air Force One costs answer:"""

inputs = tokenizer(prompt, return_tensors="pt")
inputs.to(device)

# Generate
generate_ids = model_fake.generate(inputs.input_ids, max_length=50)
output = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]

print(output)

Return True if the given article is fake. article: Boeing CEO says he assured Trump about Air Force One costs answer: True if he is not planning to be 'flagrant aggression' answer: True if he is not doing to answer:


In [35]:
model_fake.eval()

prompt = """What is the topic of the collowing article? article: Boeing CEO says he assured Trump about Air Force One costs answer:"""
inputs = tokenizer(prompt, return_tensors="pt")
inputs.to(device)

# Generate
generate_ids = model_fake.generate(inputs.input_ids, max_length=30)
tokenizer.batch_decode(generate_ids, skip_special_tokens=True, 
                    clean_up_tokenization_spaces=False)[0]


'What is the topic of the collowing article? article: Boeing CEO says he assured Trump about Air Force One costs answer: True if he is not'

In [36]:
model.eval()

prompt = """\
What is the topic of the collowing article? article: Boeing CEO says he assured Trump about Air Force One costs answer:"""
inputs = tokenizer(prompt, return_tensors="pt")
inputs.to(device)

# Generate
generate_ids = model.generate(inputs.input_ids, max_length=100)
tokenizer.batch_decode(generate_ids, skip_special_tokens=True, 
                    clean_up_tokenization_spaces=False)[0]


'What is the topic of the collowing article? article: Boeing CEO says he assured Trump about Air Force One costs answer: "I think it\'s a very good thing to do." "I think it\'s a very good thing to do," he said. "I think it\'s a very good thing to do. I think it\'s a good thing to do. I think it\'s a good thing to do." The first thing that\'s going to be a good thing, but it\'s not'

## load dataset

In [37]:
from datasets import load_dataset
data = 'heegyu/news-category-balanced-top10'
dataset = load_dataset(data)

dataset

Found cached dataset json (/home/ubuntu/.cache/huggingface/datasets/heegyu___json/heegyu--news-category-balanced-top10-5f881f7cd497c7a8/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4)


  0%|          | 0/1 [00:00<?, ?it/s]

DatasetDict({
    train: Dataset({
        features: ['link', 'headline', 'category', 'short_description', 'authors', 'date'],
        num_rows: 83878
    })
})

In [38]:
dataset['train'][0]

{'link': 'https://www.huffpost.com/entry/rei-workers-berkeley-store-union_n_6307a5f4e4b0f72c09ded80d',
 'headline': 'REI Workers At Berkeley Store Vote To Unionize In Another Win For Labor',
 'category': 'BUSINESS',
 'short_description': 'They follow in the footsteps of REI workers in New York City who formed a union earlier this year.',
 'authors': 'Dave Jamieson',
 'date': 1661385600000}

In [39]:
dataset['train'].features

{'link': Value(dtype='string', id=None),
 'headline': Value(dtype='string', id=None),
 'category': Value(dtype='string', id=None),
 'short_description': Value(dtype='string', id=None),
 'authors': Value(dtype='string', id=None),
 'date': Value(dtype='int64', id=None)}

In [40]:
df = dataset['train'].to_pandas()
df

Unnamed: 0,link,headline,category,short_description,authors,date
0,https://www.huffpost.com/entry/rei-workers-ber...,REI Workers At Berkeley Store Vote To Unionize...,BUSINESS,They follow in the footsteps of REI workers in...,Dave Jamieson,1661385600000
1,https://www.huffpost.com/entry/twitter-elon-mu...,Twitter Lawyer Calls Elon Musk 'Committed Enem...,BUSINESS,Delaware Chancery Judge Kathaleen McCormick de...,Marita Vlachou,1658275200000
2,https://www.huffpost.com/entry/starbucks-leave...,"Starbucks Leaving Russian Market, Shutting 130...",BUSINESS,Starbucks' move follows McDonald's exit from t...,"DEE-ANN DURBIN, AP",1653264000000
3,https://www.huffpost.com/entry/coinbase-crypto...,Crypto Crash Leaves Trading Platform Coinbase ...,BUSINESS,Cryptocurrency trading platform Coinbase has l...,"Matt Ott, AP",1652313600000
4,https://www.huffpost.com/entry/us-april-jobs-r...,"US Added 428,000 Jobs In April Despite Surging...",BUSINESS,"At 3.6%, unemployment nearly reached the lowes...","Paul Wiseman, AP",1651795200000
...,...,...,...,...,...,...
83873,https://www.huffingtonpost.com/entry/gratitude...,"Flex Your Gratitude Muscle, and Lift Stress Away",WELLNESS,"For most of us, giving comes a lot easier than...","meQuilibrium, Contributor\nPersonalized Stress...",1369353600000
83874,https://www.huffingtonpost.com/entry/diabetes-...,Don't Wait to Prevent Diabetes: Start Today Wi...,WELLNESS,"Small, reasonable changes can add up to a lot ...","Susan B. Dopart, MS, RD, CDE, Contributor\nHea...",1355443200000
83875,https://www.huffingtonpost.com/entry/dream-lif...,The Real Reason You're Not Living Your Dream L...,WELLNESS,Excuses are artificial creations that mask the...,"Alexis Sclamberg, Contributor\nCEO & Founder, ...",1346025600000
83876,https://www.huffingtonpost.com/entry/sugar-obe...,"Is Sugar Making the World Fat, Diabetic, and H...",WELLNESS,The new study in Public Health Nutrition remin...,"Ayala Laufer-Cahana, M.D., Contributor\nPhysic...",1362096000000


In [41]:
categories = df.category.unique().tolist()
categories.sort()
categories

['BUSINESS',
 'ENTERTAINMENT',
 'FOOD & DRINK',
 'HEALTHY LIVING',
 'PARENTING',
 'POLITICS',
 'QUEER VOICES',
 'STYLE & BEAUTY',
 'TRAVEL',
 'WELLNESS']

In [42]:
categories = categories[:4]

In [43]:
dataset = dataset.filter(lambda element: element['category'] in categories)
dataset

Loading cached processed dataset at /home/ubuntu/.cache/huggingface/datasets/heegyu___json/heegyu--news-category-balanced-top10-5f881f7cd497c7a8/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4/cache-31c8659ca0784ee1.arrow


DatasetDict({
    train: Dataset({
        features: ['link', 'headline', 'category', 'short_description', 'authors', 'date'],
        num_rows: 29026
    })
})

In [44]:
categories = [x.split(' ')[0].lower() for x in categories[:5]]
categories

['business', 'entertainment', 'food', 'healthy']

In [45]:
int2label = {i: categories[i] for i in range(len(categories))}
label2int = {int2label[key]:key for key in int2label}

In [46]:
def gen_label(element):
    category = element['category'].split(' ')[0].lower()
    return {'label': label2int[category], 'category': category}

dataset = dataset.map(gen_label)
dataset

Loading cached processed dataset at /home/ubuntu/.cache/huggingface/datasets/heegyu___json/heegyu--news-category-balanced-top10-5f881f7cd497c7a8/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4/cache-0dc1a2c1b8a47944.arrow


DatasetDict({
    train: Dataset({
        features: ['link', 'headline', 'category', 'short_description', 'authors', 'date', 'label'],
        num_rows: 29026
    })
})

In [47]:
from datasets import DatasetDict
import random

prompt_format1 = """Given the article, what is the topic of the article? article: %s  answer: %s"""
prompt_format2 = """Determine the topic of the news article. article: %s answer: %s"""
prompt_format3 = """What is this article about? business/entertainment/food/healthy/parenting article: %s answer: %s"""

prompts = [prompt_format1, prompt_format2, prompt_format3]

def gen_prompt(element):
    prompt_format = prompts[random.randint(0, len(prompts)-1)]
    return DatasetDict({'input': prompt_format%(element['headline'], element['category'])})


dataset = dataset.map(gen_prompt)

Map:   0%|          | 0/29026 [00:00<?, ? examples/s]

In [48]:
dataset['train'].to_pandas()

Unnamed: 0,link,headline,category,short_description,authors,date,label,input
0,https://www.huffpost.com/entry/rei-workers-ber...,REI Workers At Berkeley Store Vote To Unionize...,business,They follow in the footsteps of REI workers in...,Dave Jamieson,1661385600000,0,"Given the article, what is the topic of the ar..."
1,https://www.huffpost.com/entry/twitter-elon-mu...,Twitter Lawyer Calls Elon Musk 'Committed Enem...,business,Delaware Chancery Judge Kathaleen McCormick de...,Marita Vlachou,1658275200000,0,What is this article about? business/entertain...
2,https://www.huffpost.com/entry/starbucks-leave...,"Starbucks Leaving Russian Market, Shutting 130...",business,Starbucks' move follows McDonald's exit from t...,"DEE-ANN DURBIN, AP",1653264000000,0,What is this article about? business/entertain...
3,https://www.huffpost.com/entry/coinbase-crypto...,Crypto Crash Leaves Trading Platform Coinbase ...,business,Cryptocurrency trading platform Coinbase has l...,"Matt Ott, AP",1652313600000,0,Determine the topic of the news article. artic...
4,https://www.huffpost.com/entry/us-april-jobs-r...,"US Added 428,000 Jobs In April Despite Surging...",business,"At 3.6%, unemployment nearly reached the lowes...","Paul Wiseman, AP",1651795200000,0,"Given the article, what is the topic of the ar..."
...,...,...,...,...,...,...,...,...
29021,https://www.huffingtonpost.com/entry/happy-hea...,Why You Need Both a 'Bouncer' and a 'Bartender...,healthy,Instead of judging whether you made the right ...,"Elizabeth Grace Saunders, ContributorFounder, ...",1397779200000,3,"Given the article, what is the topic of the ar..."
29022,https://www.huffingtonpost.com/entry/mental-il...,How Video Games Can Improve Dialogue on Mental...,healthy,While there are strong arguments for the games...,"Mona Shattell, Contributornurse researcher",1397779200000,3,Determine the topic of the news article. artic...
29023,https://www.huffingtonpost.com/entry/wake-up-c...,Wake-Up Calls Inspired My Change From Overdriv...,healthy,My wake-up call marching orders were clear: No...,"Jane Shure, ContributorLeadership Coach, Psych...",1397779200000,3,"Given the article, what is the topic of the ar..."
29024,https://www.huffingtonpost.com/entry/narcissis...,Loving a Narcissist Without Losing Yourself,healthy,It is very difficult for some people to see an...,"Nancy Colier, ContributorPsychotherapist, inte...",1397779200000,3,What is this article about? business/entertain...


In [49]:
dataset = dataset['train'].train_test_split(test_size=0.1)

In [50]:
def tokenize(element):
    tokenizer.pad_token = tokenizer.eos_token
    outputs = tokenizer(
        element['input'],
        truncation=True,
        max_length=context_length,
        return_overflowing_tokens=False,
        return_length=True,
        padding=True
    )
    input_batch = []
    for inputs, input_ids in zip(element["input"], outputs["input_ids"]):
        input_batch.append(input_ids)
    return {"input_ids": input_batch}


context_length=128
tokenized_datasets = dataset.map(
    tokenize, batched=True, remove_columns=dataset['train'].column_names
)
tokenized_datasets

Map:   0%|          | 0/26123 [00:00<?, ? examples/s]

Map:   0%|          | 0/2903 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['input_ids'],
        num_rows: 26123
    })
    test: Dataset({
        features: ['input_ids'],
        num_rows: 2903
    })
})

## train

In [24]:
from transformers import DataCollatorForLanguageModeling

tokenizer.pad_token = tokenizer.eos_token
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)

In [25]:
out = data_collator([tokenized_datasets['train'][i] for i in range(5)])
for key in out:
    print(f"{key} shape: {out[key].shape}")

You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


input_ids shape: torch.Size([5, 56])
attention_mask shape: torch.Size([5, 56])
labels shape: torch.Size([5, 56])


In [26]:
from transformers import Trainer, TrainingArguments

args = TrainingArguments(
    output_dir="topic_llama",
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    evaluation_strategy="steps",
    eval_steps=500,
    logging_steps=500,
    gradient_accumulation_steps=8,
    num_train_epochs=1,
    weight_decay=0.1,
    warmup_steps=500,
    lr_scheduler_type="cosine",
    learning_rate=5e-4,
    save_steps=500,
    fp16=True,
    push_to_hub=False,
)

trainer = Trainer(
    model=model,
    tokenizer=tokenizer,
    args=args,
    data_collator=data_collator,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["test"],
)

In [27]:
trainer.train()



Step,Training Loss,Validation Loss


TrainOutput(global_step=816, training_loss=2.780589683383119, metrics={'train_runtime': 141.6183, 'train_samples_per_second': 184.461, 'train_steps_per_second': 5.762, 'total_flos': 358888641269760.0, 'train_loss': 2.780589683383119, 'epoch': 1.0})

## evaluate

In [26]:
prompt = """Determine the topic of the news article. article: Bikini'd Kate Hudson Hits The Beach With Chris Martin answer:"""

inputs = tokenizer(prompt, return_tensors="pt")
inputs.to("cuda:0")

# Generate
generate_ids = model.generate(inputs.input_ids, max_length=30)
output = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]

print(output)

Determine the topic of the news article. article: Bikini'd Kate Hudson Hits The Beach With Chris Martin answer: entertainment answer: entertainment


In [53]:
tokenizer = AutoTokenizer.from_pretrained("daily_tokenizer_0612", padding_side='left')
prompt_format1 = """Given the article, what is the topic of the article? article: %s  answer:"""
prompt_format2 = """Determine the topic of the news article. article: %s answer:"""
prompt_format3 = """What is this article about? business/entertainment/food/healthy/parenting article: %s answer:"""

prompts = [prompt_format1, prompt_format2, prompt_format3]

def gen_valid_prompt(element):
    prompt_format = prompts[random.randint(0, len(prompts)-1)]
    return DatasetDict({'input': prompt_format%(element['headline'])})




valid_dataset = dataset['test'].select(range(100)).map(gen_valid_prompt)
valid_dataset[0]

Map:   0%|          | 0/100 [00:00<?, ? examples/s]

{'link': 'https://www.huffpost.com/entry/sarah-michelle-gellar-buffy-the-vampire-slayer-throwback_n_5e726321c5b6f5b7c53cff96',
 'headline': "Sarah Michelle Gellar Goes Full-On 'Buffy' In Coronavirus Battle",
 'category': 'entertainment',
 'short_description': 'The actor-turned-lifestyle guru found the perfect moment to reference her "Vampire Slayer" past.',
 'authors': 'Curtis M. Wong',
 'date': 1584489600000,
 'label': 1,
 'input': "Given the article, what is the topic of the article? article: Sarah Michelle Gellar Goes Full-On 'Buffy' In Coronavirus Battle  answer:"}

In [54]:
valid_dataset.column_names

['link',
 'headline',
 'category',
 'short_description',
 'authors',
 'date',
 'label',
 'input']

In [55]:
valid_dataset = valid_dataset.map(
    tokenize, batched=True, remove_columns=['link', 'headline', 'category', 'short_description', 'authors', 'date', 'input']
)
valid_dataset

Map:   0%|          | 0/100 [00:00<?, ? examples/s]

Dataset({
    features: ['label', 'input_ids'],
    num_rows: 100
})

In [56]:
from torch.utils.data import DataLoader

batch_size=4
val_ds = valid_dataset
val_ds.set_format(type='torch')
val_dl = DataLoader(val_ds, batch_size=batch_size)

In [57]:
import re
import torch
from tqdm import tqdm

def acc(pred,label):
    return torch.sum(torch.tensor(pred) == label.squeeze()).item()


In [58]:
model_orig = LlamaForCausalLM.from_pretrained('daily_llama_0612')
model_orig.to(device)
model_orig.eval()

val_losses = []
val_acc = 0

for step, batch in enumerate(tqdm(val_dl)):
    label = batch['label']
    input_id= batch['input_ids'].to(device)

    pred = model_orig.generate(input_id, max_length=150)
    decoded_pred = tokenizer.batch_decode(pred, skip_special_tokens=True, clean_up_tokenization_spaces=False)
    decoded_pred = [re.findall("answer: ([a-z]+)", x)[0] if re.findall("answer: ([a-z]+)", x) else 'none' for x in decoded_pred]
    decoded_pred = [label2int[x] if x in label2int else -1 for x in decoded_pred]

    val_acc += acc(decoded_pred, label)
    

print("val acc: ", val_acc/len(val_dl.dataset))

100%|███████████████████████████████████████████| 25/25 [00:13<00:00,  1.82it/s]

val acc:  0.0





In [59]:
tokenizer.batch_decode(pred, skip_special_tokens=True, clean_up_tokenization_spaces=False)

['Given the article, what is the topic of the article? article: Timothée Chalamet Takes Shot At Warner Bros. With Sweatshirt Statement On \'SNL\'  answer: "The Dark Knight Rises" is a big deal. The film is a big-budget film, which is a big-budget film, and the film is a big-budget film. The film is a "The Dark Knight" and "The Dark Knight" is a "The Dark Knight" and "The Dark Knight" in the film. The film is a "The Dark Knight" and "The Dark Knight" in the film. The film is a "The Dark Knight" and "The Dark Knight" and "',
 'Determine the topic of the news article. article: Britney Spears\' Mother Makes Legal Claim, Heightening Family Drama answer: . "I\'m not a good friend," she said. "I\'m not a fan of the show. I\'m not a fan of the show. I\'m a little bit more than I\'m. I\'m not going to be a good guy. I\'m not going to be a good guy. I\'m not a fan of the show. I\'m a little bit more than I\'m. I\'m not going to be a good guy. I\'m not a fan of the show. I\'m a little bit more th

In [42]:
model.eval()
val_losses = []
val_acc = 0

for step, batch in enumerate(tqdm(val_dl)):
    label = batch['label']
    input_id= batch['input_ids'].to(device)

    pred = model.generate(input_id, max_length=65)
    decoded_pred = tokenizer.batch_decode(pred, skip_special_tokens=True, clean_up_tokenization_spaces=False)
    decoded_pred = [re.findall("answer: ([a-z]+)", x)[0] if re.findall("answer: ([a-z]+)", x) else 'none' for x in decoded_pred]
    decoded_pred = [label2int[x] if x in label2int else -1 for x in decoded_pred]

    val_acc += acc(decoded_pred, label)
    

print("val acc: ", val_acc/len(val_dl.dataset))

100%|███████████████████████████████████████████| 25/25 [00:00<00:00, 35.11it/s]

val acc:  0.84





In [48]:
model.save_pretrained('topic_llama_0618')