In [None]:
!pip -q install -U "transformers>=4.51.0" "datasets>=2.18.0" "accelerate>=0.33.0" "peft>=0.12.0" "bitsandbytes>=0.43.0" "safetensors>=0.4.3" "tqdm>=4.66.0"

import os, re, random, time
import torch
from tqdm.auto import tqdm
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, TrainingArguments, Trainer, DataCollatorForLanguageModeling, set_seed
from transformers.trainer_callback import TrainerCallback
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, PeftModel

os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["WANDB_DISABLED"] = "true"

seed = 42
set_seed(seed)
random.seed(seed)

print("cuda:", torch.cuda.is_available())
assert torch.cuda.is_available(), "GPU не включен: Runtime -> Change runtime type -> GPU"
print("gpu:", torch.cuda.get_device_name(0))
!nvidia-smi

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

prefix_path = "/content/prefixes.txt"
assert os.path.exists(prefix_path), f"Нет файла {prefix_path}"

def parse_prefixes(path):
    items = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            s = line.strip("\n").strip()
            if not s:
                continue
            m = re.match(r"^\s*(\d+)\s+(.*)\s*$", s)
            if m:
                idx = int(m.group(1))
                pref = m.group(2).strip()
                items.append((idx, pref))
            else:
                items.append((len(items), s))
    items.sort(key=lambda x: x[0])
    return items

prefixes = parse_prefixes(prefix_path)

def build_prefix_regex(prefix):
    p = prefix.strip()
    p = p.replace("—", "-").replace("–", "-")
    esc = re.escape(p)
    esc = esc.replace(r"\ ", r"\s+")
    esc = esc.replace("ё", "[её]").replace("Ё", "[ЕЁ]")
    esc = esc.replace(r"\-", r"[-—–]")
    return re.compile(r"^\s*" + esc + r"(?:(?:\s*[,:\-—–]\s*)|\s+|$)", re.IGNORECASE)

patterns = {idx: build_prefix_regex(pref) for idx, pref in prefixes}

def cleanup_one_line(s):
    s = s.replace("\r", " ").replace("\n", " ").replace("\t", " ")
    s = re.sub(r"\s+", " ", s).strip()
    s = s.strip(" ,;:-—–")
    s = re.sub(r"\s+", " ", s).strip()
    return s

def safe_continuation(full, prefix):
    full_s = full.strip()
    pref_s = prefix.strip()
    if full_s.lower().startswith(pref_s.lower()):
        cont = full_s[len(pref_s):]
    else:
        pos = full_s.lower().find(pref_s.lower())
        cont = full_s[pos + len(pref_s):] if pos >= 0 else full_s
    return cleanup_one_line(cont)

def score_candidate(s):
    if not s:
        return -1e9
    if len(s) < 25:
        return -200 + len(s)
    if len(s) > 280:
        return -50 - (len(s) - 280)
    words = re.findall(r"[A-Za-zА-Яа-яЁё0-9]+", s)
    uniq = len(set(w.lower() for w in words)) / max(1, len(words))
    rep_pen = 0
    if len(words) >= 10:
        for i in range(len(words) - 6):
            if words[i:i+3] == words[i+3:i+6]:
                rep_pen += 1
    end_bonus = 10 if re.search(r"[.!?…]$", s) else 0
    return 60 * uniq + end_bonus - 25 * rep_pen - 0.03 * abs(len(s) - 140)

per_prefix_real_limit = 60
general_limit = 6000
max_scan = 90000

counts = {idx: 0 for idx, _ in prefixes}
active = set(counts.keys())
real_texts = []
general_texts = []

print("Stage: streaming/filtering dataset")
ds_stream = load_dataset("igorktech/anekdots", split="train", streaming=True)

