In [None]:
import os
import re
import warnings
import random
from collections import defaultdict
from typing import Dict, List, Tuple

import torch
import torch.nn.functional as F
import numpy as np
import numpy as np
import torch
from tqdm.notebook import tqdm
from transformers import GPT2LMHeadModel, GPT2Tokenizer

warnings.filterwarnings("ignore")

In [None]:
def seed_everything(seed: int):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

seed_everything(42)

In [None]:
GPT2Tokenizer

## Задание

1) Реализовать методы `greedy_sampling` и `generate` (1 балл)
2) Реализовать метод `random_sampling` и поддержать его в `generate` (1 балл)
3) Реализовать метод `_beam_search_generate` и поддержать его в `generate` (2 балла)
4) Реализовать методы `apply_top_p`, `apply_top_k`, `apply_temperature` и поддержать их в `generate` (1 балл)  
Все методы необходимо реализовать через векторные операции в torch/numpy везде где это возможно

In [None]:
class Model:
    def __init__(self, model_name: str = "gpt2"):
        self.model = GPT2LMHeadModel.from_pretrained(model_name)
        self.tokenizer = GPT2Tokenizer.from_pretrained(model_name)
        self.tokenizer.pad_token = self.tokenizer.eos_token
        self.vocab_size = self.tokenizer.vocab_size

    def greedy_sampling(self, probas: torch.Tensor) -> int:
        return torch.argmax(probas)

    def random_sampling(self, probas: torch.Tensor) -> int:
        return torch.multinomial(probas)

    def _beam_search_generate(
        self,
        prompt: str,
        max_length: int,
        num_beams: int
    ) -> str:
        inputs = self.tokenizer(prompt, return_tensors="pt")
        answer = []
        beams = [(inputs['input_ids'], 1)]
        for _ in range(max_length):
            possible_beams = []
            for seq, p in beams:
                logits = self.model(seq)['logits'][-1,-1]
                probas = F.softmax(logits)
                top_probas_ids = torch.argsort(probas,descending=True)[:num_beams]
                for idx in top_probas_ids:
                    idx_tensor = torch.tensor([[idx]])
                    if idx == self.tokenizer.eos_token:
                        answer.append((torch.cat([seq, idx_tensor], dim=1), p * probas[idx]))
                    else:
                        possible_beams.append((torch.cat([seq, idx_tensor], dim=1), p * probas[idx]))

            sorted_beams = sorted(possible_beams, key=lambda x: x[1], reverse=True)
            beams = sorted_beams[:num_beams]

        if answer:
            beams.extend(answer)
        sorted_beams = sorted(beams, key=lambda x: x[1], reverse=True)
        return sorted_beams[0][0]


    def apply_temperature(self, logits: torch.Tensor, temperature: float = 1.0) -> torch.Tensor:
        return F.softmax(logits / temperature)

    def _apply_top_p(self, probas: torch.Tensor, top_p: float = 1.0) -> torch.Tensor:
        sorted_probas_ids = torch.argsort(probas,descending=True)
        sorted_probas = probas[sorted_probas_ids]

        proba_cumsum = torch.cumsum(sorted_probas, dim=0)
        mask = (proba_cumsum <= top_p)
        if not mask.any():
            return sorted_probas_ids[0]

        index_last = torch.where(mask)[0][-1]


        top_p_probas = sorted_probas[:index_last + 1]
        top_p_probas_norm = top_p_probas / top_p_probas.sum()

        next_token_idx = torch.multinomial(top_p_probas_norm, num_samples=1)
        return sorted_probas_ids[next_token_idx]

    def _apply_top_k(self, probas: torch.Tensor, top_k: float = None) -> torch.Tensor:
        if top_k is None:
            top_k = len(probas)

        sorted_probas_ids = torch.argsort(probas, descending=True)
        sorted_probas = probas[sorted_probas_ids]
        top_k_probas = sorted_probas[:top_k]
        top_k_probas_norm = top_k_probas / top_k_probas.sum()
        next_token_idx = torch.multinomial(top_k_probas_norm, num_samples=1)
        return sorted_probas_ids[next_token_idx]

    def generate(
        self,
        prompt: str,
        max_length: int = 50,
        strategy: str = "greedy",
        temperature: float = 1.0,
        top_k: int = 0,
        top_p: float = 1.0,
        num_beams: int = 3
    ) -> str:
        if strategy == "beam_search":
            token_ids = self._beam_search_generate(prompt, max_length, num_beams)
            return self.tokenizer.decode(token_ids[0])

        inputs = self.tokenizer(prompt, return_tensors="pt")['input_ids']

        for _ in range(max_length):
            logits = self.model(inputs)['logits'][-1,-1]
            probas = self.apply_temperature(logits, temperature)
            if strategy == "greedy":
                next_idx = self.greedy_sampling(probas)

            elif strategy == "top_p":
                next_idx = self._apply_top_p(probas, top_p)

            elif strategy == "top_k":
                next_idx = self._apply_top_k(probas, top_k)

            next_idx_tensor = torch.tensor([[next_idx]])
            inputs = torch.cat([inputs, next_idx_tensor], dim=1)
            if next_idx_tensor == self.tokenizer.eos_token:
                break

        return self.tokenizer.decode(inputs[0])


