In [1]:
import os, sys
sys.path.append('../')

import outlines
import outlines.models as models
import outlines.text as text

import torch
import transformers

from pydantic import BaseModel, Field, constr, conlist
from enum import Enum

from utils.summarize_utils import ConstrainedResponseHST, prompt_fn

%load_ext autoreload
%autoreload 2

In [18]:
from transformers import BitsAndBytesConfig

# model_name = "mistralai/Mistral-7B-Instruct-v0.2"
model_name = "mistralai/Mixtral-8x7B-Instruct-v0.1"

config = transformers.AutoConfig.from_pretrained(
    model_name, trust_remote_code=True,
)

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

model = models.transformers(
    
    model_name=model_name,
    device="cuda:0",
    model_kwargs={
        "config": config,
        "quantization_config": bnb_config,
        "trust_remote_code": True,
        "device_map": "auto",
        "load_in_4bit": True,
        "cache_dir": "/n/holystore01/LABS/iaifi_lab/Users/smsharma/hf_cache/"
    },
)

Loading checkpoint shards:   0%|          | 0/19 [00:00<?, ?it/s]

In [47]:
@outlines.prompt
def prompt_fn(sum):
    """[INST] Please produce a list of around concepts characterizing prominent objects, phenomena, and science use cases of images observed by the Hubble Space Telescope.

Here are some examples of objects:

{{sum}}

Follow these instructions exactly in your answer:
- Do not under any circumstances output empty strings as elements.
- Make sure that the list covers a diverse range of astronomical concepts, with items as different from each other as possible. 
- Do not give specific names of objects, to make sure you span the widest possible range of concepts (e.g., "dwarf galaxy" is allowed, but NOT "Fornax", "Terzan 5", or  "NGC6440").
- Do not return terms undescriptive of observations, e.g. "sloshing", "adiabatic", "interactions".
- Only output scientifically meaningful terms. E.g., NO "Cosmic Dance".
- Do not duplicate entries. Do not reference any telescopes, observatories, or surveys.
- Do not include units like "angular diameter distance" or any other concepts that will not correlate with images of observations.
- Use the above example list of objects only as inspiration to infer broad classes of objects.
- Answer in JSON format.
- The JSON should have the following keys {"galaxies", "stellar_physics", "exoplanets_planet_formation", "stellar_populations", "supermassive_black_holes", "solar_system", "integalactic_medium", "large_scale_structure"} reflective rough observation categories.
- Each category will have a list of objects and/or astronomical concepts.
- Output up to 20 items and no more in each category
[/INST]
"""

In [48]:
# with open("../data/UAT.csv") as file:
#     uat = file.read()

In [49]:
from pydantic import BaseModel, Field, constr, conlist, validator, field_validator
from enum import Enum
import re

class ScienceCategoriesHST(BaseModel):
    """ https://hubblesite.org/files/live/sites/hubble/files/home/_documents/hubble-cycle-31-observations-begins
    """
    galaxies: conlist(str, min_length=15, max_length=15)  # type: ignore
    stellar_physics: conlist(str, min_length=15, max_length=15)  # type: ignore
    exoplanets_planet_formation: conlist(str, min_length=15, max_length=15)  # type: ignore
    stellar_populations: conlist(str, min_length=15, max_length=15)  # type: ignore
    supermassive_black_holes: conlist(str, min_length=15, max_length=15)  # type: ignore
    solar_system: conlist(str, min_length=15, max_length=15)  # type: ignore
    integalactic_medium: conlist(str, min_length=15, max_length=15)  # type: ignore
    large_scale_structure: conlist(str, min_length=15, max_length=15)  # type: ignore

    @field_validator('*')
    def validate_values(cls, v):
        allowed_chars = [re.sub(r'[^\w\s-]+', '', x).strip() for x in v]
        return allowed_chars

In [50]:
import pandas as pd
summaries_filename = "../data/summary_v2.csv"
summaries_df = pd.read_csv(summaries_filename)

In [51]:
from outlines.generate import json

In [52]:
# prompts = [prompt_fn('\n'.join(summaries_df['objects_phenomena'].values[i_try * n_examples:(i_try + 1) * n_examples])) for i_try in range(4)]
# sequence = generator(prompts)

In [53]:
from tqdm import tqdm

n_examples = 100
n_tries = 1

sum1 = []
generator = json(model, ScienceCategoriesHST)
for i_try in tqdm(range(n_tries)):
    prompt = prompt_fn('\n'.join(summaries_df['objects_phenomena'].values[i_try * n_examples:(i_try + 1) * n_examples]))
    sequence = generator(prompt)
    list_of_lists = [sequence.model_dump()[key] for key in sequence.model_dump().keys()]
    sum1 += [item for sublist in list_of_lists for item in sublist]

