In [16]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import Dataset, DataLoader, Subset
import torchaudio.datasets as datasets
from sklearn.model_selection import train_test_split

from transformer import TransformerLanguageModel

from tokenizer.BPE import tokenize, tokenizer
import pickle as pkl
import os
import glob

### configs

In [9]:
current_dir = os.getcwd()

merges_path = os.path.join(current_dir, "tokenizer", "merges.pkl")
vocab_path = os.path.join(current_dir, "tokenizer", "vocabulary.pkl")

# Загрузка merges.pkl
with open(merges_path, "rb") as f:
    merges = pkl.load(f)
    print("Загрузка merges.pkl успешна")

# Загрузка vocab.pkl
with open(vocab_path, "rb") as f:
    vocab = pkl.load(f)
    print("Загрузка vocabulary.pkl успешна")

text = 'HELLO MY NAME IS BILL'
tokens = [vocab[0]] + tokenize(text, merges) + [vocab[1]]

token_to_id = {vocab[i]: i for i in range(len(vocab))}
id_to_token = {i: vocab[i] for i in range(len(vocab))}
PAD_ID = 2

config = {
    'dim_feedforward': 64,
    'num_heads': 8,
    'num_layers': 8,
    'learning_rate': 0.001,
    'batch_size': 64,
    'epochs': 256,
    'embedding_dim': 64,
    'dataset': "LibriSpeech dev-clean",
    'vocab_size': len(vocab),
}

vocab_size = config['vocab_size']
embedding_dim = config['embedding_dim']
dim_feedforward = config['dim_feedforward']
num_heads = config['num_heads']
num_layers = config['num_layers']
num_epochs = config['epochs']

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = TransformerLanguageModel(vocab_size, embedding_dim, num_heads, dim_feedforward, num_layers).to(device)

Загрузка merges.pkl успешна
Загрузка vocabulary.pkl успешна


### Dataset

In [11]:
data = datasets.LIBRISPEECH("../data", url="dev-clean", )
corpus = []
for i in range(2800):
    try:
        corpus.append(list(map(lambda x: token_to_id[x], [vocab[0]] + tokenize(data.__getitem__(i)[2], merges) + [vocab[1]])))
    except IndexError as err:
        break

max_length = max(len(seq) for seq in corpus)
print(max_length)
class TextDataset(Dataset):
    def __init__(self, data, max_len):
        self.data = data
        self.max_len = max_len

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = torch.tensor(self.data[idx], dtype=torch.int16)
        sample = sample[:self.max_len]
        length = sample.shape[-1]
        padding = torch.ones((self.max_len - sample.shape[-1])) * 2
        sample = torch.cat((sample, padding), dim=0)
        return torch.tensor(sample, dtype=torch.float), length

dataset = TextDataset(corpus, max_length)

train_indices, val_indices = train_test_split(list(range(len(dataset))), test_size=0.2)

# Создание тренировочного и валидационного датасетов
train_dataset = Subset(dataset, train_indices)
val_dataset = Subset(dataset, val_indices)

# Создание DataLoader-ов для тренировочного и валидационного датасетов
train_loader = DataLoader(train_dataset, batch_size=config["batch_size"], shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=config["batch_size"])

294


### Coding

In [6]:
def length_to_mask(inputs, lengths, dtype=None):
    batch_size = lengths.size(0)
    seq_len = inputs.size(1) if isinstance(inputs, torch.Tensor) else inputs
    # Causal mask [seq_len, seq_len]
    tgt_mask = torch.triu(torch.ones((seq_len, seq_len), device=device)).transpose(0, 1)
    tgt_mask = tgt_mask.float().masked_fill(tgt_mask == 0, float('-inf')).masked_fill(tgt_mask == 1, float(0.0))

    # Padding mask [batch_size, seq_len]
    key_padding_mask = (torch.arange(seq_len, device=device).expand(batch_size, seq_len) >= lengths.unsqueeze(1))

    if dtype is not None:
        key_padding_mask = key_padding_mask.to(dtype=dtype)

    return tgt_mask, key_padding_mask

In [24]:
def beam_search_decode(model, input_seq, lengths, beam_width=5, max_len=100, device='cuda'):
    model.eval()
    with torch.no_grad():
        input_seq = input_seq.unsqueeze(0).to(device)
        lengths = (torch.tensor(lengths).to(device) - 1).unsqueeze(0)

        # Начальная последовательность (BOS/start token)
        sequences = [[(input_seq[0][:1].tolist(), 0.0)]]  # [(tokens, score)]

        for _ in range(max_len):
            all_candidates = []
            for seq, score in sequences[-1]:
                if seq[-1] == token_to_id["<|endoftext|>"]:  # завершённая гипотеза
                    all_candidates.append((seq, score))
                    continue

                seq_tensor = torch.tensor(seq, dtype=torch.long).unsqueeze(0).to(device)
                lengths_tensor = torch.tensor([len(seq)]).to(device)

                tgt_mask, tgt_key_padding_mask = length_to_mask(len(seq), lengths_tensor)
                tgt_mask, tgt_key_padding_mask = tgt_mask.to(device), tgt_key_padding_mask.to(device)

                outputs = model(seq_tensor, tgt_mask=tgt_mask, lengths=lengths_tensor, tgt_key_padding_mask=tgt_key_padding_mask)
                next_token_logits = outputs[0, -1, :]  # последний токен

                probs = F.log_softmax(next_token_logits, dim=-1)
                topk_probs, topk_indices = probs.topk(beam_width)

                for i in range(beam_width):
                    next_seq = seq + [topk_indices[i].item()]
                    next_score = score + topk_probs[i].item()
                    all_candidates.append((next_seq, next_score))

            # Выбираем beam_width лучших гипотез
            ordered = sorted(all_candidates, key=lambda tup: tup[1], reverse=True)
            sequences.append(ordered[:beam_width])

            # Если все последовательности закончены
            if all(seq[-1] == token_to_id["<|endoftext|>"] for seq, _ in sequences[-1]):
                break

        final_seq = sequences[-1][0][0]
        return final_seq

In [25]:
i = 5
input_data = train_dataset[i][0]
lengths = train_dataset[i][1]

model_pattern = os.path.join("../best_models/transformer", "model_244_*")
model_files = glob.glob(model_pattern)

if not model_files:
    raise FileNotFoundError(f"Файл не найден по шаблону: {model_pattern}")

model_path = model_files[0]

checkpoint = torch.load(model_path, map_location=device)

# 3. Применение весов
model.load_state_dict(checkpoint)

  return torch.tensor(sample, dtype=torch.float), length


<All keys matched successfully>

In [43]:

# Декодируем
output_ids = beam_search_decode(model, input_data, lengths, beam_width=5, max_len=100, device=device)

# Преобразуем в текст
tokens = [id_to_token[tok] for tok in output_ids]
text = ''.join(tokens[1:-1]).replace("Ġ", " ")
print(text)

I DON'T KNOW