pbar = tqdm(total=max_scan, desc="scan", unit="ex")
scanned = 0
for ex in ds_stream:
    scanned += 1
    pbar.update(1)
    if scanned >= max_scan:
        break

    t = ex.get("text", None)
    if not isinstance(t, str):
        continue
    t = t.strip()
    if len(t) < 60 or len(t) > 900:
        continue
    if "\u0000" in t:
        continue
    mark = ex.get("total_mark", None)
    if isinstance(mark, int) and mark < 3:
        continue

    matched = False
    if active:
        for idx in list(active):
            if counts[idx] >= per_prefix_real_limit:
                active.discard(idx)
                continue
            if patterns[idx].match(t):
                real_texts.append(t)
                counts[idx] += 1
                matched = True
                if counts[idx] >= per_prefix_real_limit:
                    active.discard(idx)
                break

    if (not matched) and (len(general_texts) < general_limit):
        if isinstance(mark, int) and mark >= 10:
            general_texts.append(t)

    if scanned % 2000 == 0:
        pbar.set_postfix({"real": len(real_texts), "gen": len(general_texts), "active": len(active)})

    if (len(general_texts) >= general_limit) and (not active):
        break

pbar.close()
print("Stage done:", {"scanned": scanned, "real": len(real_texts), "gen": len(general_texts), "active": len(active)})

train_texts = real_texts + general_texts
random.shuffle(train_texts)

adapter_dir = "/content/qwen3_0p6b_ru_jokes_lora"
os.makedirs(adapter_dir, exist_ok=True)

base_model_id = "Qwen/Qwen3-0.6B-Base"
print("Model:", base_model_id)

tokenizer = AutoTokenizer.from_pretrained(base_model_id, use_fast=True, trust_remote_code=True)
if tokenizer.pad_token_id is None:
    tokenizer.pad_token = tokenizer.eos_token

use_bf16 = torch.cuda.is_available() and torch.cuda.is_bf16_supported()
compute_dtype = torch.bfloat16 if use_bf16 else torch.float16

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=compute_dtype
)

def load_base():
    return AutoModelForCausalLM.from_pretrained(
        base_model_id,
        quantization_config=bnb_config,
        torch_dtype=compute_dtype,
        device_map="auto",
        trust_remote_code=True
    )

def generate_candidates(m, tok, prefix, n, max_new_tokens=90, temperature=1.0, top_p=0.9, rep_pen=1.12):
    inp = tok(prefix, return_tensors="pt")
    input_ids = inp["input_ids"].to(m.device)
    attn = inp["attention_mask"].to(m.device)
    input_ids = input_ids.repeat(n, 1)
    attn = attn.repeat(n, 1)
    with torch.no_grad():
        out = m.generate(
            input_ids=input_ids,
            attention_mask=attn,
            do_sample=True,
            temperature=temperature,
            top_p=top_p,
            repetition_penalty=rep_pen,
            max_new_tokens=max_new_tokens,
            pad_token_id=tok.eos_token_id,
            eos_token_id=tok.eos_token_id
        )
    return [tok.decode(out[i], skip_special_tokens=True) for i in range(out.size(0))]

adapter_exists = os.path.exists(os.path.join(adapter_dir, "adapter_config.json"))

