In [1]:
#%pip install minio

Defaulting to user installation because normal site-packages is not writeable
You should consider upgrading via the '/usr/local/bin/python3 -m pip install --upgrade pip' command.[0m


In [1]:
import pandas as pd
import numpy as np
from minio import Minio
from sktime.base import load
from scipy.interpolate import Akima1DInterpolator

import warnings
warnings.filterwarnings("ignore")
from tqdm import tqdm
import joblib


def get_data(chunk, scaler, n_grid, verbose=0) -> pd.DataFrame:
        """Return dataframe with interpolated by 
        Akima1DInterpolator for n_grid points."""
        from sklearn.preprocessing import StandardScaler
        from sklearn.model_selection import train_test_split
        
        columns=(
            ['oid','mag_min'] 
            + [f'mag_{i}' for i in range(n_grid)]
            + ['magerr_mean', 'magerr_std']
        )

        intertpolate_data = []

        for record in chunk.iterrows():
            item = record[1]
            row = []
            row.append(item[5])
            # interpolate magnitude data
            data_row = item[6].replace('[(','#').replace(')]','#').replace('),(','#').replace(',','#').split('#')
            data_row = [element for element in data_row if element]
            data_row = list(map(float, data_row))
            x = np.array(data_row[1::4])
            y = np.array(data_row[2::4])
            errors = np.array(data_row[3::4])
            # subtract a minimum magnitude for normalize
            y_min = y.min()
            y = y - y_min
            row+=[y_min]
            # check input data
            if len(x) != len(y):
                if verbose == 1 and len(x) != len(y):
                    print(
                        f'len time points={len(x)} not equl len mag points={len(y)}'
                        f' in data for obs_id={row[0]} object'
                    )
                    continue
            interpolator = Akima1DInterpolator(x, y)
            xnew = np.linspace(x.min(), x.max(), n_grid)
            ynew = interpolator(xnew)
            # add features to output list
            row+= list(ynew)

            # interpolate magerror data
            errors_mean = errors.mean()
            errors_std = errors.std()
            row+=[errors_mean, errors_std]

            intertpolate_data.append(row)
        
        data = pd.DataFrame(data=intertpolate_data, columns=columns)
        
        x = data.drop(['oid'], axis=1)
        oid = data['oid']

        # scaling data
        # scaler = StandardScaler()
        x = scaler.transform(x)
        x = x[:,np.newaxis, :]

        return x, oid
    

MINIO_ACCESS_KEY_ID = 'S8yXXCReQvTPJJwu'
MINIO_SECRET_ACCESS_KEY = 'YpLtK4rbual3mqOb39e3ZdLq3pCmFh4E'

client = Minio('s3.lpc.snad.space',
               access_key=MINIO_ACCESS_KEY_ID,
               secret_key=MINIO_SECRET_ACCESS_KEY)

buckets = client.list_buckets()

print('Test connection to S3 storage')
for bucket in buckets:
    print(bucket.name, bucket.creation_date)
    
    
obj = client.get_object(
    "ztf-high-cadence-data",
    "dr14_high_cadence.csv",
)

rocket_model = load('Rocket')
scaler = joblib.load('scaler.save')

result = []
# debug
i = 0

data = pd.read_csv(obj, chunksize=10_000)
for chunk in tqdm(data, total=10_000, desc='Data progress'):
    x, oid = get_data(chunk, scaler, n_grid=100)
    y_pred = rocket_model.predict(x)
    
    oid = np.array(oid)
    y_pred = np.array(y_pred)
    
    oid_flares = oid[y_pred==1]
    if len(oid_flares)>0:
        result=result+list(oid_flares)
    # debug
#     i+=1
#     if i > 5:
#         break

with open('flare_rocket.csv', 'w') as f:
     
    # write elements of list
    for items in result:
        f.write(f'{str(items)}\n')
     
    print("File written successfully")
 
 
# close the file
f.close()
    
print(f"Amount rocket's flares {len(result)}")

Data progress:  97%|█████████▋| 9729/10000 [21:01:35<35:08,  7.78s/it]


Test connection to S3 storage
flare-classifier 2023-04-21 23:01:45.001000+00:00
ztf-high-cadence-data 2023-07-11 13:15:35.676000+00:00
File written successfully
Amount rocket's flares 5322986


In [21]:
type(result)

list

In [29]:
import pandas as pd
import numpy as np
from minio import Minio

import warnings
warnings.filterwarnings("ignore")
from tqdm import tqdm
import joblib


MINIO_ACCESS_KEY_ID = 'S8yXXCReQvTPJJwu'
MINIO_SECRET_ACCESS_KEY = 'YpLtK4rbual3mqOb39e3ZdLq3pCmFh4E'

def search_oids(search_set):
    client = Minio('s3.lpc.snad.space',
                access_key=MINIO_ACCESS_KEY_ID,
                secret_key=MINIO_SECRET_ACCESS_KEY)

    buckets = client.list_buckets()

    print('Test connection to S3 storage')
    for bucket in buckets:
        print(bucket.name, bucket.creation_date)
        
        
    obj = client.get_object(
        "ztf-high-cadence-data",
        "dr14_high_cadence.csv",
    )
    current_search_set = search_set

    data_index = {'oid': list(search_set)}
    result = pd.DataFrame(data_index)
    result['in_target_set'] = False
    result = result.set_index('oid')
    # debug
    i = 0

    data = pd.read_csv(obj, chunksize=10_000)
    for chunk in tqdm(data, total=10_000, desc='Oid search progress'):
        if len(current_search_set)==0:
            break
        for record in chunk.iterrows():
            item = record[1]
            current_oid = item[5]
#             print(f'type current_oid = {type(current_oid)} value {current_oid}')
            pred_len = len(current_search_set)
            current_search_set = current_search_set - {current_oid}
            if pred_len > len(current_search_set):
                result.loc[current_oid, 'in_target_set'] = True
            
            if len(current_search_set)==0:
                break
            # debug
#             break
        
        # debug
#         i+=1
#         if i > 5:
#             break

    result.to_csv('search_oids_result.csv')
    print('Results search oids in target file')
    print(result.head(len(search_set)))


validation_df = pd.read_csv('raw_real_flares_test.csv')
validation_oids = validation_df[validation_df.is_flare==1]['oid'].tolist()
validation_oids = list(map(int,validation_oids))
search_oids(set(validation_oids))
    


Test connection to S3 storage
flare-classifier 2023-04-21 23:01:45.001000+00:00
ztf-high-cadence-data 2023-07-11 13:15:35.676000+00:00
Results search oids in target file
                 in_target_set
oid                           
832210400037888           True
718201300005383          False
660207200039946           True
771216100033044           True
283211100006940           True
...                        ...
733207400019437           True
461216200033263           True
771211400031727           True
642215300060146           True
685205100007414           True

[102 rows x 1 columns]


Oid search progress:  97%|█████████▋| 9729/10000 [4:08:32<06:55,  1.53s/it]
