In [1]:
import torch
import torch.nn as nn
from tqdm import tqdm
from torch.utils.data import DataLoader, random_split, ConcatDataset
import numpy as np

import sys
sys.path.insert(0, '../..')
from RB_ZTF.scripts.datasets import *
from RB_ZTF.scripts.rnn import *
from RB_ZTF.scripts.losses import *

import os
from torch.utils.tensorboard import SummaryWriter

set_random_seed(7)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

2024-03-19 23:41:58.939449: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
print('Creating dataset..')
oids, labels = get_only_r_oids('../akb.ztf.snad.space.json')
data = EmbsSequenceData(oids, labels, label_type='long', path='../embeddings/')

fold1, fold2, fold3, fold4, fold5 = random_split(data, [0.2, 0.2, 0.2, 0.2, 0.2])
folds = [fold1, fold2, fold3, fold4, fold5]

print('Making dataloaders..')
bucket_boundaries = [200, 400, 600, 800]

Creating dataset..
Making dataloaders..


In [3]:
models_args = {'baseline': {'model':{'rnn_type': 'GRU'}, 'crit': nn.CrossEntropyLoss(), 'optim': {}},
        'wd_2dir': {'model':{'rnn_type': 'GRU', 'bidirectional': True,}, 'crit': nn.CrossEntropyLoss(), 'optim': {'weight_decay': 1e-5}},
        'tversky': {'model':{'rnn_type': 'GRU'}, 'crit': rnn_loss_handler, 'optim': {}},
        'lstm': {'model':{'rnn_type': 'LSTM'}, 'crit': nn.CrossEntropyLoss(), 'optim': {}}
       }

In [4]:
name = input('Choose model to train (baseline/wd_2dir/tversky/lstm):  ')
args = models_args[name]

Choose model to train (baseline/wd_2dir/tversky/lstm):   wd_2dir


In [5]:
for k, fold in enumerate(folds):
    concat_folds = ConcatDataset(folds[:k] + folds[k+1:])
    train_sampler = BySequenceLengthSampler(concat_folds, bucket_boundaries, 32, drop_last=False, shuffle=True)
    test_sampler = BySequenceLengthSampler(fold, bucket_boundaries, 32, drop_last=False, shuffle=False)


    train_loader = DataLoader(concat_folds, batch_size=1, 
                        batch_sampler=train_sampler, 
                        num_workers=8,
                        collate_fn=collate,
                        drop_last=False, pin_memory=False)

    test_loader = DataLoader(fold, batch_size=1, 
                        batch_sampler=test_sampler, 
                        num_workers=8,
                        collate_fn=collate,
                        drop_last=False, pin_memory=False)



    model = RBclassifier(hidden_size=128, latent_dim=36, out_size=2, **args['model'])
    model.train()
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, **args['optim'])
    criterion = args['crit']
    writer = SummaryWriter(f'runs/{name}_fold{k}/')
    print(f'{name} model training.. folds: {k+1}/{len(folds)}')
    res = []
    for i in tqdm(range(1, 501)):
        res.append(
                train_rnn(
                    model=model,
                    optimizer=optimizer,
                    train_loader=train_loader,
                    test_loader=test_loader,
                    criterion=criterion,
                    epoch=i,
                    device=device,
                    writer=writer
                    )
            )
    writer.flush()

    model_dir = f'../trained_models/rnn/{name}'
    if not os.path.isdir(model_dir):
        os.mkdir(model_dir)

    torch.save(model.state_dict(), f'{model_dir}/model{k}.zip')

    np.save(f'{model_dir}/result{k}.npy', res)

print(f'{name} model successfully trained\n----------------------------------------\n----------------------------------------\n\n')

wd_2dir model training.. folds: 1/5


100%|█████████████████████████████████████████| 500/500 [48:42<00:00,  5.85s/it]


wd_2dir model training.. folds: 2/5


100%|█████████████████████████████████████████| 500/500 [47:04<00:00,  5.65s/it]


wd_2dir model training.. folds: 3/5


100%|█████████████████████████████████████████| 500/500 [46:25<00:00,  5.57s/it]


wd_2dir model training.. folds: 4/5


100%|█████████████████████████████████████████| 500/500 [46:48<00:00,  5.62s/it]


wd_2dir model training.. folds: 5/5


100%|█████████████████████████████████████████| 500/500 [47:53<00:00,  5.75s/it]

wd_2dir model successfully trained
----------------------------------------
----------------------------------------





