In [16]:
import torch
import json
from tqdm import tqdm
import os, sys
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), os.pardir, os.pardir)))
from dataset import Dataset
from utils import load_model, get_all_run_ids, get_device

In [28]:
models_path = "../trained/transformer-minidataset/top-runs"
run_ids = get_all_run_ids(path=models_path)
device = get_device()
models = {}
print("run_ids:", run_ids)
for _id in run_ids:
    models[_id] = load_model(_id, path=models_path, type="timelin").to(device)



run_ids: ['7jek9blg', 'xt1b5ncs', 'tbimy590', 'kdfu69md', 'ov2quu9v', '21soegq8', '0lhbjjox', '1qvcwtuy', 'b4gmn89t', 'jw8cws22', 'rarx6o1s', 'qfqaf3a0', '5qndphh8', 'a9wnfeis', 'awaen0te', '8wzblws2', '2s5lgg5h', 'e7pccu76', 'u9xd4jpq', '7wvj11dh', '56lvlmuj', 'gza0sltm', 'r21hwu9p', '7f4fvk0u', 'rlx8q7wd']


In [29]:
dataset = Dataset().load_dataset_from_pickle(pickle_path="../../dataset/data/0814/waveset_harmonics/processed.pkl")
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)

In [30]:
max, min = torch.empty(0).to("cpu"), torch.empty(0).to("cpu")
models_range = {}

with torch.no_grad():
    for _id in run_ids:
        model = models[_id]
        
        _max, _min = torch.empty(0).to("cpu"), torch.empty(0).to("cpu")

        for batch_idx, (train_inputs, _) in enumerate(tqdm(dataloader)):
            out = model.forward_encoder(train_inputs.to(device)).detach().to("cpu")
            
            _max = torch.cat((_max, out.max(dim=1).values.max(dim=0).values.unsqueeze(0)), dim=0)
            _min = torch.cat((_min, out.min(dim=1).values.min(dim=0).values.unsqueeze(0)), dim=0)
            
            models_range[_id] = (_max, _min)
            
        _max = _max.max(dim=0).values     
        _min = _min.min(dim=0).values
               
        torch.cuda.empty_cache()
        
        models_range[_id] = {"min":_min.tolist(), "max": _max.tolist()}
        
        print("model", _id, "max", _max.tolist(), "min", _min.tolist())
            
        max = torch.cat((max, _max.unsqueeze(0)), dim=0)
        min = torch.cat((min, _min.unsqueeze(0)), dim=0)
        

100%|██████████| 202/202 [00:04<00:00, 41.08it/s]


model 7jek9blg max [1.4946842193603516, 1.562696099281311, 1.4944300651550293, 1.559212327003479] min [-1.5378073453903198, -1.457781434059143, -1.6402841806411743, -1.4616296291351318]


100%|██████████| 202/202 [00:10<00:00, 18.45it/s]


model xt1b5ncs max [0.5708659291267395, 0.8368782997131348, 0.7334136366844177, 0.9146367907524109] min [-1.0481925010681152, -0.7996190190315247, -0.7589512467384338, -0.5208720564842224]


100%|██████████| 202/202 [00:13<00:00, 15.18it/s]


model tbimy590 max [1.5461868047714233, 1.5116850137710571, 1.6513432264328003, 1.62351393699646] min [-1.6370322704315186, -1.6363015174865723, -1.682663917541504, -1.468682050704956]


100%|██████████| 202/202 [00:04<00:00, 49.55it/s]


model kdfu69md max [1.3213152885437012, 1.5722863674163818, 1.3613009452819824, 1.3980050086975098] min [-1.5317035913467407, -1.4048943519592285, -1.5199394226074219, -1.4466338157653809]


100%|██████████| 202/202 [00:01<00:00, 126.72it/s]


model ov2quu9v max [1.2373473644256592, 1.1606051921844482, 1.4942669868469238, 1.204972505569458] min [-1.3095837831497192, -1.462934970855713, -1.1023536920547485, -1.4144636392593384]


100%|██████████| 202/202 [00:04<00:00, 49.18it/s]


model 21soegq8 max [0.7738459706306458, 1.343103051185608, 0.20557770133018494, 1.0694676637649536] min [-0.7428792715072632, -0.3549509346485138, -0.9676826000213623, -1.052905797958374]


100%|██████████| 202/202 [00:09<00:00, 21.41it/s]


model 0lhbjjox max [0.7824065685272217, 1.2599328756332397, 1.1428371667861938, 1.5682355165481567] min [-0.9381613731384277, -0.9440878033638, -1.3880125284194946, -1.4200910329818726]


100%|██████████| 202/202 [00:05<00:00, 38.97it/s]


model 1qvcwtuy max [1.5136666297912598, 1.4484717845916748, 1.471276879310608, 1.0441639423370361] min [-1.396672010421753, -1.3313907384872437, -1.524397373199463, -1.4576194286346436]


100%|██████████| 202/202 [00:08<00:00, 22.99it/s]


model b4gmn89t max [1.3568365573883057, 0.948638916015625, 0.5778349041938782, 0.1359158456325531] min [-0.6629053354263306, -0.6666450500488281, -0.47497841715812683, -1.2902921438217163]


100%|██████████| 202/202 [00:11<00:00, 17.94it/s]


model jw8cws22 max [1.1412286758422852, 1.1971361637115479, 1.577427864074707, 1.54746675491333] min [-1.1534920930862427, -1.4817637205123901, -1.2611080408096313, -1.6246588230133057]


100%|██████████| 202/202 [00:03<00:00, 55.94it/s]


model rarx6o1s max [1.6492736339569092, 1.4345115423202515, 1.636875033378601, 1.2438420057296753] min [-1.352027177810669, -1.6017067432403564, -1.4655553102493286, -1.5946893692016602]


