<div class="alert alert-block alert-info">
<b>⚠️ ToDo:</b> <br>

<div>
    <input type="checkbox" id="scales" name="scales" />
    <label for="scales">Restructure and upload full dataset</label>
</div>

<div>
    <input type="checkbox" id="scales" name="scales" />
    <label for="scales">Add wandb parameter logging and visualization</label>
</div>

<div>
    <input type="checkbox" id="scales" name="scales" />
    <label for="scales">Train model on full dataset</label>
</div>

<div>
    <input type="checkbox" id="scales" name="scales" />
    <label for="scales">Add markdown cells with comments and explanation</label>
</div>

In [None]:
!pip install -q transformers datasets accelerate bitsandbytes peft trl wandb

In [None]:
import re
import os
import json
import torch
from tqdm import tqdm
from datasets import Dataset
from collections import defaultdict

from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    TrainingArguments,
    Trainer
)

from peft import (
    prepare_model_for_kbit_training,
    LoraConfig,
    get_peft_model
)

from transformers.modeling_utils import PreTrainedModel
from transformers.tokenization_utils_base import PreTrainedTokenizerBase

import wandb
from kaggle_secrets import UserSecretsClient
from huggingface_hub import login

# Authentication

In [None]:
user_secrets = UserSecretsClient()
hf_token = user_secrets.get_secret("HF")
wandb_token = user_secrets.get_secret("wandb")

login(hf_token)
wandb.login(key=wandb_token)

# Config & Hyperparameters

In [None]:
SEED = 42
MODEL_NAME = "google/gemma-2b-it"

DATASET_PATH = "/kaggle/input/red-russian/data/red_russian_dataset.json"
AUTHORS = ["Stalin"]
NUM_BOOKS = 1

MAX_LENGTH = 256
BATCH_SIZE = 6
GRADIENT_ACCUMULATION_STEPS = 2
LEARNING_RATE = 5e-4
EPOCHS = 1
LOGGING_STEPS = 10

SAVE_STRATEGY = "epoch"
OUTPUT_DIR = "./red_russian_gemma"
WANDB_PROJECT = "red_russian_gemma" 
WANDB_RUN_NAME = "red_russian_gemma-run-1"

# Data Loading

In [None]:
def load_json_data(dataset_path):
    with open(dataset_path, "r", encoding="utf-8") as f:
        json_data = json.load(f)
    if not isinstance(json_data, list):
        raise ValueError("JSON data must be a list of dictionaries")
        
    return json_data

In [None]:
def filter_and_limit_data(json_data, authors=None, num_books=None):
    
    filtered_list = []
    author_counts = defaultdict(int) 

    for item in json_data:
        author = item["author_name"]

        if authors is None or author in authors:
            if num_books is None or author_counts[author] < num_books:
                filtered_list.append(item)
                if num_books is not None:
                    author_counts[author] += 1

        n = len(filtered_list)
        print(f"Data successfully filtered by author and quantity, loaded {n} books")

    return filtered_list

# Chunking

In [None]:
def clean_text(text):

    text = re.sub(r"^\d+$", "", text, flags=re.MULTILINE) 
    text = re.sub(r"\s+", " ", text)
    text = re.sub(r"^\s*$", "", text, flags=re.MULTILINE) 

    return text

In [None]:
def chunk_data(json_data, chunk_size=256, use_authors=None):
    
    chunks = []

    for item in json_data: 
        if 'book_content' not in item:
            print(f"Warning: Skipping item due to missing 'book_content': {item}")
            continue

        text = item['book_content']
        
        if use_authors and 'author_name' in item:
            author = item['author_name']
        else:
            author = "Unknown"

        text = clean_text(text)

        for i in range(0, len(text), chunk_size):
            chunk = text[i:i + chunk_size]
            author_token = f"<|author:{author}|>"
            chunks.append(f"{author_token} {chunk}")

    if not chunks:
        print("Error: No valid text data found in the JSON file.")

        return None

    print("Data successfully chunked")

    dataset = Dataset.from_dict({"text": chunks})
    print(f"Dataset loaded: {len(dataset)} samples")
    
    return dataset

# Tokenization

In [None]:
def prepare_tokenizer(model_name):
    tokenizer = AutoTokenizer.from_pretrained(
        model_name,
        trust_remote_code=True,
        padding_side='right'
    )
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    return tokenizer

def tokenize_function(examples, tokenizer, max_length):
    tokenized_inputs = tokenizer(
        examples['text'],
        truncation=True,
        max_length=max_length,
        padding="max_length",
        return_tensors="pt"
    )

    # Shift the input ids to the right
    labels = tokenized_inputs["input_ids"].clone()
    labels[:, :-1] = labels[:, 1:]
    # Replace padding with ignore index
    labels[:, -1] = -100

    return {
        "input_ids": tokenized_inputs["input_ids"],
        "attention_mask": tokenized_inputs["attention_mask"],
        "labels": labels
    }

