# Assignment 4: Decoding Algorithms (Section 3: NeuroLogic)
  
## Section 0: Setup

Please run all the code blocks in this section. You don't need to implement or change anything here.

In [None]:
!pip install transformers
!pip install datasets
!pip install evaluate

In [None]:
"""set device and random seeds"""

######################################################
#  The following helper functions are given to you.
######################################################

import torch
import random

def set_seed(seed=19260817):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed()

### 0.1 Load dataset

In [None]:
"""load datasets"""

######################################################
#  The following helper code is given to you.
######################################################

from datasets import load_dataset

dataset = load_dataset('Ximing/ROCStories')
train_data, dev_data, test_data = dataset['train'], dataset['validation'], dataset['test']

print(train_data[0])

### 0.2 Define Evaluation models and metrics

In [None]:
"""prepare evaluation"""

######################################################
#  The following helper code is given to you.
######################################################

import torch
import torch.nn.functional as f
from evaluate import load

from transformers import RobertaForSequenceClassification, RobertaTokenizer

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'device: {device}')

perplexity_scorer = load("perplexity", module_type="metric")
cola_model_name = "textattack/roberta-base-CoLA"
cola_tokenizer = RobertaTokenizer.from_pretrained(cola_model_name)
cola_model = RobertaForSequenceClassification.from_pretrained(cola_model_name).to(device)

def batchify(data, batch_size):
    assert batch_size > 0

    batch = []
    for item in data:
        # Yield next batch
        if len(batch) == batch_size:
            yield batch
            batch = []

        batch.append(item)

    # Yield last un-filled batch
    if len(batch) != 0:
        yield batch

In [None]:
"""set up evaluation metric"""

######################################################
#  The following helper code is given to you.
######################################################

def compute_perplexity(texts, model='gpt2', batch_size=8):
    score = perplexity_scorer.compute(predictions=texts, add_start_token=False, batch_size=batch_size, model_id=model)
    return score['mean_perplexity']

def compute_fluency(texts, batch_size=8):
  scores = []
  for b_texts in batchify(texts, batch_size):
    inputs = cola_tokenizer(texts, padding=True, truncation=True, return_tensors="pt").to(device)
    with torch.no_grad():
      logits = cola_model(**inputs).logits
      probs = logits.softmax(dim=-1)
      scores.extend(probs[:, 1].tolist())
  return sum(scores) / len(scores)

def compute_diversity(texts):
    unigrams, bigrams, trigrams = set(), set(), set()
    total_words = 0
    for gen in texts:
        o = gen.split(' ')
        total_words += len(o)
        unigrams.update(o)
        for i in range(len(o) - 1):
            bigrams.add(o[i] + '_' + o[i + 1])
        for i in range(len(o) - 2):
            trigrams.add(o[i] + '_' + o[i + 1] + '_' + o[i + 2])
    return len(unigrams) / total_words, len(bigrams) / total_words, len(trigrams) / total_words


test_sents = ["This restaurant is awesome", "My dog is cute and I love it.", "Today is sunny."]
print(compute_perplexity(test_sents))
print(compute_fluency(test_sents))
print(compute_diversity(test_sents))

## 3.1 Neurologic Implementation

### Configurations: load model and tokenizer

In [None]:
"""load model and tokenizer"""

######################################################
#  The following helper code is given to you.
######################################################

from dataclasses import dataclass
from typing import List, Tuple

import copy
import torch
import torch.nn.functional as F
from transformers import GPT2LMHeadModel, GPT2Tokenizer, GenerationConfig

model_name = 'gpt2'

model = GPT2LMHeadModel.from_pretrained(model_name).to(device)
model.eval()
tokenizer = GPT2Tokenizer.from_pretrained(model_name, pad_token="<|endoftext|>")
tokenizer.padding_side = "left"

pad_token_id = tokenizer.pad_token_id
eos_token_id = tokenizer.eos_token_id

### Helper classes: `Word`, `ConstrainedHypothesis` and `ConstrainedCandidate`

In [None]:
######################################################
#  The following helper code is given to you.
######################################################

class Word:
    def __init__(self, idx: int, word_token: int):
        """
        :param idx: index id for this keyword
        :param word_token: token of the keyword, we assume single-token constraint for this homework
        """
        self.idx = idx
        self.token = word_token
        self.satisfied = False

    def __str__(self):
        return f'word(id={self.idx}, token={self.token}, satisfy={self.satisfied})'

    def advance(self, next_token: int):
        """
        :param next_token: selected token at current time step
        """
        if next_token == self.token:
            self.satisfied = True


