In [1]:
import os, sys
sys.path.append('../')

import outlines
import outlines.models as models
import outlines.text as text

import torch
import transformers

from pydantic import BaseModel, Field, constr, conlist
from enum import Enum

from utils.summarize_utils import ConstrainedResponseHST, prompt_fn

%load_ext autoreload
%autoreload 2

In [3]:
from transformers import BitsAndBytesConfig

model_name = "mistralai/Mixtral-8x7B-Instruct-v0.1"

config = transformers.AutoConfig.from_pretrained(
    model_name, trust_remote_code=True, asd=True,
)

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

model = models.transformers(
    
    model_name=model_name,
    model_kwargs={
        "config": config,
        "quantization_config": bnb_config,
        "trust_remote_code": True,
        "device_map": "auto",
        "load_in_4bit": True,
        "cache_dir": "/n/holystore01/LABS/iaifi_lab/Users/smsharma/hf_cache/"
    },
)

Loading checkpoint shards:   0%|          | 0/19 [00:00<?, ?it/s]

In [176]:
@outlines.prompt
def prompt_fn(sum):
    """<s>[INST] Please produce a list of 100 keywords characterizing characterizing prominent objects, phenomena, and science use cases of images taken by the Hubble Space Telescope.

{{sum}}

Above are some representative objects.

Please follow the following instructions exactly:
- Do not under any circumstances output empty strings as elements
- Answer in clear English
- Make sure that the list covers a diverse range of astronomical concepts, with elements as different from each other as possible. 
- Do not give very specific names of objects, to make sure you span the widest possible range of concepts (e.g., "dwarf galaxy"is fine, but not "Fornax")
- Do not give undescriptive terms, e.g. "sloshing", "adiabatic", "interactions"
- Start with the following: ["dark matter", "globular cluster", "supernova remnant", "starburst galaxy"]
- Answer in JSON format.

[/INST]
"""

In [177]:
# with open("../data/UAT.csv") as file:
#     uat = file.read()

In [178]:
from pydantic import BaseModel, Field, constr, conlist, validator, field_validator
from enum import Enum
import re

# Galaxies were allocated 25% of the successful
# proposals, followed by Stellar Physics (23%), Exoplanets and Planet Formation (12%), Stellar Populations
# (15%), Supermassive Black Holes (9%), Solar System (8%), and Intergalactic Medium and Large-Scale
# Structure (4% each).

class ScienceCategoriesHST(BaseModel):
    """ https://hubblesite.org/files/live/sites/hubble/files/home/_documents/hubble-cycle-31-observations-begins
    """
    galaxies: conlist(str, min_length=25, max_length=25)
    stellar_physics: conlist(str, min_length=23, max_length=23)
    exoplanets_planet_formation: conlist(str, min_length=12, max_length=12)
    stellar_populations: conlist(str, min_length=15, max_length=15)
    supermassive_black_holes: conlist(str, min_length=9, max_length=9)
    solar_system: conlist(str, min_length=8, max_length=8)
    integalactic_medium: conlist(str, min_length=4, max_length=4)
    large_scale_structure: conlist(str, min_length=4, max_length=4)

    @field_validator('*')
    def validate_values(cls, v):
        allowed_chars = [re.sub(r'[^\w\s]+', '', x).strip() for x in v]
        return allowed_chars

In [179]:
# sys.path.append('../')
# from utils.abstract_utils import read_abstracts_file

# filename = "../data/abstracts.cat"

# abstracts_df = read_abstracts_file(filename)

# # Drop rows with missing Cycle
# abstracts_df = abstracts_df.dropna(subset=['Cycle'])
# abstracts_df = abstracts_df[abstracts_df['Cycle'] != '']

# # Convert Cycle and ID to int
# abstracts_df['Cycle'] = abstracts_df['Cycle'].astype(int)
# abstracts_df['ID'] = abstracts_df['ID'].astype(int)

# # Only keep specific Cycles

# cycle_min = 27
# cycle_max = 27

# abstracts_cycle_df = abstracts_df[(abstracts_df['Cycle'] >= cycle_min) & (abstracts_df['Cycle'] <= cycle_max)]

# '\n'.join(abstracts_cycle_df['Abstract'].values[-5:])

In [180]:
import pandas as pd
summaries_filename = "../data/summary_v2.csv"
summaries_df = pd.read_csv(summaries_filename)

In [181]:
prompt = prompt_fn('\n'.join(summaries_df['objects_phenomena'].values[:100]))
generator = outlines.generate.json(model, ScienceCategoriesHST)
sequence = generator(prompt)
sequence

ScienceCategoriesHST(galaxies=['dark matter', 'spiral density waves', 'dwarf galaxies', 'irregular galaxies', 'interacting galaxies', 'highredshift galaxies', 'protocluster', 'galaxy groups', 'elliptical galaxies', 'starforming galaxies', 'poststarburst galaxies', 'spheroidal galaxies', 'clustersized overdensity', 'primeval galaxies', 'disc galaxy', 'edgeon spirals', 'lowmass galaxies', 'companions', 'gravitational lensing', 'gravitationally lensed Lymanalpha Emitters LAEs', 'strongly lensed Lymanalpha Emitters LAEs', 'lensed galaxy', 'star forming galaxy', 'luminous galaxy', 'ultraluminous IRAS galaxy'], stellar_physics=['globular cluster', 'neutron stars', 'pulsars', 'radiopulsars', 'PSR J01081431', 'Xray clusters', 'cluster galaxies', 'radial arcs', 'tangential arcs', 'optical jet candidates', 'hot stars', 'blue stars', 'young globular clusters', 'hot white dwarfs', 'cataclysmic variables', 'mainsequence stars', 'lowmass stars', 'white dwarfs', 'brown dwarfs', 'ultracool dwarfs', 'Y