# Set up Environment

In [None]:
!pip install pytorch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 pytorch-cuda=11.8 -c pytorch -c nvidia
!pip install tqdm==4.66.1
!pip install requests==2.31.0
!pip install importlib-metadata==4.13.0
!pip install filelock==3.0.12
!pip install scikit-learn==1.2.2
!pip install numpy==1.26.3
!pip install tokenizers==0.13.3
!pip install sentencepiece==0.1.99
#!wget https://www.cs.cmu.edu/~vijayv/stories42M.pt

# utils.py

In [1]:
from typing import Dict, List, Optional, Union, Tuple, BinaryIO
import os
import sys
import json
import tempfile
import copy
from tqdm.auto import tqdm
from functools import partial
from urllib.parse import urlparse
from pathlib import Path
import requests
from hashlib import sha256
from filelock import FileLock
import importlib_metadata
import torch
import torch.nn as nn
from torch import Tensor

__version__ = "4.0.0"
_torch_version = importlib_metadata.version("torch")

hf_cache_home = os.path.expanduser(os.getenv("HF_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "huggingface")))
default_cache_path = os.path.join(hf_cache_home, "transformers")
PYTORCH_PRETRAINED_BERT_CACHE = os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path)
PYTORCH_TRANSFORMERS_CACHE = os.getenv("PYTORCH_TRANSFORMERS_CACHE", PYTORCH_PRETRAINED_BERT_CACHE)
TRANSFORMERS_CACHE = os.getenv("TRANSFORMERS_CACHE", PYTORCH_TRANSFORMERS_CACHE)

PRESET_MIRROR_DICT = {
    "tuna": "https://mirrors.tuna.tsinghua.edu.cn/hugging-face-models",
    "bfsu": "https://mirrors.bfsu.edu.cn/hugging-face-models",
}
HUGGINGFACE_CO_PREFIX = "https://huggingface.co/{model_id}/resolve/{revision}/{filename}"
WEIGHTS_NAME = "pytorch_model.bin"
CONFIG_NAME = "config.json"


def is_torch_available():
  return True


def is_tf_available():
  return False


def is_remote_url(url_or_filename):
  parsed = urlparse(url_or_filename)
  return parsed.scheme in ("http", "https")


def http_get(url: str, temp_file: BinaryIO, proxies=None, resume_size=0, headers: Optional[Dict[str, str]] = None):
  headers = copy.deepcopy(headers)
  if resume_size > 0:
    headers["Range"] = "bytes=%d-" % (resume_size,)
  r = requests.get(url, stream=True, proxies=proxies, headers=headers)
  r.raise_for_status()
  content_length = r.headers.get("Content-Length")
  total = resume_size + int(content_length) if content_length is not None else None
  progress = tqdm(
    unit="B",
    unit_scale=True,
    total=total,
    initial=resume_size,
    desc="Downloading",
    disable=False,
  )
  for chunk in r.iter_content(chunk_size=1024):
    if chunk:  # filter out keep-alive new chunks
      progress.update(len(chunk))
      temp_file.write(chunk)
  progress.close()


def url_to_filename(url: str, etag: Optional[str] = None) -> str:
  url_bytes = url.encode("utf-8")
  filename = sha256(url_bytes).hexdigest()

  if etag:
    etag_bytes = etag.encode("utf-8")
    filename += "." + sha256(etag_bytes).hexdigest()

  if url.endswith(".h5"):
    filename += ".h5"

  return filename


def hf_bucket_url(
  model_id: str, filename: str, subfolder: Optional[str] = None, revision: Optional[str] = None, mirror=None
) -> str:
  if subfolder is not None:
    filename = f"{subfolder}/{filename}"

  if mirror:
    endpoint = PRESET_MIRROR_DICT.get(mirror, mirror)
    legacy_format = "/" not in model_id
    if legacy_format:
      return f"{endpoint}/{model_id}-{filename}"
    else:
      return f"{endpoint}/{model_id}/{filename}"

  if revision is None:
    revision = "main"
  return HUGGINGFACE_CO_PREFIX.format(model_id=model_id, revision=revision, filename=filename)


def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
  ua = "transformers/{}; python/{}".format(__version__, sys.version.split()[0])
  if is_torch_available():
    ua += f"; torch/{_torch_version}"
  if is_tf_available():
    ua += f"; tensorflow/{_tf_version}"
  if isinstance(user_agent, dict):
    ua += "; " + "; ".join("{}/{}".format(k, v) for k, v in user_agent.items())
  elif isinstance(user_agent, str):
    ua += "; " + user_agent
  return ua


def get_from_cache(
  url: str,
  cache_dir=None,
  force_download=False,
  proxies=None,
  etag_timeout=10,
  resume_download=False,
  user_agent: Union[Dict, str, None] = None,
  use_auth_token: Union[bool, str, None] = None,
  local_files_only=False,
) -> Optional[str]:
  if cache_dir is None:
    cache_dir = TRANSFORMERS_CACHE
  if isinstance(cache_dir, Path):
    cache_dir = str(cache_dir)

  os.makedirs(cache_dir, exist_ok=True)

  headers = {"user-agent": http_user_agent(user_agent)}
  if isinstance(use_auth_token, str):
    headers["authorization"] = "Bearer {}".format(use_auth_token)
  elif use_auth_token:
    token = HfFolder.get_token()
    if token is None:
      raise EnvironmentError("You specified use_auth_token=True, but a huggingface token was not found.")
    headers["authorization"] = "Bearer {}".format(token)

  url_to_download = url
  etag = None
  if not local_files_only:
    try:
      r = requests.head(url, headers=headers, allow_redirects=False, proxies=proxies, timeout=etag_timeout)
      r.raise_for_status()
      etag = r.headers.get("X-Linked-Etag") or r.headers.get("ETag")
      # We favor a custom header indicating the etag of the linked resource, and
      # we fallback to the regular etag header.
      # If we don't have any of those, raise an error.
      if etag is None:
        raise OSError(
          "Distant resource does not have an ETag, we won't be able to reliably ensure reproducibility."
        )
      # In case of a redirect,
      # save an extra redirect on the request.get call,
      # and ensure we download the exact atomic version even if it changed
      # between the HEAD and the GET (unlikely, but hey).
      if 300 <= r.status_code <= 399:
        url_to_download = r.headers["Location"]
    except (requests.exceptions.ConnectionError, requests.exceptions.Timeout):
      # etag is already None
      pass

  filename = url_to_filename(url, etag)

  # get cache path to put the file
  cache_path = os.path.join(cache_dir, filename)

  # etag is None == we don't have a connection or we passed local_files_only.
  # try to get the last downloaded one
  if etag is None:
    if os.path.exists(cache_path):
      return cache_path
    else:
      matching_files = [
        file
        for file in fnmatch.filter(os.listdir(cache_dir), filename.split(".")[0] + ".*")
        if not file.endswith(".json") and not file.endswith(".lock")
      ]
      if len(matching_files) > 0:
        return os.path.join(cache_dir, matching_files[-1])
      else:
        # If files cannot be found and local_files_only=True,
        # the models might've been found if local_files_only=False
        # Notify the user about that
        if local_files_only:
          raise FileNotFoundError(
            "Cannot find the requested files in the cached path and outgoing traffic has been"
            " disabled. To enable model look-ups and downloads online, set 'local_files_only'"
            " to False."
          )
        else:
          raise ValueError(
            "Connection error, and we cannot find the requested files in the cached path."
            " Please try again or make sure your Internet connection is on."
          )

  # From now on, etag is not None.
  if os.path.exists(cache_path) and not force_download:
    return cache_path

  # Prevent parallel downloads of the same file with a lock.
  lock_path = cache_path + ".lock"
  with FileLock(lock_path):

    # If the download just completed while the lock was activated.
    if os.path.exists(cache_path) and not force_download:
      # Even if returning early like here, the lock will be released.
      return cache_path

    if resume_download:
      incomplete_path = cache_path + ".incomplete"

      @contextmanager
      def _resumable_file_manager() -> "io.BufferedWriter":
        with open(incomplete_path, "ab") as f:
          yield f

      temp_file_manager = _resumable_file_manager
      if os.path.exists(incomplete_path):
        resume_size = os.stat(incomplete_path).st_size
      else:
        resume_size = 0
    else:
      temp_file_manager = partial(tempfile.NamedTemporaryFile, mode="wb", dir=cache_dir, delete=False)
      resume_size = 0

    # Download to temporary file, then copy to cache dir once finished.
    # Otherwise you get corrupt cache entries if the download gets interrupted.
    with temp_file_manager() as temp_file:
      http_get(url_to_download, temp_file, proxies=proxies, resume_size=resume_size, headers=headers)

    os.replace(temp_file.name, cache_path)

    meta = {"url": url, "etag": etag}
    meta_path = cache_path + ".json"
    with open(meta_path, "w") as meta_file:
      json.dump(meta, meta_file)

  return cache_path


def cached_path(
  url_or_filename,
  cache_dir=None,
  force_download=False,
  proxies=None,
  resume_download=False,
  user_agent: Union[Dict, str, None] = None,
  extract_compressed_file=False,
  force_extract=False,
  use_auth_token: Union[bool, str, None] = None,
  local_files_only=False,
) -> Optional[str]:
  if cache_dir is None:
    cache_dir = TRANSFORMERS_CACHE
  if isinstance(url_or_filename, Path):
    url_or_filename = str(url_or_filename)
  if isinstance(cache_dir, Path):
    cache_dir = str(cache_dir)

  if is_remote_url(url_or_filename):
    # URL, so get it from the cache (downloading if necessary)
    output_path = get_from_cache(
      url_or_filename,
      cache_dir=cache_dir,
      force_download=force_download,
      proxies=proxies,
      resume_download=resume_download,
      user_agent=user_agent,
      use_auth_token=use_auth_token,
      local_files_only=local_files_only,
    )
  elif os.path.exists(url_or_filename):
    # File, and it exists.
    output_path = url_or_filename
  elif urlparse(url_or_filename).scheme == "":
    # File, but it doesn't exist.
    raise EnvironmentError("file {} not found".format(url_or_filename))
  else:
    # Something unknown
    raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename))

  if extract_compressed_file:
    if not is_zipfile(output_path) and not tarfile.is_tarfile(output_path):
      return output_path

    # Path where we extract compressed archives
    # We avoid '.' in dir name and add "-extracted" at the end: "./model.zip" => "./model-zip-extracted/"
    output_dir, output_file = os.path.split(output_path)
    output_extract_dir_name = output_file.replace(".", "-") + "-extracted"
    output_path_extracted = os.path.join(output_dir, output_extract_dir_name)

    if os.path.isdir(output_path_extracted) and os.listdir(output_path_extracted) and not force_extract:
      return output_path_extracted

    # Prevent parallel extractions
    lock_path = output_path + ".lock"
    with FileLock(lock_path):
      shutil.rmtree(output_path_extracted, ignore_errors=True)
      os.makedirs(output_path_extracted)
      if is_zipfile(output_path):
        with ZipFile(output_path, "r") as zip_file:
          zip_file.extractall(output_path_extracted)
          zip_file.close()
      elif tarfile.is_tarfile(output_path):
        tar_file = tarfile.open(output_path)
        tar_file.extractall(output_path_extracted)
        tar_file.close()
      else:
        raise EnvironmentError("Archive format of {} could not be identified".format(output_path))

    return output_path_extracted

  return output_path