class ConstrainedHypothesis:
    def __init__(self, keyword_list: List[int]):
        """
        :param keyword_list: list of tokenized keywords
        """
        self.words = []
        for idx, word_token in enumerate(keyword_list):
            self.words.append(Word(idx=idx, word_token=word_token))

    def __str__(self) -> str:
        return '\t'.join([str(w) for w in self.words])

    def num_satisfied_keywords(self) -> int:
        """
        :return: number of satisfied keywords
        """
        # return how many keywords are satisfied
        return sum([int(w.satisfied) for w in self.words])

    def get_satisfied_keywords_idx(self) -> Tuple[int]:
        """
        :return: the index ids of the satisfied keywords, from low to high
        """
        # traverse through self.words, add word.idx to the return tuple if it's satisfied,
        #  notice that the index ids should be sorted from low to high
        satisfied_keywords_idx = [w.idx for w in self.words if w.satisfied]
        satisfied_keywords_idx = tuple(sorted(satisfied_keywords_idx))
        return satisfied_keywords_idx

    def get_unsatisfied_words(self) -> List[int]:
        """
        :return: the token of keywords that are not satisfied
        """
        return [w.token for w in self.words if not w.satisfied]

    def advance(self, next_token: int) -> 'ConstrainedHypothesis':
        """
        :param next_token: selected token at current time step
        :return: a new ConstrainedHypothesis object with updated state based on the next token
        """
        obj = copy.deepcopy(self)
        # update the keyword satisfaction state of obj based on the next token
        for word in obj.words:
            if word.satisfied:
                continue
            word.advance(next_token)
        return obj


class ConstrainedCandidate:
    """
    :param row: The row (beam index) in the scores matrix.
    :param col: The column (token id) in the scores matrix.
    :param score: the associated accumulated score.
    :param hypothesis: The ConstrainedHypothesis containing constraint information.
    """

    __slots__ = ('row', 'col', 'score', 'hypothesis')

    def __init__(self, row: int, col: int, score: float, hypothesis: ConstrainedHypothesis):
        self.row = row
        self.col = col
        self.score = score
        self.hypothesis = hypothesis

    def __hash__(self):
        return hash((self.row, self.col))

    def __eq__(self, other):
        return self.row == other.row and self.col == other.col

    def __str__(self):
        return '[{}, {}, {}, {}]'.format(self.row, self.col, self.score, str(self.hypothesis))


def initialize_constraint(keyword_lists: List[List[int]], beam_size: int) -> List[ConstrainedHypothesis]:
    """
    :param keyword_lists: list of tokenized keyword list in a batch
    :param beam_size: beam size
    :return: list of initialized ConstrainedHypothesis objected with shape (batch_size * beam_size,)
    """
    batch_size = len(keyword_lists)
    constraints_list = [None] * (batch_size * beam_size)
    for i, word_list in enumerate(keyword_lists):
        hyp = ConstrainedHypothesis(word_list)
        start_idx = i * beam_size
        constraints_list[start_idx:start_idx + beam_size] = [copy.deepcopy(hyp) for _ in range(beam_size)]
    return constraints_list

### Helper classes: `BeamHypothesisList` and `BeamManager`

In [None]:
######################################################
#  The following helper code is given to you.
######################################################

@dataclass
class BeamHypothesis:
    def __init__(self, input_ids: torch.LongTensor, score: torch.FloatTensor, num_satisfied_keywords: int):
        self.input_ids: torch.LongTensor = input_ids  # a single token sequence of size (seq_len,)
        self.score: torch.FloatTensor = score  # a scalar score for the token sequence
        self.num_satisfied_keywords: int = num_satisfied_keywords

    def __str__(self):
        return f"BeamHypothesis(input_ids: {self.input_ids}, score: {self.score}, num_satisfied_keywords: {self.num_satisfied_keywords})"


