### Preparing the data to train topics extraction model

Good for us, ATels already contain some tags (called `subjects`). I will use all available ATels tags to label their bodies and train topics model on them

In [146]:
import pandas as pd
import numpy as np
import json
import re
import os

In [147]:
atel_df = pd.read_csv("../data/atel.csv", index_col=0)\
    .reset_index()\
    .rename(columns={'subjects': 'topics', 'index': 'telegram_index'})[['telegram_index', 'topics']]

In [148]:
len(atel_df)

16037

In [149]:
atel_df['telegram_index'] = atel_df['telegram_index'].apply(lambda x: str(x) + '_atel')

In [150]:
atel_df = atel_df.set_index('telegram_index')

In [151]:
clean_bodies = pd.read_csv("../word2vec/clean_bodies.csv", index_col=0)

In [152]:
atel_df = atel_df.join(clean_bodies).dropna(subset=['topics', 'body_clean'])

In [153]:
atel_df['topics'] = atel_df['topics'].apply(str.lower).apply(lambda x: re.sub(r'[^\w\s\-,]', ' ', x))

In [154]:
atel_df.to_csv("atel_with_topics.csv", index=True)

In [155]:
atel_df  # 205 ATels were thrown out (had no topics data or missing cleaned bodies)

Unnamed: 0_level_0,topics,body_clean
telegram_index,Unnamed: 1_level_1,Unnamed: 2_level_1
2_atel,"gamma ray, gamma-ray burst",the following message emailed evening walter l...
3_atel,"gamma ray, gamma-ray burst",in addendum atel 2 additional information forw...
4_atel,"optical, gamma ray, a comment, gamma-ray burst",the recent detection delayed gamma ray burst g...
5_atel,"optical, gamma-ray burst",the optical transient iauc 6788 grb 971214 iau...
6_atel,"optical, gamma-ray burst",grb980109 field observed ogle collaboration 1....
...,...,...
16033_atel,"gamma ray, gev, agn, blazar, quasar",the large area telescope lat one two instrumen...
16034_atel,"optical, supernovae",we report following classification spectroscop...
16035_atel,"gamma ray, gev, request for observations, agn...",the large area telescope lat one two instrumen...
16036_atel,"cataclysmic variable, nova, transient",our spectroscopic monitoring development pnv j...


### We will consider training open.ai's `ada` model. Prepare jsonl with training annotations

In [88]:
PROMPT = """
{text}
\n\n###\n\n
"""

with open("training_data/annots.jsonl", 'w') as f:
    
    for i, r in atel_df.iterrows():
        
        prompt = PROMPT.format(text=r.body_clean)
        output = json.dumps({
            "prompt": prompt,
            "completion": f" {r.topics} ###"
        })

        f.write(f"{output}\n")

### Set open.ai API key and proceed with the train. For 15k annotations the cheapest model will cost 8,8$

In [84]:
# pip install --force-reinstall openai==0.25.0

In [85]:
# os.environ["OPENAI_API_KEY"] = # ...

In [86]:
# !openai api fine_tunes.create -t training_data/annots.jsonl -m ada

In [87]:
# !openai api fine_tunes.follow -i ft-dUHWnM1TSgx1yEbeV5cdjoun

### load the model after fine-tune's status is `completed`

In [126]:
clean_bodies.body_clean[126]

'we analyzed second bepposax nfi observation rxte asm error box smith et al. gcn 126 grb 980703 made july 7.779-8.706 ut. preliminary analysis combined mecs 2 3 data shows variable x-ray source 1sax j2359.1+0835 galama et al. gcn 127 positionally coincident radio counterpart frail et al. gcn 128 decayed factor 5.5 +- 1.5 july 4 8. assuming power law light curve find decay index 1.33 +- 0.25 this message citeable.'

In [127]:
import openai

PROMPT = f"""
we analyzed second bepposax nfi observation rxte asm error box smith et al. gcn 126 grb 980703 made july 7.779-8.706 ut. 
preliminary analysis combined mecs 2 3 data shows variable x-ray source 1sax j2359.1+0835 galama et al. gcn 127 positionally 
coincident radio counterpart frail et al. gcn 128 decayed factor 5.5 +- 1.5 july 4 8. assuming power 
law light curve find decay index 1.33 +- 0.25 this message citeable.
\n\n###\n\n
"""

resp = openai.Completion.create(
    model="ada:ft-pai-2023-05-16-20-18-46",
    prompt=PROMPT)

In [128]:
print(resp['choices'][0]['text'])

 x-ray, gamma ray, gamma-ray burst, variables ###





In [129]:
[x.strip() for x in resp['choices'][0]['text'].replace("###", ',').split(',') if x.strip()]

['x-ray', 'gamma ray', 'gamma-ray burst', 'variables']