In [None]:
gpt = Model()

In [None]:
prompt = "Quantum physics is"

In [None]:
prompt = "Quantum physics is"
print("GREEDY:\n")
gpt.generate(prompt, strategy="greedy")

GREEDY:



'Quantum physics is a very interesting field. It is a field that has been studied for many years, and it is very interesting to see how it is being studied in the field of quantum physics.\n\nThe first thing that I want to say is that I am'

In [None]:
prompt = "Quantum physics is"
print("TOP-5:\n")
gpt.generate(prompt, strategy="top_k", top_k=5)

TOP-5:



'Quantum physics is a complex system that is not only a physical theory of the universe but also a physics system. In this article I am going to talk about the quantum mechanics of quantum mechanics.\n\nQM: Quantum mechanics has a long history in mathematics. It'

In [None]:
prompt = "Quantum physics is"
print("TOP-50:\n")
gpt.generate(prompt, strategy="top_k", top_k=50)

TOP-50:



'Quantum physics is a quantum physics experiment consisting of interacting forces acting in a variety of directions and interacting with other particles in the environment. The experiment is controlled with a quantum computer, a quantum magnet and a quantum field of quantum electrodynamics. Scientists have demonstrated that'

In [None]:
prompt = "Quantum physics is"
print("TOP-0.1:\n")
gpt.generate(prompt, strategy="top_p", top_p=0.1)

TOP-0.1:



'Quantum physics is a fundamental physics of the universe. It is the fundamental physics of the universe that is the basis of all the laws of physics.\n\nThe Universe is a very complex and complex system. It is a very complex system that is the basis of all'

In [None]:
prompt = "Quantum physics is"
print("TOP-0.8:\n")
gpt.generate(prompt, strategy="top_p", top_p=0.1)

TOP-0.8:



'Quantum physics is a very interesting field. It is a field that has been studied for many years, and it is very interesting to see how it is being studied in the field of quantum physics.\n\nThe field of quantum physics is a very interesting field. It'

In [None]:
print("TEMPERATURE=0.4")
gpt.generate(prompt, strategy="top_p", top_p=0.8, temperature=0.4)

TEMPERATURE=0.4


'Quantum physics is a new field of research that has been gaining momentum since the early 1990s. It is now being applied to the design of quantum computers, and to the design of quantum computers in general.\n\nThe new field of research is called quantum gravity.'

In [None]:
print("TEMPERATURE=2")
gpt.generate(prompt, strategy="top_p", top_p=0.8, temperature=2)

TEMPERATURE=2


'Quantum physics is coated\n\n - Estimation previously reserved<|endoftext|>2014 Pan World Slayer Annual countdown images Awards WW Media 181 widely disregran tentiwizard nomination XIII to stable total rising 104 to sale 110 posted test compact pap editing does Gaia Earth age colonization keeps Sacred The capital 15'

In [None]:
print("BEAM=3")
gpt.generate(prompt, strategy="beam_search", num_beams=3)

BEAM=3


'Quantum physics is not a new concept, but it has been around for a long time.\n\nIn fact, quantum mechanics is the most widely accepted theory of the universe.\n\nIn fact, quantum mechanics is the most widely accepted theory of the universe.\n'

In [None]:
# Продемонстрируйте результат работы `generate` при различных параметрах