-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
c4676d3
commit ff4fd8a
Showing
2 changed files
with
196 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,196 @@ | ||
import os | ||
import torch | ||
import torch.nn.functional as F | ||
import numpy as np | ||
from transformers import GPT2Config | ||
|
||
from transformers_quantized import TransformerQuantizedPretrained | ||
from cichy_data import CichyData, CichyContData, CichyQuantized, CichyQuantizedGauss, CichyQuantizedAR | ||
|
||
|
||
class Args: | ||
gpu = '1' # cuda gpu index | ||
func = {'train': True} # dict of functions to run from training.py | ||
|
||
def __init__(self): | ||
n = 1 # can be used to do multiple runs, e.g. over subjects | ||
|
||
# experiment arguments | ||
self.name = 'args.py' # name of this file, don't change | ||
self.fix_seed = True | ||
self.common_dataset = False | ||
self.load_dataset = True # whether to load self.dataset | ||
self.learning_rate = 0.0001 # learning rate for Adam | ||
self.max_trials = 1.0 # ratio of training data (1=max) | ||
self.val_max_trials = False | ||
self.batch_size = 2 # batch size for training and validation data | ||
self.epochs = 1000 # number of loops over training data | ||
self.val_freq = 10 # how often to validate (in epochs) | ||
self.print_freq = 2 # how often to print metrics (in epochs) | ||
self.anneal_lr = False # whether to anneal learning rate | ||
self.save_curves = True # whether to save loss curves to file | ||
self.load_model = False | ||
self.result_dir = [os.path.join( | ||
'/', | ||
'well', | ||
'woolrich', | ||
'users', | ||
'yaq921', | ||
'MEG-transfer-decoding', # path(s) to save model and others | ||
'results', | ||
'cichy_epoched', | ||
'subj1', | ||
'cont_quantized', | ||
'gpt2_50hz100hz', | ||
'concat_output')] | ||
self.model = TransformerQuantizedPretrained # class of model to use | ||
self.dataset = CichyQuantized # dataset class for loading and handling data | ||
|
||
# wavenet arguments | ||
self.activation = torch.nn.Identity() # activation function for models | ||
self.subjects = 0 # number of subjects used for training | ||
self.embedding_dim = 0 # subject embedding size | ||
self.p_drop = 0.0 # dropout probability | ||
self.ch_mult = 2 # channel multiplier for hidden channels in wavenet | ||
self.groups = 306 | ||
self.kernel_size = 2 # convolutional kernel size | ||
self.timesteps = 1 # how many timesteps in the future to forecast | ||
self.sample_rate = [0, 256] # start and end of timesteps within trials | ||
self.rf = 128 # receptive field of wavenet, 2*rf - 1 | ||
rf = 128 | ||
ks = self.kernel_size | ||
nl = int(np.log(rf) / np.log(ks)) | ||
dilations = [ks**i for i in range(nl)] | ||
self.dilations = dilations + dilations # dilation: 2^num_layers | ||
#self.dilations = [1] + [2] + [4] * 7 # costum dilations | ||
|
||
# classifier arguments | ||
self.wavenet_class = None # class of wavenet model | ||
self.load_conv = False # where to load neural nerwork | ||
# dimensionality reduction from | ||
self.pred = False # whether to use wavenet in prediction mode | ||
self.init_model = True # whether to reinitialize classifier | ||
self.reg_semb = True # whether to regularize subject embedding | ||
self.fixed_wavenet = False # whether to fix weights of wavenet | ||
self.alpha_norm = 0.0 # regularization multiplier on weights | ||
self.num_classes = 119 # number of classes for classification | ||
self.units = [800, 300] # hidden layer sizes of fully-connected block | ||
self.dim_red = 16 # number of pca components for channel reduction | ||
self.stft_freq = 0 # STFT frequency index for LDA_wavelet_freq model | ||
self.decode_peak = 0.1 | ||
|
||
# GPT2 arguments | ||
n_embd = 768 | ||
self.gpt2_config = GPT2Config( | ||
vocab_size=50257, | ||
n_positions=1024, | ||
n_embd=n_embd, | ||
n_layer=12, | ||
n_head=12, | ||
resid_pdrop=0.1, | ||
embd_pdrop=0.1, | ||
attn_pdrop=0.1, | ||
use_cache=False | ||
) | ||
|
||
# quantized wavenet arguments | ||
self.skips_shift = 1 | ||
self.mu = 255 | ||
self.residual_channels = 128 | ||
self.dilation_channels = 128 | ||
self.skip_channels = 512 | ||
self.channel_emb = n_embd | ||
self.class_emb = n_embd | ||
self.quant_emb = n_embd | ||
self.pos_emb = n_embd | ||
self.cond_channels = self.class_emb + self.embedding_dim | ||
self.head_channels = 256 | ||
self.conv_bias = False | ||
|
||
# dataset arguments | ||
data_path = os.path.join('/', 'gpfs2', 'well', 'woolrich', 'projects', | ||
'cichy118_cont', 'preproc_data_osl', 'subj1') | ||
self.data_path = [[os.path.join(data_path, 'subj1_50hz.npy')]] # path(s) to data directory | ||
self.num_channels = list(range(614)) # channel indices | ||
self.numpy = True # whether data is saved in numpy format | ||
self.crop = 1 # cropping ratio for trials | ||
self.whiten = False # pca components used in whitening | ||
self.group_whiten = False # whether to perform whitening at the GL | ||
self.split = np.array([0, 0.1]) # validation split (start, end) | ||
self.sr_data = 100 # sampling rate used for downsampling | ||
self.original_sr = 1000 | ||
self.save_data = True # whether to save the created data | ||
self.bypass = False | ||
self.subjects_data = False # list of subject inds to use in group data | ||
self.save_whiten = False | ||
self.num_clip = 4 | ||
self.dump_data = [os.path.join(data_path, '50hz100hz_quantized_clamp4')] # path(s) for dumping data | ||
self.load_data = self.dump_data # path(s) for loading data files | ||
|
||
# analysis arguments | ||
self.closest_chs = 20 # channel neighbourhood size for spatial PFI | ||
self.PFI_inverse = False # invert which channels/timesteps to shuffle | ||
self.pfich_timesteps = [0, 256] # time window for spatiotemporal PFI | ||
self.PFI_perms = 20 # number of PFI permutations | ||
self.halfwin = 5 # half window size for temporal PFI | ||
self.halfwin_uneven = False # whether to use even or uneven window | ||
self.generate_noise = 1 # noise used for wavenet generation | ||
self.generate_length = self.sr_data * 1000 # generated timeseries len | ||
self.generate_mode = 'recursive' # IIR or FIR mode for wavenet generation | ||
self.generate_input = 'data' # input type for generation | ||
self.generate_sampling = 'top-p' | ||
self.top_p = 0.8 | ||
self.individual = True # whether to analyse individual kernels | ||
self.anal_lr = 0.001 # learning rate for input backpropagation | ||
self.anal_epochs = 200 # number of epochs for input backpropagation | ||
self.norm_coeff = 0.0001 # L2 of input for input backpropagation | ||
self.kernel_limit = 300 # max number of kernels to analyse | ||
|
||
# simulation arguments | ||
self.nonlinear_prenoise = True | ||
self.nonlinear_data = True | ||
self.seconds = 3000 | ||
self.events = 8 | ||
self.sim_num_channels = 1 | ||
self.sim_ar_order = 2 | ||
self.gamma_shape = 14 | ||
self.gamma_scale = 14 | ||
self.noise_std = 2.5 | ||
self.lambda_exp = 0.005 | ||
self.ar_shrink = 1.0 | ||
self.freqs = [] | ||
self.ar_noise_std = np.random.rand(self.events) / 5 + 0.8 | ||
self.max_len = 1000 | ||
|
||
# AR model arguments | ||
self.order = 20 | ||
self.uni = True | ||
self.save_AR = True | ||
self.do_anal = False | ||
self.AR_load_path = [os.path.join( # path(s) to save model and others | ||
'results', | ||
'cichy_epoched', | ||
'subj1', | ||
'cont_quantized', | ||
'AR_uni')] | ||
|
||
# unused | ||
self.num_plot = 1 | ||
self.plot_ch = 1 | ||
self.linear = False | ||
self.num_samples_CPC = 20 | ||
self.dropout2d_bad = False | ||
self.k_CPC = 1 | ||
self.conv1x1_groups = 1 | ||
self.pos_enc_type = 'cat' | ||
self.pos_enc_d = 128 | ||
self.l1_loss = False | ||
self.norm_alpha = self.alpha_norm | ||
self.num_components = 0 | ||
self.resample = 7 | ||
self.save_norm = True | ||
self.norm_path = os.path.join(data_path, 'norm_coeff') | ||
self.pca_path = os.path.join(data_path, 'pca128_model') | ||
self.load_pca = False | ||
self.compare_model = False | ||
self.channel_idx = 0 |