In [None]:
import pandas as pd

# Reload modules automatically

%load_ext autoreload
%autoreload 2

In [None]:
def read_abstracts_file(filename):
    abstracts = []
    abstract = {}
    inside_abstract = False  # Flag to check if we're inside an abstract
    
    with open(filename, 'r') as file:
        for line in file:
            line = line.strip()
            
            if line.startswith('-----'):
                if abstract:  # If abstract has content, append to list
                    abstracts.append(abstract)
                    abstract = {}
                    inside_abstract = False  # Reset the flag
            else:
                # Check for known properties
                property_starts = ['Prop. Type:', 'Category:', 'ID:', 'Cycle:', 'Title:', 'PI:']
                
                if any(line.startswith(prop) for prop in property_starts) and not inside_abstract:
                    if 'Prop. Type' in abstract and line.startswith('Prop. Type:'):
                        # If a new abstract starts without delimiter, assume previous one ended
                        abstracts.append(abstract)
                        abstract = {}
                    
                    if line.startswith('Prop. Type:'):
                        abstract['Prop. Type'] = line.split(':', 1)[1].strip()
                    elif line.startswith('Category:'):
                        abstract['Category'] = line.split(':', 1)[1].strip()
                    elif line.startswith('ID:'):
                        id_val = line.split(':', 1)[1].strip()
                        try:
                            abstract['ID'] = int(id_val)
                        except ValueError:
                            abstract['ID'] = id_val
                    elif line.startswith('Cycle:'):
                        cycle_val = line.split(':', 1)[1].strip()
                        try:
                            abstract['Cycle'] = int(cycle_val)
                        except ValueError:
                            abstract['Cycle'] = cycle_val
                    elif line.startswith('Title:'):
                        abstract['Title'] = line.split(':', 1)[1].strip()
                    elif line.startswith('PI:'):
                        abstract['PI'] = line.split(':', 1)[1].strip()
                else:
                    # If none of the known properties are found, we treat the line as part of the abstract
                    abstract['Abstract'] = abstract.get('Abstract', '') + ' ' + line
                    inside_abstract = True  # Set the flag indicating we're inside an abstract
                
    # After loop ends, check if there's any remaining content in the abstract dictionary
    if abstract:
        abstracts.append(abstract)

    df = pd.DataFrame(abstracts)
    
    return df                

filename = "../data/abstracts.cat"
abstracts_df = read_abstracts_file(filename)

In [None]:
abstracts_df[abstracts_df['ID'] == 13200]

In [None]:
abstracts_df = abstracts_df.dropna(subset=['Cycle'])
abstracts_df = abstracts_df[abstracts_df['Cycle'] != '']

abstracts_df['Cycle'] = abstracts_df['Cycle'].astype(int)
abstracts_df['ID'] = abstracts_df['ID'].astype(int)
abstracts_cycle_df = abstracts_df[(abstracts_df['Cycle'] >= 25) & (abstracts_df['Cycle'] <= 31)]

In [None]:
abstracts_cycle_df['Cycle'].value_counts()

In [None]:
abstracts_cycle_df

In [None]:
abstract_ids = abstracts_cycle_df['ID'].values

In [None]:
import sys
sys.path.append("../")

from tqdm import tqdm
from scripts.download_data import download_data

proposal_id = 15922
n_max_images = 10
max_resolution = 512
seed = 42


for proposal_id in tqdm(abstract_ids):
    download_data(proposal_id, n_max_images, max_resolution, seed, data_dir='../data/observations/')

In [None]:
import os

def remove_large_files(directory, size_limit=2*1024*1024):  # default size_limit is set to 2MB
    for foldername, subfolders, filenames in os.walk(directory):
        for filename in filenames:
            filepath = os.path.join(foldername, filename)
            if os.path.getsize(filepath) > size_limit:
                try:
                    os.remove(filepath)
                    print(f"Removed {filepath}")
                except Exception as e:
                    print(f"Error removing {filepath}: {e}")

directory_path = '../data/observations/'
remove_large_files(directory_path)