def prepare_training_data(dataset, tokenizer, max_length):
    tokenized_dataset = dataset.map(
        lambda examples: tokenize_function(examples, tokenizer, max_length),
        batched=True,
        remove_columns=dataset.column_names
    )

    print("Data successfully tokenized")
    
    return tokenized_dataset

# Model

In [None]:
def configure_model(model_name):
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        device_map="auto",
        torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
        load_in_4bit=True
    )
    model = prepare_model_for_kbit_training(model)

    lora_config = LoraConfig(
        r=8,
        lora_alpha=32,
        target_modules=[
            "q_proj", "o_proj",
            "k_proj", "v_proj",
            "gate_proj", "up_proj"
        ],
        lora_dropout=0.1,
        bias="none",
        task_type="CAUSAL_LM"
    )
    model = get_peft_model(model, lora_config)
    return model

# Training

In [None]:
class CustomTrainer(Trainer):
    def compute_loss(self, model, inputs, num_items_in_batch=2, return_outputs=False):
        labels = inputs.pop("labels")
        outputs = model(**inputs)
        logits = outputs.logits
        loss_fct = torch.nn.CrossEntropyLoss()
        loss = loss_fct(logits.view(-1, logits.shape[-1]), labels.view(-1))
        return (loss, outputs) if return_outputs else loss

In [None]:
def train_model(model, tokenizer, tokenized_dataset, output_dir, batch_size, gradient_accumulation_steps, learning_rate, epochs, logging_steps, save_strategy, wandb_project, wandb_run_name, seed):

    torch.manual_seed(seed)
    wandb.init(project=wandb_project, name=wandb_run_name)

    training_args = TrainingArguments(
        output_dir=output_dir,
        num_train_epochs=epochs,
        per_device_train_batch_size=batch_size,
        gradient_accumulation_steps=gradient_accumulation_steps,
        learning_rate=learning_rate,
        logging_dir="./logs",
        logging_steps=logging_steps,
        save_strategy=save_strategy,
        fp16=False,
        bf16=torch.cuda.is_available(),
        gradient_checkpointing=False,
        report_to="wandb", #Log to W&B
        seed = seed
    )

    trainer = CustomTrainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_dataset,
        tokenizer=tokenizer
    )

    trainer.train()

    trainer.save_model(output_dir)
    wandb.finish()

In [None]:
json_data = load_json_data(DATASET_PATH)
json_data = filter_and_limit_data(json_data, AUTHORS, NUM_BOOKS)
dataset = chunk_data(json_data, MAX_LENGTH, AUTHORS)

if dataset is None:
    print("Dataset loading failed. Aborting.")
else:
    tokenizer = prepare_tokenizer(MODEL_NAME)
    tokenized_dataset = prepare_training_data(dataset, tokenizer, MAX_LENGTH)

    model = configure_model(MODEL_NAME)

    train_model(model, tokenizer, tokenized_dataset, OUTPUT_DIR, BATCH_SIZE, GRADIENT_ACCUMULATION_STEPS,
               LEARNING_RATE, EPOCHS, LOGGING_STEPS, SAVE_STRATEGY, WANDB_PROJECT, WANDB_RUN_NAME, SEED)

# Generation

In [None]:
def generate_text(model, tokenizer, prompt, max_length=100):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    outputs = model.generate(
        **inputs,
        max_length=max_length,
        num_return_sequences=1,
        return_full_text=False,
        temperature=0.7
    )
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

In [None]:
print("\n--- Inference Examples ---")

prompts = [
    "Мы пойдём другим путём?",
    "Каждая кухарка должна научиться управлять государством",
    "### Stalin: \nРасскажи мне о Ленине",
    "### Lenin:\nКакова цель революции?"
]

for prompt in prompts:
    generated_text = generate_text(model, tokenizer, prompt)
    print(f"Prompt: {prompt}\nGenerated Text: {generated_text}\n")

In [None]:
prompts = [
    "Мы пойдём другим путём?",
    "Каждая кухарка должна научиться управлять государством",
    "<|author:Stalin|> Расскажи мне о Ленине", 
    "<|author:Lenin|> Какова цель революции?" 
]

for prompt in prompts:
    generated_text = generate_text(model, tokenizer, prompt)

    match = re.match(r"<\|author:(.*?)\|> ", generated_text)
    if match:
        author = match.group(1)
        generated_text = generated_text[match.end():].strip()
        print(f"Generated Author: {author}")

    print(f"Prompt: {prompt}")
    print(f"Generated Text:\n{generated_text}\n---")