In [None]:
from sklearn.neighbors import NearestNeighbors
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM
from typing import Tuple
from typing import Union, List, Literal, Optional
import gc
import numpy as np
import os
import random
import sys
import torch
import torch.nn.functional as F
import transformers

sys.path.append("../")
from slack_notifier import send_slack_notification


print(f"Python Version : {sys.version}")
print(f"Torch Version : {torch.__version__}")
print(f"Transformers Version : {transformers.__version__}")

In [None]:
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM
from typing import Tuple
from typing import Union, List, Optional
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

class EBD:
    def __init__(self,
                 model_name: str,
                 device: Union[int,str] = 0):
        device_map = torch.device(f"cuda:{device}" if torch.cuda.is_available() else "cpu")
        self.model = AutoModelForCausalLM.from_pretrained(model_name, device_map=device_map, use_cache=True, attn_implementation="flash_attention_2", torch_dtype=torch.float16)
        # self.model = torch.compile(self.model)
        self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
        self.device = device_map
        self.tokenizer.pad_token = self.tokenizer.eos_token
        
    def construct_context_based_inputs(self,
                                       prompts: List[str],
                                       context_prefix: str = None,
                                       contexts: List[List[Tuple[Optional[str], str]]] = None) -> Tuple[List[str], List[int]]:
        """
        Construct input strings with contexts for the model
        Args:
            prompts: List of prompts to generate completions for
            context_prefix: Prefix to add to the context before each prompt
            contexts: List of lists of tuples containing context IDs and context texts
        Returns:
            List of input strings with contexts
        """
        
        inputs_with_contexts = []
        inputs_with_contexts_to_prompt_index = []
        for prompt_index, prompt in enumerate(prompts):
            if contexts is not None:
                context_list = contexts[prompt_index]
                for context_id, context_text in context_list:
                    if len(context_text) > 0:
                        if context_id is not None:
                            context_prefix = context_text.format(context_id)
                        inputs_with_contexts.append(f"{context_prefix}\n{context_text} {self.tokenizer.eos_token} {prompt}")
                        inputs_with_contexts_to_prompt_index.append(prompt_index)
            else:
                inputs_with_contexts.append(prompt)
                inputs_with_contexts_to_prompt_index.append(prompt_index)
        return inputs_with_contexts
        
    def predict_next_token(self, 
                           logits: torch.Tensor, 
                           decoding_strategy: str, 
                           top_p: float, 
                           top_k: int, 
                           use_repetition_penalty: bool, 
                           repetition_penalty_value: float, 
                           generated_tokens: List[set] = None
                           ) -> torch.Tensor :
        # * Repetitin Penalty 참고 코드 : https://huggingface.co/transformers/v2.11.0/_modules/transformers/modeling_utils.html#PreTrainedModel.enforce_repetition_penalty_
        if use_repetition_penalty:
            assert repetition_penalty_value >= 1.0, "Repetition penalty must be >= 1."
            mask = torch.zeros_like(logits)
            for i, token_set in enumerate(generated_tokens):
                mask[i, list(token_set)] = 1.0
            penalty = torch.where(mask == 1.0, repetition_penalty_value, 1.0) # generated_tokens에 있는 토큰들은 penalty를 repetition_penalty_value로, 없는 토큰들은 1.0(현상 유지)으로 설정
            logits *= torch.where(logits < 0, penalty, 1.0/penalty) # if logit is smaller than 0, multiply with penalty, else divide by penalty
        
        if decoding_strategy == 'top_p':
            assert top_p is not None, "top_p must be provided for top_p sampling"
            logits = self._top_p_sampling(logits, top_p)
            probs = F.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1).squeeze()

        elif decoding_strategy == 'top_k':
            assert top_k is not None, "top_k must be provided for top_k sampling"
            logits = self._top_k_sampling(logits, top_k)
            probs = F.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1).squeeze()

        elif decoding_strategy == 'greedy':
            next_token = torch.argmax(logits, dim=-1)

        return next_token
    
    def plot_next_word_distribution(self, logits, top_k=10):
        """
        Plots the probability distribution of the next word from logits.

        Parameters:
        - logits: Tensor, the raw logits output by the model for the next word.
        - top_k: int, the number of most probable words to display in the plot.
        """
        probabilities = F.softmax(logits, dim=-1)

        top_k_probs, top_k_indices = torch.topk(probabilities, top_k)
        top_k_words = [self.tokenizer.decode([idx]) for idx in top_k_indices.tolist()[0]]

        # Plot the distribution
        plt.figure(figsize=(10, 6))
        plt.bar(top_k_words, top_k_probs.tolist()[0], alpha=0.7)
        plt.xlabel('Words')
        plt.ylabel('Probability')
        plt.title('Next Word Probability Distribution')
        plt.xticks(rotation=45, ha='right')
        plt.tight_layout()
        plt.show()

    def compute_maximum_entropy_layer_probs(self,
                                            input_ids: torch.Tensor,
                                            attention_mask: torch.Tensor,
                                            temperature: float = 1.0) -> torch.Tensor:
        with torch.no_grad():
            outputs = self.model(input_ids,
                                 attention_mask=attention_mask,
                                 return_dict=True,
                                 output_hidden_states=True)
        # batch_layer_entropies = []
        # for hidden_state_l in outputs.hidden_states:
        #     current_rep_hidden_states = hidden_state_l[:, -1:, :]  
        #     current_rep_logits = torch.matmul(current_rep_hidden_states, self.model.lm_head.weight.t())  # (batch_size, vocab_size)
        #     probs = F.softmax(current_rep_logits / temperature, dim=-1) 
        #     batch_layer_entropies.append(-torch.sum(probs * torch.log(probs), dim=-1))
        # entropies_tensor = torch.stack(batch_layer_entropies)
        
        
        entropies_tensor = torch.stack([
            -torch.sum(F.softmax(torch.matmul(hidden_state_l[:, -1:, :], self.model.lm_head.weight.t()) / temperature, dim=-1) * 
            torch.log_softmax(torch.matmul(hidden_state_l[:, -1:, :], self.model.lm_head.weight.t()) / temperature, dim=-1), dim=-1)
            for hidden_state_l in outputs.hidden_states
        ])
        
        
        maximum_entropy_layer_indices = torch.max(entropies_tensor, dim=0).indices.squeeze(-1)
        maximum_entropy_layer_probs = []
        for prompt_index, maximum_entropy_layer_index in enumerate(maximum_entropy_layer_indices):
            maximum_entropy_layer = outputs.hidden_states[maximum_entropy_layer_index][:, -1:, :]
            maximum_entropy_layer_logits = torch.matmul(maximum_entropy_layer, self.model.lm_head.weight.t())
            maximum_entropy_layer_probs.append(maximum_entropy_layer_logits[prompt_index])
        return torch.stack(maximum_entropy_layer_probs).squeeze(1)
    
    def compute_le_ens_scores(self,
                              input_ids_with_context: torch.Tensor,
                              attention_mask_with_context: torch.Tensor,
                              context_lengths: List[int],
                              temperature: float = 1.0) -> torch.Tensor:
        with torch.no_grad():
            outputs = self.model(input_ids_with_context,
                                 attention_mask=attention_mask_with_context,
                                 return_dict=True,
                                 output_hidden_states=True)
        next_token_logits = outputs.logits[:, -1, :]
        batched_logits_per_prompt = next_token_logits.split(context_lengths)
        assert batched_logits_per_prompt[0].shape[0] == context_lengths[0]
        
        le_ens_scores = []
        for logits in batched_logits_per_prompt:
            probs = F.softmax(logits / temperature, dim=-1)
            le_ens_score_entropy = F.softmax(torch.sum(probs * torch.log(probs), dim=-1), dim=-1)
            le_ens_scores.append(torch.sum(le_ens_score_entropy.unsqueeze(1) * torch.log(probs), dim=0))
        return torch.stack(le_ens_scores)
        
        
        
    def generate(self, 
                prompts: List[str],
                context_prefix: str = None,
                contexts: Optional[List[str]] = None, 
                beta: float = 0.5,
                max_length: int = 256,
                decoding_strategy: str = 'top_p',
                top_p_value: float = 0.9,
                top_k_value: int = 20,
                use_repetition_penalty: bool = False, 
                repetition_penalty_value: float = 1.0,
                ) -> List[List[int]]:

        tokenized_inputs = self.tokenizer(prompts, return_tensors="pt", padding=True, truncation=True, max_length=self.model.config.max_position_embeddings)
        tokenized_inputs = {key: value.to(self.model.device) for key, value in tokenized_inputs.items()}
        input_ids = tokenized_inputs['input_ids']
        attention_mask = tokenized_inputs['attention_mask']
        
        
        inputs_with_contexts = self.construct_context_based_inputs(prompts,
                                                                   context_prefix,
                                                                   contexts)
        tokenized_inputs_with_contexts = self.tokenizer(inputs_with_contexts, return_tensors="pt", padding=True, truncation=True, max_length=self.model.config.max_position_embeddings)
        tokenized_inputs_with_contexts = {key: value.to(self.model.device) for key, value in tokenized_inputs_with_contexts.items()}
        input_ids_with_context = tokenized_inputs_with_contexts['input_ids']
        attention_mask_with_context = tokenized_inputs_with_contexts['attention_mask']
        
        cur_len = 0
        batch_size = len(input_ids)
        unfinished_sents = input_ids.new(batch_size).fill_(1)
        sent_lengths = input_ids.new(batch_size).fill_(max_length)

        generated_tokens = [[] for _ in range(batch_size)] # e.g., [[4132, 102, 29402], [2378, 7893, 23001]]
        context_lengths = [len(context) for context in contexts]

                
        with torch.no_grad():
            pbar = tqdm(total=max_length, desc="EBD'ing", position=0)
            while cur_len < max_length:
                maximum_entropy_layer_probs = self.compute_maximum_entropy_layer_probs(input_ids,
                                                                                       attention_mask)
                le_ens_scores = self.compute_le_ens_scores(input_ids_with_context,
                                                           attention_mask_with_context,
                                                           context_lengths)
                outputs = self.model(input_ids, attention_mask=attention_mask)
                # next_token_logits = outputs.logits[:, -1, :] 
                self.plot_next_word_distribution(outputs.logits[:, -1, :])
                next_token_logits = (1 + beta) * le_ens_scores - beta * torch.log(maximum_entropy_layer_probs)
                self.plot_next_word_distribution(next_token_logits)
                raise ValueError("Stop")
                next_token = self.predict_next_token(logits=next_token_logits, 
                                                    decoding_strategy=decoding_strategy, 
                                                    top_p=top_p_value, 
                                                    top_k=top_k_value, 
                                                    use_repetition_penalty=use_repetition_penalty, 
                                                    repetition_penalty_value=repetition_penalty_value, 
                                                    generated_tokens=[set(tokens) for tokens in generated_tokens])

                input_ids = torch.cat([input_ids, next_token.unsqueeze(-1)], dim=-1)
                attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1), device=self.device)], dim=-1)
                repeated_next_token = torch.cat([
                    next_token[i].repeat(context_lengths[i]) for i in range(len(next_token))
                ], dim=0)
                input_ids_with_context = torch.cat([input_ids_with_context, repeated_next_token.unsqueeze(-1)], dim=-1)
                attention_mask_with_context = torch.cat([attention_mask_with_context, torch.ones((input_ids_with_context.shape[0], 1), device=self.device)], dim=-1)
                
                cur_len += 1

                # Update generated tokens and check for completion
                for i, token in enumerate(next_token.tolist()):
                    if unfinished_sents[i] == 1:
                        generated_tokens[i].append(token)
                    if unfinished_sents[i] == 1 and token == self.tokenizer.eos_token_id:
                        unfinished_sents[i] = 0
                        sent_lengths[i] = cur_len

                # Check for sentences that are finished
                if self.tokenizer.eos_token_id is not None:
                    eos_in_sents = next_token == self.tokenizer.eos_token_id
                    is_sents_unfinished_and_token_to_add_is_eos = unfinished_sents.mul(eos_in_sents.long()).bool()
                    sent_lengths.masked_fill_(is_sents_unfinished_and_token_to_add_is_eos, cur_len)
                    unfinished_sents.mul_((~eos_in_sents).long())

                # Break if all sentences are finished : stop when there is a EOS token in each sentence, or if we exceed the maximul length
                if unfinished_sents.max() == 0:
                    break
                pbar.update(1)
        return generated_tokens


