In [5]:
%load_ext autoreload
%autoreload 2

import argparse
import random

import torch.backends.cudnn as cudnn
from timeit import default_timer
from attrdict import AttrDict


import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import os
import re
from tqdm import tqdm
import sys
# from ml_collections import config_dict
# from easydict import EasyDict as edict
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import matplotlib
import pandas as pd
from transformers import AutoModel, AutoConfig, RobertaForTokenClassification

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [8]:
from utils import count_params, count_trainable_params, get_params_to_update, \
                            set_grad_state, set_param_grad
from task_configs import get_data, get_config, get_metric, get_optimizer_scheduler, get_scheduler
from embedder import get_tgt_model, wrapper1D
import json
import time

from functools import partial
from torch.utils.data import random_split
import torchvision
import torchvision.transforms as transforms

In [12]:
from ray import train, tune
from ray.tune.schedulers import ASHAScheduler

In [13]:
sys.path.append("~/ORCA/clean/gene-orca")
    
os.environ['CUDA_VISIBLE_DEVICES'] = "0"

def get_params_to_update(model):

    params_to_update = []
    name_list = ''
    for name, param in model.named_parameters():
        if param.requires_grad == True:
            params_to_update.append(param)
            name_list += "\t" + name
    print("Params to learn:", name_list)
    return params_to_update

In [14]:
print(torch.cuda.is_available())

False


In [15]:
def get_optimizer(config):
    if config['optimizer'] == 'SGD':
        return partial(torch.optim.SGD, lr=config['lr'], momentum=0.99, weight_decay=config['weight_decay'])
    elif config['optimizer'] == 'Adam':
        return partial(torch.optim.Adam, lr=config['lr'], betas=[0.9, 0.98], weight_decay=config['weight_decay'])
    elif config['optimizer'] == 'AdamW':
        return partial(torch.optim.AdamW, lr=config['lr'], betas=[0.9, 0.98], weight_decay=config['weight_decay'])


In [20]:
def train_one_epoch(args, model, optimizer, scheduler, loader, loss, temp, label_smoothing_factor=None, decoder=None, transform=None):    

    model.train()
                    
    train_loss = 0
    optimizer.zero_grad()
    right, alldata = 0,0
    for i, data in enumerate(loader):

        x, y = data 
            
        x, y = x.to(args.device), y.to(args.device) # accelerate
        out = model(x)

        # right += (y==out.argmax(-1)).float().sum()  # if using accuracy, count how many out is correct (=y)
        # alldata += len(x)

      
        # print('out:',out.size(),'y:', y.size())
        l = loss(out, y)
        
        l.backward()


        if config['clip'] > 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), config['clip'])

        if (i + 1) % config['accum'] == 0: # to save memory to approximate performance by large batchsize
            optimizer.step()
            optimizer.zero_grad()
            
        
        if args.lr_sched_iter:
            scheduler.step()

        train_loss += l.item()

        if i >= temp - 1:
            break

    if (not args.lr_sched_iter):
        scheduler.step()
    # print(right/alldata)
    return train_loss / temp


def evaluate(args, model, loader, loss, metric, n_eval, decoder=None, transform=None, fsd_epoch=None):
    model.eval()
    
    eval_loss, eval_score = 0, 0
    right, alldata = 0,0


    ys, outs, n_eval, n_data = [], [], 0, 0

    with torch.no_grad():
        for i, data in enumerate(loader):
            x, y = data
                                
            x, y = x.to(args.device), y.to(args.device) # accelerate

            out = model(x)

            # right+=(out.argmax(-1)==y).float().sum() #
            # alldata += len(x) #
    
            outs.append(out) 
            ys.append(y) 
            n_data += x.shape[0]
        
            if n_data >= args.eval_batch_size or i == len(loader) - 1:
                outs = torch.cat(outs, 0)
                ys = torch.cat(ys, 0)

                eval_loss += loss(outs, ys).item()
                # print('309',outs.shape)
                # print('309',ys.shape)
                # print('309',outs)
                # print('309',ys)
                # print(metric(outs, ys))
                eval_score += metric(outs, ys).item()
                
                n_eval += 1

                ys, outs, n_data = [], [], 0

        eval_loss /= n_eval
        eval_score /= n_eval


   
    # eval_score = 1-(right/alldata).detach().cpu().numpy() # if using accuracy
    return eval_loss, eval_score

