In [8]:
import re
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim
from transformers import AutoTokenizer, pipeline
from sklearn.model_selection import train_test_split
from rouge_score import rouge_scorer
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os


# Устройство
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Токенизатор Rubert для LSTM
tokenizer_lstm = AutoTokenizer.from_pretrained("DeepPavlov/rubert-base-cased")

def clean_text(text):
    text = text.lower()
    text = re.sub(r'http\S+|www\S+', '', text)
    text = re.sub(r'@\w+', '', text)
    text = re.sub(r'[^\w\sа-яё]', '', text, flags=re.UNICODE)
    text = re.sub(r'\s+', ' ', text).strip()
    return text

def clean_and_tokenize(text):
    text = clean_text(text)
    return tokenizer_lstm.tokenize(text)

def create_training_samples(tokens):
    if len(tokens) < 2:
        return None
    return tokens[:-1], tokens[1:]

def process_file_and_save(input_path):
    with open(input_path, 'r', encoding='utf-8') as f:
        texts = f.readlines()

    clean_texts = [clean_text(t) for t in texts if t.strip()]
    df_raw = pd.DataFrame({'text': clean_texts})
    os.makedirs('/home/assistant/text-autocomplete/data', exist_ok=True)
    df_raw.to_csv('/home/assistant/text-autocomplete/data/raw_dataset.csv', index=False, encoding='utf-8')
    print("Raw dataset saved to data/raw_dataset.csv")

    samples = []
    for text in clean_texts:
        tokens = tokenizer_lstm.tokenize(text)
        pair = create_training_samples(tokens)
        if pair:
            X, Y = pair
            samples.append({'X': X, 'Y': Y})

    df_token = pd.DataFrame(samples)
    df_token['X_str'] = df_token['X'].apply(lambda x: ' '.join(x))
    df_token['Y_str'] = df_token['Y'].apply(lambda x: ' '.join(x))
    df_token[['X_str','Y_str']].to_csv('/home/assistant/text-autocomplete/data/dataset_processed.csv', index=False, encoding='utf-8')
    print("Tokenized dataset saved to data/dataset_processed.csv")
    return samples

def tokens_to_indices(samples, vocab):
    X_indices, Y_indices = [], []
    for sample in samples:
        x_idx = [vocab.get(token, 0) for token in sample['X']]
        y_idx = [vocab.get(token, 0) for token in sample['Y']]
        X_indices.append(x_idx)
        Y_indices.append(y_idx)
    return X_indices, Y_indices

def pad_sequences_torch(sequences, max_len):
    padded = []
    for seq in sequences:
        s = seq[:max_len]
        padded_seq = s + [0]*(max_len - len(s))
        padded.append(torch.tensor(padded_seq, dtype=torch.long))
    return torch.stack(padded)

def compute_rouge(reference, prediction):
    scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2'], use_stemmer=True)
    scores = scorer.score(reference, prediction)
    return scores['rouge1'].fmeasure, scores['rouge2'].fmeasure

input_file = '/home/assistant/text-autocomplete/data/tweets.txt'

samples = process_file_and_save(input_file)
all_tokens = [token for sample in samples for token in sample['X']] + [token for sample in samples for token in sample['Y']]
vocab = {token: idx+1 for idx, token in enumerate(sorted(set(all_tokens)))}
vocab_size = len(vocab) + 1
X_indices, Y_indices = tokens_to_indices(samples, vocab)
max_len = max(len(x) for x in X_indices)
X_pad = pad_sequences_torch(X_indices, max_len)
Y_pad = pad_sequences_torch(Y_indices, max_len)

class TextDataset(Dataset):
    def __init__(self, X, Y):
        self.X = X
        self.Y = Y
    def __len__(self):
        return len(self.X)
    def __getitem__(self, idx):
        return self.X[idx], self.Y[idx]

X_train, X_temp, Y_train, Y_temp = train_test_split(X_pad, Y_pad, test_size=0.2, random_state=42)
X_val, X_test, Y_val, Y_test = train_test_split(X_temp, Y_temp, test_size=0.5, random_state=42)

os.makedirs('/home/assistant/text-autocomplete/data', exist_ok=True)

def tensor_to_str_list(tensor):
    return [' '.join(map(str, seq.tolist())) for seq in tensor]

pd.DataFrame({'X': tensor_to_str_list(X_train), 'Y': tensor_to_str_list(Y_train)}).to_csv('/home/assistant/text-autocomplete/data/train.csv', index=False, encoding='utf-8')
pd.DataFrame({'X': tensor_to_str_list(X_val),   'Y': tensor_to_str_list(Y_val)}).to_csv('/home/assistant/text-autocomplete/data/val.csv', index=False, encoding='utf-8')
pd.DataFrame({'X': tensor_to_str_list(X_test),  'Y': tensor_to_str_list(Y_test)}).to_csv('/home/assistant/text-autocomplete/data/test.csv', index=False, encoding='utf-8')

train_dataset = TextDataset(X_train, Y_train)
val_dataset = TextDataset(X_val, Y_val)
test_dataset = TextDataset(X_test, Y_test)

batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

class LSTMModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)
    def forward(self, x):
        emb = self.embedding(x)
        lstm_out, _ = self.lstm(emb)
        out = self.fc(lstm_out)
        return out
    def generate(self, tokenizer, vocab, seed_text, max_length=20):
        self.eval()
        tokens = tokenizer.tokenize(seed_text.lower())
        with torch.no_grad():
            for _ in range(max_length):
                x_idx = [vocab.get(token, 0) for token in tokens]
                x_tensor = torch.tensor([x_idx], dtype=torch.long).to(next(self.parameters()).device)
                output = self.forward(x_tensor)
                last_logits = output[0, len(tokens) - 1]
                next_id = torch.argmax(last_logits).item()
                next_token = None
                for tok, idx in vocab.items():
                    if idx == next_id:
                        next_token = tok
                        break
                if next_token is None or next_token == '[SEP]':
                    break
                tokens.append(next_token)
                if len(tokens) >= max_length:
                    break
        return ' '.join(tokens)