class BeamHypothesisList:
    def __init__(self, num_beams: int):
        self.beam_hypotheses: List[BeamHypothesis] = []  # list of beam_hypothesis
        self.num_beams: int = num_beams

        self.worst_score = {}  # worst beam score of hypotheses with certain num_satisfied_keywords in self.beam_hypotheses

    def add(self, new_input_ids: torch.LongTensor, sum_logprobs: float, num_satisfied_keywords: int):
        """
        :param new_input_ids: new token sequence of size (1, seq_len)
        :param sum_logprobs: sum of log probabilities of tokens in new_input_ids
        :param num_satisfied_keywords: number of satisfied constraints
        Given a new hypothesis (new_input_ids) and its corresponding sum_logprobs and num_satisfied_keywords,
        update self.beam_hypotheses with a finished hypothesis.
        (1) If self.beam_hypotheses contains hypotheses with the same number of satisfied keywords
            and the new_input_ids has higher score than the worst score within the hypothesis group with the same number of satisfied keywords
            we replace the worst scoring hypothesis within this group with the new hypothesis.
        (2) If self.beam_hypotheses doesn't contain hypotheses with the same number of satisfied keywords
            and the number of satisfied keywords is greater than the lowest number of satisfied keywords in self.beam_hypotheses,
            we replace the worst scoring hypothesis within the group with the lowest number of satisfied keywords with the new hypothesis.
        (3) Otherwise, we do not make any change to self.beam_hypotheses.
        """
        # For score, we compute average token log probability
        score = sum_logprobs / new_input_ids.size(-1)

        # add the new hypothesis if we still have vacant beams
        if len(self.beam_hypotheses) < self.num_beams:
            # initialize the new_beam_hypothesis using new_input_ids and score
            new_beam_hypothesis: BeamHypothesis = BeamHypothesis(input_ids=new_input_ids, score=score,
                                                                 num_satisfied_keywords=num_satisfied_keywords)

            # add new_beam_hypothesis to beam_hypotheses
            self.beam_hypotheses.append(new_beam_hypothesis)

            # keep track of the worst score among hypotheses with certain number of satisfied keywords
            if num_satisfied_keywords not in self.worst_score:
                self.worst_score[num_satisfied_keywords] = float(score)
            else:
                self.worst_score[num_satisfied_keywords] = min(self.worst_score[num_satisfied_keywords], float(score))

        # new_input_ids has higher score than the worst score within the hypothesis group with the same number of satisfied keywords
        elif num_satisfied_keywords in self.worst_score and score > self.worst_score[num_satisfied_keywords]:
            # find the group of hypotheses with the same number of satisfied keywords
            same_keywords_group = [hyp for hyp in self.beam_hypotheses if hyp.num_satisfied_keywords == num_satisfied_keywords]
            # remove the worst hypothesis, the one with the lowest score in this group
            worst_hypothesis: BeamHypothesis = min(same_keywords_group, key=lambda hyp: hyp.score)
            self.beam_hypotheses.remove(worst_hypothesis)

            # add new hypothesis
            # initialize the new_beam_hypothesis using new_input_ids, score and num_satisfied_keywords
            new_beam_hypothesis = BeamHypothesis(input_ids=new_input_ids, score=score,
                                                 num_satisfied_keywords=num_satisfied_keywords)

            # add new_beam_hypothesis to beam_hypotheses
            self.beam_hypotheses.append(new_beam_hypothesis)

            # update the worst score for group with num_satisfied_keywords
            same_keywords_group = [hyp for hyp in self.beam_hypotheses if hyp.num_satisfied_keywords == num_satisfied_keywords]
            worst_score = min(float(hyp.score) for hyp in same_keywords_group)
            self.worst_score[num_satisfied_keywords] = worst_score

        # the number of satisfied keywords is greater than the lowest number of satisfied keywords
        elif num_satisfied_keywords not in self.worst_score and num_satisfied_keywords > min(self.worst_score.keys()):
            # find the group of hypotheses with the lowest number of satisfied keywords
            lowest_keywords_group = [hyp for hyp in self.beam_hypotheses if hyp.num_satisfied_keywords == min(self.worst_score.keys())]
            # remove the worst hypothesis, the one with the lowest score in this group
            worst_hypothesis: BeamHypothesis = min(lowest_keywords_group, key=lambda hyp: hyp.score)
            self.beam_hypotheses.remove(worst_hypothesis)

            # add new hypothesis
            # initialize the new_beam_hypothesis using new_input_ids, score and num_satisfied_keywords
            new_beam_hypothesis = BeamHypothesis(input_ids=new_input_ids, score=score,
                                                 num_satisfied_keywords=num_satisfied_keywords)

            # add new_beam_hypothesis to beam_hypotheses
            self.beam_hypotheses.append(new_beam_hypothesis)

            # update the worst score for group with num_satisfied_keywords
            self.worst_score[num_satisfied_keywords] = float(score)
            # update the worst score for group with the lowest num_satisfied_keywords
            lowest_keywords_group = [hyp for hyp in self.beam_hypotheses if hyp.num_satisfied_keywords == min(self.worst_score.keys())]
            worst_score = min(float(hyp.score) for hyp in lowest_keywords_group)
            self.worst_score[min(self.worst_score.keys())] = worst_score

        # sanity check
        assert len(self.beam_hypotheses) <= self.num_beams

In [None]:
######################################################
#  The following helper code is given to you.
######################################################

