In [None]:
!pip install -r requirements.txt
!pip install numpy==1.26.4 scikit-learn==1.3.2 --force-reinstall --no-cache-dir
!pip install --upgrade peft


In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from accelerate import Accelerator
import json
import random
from tqdm import tqdm
import math
import re
import os
from peft import PeftModel, LoraConfig, get_peft_model
from huggingface_hub import login

# ==== CONFIGURATION ====
login(token="hf_...")
GEMMA_MODEL = "google/gemma-3-4b-pt"
CHECKPOINT_DIR = "./persistent_volume/last_checkpoint/"
BOOKS_PATH = "books.jsonl"
VAL_SPLIT = 0.0025
BATCH_SIZE = 1
MAX_CHUNK_LENGTH = 256
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32


In [None]:

# ==== LOAD MODELS ====
print("Loading Gemma tokenizer and model...")
try:
    tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_DIR)
except Exception:
    tokenizer = AutoTokenizer.from_pretrained(GEMMA_MODEL)

tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

base_model = AutoModelForCausalLM.from_pretrained(
    GEMMA_MODEL,
    torch_dtype=DTYPE,
    device_map="auto"
)

model = PeftModel.from_pretrained(base_model, CHECKPOINT_DIR)
model.eval()
for p in model.parameters():
    p.requires_grad = False

PROJECTOR_OUT_DIM = model.get_input_embeddings().embedding_dim
print(f"Gemma embedding dim: {PROJECTOR_OUT_DIM}")

accelerator = Accelerator(mixed_precision="bf16" if DTYPE == torch.bfloat16 else "no")


In [9]:
# ==== DATA PREPARATION FUNCTIONS ====
def clean_books_texts(texts: list[str]) -> list[str]:
    def clean_text(text: str) -> str:
        # Убираем символы страниц и мусор
        text = re.sub(r'[\f\x0c]', ' ', text)  # page breaks
        text = re.sub(r'[*=_\-]{2,}', ' ', text)  # repeated chars like '====' or '***'
        text = re.sub(r'\n+', ' ', text)  # newlines
        text = re.sub(r'\s{2,}', ' ', text)  # multiple spaces
        text = text.strip()

        # Обрезаем 10% сверху и снизу
        total_len = len(text)
        cut_len = total_len // 10
        if total_len > 2 * cut_len:
            text = text[cut_len:-cut_len]
        return text.strip()

    return [clean_text(t) for t in texts]

def chunk_text(text, chunk_size=MAX_CHUNK_LENGTH, tokenizer=tokenizer):
    """Разбивает текст на отрывки по chunk_size токенов."""
    tokens = tokenizer.encode(text, truncation=False)
    chunks = [tokenizer.decode(tokens[i:i+chunk_size]) for i in range(0, len(tokens), chunk_size) if len(tokens[i:i+chunk_size]) >= 10]
    return chunks

def process_book(text, tokenizer=tokenizer):
    """Обрабатывает текст книги, создавая примеры для обучения."""
    chunks = chunk_text(text, tokenizer=tokenizer)
    examples = []
    # Убедимся, что достаточно чанков для создания хотя бы одного примера (10 history + 1 current + 1 target = 12)
    if len(chunks) < 12:
        return []
    # Шаг 6 для перекрытия
    for i in range(0, len(chunks) - 12 + 1, 6):
        examples.append({
            "user_history": chunks[i:i+10],  # 10 отрывков контекста
            "current_input": chunks[i+10],   # Промпт
            "target": chunks[i+11]           # Продолжение
        })
    return examples

# ==== DATASET CLASS ====
class StyleTransferDataset(Dataset):
    def __init__(self, texts, tokenizer):
        self.samples = []
        print("Обрабатываем данные для датасета...")
        for text in tqdm(texts, desc="Обработка книг"):
            self.samples.extend(process_book(text, tokenizer))
        print(f"Всего создано примеров: {len(self.samples)}")
        if len(self.samples) == 0:
            print("ВНИМАНИЕ: Датасет пуст! Проверьте входные данные и параметры chunking/processing.")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.samples[idx]


