In [1]:
# pip install accelerate transformers tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM, logging
import torch
from tqdm.notebook import tqdm
import warnings
warnings.filterwarnings('ignore')

# Disable transformers warnings
logging.set_verbosity_error()

In [25]:
class DebateTemplate:
  template = \
  f"""This is a debate. Below is the series of messages associated with this debate.

  Judge: I will now ask you a question. I want you to respond to the question with your opinion, keeping it concise, then the opposing debator will respond with their rebuttal. You will go back and forth, and at the end, I will reward the best model.
  """
  
  def __init__(self):
    pass

  def start_debate(self, question: str):
    self.template += "\n"
    self.template += f"Judge: {question} Debator 1, your turn."

  def add_message(self, message: str, role: str):
    self.template += "\n"
    self.template += f"{role}: {message}"

  def to_message(self) -> list[dict]:
    return [{"role": "user", "content": self.template}]

  @classmethod
  def to_dict(cls, message, role):
    return {"role": role, "content": message}

In [26]:
from transformers import TextStreamer, LlamaForCausalLM

class Model:
  tokenizer = AutoTokenizer.from_pretrained(
    "meta-llama/Llama-3.1-8B-Instruct",
    use_fast=False,  # Ensures compatibility with all models
  )
  model: LlamaForCausalLM = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3.1-8B-Instruct",
    device_map="auto",
    torch_dtype=torch.bfloat16,
  )
  streamer = TextStreamer(tokenizer)

  def __init__(self, context=[]):
    self.context = context


  def prepare_text(self, messages: list[dict]) -> str:
    messages = [
      *self.context,
      *messages,
    ]
    text = self.tokenizer.apply_chat_template(
      messages,
      tokenize=False,
      add_generation_prompts=False
    )
    return text

  @classmethod
  def generate(cls, formatted_text: str):
    input_ids = cls.tokenizer(formatted_text, return_tensors="pt", add_special_tokens=False).to("cuda")
    input_length = len(input_ids['input_ids'][0])
    outputs = cls.model.generate(
      **input_ids,
      max_length=312,          # Increase maximum length
      min_length=15,           # Set minimum length if desired
      pad_token_id=cls.tokenizer.eos_token_id,  # Proper padding
      temperature=0.7,        # Control randomness (0.0 to 1.0)
      no_repeat_ngram_size=2,  # Avoid repetition)
    )
    generated_text = cls.tokenizer.decode(outputs[0][input_length:], skip_special_tokens=True)
    return generated_text
    
  @classmethod
  def generate_stream(cls, formatted_text: str):
    input_ids = cls.tokenizer(formatted_text, return_tensors="pt", add_special_tokens=False).to("cuda")
    outputs = cls.model.generate(
      **input_ids,
      max_length=312,          # Increase maximum length
      min_length=30,           # Set minimum length if desired
      num_return_sequences=1,  # Number of outputs to generate
      pad_token_id=cls.tokenizer.eos_token_id,  # Proper padding
      streamer=cls.streamer,
      do_sample=True,         # Enable sampling
      temperature=0.7,        # Control randomness (0.0 to 1.0)
      no_repeat_ngram_size=2  # Avoid repetition)
    )
    for output in outputs:
      yield cls.tokenizer.decode(output, skip_special_tokens=True)

  @classmethod
  def prefix_allowed_tokens_fn(cls, batch_id, input_ids):
    # Skip the input tokens
    return input_ids[len(input_ids):]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [27]:
context_1 = [
  DebateTemplate.to_dict("You are going to be engaged in a debate with an opponent. You are Debator ID=1.","system"),
  DebateTemplate.to_dict("<Debator 1>I am Debator 1</Debator 1>","assistant")
]

context_2 = [
  DebateTemplate.to_dict("You are going to be engaged in a debate with an opponent. You are Debator ID=2.","system"),
  DebateTemplate.to_dict("<Debator 2>I am Debator 2</Debator 2>","assistant")
]

model_1 = Model(context_1)
model_2 = Model(context_2)

template = DebateTemplate()

template.start_debate("Is Universal Basic Income good?")


In [30]:
text = model_1.prepare_text(template.to_message())

model_1.generate(text)

''

In [31]:
text

'<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\nYou are going to be engaged in a debate with an opponent. You are Debator ID=1.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n<Debator 1>I am Debator 1</Debator 1><|eot_id|><|start_header_id|>user<|end_header_id|>\n\nThis is a debate. Below is the series of messages associated with this debate.\n\n  Judge: I will now ask you a question. I want you to respond to the question with your opinion, keeping it concise, then the opposing debator will respond with their rebuttal. You will go back and forth, and at the end, I will reward the best model.\n  \nJudge: Is Universal Basic Income good? Debator 1, your turn.<|eot_id|>'

In [36]:
model = Model()

testing = "What do you think about UBI?"

text = model.prepare_text([{"role":"user","content":testing}])
model.generate("You suck")

"assistant\n\n\nUniversal Basic Income (UBI) is a concept where every citizen or resident of a country receives a regular, unconditional sum of money from the government to cover their basic needs. The idea has been debated and experimented with in various forms around the world.\n\nHere are some pros and cons of UBU:\n\nPros:\n\n1. **Simplification of welfare systems**: Ubi could potentially consolidate and simplify existing welfare programs, reducing bureaucracy and administrative costs.\n2. ***Poverty reduction***: A guaranteed minimum income could help alleviate poverty and provide a safety net for the most vulnerable members of society.\n3. *Economic stimulus*: UBi could put more money in people's pockets, boosting local economies and encouraging spending.\n4. ****Freedom and autonomy****: UBio could give people the financial security to pursue their passions and interests, rather than just taking any job for a living wage.\n5. *****Social cohesion*****: By providing a basic incom