In [22]:

# args = np.load('/home/wenduoc/ORCA/clean/gene-orca/results/' + dataset + '/all_' + exp_id + '/0/hparams.npy', allow_pickle=True).item()
# print(args)
# args= AttrDict(args)
# # args=AttrDict({'dataset': 'H4', 'embedder_dataset': 'text', 'objective': 'MMD', 'weight': 'roberta-large', 'maxsamples': 256, 'target_seq_len': 128, 'experiment_id': 8, 'seed': 0, 'epochs': 20, 'embedder_epochs': 60, 'pretrain_epochs': 0, 'predictor_epochs': 0, 'joint_optim': True, 'alpha': 1, 'beta': 1, 'finetune_method': 'all', 'one_hot': False, 'lora_r': 12, 'lora_alpha': 32, 'lora_dropout': 0.1, 'lora_target_modules': ['q_proj', 'v_proj'], 'drop_out': 0, 'label_smoothing_factor': 0, 'activation': None, 'rc_aug': True, 'shift_aug': True, 'use_wandb': True, 'wandb_key': 'ef4b923327eb2a110fda334efee4ec80feee4bc7', 'data_parallel': False, 'quantize': False, 'embedder_type': 'unet', 'embedder_init': 'random', 'batch_size': 64, 'eval_batch_size': 1000, 'accum': 1, 'clip': 1, 'validation_freq': 1, 'optimizer': {'name': 'AdamW', 'params': {'lr': 5e-06, 'betas': [0.9, 0.98], 'weight_decay': 1e-06, 'momentum': 0.99}}, 'scheduler': {'name': 'WarmupLR', 'params': {'warmup_epochs': 5, 'decay_epochs': 20, 'sched': [30, 60, 90], 'base': 0.2}}, 'no_warmup_scheduler': {'name': 'StepLR', 'params': {'warmup_epochs': 10, 'decay_epochs': 100, 'sched': [40, 60, 80], 'base': 0.2}}, 'num_workers': 2, 'reproducibility': False, 'valid_split': False, 'device': 'cuda', 'infer_label': False, 'lr_sched_iter': True})


In [19]:
   
