### Imports & Setup

Create a separate conda environment from the environment.yaml file.

`conda env create -f environment.yaml`

In [None]:
import os
import torch
import accelerate
import numpy as np
from transformers import (AutoTokenizer,
                          AutoModelForCausalLM,
                          BitsAndBytesConfig,
                          pipeline
                          )
import textwrap
from datasets import Dataset
import pandas as pd
import requests

os.environ ['HF_HUB_ENABLE_HF_TRANSFER'] = '1'
torch.set_default_device('cuda')

## Model Class and LLM Checking

`LLMObj` is a HF wrapper that contains the LLM model, tokenizer, and text generation wrapper.

Below the class code, several LLMs that are available on HF are initialized.

For some models like LLama, you need to authenticate your HF account, so add your [HF access token](https://huggingface.co/docs/hub/security-tokens) to the secrets on secrets as `HF_TOKEN`.

In [None]:
from utils import LLMObj, parse_args

args = parse_args()
print(args)

### LLM Inference

### Llama

In [None]:
# use the commented out parts for running in 4bit
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
model = "meta-llama/Meta-Llama-3-8B-Instruct",
model_kwargs = {"torch_dtype": torch.bfloat16,
            "quantization_config": quantization_config,
            "low_cpu_mem_usage": True}

Llama = LLMObj(model="meta-llama/Meta-Llama-3-8B-Instruct", model_kwargs=model_kwargs)

In [None]:
Llama.generate('Write a detailed analogy between mathematics and a lighthouse.',
        #  system_prompt="You are a helpful assistant called Llama-3. Write out your reasoning step-by-step to be sure you get the right answers!",
         max_length=1024)

### Starling

In [None]:
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
model_kwargs = {"torch_dtype": torch.bfloat16,
            "quantization_config": quantization_config
            }

starling = LLMObj(model="berkeley-nest/Starling-LM-7B-alpha", model_kwargs=model_kwargs)

### Phi-3

In [None]:
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
model_kwargs = {"torch_dtype": torch.bfloat16,
            "quantization_config": quantization_config
            }

phi_3 = LLMObj(model="microsoft/Phi-3-mini-128k-instruct", model_kwargs=model_kwargs, tokenizer="microsoft/Phi-3-mini-128k-instruct")

## Dataloader and analogy template sentences

### Analogy sentences

This is the template used for constructing the analogy sentence. Here we have the analogy sentence with the missing word that needs to be inferred, along with the full analogy sentence that is provided as an example.

In [None]:
ANALOGY_TEMPLATE_SIMPLE_FULL = "If {} is like {}, then {} is like {}."
ANALOGY_TEMPLATE_SIMPLE_INFERENCE = "If {} is like {}, then {} is like..."

### Download the datasets
- Downloading SCAN dataset and examples

In [None]:
SCAN_DATASET_FILEPATH = "scan.csv"
SCAN_EXAMPLES_FILEPATH = "scan_examples.txt"

url_scan_dataset = "https://raw.githubusercontent.com/prundeanualin/ATCS-project/blob/main/data/SCAN/SCAN_dataset.csv"
    response = requests.get(url_scan_dataset)
    if response.status_code == 200:
        csv_content = response.text
        with open(SCAN_DATASET_FILEPATH, "w") as csv_file:
            csv_file.write(csv_content)

        print("SCAN dataset file downloaded successfully.")
    else:
        raise Exception("Failed to download SCAN dataset file. Status code:", response.status_code)

url_scan_examples = "https://raw.githubusercontent.com/prundeanualin/ATCS-project/blob/main/data/SCAN/SCAN_examples.txt"
    response = requests.get(url_scan_examples)
    if response.status_code == 200:
        csv_content = response.text
        with open(SCAN_EXAMPLES_FILEPATH, "w") as csv_file:
            csv_file.write(csv_content)

        print("SCAN examples file downloaded successfully.")
    else:
        raise Exception("Failed to download SCAN examples file. Status code:", response.status_code)

### Load the dataset

In [None]:
def get_list_alternatives(alternatives):
    if alternatives == 'nan':
        return []
    else:
        return alternatives.split(', ')


class ScanDataLoader():
    def __init__(self,
                 shot_nr=1,
                 examples_start_idx=0,
                 analogy_sentence_infer=ANALOGY_TEMPLATE_SIMPLE_INFERENCE,
                 analogy_sentence_full=ANALOGY_TEMPLATE_SIMPLE_FULL):
      """
        - examples_start_idx: needs to be within the range of available examples per analogy type.
      """

        self.shot_nr = shot_nr
        self.examples_start_idx = examples_start_idx
        self.analogy_sentence_inference = analogy_sentence_infer
        self.analogy_sentence_example = analogy_sentence_full

        with open(SCAN_DATASET_FILEPATH, 'r', encoding='utf8') as f:
            self.df = pd.read_csv(f, sep=',', index_col=False)

        # Transform alternatives into list of strings. If none is provided, then use an empty list
        self.df['alternatives'] = self.df['alternatives'].astype(str).map(get_list_alternatives)

        # there were still some rows with same first three words but different fourth word. This should not
        # be the case since then the second instance should be part of the first's alternatives
        last_doubled = self.df[self.df.duplicated(subset=['source', 'target', 'targ_word'], keep='last')]
        self.df.drop(last_doubled.index, axis=0, inplace=True)

        for _, row in last_doubled.iterrows():
            first = self.df[(self.df['source'] == row['source']) & (self.df['target'] == row['target']) & (
                    self.df['targ_word'] == row['targ_word'])].index

            self.df.at[first.values[0], 'alternatives'] = self.df.at[first.values[0], 'alternatives'] + row[
                'alternatives'] + [row['src_word']]
        self.examples = {
            'science': [],
            'metaphor': []
        }

        with open(SCAN_EXAMPLES_FILEPATH, 'r', encoding='utf8') as f:
            lines = [line.rstrip() for line in f]
            nr_examples = round(len(lines) / 3)
            examples = [lines[idx * 3: (idx * 3) + 2] for idx in range(0, nr_examples)]
            for example in examples:
                target, source, targ_word, src_word, alternatives, analogy_type = tuple(example[0].split(','))
                ex = {
                    'target': target,
                    'source': source,
                    'targ_word': targ_word,
                    'src_word': src_word,
                    'detailed_cot': example[1],
                    'simple': self.analogy_sentence_example.format(target, source, targ_word, src_word),
                    'analogy_type': analogy_type
                }
                self.examples[analogy_type].append(ex)

        self.current_examples = None
        self.df_remapped_indices = None
        set_current_examples_and_exclude_from_dataset(self.examples_start_idx, self.shot_nr)


  def set_current_examples_and_exclude_from_dataset(examples_start_idx, nr_examples):
        # Get the nr_examples of examples starting from the start_idx, for each analogy type
        self.current_examples = []
        for k in  self.examples.keys():
          self.current_examples.extend(self.examples[k][self.examples_start_idx: self.shot_nr])

        # Remapping the indices in df so that those corresponding to the examples_to_consider are removed.
        # This way, the examples will not be returned when iterating through the df
        self.df_remapped_indices = [i for i in range(len(self.df))]
        indices_examples = []
        for ex in self.current_examples:
            idx = self.df[(self.df['source'] == ex['source']) & (self.df['target'] == ex['target']) & (
                    self.df['targ_word'] == ex['targ_word']) & (self.df['targ_word'] == ex['targ_word'])].index
            indices_examples.append(idx.values[0])
        for ex_i in indices_examples:
            self.df_remapped_indices.pop(ex_i)


    def __len__(self):
        return int(len(self.df))

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        idx = self.df_remapped_indices[idx]
        analogy_sent = self.df.iloc[idx, :3].values.tolist()
        label = self.df.iloc[idx, 3]
        alternatives = self.df.iloc[idx, 4]
        analogy_type = self.df.iloc[idx, 5]

        return {
            'inference': self.analogy_sentence_inference.format(*analogy_sent),
            'examples': self.current_examples,
            'label': label,
            'alternatives': alternatives,
            'analogy_type': analogy_type
        }

In [None]:
# class SCANDataloader:
#   def __init__(self):
#     url = "https://raw.githubusercontent.com/taczin/SCAN_analogies/main/data/SCAN_dataset.csv"
#     response = requests.get(url)
#     if response.status_code == 200:
#         csv_content = response.text
#         with open("scan.csv", "w") as csv_file:
#             csv_file.write(csv_content)

#         print("CSV file downloaded successfully.")
#     else:
#         print("Failed to download CSV file. Status code:", response.status_code)

#     self.data = pd.read_csv('scan.csv')
#     self.data.fillna("NaN", inplace=True)
#     self.shuffled_idx = np.arange(len(self.data))
#     # np.random.shuffle(self.shuffled_idx)

#   def __getitem__(self, idx):
#     row = self.data.iloc[self.shuffled_idx[idx]]
#     prompt = f"If {row['source']} is like {row['target']}, then {row['src_word']} is like..."
#     return {"row": row.to_dict(), "prompt": prompt}

#   def __call__(self):
#     for idx in self.shuffled_idx:
#       row = self.data.iloc[self.shuffled_idx[idx]]
#       prompt = f"If {row['source']} is like {row['target']}, then {row['src_word']} is like..."
#       yield {"row": row.to_dict(), "prompt": prompt}


## Inference Pipeline

In [None]:
# quantization_config = BitsAndBytesConfig(load_in_4bit=True)
# model = "meta-llama/Meta-Llama-3-8B-Instruct",
# model_kwargs = {"torch_dtype": torch.bfloat16,
#             "quantization_config": quantization_config,
#             "low_cpu_mem_usage": True}

# Llama = LLMObj(model="meta-llama/Meta-Llama-3-8B-Instruct", model_kwargs=model_kwargs)

SCAN_loader = ScanDataLoader()
# ds = Dataset.from_generator(SCAN_loader)
#   from tqdm import tqdm
generated_prompts = []

for i, sample in tqdm(enumerate(SCAN_loader)):
  output = phi_3.generate(sample['inference'])
  generated_prompts.append([sample, output])
  # if i > 5:
  #   break

In [None]:
def process_line(item, llm):
  # Generate the prompt
  prompt = f"If {item['source']} is like {item['target']}, then {item['src_word']} is like..."
  # Generate the response
  response = llm.generate(prompt)
  # Return the complete line with the response
  return {**item, 'response': response}

with open('scan_analogy_responses.csv', 'w') as f:
    headers = ['target', 'source', 'targ_word', 'src_word', 'alternatives', 'analogy_type', 'response']
    writer = csv.DictWriter(f, fieldnames=headers)
    writer.writeheader()

    # Read and process each line
    for chunk in pd.read_csv('scan.csv', chunksize=1):
        for index, row in chunk.iterrows():
            result = process_line(row, llm)
            writer.writerow(result)

print("Data with generated responses has been saved")