# ==== COLLATE FUNCTION ====
def collate_fn(batch):
    user_history_batch = [x["user_history"] for x in batch]
    current_input_texts = [x["current_input"] for x in batch]
    target_texts = [x["target"] for x in batch]

    # Tokenize with Gemma tokenizer
    tokenized_batch = tokenizer(
        current_input_texts,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=MAX_CHUNK_LENGTH
    )

    target_tokenized = tokenizer(
        target_texts,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=MAX_CHUNK_LENGTH
    )

    return {
        "user_history_texts": user_history_batch,
        "current_input_texts": current_input_texts,
        "input_ids": tokenized_batch["input_ids"],
        "attention_mask": tokenized_batch["attention_mask"],
        "target_ids": target_tokenized["input_ids"],
    }

# ==== STYLE EMBEDDING WITH GEMMA ENCODER ====
def get_style_embedding(user_history_batch, current_input_batch, model, tokenizer, device, dtype):
    batch_style_embs = []

    for i in range(len(user_history_batch)):
        history_chunks = user_history_batch[i]
        current_input = current_input_batch[i]

        # Encode current input
        current_inputs = tokenizer(
            [current_input],
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=MAX_CHUNK_LENGTH
        ).to(device)

        with torch.no_grad():
            outputs = model(**current_inputs, output_hidden_states=True)
            current_hidden = outputs.hidden_states[-1]  # Last layer
            # Mean pooling excluding padding
            current_mask = current_inputs['attention_mask']
            current_emb = (current_hidden * current_mask.unsqueeze(-1)).sum(dim=1) / current_mask.sum(dim=1, keepdim=True)

        # Encode history chunks
        history_embs = []
        for chunk in history_chunks:
            chunk_inputs = tokenizer(
                [chunk],
                return_tensors="pt",
                padding=True,
                truncation=True,
                max_length=MAX_CHUNK_LENGTH
            ).to(device)

            with torch.no_grad():
                outputs = model(**chunk_inputs, output_hidden_states=True)
                chunk_hidden = outputs.hidden_states[-1]
                chunk_mask = chunk_inputs['attention_mask']
                chunk_emb = (chunk_hidden * chunk_mask.unsqueeze(-1)).sum(dim=1) / chunk_mask.sum(dim=1, keepdim=True)
                history_embs.append(chunk_emb)

        # Compute attention weights
        history_embs_tensor = torch.cat(history_embs, dim=0)
        weights = torch.softmax(torch.matmul(history_embs_tensor, current_emb.t()).squeeze(), dim=0)

        # Weighted sum
        style_emb = torch.sum(history_embs_tensor * weights.unsqueeze(1), dim=0)
        batch_style_embs.append(style_emb)

    return torch.stack(batch_style_embs)

# ==== DATA LOADING ====
print("Загружаем данные из books.jsonl...")
try:
    with open(BOOKS_PATH, "r", encoding="utf-8") as f:
        texts = [json.loads(line)["text"] for line in f if "text" in json.loads(line)]
        texts = clean_books_texts(texts)
    print(f"Загружено {len(texts)} книг.")
except FileNotFoundError:
    print(f"Ошибка: Файл {BOOKS_PATH} не найден. Убедитесь, что он существует.")
    exit()
except json.JSONDecodeError:
    print(f"Ошибка: Некорректный формат JSON в файле {BOOKS_PATH}.")
    exit()

random.shuffle(texts)
split_idx = int(len(texts) * (1 - VAL_SPLIT))
train_texts = texts[:split_idx]
val_texts = texts[split_idx:]

train_dataset = StyleTransferDataset(train_texts, tokenizer)
val_dataset = StyleTransferDataset(val_texts, tokenizer)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn)
print("Данные готовы к обучению!")


Загружаем данные из books.jsonl...
Загружено 300 книг.
Обрабатываем данные для датасета...


Обработка книг: 100%|██████████| 299/299 [03:52<00:00,  1.29it/s]


Всего создано примеров: 41120
Обрабатываем данные для датасета...


Обработка книг: 100%|██████████| 1/1 [00:00<00:00,  2.25it/s]

Всего создано примеров: 162
Данные готовы к обучению!





In [10]:

# ==== EVALUATION FUNCTION ====
def evaluate(model, dataloader, tokenizer, device, dtype):
    model.eval()
    total_loss, total_tokens = 0, 0

    for batch in tqdm(dataloader, desc="Evaluating"):
        # Get style embedding
        style_emb = get_style_embedding(
            batch["user_history_texts"],
            batch["current_input_texts"],
            model,
            tokenizer,
            device,
            dtype
        ).unsqueeze(1)  # [batch_size, 1, hidden_dim]

        # Prepare input embeddings
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        target_ids = batch["target_ids"].to(device)

        # Get base embeddings
        input_embeds = model.get_input_embeddings()(input_ids)

        # Concatenate style embedding to each token
        # style_emb: [batch_size, 1, hidden_dim]
        # input_embeds: [batch_size, seq_len, hidden_dim]
        style_emb_repeated = style_emb.repeat(1, input_embeds.size(1), 1)
        combined_embeds = input_embeds + style_emb_repeated

        # Prepare labels - only compute loss on target text
        labels = torch.full_like(input_ids, -100)
        labels = torch.cat([labels, target_ids], dim=1)

        # Prepare full input
        full_embeds = torch.cat([
            combined_embeds,
            model.get_input_embeddings()(target_ids)
        ], dim=1)

        # Create attention mask
        full_attention_mask = torch.cat([
            attention_mask,
            torch.ones_like(target_ids)
        ], dim=1)

        with torch.no_grad():
            outputs = model(
                inputs_embeds=full_embeds,
                attention_mask=full_attention_mask,
                labels=labels
            )
            loss = outputs.loss

        # Calculate number of target tokens
        num_target_tokens = (target_ids != tokenizer.pad_token_id).sum().item()
        total_loss += loss.item() * num_target_tokens
        total_tokens += num_target_tokens

    avg_loss = total_loss / total_tokens if total_tokens > 0 else 0
    ppl = math.exp(avg_loss) if avg_loss > 0 else float('inf')
    return avg_loss, ppl


# Move model to device
model = model.to(DEVICE)

print("Evaluating validation set...")
val_loss, val_ppl = evaluate(model, val_loader, tokenizer, DEVICE, DTYPE)
print(f"Validation Loss: {val_loss:.4f}, Perplexity: {val_ppl:.2f}")

# ==== TEXT GENERATION EXAMPLE ====
def generate_with_style(model, tokenizer, user_history, prompt, device, max_length=100):
    """Generate text continuation with style"""
    # Get style embedding
    style_emb = get_style_embedding(
        [user_history],
        [prompt],
        model,
        tokenizer,
        device,
        DTYPE
    ).unsqueeze(1)

    # Tokenize prompt
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    input_ids = inputs["input_ids"]
    attention_mask = inputs["attention_mask"]

    # Get base embeddings
    input_embeds = model.get_input_embeddings()(input_ids)

    # Add style to each token
    style_emb_repeated = style_emb.repeat(1, input_embeds.size(1), 1)
    combined_embeds = input_embeds + style_emb_repeated

    # Generate continuation
    output = model.generate(
        inputs_embeds=combined_embeds,
        attention_mask=attention_mask,
        max_length=max_length,
        do_sample=True,
        temperature=0.8,
        top_p=0.9,
        pad_token_id=tokenizer.eos_token_id
    )

    return tokenizer.decode(output[0], skip_special_tokens=True)

# Example usage
sample_history = ["Это пример истории 1", "Это пример истории 2"] * 5
sample_prompt = "Начни рассказ:"
generated = generate_with_style(model, tokenizer, sample_history, sample_prompt, DEVICE)
print("\nGenerated Text with Style:")
print(generated)

Evaluating validation set...


Evaluating: 100%|██████████| 162/162 [02:51<00:00,  1.06s/it]


Validation Loss: 11.6229, Perplexity: 111625.34


W0716 10:09:21.632000 1032 site-packages/torch/_inductor/utils.py:1137] [0/0] Not enough SMs to use max_autotune_gemm mode



Generated Text with Style:
 болж
...



	

	

	

	

	

	

	

	

	

	

	

	

	

	

	

	

	

	

	

	

	

	

	

	

	

	

	

	

	

	

	

	

	

	

	

	

	

	

	

	

	

	

	


