In [None]:
import sys
sys.path.append('../') # or just install the module
sys.path.append('../../fuzzy-torch') # or just install the module
sys.path.append('../../fuzzy-tools') # or just install the module
sys.path.append('../../astro-lightcurves-handler') # or just install the module

In [None]:
from fuzzytools.files import search_for_filedirs
from lchandler import _C as _C

surveys_rootdir = '../../surveys-save/'
filedirs = search_for_filedirs(surveys_rootdir, fext=_C.EXT_SPLIT_LIGHTCURVE)

In [None]:
%load_ext autoreload
%autoreload 2
import numpy as np
from fuzzytools.files import load_pickle, save_pickle
from fuzzytools.files import get_dict_from_filedir

method = 'spm-mcmc-estw'
filedir = f'../../surveys-save/survey=alerceZTFv7.1~bands=gr~mode=onlySNe~method={method}.splcds'
kf = '3'

filedict = get_dict_from_filedir(filedir)
root_folder = filedict['_rootdir']
cfilename = filedict['_cfilename']
survey = filedict['survey']
lcdataset = load_pickle(filedir)
lcdataset.only_keep_kf(kf) # saves ram
print(lcdataset)

In [None]:
lcset = lcdataset[f'{kf}@train']
values = lcset.get_all_values('obs')
print(np.min(values), np.percentile(values, 50), np.max(values))
print(np.sort(values)[::-1][:100])

In [None]:
lcset = lcdataset[f'{kf}@train.{method}']
values = lcset.get_all_values('obs')
print(np.min(values), np.percentile(values, 50), np.max(values))
print(np.sort(values)[::-1][:100])

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
from lchandler.plots.distrs import plot_values_distribution

kf = 0
set_name = f'{kf}@train'
lcdataset[set_name].set_diff_parallel('days')
plot_values_distribution(lcdataset, set_name, 'd_days')
plot_values_distribution(lcdataset, set_name, 'obs')
plot_values_distribution(lcdataset, set_name, 'obse')

In [None]:
%load_ext autoreload
%autoreload 2
from lcclassifier.datasets import CustomDataset

dataset_kwargs = {
    'max_day':100.,
    #attrs':['days','obs', 'obse']
    'in_attrs':['obs', 'obse'],
    #'attrs':['d_days','obs', 'obse']
    'rec_attr':'obs',
}
repeats = 5
device = 'cpu'
#lcset_name = f'{kf}@train.{method}'
lcset_name = f'{kf}@train'
s_train_dataset_da = CustomDataset(lcset_name, copy(lcdataset[lcset_name]), device,
    balanced_repeats=repeats,
    precomputed_copies=2, # 1 8 16
    uses_daugm=True,
    uses_dynamic_balance=True,
    ds_mode={'random':.75, 'left':.0, 'none':.25,},
    **dataset_kwargs,
    )

In [None]:
%load_ext autoreload
%autoreload 2
s_train_dataset_da.calcule_precomputed(verbose=1, n_jobs=1)

In [None]:
assert 0

In [None]:
assert 0

In [None]:
%load_ext autoreload
%autoreload 2
from lcclassifier.datasets import CustomDataset


lcset_name = f'{main_args.kf}@train.{main_args.method}'
s_train_dataset_da = CustomDataset(lcset_name, copy(lcdataset[lcset_name]), device,
    balanced_repeats=repeats,
    precomputed_copies=8, # 1 8 16
    uses_daugm=True,
    uses_dynamic_balance=True,
    ds_mode={'random':.75, 'left':.0, 'none':.25,},
    **dataset_kwargs,
    )

device = 'cpu' # cpu
train_dataset = CustomDataset(f'{kf}@train.{method}', copy(lcdataset[f'{kf}@train.{method}']), device, **dataset_kwargs)
val_dataset = CustomDataset(f'{kf}@val', copy(lcdataset[f'{kf}@val']), device, **dataset_kwargs)
train_dataset.transfer_scalers_to(val_dataset) # transfer metadata to val/test
print('train_dataset:', train_dataset)
print('val_dataset:', val_dataset)

In [None]:
%load_ext autoreload
%autoreload 2
from fuzzytorch.utils import print_tdict
import cProfile

