In [2]:
import os
import re
import math
import random
import torch
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, TrainingArguments, Trainer
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, PeftModel

torch.manual_seed(42)
random.seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True

MODEL_NAME = "Qwen/Qwen3-0.6B-Base"
DATASET_NAME = "igorktech/anekdots"
OUT_DIR = "qwen3_06b_anekdots_lora"
MAX_LEN = 384
TRAIN_SIZE = 80000

In [3]:
def get_compute_dtype():
    if not torch.cuda.is_available():
        return torch.float32
    major, minor = torch.cuda.get_device_capability(0)
    if major >= 8:
        return torch.bfloat16
    return torch.float16

compute_dtype = get_compute_dtype()

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=compute_dtype,
)

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

In [4]:
base_model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    device_map="auto",
    quantization_config=bnb_config,
    torch_dtype=compute_dtype,
)

base_model.config.use_cache = False
base_model = prepare_model_for_kbit_training(base_model)

lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
)

model = get_peft_model(base_model, lora_config)
model.print_trainable_parameters()

config.json:   0%|          | 0.00/727 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors:   0%|          | 0.00/1.19G [00:00<?, ?B/s]

KeyboardInterrupt: 

In [None]:
def normalize_text(s):
    s = s.replace("\r", "\n")
    s = re.sub(r"\n{2,}", "\n", s)
    s = re.sub(r"[ \t]+", " ", s)
    s = s.strip()
    return s

def build_prompt(prefix):
    return f"Продолжи анекдот на русском языке.\nЗатравка: {prefix}\nАнекдот:"

def find_sublist(haystack, needle):
    if len(needle) == 0:
        return -1
    for i in range(0, len(haystack) - len(needle) + 1):
        if haystack[i : i + len(needle)] == needle:
            return i
    return -1

response_marker = "Анекдот:"
response_ids = tokenizer(response_marker, add_special_tokens=False).input_ids

In [None]:
ds = load_dataset(DATASET_NAME, split="train")
ds = ds.shuffle(seed=42).select(range(min(TRAIN_SIZE, len(ds))))

def make_pair(example, idx):
    text = normalize_text(example["text"])
    if len(text) < 40:
        return {"text": None}
    ids = tokenizer(text, add_special_tokens=False).input_ids
    if len(ids) < 60:
        return {"text": None}
    rng = random.Random(1000003 + idx)
    max_cut = min(32, len(ids) - 24)
    if max_cut <= 8:
        return {"text": None}
    cut = rng.randint(8, max_cut)
    prefix = tokenizer.decode(ids[:cut], skip_special_tokens=True).strip()
    cont = tokenizer.decode(ids[cut:], skip_special_tokens=True).strip()
    if len(prefix) < 8 or len(cont) < 20:
        return {"text": None}
    prompt = build_prompt(prefix)
    full = prompt + " " + cont + tokenizer.eos_token
    return {"text": full}

ds_pairs = ds.map(make_pair, with_indices=True, remove_columns=ds.column_names, num_proc=os.cpu_count() or 2)
ds_pairs = ds_pairs.filter(lambda x: x["text"] is not None)
ds_pairs

In [None]:
def tokenize_fn(batch):
    return tokenizer(batch["text"], truncation=True, max_length=MAX_LEN, add_special_tokens=False)

tok_ds = ds_pairs.map(tokenize_fn, batched=True, remove_columns=["text"], num_proc=os.cpu_count() or 2)
tok_ds

In [None]:
class CompletionCollator:
    def __init__(self, tokenizer, response_ids):
        self.tokenizer = tokenizer
        self.response_ids = response_ids

    def __call__(self, features):
        batch = self.tokenizer.pad(features, return_tensors="pt")
        labels = batch["input_ids"].clone()
        for i in range(labels.size(0)):
            ids = labels[i].tolist()
            start = find_sublist(ids, self.response_ids)
            if start == -1:
                labels[i, :] = -100
            else:
                end = start + len(self.response_ids)
                labels[i, :end] = -100
        batch["labels"] = labels
        return batch