class BeamManager:
    def __init__(self, batch_size: int, num_beams: int):
        self.finished_beam_hypotheses_list = [BeamHypothesisList(num_beams) for _ in range(batch_size)]
        self.batch_size = batch_size
        self.num_beams = num_beams

    def process(self,
                input_ids: torch.LongTensor,
                top_token_scores: torch.FloatTensor,
                top_token_indices: torch.LongTensor,
                top_token_beam_indices: torch.LongTensor,
                constraint_hypotheses: List[ConstrainedHypothesis],
                ):
        """
        :param input_ids: (batch_size * num_beams, current_seq_length), the input_ids that were used to compute top_tokens
        :param top_token_scores: (batch_size, 2 * num_beams), representing the score of each top token
        :param top_token_indices: (batch_size, 2 * num_beams), representing each token's index (in vocabulary) of the top tokens
        :param top_token_beam_indices: (batch_size, 2 * num_beams), representing each token's corresponding beam index of the top tokens
        :param constraint_hypotheses: (batch_size * 2 * num_beams), the constraint state of each top token

        Note: the input arguments `top_token_*` for each sample in batch are sorted from the largest score to the smallest score.
        For example, if batch_size = 2 and num_beams = 3, then each of these values denote
        top_token_indices[1, 2]: what is the third-best next token for the second sample in the batch?
        top_token_scores[0, 1]: what is the score of the second-best next token for the first sample in the batch?
        top_token_beam_indices[0, 1]: which beam did we use to generate the second-best next token for the first sample in the batch?

        In this function, for each of the top-(2 * num_beams) tokens, we do the following:
        (1) If the top token is EOS token:
            This means that this hypothesis is done. Therefore, we save the hypothesis so-far to self.finished_beam_hypotheses_list.
        (2) If the top token is not EOS token
            We have to keep searching with this hypothesis. Therefore, we prepare the hypothesis for next time step.

        Returns a dictionary, where
        "unfinished_scores": size (batch_size * num_beams,), the score of the unfinished beams
        "unfinished_token_indices": size (batch_size * num_beams,), the index of the last token in the unfinished beams
        "unfinished_beam_indices": the index of the beam that was used to generate the new unfinished beam
        """
        device = top_token_scores.device

        # Initialize unfinished_token_*, which we will return for the next time step.
        unfinished_scores = torch.zeros((self.batch_size, self.num_beams), dtype=top_token_scores.dtype).to(device)  # score of the unfinished beams
        unfinished_token_indices = torch.zeros((self.batch_size, self.num_beams), dtype=top_token_indices.dtype).to(
            device)  # index of the last token of the unfinished beams
        unfinished_token_beam_indices = torch.zeros((self.batch_size, self.num_beams), dtype=top_token_beam_indices.dtype).to(
            device)  # index of the unfinished beam in the batch
        unfinished_constraint_hypotheses = [None] * (self.batch_size * self.num_beams)

        # Loop over the batch
        for batch_idx in range(self.batch_size):
            # get sample_beam_hypothesis_list: the finished_beam_hypothesis_list for this sample in the batch
            sample_beam_hypothesis_list: BeamHypothesisList = self.finished_beam_hypotheses_list[batch_idx]

            # get the top_token_scores, top_token_indices, top_token_beam_indices, constraint_hypotheses for this sample in the batch
            # NOTE: size of sample_top_token_*: (2 * num_beams,)
            sample_top_token_scores = top_token_scores[batch_idx]
            sample_top_token_indices = top_token_indices[batch_idx]
            sample_top_token_beam_indices = top_token_beam_indices[batch_idx]
            sample_const_hypos = constraint_hypotheses[batch_idx * (2 * self.num_beams): (batch_idx + 1) * (2 * self.num_beams)]

            # Loop over all top tokens
            sample_beam_idx = 0
            for top_token_score, top_token_index, top_token_beam_index, const_hypo in zip(
                    sample_top_token_scores, sample_top_token_indices, sample_top_token_beam_indices, sample_const_hypos
            ):
                # Note that top_token_beam_indices only denotes the index of the beam in each sample.
                # We transform this into `beam_idx_in_batch`, we denote the index of the beam among all (batch_size * num_beams) beams in the batch.
                beam_idx_in_batch = batch_idx * self.num_beams + top_token_beam_index

                # if top_token == EOS, we add the generation so-far to the beam_hypotheses_list
                if top_token_index.item() == eos_token_id:
                    # among the (batch_size * num_beams) input_ids, find the input_ids that correspond to this top_token
                    # NOTE: the size of new_input_ids: (seq_len,)
                    new_input_ids = input_ids[beam_idx_in_batch]

                    # add the new beam to sample_beam_hypothesis_list
                    sample_beam_hypothesis_list.add(
                        new_input_ids,
                        top_token_score,
                        const_hypo.num_satisfied_keywords(),
                    )

                # if top_token =/= EOS, we aggregate them for next time step.
                else:
                    # store the score, token_index, beam_idx_in_batch to the unfinished_scores, unfinished_token_indices, unfinished_token_beam_indices
                    unfinished_scores[batch_idx, sample_beam_idx] = top_token_score
                    unfinished_token_indices[batch_idx, sample_beam_idx] = top_token_index
                    unfinished_token_beam_indices[batch_idx, sample_beam_idx] = beam_idx_in_batch
                    unfinished_constraint_hypotheses[batch_idx * self.num_beams + sample_beam_idx] = const_hypo

                    sample_beam_idx += 1

                # once we have `num_beams` number of new beams, we don't have to add anymore.
                if sample_beam_idx == self.num_beams:
                    break

        # return the dictionary of unfinished_scores, unfinished_token_indices, unfinished_beam_indices, unfinished_constraint_hypotheses
        # Make sure to change the size of each tensor to (batch_size * num_beams,)
        return {
            "unfinished_scores": unfinished_scores.view(-1),  # (batch_size * num_beams,)
            "unfinished_token_indices": unfinished_token_indices.view(-1),  # (batch_size * num_beams,)
            "unfinished_beam_indices": unfinished_token_beam_indices.view(-1),  # (batch_size * num_beams,)
            "unfinished_constraint_hypotheses": unfinished_constraint_hypotheses, # (batch_size * num_beams,)
        }

    def finalize(
            self,
            input_ids: torch.LongTensor,
            beam_scores: torch.FloatTensor,
            constraint_hypotheses: List[ConstrainedHypothesis],
    ) -> Tuple[List[torch.LongTensor], List[torch.FloatTensor]]:
        """
        :param input_ids: (batch_idx * num_beams, max_length), input_ids of unfinished beams
        :param beam_scores: (batch_idx * num_beams,), scores of unfinished beams
        :param constraint_hypotheses: (batch_size * num_beams, ), the constraint state of unfinished beams

        Get the final best beams, among
        (1) unfinished beams, for which we get the input_ids, beam_scores and constraint_hypotheses as arguments
        (2) finished beams, which we store in self.batch_beam_hypothesis_list
        Returns a tuple of two lists, where
        - tuple[0] is the list of the input_ids of the best beams (length: batch_idx)
        - tuple[1] is the list of the scores of the best beams (length: batch_idx
        """

        # 1. Add all unfinished beam hypotheses to self.finished_beam_hypotheses_list
        for batch_idx in range(self.batch_size):
            # get sample_beam_hypothesis_list: the finished_beam_hypothesis_list for this sample in the batch
            sample_beam_hypothesis_list: BeamHypothesisList = self.finished_beam_hypotheses_list[batch_idx]

            for sample_beam_idx in range(self.num_beams):
                # get beam_idx_in_batch: index of the beam in all `batch_size * num_beams` beams in the batch
                beam_idx_in_batch = batch_idx * self.num_beams + sample_beam_idx

                # get the input_id for this beam, using `beam_idx_in_batch`
                # NOTE: the size of new_input_ids: (seq_len,)
                new_input_ids = input_ids[beam_idx_in_batch]

                # get the score of this beam, using `beam_idx_in_batch`
                # NOTE: beam_score should be a scalar
                beam_score = beam_scores[beam_idx_in_batch].item()

                # get the number of satisfied keywords of this beam, using `constraint_hypotheses`
                num_satisfied_keywords = constraint_hypotheses[beam_idx_in_batch].num_satisfied_keywords()

                # add the new hypothesis to sample_beam_hypothesis_list
                sample_beam_hypothesis_list.add(new_input_ids, beam_score, num_satisfied_keywords)

        # 2. Select the best hypothesis from each beam_hypothesis_list
        best_input_ids = []
        best_scores = []
        for batch_idx in range(self.batch_size):
            # get sample_beam_hypothesis_list: the finished_beam_hypothesis_list for this sample in the batch
            sample_beam_hypothesis_list: BeamHypothesisList = self.finished_beam_hypotheses_list[batch_idx]

            # get the group of hypotheses with the highest number of satisfied keywords
            max_num_satisfied_keywords = max([hyp.num_satisfied_keywords for hyp in sample_beam_hypothesis_list.beam_hypotheses])
            max_satisfied_group = [hyp for hyp in sample_beam_hypothesis_list.beam_hypotheses if hyp.num_satisfied_keywords == max_num_satisfied_keywords]

            # get best_hypothesis among this group (the one with the highest score)
            best_hypothesis = max(max_satisfied_group, key=lambda hyp: hyp.score)

            # save the input_ids and score of best_hypothesis
            best_input_ids.append(best_hypothesis.input_ids)
            best_scores.append(best_hypothesis.score)

        return best_input_ids, best_scores