if not adapter_exists:
    model_for_aug = load_base()
    model_for_aug.eval()
    aug_per_prefix = 1
    aug_texts = []
    for _, pref in tqdm(prefixes, desc="augment", unit="pref"):
        fulls = generate_candidates(model_for_aug, tokenizer, pref, n=aug_per_prefix, max_new_tokens=80, temperature=1.02, top_p=0.9, rep_pen=1.10)
        for ft in fulls:
            ft = cleanup_one_line(ft)
            if len(ft) >= len(pref) + 20:
                aug_texts.append(ft)
    train_texts = train_texts + aug_texts
    random.shuffle(train_texts)

    model = load_base()
    model = prepare_model_for_kbit_training(model)
    model.config.use_cache = False

    lora_cfg = LoraConfig(
        r=16,
        lora_alpha=32,
        lora_dropout=0.05,
        bias="none",
        task_type="CAUSAL_LM",
        target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"]
    )
    model = get_peft_model(model, lora_cfg)

    train_ds = Dataset.from_dict({"text": train_texts})

    max_len = 192
    def tok_fn(batch):
        return tokenizer(batch["text"], truncation=True, max_length=max_len, padding=False)

    tok_ds = train_ds.map(tok_fn, batched=True, remove_columns=["text"], desc="tokenize")
    data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

    class TrainTimeCallback(TrainerCallback):
        def __init__(self):
            self.t0 = None
        def on_train_begin(self, args, state, control, **kwargs):
            self.t0 = time.time()
            print(f"train_begin: max_steps={state.max_steps} epochs={args.num_train_epochs} bs={args.per_device_train_batch_size} ga={args.gradient_accumulation_steps} bf16={args.bf16} fp16={args.fp16}")
        def on_log(self, args, state, control, logs=None, **kwargs):
            if logs is None:
                return
            now = time.time()
            elapsed = now - (self.t0 or now)
            step = state.global_step
            max_steps = state.max_steps or 0
            eta = None
            if max_steps and step:
                eta = elapsed * (max_steps - step) / max(1, step)
            loss = logs.get("loss", None)
            lr = logs.get("learning_rate", None)
            ep = state.epoch
            s = f"step={step}/{max_steps} epoch={ep:.3f}" if ep is not None else f"step={step}/{max_steps}"
            if loss is not None:
                s += f" loss={loss:.4f}"
            if lr is not None:
                s += f" lr={lr:.2e}"
            s += f" elapsed={elapsed/60:.1f}m"
            if eta is not None:
                s += f" eta={eta/60:.1f}m"
            print(s)
        def on_train_end(self, args, state, control, **kwargs):
            if self.t0 is None:
                return
            elapsed = time.time() - self.t0
            print(f"train_end: steps={state.global_step} time={elapsed/60:.1f}m")

    args = TrainingArguments(
        output_dir=adapter_dir,
        per_device_train_batch_size=4,
        gradient_accumulation_steps=4,
        num_train_epochs=1,
        learning_rate=2e-4,
        warmup_ratio=0.03,
        lr_scheduler_type="cosine",
        logging_strategy="steps",
        logging_steps=10,
        save_strategy="no",
        bf16=use_bf16,
        fp16=not use_bf16,
        optim="paged_adamw_8bit",
        report_to="none",
        dataloader_num_workers=2,
        dataloader_pin_memory=True,
        group_by_length=True,
        disable_tqdm=False
    )

    trainer = Trainer(
        model=model,
        args=args,
        train_dataset=tok_ds,
        data_collator=data_collator,
        callbacks=[TrainTimeCallback()]
    )

    print("Stage: training")
    trainer.train()
    model.save_pretrained(adapter_dir)
    tokenizer.save_pretrained(adapter_dir)

base_model = load_base()
model = PeftModel.from_pretrained(base_model, adapter_dir)
model.eval()

def generate_batch(prefix, n, max_new_tokens=110, temperature=1.03, top_p=0.9, rep_pen=1.12):
    inp = tokenizer(prefix, return_tensors="pt")
    input_ids = inp["input_ids"].to(model.device)
    attn = inp["attention_mask"].to(model.device)
    input_ids = input_ids.repeat(n, 1)
    attn = attn.repeat(n, 1)
    with torch.no_grad():
        out = model.generate(
            input_ids=input_ids,
            attention_mask=attn,
            do_sample=True,
            temperature=temperature,
            top_p=top_p,
            repetition_penalty=rep_pen,
            max_new_tokens=max_new_tokens,
            pad_token_id=tokenizer.eos_token_id,
            eos_token_id=tokenizer.eos_token_id
        )
    return [tokenizer.decode(out[i], skip_special_tokens=True) for i in range(out.size(0))]

submission_path = "/content/submission.txt"
num_lines_per_prefix = 3
candidates_per_prefix = 12

print("Stage: generating submission")
with open(submission_path, "w", encoding="utf-8") as f:
    for idx, pref in tqdm(prefixes, desc="prefixes", unit="pref"):
        fulls = generate_batch(pref, candidates_per_prefix)
        cands = []
        for ft in fulls:
            cont = safe_continuation(ft, pref)
            if cont:
                cands.append(cont)
        cands = list(dict.fromkeys(cands))
        cands.sort(key=score_candidate, reverse=True)
        picked = cands[:num_lines_per_prefix] if cands else ["..."]
        for cont in picked:
            f.write(f"{idx} {cont}\n")

print("saved:", submission_path)
with open(submission_path, "r", encoding="utf-8") as f:
    for _ in range(15):
        line = f.readline()
        if not line:
            break
        print(line.rstrip("\n"))

вырезал с префиксом 0