In [None]:
proposal_id = 16075

abstract = abstracts_cycle_df[abstracts_cycle_df['ID'] == proposal_id]['Abstract'].values[0]
title = abstracts_cycle_df[abstracts_cycle_df['ID'] == proposal_id]['Title'].values[0]
category = abstracts_cycle_df[abstracts_cycle_df['ID'] == proposal_id]['Category'].values[0]

# Check if category is 'None' or empty; if so, set to 'None'
if not category or category == 'None':
    category = 'None'

abstract

In [None]:
import openai
from tenacity import retry, stop_after_attempt, wait_random_exponential


In [None]:
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
def completion_with_backoff(**kwargs):
    """Chat completion with exponential backoff to prevent rate limiting."""
    return openai.ChatCompletion.create(**kwargs)

def answer_question(prompt, system_prompt, api_key=None, model="gpt-3.5-turbo", max_tokens=300, temperature=0.1):
    """Return answer to query given context chunk."""

    openai.api_key = api_key

    response = completion_with_backoff(
        model=model,
        messages=[{"role": "system", "content": system_prompt}, {"role": "user", "content": prompt}],
        max_tokens=max_tokens,
        n=1,
        temperature=temperature,
    )
    return response["choices"][0]["message"]["content"]

In [None]:
api_key = 'sk-xxx'

system_prompt = "You are an expert astronomer. You can answer questions about what a telescope is likely to observe based on a given proposal. You never mention the proposal itself or its specifics. You are very concise and to the point. The answers of your questions will be used to instruction-tune a language model to answer questions about observations; respond accordingly and be general; mention generic classes rather than specific names or designations, never mention 'proposed observations'."

# prompt = f"Consider the following abstract from category {category}: {abstract}. The title is: {title}. Based on this abstract, what objects is the telescope likely to observe? Only answer the question."
prompt = f"Consider the following abstract from category {category}: {abstract}. The title is: {title}. Based on this abstract, what kind of science could be done with the observation? Only answer the question."

answer_question(prompt, system_prompt=system_prompt, api_key=api_key)

In [None]:
data_folder = "../data/observations/"

In [None]:
import sys
sys.path.append("../")

from data.utils import make_dataloader, create_input_iter, get_abstracts_and_images

abstracts, images, masks = get_abstracts_and_images(data_folder, abstracts_cycle_df)
train_ds = make_dataloader(abstracts, masks, images, batch_size=32, seed=42)
batches = create_input_iter(train_ds)

In [None]:
import matplotlib.pyplot as plt
plt.imshow(next(batches)[2][0][14])

In [None]:
import os
import pandas as pd
from PIL import Image
import numpy as np

def get_abstracts_and_images(data_folder, abstracts_cycle_df):
    # Lists to store results
    images_list = []
    abstracts_list = []

    # Walk through data folder
    for root, dirs, files in os.walk(data_folder):
        for file in files:
            if file.endswith(".jpg"):
                image_path = os.path.join(root, file)
                proposal_id = root.split("_")[-1]  # Extract proposal id from the directory name

                # Extract abstract using the dataframe
                abstract = abstracts_cycle_df[abstracts_cycle_df["ID"] == int(proposal_id)]["Abstract"].values[0]

                image = Image.open(image_path).convert("RGB")
                image = np.array(image)

                # Pad image to square
                h, w, c = image.shape
                max_dim = max(h, w)
                padded_image = np.ones((max_dim, max_dim, c), dtype=np.uint8) * 255

                # Calculate top and left padding
                y_offset = (max_dim - h) // 2
                x_offset = (max_dim - w) // 2

                padded_image[y_offset : y_offset + h, x_offset : x_offset + w, :] = image

                images_list.append(padded_image)
                abstracts_list.append(abstract)

    return np.array(abstracts_list), np.array(images_list)

data_folder = "../data/observations/"
abstracts, images = get_abstracts_and_images(data_folder, abstracts_cycle_df)

In [189]:
import jax
import tensorflow as tf

