In [None]:
import torch
import math
from datasets import load_dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    Trainer,
    TrainingArguments,
    DataCollatorForLanguageModeling,
)
from peft import LoraConfig, get_peft_model

# Our chosen base model
BASE_MODEL = "gpt2"

In [None]:
# ds_name = "StephanAkkerman/crypto-stock-tweets"
ds_name = "flowfree/crypto-news-headlines"
dataset = load_dataset(ds_name)
train = dataset["train"]
val = dataset["validation"]


In [None]:
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
tokenizer.pad_token = tokenizer.eos_token

In [None]:
def tokenize_function(examples):
    return tokenizer(
        examples["text"],
        truncation=True,
        max_length=512,   # keeps sequence length small to reduce memory usage
        padding=False     # we'll let the data collator handle padding
    )

# tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=["text"])
# print(tokenized_dataset)

train_dataset = train.map(tokenize_function, batched=True, remove_columns=["text"])
eval_dataset = val.map(tokenize_function, batched=True, remove_columns=["text"])

In [None]:
# Assuming tokenized_dataset is a DatasetDict with 'train' split
# train_size = int(0.8 * 700)  # Use the 'train' split

# train_dataset = tokenized_dataset["train"].select(range(train_size))  # Apply select on the train split
# eval_dataset = tokenized_dataset["train"].select(range(train_size, len(tokenized_dataset["train"])))  # Remaining for eval

# print("Train size:", len(train_dataset))
# print("Eval size:", len(eval_dataset))


In [None]:
target_mods = ["attn.c_attn", "attn.c_proj"]  # typical GPT-2 attention layers
lora_config = LoraConfig(
    task_type="CAUSAL_LM",
    r=8,
    lora_alpha=32,
    target_modules=target_mods,
    lora_dropout=0.01,
    bias="none",
)

In [None]:
# Load the base model and attach LoRA adapters
base_model = AutoModelForCausalLM.from_pretrained(BASE_MODEL)
model = get_peft_model(base_model, lora_config)

# Disable caching since gradient checkpointing requires use_cache to be False
model.config.use_cache = False

# Ensure inputs require gradients for checkpointing to work properly
model.enable_input_require_grads()

tokenizer.pad_token = tokenizer.eos_token

In [None]:
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer, 
    mlm=False  # Causal models do not use masked language modeling
)

In [None]:
training_args = TrainingArguments(
    output_dir="distilgpt2-crypto-finetuned",
    overwrite_output_dir=True,
    num_train_epochs=100,
    per_device_train_batch_size=8,  # batch size per device
    gradient_accumulation_steps=2,  # to simulate a larger effective batch size
    per_device_eval_batch_size=8,
    gradient_checkpointing=True,    # saves memory on activations
    fp16=True,                      # enable mixed-precision training
    evaluation_strategy="epoch",
    logging_steps=5,
    save_steps=25,
    save_total_limit=1,
    lr_scheduler_type="cosine",
    warmup_steps=10,
    learning_rate=5e-4,
    dataloader_num_workers=0,       # parallel data loading disabled
    optim="adamw_torch",            # optimizer choice
    save_safetensors=False,
    report_to="none"
)

In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=data_collator,
)

In [None]:
trainer.train()

# Save the final model
trainer.save_model("gpt2-crypto-finetuned")
print("Fine-tuning complete. Model saved.")

In [None]:
eval_results = trainer.evaluate()
eval_loss = eval_results["eval_loss"]
perplexity = math.exp(eval_loss)
print(f"Evaluation Loss: {eval_loss:.4f}")
print(f"Perplexity: {perplexity:.4f}")

In [None]:
model.eval()
prompt = "tell me the top 5 cryptocurrencies"
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")

with torch.no_grad():
    outputs = model.generate(
        **inputs,
        max_new_tokens=50,
        temperature=0.7,
        top_p=0.9
    )

print("Generated text:")
print(tokenizer.decode(outputs[0], skip_special_tokens=True))