In [1]:
import torch
import sentencepiece as spm
from transformer_copy import Transformer
import torch.nn.functional as F

In [2]:
class InferBottle:
    def __init__(self, model_path="../model/best_model.pth", device="cuda" if torch.cuda.is_available() else "cpu", vocab_size=50000,
                 en_tokenizer_path="../data/en_spm.model", tr_tokenizer_path="../data/tr_spm.model", beam_width=5, max_len=500, sampling_method="all"):

        sp_en = spm.SentencePieceProcessor()
        sp_en.load(en_tokenizer_path)
        self.en_tokenizer = sp_en

        sp_tr = spm.SentencePieceProcessor()
        sp_tr.load(tr_tokenizer_path)
        self.tr_tokenizer = sp_tr

        self.vocab_size = vocab_size
        self.model = Transformer(vocab_size=self.vocab_size, embed_dim=512, ff_dim=2048, num_heads=8, n_encoders=5, n_decoders=5)
        self.model.load_state_dict(torch.load(model_path, map_location=device))
        self.model.to(device)
        self.model.eval()

        self.device = device
        self.sampling_method = sampling_method
        self.beam_width = beam_width
        self.max_len = max_len

        self.BOS_ID = sp_tr.bos_id()
        self.EOS_ID = sp_tr.eos_id()
        self.PAD_ID = sp_en.pad_id() if sp_en.pad_id() != -1 else 0

        assert sampling_method in ["greedy", "top_p", "beam_search", "all"], "Invalid sampling method. Choose one of 'greedy', 'top_p', 'beam_search', or 'all'."

    def translate(self, text):
        input_ids = self.en_tokenizer.encode(text, out_type=int)
        input_ids = torch.tensor(input_ids).unsqueeze(0).to(self.device)
        src_key_padding_mask = (input_ids == self.PAD_ID)

        def decode_sequence(decoded_ids):
            return self.tr_tokenizer.decode(decoded_ids)

        if self.sampling_method in ["greedy", "all"]:
            decoded_ids = self._greedy_decode(input_ids, src_key_padding_mask)
            greedy_output = decode_sequence(decoded_ids)

        if self.sampling_method in ["top_p", "all"]:
            decoded_ids = self._top_p_decode(input_ids, src_key_padding_mask)
            top_p_output = decode_sequence(decoded_ids)

        if self.sampling_method in ["beam_search", "all"]:
            decoded_ids = self.beam_search(input_ids, src_key_padding_mask)
            beam_output = decode_sequence(decoded_ids)

        if self.sampling_method == "greedy":
            return greedy_output
        elif self.sampling_method == "top_p":
            return top_p_output
        elif self.sampling_method == "beam_search":
            return {
                "text": text,
                "beam_search": beam_output
            }
        elif self.sampling_method == "all":
            return {
                "text": text,
                "greedy": greedy_output,
                "top_p": top_p_output,
                "beam_search": beam_output
            }

    def _greedy_decode(self, input_ids, src_key_padding_mask):
        decoded_ids = [self.BOS_ID]

        repetition_counter = {}
        for _ in range(self.max_len):
            tgt_input = torch.tensor(decoded_ids).unsqueeze(0).to(self.device)
            tgt_key_padding_mask = (tgt_input == 0)

            with torch.no_grad():
                logits = self.model(
                    input_ids,
                    tgt_input,
                    src_key_padding_mask=src_key_padding_mask,
                    tgt_key_padding_mask=tgt_key_padding_mask
                )[0, -1, :]

            next_token = self.greedy_sampling(logits) # simply taking the argmax over the logits

            if next_token == self.EOS_ID:
                break

            decoded_ids.append(next_token)

            repetition_counter[next_token] = repetition_counter.get(next_token, 0) + 1
            if repetition_counter[next_token] >= 3:
                break

        return decoded_ids[1:]

    def _top_p_decode(self, input_ids, src_key_padding_mask):
        decoded_ids = [self.BOS_ID]

        repetition_counter = {}
        for _ in range(self.max_len):
            tgt_input = torch.tensor(decoded_ids).unsqueeze(0).to(self.device)
            tgt_key_padding_mask = (tgt_input == 0)

            with torch.no_grad():
                logits = self.model(
                    input_ids,
                    tgt_input,
                    src_key_padding_mask=src_key_padding_mask,
                    tgt_key_padding_mask=tgt_key_padding_mask
                )[0, -1, :]

            for token in set(decoded_ids):  # to discourage repetition, can be disabled for more advanced models
                logits[token] /= 1.2

            next_token = self.top_p_sample(logits, p=0.9, temperature=0.8) # random sampling from the top-p cumulative probability mass (0.9)
            if next_token == self.EOS_ID:
                break
            decoded_ids.append(next_token)

            repetition_counter[next_token] = repetition_counter.get(next_token, 0) + 1
            if repetition_counter[next_token] >= 3:
                break

        return decoded_ids[1:]

    def beam_search(self, input_ids, src_key_padding_mask):
        beams = [([self.BOS_ID], 0.0, False)]
        for _ in range(self.max_len):
            candidates = []

            for seq, score, has_ended in beams:
                if has_ended:
                    candidates.append((seq, score, True))
                    continue

                tgt_input = torch.tensor(seq).unsqueeze(0).to(self.device)
                tgt_key_padding_mask = (tgt_input == 0)

                with torch.no_grad():
                    logits = self.model(
                        input_ids,
                        tgt_input,
                        src_key_padding_mask=src_key_padding_mask,
                        tgt_key_padding_mask=tgt_key_padding_mask
                    )[0, -1, :]

                for token in set(seq):
                    logits[token] /= 1.3  # reduce the probability of already generated tokens

                probs = F.log_softmax(logits, dim=-1) # log softmax for numerical stability
                topk_probs, topk_indices = torch.topk(probs, self.beam_width)

                for log_prob, token_id in zip(topk_probs.tolist(), topk_indices.tolist()):
                    new_seq = seq + [token_id]
                    new_score = score + log_prob
                    ended = token_id == self.EOS_ID
                    candidates.append((new_seq, new_score, ended))

            beams = sorted(candidates, key=lambda x: x[1], reverse=True)[:self.beam_width] # keep the top-beam_width sequences

            if all(ended for _, _, ended in beams): # break if all sequences have ended
                break

        best_seq = beams[0][0]
        if best_seq[0] == self.BOS_ID:
            best_seq = best_seq[1:]
        if best_seq and best_seq[-1] == self.EOS_ID:
            best_seq = best_seq[:-1]
        return best_seq

    @staticmethod
    def greedy_sampling(logits):
        return torch.argmax(logits, dim=-1).item()

    @staticmethod
    def top_p_sample(logits, p=0.9, temperature=1.0):
        logits = logits / temperature # temperature scaling to control randomness/exploration
        probs = F.softmax(logits, dim=-1)

        sorted_probs, sorted_indices = torch.sort(probs, descending=True)
        cumulative_probs = torch.cumsum(sorted_probs, dim=-1)

        sorted_mask = cumulative_probs > p
        sorted_mask[..., 1:] = sorted_mask[..., :-1].clone() # shift the mask to the right to include the first token where cumulative_probs >= p
        sorted_mask[..., 0] = False # to always keep the first token

        sorted_probs[sorted_mask] = 0
        sorted_probs /= sorted_probs.sum() # normalizing the probabilities

        next_token = sorted_indices[torch.multinomial(sorted_probs, 1)] # randomly sampling from the distribution
        return next_token.item()


