In [1]:
import os, sys
sys.path.append('../')

import outlines
import outlines.models as models
import outlines.text as text

import torch
import transformers

from pydantic import BaseModel, Field, constr, conlist
from enum import Enum

from utils.summarize_utils import ConstrainedResponseHST, prompt_fn

%load_ext autoreload
%autoreload 2

In [4]:
from transformers import BitsAndBytesConfig

model_name = "mistralai/Mixtral-8x7B-Instruct-v0.1"

config = transformers.AutoConfig.from_pretrained(
    model_name, trust_remote_code=True, asd=True,
)

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

model = models.transformers(
    
    model_name=model_name,
    model_kwargs={
        "config": config,
        "quantization_config": bnb_config,
        "trust_remote_code": True,
        "device_map": "auto",
        "load_in_4bit": True,
        "cache_dir": "/n/holystore01/LABS/iaifi_lab/Users/smsharma/hf_cache/"
    },
)

Loading checkpoint shards:   0%|          | 0/19 [00:00<?, ?it/s]

In [5]:
abstract = """
The observed optical depths to microlensing of stars in the Galactic bulge are
difficult to reconcile with our present understanding of Galactic dynamics.
The main source of uncertainty in those comparisons is now shifting from
microlensing measurements to the dynamical models of the Galactic bar. We
propose to constrain the Galactic bar models with proper motion observations
of Bulge stars that underwent microlensing by determining both the kinematic
identity of the microlensed sources and the importance of streaming motions.
The lensed stars are typically farther than randomly selected stars.
Therefore, our proper motion determinations for 36 targeted MACHO events will
provide valuable constraints on the dynamics of bulge stars as a function of
distance. The first epoch data for our proposed events is already available in
the HST archive so the project can be completed within a single HST cycle. The
exceptional spatial resolution of HST is essential for completion of the
project. Constraints on te total mass in the bulge will ultimately lead to
the determination of the amount of dark matter in inner Galaxy.
"""

In [70]:
@outlines.prompt
def prompt_fn(abstract, query):
     """[INST]
You are an expert astrophysicist, with broad expertise across observational and theoretical astrophysics.

Abstract: "{{abstract}}"
Query: "{{query}}"

The above is an abstract for a proposed observation taken by the Hubble Space Telescope (labeled "Abstract"), and an object or concept (labeled "Query").

Could the observations corresponding to the abstract contain the query? Be precise, and do not contain related concepts or objects. 

Your response should be either True or False. Only return True if the query is closely related to the abstract, and the downstream observation could be relevant to the query.
[/INST]
"""

In [71]:
# Open the file in read mode

import numpy as np
from tqdm.notebook import tqdm

type_list = ["base", "ft", "tfid"]

bool_list = np.zeros((3, 10, 10))

for idx_type, type in enumerate(type_list):

    for idx in tqdm(range(10)):

        with open(f'eval_quant/captions_{type}_{idx}.txt', 'r') as file:
            # Read the lines from the file
            captions = file.readlines()
        
        with open(f'eval_quant/queries.txt', 'r') as file:
            # Read the lines from the file
            queries = file.readlines()
        
        # Strip any trailing newline characters
        captions = [caption.strip() for caption in captions]
        
        # Print the loaded captions
        for idx_cap, caption in enumerate(captions):
            prompt = prompt_fn(caption, queries[idx])
            generator = outlines.generate.format(model, bool)
            sequence = generator(prompt)
            bool_list[idx_type, idx, idx_cap] = 1 if sequence == 'True' else 0

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

In [72]:
bool_list[0].sum() / 100, bool_list[1].sum() / 100, bool_list[2].sum() / 100

(0.4, 0.76, 0.82)

In [74]:
# prompt = prompt_fn(abstract, "bulge stars")
# generator = outlines.generate.format(model, bool)
# sequence = generator(prompt)
# sequence