## Neurologic Search

In [None]:
def rerank_beam(new_scores: torch.FloatTensor,
                top_token_beam_indices: torch.LongTensor,
                top_token_indices: torch.LongTensor,
                constraint_hypotheses: List[ConstrainedHypothesis]):
    """
        :param num_beams: number of beams
        :param new_scores: (batch_size, num_beams, vocab_size), accumulated score form previous beam_scores and the next_token_scores
        :param top_token_indices: (batch_size, 2 * num_beams), representing each token's index (in vocabulary) of the top tokens
        :param top_token_beam_indices: (batch_size, 2 * num_beams), representing each token's corresponding beam index of the top tokens
        :param constraint_hypotheses: the list of constraint hypothesis objects. (length: (batch_size * num_beams,))
    """
    batch_size, num_beams, vocab_size = new_scores.shape

    # Initialize new_*, which we will return for the next time step.
    new_top_token_beam_indices = torch.zeros((batch_size, 2 * num_beams), dtype=top_token_beam_indices.dtype).to(device)
    new_top_token_indices = torch.zeros((batch_size, 2 * num_beams), dtype=top_token_indices.dtype).to(device)
    new_top_token_scores = torch.zeros((batch_size, 2 * num_beams), dtype=new_scores.dtype).to(device)
    new_constraint_hypotheses = [None] * (batch_size * (2 * num_beams))

    for batch_idx in range(batch_size):
        scores = new_scores[batch_idx]  # (num_beams, vocab_size)
        best_beam_idx = top_token_beam_indices[batch_idx]  # (2 * num_beams,)
        best_token_idx = top_token_indices[batch_idx]  # (2 * num_beams,)
        const_hypos = constraint_hypotheses[batch_idx * num_beams: (batch_idx + 1) * num_beams]  # (num_beams,)

        candidates = set()
        for row, col in zip(best_beam_idx.tolist(), best_token_idx.tolist()):
            # add ConstrainedCandidate with (row, col) from (best_beam_idx, best_token_idx) to candidates
            # NOTE: you need to update corresponding constraint hypothesis to get a new constraint hypothesis
            # to fill in the field of ConstrainedCandidate
            new_hypo = const_hypos[row].advance(col)
            candidate = ConstrainedCandidate(row, col, scores[row, col], new_hypo)
            candidates.add(candidate)

        for row in range(num_beams):
            # for each previous beam (row),
            #  (1) add ConstrainedCandidates with the next token (col) that would satisfy one more constraint to candidates
            #  (2) add ConstrainedCandidate with the best next_token (col) to preserve candidates who satisfied constraints in the previous steps
            const_cols = const_hypos[row].get_unsatisfied_words()
            const_cols.append(torch.argmax(scores[row]).item())

            for col in const_cols:
                new_hypo = const_hypos[row].advance(col)
                candidate = ConstrainedCandidate(row, col, scores[row, col], new_hypo)
                candidates.add(candidate)

        # group candidates by the index ids of the satisfied keywords (i.e. get_satisfied_keywords_idx())
        idx_combs = set([x.hypothesis.get_satisfied_keywords_idx() for x in candidates])
        grouped_candidates = [[x for x in candidates if x.hypothesis.get_satisfied_keywords_idx() == c] for c in idx_combs]
        # sort candidates in each group by score, from high to low
        grouped_sorted_candidates = [sorted(g, key=lambda x: x.score, reverse=True) for g in grouped_candidates]

        # NOTE: get rid of candidate with -inf score, which is resulted by padding
        grouped_sorted_candidates = [[c for c in g if c.score > -1e8] for g in grouped_sorted_candidates]

        selected_candidates = []
        # select the top candidates with highest score within each group to fill in selected_candidates
        for top_n in range(max([len(g) for g in grouped_sorted_candidates])):
            # get top_n items from all the groups with len(group) >= n
            top_n_items = [g[top_n] for g in grouped_sorted_candidates if len(g) > top_n]
            # sort top_n_items by score
            sorted_top_n_items = sorted(top_n_items, key=lambda x: x.score, reverse=True)
            # add sorted top_n_items to selected_candidates
            selected_candidates.extend(sorted_top_n_items)

        # return the top 2 * num_beams candidates as the hypotheses for next time step
        selected_candidates = selected_candidates[:2 * num_beams]
        for j, selected_candidate in enumerate(selected_candidates):
            new_top_token_beam_indices[batch_idx, j] = selected_candidate.row
            new_top_token_indices[batch_idx, j] = selected_candidate.col
            new_top_token_scores[batch_idx, j] = selected_candidate.score
            new_constraint_hypotheses[batch_idx * (2 * num_beams) + j] = selected_candidate.hypothesis

    return new_top_token_beam_indices, new_top_token_indices, new_top_token_scores, new_constraint_hypotheses