In [3]:
infer = InferBottle(model_path="../model/best_model.pth", sampling_method="beam_search")

In [4]:
outputs = infer.translate("I’m tired, but I can’t sleep.")
for method, text in outputs.items():
    print(f"{method}: {text}")
print("\n")

outputs = infer.translate("She left the window open, and now the whole room is freezing.")
for method, text in outputs.items():
    print(f"{method}: {text}")
print("\n")

outputs = infer.translate("We don’t always get what we want, but we learn to live with it.")
for method, text in outputs.items():
    print(f"{method}: {text}")
print("\n")

outputs = infer.translate("Even after all these years, he still remembered the smell of her perfume.")
for method, text in outputs.items():
    print(f"{method}: {text}")
print("\n")

outputs = infer.translate("The scientist paused, uncertain whether to publish results that might change everything.")
for method, text in outputs.items():
    print(f"{method}: {text}")
print("\n")

outputs = infer.translate("They told him it was impossible, but he did it anyway — and now the world knows his name.")
for method, text in outputs.items():
    print(f"{method}: {text}")
print("\n")

outputs = infer.translate("In the silence that followed, no one dared to speak.")
for method, text in outputs.items():
    print(f"{method}: {text}")
print("\n")

outputs = infer.translate("They couldn't believe their eyes. What they saw was unimaginable.")
for method, text in outputs.items():
    print(f"{method}: {text}")