embedding_dim = 128
hidden_dim = 256
model = LSTMModel(vocab_size, embedding_dim, hidden_dim).to(device)

criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = optim.Adam(model.parameters())

train_losses = []
val_losses = []

def train_epoch(model, dataloader, criterion, optimizer):
    model.train()
    total_loss = 0
    for x_batch, y_batch in dataloader:
        x_batch, y_batch = x_batch.to(device), y_batch.to(device)
        optimizer.zero_grad()
        outputs = model(x_batch)
        outputs = outputs.view(-1, vocab_size)
        y_batch = y_batch.view(-1)
        loss = criterion(outputs, y_batch)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(dataloader)

def eval_epoch(model, dataloader, criterion):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for x_batch, y_batch in dataloader:
            x_batch, y_batch = x_batch.to(device), y_batch.to(device)
            outputs = model(x_batch)
            outputs = outputs.view(-1, vocab_size)
            y_batch = y_batch.view(-1)
            loss = criterion(outputs, y_batch)
            total_loss += loss.item()
    return total_loss / len(dataloader)

epochs = 2
for epoch in range(epochs):
    train_loss = train_epoch(model, train_loader, criterion, optimizer)
    val_loss = eval_epoch(model, val_loader, criterion)
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    print(f"Epoch {epoch+1}/{epochs} - train_loss: {train_loss:.4f}, val_loss: {val_loss:.4f}")

torch.save(model.state_dict(), "model_lstm_weights.pth")
print("Model weights saved to model_lstm_weights.pth")

plt.figure(figsize=(10,6))
plt.plot(range(1, epochs+1), train_losses, label='Train Loss')
plt.plot(range(1, epochs+1), val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()

generator_distilgpt2 = pipeline("text-generation", model="distilgpt2")
tokenizer_distilgpt2 = AutoTokenizer.from_pretrained("distilgpt2")

def split_text_for_completion(text):
    tokens = tokenizer_distilgpt2.tokenize(text)
    cut_off = (len(tokens) * 3) // 4
    input_text = tokenizer_distilgpt2.convert_tokens_to_string(tokens[:cut_off])
    target_text = tokenizer_distilgpt2.convert_tokens_to_string(tokens[cut_off:])
    return input_text, target_text

def clean_text_for_distilgpt2(text):
    text = text.lower()
    text = re.sub(r'http\S+|www\S+', '', text)
    text = re.sub(r'@\w+', '', text)
    text = re.sub(r'[^\w\sа-яёa-z]', '', text, flags=re.UNICODE)
    text = re.sub(r'\s+', ' ', text).strip()
    return text

texts = [clean_text_for_distilgpt2(t) for t in open(input_file, encoding='utf-8').readlines() if len(t.strip()) > 10]
_, val_texts = train_test_split(texts, test_size=0.2, random_state=42)

rouge1_lstm_scores, rouge2_lstm_scores = [], []
rouge1_gpt2_scores, rouge2_gpt2_scores = [], []

print("Оцениваем модели...")

for idx, text in enumerate(val_texts[:50]):
    input_text, ref_text = split_text_for_completion(text)

    lstm_pred = model.generate(tokenizer_lstm, vocab, input_text, max_length=len(tokenizer_lstm.tokenize(ref_text)))
    r1_lstm, r2_lstm = compute_rouge(ref_text, lstm_pred)
    rouge1_lstm_scores.append(r1_lstm)
    rouge2_lstm_scores.append(r2_lstm)

    gpt2_out = generator_distilgpt2(input_text,
                                    max_length=len(tokenizer_distilgpt2.encode(input_text + ref_text)),
                                    do_sample=True, top_k=50, num_return_sequences=1)
    gpt2_pred = gpt2_out[0]['generated_text'][len(input_text):].strip()
    r1_gpt2, r2_gpt2 = compute_rouge(ref_text, gpt2_pred)
    rouge1_gpt2_scores.append(r1_gpt2)
    rouge2_gpt2_scores.append(r2_gpt2)

    if idx < 5:
        print(f"\nПример #{idx+1}")
        print("Вход:", input_text)
        print("Эталон:", ref_text)
        print(f"LSTM предсказание: {lstm_pred}")
        print(f"distilgpt2 предсказание: {gpt2_pred}")
        print(f"LSTM ROUGE-1: {r1_lstm:.3f}, ROUGE-2: {r2_lstm:.3f}")
        print(f"distilgpt2 ROUGE-1: {r1_gpt2:.3f}, ROUGE-2: {r2_gpt2:.3f}")

print(f"\nСреднее LSTM ROUGE-1: {np.mean(rouge1_lstm_scores):.3f}")
print(f"Среднее LSTM ROUGE-2: {np.mean(rouge2_lstm_scores):.3f}")
print(f"Среднее distilgpt2 ROUGE-1: {np.mean(rouge1_gpt2_scores):.3f}")
print(f"Среднее distilgpt2 ROUGE-2: {np.mean(rouge2_gpt2_scores):.3f}")

print("""
Выводы:
- distilgpt2 обычно создаёт более связный и качественный текст с лучшими ROUGE метриками.
- LSTM может быть полезна при ограниченных ресурсах или для специфичных задач.
- Для большинства задач автодополнения рекомендуется использовать трансформеры как distilgpt2.
""")


KeyboardInterrupt: 