def neurologic_search(prompts: List[str], keyword_lists: List[List[str]], num_beams: int, max_length: int) -> List[str]:
    """
    :param prompts: list of prompt strings
    :param keyword_lists: list of keyword list
    :param num_beams: number of beams
    :param max_length: max generation length
    :return: list of generation, including both the original prompt and generation
    """
    # encode the prompts using tokenizer (padding=True), to get input_ids and attention_mask
    # Note: don't forget to push the encoded text to device.
    input_encoding = tokenizer(prompts, padding=True, return_tensors="pt").to(device)
    input_ids, attention_mask = input_encoding["input_ids"], input_encoding["attention_mask"]

    tokenized_keywords = [list(map(lambda x: tokenizer.encode(f' {x}')[0], k_list)) for k_list in keyword_lists]
    constraints = initialize_constraint(tokenized_keywords, num_beams)

    if input_ids.size(-1) > max_length:
      raise ValueError("Input ID is larger than max_length.")

    # --- Do not change below --- #
    batch_size = input_ids.size(0)
    vocab_size = len(tokenizer)

    # initialize model_kwargs
    model_kwargs = {'attention_mask': attention_mask}

    # interleave input_ids according to num_beams.
    # For example, input_ids for ["Hi", "good"] with num_beams=3 becomes ["Hi", "Hi", "Hi", "good", "good", "good"]
    input_ids, model_kwargs = model._expand_inputs_for_generation(
        input_ids=input_ids,
        expand_size=num_beams,
        is_encoder_decoder=False,
        **model_kwargs,
    )
    # input_ids: tensor of size (batch_size * num_beams, seq_len)
    # model_kwargs: a dictionary with single element 'attention_mask', sized (batch_size * num_beams, seq_len)
    # --- Do not change above --- #

    # initialize beam_manager
    beam_manager = BeamManager(batch_size=batch_size, num_beams=num_beams)

    # initialize unfinished_beam_scores, a tensor of size (batch_size, num_beams) with all elements = 0
    unfinished_beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=device)

    # For each sample in the batch, set all initial beam_score to -1e9, except for the first beam
    unfinished_beam_scores[:, 1:] = -1e9
    unfinished_beam_scores = unfinished_beam_scores.view(-1)  # (batch_size * num_beams,)

    while True:
        # --- Do not change below --- #
        model_inputs = model.prepare_inputs_for_generation(input_ids, **model_kwargs)
        # --- Do not change above --- #

        # run model forward pass with model_inputs as the input
        # NOTE: we should set return_dict=True, output_attentions=False and output_hidden_states=False
        model_outputs = model(
            **model_inputs,
            return_dict=True,
            output_attentions=False,
            output_hidden_states=False
        )

        # compute log_probs for next tokens given the prompt
        # NOTE: size of next_token_scores: (batch_size * num_beams, vocab_size)
        next_token_logits = model_outputs.logits[:, -1, :]
        next_token_scores = F.log_softmax(next_token_logits, dim=-1)

        # add previous beam_scores to the next_token_scores
        # NOTE: size of new_scores: (batch_size * num_beams, vocab_size)
        new_scores = next_token_scores + unfinished_beam_scores.unsqueeze(1)

        # retrieve top-(2 * num_beams) next tokens for each sample in the batch
        # NOTE: size of `top_token_scores` and `top_token_indices` needs to be: (batch_size, 2 * num_beams)
        # NOTE: `top_token_scores` and `top_token_indices` should be sorted from the one with larget score to the one with smallest score (for each sample in batch)
        # NOTE: new_scores needs to be transformed to shape (batch_size, num_beams * vocab_size) prior to topk operation.
        # Hint: use torch.topk with largest=True, sorted=True
        top_token_scores, top_token_indices = torch.topk(
            new_scores.view(batch_size, num_beams * vocab_size), 2 * num_beams,
            dim=1, largest=True, sorted=True
        )

        # since top_token_indices are over num_beams * vocab_size, divide it by num_beams to get vocabulary index and beam index
        top_token_beam_indices = torch.div(top_token_indices, vocab_size,
                                           rounding_mode="floor")  # from which beam the top-token was retrieved from
        top_token_indices = top_token_indices % vocab_size  # the index of top-token in the vocabulary

        # rerank the beam candidates based on neurologic algorithm
        rerank_outputs = rerank_beam(new_scores=torch.reshape(new_scores, [batch_size, num_beams, -1]),
                                     top_token_beam_indices=top_token_beam_indices,
                                     top_token_indices=top_token_indices,
                                     constraint_hypotheses=constraints)
        top_token_beam_indices, top_token_indices, top_token_scores, constraints = rerank_outputs

        # --- Run beam_manager.process and save the results in unfinished_beam_scores, unfinished_token_indices and unfinished_beam_indices --- #
        unfinished_beam_outputs = beam_manager.process(
            input_ids,
            top_token_scores,
            top_token_indices,
            top_token_beam_indices,
            constraints,
        )
        unfinished_beam_scores = unfinished_beam_outputs["unfinished_scores"]
        unfinished_token_indices = unfinished_beam_outputs["unfinished_token_indices"]
        unfinished_beam_indices = unfinished_beam_outputs["unfinished_beam_indices"]
        constraints = unfinished_beam_outputs["unfinished_constraint_hypotheses"]

        # --- Prepare input_ids for next time step --- #
        # index input_ids with the unfinished beam indices
        # NOTE: input_ids should be (batch_size * num_beams, seq_len)
        input_ids = input_ids[unfinished_beam_indices]

        # concatenate the unfinished token index to the corresponding input_ids
        input_ids = torch.cat([input_ids, unfinished_token_indices.unsqueeze(-1)], dim=-1)

        # --- Do not change below --- #
        # update the model_kwargs according to the concatenated input_ids
        model_kwargs = model._update_model_kwargs_for_generation(
            model_outputs, model_kwargs, is_encoder_decoder=False
        )
        if model_kwargs["past_key_values"] is not None:
            model_kwargs["past_key_values"] = model._temporary_reorder_cache(
                model_kwargs["past_key_values"], unfinished_beam_indices,
            )
        # --- Do not change above --- #

        # if unfinished input_ids reach the max seq length, exit the loop
        if input_ids.size(-1) == max_length:
            break

    # --- Run beam_manager.finalize to get the best_input_ids and best_scores, among all finished / unfinished beams --- #
    best_input_ids, best_scores = beam_manager.finalize(input_ids, unfinished_beam_scores, constraints)

    # if len(best_input_ids) < max_length, pad them to the max length
    for batch_idx, sample_input_ids in enumerate(best_input_ids):
        if sample_input_ids.size(-1) < max_length:
            pad_tensor = torch.LongTensor([pad_token_id] * (max_length - sample_input_ids.size(-1))).to(device)

            # pad best_input_ids with pad_tensor
            best_input_ids[batch_idx] = torch.cat([sample_input_ids, pad_tensor], dim=-1)

    # transform best_input_ids (which is currently a list of tensors) into a tensor of size (batch_idx, max_seq_length)
    best_input_ids = torch.stack(best_input_ids, dim=0)

    return tokenizer.batch_decode(best_input_ids, skip_special_tokens=True)