In [None]:
def set_seed(random_seed):
    torch.manual_seed(random_seed)
    torch.cuda.manual_seed(random_seed)
    torch.cuda.manual_seed_all(random_seed)  # if use multi-GPU
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(random_seed)
    random.seed(random_seed)

set_seed(1002)

In [None]:
ebd_model = EBD(model_name="mistralai/Mistral-7B-Instruct-v0.3",
                device=2)

## Experiment 1 : Compare w/ or w/o using Context-aware Decoding

In [None]:
context_prefix = """
Here is a reference case, which you can use and must explicitly mention its reference id which is {reference_id}.
"""

contexts = [[
(
    "114 F.3d 596",
    """
    Relations Act, 29 U.S.C. § 185. The parties filed cross-motions for summary judgment, and the district court enforced the award. The Beacon Journal filed this timely appeal. II. This court reviews the district court’s grant of summary judgment de novo. Rowley v. United States, 76 F.3d 796, 799 (6th Cir.1996). Nevertheless, our scope of review, like the review of the district court, is extremely limited. The Supreme Court has made clear in the Steelworkers’ Trilogy and its progeny that courts must accord an arbitrator’s decision substantial deference because it is the arbitrator’s construction of the agreement, not the court’s construction, to which the parties have agreed. See United Paperworkers Int’l Union v. Misco, 484 U.S. 29, 37-8, 108 S.Ct. 364, 371, 98 L.Ed.2d 286 (1987) (“Because the parties have contracted to have disputes settled by an arbitrator chosen by them rather than by a judge, it is the arbitrator’s view of the facts and of the meaning of the contract that they have agreed to accept.”). Hence, our review is extremely limited. We review the arbitrator’s decision only to determine whether the arbitrator was “arguably construing or applying the contract and acting within the scope of his authority.” Id. at 38, 108 S.Ct. at 371. If the arbitrator’s award “draws its essence from the collective bargaining agreement,” and is not merely the arbitrator’s “own brand of industrial justice,” the award is legitimate. United Steelworkers of Am. v. Enterprise Wheel & Car Co., 363 U.S. 593, 597, 80 S.Ct. 1358, 1361, 4 L.Ed.2d 1424 (1960). Courts will not weigh the merits of the claim or determine whether the claim is supported by language in the written instrument; otherwise, the policy of settling labor disputes through arbitration would be undermined. Misco, 484 U.S. at 36, 108 S.Ct. at 369-70; see also Unite
    """
),
(
    "114 F.3d 596",
    """
    any evidence that a member had “to modify or change his/her vacation plans due to the management’s ‘new interpretation of its rights under the vacation and management rights clauses of the labor agreement.” Arbitrator’s Decision, Slip op. at 6. In contrast, management was “vague on the specifics of not being able to meet the necessities of the supervisors and the production needs of the newspaper.” Id. The arbitrator made no further findings, but instead found that the Union’s grievance was justified. He then crafted his own solution, whereby the four new supervisors and the Union employees were thrown into a “seniority pool” for vacation selection purposes. He also provided for a grievance procedure through the Union for employees that believed they were adversely affected by the new procedure. The Beacon Journal refused to comply with the arbitration award and instead instituted this lawsuit under section 801 of the Labor Management Relations Act, 29 U.S.C. § 185. The parties filed cross-motions for summary judgment, and the district court enforced the award. The Beacon Journal filed this timely appeal. II. This court reviews the district court’s grant of summary judgment de novo. Rowley v. United States, 76 F.3d 796, 799 (6th Cir.1996). Nevertheless, our scope of review, like the review of the district court, is extremely limited. The Supreme Court has made clear in the Steelworkers’ Trilogy and its progeny that courts must accord an arbitrator’s decision substantial deference because it is the arbitrator’s construction of the agreement, not the court’s construction, to which the parties have agreed. See United Paperworkers Int’l Union v. Misco, 484 U.S. 29, 37-8, 108 S.Ct. 364, 371, 98 L.Ed.2d 286 (1987) (“Because the parties have contracted to have disputes settled by an arbitrator chosen by them rather than by a judge, it is the arbitrator’s view
    """
),
(
    "114 F.3d 596",
    """
    of the facts and of the meaning of the contract that they have agreed to accept.”). Hence, our review is extremely limited. We review the arbitrator’s decision only to determine whether the arbitrator was “arguably construing or applying the contract and acting within the scope of his authority.” Id. at 38, 108 S.Ct. at 371. If the arbitrator’s award “draws its essence from the collective bargaining agreement,” and is not merely the arbitrator’s “own brand of industrial justice,” the award is legitimate. United Steelworkers of Am. v. Enterprise Wheel & Car Co., 363 U.S. 593, 597, 80 S.Ct. 1358, 1361, 4 L.Ed.2d 1424 (1960). Courts will not weigh the merits of the claim or determine whether the claim is supported by language in the written instrument; otherwise, the policy of settling labor disputes through arbitration would be undermined. Misco, 484 U.S. at 36, 108 S.Ct. at 369-70; see also United Steelworkers of Am. v. American Mfg. Co., 363 U.S. 564, 568, 80 S.Ct. 1343, 1346, 4 L.Ed.2d 1403 (1960) (“[C]ourts, therefore, have no business weighing the merits of the grievance, considering whether there is equity in a particular claim, or determining whether there is particular language in the written instrument which will support the claim.”). Despite the great amount of deference accorded an arbitrator’s decision, our review is not toothless when an arbitrator’s award disregards the collective bargaining agreement and its terms. See Lattimer-Stevens Co. v. United Steelworkers, 913 F.2d 1166, 1171-72 (6th Cir.1990) (Boggs, J., dissenting) (delineating eases setting aside arbitrator’s decision). Even though arbitrators are not flawless, courts must refrain from reversing an arbitrator simply because the court disagrees with the result or believes the arbitrator made a serious legal or factual error. Misco, 484 U.S. at 38, 108 S.Ct. at 371 (“that a court is convinced [the
    """
)
]]