collator = CompletionCollator(tokenizer, response_ids)

In [None]:
bf16 = torch.cuda.is_available() and compute_dtype == torch.bfloat16
fp16 = torch.cuda.is_available() and compute_dtype == torch.float16

args = TrainingArguments(
    output_dir=OUT_DIR,
    per_device_train_batch_size=16,
    gradient_accumulation_steps=4,
    learning_rate=2e-4,
    lr_scheduler_type="cosine",
    warmup_ratio=0.03,
    num_train_epochs=1,
    logging_steps=20,
    save_steps=500,
    save_total_limit=2,
    bf16=bf16,
    fp16=fp16,
    optim="paged_adamw_8bit",
    weight_decay=0.01,
    max_grad_norm=1.0,
    report_to="none",
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=tok_ds,
    data_collator=collator,
)

trainer.train()

In [None]:
trainer.model.save_pretrained(OUT_DIR)
tokenizer.save_pretrained(OUT_DIR)

In [None]:
def read_prefixes(path="prefixes.txt"):
    items = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            s = line.strip()
            if not s:
                continue
            m = re.match(r"^(\d+)\s+(.*)$", s)
            if m:
                idx = int(m.group(1))
                pref = m.group(2).strip()
                items.append((idx, pref))
            else:
                items.append((len(items), s))
    return items

prefixes = read_prefixes("prefixes.txt")
len(prefixes), prefixes[:5]

In [None]:
base = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    device_map="auto",
    quantization_config=bnb_config,
    torch_dtype=compute_dtype,
)
model = PeftModel.from_pretrained(base, OUT_DIR)
model.eval()

tokenizer = AutoTokenizer.from_pretrained(OUT_DIR, use_fast=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"

In [None]:
bad_substrings = [
    "Затравка:",
    "Анекдот:",
    "<think>",
    "</think>",
]

def clean_continuation(s, prefix):
    s = s.replace("\r", "\n")
    for b in bad_substrings:
        s = s.replace(b, " ")
    s = re.sub(r"\s+", " ", s).strip()
    if s.startswith(prefix):
        s = s[len(prefix):].lstrip(" .,!?:;—-")
    s = s.strip()
    return s

def generate_continuations(prefix, n=5, max_new_tokens=96, temperature=0.9, top_p=0.92, top_k=50, repetition_penalty=1.08):
    prompt = build_prompt(prefix)
    prompts = [prompt] * n
    enc = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)
    input_lens = enc["attention_mask"].sum(dim=1).tolist()
    with torch.no_grad():
        out = model.generate(
            **enc,
            do_sample=True,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            top_p=top_p,
            top_k=top_k,
            repetition_penalty=repetition_penalty,
            no_repeat_ngram_size=3,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.pad_token_id,
        )
    results = []
    for i in range(out.size(0)):
        gen_ids = out[i, input_lens[i]:]
        txt = tokenizer.decode(gen_ids, skip_special_tokens=True)
        txt = clean_continuation(txt, prefix)
        txt = txt.replace("\n", " ").strip()
        txt = re.sub(r"\s+", " ", txt).strip()
        if len(txt) >= 8:
            results.append(txt)
    return results

In [None]:
N_PER_PREFIX = 5
all_lines = []

for idx, pref in prefixes:
    conts = generate_continuations(pref, n=N_PER_PREFIX)
    if len(conts) == 0:
        conts = ["..."]
    for c in conts:
        all_lines.append(f"{idx} {c}")

out_path = "anekdots_submission.txt"
with open(out_path, "w", encoding="utf-8") as f:
    for line in all_lines:
        f.write(line.strip() + "\n")

out_path, len(all_lines), all_lines[:10]

In [None]:
with open("anekdots_submission.txt", "r", encoding="utf-8") as f:
    for _ in range(25):
        print(f.readline().rstrip())