100%|██████████████████████████████████████████████████████████| 1/1 [01:21<00:00, 81.35s/it]


In [54]:
import string 

special_chars = set(string.punctuation) 

cleaned_sum1 = []
for s in sum1:
    start_idx = 0
    while start_idx < len(s) and (s[start_idx] in special_chars or s[start_idx].isspace()):
        start_idx += 1
        
    cleaned_sum1.append(s[start_idx:])
        
cleaned_sum1;

In [55]:
cleaned2_sum1 = []
for s in cleaned_sum1:
    if s == '':
        continue
        
    if s.lower() not in map(str.lower, cleaned2_sum1):
        cleaned2_sum1.append(s)

cleaned2_sum1

['spiral galaxies',
 'elliptical galaxies',
 'irregular galaxies',
 'dwarf galaxies',
 'starburst galaxies',
 'interacting galaxies',
 'high-redshift galaxies',
 'luminous galaxies',
 'ultra-diffuse galaxies',
 'protogalaxies',
 'cluster galaxies',
 'galaxy clusters',
 'filamentary structures',
 'cosmic web',
 'radio galaxies',
 'supernovae',
 'neutron stars',
 'pulsars',
 'magnetars',
 'stellar winds',
 'stellar atmospheres',
 'stellar evolution',
 'red giants',
 'white dwarfs',
 'brown dwarfs',
 'main-sequence stars',
 'degeneracy',
 'evolutionary sequences',
 'massive stars',
 'stellar remnants',
 'exoplanets',
 'planetary systems',
 'circumstellar disks',
 'protoplanetary disks',
 'planet formation',
 'planet migration',
 'planet-star interactions',
 'debris disks',
 'exomoons',
 'habitability',
 'transit method',
 'radial velocity method',
 'direct imaging',
 'gravitational microlensing',
 'exoplanetary atmospheres',
 'globular clusters',
 'open clusters',
 'stellar associations',

In [56]:
# model = models.awq("TheBloke/OpenHermes-2.5-Mistral-7B-AWQ")

In [57]:
from utils.abstract_utils import read_abstracts_file

from tqdm.notebook import tqdm

filename = "../data/abstracts.cat"

abstracts_df = read_abstracts_file(filename)

# Drop rows with missing Cycle
abstracts_df = abstracts_df.dropna(subset=['Cycle'])
abstracts_df = abstracts_df[abstracts_df['Cycle'] != '']

# Convert Cycle and ID to int
abstracts_df['Cycle'] = abstracts_df['Cycle'].astype(int)
abstracts_df['ID'] = abstracts_df['ID'].astype(int)

In [65]:
abs = abstracts_df['Abstract'].values[-79]  # -77
abs

' New and fundamental constraints on the evolutionary state of high redshift clusters will be made by obtaining deep, multiband images {SDSS r, i, z} over the central 1.5 Mpc regions of seven distant clusters in the range 0.76 < z < 1.27. The ACS data will allow us to {1} definitively establish the morphological composition and star formation rates as functions of clustercentric radius, local density, x-ray luminosity {obtained from accompanying Chandra, and XMM data}, {2} explore the relationship between substructure, kinematics, and morphology, {3} strongly constrain the galaxy merger frequency and the origins of elliptical and S0 galaxies, {4} measure the mass distribution independently from the light {via gravitational lensing} enabling comparisons with kinematically derived masses, and {5} study the evolution of the structure of the brightest cluster members. The clusters selected for this program already have extensive spectroscopic observations and NIR imaging is either in hand 

In [66]:
@outlines.prompt
def prompt_fn(abs, cats):
    """<s>[INST] The following is a successful proposal abstract for the Hubble Space Telescope: "{{abs}}"

The following is a list of categories (astronomical concepts) that this abstract could correspond to.

{{cats}}

Please answer which of these listed concepts best describes this proposal, based on the objects and phenomena mentioned in the abstract? 

- For example, "The locations of supernovae {SNe} in the local stellar and gaseous environment in galaxies, as measured in high spatial resolution WFPC2 and ACS images, contain important clues to their progenitor stars." should return "supernova".
- If the abstract centers calibration and/or instrumentation efforts, return calibration or instrumention".

If no concept make sense, return "None". [/INST]
"""

In [67]:
from outlines.generate import choice

In [68]:
prompt = prompt_fn(abs, ', '.join(cleaned2_sum1 + ["calibration or instrumention"]))
choice(model, cleaned2_sum1 + ["None", "calibration or instrumention"])(prompt)

'cluster galaxies'

In [None]:
# - For example, "The locations of supernovae {SNe} in the local stellar and gaseous environment in galaxies, as measured in high spatial resolution WFPC2 and ACS images, contain important clues to their progenitor stars." should return "supernova".