### Sanity check for debugging

In [None]:
sents = [
    "The soccer game was tied 3 to 3 and there was a minute left to play.",
    "Molly loves popcorn.",
    "Tim rented a car to visit his ill mother.",
]

keywords = [
    ["Julie", "goal"],
    ["Molly", "mom"],
    ["Tim", "mother"],
]

neurologic_search(sents, keywords, num_beams=5, max_length=35)

# expected output: ['The soccer game was tied 3 to 3 and there was a minute left to play. The goal was scored by the goalkeeper and the ball was headed for the net. Julie was', 'Molly loves popcorn.\n\n"I love popcorn," Molly said. "I love popcorn mommy."\n', 'Tim rented a car to visit his ill mother.\n\n"She was very upset," Tim said. "She said mother, I\'m']

## 3.2 Evaluate

In [None]:
# Given a list of generations, evaluate their perplexity / fluency / diversity and report the result.

def evaluate(generations):
  generations = [_ for _ in generations if _ != '']
  print("Computing perplexity...")
  perplexity = compute_perplexity(generations)
  print("Computing fleuncy...")
  fluency = compute_fluency(generations)
  print("Compute diversity...")
  diversity = compute_diversity(generations)
  print("Neurologic Search")
  print(f'perplexity = {perplexity:.2f}')
  print(f'fluency = {fluency:.2f}')
  print(f'diversity = {diversity[0]:.2f}, {diversity[1]:.2f}, {diversity[2]:.2f}')
  print()