p = cProfile.Profile(); p.enable()
tdict = train_dataset.get_item(train_dataset.get_lcobj_names()[0])
print_tdict(tdict)
p.disable(); p.dump_stats('prof.prof')
print(tdict['target']['balanced_w'])
pass

In [None]:
%load_ext autoreload
%autoreload 2
import matplotlib.pyplot as plt
from lchandler.plots.lc import plot_lightcurve
from lchandler import C_ as C_
%matplotlib inline

dataset = train_dataset
lcobj_name = dataset.get_random_stratified_lcobj_names()[0]


tdict, lcobj = dataset.get_item(lcobj_name, uses_len_clip=False, uses_daugm=False, return_lcobjs=True)
print(lcobj)


minput = tdict['input']
target = tdict['target']

figsize = (10,13)
fig, axs = plt.subplots(5+1, 1, figsize=figsize)

ax = axs[0]
for kb,b in enumerate(dataset.band_names):
    plot_lightcurve(ax, lcobj, b, label=f'{b} obs', max_day=dataset.max_day)
ax.set_ylabel('observation')

b = 'r'
len_lcobj = minput[f'onehot.{b}'].sum()
ax = axs[1]
time = minput[f'rtime.{b}'][...,0]
for ka,in_attr in enumerate(dataset.in_attrs):
    ax.plot(time[:len_lcobj], minput[f'x.{b}'][:len_lcobj,ka], '-o', label=f'{C_.SHORT_NAME_DICT[in_attr]} (norm)')
ax.set_ylabel(f'x.{b}')

ax = axs[2]
ax.plot(time[:len_lcobj], minput[f'onehot.{b}'][:len_lcobj], 'o')
ax.set_ylabel(f'onehot.{b}')

ax = axs[3]
ax.plot(time[:len_lcobj], minput[f'rtime.{b}'][:len_lcobj], '-o')
ax.set_ylabel(f'time.{b}')

ax = axs[4]
ax.plot(time[:len_lcobj], minput[f'dtime.{b}'][:len_lcobj], '-o')
ax.set_ylabel(f'dtime.{b}')

ax = axs[5]
ax.plot(time[:len_lcobj], target[f'rec_x.{b}'][:len_lcobj], '-o')
ax.set_ylabel(f'rec_x.{b}')

class_name = dataset.class_names[target['y']]
title = ''
title += f'training light curve sample & model inputs & onehot & temporal encoding \n'
title += f'survey: {dataset.lcset.survey} - set: {dataset.lcset_name}'
title += f' - class: {class_name} - max_day: {dataset.max_day:.2f} - max_len: {dataset.max_len}'
#title += f' - training: {dataset.training}'
for ax in axs:
    #ax.legend(prop={'size':14})
    ax.legend(loc='upper right')
    ax.grid(alpha=0.5)
axs[0].set_title(title)
axs[-1].set_xlabel('days')
plt.show()

In [None]:
%load_ext autoreload
%autoreload 2
from lcclassifier.dataloaders import CustomDataLoader
from fuzzytorch.utils import print_tdict

loader_kwargs = {
    'batch_size':2,
}
random_subcrops = 3
s_train_loader = CustomDataLoader(train_dataset, shuffle=False, **loader_kwargs)
s_train_loader.eval()
dataset.set_max_day(40)

for k,tdict in enumerate(s_train_loader):
    target = tdict['target']
    print_tdict(tdict)
    print(tdict['input']['rtime.*'][0,:,0])
    break

In [None]:
%load_ext autoreload
%autoreload 2
from lcclassifier.dataloaders import CustomDataLoader
from fuzzytorch.utils import print_tdict

loader_kwargs = {
    'batch_size':1,
    #'num_workers':1, # bug?
}
random_subcrops = 3
s_train_loader = CustomDataLoader(train_dataset, shuffle=True, random_subcrops=random_subcrops, **loader_kwargs)
s_train_loader.train()

for k,tdict in enumerate(s_train_loader):
    model_input = tdict['input']
    target = tdict['target']
    print_tdict(tdict)
    for idx in range(len(model_input['x'])):
        print(model_input['x'][idx,:,0])
        print(model_input['onehot'][idx].sum(-1))
    assert 0