def main(config):
    
    torch.manual_seed(0)
    np.random.seed(0)
    random.seed(0) 
    torch.cuda.manual_seed_all(0)
    # torch.backends.cudnn.deterministic = True
    # torch.backends.cudnn.benchmark = False
    eval('setattr(torch.backends.cudnn, "deterministic", True)')
    eval('setattr(torch.backends.cudnn, "benchmark", False)')
            
    args = config_dict.ConfigDict()
    args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    args.dataset = 'H4' 
    args.weight = 'roberta'
    args.eval_batch_size = 1000
    args.maxsamples=256
    args.target_seq_len: 128
    
    # args.accum = 1
    args.validation_freq = 1
    
    args.finetune_method = 'all'
    
    # args.lora = config_dict.ConfigDict()
    # args.lora.target_modules = 'query value key dense reduction'
    # args.lora.layer_indices = False
    # args.lora.layers_pattern = False
    # args.lora.bias = 'none'
    # args.lora.rank = 8
    # args.lora.alpha = 16
    
    args.num_workers = 4
    args.valid_split = False
    args.epochs = 30

    args.scheduler = config_dict.ConfigDict()
    args.scheduler.name = 'WarmupLR'
    args.scheduler.params = config_dict.ConfigDict()
    args.scheduler.params.warmup_epochs = 5
    args.scheduler.params.decay_epochs = 60
    args.scheduler.params.sched = [20, 40, 60]
    args.scheduler.params.base = 0.2

    args.no_warmup_scheduler = config_dict.ConfigDict()
    args.scheduler.name = 'StepLR'
    args.scheduler.params = config_dict.ConfigDict()
    args.scheduler.params.warmup_epochs = 5
    args.scheduler.params.decay_epochs = 30
    args.scheduler.params.sched = [30, 60, 90]
    args.scheduler.params.base = 0.2
    

    root = '/home/wenduoc/ORCA/clean/gene-orca/datasets'
    
    print('torch.cuda.is_available():',torch.cuda.is_available())
    print('device:', args.device)


    dims, sample_shape, num_classes, loss, args = get_config(root, args)
    # print(dims, sample_shape, num_classes, loss)

    args.embedder_epochs = 0

    wrapper_func = wrapper1D 
    model = wrapper_func(sample_shape, num_classes, weight=args.weight, 
                            train_epoch=args.embedder_epochs, activation=args.activation, 
                            target_seq_len=args.target_seq_len, drop_out=config['drop_out'], lora=args.lora)
    model.output_raw = False
    model = model.to(args.device).train()
    print(model)

    # train_loader, val_loader, test_loader, n_train, n_val, n_test, data_kwargs = get_data(root, args.dataset, config['batch_size'], args.valid_split)
    # decoder = data_kwargs['decoder'] if data_kwargs is not None and 'decoder' in data_kwargs else None 
    # transform = data_kwargs['transform'] if data_kwargs is not None and 'transform' in data_kwargs else None
    # metric, compare_metrics = get_metric(root, args.dataset)


    train_loader, val_loader, test_loader, n_train, n_val, n_test, data_kwargs = get_data(root, args.dataset, config['batch_size'], args.valid_split, quantize=False, rc_aug=True, shift_aug=True, one_hot=False)
    metric, compare_metrics = get_metric(root, args.dataset)
    decoder = None 
    transform = None
    
    train_full = True

    # set whole model to be trainable
    set_grad_state(model, True)  
    set_param_grad(args, model, args.finetune_method)

    optimizer = get_optimizer(config)(get_params_to_update(model))
    lr_lambda, args.lr_sched_iter = get_scheduler(args.scheduler.name, args.scheduler.params, args.epochs, n_train)
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)

    # print("\n------- Experiment Summary --------")
    # print("id:", args.experiment_id)
    # print("dataset:", args.dataset, "\tbatch size:", args.batch_size, "\tlr:", args.optimizer.params.lr)
    # print("num train batch:", n_train, "\tnum validation batch:", n_val, "\tnum test batch:", n_test)
    # print("finetune method:", args.finetune_method)
    # print('train_full:', train_full)
    # print("param count:", count_params(model), count_trainable_params(model))
    
    train_losses, train_score = [], []
    for ep in range(args.epochs):

        train_loss, model_time, data_time = train_one_epoch(args, config, model, optimizer, scheduler, train_loader, loss,  decoder, transform)
        
        if ep % args.validation_freq == 0 or ep == args.epochs-1: 
            val_loss, val_score = evaluate(args, model, val_loader, loss, metric, n_val, decoder, transform, 
                                           fsd_epoch=ep if args.dataset == 'FSD' else None)
            
            train_losses.append(train_loss)
            train_score.append(val_score)

            train.report({'val_score': val_score})

            print("[train full", ep, "]",                    
                    "\ttrain loss:", "%.4f" % train_loss, "\tval loss:", "%.4f" % val_loss, 
                    "\tval score:", "%.4f" % val_score, "\tbest val score:", "%.4f" % compare_metrics(train_score))

            

In [23]:
search_space = {
    "lr": tune.choice([5e-3, 5e-4, 5e-5, 5e-6]),
    # "weight_decay": tune.choice([0, 1e-2, 1e-4]),
    'batch_size': tune.choice([16, 32, 64]),
    # 'accum': tune.choice([1,2]),
    # 'clip': tune.choice([-1, 1]),
    'drop_out': tune.choice([0, 0.05]),
    'optimizer': tune.choice(['Adam', 'AdamW']),
}

# Uncomment this to enable distributed execution
# `ray.init(address="auto")`

main_with_gpu = tune.with_resources(main, {"cpu": 8, "gpu": 1})
tuner = tune.Tuner(
    main_with_gpu,
    tune_config=tune.TuneConfig(
        num_samples=20,
        scheduler=ASHAScheduler(metric="val_score", mode="min"),
    ),
    param_space=search_space,
)
results = tuner.fit()

best_result = results.get_best_result("val_score", "min")

print("Best trial config: {}".format(best_result.config))

# Obtain a trial dataframe from all run trials of this `tune.run` call.
dfs = {result.path: result.metrics_dataframe for result in results}


KeyboardInterrupt



In [None]:
results