print("\n")

text: I’m tired, but I can’t sleep.
beam_search: Ben yorgun, ama uyuya Uyamıyorum.


text: She left the window open, and now the whole room is freezing.
beam_search: Pencere açık bıraktı ve şimdi bütün oda donma. O, o da onu terk


text: We don’t always get what we want, but we learn to live with it.
beam_search: Her zaman istediğimizi almıyoruz, ama onunla yaşamayı öğreniyoruz.


text: Even after all these years, he still remembered the smell of her perfume.
beam_search: Tüm bu yıl sonra, hala onun parfüm kokusunu hatırdı.


text: The scientist paused, uncertain whether to publish results that might change everything.
beam_search: bilim adamı, her şeyi değiştirebilecek sonuçların yayınlanıp yayımlanmaya durakladı.


text: They told him it was impossible, but he did it anyway — and now the world knows his name.
beam_search: Ona imkansız olduğunu söylediler, ama yine de bunu yaptı ve şimdi dünya adını biliyor.


text: In the silence that followed, no one dared to speak.
beam_search: Tak

In [5]:
outputs = infer.translate("I wanted to leave, but I stayed.")
for method, text in outputs.items():
    print(f"{method}: {text}")
print("\n")

outputs = infer.translate("After the meeting ended, everyone left the room.")
for method, text in outputs.items():
    print(f"{method}: {text}")
print("\n")

outputs = infer.translate("If you had told me earlier, I would have helped.")
for method, text in outputs.items():
    print(f"{method}: {text}")
print("\n")

outputs = infer.translate("The results were announced yesterday.")
for method, text in outputs.items():
    print(f"{method}: {text}")
print("\n")

outputs = infer.translate("She turned a blind eye to the whole situation.")
for method, text in outputs.items():
    print(f"{method}: {text}")
print("\n")

outputs = infer.translate("The book that he recommended was surprisingly good.")
for method, text in outputs.items():
    print(f"{method}: {text}")
print("\n")

outputs = infer.translate("They said he might be coming, but they weren’t sure.")
for method, text in outputs.items():
    print(f"{method}: {text}")
print("\n")

outputs = infer.translate("It’s hard to say whether what he did was right or wrong.")
for method, text in outputs.items():
    print(f"{method}: {text}")
print("\n")

text: I wanted to leave, but I stayed.
beam_search: Ben ayrılmak istedim, ama ben kaldım Kaldım.


text: After the meeting ended, everyone left the room.
beam_search: toplantı sona erdi, herkes odadan ayrıldıktan sonra gitti. Toplantı bittiğinde herkesi odadan ayrıldı


text: If you had told me earlier, I would have helped.
beam_search: Bana bana daha önce bana daha erken söylersen, yardım ederdim. Yardım olurdu.


text: The results were announced yesterday.
beam_search: sonuçları dün açıklandı. Sonuçlar duyuruldu.


text: She turned a blind eye to the whole situation.
beam_search: O bütün duruma göz kör bir göze döndü. O, o Gözü tamamen gitti


text: The book that he recommended was surprisingly good.
beam_search: Onun tavsiye ettiği kitap şaşırtıcı bir şekilde iyiydi. Tavsiye eden o şiddetle tavsiye ederim


text: They said he might be coming, but they weren’t sure.
beam_search: Onlar geliyor olabileceğini söyledi, ama emin değil dediler.


text: It’s hard to say whether what he did 