100%|██████████| 202/202 [00:08<00:00, 24.89it/s]


model qfqaf3a0 max [0.36647024750709534, 1.588884949684143, 1.5461946725845337, 1.3958901166915894] min [-1.3239696025848389, -0.56501704454422, -1.6507271528244019, -1.5292942523956299]


100%|██████████| 202/202 [00:04<00:00, 49.44it/s]


model 5qndphh8 max [1.2883944511413574, 1.5119625329971313, 1.392168402671814, 1.29911470413208] min [-1.3703184127807617, -1.243712067604065, -1.4487407207489014, -1.427443027496338]


100%|██████████| 202/202 [00:00<00:00, 204.53it/s]


model a9wnfeis max [1.2419769763946533, 1.1727889776229858, 0.7604421973228455, 1.6347312927246094] min [-1.3410789966583252, -1.201667070388794, -1.6207364797592163, -1.7132971286773682]


100%|██████████| 202/202 [00:04<00:00, 47.07it/s]


model awaen0te max [1.387220859527588, 0.06399083137512207, 1.3740767240524292, 1.1314582824707031] min [-1.1059094667434692, -1.2134615182876587, -1.057119607925415, -0.39024099707603455]


100%|██████████| 202/202 [00:08<00:00, 22.47it/s]


model 8wzblws2 max [1.3469789028167725, 1.5222971439361572, 1.4309083223342896, 0.6362711787223816] min [-1.5450645685195923, -0.31412142515182495, -1.5380229949951172, -1.508801817893982]


100%|██████████| 202/202 [00:02<00:00, 67.42it/s]


model 2s5lgg5h max [1.386570930480957, 1.4140715599060059, 1.2460122108459473, 1.6028938293457031] min [-1.589165449142456, -1.6127716302871704, -1.6057164669036865, -1.4662046432495117]


100%|██████████| 202/202 [00:02<00:00, 81.67it/s]


model e7pccu76 max [1.0113177299499512, 1.3913065195083618, 0.3409314453601837, 1.5517088174819946] min [-1.26580810546875, -0.591565728187561, -0.8417868614196777, -1.4047974348068237]


100%|██████████| 202/202 [00:04<00:00, 43.55it/s]


model u9xd4jpq max [1.6526445150375366, 1.3393950462341309, 1.6458903551101685, 1.6825149059295654] min [-1.4653071165084839, -1.1881810426712036, -1.5952774286270142, -1.6234898567199707]


100%|██████████| 202/202 [00:03<00:00, 64.14it/s]


model 7wvj11dh max [1.6292576789855957, 1.0490270853042603, 1.1840529441833496, 0.7666370272636414] min [-1.88172447681427, -1.079048752784729, -0.2107587307691574, -1.4214043617248535]


100%|██████████| 202/202 [00:06<00:00, 29.95it/s]


model 56lvlmuj max [1.6150938272476196, 1.5431747436523438, 1.396875262260437, 1.5256397724151611] min [-1.5801997184753418, -1.7285032272338867, -0.6197175979614258, -0.9361215233802795]


100%|██████████| 202/202 [00:02<00:00, 68.43it/s]


model gza0sltm max [1.1571756601333618, 1.5852386951446533, 0.8634416460990906, 1.061753273010254] min [-1.640497088432312, -1.6628888845443726, -1.4151654243469238, -1.4069663286209106]


100%|██████████| 202/202 [00:06<00:00, 32.96it/s]


model r21hwu9p max [1.6200934648513794, 1.1105892658233643, 1.121202826499939, 1.2507604360580444] min [-1.6600890159606934, -1.3960494995117188, -1.0675920248031616, -1.5820281505584717]


100%|██████████| 202/202 [00:05<00:00, 38.45it/s]


model 7f4fvk0u max [1.5808073282241821, 1.2837378978729248, 1.6427041292190552, 1.2102341651916504] min [-1.5337498188018799, -1.15310800075531, -1.4271023273468018, -1.6008830070495605]


100%|██████████| 202/202 [00:05<00:00, 37.09it/s]

model rlx8q7wd max [0.3706403076648712, 0.831834614276886, 0.9259507656097412, 1.4498611688613892] min [-0.8492431640625, -1.372462511062622, -1.1744530200958252, -0.6216397881507874]





In [31]:
models_range

{'7jek9blg': {'min': [-1.5378073453903198,
   -1.457781434059143,
   -1.6402841806411743,
   -1.4616296291351318],
  'max': [1.4946842193603516,
   1.562696099281311,
   1.4944300651550293,
   1.559212327003479]},
 'xt1b5ncs': {'min': [-1.0481925010681152,
   -0.7996190190315247,
   -0.7589512467384338,
   -0.5208720564842224],
  'max': [0.5708659291267395,
   0.8368782997131348,
   0.7334136366844177,
   0.9146367907524109]},
 'tbimy590': {'min': [-1.6370322704315186,
   -1.6363015174865723,
   -1.682663917541504,
   -1.468682050704956],
  'max': [1.5461868047714233,
   1.5116850137710571,
   1.6513432264328003,
   1.62351393699646]},
 'kdfu69md': {'min': [-1.5317035913467407,
   -1.4048943519592285,
   -1.5199394226074219,
   -1.4466338157653809],
  'max': [1.3213152885437012,
   1.5722863674163818,
   1.3613009452819824,
   1.3980050086975098]},
 'ov2quu9v': {'min': [-1.3095837831497192,
   -1.462934970855713,
   -1.1023536920547485,
   -1.4144636392593384],
  'max': [1.237347364425

In [32]:
type(models_range)

dict

In [34]:
with open("../../models/trained/transformer-minidataset/top-runs/models_range.json", "w") as f:
    json.dump(models_range, f)