prompts = ["""
Continue to write the following case using the style of my write up. Your answer contains from 100 to 400 words. Make your answer concise, relevant and avoid redundant language.

BEER, District Judge.
Alken-Ziegler, Incorporated, (Company) appeals from the district court’s grant of summary judgment affirming an arbitration award in favor of the International Union, United Automobile, Aerospace and Agricultural Implement Workers of America, and Local Union 985 (Union). For the following reasons, we find that, even in light of our deferential review, the arbitrator disregarded the provisions of the labor contract. Therefore, we reverse the district court’s decision and vacate the arbitration award.
I
The Company and the Union were parties to a labor contract effective December 15, 1999. In March, 2001, the Company notified the Union that it would be closing its Novi plant and that it would be necessary to terminate all of the employees at the facility. As a result of the plant closing on October 17, 2001, all but one employee was terminated during the calendar year, 2001. The Company refused to pay vacationpay benefits to employees who did not work for the Company on January 1, 2002. The Union filed a grievance.
Article 16 (61) of the labor agreement sets forth the eligibility requirement for payment of vacation benefits:
(a) Employees shall be eligible for vacations, time off and vacation pay as set forth below.
(b) For purposes of eligibility, the vacation year will be considered the calendar year period from January 1st to December 31.
(c) An employee covered by the agreement who is actually working on January 1st of any year and who has at least six (6) months seniority and has' worked at least eight hundred (800) hours from and after January 1st of the previous year shall be paid the equivalent of two-and-one half (2-1/2) days vacation pay.
ijs ifc tjc %
(f) Employees with twelve (12) months or more of seniority who have worked more than eight hundred (800) hours, but less than sixteen hundred (1600) hours, during the vacation year, shall receive a pro-rated vacation pay on the basis of the ratio of their actual hours to sixteen hundred (1600) hours, but not to exceed the full vacation pay to which they were entitled by reason of their seniority and hours worked as set forth above.
(g) Vacation pay will be computed on a straight time forty (40) hour basis including applicable shift premium. The employee’s hour basis including applicable shift premium. The employee’s hourly rate in effect when vacation is taken will be used to compute vacation pay. If an employee is laid off after six (6) months service, their vacation pay will be pro-rated same as above.
Pursuant to Article 5 of the labor contract, the parties arbitrated the grievance. At the arbitration the Union asserted that because it was not the employees’ fault that they were unable to work the full year, the employees were entitled to their vacation pay. The arbitrator granted the grievance, allowing all plaintiffs, who, but for being laid off, would have been able to continue employment and thereby qualify for vacation benefits. The arbitrator reasoned that “[i]t would be unreasonable to cause such forfeitures particularly where an employee has no control over the situation.”
The Company filed a complaint in the district court asserting that the arbitrator’s award contradicted the clear, mandatory commands of the labor contract, which required that an employee be “actually working” for the Company as of January 1, 2002, to receive vacation pay. The district court granted the Union’s motion for summary judgment and upheld the arbitrator’s award. The Company appealed.
II
"""]