def get_parameter_dtype(parameter: Union[nn.Module]):
  try:
    return next(parameter.parameters()).dtype
  except StopIteration:
    # For nn.DataParallel compatibility in PyTorch 1.5

    def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
      tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
      return tuples

    gen = parameter._named_members(get_members_fn=find_tensor_attributes)
    first_tuple = next(gen)
    return first_tuple[1].dtype


def get_extended_attention_mask(attention_mask: Tensor, dtype) -> Tensor:
  # attention_mask [batch_size, seq_length]
  assert attention_mask.dim() == 2
  # [batch_size, 1, 1, seq_length] for multi-head attention
  extended_attention_mask = attention_mask[:, None, None, :]
  extended_attention_mask = extended_attention_mask.to(dtype=dtype)  # fp16 compatibility
  extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
  return extended_attention_mask

# config.py

In [2]:
from typing import Union, Tuple, Dict, Any, Optional
import os
import json
from collections import OrderedDict
import torch
# from utils import CONFIG_NAME, hf_bucket_url, cached_path, is_remote_url

class PretrainedConfig(object):
  model_type: str = ""
  is_composition: bool = False

  def __init__(self, **kwargs):
    # Attributes with defaults
    self.return_dict = kwargs.pop("return_dict", True)
    self.output_hidden_states = kwargs.pop("output_hidden_states", False)
    self.output_attentions = kwargs.pop("output_attentions", False)
    self.torchscript = kwargs.pop("torchscript", False)  # Only used by PyTorch models
    self.use_bfloat16 = kwargs.pop("use_bfloat16", False)
    self.pruned_heads = kwargs.pop("pruned_heads", {})
    self.tie_word_embeddings = kwargs.pop(
      "tie_word_embeddings", True
    )  # Whether input and output word embeddings should be tied for all MLM, LM and Seq2Seq models.

    # Is decoder is used in encoder-decoder models to differentiate encoder from decoder
    self.is_encoder_decoder = kwargs.pop("is_encoder_decoder", False)
    self.is_decoder = kwargs.pop("is_decoder", False)
    self.add_cross_attention = kwargs.pop("add_cross_attention", False)
    self.tie_encoder_decoder = kwargs.pop("tie_encoder_decoder", False)

    # Parameters for sequence generation
    self.max_length = kwargs.pop("max_length", 20)
    self.min_length = kwargs.pop("min_length", 0)
    self.do_sample = kwargs.pop("do_sample", False)
    self.early_stopping = kwargs.pop("early_stopping", False)
    self.num_beams = kwargs.pop("num_beams", 1)
    self.num_beam_groups = kwargs.pop("num_beam_groups", 1)
    self.diversity_penalty = kwargs.pop("diversity_penalty", 0.0)
    self.temperature = kwargs.pop("temperature", 1.0)
    self.top_k = kwargs.pop("top_k", 50)
    self.top_p = kwargs.pop("top_p", 1.0)
    self.repetition_penalty = kwargs.pop("repetition_penalty", 1.0)
    self.length_penalty = kwargs.pop("length_penalty", 1.0)
    self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0)
    self.encoder_no_repeat_ngram_size = kwargs.pop("encoder_no_repeat_ngram_size", 0)
    self.bad_words_ids = kwargs.pop("bad_words_ids", None)
    self.num_return_sequences = kwargs.pop("num_return_sequences", 1)
    self.chunk_size_feed_forward = kwargs.pop("chunk_size_feed_forward", 0)
    self.output_scores = kwargs.pop("output_scores", False)
    self.return_dict_in_generate = kwargs.pop("return_dict_in_generate", False)
    self.forced_bos_token_id = kwargs.pop("forced_bos_token_id", None)
    self.forced_eos_token_id = kwargs.pop("forced_eos_token_id", None)

    # Fine-tuning task arguments
    self.architectures = kwargs.pop("architectures", None)
    self.finetuning_task = kwargs.pop("finetuning_task", None)
    self.id2label = kwargs.pop("id2label", None)
    self.label2id = kwargs.pop("label2id", None)
    if self.id2label is not None:
      kwargs.pop("num_labels", None)
      self.id2label = dict((int(key), value) for key, value in self.id2label.items())
      # Keys are always strings in JSON so convert ids to int here.
    else:
      self.num_labels = kwargs.pop("num_labels", 2)

    # Tokenizer arguments
    self.tokenizer_class = kwargs.pop("tokenizer_class", None)
    self.prefix = kwargs.pop("prefix", None)
    self.bos_token_id = kwargs.pop("bos_token_id", None)
    self.pad_token_id = kwargs.pop("pad_token_id", None)
    self.eos_token_id = kwargs.pop("eos_token_id", None)
    self.sep_token_id = kwargs.pop("sep_token_id", None)

    self.decoder_start_token_id = kwargs.pop("decoder_start_token_id", None)

    # task specific arguments
    self.task_specific_params = kwargs.pop("task_specific_params", None)

    # TPU arguments
    self.xla_device = kwargs.pop("xla_device", None)

    # Name or path to the pretrained checkpoint
    self._name_or_path = str(kwargs.pop("name_or_path", ""))

    # Drop the transformers version info
    kwargs.pop("transformers_version", None)

    # Additional attributes without default values
    for key, value in kwargs.items():
      try:
        setattr(self, key, value)
      except AttributeError as err:
        raise err

  @classmethod
  def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
    config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
    return cls.from_dict(config_dict, **kwargs)

  @classmethod
  def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
    with open(json_file, "r", encoding="utf-8") as reader:
      text = reader.read()
    return json.loads(text)

  @classmethod
  def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "PretrainedConfig":
    return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)

    config = cls(**config_dict)

    if hasattr(config, "pruned_heads"):
      config.pruned_heads = dict((int(key), value) for key, value in config.pruned_heads.items())

    # Update config with kwargs if needed
    to_remove = []
    for key, value in kwargs.items():
      if hasattr(config, key):
        setattr(config, key, value)
        to_remove.append(key)
    for key in to_remove:
      kwargs.pop(key, None)

    if return_unused_kwargs:
      return config, kwargs
    else:
      return config

  @classmethod
  def get_config_dict(
    cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
  ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
    cache_dir = kwargs.pop("cache_dir", None)
    force_download = kwargs.pop("force_download", False)
    resume_download = kwargs.pop("resume_download", False)
    proxies = kwargs.pop("proxies", None)
    use_auth_token = kwargs.pop("use_auth_token", None)
    local_files_only = kwargs.pop("local_files_only", False)
    revision = kwargs.pop("revision", None)

    pretrained_model_name_or_path = str(pretrained_model_name_or_path)
    if os.path.isdir(pretrained_model_name_or_path):
      config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
    elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
      config_file = pretrained_model_name_or_path
    else:
      config_file = hf_bucket_url(
        pretrained_model_name_or_path, filename=CONFIG_NAME, revision=revision, mirror=None
      )

    try:
      # Load from URL or cache if already cached
      resolved_config_file = cached_path(
        config_file,
        cache_dir=cache_dir,
        force_download=force_download,
        proxies=proxies,
        resume_download=resume_download,
        local_files_only=local_files_only,
        use_auth_token=use_auth_token,
      )
      # Load config dict
      config_dict = cls._dict_from_json_file(resolved_config_file)

    except EnvironmentError as err:
      msg = (
        f"Can't load config for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
        f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
        f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a {CONFIG_NAME} file\n\n"
      )
      raise EnvironmentError(msg)

    except json.JSONDecodeError:
      msg = (
        "Couldn't reach server at '{}' to download configuration file or "
        "configuration file is not a valid JSON file. "
        "Please check network or file content here: {}.".format(config_file, resolved_config_file)
      )
      raise EnvironmentError(msg)

    return config_dict, kwargs

class LlamaConfig(PretrainedConfig):
  model_type = "llama"
  def __init__(
    self,
    vocab_size: int = 32000,
    dim: int = 512,
    dropout: int = 0.0,
    n_layers: int = 8,
    n_heads: int = 8,
    n_kv_heads: Optional[int] = 8,
    max_seq_len: int = 1024,
    layer_norm_eps: float = 1e-5,
    multiple_of: int = 32,
    hidden_dim: Optional[int] = None,
    position_embedding_type: str = "rotary",
    use_cache: bool = True,
    **kwargs
  ):
    super().__init__(**kwargs)

    self.vocab_size = vocab_size
    self.dim = dim
    self.dropout = dropout
    self.n_layers = n_layers
    self.n_heads = n_heads
    self.max_seq_len = max_seq_len
    self.n_kv_heads = n_kv_heads
    self.layer_norm_eps = layer_norm_eps
    self.multiple_of = multiple_of
    self.hidden_dim = hidden_dim
    self.position_embedding_type = position_embedding_type
    self.use_cache = use_cache

# base_llama.py

In [3]:
from dataclasses import dataclass

import re
from torch import dtype
# from config import LlamaConfig
# from utils import *

class LlamaPreTrainedModel(nn.Module):
  config_class = LlamaConfig
  base_model_prefix = "llama"

  def __init__(self, config: LlamaConfig):
      super().__init__()
      self.config = config
      self.vocab_size = config.vocab_size
      self.n_layers = config.n_layers

  def init_weights(self):
    # Initialize weights
    self.apply(self._init_weights)

  def _init_weights(self, module):
    """ Initialize the weights """
    if isinstance(module, nn.Linear):
        torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
        if module.bias is not None:
            torch.nn.init.zeros_(module.bias)
    elif isinstance(module, nn.Embedding):
        torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

  @property
  def dtype(self) -> dtype:
    return get_parameter_dtype(self)

# rope.py

In [4]:
from typing import Tuple
import torch

def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
    """
    Helper function to reshape frequency tensor to have the same shape as the target tensor 'x'
    for the purpose of broadcasting the frequency tensor during element-wise operations.

    Args:
        freqs_cis (torch.Tensor): Frequency tensor to be reshaped.
        x (torch.Tensor): Target tensor for broadcasting compatibility.

    Returns:
        torch.Tensor: Reshaped frequency tensor.

    Raises:
        AssertionError: If the frequency tensor doesn't match the expected shape.
        AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions.
    """
    ndim = x.ndim
    assert 0 <= 1 < ndim
    assert freqs_cis.shape == (x.shape[1], x.shape[-1])
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
    return freqs_cis.view(shape)

def apply_rotary_emb(
    query: torch.Tensor,
    key: torch.Tensor,
    head_dim: int,
    max_seq_len: int,
    theta: float = 10000.0,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Apply rotary embeddings to input tensors using the given frequency tensor.

    This function applies rotary embeddings to the given query and key tensors. The rotation to each token
    embedding is a function of that token's position in the sequence, head_dim, and theta.
    The input tensors are reshaped as complex numbers to simplify your implementation.

    Args:
        query (torch.Tensor): Query tensor to apply rotary embeddings.
                              Shape: (batch_size, seqlen, n_local_heads, self.head_dim)
        key (torch.Tensor): Key tensor to apply rotary embeddings.
                              Shape: (batch_size, seqlen, n_local_kv_heads, self.head_dim)
        head_dim (int): Dimension of each attention head.
        max_seq_len (int): Maximum sequence length supported by model.
    Returns:
        Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
    """

    _, seqlen, _, _ = query.shape
    device = query.device
    # todo
    #
    # Please refer to slide 22 in https://phontron.com/class/anlp2024/assets/slides/anlp-05-transformers.pdf
    # and Section 3 in https://arxiv.org/abs/2104.09864.

    # reshape xq and xk to match the complex representation
    query_real, query_imag = query.float().reshape(query.shape[:-1] + (-1, 2)).unbind(-1)
    key_real, key_imag = key.float().reshape(key.shape[:-1] + (-1, 2)).unbind(-1)
    # This separates each query/key vector into its odd and even indices (assuming *one-indexing*).
    # query_real contains q_1, q_3, q_5, ... and query_imag contains q_2, q_4, q_6, ...

    # First, compute the trigonometric values in the second and fourth columns in
    # slide 22 (linked above).
    theta = 1.0 / torch.pow(theta, torch.arange(0, head_dim, 2).float() / head_dim)
    mtheta = torch.outer(torch.arange(query.shape[1]).float(), theta).to(device)
    cos = reshape_for_broadcast(torch.cos(mtheta), query_real)
    sin = reshape_for_broadcast(torch.sin(mtheta), query_real)

    # Then, combine these trigonometric values with the tensors query_real, query_imag,
    # key_real, and key_imag.

    # raise NotImplementedError

    # query_out = None
    # key_out = None
    query_odd = query_real * cos - query_imag * sin
    query_even = query_real * sin + query_imag * cos
    key_odd = key_real * cos - key_imag * sin
    key_even = key_real * sin + key_imag * cos

    query_out = torch.stack((query_odd, query_even), dim = -1).reshape(query.shape)
    key_out = torch.stack((key_odd, key_even), dim = -1).reshape(key.shape)
    # Return the rotary position embeddings for the query and key tensors
    return query_out, key_out

# llama.py

In [5]:
from contextlib import nullcontext
from typing import Optional, Tuple
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

#from base_llama import LlamaPreTrainedModel, LlamaConfig
#from rope import apply_rotary_emb
#from utils import *

# Root Mean Square Layer Normalization (https://arxiv.org/abs/1910.07467)
# borrowed from the official Llama implementation:
# https://github.com/facebookresearch/llama/blob/main/llama/model.py
class RMSNorm(torch.nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        """
        Initialize the RMSNorm normalization layer.

        Args:
            dim (int): The dimension of the input tensor.
            eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.

        Attributes:
            eps (float): A small value added to the denominator for numerical stability.
            weight (nn.Parameter): Learnable scaling parameter.

        """
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x):
        """
        Compute the root mean square normalization. Use Equation 4 under
        Section 4 of https://arxiv.org/abs/1910.07467 as a reference. Add
        the given epsilon value (self.eps) to the tensor's norm (i.e. inside
        the square root in Equation 4) before normalizing the tensor.

        Args:
            x (torch.Tensor): The input tensor.

        Returns:
            torch.Tensor: The normalized tensor.
        """
        # todo
        # raise NotImplementedError
        rms = torch.sqrt(torch.mean(x**2, dim = -1, keepdim = True) + self.eps)
        return x / rms

    def forward(self, x):
        """
        Apply the root mean square normalizer.

        Args:
            x (torch.Tensor): The input tensor.

        Returns:
            torch.Tensor: The output tensor after applying RMSNorm.

        """
        output = self._norm(x.float()).type_as(x)
        return output * self.weight

class Attention(nn.Module):
    def __init__(self, config: LlamaConfig):
        super().__init__()
        self.n_kv_heads = config.n_heads if config.n_kv_heads is None else config.n_kv_heads
        assert config.n_heads % self.n_kv_heads == 0
        model_parallel_size = 1
        self.n_local_heads = config.n_heads // model_parallel_size
        self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
        self.n_rep = self.n_local_heads // self.n_local_kv_heads
        self.head_dim = config.dim // config.n_heads
        self.max_seq_len = config.max_seq_len
        self.compute_query = nn.Linear(config.dim, config.n_heads * self.head_dim, bias=False)
        self.compute_key = nn.Linear(config.dim, self.n_kv_heads * self.head_dim, bias=False)
        self.compute_value = nn.Linear(config.dim, self.n_kv_heads * self.head_dim, bias=False)
        self.compute_output = nn.Linear(config.n_heads * self.head_dim, config.dim, bias=False)
        self.attn_dropout = nn.Dropout(config.dropout)
        self.resid_dropout = nn.Dropout(config.dropout)
        self.dropout = config.dropout

    def compute_query_key_value_scores(self,
                                       query: torch.Tensor,
                                       key: torch.Tensor,
                                       value: torch.Tensor) -> torch.Tensor:
        '''
        Jointly compute Scaled Dot Product Attention (see Section 3.2.1 in
        https://arxiv.org/abs/1706.03762 for details). The query, key, and
        value tensors each have shape (bs, n_local_heads, seqlen, head_dim).
        An optimal implemention will jointly computing attention for multiple
        heads (n_local_heads of them) at once using matrix/tensor operations.

        Make sure to use attention_dropout (self.attn_dropout) on the computed
        attention matrix before applying it to the value tensor.
        '''
        # todo
        # raise NotImplementedError
        score = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(query.shape[-1])
        attention = F.softmax(score, dim = -1)
        attention = self.attn_dropout(attention)
        output = torch.matmul(attention, value)
        return output


    def forward(
        self,
        x: torch.Tensor
    ):
        '''
        Llama2 uses Grouped-Query Attention. The details of GQA are actually
        not critical to solving this assignment; you are simply asked to
        compute Scaled Dot Product Attention (see above for details). GQA is
        a memory optimization to compute multi-head attention efficiently. See
        Section 2.2 in https://arxiv.org/abs/2305.13245 or
        https://ai.plainenglish.io/understanding-llama2-kv-cache-grouped-query-attention-rotary-embedding-and-more-c17e5f49a6d7
        for details.
        '''
        batch_size, seqlen, _ = x.shape

        query = self.compute_query(x)
        key = self.compute_key(x)
        value = self.compute_value(x)
        query = query.view(batch_size, seqlen, self.n_local_heads, self.head_dim)
        key = key.view(batch_size, seqlen, self.n_local_kv_heads, self.head_dim)
        value = value.view(batch_size, seqlen, self.n_local_kv_heads, self.head_dim)

        # RoPE relative positional embeddings
        query, key = apply_rotary_emb(query, key, self.head_dim, self.max_seq_len)

        # Grouped multiquery attention: expand out keys and values.
        # Convert both to:
        # (bs, seqlen, n_local_heads, head_dim)
        key = torch.repeat_interleave(key, dim=2, repeats=self.n_rep)
        value = torch.repeat_interleave(value, dim=2, repeats=self.n_rep)

        # make heads into a batch dimension
        query = query.transpose(1, 2)  # (bs, n_local_heads, seqlen, head_dim)
        key = key.transpose(1, 2)
        value = value.transpose(1, 2)
        output = self.compute_query_key_value_scores(query, key, value)

        # restore time as batch dimension and concat heads
        output = output.transpose(1, 2).contiguous().view(batch_size, seqlen, -1)

        # final projection into the residual stream
        output = self.resid_dropout(self.compute_output(output))
        return output


class FeedForward(nn.Module):
    def __init__(self, dim: int, hidden_dim: int, multiple_of: int, dropout: float):
        super().__init__()
        if hidden_dim is None:
            hidden_dim = 4 * dim
            hidden_dim = int(2 * hidden_dim / 3)
            hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
        self.w1 = nn.Linear(dim, hidden_dim, bias=False)
        self.w2 = nn.Linear(hidden_dim, dim, bias=False)
        self.w3 = nn.Linear(dim, hidden_dim, bias=False)
        self.dropout = nn.Dropout(dropout)

    def SwiGLU(self, x: torch.Tensor) -> torch.Tensor:
        '''
        Compute the SwiGLU activation function (see Section 2 in
        https://arxiv.org/abs/2204.02311
        '''
        return F.silu(self.w1(x)) * self.w3(x)

    def forward(self, x):
        return self.dropout(self.w2(self.SwiGLU(x)))


class LlamaLayer(nn.Module):
    def __init__(self, layer_id: int, config: LlamaConfig):
        super().__init__()
        self.n_heads = config.n_heads
        self.dim = config.dim
        self.head_dim = config.dim // config.n_heads
        self.attention = Attention(config)
        self.feed_forward = FeedForward(
            dim=config.dim,
            hidden_dim=config.hidden_dim,
            multiple_of=config.multiple_of,
            dropout=config.dropout,
        )
        self.layer_id = layer_id
        self.attention_norm = RMSNorm(config.dim, eps=config.layer_norm_eps)
        self.ffn_norm = RMSNorm(config.dim, eps=config.layer_norm_eps)

    def forward(self, x):
        '''
        This is the forward pass of the basic transformer building block. This is a
        modernized version of the block shown on the left of Figure 1 on
        https://arxiv.org/pdf/1706.03762.pdf.

        The transformer block should consist of:
        1) layer normalization of the input (via Root Mean Square layer normalization)
        2) self-attention on the layer-normalized input
        3) a residual connection (i.e., add the input to the output of the self-attention)
        3) layer normalization on the output of the self-attention
        4) a feed-forward network on the layer-normalized output of the self-attention
        5) add a residual connection from the unnormalized self-attention output to the
           output of the feed-forward network
        '''
        # todo
        # raise NotImplementedError
        # layer norm
        attention_norm = self.attention_norm(x)
        # self-attention
        attention = x + self.attention(attention_norm)
        # layer norm
        ffn_norm = self.ffn_norm(attention)
        # feed-forward network
        y = attention + self.feed_forward(ffn_norm)
        return y

class Llama(LlamaPreTrainedModel):
    def __init__(self, config: LlamaConfig):
        '''
        You will probably never need to call this function, unless you decide
        to pretrain a Llama model from scratch.
        '''
        super().__init__(config)
        self.params = config
        self.vocab_size = config.vocab_size
        self.n_layers = config.n_layers

        self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim)
        self.dropout = nn.Dropout(config.dropout)
        self.layers = torch.nn.ModuleList()
        for layer_id in range(config.n_layers):
            self.layers.append(LlamaLayer(layer_id, config))
        self.norm = RMSNorm(config.dim, eps=config.layer_norm_eps)
        self.output = nn.Linear(config.dim, config.vocab_size, bias=False)

        # share the unembedding parameters with the embedding parameters
        self.tok_embeddings.weight = self.output.weight # https://paperswithcode.com/method/weight-tying

        # some useful precompute for the RoPE relative positional embeddings

        # init all weights
        self.apply(self._init_weights)
        # apply special scaled init to the residual projections, per GPT-2 paper
        for pn, p in self.named_parameters():
            if pn.endswith('w3.weight') or pn.endswith('compute_output.weight'):
                torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layers))

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, tokens: torch.Tensor, targets: Optional[torch.Tensor] = None) -> torch.Tensor:
        _batch_size, seqlen = tokens.shape
        h = self.tok_embeddings(tokens)
        h = self.dropout(h)

        for layer in self.layers:
            h = layer(h)
        h = self.norm(h)

        if targets is not None:
            # if we are given some desired targets also calculate the loss
            logits = self.output(h)
        else:
            # inference-time mini-optimization: only forward the output on the very last position
            logits = self.output(h[:, [-1], :]) # note: using list [-1] to preserve the time dim

        return logits, h

    @torch.inference_mode()
    def generate(self, idx, max_new_tokens, temperature=1.0):
        """
        Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
        the sequence max_new_tokens times, feeding the predictions back into the model each time.
        We perform this generation using basic temperature sampling. Note that we are not using
        nucleus sampling (i.e. limiting ourselves to sampling from the top-k most probable tokens
        at each timestep), though this is often used in conjunction with temperature sampling,
        Most likely you'll want to make sure to be in model.eval() mode of operation for this.
        Also note this is a super inefficient version of sampling with no key/value cache.
        """
        for _ in range(max_new_tokens):
            # if the sequence context is growing too long we must crop it at block_size
            idx_cond = idx if idx.size(1) <= self.params.max_seq_len else idx[:, -self.params.max_seq_len:]
            # forward the model to get the logits for the index in the sequence
            logits, _ = self(idx_cond)
            logits = logits[:, -1, :] # crop to just the final time step
            # todo
            # raise NotImplementedError

            if temperature == 0.0:
                # select the single most likely index
                # idx_next = None
                idx_next = torch.argmax(logits, dim = -1, keepdim = True)
            else:
                '''
                Perform temperature sampling:
                1) identify  the logits at the final step.
                2) scale (divide) these probabilities by the given temperature.
                3) normalize the scaled logits with a softmax to obtain scaled probabilities.
                4) sample from the scaled probability distribution.

                Note that we are not using top-k sampling/nucleus sampling in this procedure.
                '''
                # idx_next = None
                scaled_logits = logits / temperature
                probs = F.softmax(scaled_logits, dim = -1)
                idx_next = torch.multinomial(probs, num_samples = 1)
            # append sampled index to the running sequence and continue
            idx = torch.cat((idx, idx_next), dim=1)


        return idx

def load_pretrained(checkpoint):
  device = 'cuda' if torch.cuda.is_available() else 'cpu' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc.
  #dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32' or 'bfloat16' or 'float16'
  dtype = "float32"

  torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
  torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
  device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
  ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
  ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)

  # init from a model saved in a specific directory
  checkpoint_dict = torch.load(checkpoint, map_location=device)
  config = LlamaConfig(**checkpoint_dict['model_args'])
  model = Llama(config)
  state_dict = checkpoint_dict['model']
  unwanted_prefix = '_orig_mod.'
  for k,v in list(state_dict.items()):
      if k.startswith(unwanted_prefix):
          state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
  model.load_state_dict(state_dict, strict=False)
  return model


# tokenizer.py

In [6]:
# Taken from llama code and modified by Andrej Karpathy.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.

import os
import struct
import argparse
from typing import List

from sentencepiece import SentencePieceProcessor

TOKENIZER_MODEL = "tokenizer.model" # the llama sentencepiece tokenizer model

class Tokenizer:
    def __init__(self, max_len=None, tokenizer_model=None):
        model_path = tokenizer_model if tokenizer_model else TOKENIZER_MODEL
        assert os.path.isfile(model_path), model_path
        self.sp_model = SentencePieceProcessor(model_file=model_path)
        self.model_path = model_path
        self.max_len = max_len

        # BOS / EOS token IDs
        self.n_words: int = self.sp_model.vocab_size()
        self.bos_id: int = self.sp_model.bos_id()
        self.eos_id: int = self.sp_model.eos_id()
        # Overwrite the default of pad_id=-1, which is problematic.
        self.pad_id: int = self.sp_model.piece_to_id("<0x00>")
        #print(f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}")
        assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()

    def encode(self, s: str, bos: bool, eos: bool) -> List[int]:
        assert type(s) is str
        t = self.sp_model.encode(s)
        if self.max_len is not None and len(t) > self.max_len:
            t = t[:self.max_len]
        if bos:
            t = [self.bos_id] + t
        if eos:
            t = t + [self.eos_id]
        return t

    def decode(self, t: List[int]) -> str:
        return self.sp_model.decode(t)

    def export(self):

        # get all the tokens (postprocessed) and their scores as floats
        tokens, scores = [], []
        for i in range(self.n_words):

            # decode the token and light postprocessing
            t = self.sp_model.id_to_piece(i)
            s = self.sp_model.get_score(i)
            if i == self.bos_id:
                t = '\n<s>\n'
            elif i == self.eos_id:
                t = '\n</s>\n'
            t = t.replace('▁', ' ') # sentencepiece uses this character as whitespace
            b = t.encode('utf-8') # bytes of this token, utf-8 encoded

            tokens.append(b)
            scores.append(s)

        # record the max token length
        max_token_length = max(len(t) for t in tokens)

        # write to a binary file
        # the tokenizer.bin file is the same as .model file, but .bin
        tokenizer_bin = self.model_path.replace('.model', '.bin')
        with open(tokenizer_bin, 'wb') as f:
            f.write(struct.pack("I", max_token_length))
            for bytes, score in zip(tokens, scores):
                f.write(struct.pack("fI", score, len(bytes)))
                f.write(bytes)

#if __name__ == "__main__":
#    parser = argparse.ArgumentParser()
#    parser.add_argument("-t", "--tokenizer-model", type=str, help="optional path to custom tokenizer ")
#    args = parser.parse_args()

#    t = Tokenizer(args.tokenizer_model)
#    t.export()


# classifier.py

In [7]:
import torch
import torch.nn.functional as F

# change it with respect to the original model
# from config import LlamaConfig
# from llama import load_pretrained
# from tokenizer import Tokenizer

class LlamaZeroShotClassifier(torch.nn.Module):
	def __init__(self, config: LlamaConfig, tokenizer: Tokenizer, label_names: list[str]):
		super(LlamaZeroShotClassifier, self).__init__()
		self.num_labels = config.num_labels
		self.llama = load_pretrained(config.pretrained_model_path)
		# Zero-shot classification does not require updating llama paramters.
		for param in self.llama.parameters():
			param.requires_grad = False
		assert len(label_names) == self.num_labels
		self.tokenizer = tokenizer
		self.label_name_ids = [tokenizer.encode(label, bos=False, eos=False) for label in label_names]


	def forward(self, input_ids):
		# compute the completion probability of each label string
		logits, _ = self.llama(input_ids)
		log_probabilities = F.log_softmax(logits, dim=-1)
		label_probabilities = torch.zeros((log_probabilities.shape[0], self.num_labels), device=log_probabilities.device)
		for i, label_token_ids in enumerate(self.label_name_ids):
			total_log_prob = torch.sum(log_probabilities[:, :, label_token_ids], axis=-1)
			label_probabilities[:, i] = total_log_prob[:, 0]
		return label_probabilities

class LlamaEmbeddingClassifier(torch.nn.Module):
	def __init__(self, config):
		super(LlamaEmbeddingClassifier, self).__init__()
		self.num_labels = config.num_labels
		self.llama = load_pretrained(config.pretrained_model_path)
		# If we use pretrain mode, we freeze Llama parameters.
		for param in self.llama.parameters():
			if config.option == 'pretrain':
				param.requires_grad = False
			elif config.option == 'finetune':
				param.requires_grad = True

		self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
		self.classifier_head = torch.nn.Linear(self.llama.config.dim, self.num_labels)

	def forward(self, input_ids):
		'''
		1) Find the hidden state after the final token of the input sequence
		2) Apply dropout (self.dropout) to the hidden state at training time to mitigate
		   overfitting.
		2) Pass this through the classifier head (self.classifier_head), which will return
		   logits (unnormalized probabilities) over all classes.
		3) Take the log-softmax of the logits and return log-probabilities over all classes.
		'''
		# todo
		# raise NotImplementedError
		_, hidden_states = self.llama(input_ids)
		hidden_state = self.dropout(hidden_states[:, -1, :])
		logits = self.classifier_head(hidden_state)
		return F.log_softmax(logits, dim = -1)


# optimizer.py

In [8]:
from typing import Callable, Iterable, Tuple

import torch
from torch.optim import Optimizer


class AdamW(Optimizer):
    def __init__(
            self,
            params: Iterable[torch.nn.parameter.Parameter],
            lr: float = 1e-3,
            betas: Tuple[float, float] = (0.9, 0.999),
            eps: float = 1e-6,
            weight_decay: float = 0.0,
            correct_bias: bool = True,
    ):
        if lr < 0.0:
            raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[0]))
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[1]))
        if not 0.0 <= eps:
            raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(eps))
        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, correct_bias=correct_bias)
        super().__init__(params, defaults)

    def step(self, closure: Callable = None):
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None:
                    continue
                grad = p.grad.data
                if grad.is_sparse:
                    raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead")

                # raise NotImplementedError()

                # State should be stored in this dictionary
                state = self.state[p]

                # Access hyperparameters from the `group` dictionary
                # alpha = group["lr"]

                # Update first and second moments of the gradients
                t = state.get("t", 0) + 1
                m = (
                    group["betas"][0] * state.get("m", torch.zeros_like(p))
                    + (1 - group["betas"][0]) * grad
                )
                v = (
                    group["betas"][1] * state.get("v", torch.zeros_like(p))
                    + (1 - group["betas"][1]) * grad**2
                )
                state["t"] = t
                state["m"] = m
                state["v"] = v

                # Bias correction
                # Please note that we are using the "efficient version" given in
                # https://arxiv.org/abs/1412.6980
                if group["correct_bias"]:
                    alpha = (
                        group["lr"]
                        * (1 - group["betas"][1] ** t) ** 0.5
                        / (1 - group["betas"][0] ** t)
                    )
                else:
                    alpha = group["lr"]

                # Update parameters
                p.data -= alpha * m / (v**0.5 + group["eps"])

                # Add weight decay after the main gradient-based updates.
                # Please note that the learning rate should be incorporated into this update.
                p.data -= group["lr"] * group["weight_decay"] * p.data

        return loss

# run_llama.py

In [14]:
from contextlib import nullcontext
import json
import time, random, numpy as np, argparse, sys, re, os
from types import SimpleNamespace

import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import classification_report, f1_score, recall_score, accuracy_score

# change it with respect to the original model
# from classifier import LlamaZeroShotClassifier, LlamaEmbeddingClassifier
# from llama import Llama, load_pretrained
# from optimizer import AdamW
# from tokenizer import Tokenizer
from tqdm import tqdm
from typing import Optional


TQDM_DISABLE=False
# fix the random seed
def seed_everything(seed=11711):
	random.seed(seed)
	np.random.seed(seed)
	torch.manual_seed(seed)
	torch.cuda.manual_seed(seed)
	torch.cuda.manual_seed_all(seed)
	torch.backends.cudnn.benchmark = False
	torch.backends.cudnn.deterministic = True

# create a custom Dataset Class to be used for the dataloader
class LlamaDataset(Dataset):
	def __init__(self, dataset, args, eos=False):
		self.dataset = dataset
		self.p = args
		self.tokenizer = Tokenizer(max_len=args.max_sentence_len)
		self.eos = eos

	def __len__(self):
		return len(self.dataset)

	def __getitem__(self, idx):
		ele = self.dataset[idx]
		return ele

	def pad_data(self, data):
		sents = [x[0] for x in data]
		labels = [x[1] for x in data]
		encoding = [self.tokenizer.encode(s, bos=True, eos=self.eos) for s in sents]
		max_length_in_batch = max([len(sentence) for sentence in encoding])
		encoding_padded = [sentence + [self.tokenizer.pad_id] * (max_length_in_batch - len(sentence)) for sentence in encoding]
		token_ids = torch.LongTensor(encoding_padded)
		labels = torch.LongTensor(labels)

		return token_ids, labels, sents

	def collate_fn(self, all_data):

		token_ids, labels, sents = self.pad_data(all_data)
		batched_data = {
				'token_ids': token_ids,
				'labels': labels,
				'sents': sents,
			}

		return batched_data


# create the data which is a list of (sentence, label, token for the labels)
def create_data(filename, tokenizer: Tokenizer, flag: str ='train', lower: bool = False, eos: bool = True, prompt_suffix: Optional[str]=None):
	# specify the tokenizer
	num_labels = {}
	data = []

	with open(filename, "r", encoding="utf-8") as fp:
		for line in fp:
			label, org_sent = line.split(' ||| ')
			if lower:
				org_sent = org_sent.lower()
			sent = org_sent.strip()
			if prompt_suffix is not None:
				sent = f"{sent} {prompt_suffix}"
			tokens = tokenizer.encode(sent, bos=True, eos=eos)
			label = int(label.strip())
			if label not in num_labels:
				num_labels[label] = len(num_labels)
			data.append((sent, label, tokens))
	print(f"load {len(data)} data from {filename}")
	if flag == 'train':
		return data, len(num_labels)
	else:
		return data

# perform model evaluation in terms of the accuracy and f1 score.
def model_eval(dataloader, model, device):
	model.eval() # switch to eval model, will turn off randomness like dropout
	y_true = []
	y_pred = []
	sents = []
	for step, batch in enumerate(tqdm(dataloader, desc=f'eval', disable=TQDM_DISABLE)):
		b_ids, b_labels, b_sents = batch['token_ids'], batch['labels'], batch['sents']

		b_ids = b_ids.to(device)

		logits = model(b_ids)
		logits = logits.detach().cpu().numpy()
		preds = np.argmax(logits, axis=1).flatten()

		b_labels = b_labels.flatten()
		y_true.extend(b_labels)
		y_pred.extend(preds)
		sents.extend(b_sents)

	f1 = f1_score(y_true, y_pred, average='macro')
	acc = accuracy_score(y_true, y_pred)

	return acc, f1, y_pred, y_true, sents

def save_model(model, optimizer, args, config, filepath):
	save_info = {
		'model': model.state_dict(),
		'optim': optimizer.state_dict(),
		#'args': args,
		'model_config': config,
		'system_rng': random.getstate(),
		'numpy_rng': np.random.get_state(),
		'torch_rng': torch.random.get_rng_state(),
	}

	torch.save(save_info, filepath)
	print(f"save the model to {filepath}")

def train(args):
	device = torch.device('cuda') if args.use_gpu else torch.device('cpu')
	#### Load data
	# create the data and its corresponding datasets and dataloader
	tokenizer = Tokenizer(args.max_sentence_len)
	train_data, num_labels = create_data(args.train, tokenizer, 'train')
	dev_data = create_data(args.dev, tokenizer, 'valid')

	train_dataset = LlamaDataset(train_data, args)
	dev_dataset = LlamaDataset(dev_data, args)

	train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch_size,
								  collate_fn=train_dataset.collate_fn)
	dev_dataloader = DataLoader(dev_dataset, shuffle=False, batch_size=args.batch_size,
								collate_fn=dev_dataset.collate_fn)

	#### Init model
	config = {'hidden_dropout_prob': args.hidden_dropout_prob,
			  'pretrained_model_path': args.pretrained_model_path,
			  'num_labels': num_labels,
			  'data_dir': '.',
			  'option': args.option}

	config = SimpleNamespace(**config)

	# initialize the Senetence Classification Model
	model = LlamaEmbeddingClassifier(config)
	model = model.to(device)

	lr = args.lr
	## specify the optimizer
	optimizer = AdamW(model.parameters(), lr=lr)
	best_dev_acc = 0

	## run for the specified number of epochs
	for epoch in tqdm(range(args.epochs)):
		model.train()
		train_loss = 0
		num_batches = 0
		for step, batch in enumerate(tqdm(train_dataloader, desc=f'train-{epoch}', disable=TQDM_DISABLE)):
			b_ids, b_labels, b_sents = batch['token_ids'], batch['labels'], batch['sents']

			b_ids = b_ids.to(device)
			b_labels = b_labels.to(device)

			optimizer.zero_grad()
			logits = model(b_ids)
			loss = F.nll_loss(logits, b_labels.view(-1), reduction='sum') / args.batch_size

			loss.backward()
			optimizer.step()

			train_loss += loss.item()
			num_batches += 1

		train_loss = train_loss / (num_batches)

		train_acc, train_f1, *_ = model_eval(train_dataloader, model, device)
		dev_acc, dev_f1, *_ = model_eval(dev_dataloader, model, device)

		if dev_acc > best_dev_acc:
			best_dev_acc = dev_acc
			save_model(model, optimizer, args, config, args.filepath)

		print(f"epoch {epoch}: train loss :: {train_loss :.3f}, train acc :: {train_acc :.3f}, dev acc :: {dev_acc :.3f}")

def generate_sentence(args, prefix, outfile, max_new_tokens = 75, temperature = 0.0):
	with torch.no_grad():
		device = torch.device('cuda') if args.use_gpu else torch.device('cpu')
		ctx = torch.amp.autocast(device_type="cuda", dtype=torch.float32) if args.use_gpu else nullcontext()
		llama = load_pretrained(args.pretrained_model_path)
		llama = llama.to(device)
		print(f"load model from {args.pretrained_model_path}")
		enc = Tokenizer(args.max_sentence_len)

		start_ids = enc.encode(prefix, bos=True, eos=False)
		x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...])

		# run generation
		with torch.no_grad():
			with ctx:
				y = llama.generate(x, max_new_tokens, temperature=temperature)
				sentence = enc.decode(y[0].tolist())
				print(f"Temperature is {temperature}")
				print(sentence)
				print('---------------')
				writer = open(outfile, 'w')
				writer.write(sentence)
				print(f"Wrote generated sentence to {outfile}.")
				writer.close()

def write_predictions_to_file(split: str, outfile: str, acc: float, pred: list[str], sents: list[str]):
	with open(outfile, "w+", encoding="utf-8") as f:
		print(f"{split} acc :: {acc :.3f}")
		for s, p in zip(sents, pred):
			f.write(f"{p} ||| {s}\n")

def test_with_prompting(args):
	assert args.dev_out.endswith("dev-prompting-output.txt"), 'For saving prompting results, please set the dev_out argument as "<dataset>-dev-prompting-output.txt"'
	assert args.test_out.endswith("test-prompting-output.txt"), 'For saving prompting results, please set the test_out argument as "<dataset>-test-prompting-output.txt"'

	with torch.no_grad():

		device = torch.device('cuda') if args.use_gpu else torch.device('cpu')
		#### Load data
		# create the data and its corresponding datasets and dataloader
		tokenizer = Tokenizer(args.max_sentence_len)
		label_names = json.load(open(args.label_names, 'r'))
		_, num_labels = create_data(args.train, tokenizer, 'train')

		#### Init model
		config = {'pretrained_model_path': args.pretrained_model_path,
				'label_names': label_names,
				'num_labels': num_labels,
				'data_dir': '.',
				'option': args.option}

		config = SimpleNamespace(**config)

		if len(label_names) == 2:
			label_name_str = " or ".join(label_names)
		else:
			label_name_str = ", ".join(label_names[:-1]) + ", or " + label_names[-1]
		prompt_suffix=f"Is this movie {label_name_str}? This movie is "
		model = LlamaZeroShotClassifier(config, tokenizer, label_names)
		model = model.to(device)

		dev_data = create_data(args.dev, tokenizer, 'valid', eos=False, prompt_suffix=prompt_suffix)
		dev_dataset = LlamaDataset(dev_data, args, eos=False)
		dev_dataloader = DataLoader(dev_dataset, shuffle=False, batch_size=args.batch_size, collate_fn=dev_dataset.collate_fn)

		test_data = create_data(args.test, tokenizer, 'test', eos=False, prompt_suffix=prompt_suffix)
		test_dataset = LlamaDataset(test_data, args, eos=False)
		test_dataloader = DataLoader(test_dataset, shuffle=False, batch_size=args.batch_size, collate_fn=test_dataset.collate_fn)

		dev_acc, dev_f1, dev_pred, dev_true, dev_sents = model_eval(dev_dataloader, model, device)
		test_acc, test_f1, test_pred, test_true, test_sents = model_eval(test_dataloader, model, device)

		write_predictions_to_file("dev", args.dev_out, dev_acc, dev_pred, dev_sents)
		write_predictions_to_file("test", args.test_out, test_acc, test_pred, test_sents)

def test(args):
	assert args.dev_out.endswith("dev-finetuning-output.txt"), 'For saving finetuning results, please set the dev_out argument as "<dataset>-dev-finetuning-output.txt"'
	assert args.test_out.endswith("test-finetuning-output.txt"), 'For saving finetuning results, please set the test_out argument as "<dataset>-test-finetuning-output.txt"'
	with torch.no_grad():
		device = torch.device('cuda') if args.use_gpu else torch.device('cpu')
		saved = torch.load(args.filepath)
		config = saved['model_config']
		model = LlamaEmbeddingClassifier(config)
		model.load_state_dict(saved['model'])
		model = model.to(device)
		print(f"load model from {args.filepath}")
		tokenizer = Tokenizer(args.max_sentence_len)
		dev_data = create_data(args.dev, tokenizer, 'valid')
		dev_dataset = LlamaDataset(dev_data, args)
		dev_dataloader = DataLoader(dev_dataset, shuffle=False, batch_size=args.batch_size, collate_fn=dev_dataset.collate_fn)

		test_data = create_data(args.test, tokenizer, 'test')
		test_dataset = LlamaDataset(test_data, args)
		test_dataloader = DataLoader(test_dataset, shuffle=False, batch_size=args.batch_size, collate_fn=test_dataset.collate_fn)

		dev_acc, dev_f1, dev_pred, dev_true, dev_sents = model_eval(dev_dataloader, model, device)
		test_acc, test_f1, test_pred, test_true, test_sents = model_eval(test_dataloader, model, device)

		write_predictions_to_file("dev", args.dev_out, dev_acc, dev_pred, dev_sents)
		write_predictions_to_file("test", args.test_out, test_acc, test_pred, test_sents)

def get_args():
    # Instead of using argparse, we define a simple class to hold our parameters
    class Args:
        def __init__(self):
            self.train = "data/cfimdb-train.txt"
            self.dev = "data/cfimdb-dev.txt"
            self.test = "data/cfimdb-test.txt"
            self.label_names = "data/cfimdb-label-mapping.json"
            self.pretrained_model_path = "stories42M.pt"
            self.max_sentence_len = None
            self.seed = 1337
            self.epochs = 5
            self.option = "generate"  # ('generate', 'prompt', 'finetune')'prompt: the Llama parameters are frozen; finetune: Llama parameters are updated',
            self.use_gpu = False  # Set to True if you want to use GPU
            self.generated_sentence_low_temp_out = "generated-sentence-temp-0.txt"
            self.generated_sentence_high_temp_out = "generated-sentence-temp-1.txt"
            self.dev_out = "cfimdb-dev-prompting-output.txt"
            self.test_out = "cfimdb-test-prompting-output.txt"
            self.batch_size = 8 # sst: 64, cfimdb: 8 can fit a 12GB GPU
            self.hidden_dropout_prob = 0.3
            self.lr = 2e-5 # default lr for 'pretrain': 1e-3, 'finetune': 1e-5", default=2e-5

    args = Args()
    print(f"args: {vars(args)}")
    return args

if __name__ == "__main__":
    args = get_args()
    args.filepath = f'{args.option}-{args.epochs}-{args.lr}.pt'  # save path
    seed_everything(args.seed)  # fix the seed for reproducibility

    if args.option == "generate":
        # Step 1
        # Complete this sentence to test your implementation!
        prefix = "I have wanted to see this thriller for a while, and it didn't disappoint. Keanu Reeves, playing the hero John Wick, is"
        generate_sentence(args, prefix, args.generated_sentence_low_temp_out, max_new_tokens=75, temperature=0.0)
        generate_sentence(args, prefix, args.generated_sentence_high_temp_out, max_new_tokens=75, temperature=1.0)
    elif args.option == "prompt":
        # Step 2
        # Solve this task with prompted language modeling
        test_with_prompting(args)
    elif args.option == "finetune":
        # Step 3
        # Finetune a classification model
        train(args)

        # Step 4
        # Evaluate your model on the dev and test sets
        test(args)
    else:
        raise ValueError(f"Invalid option: {args.option}")

args: {'train': 'data/cfimdb-train.txt', 'dev': 'data/cfimdb-dev.txt', 'test': 'data/cfimdb-test.txt', 'label_names': 'data/cfimdb-label-mapping.json', 'pretrained_model_path': 'stories42M.pt', 'max_sentence_len': None, 'seed': 1337, 'epochs': 5, 'option': 'generate', 'use_gpu': False, 'generated_sentence_low_temp_out': 'generated-sentence-temp-0.txt', 'generated_sentence_high_temp_out': 'generated-sentence-temp-1.txt', 'dev_out': 'cfimdb-dev-prompting-output.txt', 'test_out': 'cfimdb-test-prompting-output.txt', 'batch_size': 8, 'hidden_dropout_prob': 0.3, 'lr': 2e-05}


  checkpoint_dict = torch.load(checkpoint, map_location=device)


load model from stories42M.pt
Temperature is 0.0
I have wanted to see this thriller for a while, and it didn't disappoint. Keanu Reeves, playing the hero John Wick, is this day. He was playing with his toy car, driving it around the living room. Suddenly, he heard a loud crash. He had broken the car and was very sad.
John was angry and he shouted at his little brother. He was only three years old and he was only three. He was only three years old. He was very ups
---------------
Wrote generated sentence to generated-sentence-temp-0.txt.
load model from stories42M.pt
Temperature is 1.0
I have wanted to see this thriller for a while, and it didn't disappoint. Keanu Reeves, playing the hero John Wick, is it!" As John toddled up to the sweet aroma, he held his mum's hand tight, hoping he would finally reach the top step.
But as he stepped, he felt so heavy, like an apple. He had made a mistake! His mum apologisedDen handlely leaving, so he started to cry.
His m
---------------
Wrote genera

# Zero-Shot Prompting for SST

In [None]:
def get_args_zero_shot():
    # Instead of using argparse, we define a simple class to hold our parameters
    class Args:
        def __init__(self):
            self.train = "data/sst-train.txt"
            self.dev = "data/sst-dev.txt"
            self.test = "data/sst-test.txt"
            self.label_names = "data/sst-label-mapping.json"
            self.pretrained_model_path = "stories42M.pt"
            self.max_sentence_len = None
            self.seed = 1337
            self.epochs = 10
            self.option = "prompt"  # ('generate', 'prompt', 'finetune')'prompt: the Llama parameters are frozen; finetune: Llama parameters are updated',
            self.use_gpu = False  # Set to True if you want to use GPU
            self.generated_sentence_low_temp_out = "generated-sentence-temp-0.txt"
            self.generated_sentence_high_temp_out = "generated-sentence-temp-1.txt"
            self.dev_out = "sst-dev-prompting-output.txt"
            self.test_out = "sst-test-prompting-output.txt"
            self.batch_size = 10 # sst: 64, cfimdb: 8 can fit a 12GB GPU
            self.hidden_dropout_prob = 0.3
            self.lr = 2e-5 # default lr for 'pretrain': 1e-3, 'finetune': 1e-5", default=2e-5

    args = Args()
    print(f"args: {vars(args)}")
    return args

if __name__ == "__main__":
    args = get_args_zero_shot()
    args.filepath = f'{args.option}-{args.epochs}-{args.lr}.pt'  # save path
    seed_everything(args.seed)  # fix the seed for reproducibility

    if args.option == "generate":
        # Step 1
        # Complete this sentence to test your implementation!
        prefix = "I have wanted to see this thriller for a while, and it didn't disappoint. Keanu Reeves, playing the hero John Wick, is"
        generate_sentence(args, prefix, args.generated_sentence_low_temp_out, max_new_tokens=75, temperature=0.0)
        generate_sentence(args, prefix, args.generated_sentence_high_temp_out, max_new_tokens=75, temperature=1.0)
    elif args.option == "prompt":
        # Step 2
        # Solve this task with prompted language modeling
        test_with_prompting(args)
    elif args.option == "finetune":
        # Step 3
        # Finetune a classification model
        train(args)

        # Step 4
        # Evaluate your model on the dev and test sets
        test(args)
    else:
        raise ValueError(f"Invalid option: {args.option}")

args: {'train': 'data/sst-train.txt', 'dev': 'data/sst-dev.txt', 'test': 'data/sst-test.txt', 'label_names': 'data/sst-label-mapping.json', 'pretrained_model_path': 'stories42M.pt', 'max_sentence_len': None, 'seed': 1337, 'epochs': 10, 'option': 'prompt', 'use_gpu': False, 'generated_sentence_low_temp_out': 'generated-sentence-temp-0.txt', 'generated_sentence_high_temp_out': 'generated-sentence-temp-1.txt', 'dev_out': 'sst-dev-prompting-output.txt', 'test_out': 'sst-test-prompting-output.txt', 'batch_size': 10, 'hidden_dropout_prob': 0.3, 'lr': 2e-05}
load 8544 data from data/sst-train.txt


  checkpoint_dict = torch.load(checkpoint, map_location=device)


load 1101 data from data/sst-dev.txt
load 2210 data from data/sst-test.txt


eval: 100%|██████████| 111/111 [01:32<00:00,  1.20it/s]
eval: 100%|██████████| 221/221 [02:56<00:00,  1.25it/s]


dev acc :: 0.213
test acc :: 0.224


# Zero-Shot Prompting for CFIMDB

In [None]:
def get_args_zero_shot():
    # Instead of using argparse, we define a simple class to hold our parameters
    class Args:
        def __init__(self):
            self.train = "data/cfimdb-train.txt"
            self.dev = "data/cfimdb-dev.txt"
            self.test = "data/cfimdb-test.txt"
            self.label_names = "data/cfimdb-label-mapping.json"
            self.pretrained_model_path = "stories42M.pt"
            self.max_sentence_len = None
            self.seed = 1337
            self.epochs = 10
            self.option = "prompt"  # ('generate', 'prompt', 'finetune')'prompt: the Llama parameters are frozen; finetune: Llama parameters are updated',
            self.use_gpu = False  # Set to True if you want to use GPU
            self.generated_sentence_low_temp_out = "generated-sentence-temp-0.txt"
            self.generated_sentence_high_temp_out = "generated-sentence-temp-1.txt"
            self.dev_out = "cfimdb-dev-prompting-output.txt"
            self.test_out = "cfimdb-test-prompting-output.txt"
            self.batch_size = 10 # sst: 64, cfimdb: 8 can fit a 12GB GPU
            self.hidden_dropout_prob = 0.3
            self.lr = 2e-5 # default lr for 'pretrain': 1e-3, 'finetune': 1e-5", default=2e-5

    args = Args()
    print(f"args: {vars(args)}")
    return args

if __name__ == "__main__":
    args = get_args_zero_shot()
    args.filepath = f'{args.option}-{args.epochs}-{args.lr}.pt'  # save path
    seed_everything(args.seed)  # fix the seed for reproducibility

    if args.option == "generate":
        # Step 1
        # Complete this sentence to test your implementation!
        prefix = "I have wanted to see this thriller for a while, and it didn't disappoint. Keanu Reeves, playing the hero John Wick, is"
        generate_sentence(args, prefix, args.generated_sentence_low_temp_out, max_new_tokens=75, temperature=0.0)
        generate_sentence(args, prefix, args.generated_sentence_high_temp_out, max_new_tokens=75, temperature=1.0)
    elif args.option == "prompt":
        # Step 2
        # Solve this task with prompted language modeling
        test_with_prompting(args)
    elif args.option == "finetune":
        # Step 3
        # Finetune a classification model
        train(args)

        # Step 4
        # Evaluate your model on the dev and test sets
        test(args)
    else:
        raise ValueError(f"Invalid option: {args.option}")

args: {'train': 'data/cfimdb-train.txt', 'dev': 'data/cfimdb-dev.txt', 'test': 'data/cfimdb-test.txt', 'label_names': 'data/cfimdb-label-mapping.json', 'pretrained_model_path': 'stories42M.pt', 'max_sentence_len': None, 'seed': 1337, 'epochs': 10, 'option': 'prompt', 'use_gpu': False, 'generated_sentence_low_temp_out': 'generated-sentence-temp-0.txt', 'generated_sentence_high_temp_out': 'generated-sentence-temp-1.txt', 'dev_out': 'cfimdb-dev-prompting-output.txt', 'test_out': 'cfimdb-test-prompting-output.txt', 'batch_size': 10, 'hidden_dropout_prob': 0.3, 'lr': 2e-05}
load 1707 data from data/cfimdb-train.txt


  checkpoint_dict = torch.load(checkpoint, map_location=device)


load 245 data from data/cfimdb-dev.txt
load 488 data from data/cfimdb-test.txt


eval: 100%|██████████| 25/25 [02:47<00:00,  6.70s/it]
eval: 100%|██████████| 49/49 [05:25<00:00,  6.65s/it]


dev acc :: 0.502
test acc :: 0.213


# Finetuning for SST

In [None]:
def get_args():
    # Instead of using argparse, we define a simple class to hold our parameters
    class Args:
        def __init__(self):
            self.train = "data/sst-train.txt"
            self.dev = "data/sst-dev.txt"
            self.test = "data/sst-test.txt"
            self.label_names = "data/sst-label-mapping.json"
            self.pretrained_model_path = "stories42M.pt"
            self.max_sentence_len = None
            self.seed = 1337
            self.epochs = 5
            self.option = "finetune"  # ('generate', 'prompt', 'finetune')'prompt: the Llama parameters are frozen; finetune: Llama parameters are updated',
            self.use_gpu = False  # Set to True if you want to use GPU
            self.generated_sentence_low_temp_out = "generated-sentence-temp-0.txt"
            self.generated_sentence_high_temp_out = "generated-sentence-temp-1.txt"
            self.dev_out = "sst-dev-finetuning-output.txt"
            self.test_out = "sst-test-finetuning-output.txt"
            self.batch_size = 80 # sst: 64, cfimdb: 8 can fit a 12GB GPU
            self.hidden_dropout_prob = 0.3
            self.lr = 2e-5 # default lr for 'pretrain': 1e-3, 'finetune': 1e-5", default=2e-5

    args = Args()
    print(f"args: {vars(args)}")
    return args

if __name__ == "__main__":
    args = get_args()
    args.filepath = f'{args.option}-{args.epochs}-{args.lr}.pt'  # save path
    seed_everything(args.seed)  # fix the seed for reproducibility

    if args.option == "generate":
        # Step 1
        # Complete this sentence to test your implementation!
        prefix = "I have wanted to see this thriller for a while, and it didn't disappoint. Keanu Reeves, playing the hero John Wick, is"
        generate_sentence(args, prefix, args.generated_sentence_low_temp_out, max_new_tokens=75, temperature=0.0)
        generate_sentence(args, prefix, args.generated_sentence_high_temp_out, max_new_tokens=75, temperature=1.0)
    elif args.option == "prompt":
        # Step 2
        # Solve this task with prompted language modeling
        test_with_prompting(args)
    elif args.option == "finetune":
        # Step 3
        # Finetune a classification model
        train(args)

        # Step 4
        # Evaluate your model on the dev and test sets
        test(args)
    else:
        raise ValueError(f"Invalid option: {args.option}")

args: {'train': 'data/sst-train.txt', 'dev': 'data/sst-dev.txt', 'test': 'data/sst-test.txt', 'label_names': 'data/sst-label-mapping.json', 'pretrained_model_path': 'stories42M.pt', 'max_sentence_len': None, 'seed': 1337, 'epochs': 5, 'option': 'finetune', 'use_gpu': False, 'generated_sentence_low_temp_out': 'generated-sentence-temp-0.txt', 'generated_sentence_high_temp_out': 'generated-sentence-temp-1.txt', 'dev_out': 'sst-dev-finetuning-output.txt', 'test_out': 'sst-test-finetuning-output.txt', 'batch_size': 80, 'hidden_dropout_prob': 0.3, 'lr': 2e-05}
load 8544 data from data/sst-train.txt
load 1101 data from data/sst-dev.txt


  checkpoint_dict = torch.load(checkpoint, map_location=device)
  0%|          | 0/5 [00:00<?, ?it/s]
train-0:   0%|          | 0/107 [00:00<?, ?it/s][A
train-0:   1%|          | 1/107 [00:21<37:30, 21.23s/it][A
train-0:   2%|▏         | 2/107 [00:41<36:05, 20.63s/it][A
train-0:   3%|▎         | 3/107 [00:57<32:19, 18.65s/it][A
train-0:   4%|▎         | 4/107 [01:13<29:56, 17.44s/it][A
train-0:   5%|▍         | 5/107 [01:28<28:04, 16.52s/it][A
train-0:   6%|▌         | 6/107 [01:42<26:41, 15.86s/it][A
train-0:   7%|▋         | 7/107 [01:59<26:55, 16.15s/it][A
train-0:   7%|▋         | 8/107 [02:14<26:10, 15.86s/it][A
train-0:   8%|▊         | 9/107 [02:32<26:56, 16.50s/it][A
train-0:   9%|▉         | 10/107 [02:49<26:38, 16.48s/it][A
train-0:  10%|█         | 11/107 [03:07<27:05, 16.93s/it][A
train-0:  11%|█         | 12/107 [03:22<26:07, 16.50s/it][A
train-0:  12%|█▏        | 13/107 [03:39<25:50, 16.49s/it][A
train-0:  13%|█▎        | 14/107 [03:56<26:06, 16.85s/it][A
t

save the model to finetune-5-2e-05.pt
epoch 0: train loss :: 1.882, train acc :: 0.261, dev acc :: 0.262



train-1:   0%|          | 0/107 [00:00<?, ?it/s][A
train-1:   1%|          | 1/107 [00:19<33:37, 19.03s/it][A
train-1:   2%|▏         | 2/107 [00:40<35:21, 20.20s/it][A
train-1:   3%|▎         | 3/107 [00:59<34:18, 19.79s/it][A
train-1:   4%|▎         | 4/107 [01:15<31:41, 18.46s/it][A
train-1:   5%|▍         | 5/107 [01:31<29:42, 17.48s/it][A
train-1:   6%|▌         | 6/107 [01:47<28:46, 17.09s/it][A
train-1:   7%|▋         | 7/107 [02:04<28:06, 16.86s/it][A
train-1:   7%|▋         | 8/107 [02:19<26:49, 16.26s/it][A
train-1:   8%|▊         | 9/107 [02:35<26:42, 16.35s/it][A
train-1:   9%|▉         | 10/107 [02:51<26:04, 16.13s/it][A
train-1:  10%|█         | 11/107 [03:08<26:02, 16.27s/it][A
train-1:  11%|█         | 12/107 [03:25<26:07, 16.50s/it][A
train-1:  12%|█▏        | 13/107 [03:43<26:50, 17.13s/it][A
train-1:  13%|█▎        | 14/107 [03:59<25:52, 16.69s/it][A
train-1:  14%|█▍        | 15/107 [04:13<24:14, 15.81s/it][A
train-1:  15%|█▍        | 16/107 [04:31<2

epoch 1: train loss :: 1.654, train acc :: 0.273, dev acc :: 0.253



train-2:   0%|          | 0/107 [00:00<?, ?it/s][A
train-2:   1%|          | 1/107 [00:16<29:51, 16.90s/it][A
train-2:   2%|▏         | 2/107 [00:31<27:32, 15.74s/it][A
train-2:   3%|▎         | 3/107 [00:47<27:32, 15.89s/it][A
train-2:   4%|▎         | 4/107 [01:05<28:08, 16.39s/it][A
train-2:   5%|▍         | 5/107 [01:21<27:48, 16.36s/it][A
train-2:   6%|▌         | 6/107 [01:36<26:34, 15.78s/it][A
train-2:   7%|▋         | 7/107 [01:54<27:29, 16.50s/it][A
train-2:   7%|▋         | 8/107 [02:07<25:38, 15.54s/it][A
train-2:   8%|▊         | 9/107 [02:25<26:48, 16.41s/it][A
train-2:   9%|▉         | 10/107 [02:41<26:22, 16.32s/it][A
train-2:  10%|█         | 11/107 [02:57<25:55, 16.21s/it][A
train-2:  11%|█         | 12/107 [03:15<26:29, 16.73s/it][A
train-2:  12%|█▏        | 13/107 [03:31<25:48, 16.47s/it][A
train-2:  13%|█▎        | 14/107 [03:52<27:21, 17.65s/it][A
train-2:  14%|█▍        | 15/107 [04:07<26:09, 17.06s/it][A
train-2:  15%|█▍        | 16/107 [04:23<2

save the model to finetune-5-2e-05.pt
epoch 2: train loss :: 1.558, train acc :: 0.401, dev acc :: 0.361



train-3:   0%|          | 0/107 [00:00<?, ?it/s][A
train-3:   1%|          | 1/107 [00:17<30:15, 17.12s/it][A
train-3:   2%|▏         | 2/107 [00:32<28:00, 16.00s/it][A
train-3:   3%|▎         | 3/107 [00:49<28:57, 16.70s/it][A
train-3:   4%|▎         | 4/107 [01:07<29:05, 16.95s/it][A
train-3:   5%|▍         | 5/107 [01:25<29:39, 17.45s/it][A
train-3:   6%|▌         | 6/107 [01:41<28:34, 16.98s/it][A
train-3:   7%|▋         | 7/107 [01:55<26:47, 16.08s/it][A
train-3:   7%|▋         | 8/107 [02:12<26:53, 16.30s/it][A
train-3:   8%|▊         | 9/107 [02:33<29:07, 17.83s/it][A
train-3:   9%|▉         | 10/107 [02:47<26:54, 16.64s/it][A
train-3:  10%|█         | 11/107 [03:07<27:58, 17.49s/it][A
train-3:  11%|█         | 12/107 [03:21<26:03, 16.46s/it][A
train-3:  12%|█▏        | 13/107 [03:39<26:27, 16.89s/it][A
train-3:  13%|█▎        | 14/107 [03:55<26:05, 16.83s/it][A
train-3:  14%|█▍        | 15/107 [04:15<27:09, 17.71s/it][A
train-3:  15%|█▍        | 16/107 [04:30<2

save the model to finetune-5-2e-05.pt
epoch 3: train loss :: 1.326, train acc :: 0.517, dev acc :: 0.392



train-4:   0%|          | 0/107 [00:00<?, ?it/s][A
train-4:   1%|          | 1/107 [00:19<34:19, 19.43s/it][A
train-4:   2%|▏         | 2/107 [00:37<32:27, 18.55s/it][A
train-4:   3%|▎         | 3/107 [00:52<29:14, 16.87s/it][A
train-4:   4%|▎         | 4/107 [01:09<29:14, 17.04s/it][A
train-4:   5%|▍         | 5/107 [01:23<27:06, 15.94s/it][A
train-4:   6%|▌         | 6/107 [01:41<28:03, 16.67s/it][A
train-4:   7%|▋         | 7/107 [01:58<27:53, 16.73s/it][A
train-4:   7%|▋         | 8/107 [02:15<27:56, 16.93s/it][A
train-4:   8%|▊         | 9/107 [02:31<26:50, 16.43s/it][A
train-4:   9%|▉         | 10/107 [02:45<25:34, 15.82s/it][A
train-4:  10%|█         | 11/107 [03:01<25:06, 15.69s/it][A
train-4:  11%|█         | 12/107 [03:16<24:53, 15.72s/it][A
train-4:  12%|█▏        | 13/107 [03:31<24:13, 15.46s/it][A
train-4:  13%|█▎        | 14/107 [03:48<24:34, 15.86s/it][A
train-4:  14%|█▍        | 15/107 [04:04<24:21, 15.88s/it][A
train-4:  15%|█▍        | 16/107 [04:23<2

save the model to finetune-5-2e-05.pt
epoch 4: train loss :: 1.074, train acc :: 0.688, dev acc :: 0.414



  saved = torch.load(args.filepath)
  checkpoint_dict = torch.load(checkpoint, map_location=device)


load model from finetune-5-2e-05.pt
load 1101 data from data/sst-dev.txt
load 2210 data from data/sst-test.txt


eval: 100%|██████████| 14/14 [01:15<00:00,  5.37s/it]
eval: 100%|██████████| 28/28 [02:30<00:00,  5.36s/it]

dev acc :: 0.414
test acc :: 0.418





# Finetuning for CFIMDB

In [16]:
def get_args():
    # Instead of using argparse, we define a simple class to hold our parameters
    class Args:
        def __init__(self):
            self.train = "data/cfimdb-train.txt"
            self.dev = "data/cfimdb-dev.txt"
            self.test = "data/cfimdb-test.txt"
            self.label_names = "data/cfimdb-label-mapping.json"
            self.pretrained_model_path = "stories42M.pt"
            self.max_sentence_len = None
            self.seed = 1337
            self.epochs = 5
            self.option = "finetune"  # ('generate', 'prompt', 'finetune')'prompt: the Llama parameters are frozen; finetune: Llama parameters are updated',
            self.use_gpu = True  # Set to True if you want to use GPU
            self.generated_sentence_low_temp_out = "generated-sentence-temp-0.txt"
            self.generated_sentence_high_temp_out = "generated-sentence-temp-1.txt"
            self.dev_out = "cfimdb-dev-finetuning-output.txt"
            self.test_out = "cfimdb-test-finetuning-output.txt"
            self.batch_size = 10 # sst: 64, cfimdb: 8 can fit a 12GB GPU
            self.hidden_dropout_prob = 0.3
            self.lr = 2e-5 # default lr for 'pretrain': 1e-3, 'finetune': 1e-5", default=2e-5

    args = Args()
    print(f"args: {vars(args)}")
    return args

if __name__ == "__main__":
    args = get_args()
    args.filepath = f'{args.option}-{args.epochs}-{args.lr}.pt'  # save path
    seed_everything(args.seed)  # fix the seed for reproducibility

    if args.option == "generate":
        # Step 1
        # Complete this sentence to test your implementation!
        prefix = "I have wanted to see this thriller for a while, and it didn't disappoint. Keanu Reeves, playing the hero John Wick, is"
        generate_sentence(args, prefix, args.generated_sentence_low_temp_out, max_new_tokens=75, temperature=0.0)
        generate_sentence(args, prefix, args.generated_sentence_high_temp_out, max_new_tokens=75, temperature=1.0)
    elif args.option == "prompt":
        # Step 2
        # Solve this task with prompted language modeling
        test_with_prompting(args)
    elif args.option == "finetune":
        # Step 3
        # Finetune a classification model
        train(args)

        # Step 4
        # Evaluate your model on the dev and test sets
        test(args)
    else:
        raise ValueError(f"Invalid option: {args.option}")

args: {'train': 'data/cfimdb-train.txt', 'dev': 'data/cfimdb-dev.txt', 'test': 'data/cfimdb-test.txt', 'label_names': 'data/cfimdb-label-mapping.json', 'pretrained_model_path': 'stories42M.pt', 'max_sentence_len': None, 'seed': 1337, 'epochs': 5, 'option': 'finetune', 'use_gpu': True, 'generated_sentence_low_temp_out': 'generated-sentence-temp-0.txt', 'generated_sentence_high_temp_out': 'generated-sentence-temp-1.txt', 'dev_out': 'cfimdb-dev-finetuning-output.txt', 'test_out': 'cfimdb-test-finetuning-output.txt', 'batch_size': 10, 'hidden_dropout_prob': 0.3, 'lr': 2e-05}
load 1707 data from data/cfimdb-train.txt
load 245 data from data/cfimdb-dev.txt


  checkpoint_dict = torch.load(checkpoint, map_location=device)
  0%|          | 0/5 [00:00<?, ?it/s]
train-0:   0%|          | 0/171 [00:00<?, ?it/s][A
train-0:   1%|          | 1/171 [00:01<04:45,  1.68s/it][A
train-0:   1%|          | 2/171 [00:01<02:24,  1.17it/s][A
train-0:   2%|▏         | 3/171 [00:02<01:39,  1.69it/s][A
train-0:   2%|▏         | 4/171 [00:02<01:21,  2.04it/s][A
train-0:   3%|▎         | 5/171 [00:02<01:05,  2.52it/s][A
train-0:   4%|▎         | 6/171 [00:03<00:56,  2.92it/s][A
train-0:   4%|▍         | 7/171 [00:03<00:51,  3.19it/s][A
train-0:   5%|▍         | 8/171 [00:03<00:47,  3.40it/s][A
train-0:   5%|▌         | 9/171 [00:03<00:45,  3.57it/s][A
train-0:   6%|▌         | 10/171 [00:04<00:42,  3.76it/s][A
train-0:   6%|▋         | 11/171 [00:04<00:45,  3.49it/s][A
train-0:   7%|▋         | 12/171 [00:04<00:45,  3.52it/s][A
train-0:   8%|▊         | 13/171 [00:05<00:49,  3.21it/s][A
train-0:   8%|▊         | 14/171 [00:05<00:47,  3.30it/s][A
t

save the model to finetune-5-2e-05.pt
epoch 0: train loss :: 0.972, train acc :: 0.501, dev acc :: 0.502



train-1:   0%|          | 0/171 [00:00<?, ?it/s][A
train-1:   1%|          | 1/171 [00:00<00:56,  2.98it/s][A
train-1:   1%|          | 2/171 [00:00<00:52,  3.22it/s][A
train-1:   2%|▏         | 3/171 [00:00<00:51,  3.24it/s][A
train-1:   2%|▏         | 4/171 [00:01<00:45,  3.69it/s][A
train-1:   3%|▎         | 5/171 [00:01<00:45,  3.62it/s][A
train-1:   4%|▎         | 6/171 [00:01<00:44,  3.69it/s][A
train-1:   4%|▍         | 7/171 [00:02<00:46,  3.52it/s][A
train-1:   5%|▍         | 8/171 [00:02<00:47,  3.45it/s][A
train-1:   5%|▌         | 9/171 [00:02<00:54,  2.99it/s][A
train-1:   6%|▌         | 10/171 [00:03<00:55,  2.91it/s][A
train-1:   6%|▋         | 11/171 [00:03<00:50,  3.18it/s][A
train-1:   7%|▋         | 12/171 [00:03<00:47,  3.36it/s][A
train-1:   8%|▊         | 13/171 [00:03<00:46,  3.39it/s][A
train-1:   8%|▊         | 14/171 [00:04<00:46,  3.41it/s][A
train-1:   9%|▉         | 15/171 [00:04<00:48,  3.24it/s][A
train-1:   9%|▉         | 16/171 [00:04<0

epoch 1: train loss :: 0.816, train acc :: 0.503, dev acc :: 0.502



train-2:   0%|          | 0/171 [00:00<?, ?it/s][A
train-2:   1%|          | 1/171 [00:00<00:50,  3.40it/s][A
train-2:   1%|          | 2/171 [00:00<00:42,  3.96it/s][A
train-2:   2%|▏         | 3/171 [00:00<00:40,  4.17it/s][A
train-2:   2%|▏         | 4/171 [00:01<00:44,  3.76it/s][A
train-2:   3%|▎         | 5/171 [00:01<00:41,  4.03it/s][A
train-2:   4%|▎         | 6/171 [00:01<00:42,  3.91it/s][A
train-2:   4%|▍         | 7/171 [00:01<00:46,  3.52it/s][A
train-2:   5%|▍         | 8/171 [00:02<00:49,  3.30it/s][A
train-2:   5%|▌         | 9/171 [00:02<00:52,  3.10it/s][A
train-2:   6%|▌         | 10/171 [00:02<00:52,  3.08it/s][A
train-2:   6%|▋         | 11/171 [00:03<00:50,  3.18it/s][A
train-2:   7%|▋         | 12/171 [00:03<00:53,  2.99it/s][A
train-2:   8%|▊         | 13/171 [00:03<00:55,  2.85it/s][A
train-2:   8%|▊         | 14/171 [00:04<00:53,  2.92it/s][A
train-2:   9%|▉         | 15/171 [00:04<00:51,  3.02it/s][A
train-2:   9%|▉         | 16/171 [00:04<0

save the model to finetune-5-2e-05.pt
epoch 2: train loss :: 0.747, train acc :: 0.543, dev acc :: 0.522



train-3:   0%|          | 0/171 [00:00<?, ?it/s][A
train-3:   1%|          | 1/171 [00:00<01:04,  2.63it/s][A
train-3:   1%|          | 2/171 [00:00<01:01,  2.73it/s][A
train-3:   2%|▏         | 3/171 [00:00<00:49,  3.40it/s][A
train-3:   2%|▏         | 4/171 [00:01<00:49,  3.35it/s][A
train-3:   3%|▎         | 5/171 [00:01<00:49,  3.37it/s][A
train-3:   4%|▎         | 6/171 [00:01<00:51,  3.22it/s][A
train-3:   4%|▍         | 7/171 [00:02<00:50,  3.22it/s][A
train-3:   5%|▍         | 8/171 [00:02<00:50,  3.22it/s][A
train-3:   5%|▌         | 9/171 [00:02<00:46,  3.45it/s][A
train-3:   6%|▌         | 10/171 [00:03<00:45,  3.55it/s][A
train-3:   6%|▋         | 11/171 [00:03<00:46,  3.44it/s][A
train-3:   7%|▋         | 12/171 [00:03<00:45,  3.47it/s][A
train-3:   8%|▊         | 13/171 [00:03<00:46,  3.37it/s][A
train-3:   8%|▊         | 14/171 [00:04<00:44,  3.50it/s][A
train-3:   9%|▉         | 15/171 [00:04<00:43,  3.62it/s][A
train-3:   9%|▉         | 16/171 [00:04<0

save the model to finetune-5-2e-05.pt
epoch 3: train loss :: 0.597, train acc :: 0.877, dev acc :: 0.857



train-4:   0%|          | 0/171 [00:00<?, ?it/s][A
train-4:   1%|          | 1/171 [00:00<00:44,  3.85it/s][A
train-4:   1%|          | 2/171 [00:00<00:41,  4.03it/s][A
train-4:   2%|▏         | 3/171 [00:00<00:41,  4.09it/s][A
train-4:   2%|▏         | 4/171 [00:00<00:38,  4.30it/s][A
train-4:   3%|▎         | 5/171 [00:01<00:42,  3.88it/s][A
train-4:   4%|▎         | 6/171 [00:01<00:45,  3.59it/s][A
train-4:   4%|▍         | 7/171 [00:01<00:44,  3.69it/s][A
train-4:   5%|▍         | 8/171 [00:02<00:47,  3.46it/s][A
train-4:   5%|▌         | 9/171 [00:02<00:48,  3.31it/s][A
train-4:   6%|▌         | 10/171 [00:02<00:48,  3.35it/s][A
train-4:   6%|▋         | 11/171 [00:03<00:46,  3.41it/s][A
train-4:   7%|▋         | 12/171 [00:03<00:42,  3.74it/s][A
train-4:   8%|▊         | 13/171 [00:03<00:44,  3.55it/s][A
train-4:   8%|▊         | 14/171 [00:03<00:44,  3.52it/s][A
train-4:   9%|▉         | 15/171 [00:04<00:47,  3.28it/s][A
train-4:   9%|▉         | 16/171 [00:04<0

epoch 4: train loss :: 0.355, train acc :: 0.907, dev acc :: 0.845


  checkpoint_dict = torch.load(checkpoint, map_location=device)


load model from finetune-5-2e-05.pt
load 245 data from data/cfimdb-dev.txt
load 488 data from data/cfimdb-test.txt


eval: 100%|██████████| 25/25 [00:02<00:00,  9.93it/s]
eval: 100%|██████████| 49/49 [00:04<00:00, 10.13it/s]

dev acc :: 0.857
test acc :: 0.469