In [None]:
# If your implementation is efficient enough, the following code will run in no longer than 3 minutes with device = gpu.

from tqdm import tqdm
from pprint import pprint

SUBSET=10
NUM_KEYWORDS = 2
prompts = [item['prompt'] for item in test_data][:SUBSET]
keywords = [item['constraint_words'][:NUM_KEYWORDS] for item in test_data][:SUBSET]

MAX_LEN = 50
NUM_BEAMS = 5
BATCH_SIZE = 5

generations = []
for batch_start_idx in tqdm(range(0, len(prompts), BATCH_SIZE)):
  batched_prompts = prompts[batch_start_idx: batch_start_idx + BATCH_SIZE]
  batched_keywords = keywords[batch_start_idx: batch_start_idx + BATCH_SIZE]
  batched_generations = neurologic_search(batched_prompts, batched_keywords, NUM_BEAMS, MAX_LEN)

  # remove prompt from generation
  batched_generations = [generation[len(prompt):] for prompt, generation in zip(batched_prompts, batched_generations)]

  generations += batched_generations

evaluate(generations)

In [None]:
# print first 10 generations

sampled_prompts = prompts[:10]
sampled_generations = generations[:10]
sampled_keywords = keywords[:10]

for idx, (prompt, generation, keyword) in enumerate(zip(sampled_prompts, sampled_generations, sampled_keywords)):
  print(f"Prompt {idx}")
  print(prompt)
  print(f"Keyword {idx}")
  print(keyword)
  print(f"Generation {idx}")
  print(generation)
  print("---------")
  print()