max_length = 4
decoding_strategy = 'greedy'
use_repetition_penalty = True
repetition_penalty_value = 1.2
method = 'ebd'

outputs = ebd_model.generate(
                            prompts=prompts,
                            contexts=contexts,
                            max_length=max_length,
                            decoding_strategy=decoding_strategy,
                            context_prefix=context_prefix,
                            beta=0.5,
                            use_repetition_penalty=use_repetition_penalty,
                            repetition_penalty_value=repetition_penalty_value,
                            )
print("Answer:")
print(ebd_model.tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])

rep_extension = f'_rep_{f"{use_repetition_penalty}_rep_value_{repetition_penalty_value}" if use_repetition_penalty else use_repetition_penalty}'
filename = f"../basement/cad_generations/output_{method}_{decoding_strategy}_rep_{rep_extension}_{max_length}.txt"
os.makedirs(os.path.dirname(filename), exist_ok=True)
with open(filename, 'w') as file:
    file.write("Prompt:\n")
    file.write(prompts[0])
    file.write("\n\nAnswer:\n")
    file.write(ebd_model.tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])

In [None]:
decoded_output = ebd_model.tokenizer.batch_decode(outputs, skip_special_tokens=True)
for i, output in enumerate(decoded_output):
    print(f"Output {i}: {output}")