def _bytes_feature(value):
    """Returns a bytes_list from a string / byte."""
    if isinstance(value, type(tf.constant(0))):
        value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def serialize_example(string, image):
    """
    Creates a tf.train.Example message from a string and a 2D numpy array.
    """
    feature = {
        'abstract': _bytes_feature(tf.io.encode_base64(string.encode('utf-8'))),
        'image': _bytes_feature(tf.io.encode_base64(image.tobytes()))
    }
    
    example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
    return example_proto.SerializeToString()

with tf.io.TFRecordWriter('data.tfrecord') as writer:
    for abstract, image in zip(abstracts, images):
        example = serialize_example(abstract, image)
        writer.write(example)

In [191]:
from data.data import Dataset

dataset = Dataset('./', './', train_batch_size=4)
batches = next(dataset.train)

TypeError: in user code:

    File "/Users/smsharma/Projects/multimodal-data/notebooks/../data/data.py", line 72, in _parse_no_filter  *
        return _parse_image(*_parse_function(example_proto))
    File "/Users/smsharma/Projects/multimodal-data/notebooks/../data/data.py", line 66, in _parse_image  *
        image = tf.io.decode_raw(image_bytes)

    TypeError: Missing required positional argument


In [194]:

# TFRecord files
files = ['data.tfrecord']

# Parse function 
def parse_function(example_proto):

  features = {
    'image': tf.io.FixedLenFeature([], tf.string),
    'abstract': tf.io.FixedLenFeature([], tf.string)
  }

  parsed_features = tf.io.parse_single_example(example_proto, features)

  image = parsed_features['image']
  caption = parsed_features['abstract']

  # Decode raw image bytes
  image = tf.io.decode_raw(image, tf.uint8)
  image = tf.reshape(image, [128, 128, 3]) # reshape to match image shape
  
  return image, caption

# Create dataset and parse examples
dataset = tf.data.TFRecordDataset(files)
dataset = dataset.map(parse_function)

# Print example
image, caption = next(iter(dataset))
print(image.shape) # (128, 128, 3)


InvalidArgumentError: {{function_node __wrapped__IteratorGetNext_output_types_2_device_/job:localhost/replica:0/task:0/device:CPU:0}} Input to reshape is a tensor with 1048576 values, but the requested shape has 49152
	 [[{{node Reshape}}]] [Op:IteratorGetNext] name: 

In [None]:
batches

(array([b'______________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________

In [None]:


# batch_size = 32

# train_ds = tf.data.TFRecordDataset('data.tfrecord')

# def parse_tfrecord_fn(example):
#     feature_description = {
#         'string': tf.io.FixedLenFeature([], tf.string),
#         'image_data': tf.io.FixedLenFeature([], tf.string),
#     }
#     example = tf.io.parse_single_example(example, feature_description)
    
#     # Decode the image data from base64
#     image_data = tf.io.decode_base64(example['image_data'])
#     image = tf.io.decode_raw(image_data, tf.float32)  # assuming you stored float32 images; adjust if needed
#     # reshape the image here if you know its size
#     # for example: image = tf.reshape(image, (HEIGHT, WIDTH))
    
#     return example['string']._numpy(), image

# train_ds = train_ds.map(parse_tfrecord_fn) 
# train_ds = train_ds.shuffle(buffer_size=10000) 
# train_ds = train_ds.batch(32)

# train_ds = train_ds.cache()
# train_ds = train_ds.repeat()

# batch_dims = [jax.local_device_count(), batch_size // jax.device_count()]

# for _batch_size in reversed(batch_dims):
#     train_ds = train_ds.batch(_batch_size, drop_remainder=False)

# import flax
# def create_input_iter(ds):
#     """Create an input iterator that prefetches to device."""

#     def _prepare(xs):
#         def _f(x):
#             x = x._numpy()
#             return x

#         return jax.tree_util.tree_map(_f, xs)

#     it = map(_prepare, ds)
#     it = flax.jax_utils.prefetch_to_device(it, 2)
#     return it

# batches = create_input_iter(train_ds)