### Subject Classifier

In [None]:
class SACLAdversary(torch.nn.Module):
    """
    see  Figure 1 in arxiv.org/pdf/2007.04871.pdf for diagram
    """
    def __init__(self, embed_dim, num_subjects, dropout_rate=0.5):
        super(SACLAdversary, self).__init__()
        self.model = torch.nn.Sequential(torch.nn.Linear(embed_dim, embed_dim//2), 
                                                torch.nn.ReLU(), 
                                                torch.nn.Linear(embed_dim//2, embed_dim//2), 
                                                torch.nn.ReLU(), 
                                                torch.nn.Linear(embed_dim//2, embed_dim//2), 
                                                torch.nn.ReLU(), 
                                                torch.nn.Linear(embed_dim//2, num_subjects), 
                                                torch.nn.Sigmoid() # ADDED BY ZAC TO ADDRESS NANs IN ADVERSARIAL LOSS
        )
        pass
    
    def forward(self, x):
        return self.model(x)

In [None]:
# adversary -> subject classifier
adversary = SACLAdversary(embed_dim, num_subjects, dropout_rate=dropout_rate).to(device)
if former_adversary_state_dict_file is not None:
    adversary.load_state_dict(torch.load(former_adversary_state_dict_file))

### Contrastive Adversarial Loss

In [None]:
class SAContrastiveAdversarialLoss(nn.Module):
    """
    see Section 3.1 of arxiv.org/pdf/2007.04871.pdf
    """
    def __init__(self, temperature, adversarial_weighting_factor=1):
        super(SAContrastiveAdversarialLoss, self).__init__()
        self.BATCH_DIM = 0
        self.tau = temperature
        self.lam = adversarial_weighting_factor
        self.cos_sim = torch.nn.CosineSimilarity(0)
        self.log_noise = 1e-12 # 8 # see https://stackoverflow.com/questions/40050397/deep-learning-nan-loss-reasons
        # self.contrastive_loss = ContrastiveLoss(temperature)
        pass
    
    def forward(self, z1s, z2s, z1_c_outs, z1_subject_labels):
        """
        z1s represents the (batched) representation(s) of the t1-transformed signal(s)
        z2s represents the (batched) representation(s) of the t2-transformed signal(s)
        z1_c_outs represents the (batched) subject predictions produced by the adversary
        z1_subject_labels represents the (batched) subject labels, representing the ground truth for the adversary

        see Sectoin 3.1 of arxiv.org/pdf/2007.04871.pdf
        """
        z1_c_outs = torch.nn.functional.normalize(z1_c_outs, p=2, dim=1) # see https://discuss.pytorch.org/t/how-to-normalize-embedding-vectors/1209

        loss = 0.
        curr_batch_size = z1s.size(self.BATCH_DIM)

        # get contrastive loss of representations
        # loss += self.contrastive_loss(z1s, z2s)
        z1s = z1s.view(curr_batch_size, -1)
        z2s = z2s.view(curr_batch_size, -1)

        for i in range(curr_batch_size):
            # see https://pytorch.org/docs/stable/generated/torch.nn.CosineSimilarity.html

            # compute loss contributions of t1-to-other pairings
            numerator1 = torch.exp(self.cos_sim(z1s[i,:], z2s[i,:]) / self.tau)
            denominator1 = 0.
            for k in range(curr_batch_size):
                denominator1 += torch.exp(self.cos_sim(z1s[i,:], z2s[k,:]) / self.tau) # compare t1 ith signal with all t2 signals
                if k != i:                                                             # compare t1 ith signal to all other t1 signals, skipping the t1 ith signal
                    denominator1 += torch.exp(self.cos_sim(z1s[i,:], z1s[k,:]) / self.tau)
            loss += -1.*torch.log(self.log_noise + numerator1/denominator1)
            # print("SAContrastiveAdversarialLoss.forward: \t loss == ", loss)
            
            # SKIP loss contributions of t2-to-other pairings because they came from momentum-updated network
            # numerator2 = torch.exp(self.cos_sim(z2s[i,:], z1s[i,:]) / self.tau)
            # denominator2 = 0.
            # for k in range(curr_batch_size):
            #     denominator2 += torch.exp(self.cos_sim(z2s[i,:], z1s[k,:]) / self.tau) # compare augmented ith signal with all orig signals
            #     if k != i:                                                             # compare augmented ith signal to all other augmented signals, skipping the augmented ith signal
            #         denominator2 += torch.exp(self.cos_sim(z2s[i,:], z2s[k,:]) / self.tau)
            # loss += -1.*torch.log(numerator2/denominator2)

        loss = loss / (curr_batch_size*(2.*curr_batch_size - 1.)) # loss / (curr_batch_size*2.*(2.*curr_batch_size - 1.)) # take the average loss across the t1 signals 

        for i in range(curr_batch_size):
            j = torch.argmax(z1_subject_labels[i,:])
            loss += self.lam *(-1.)*torch.log(self.log_noise + (1. - z1_c_outs[i,j])) # see equation 3 of arxiv.org/pdf/2007.04871.pdf
            # print("SAContrastiveAdversarialLoss.forward: \t loss == ", loss)
        
        return loss

    def get_number_of_correct_reps(self, z1s, z2s, z1_c_outs, z1_subject_labels):
        curr_batch_size = z1s.size(self.BATCH_DIM)

        z1s = z1s.view(curr_batch_size, -1)
        z2s = z2s.view(curr_batch_size, -1)

        num_correct_reps = 0.
        for i in range(curr_batch_size):
            # see https://pytorch.org/docs/stable/generated/torch.nn.CosineSimilarity.html

            # compute accuracy contributions of orig-to-other pairings
            sim_measure_of_interest = self.cos_sim(z1s[i,:], z2s[i,:])
            representation_is_correct = True
            for k in range(curr_batch_size):
                other_sim_measure = self.cos_sim(z1s[i,:], z2s[k,:]) # compare t1 ith signal with all augmented signals
                if other_sim_measure > sim_measure_of_interest:
                    representation_is_correct = False
                    break
                if k != i:                                           # compare t1 ith signal to all other orig signals, skipping the t1 ith signal
                    other_sim_measure = self.cos_sim(z1s[i,:], z1s[k,:])
                    if other_sim_measure > sim_measure_of_interest:
                        representation_is_correct = False
                        break
                
            if torch.argmax(z1_subject_labels[i,:]) == torch.argmax(z1_c_outs[i,:]):
                representation_is_correct = False

            if representation_is_correct:
                num_correct_reps += 1.
            
            # SKIP loss contributions of t2-to-other pairings because they were generated by momentum-updated network
            # sim_measure_of_interest = self.cos_sim(z2s[i,:], z1s[i,:])
            # representation_is_correct = True
            # for k in range(curr_batch_size):
            #     other_sim_measure += self.cos_sim(z2s[i,:], z1s[k,:]) # compare augmented ith signal with all orig signals
            #     if other_sim_measure > sim_measure_of_interest:
            #         representation_is_correct = False
            #         break
            #     if k != i:                                                             # compare augmented ith signal to all other augmented signals, skipping the augmented ith signal
            #         other_sim_measure += self.cos_sim(z2s[i,:], z2s[k,:])
            #         if other_sim_measure > sim_measure_of_interest:
            #             representation_is_correct = False
            #             break
            # if representation_is_correct:
            #     num_correct_reps += 1.

        return num_correct_reps

### How to calculate Adversarial Loss

In [None]:
class SAAdversarialLoss(nn.Module):
    """
    see Section 3.1 of arxiv.org/pdf/2007.04871.pdf
    """
    def __init__(self):
        super(SAAdversarialLoss, self).__init__()
        self.BATCH_DIM = 0
        self.log_noise = 1e-12 # 8 # see https://stackoverflow.com/questions/40050397/deep-learning-nan-loss-reasons
        pass
    
    def forward(self, z1_c_outs, z1_subject_labels):
        """
        z1_c_outs represents the (batched) subject predictions produced by the adversary
        z1_subject_labels represents the (batched) subject labels, representing the ground truth for the adversary

        see Sectoin 3.1 of arxiv.org/pdf/2007.04871.pdf
        """
        # print("z1_c_outs.shape == ", z1_c_outs.shape)
        # print("z1_c_outs == ", z1_c_outs)
        z1_c_outs = torch.nn.functional.normalize(z1_c_outs, p=2, dim=1) # see https://discuss.pytorch.org/t/how-to-normalize-embedding-vectors/1209
        # print("z1_c_outs == ", z1_c_outs)
        # print("z1_c_outs.shape == ", z1_c_outs.shape)

        loss = 0.
        curr_batch_size = z1_c_outs.size(self.BATCH_DIM)

        for i in range(curr_batch_size):
            j = torch.argmax(z1_subject_labels[i,:])
            loss += -1.*torch.log(self.log_noise + z1_c_outs[i,j]) # see equation 3 of arxiv.org/pdf/2007.04871.pdf
            # print("SAAdversarialLoss.forward: \t loss == ", loss, " (i,j) == ", (i,j), " z1_c_outs[i,j] == ", z1_c_outs[i,j])
        # raise NotImplementedError()
        return loss

### How to utilize adversarial loss in the training

In [None]:
def train_SA_model(save_dir_for_model, model_file_name="final_SA_model.bin", batch_size=256, shuffle=True, # hyper parameters for training loop
                    max_epochs=100, learning_rate=5e-4, beta_vals=(0.9, 0.999), weight_decay=0.001, #num_workers=4, 
                    max_evals_after_saving=6, save_freq=20, former_state_dict_file=None, ct_dim=None, h_dim=None, 
                    channels=11, temporal_len=3000, dropout_rate=0.5, embed_dim=100, encoder_type=None, bw=5, # hyper parameters for SA Model
                    randomized_augmentation=False, num_upstream_decode_features=32, temperature=0.05, NUM_AUGMENTATIONS=2, perturb_orig_signal=True, former_adversary_state_dict_file=None, adversarial_weighting_factor=1., momentum=0.999, # hyper parameters for SA Model
                    cached_datasets_list_dir=None, total_points_val=2000, tpos_val=None, tneg_val=None, window_size=3, #hyper parameters for data loaders
                    sfreq=1000, Nc=None, Np=None, Nb=None, max_Nb_iters=None, total_points_factor=None, 
                    windowed_data_name="_Windowed_Pretext_Preprocess.npy", 
                    windowed_start_time_name="_Windowed_StartTime.npy", data_folder_name="Mouse_Training_Data", 
                    data_root_name="Windowed_Data", file_names_list="training_names.txt", train_portion=0.7, 
                    val_portion=0.2, test_portion=0.1, random_seed=0):
    
    # First, load the training, validation, and test sets
    train_set, val_set, test_set = load_SSL_Dataset_Based_On_Subjects('SA', 
                                                    cached_datasets_list_dir=cached_datasets_list_dir, 
                                                    total_points_val=total_points_val, 
                                                    tpos_val=tpos_val, 
                                                    tneg_val=tneg_val, 
                                                    window_size=window_size, 
                                                    sfreq=sfreq, 
                                                    Nc=Nc, 
                                                    Np=Np, 
                                                    Nb=Nb, # this used to be 2 not 4, but 4 would work better
                                                    max_Nb_iters=max_Nb_iters, 
                                                    total_points_factor=total_points_factor, 
                                                    bw=bw,                                              # items for SA data loading
                                                    randomized_augmentation=randomized_augmentation,    # items for SA data loading
                                                    num_channels=channels,                              # items for SA data loading
                                                    temporal_len=temporal_len,                          # items for SA data loading
                                                    NUM_AUGMENTATIONS=NUM_AUGMENTATIONS,                # items for SA data loading
                                                    perturb_orig_signal=perturb_orig_signal,            # items for SA data loading
                                                    windowed_data_name=windowed_data_name,
                                                    windowed_start_time_name=windowed_start_time_name,
                                                    data_folder_name=data_folder_name, 
                                                    data_root_name=data_root_name, 
                                                    file_names_list=file_names_list, 
                                                    train_portion=train_portion, 
                                                    val_portion=val_portion, 
                                                    test_portion=test_portion, 
                                                    random_seed=random_seed
    )

    # initialize data loaders for training
    train_loader = torch.utils.data.DataLoader(train_set, 
                                                batch_size=batch_size, 
                                                shuffle=shuffle#, num_workers=num_workers # see https://www.programmersought.com/article/93393550792/
    )
    val_loader = torch.utils.data.DataLoader(val_set, 
                                             batch_size=batch_size, 
                                             shuffle=shuffle#, num_workers=num_workers
    )

    # print("train_SA_model: len of the train_loader is ", len(train_loader))

    # cuda setup if allowed
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Pytorch v0.4.0

    # initialize models - see Figure 1 of arxiv.org/pdf/2007.04871.pdf
    model = SACLNet(channels, temporal_len, dropout_rate=dropout_rate, embed_dim=embed_dim, num_upstream_decode_features=num_upstream_decode_features)
    if former_state_dict_file is not None:
        model.load_state_dict(torch.load(former_state_dict_file))
    momentum_model = copy.deepcopy(model) # see https://discuss.pytorch.org/t/copying-weights-from-one-net-to-another/1492 and https://www.geeksforgeeks.org/copy-python-deep-copy-shallow-copy/
    model = model.to(device)
    momentum_model = momentum_model.to(device)

    _, _, y0 = next(iter(train_loader))
    assert len(y0.shape) == 2 
    num_subjects = y0.shape[1]
    adversary = SACLAdversary(embed_dim, num_subjects, dropout_rate=dropout_rate).to(device)
    if former_adversary_state_dict_file is not None:
        adversary.load_state_dict(torch.load(former_adversary_state_dict_file))

    print("train_SA_model: START OF TRAINING")
    # initialize training state
    min_val_inaccuracy = float("inf")
    min_state = None
    num_evaluations_since_model_saved = 0
    saved_model = None
    saved_momentum_model = None
    loss_fn = SAContrastiveAdversarialLoss(temperature, adversarial_weighting_factor=adversarial_weighting_factor)
    # learning_rate = learning_rate
    # beta_vals = beta_vals
    optimizer = torch.optim.Adam(model.parameters(), betas=beta_vals, lr=learning_rate, weight_decay=weight_decay)

    saved_adversary = None
    ##############################################################################################################################
    adversarial_loss_fn = SAAdversarialLoss()
    adversarial_optimizer = torch.optim.Adam(adversary.parameters(), betas=beta_vals, lr=learning_rate, weight_decay=weight_decay)
    ##############################################################################################################################

    # Iterate over epochs
    avg_train_losses = []
    avg_train_accs = []
    avg_val_accs = []
    avg_adversary_train_losses = []
    avg_adversary_train_accs = []
    avg_adversary_val_accs = []
    for epoch in range(max_epochs):
        # print("train_SA_model: epoch ", epoch, " of ", max_epochs)

        model.train()
        momentum_model.train()
        adversary.train()

        running_train_loss = 0
        num_correct_train_preds = 0
        total_num_train_preds = 0
        running_adversary_train_loss = 0
        num_adversary_correct_train_preds = 0
        total_num_adversary_train_preds = 0
        
        # iterate over training batches
        # print("train_SA_model: \tNow performing training updates")
        counter = 0
        for x_t1, x_t2, y in train_loader:
            # transfer to GPU
            x_t1, x_t2, y = x_t1.to(device), x_t2.to(device), y.to(device)
            # print("x_t1 == ", x_t1.shape)
            # print("x_t2 == ", x_t2.shape)
            # print("y == ", y.shape)

            ##############################################################################################################################
            # UPDATE ADVERSARY
            for p in model.parameters():
                p.requires_grad = False
            for p in momentum_model.parameters():
                p.requires_grad = False
            for p in adversary.parameters():
                p.requires_grad = True
            
            adversarial_optimizer.zero_grad()
            
            x_t1_initial_reps = model.embed_model(x_t1) # x_t1_initial_reps -> hidden vector
            x_t1_initial_subject_preds = adversary(x_t1_initial_reps) # Input hidden vector in subject classifier
            # x_t1_initial_subject_preds

            adversarial_loss = adversarial_loss_fn(x_t1_initial_subject_preds, y) # SAAdversarialLoss(), y is a subject-class label
            num_adversary_correct_train_preds += adversarial_loss_fn.get_number_of_correct_preds(x_t1_initial_subject_preds, y)
            total_num_adversary_train_preds += len(x_t1_initial_subject_preds)

            adversarial_loss.backward()
            adversarial_optimizer.step()

            running_adversary_train_loss += adversarial_loss.item()

            del x_t1_initial_reps
            del x_t1_initial_subject_preds
            del adversarial_loss
            torch.cuda.empty_cache()
            ##############################################################################################################################
        
            # UPDATE MODEL - references Algorithm 1 of arxiv.org/pdf/1911.05722.pdf and Figure 1 of arxiv.org/pdf/2007.04871.pdf
            for p in model.parameters():
                p.requires_grad = True
            for p in momentum_model.parameters():
                p.requires_grad = False
            for p in adversary.parameters():
                p.requires_grad = False

            # zero out any pre-existing gradients
            optimizer.zero_grad()

            # make prediction and compute resulting loss
            # print("train_SA_model: \t\tembedding x1")
            x1_rep = model(x_t1)
            # print("train_SA_model: \t\tembedding x2")
            x2_rep = momentum_model(x_t2)
            # x2_rep.detatch()
            # print("x1_rep == ", x1_rep.shape)
            # print("x2_rep == ", x2_rep.shape)
            x1_embeds = model.embed_model(x_t1)
            x1_subject_preds = adversary(x1_embeds)
            # print("train_SA_model: \t\tcomputing loss")
            loss = loss_fn(x1_rep, x2_rep, x1_subject_preds, y) # SAContrastiveAdversarialLoss.forward(z1s, z2s, z1_c_outs, z1_subject_labels)
            # print("loss == ", loss)

            # compute accuracy
            # print("train_SA_model: \t\tcomputing accuracy")
            num_correct_train_preds += loss_fn.get_number_of_correct_reps(x1_rep, x2_rep, x1_subject_preds, y)
            # print("train_SQ_model: \t\trecording accuracy")
            total_num_train_preds += len(x1_rep)

            # update weights
            # print("train_SA_model: \t\tperforming backprop")
            loss.backward()
            # print("train_SA_model: \t\tupdating weights")
            optimizer.step()

            # track loss
            # print("train_SA_model: \t\trecording loss val")
            running_train_loss += loss.item()

            # UPDATE MOMENTUM MODEL
            # momentum_model.parameters = momentum*momentum_model.parameters + (1.-momentum)*model.parameters
            momentum_model = momentum_model_parameter_update(momentum, momentum_model, model)

            # free up cuda memory
            # print("train_SA_model: \t\tclearing memory")
            del x_t1
            del x_t2
            del x1_rep
            del x2_rep
            del x1_embeds
            del x1_subject_preds
            del loss
            torch.cuda.empty_cache()

            # if counter % 50 == 0:
            #     print("train_SA_model: \t\tFinished batch ", counter)
            counter += 1
            # if counter == 5:
            #     raise NotImplementedError()
            # raise NotImplementedError()
            # break # FOR DEBUGGING PURPOSES
        
        # iterate over validation batches
        # print("train_SA_model: \tNow performing validation")
        num_correct_val_preds = 0
        total_num_val_preds = 0
        num_correct_adversarial_val_preds = 0
        total_num_adversarial_val_preds = 0
        with torch.no_grad():
            model.eval()
            momentum_model.eval()
            adversary.eval()

            for x_t1, x_t2, y in val_loader:
                x_t1, x_t2, y = x_t1.to(device), x_t2.to(device), y.to(device)

                # evaluate model and adversary
                x1_rep = model(x_t1)
                x2_rep = momentum_model(x_t2)
                x1_embeds = model.embed_model(x_t1)
                x1_subject_preds = adversary(x1_embeds)
                # x1_subject_preds = adversary(x1_rep)

                num_correct_val_preds += loss_fn.get_number_of_correct_reps(x1_rep, x2_rep, x1_subject_preds, y)
                total_num_val_preds += len(x1_rep)

                num_correct_adversarial_val_preds += adversarial_loss_fn.get_number_of_correct_preds(x1_subject_preds, y)
                total_num_adversarial_val_preds += len(x1_subject_preds)

                # free up cuda memory
                del x_t1
                del x_t2
                del x1_rep
                del x2_rep
                del x1_embeds
                del x1_subject_preds
                torch.cuda.empty_cache()
                # break # FOR DEBUGGING PURPOSES
        
        # record averages
        avg_train_accs.append(num_correct_train_preds / total_num_train_preds)
        avg_val_accs.append(num_correct_val_preds / total_num_val_preds)
        avg_train_losses.append(running_train_loss / len(train_loader))
        
        avg_adversary_train_accs.append(num_adversary_correct_train_preds / total_num_adversary_train_preds)
        avg_adversary_val_accs.append(num_correct_adversarial_val_preds / total_num_adversarial_val_preds)
        avg_adversary_train_losses.append(running_adversary_train_loss / len(train_loader))
        
        # check stopping criterion / save model
        incorrect_val_percentage = 1. - (num_correct_val_preds / total_num_val_preds)
        if incorrect_val_percentage < min_val_inaccuracy:
            num_evaluations_since_model_saved = 0
            min_val_inaccuracy = incorrect_val_percentage
            saved_model = model.state_dict()
            saved_momentum_model = momentum_model.state_dict()
            saved_adversary = adversary.state_dict()
        else:
            num_evaluations_since_model_saved += 1
            if num_evaluations_since_model_saved >= max_evals_after_saving:
                print("train_SA_model: EARLY STOPPING on epoch ", epoch)
                break
        
        # save intermediate state_dicts just in case
        if epoch % save_freq == 0:
            temp_model_save_path = os.path.join(save_dir_for_model, "temp_full_SA_model_epoch"+str(epoch)+".bin")
            torch.save(model.state_dict(), temp_model_save_path)
            
            temp_model_save_path = os.path.join(save_dir_for_model, "temp_full_SA_momentum_model_epoch"+str(epoch)+".bin")
            torch.save(momentum_model.state_dict(), temp_model_save_path)
            
            temp_model_save_path = os.path.join(save_dir_for_model, "temp_full_SA_adversary_epoch"+str(epoch)+".bin")
            torch.save(adversary.state_dict(), temp_model_save_path)

            embedder_save_path = os.path.join(save_dir_for_model, "temp_embedder_epoch"+str(epoch)+".bin")
            torch.save(model.embed_model.state_dict(), embedder_save_path)

            plot_avgs(avg_train_losses, avg_train_accs, avg_val_accs, "model_epoch"+str(epoch), save_dir_for_model)
            plot_avgs(avg_adversary_train_losses, avg_adversary_train_accs, avg_adversary_val_accs, "adversary_epoch"+str(epoch), save_dir_for_model)
        # break # FOR DEBUGGING PURPOSES

    print("train_SA_model: END OF TRAINING - now saving final model / other info")

    # save final model(s)
    model.load_state_dict(saved_model)
    model_save_path = os.path.join(save_dir_for_model, model_file_name)
    torch.save(model.state_dict(), model_save_path)

    momentum_model.load_state_dict(saved_momentum_model)
    model_save_path = os.path.join(save_dir_for_model, "momentum_model_"+model_file_name)
    torch.save(model.state_dict(), model_save_path)

    adversary.load_state_dict(saved_adversary)
    model_save_path = os.path.join(save_dir_for_model, "adversary_"+model_file_name)
    torch.save(model.state_dict(), model_save_path)

    embedder_save_path = os.path.join(save_dir_for_model, "embedder_"+model_file_name)
    torch.save(model.embed_model.state_dict(), embedder_save_path)

    meta_data_save_path = os.path.join(save_dir_for_model, "meta_data_and_hyper_parameters.pkl")
    with open(meta_data_save_path, 'wb') as outfile:
        pkl.dump({
            "avg_train_losses": avg_train_losses, 
            "avg_train_accs": avg_train_accs, 
            "avg_val_accs": avg_val_accs, 
            "avg_adversary_train_losses": avg_adversary_train_losses, 
            "avg_adversary_train_accs": avg_adversary_train_accs, 
            "avg_adversary_val_accs": avg_adversary_val_accs, 
            "save_dir_for_model": save_dir_for_model, 
            "model_file_name": model_file_name, 
            "batch_size": batch_size, 
            "shuffle": shuffle, #"num_workers": num_workers, 
            "max_epochs": max_epochs, 
            "learning_rate": learning_rate, 
            "beta_vals": beta_vals, 
            "weight_decay": weight_decay, 
            "max_evals_after_saving": max_evals_after_saving, 
            "save_freq": save_freq, 
            "former_state_dict_file": former_state_dict_file, 
            "ct_dim": ct_dim, 
            "h_dim": h_dim, 
            "channels": channels, 
            "temporal_len": temporal_len, 
            "dropout_rate": dropout_rate, 
            "embed_dim": embed_dim,
            "encoder_type": encoder_type, 
            "bw": bw, 
            "randomized_augmentation": randomized_augmentation, 
            "num_upstream_decode_features": num_upstream_decode_features, 
            "temperature": temperature, 
            "NUM_AUGMENTATIONS": NUM_AUGMENTATIONS, 
            "perturb_orig_signal": perturb_orig_signal, 
            "former_adversary_state_dict_file": former_adversary_state_dict_file, 
            "adversarial_weighting_factor": adversarial_weighting_factor, 
            "momentum": momentum, 
            "cached_datasets_list_dir": cached_datasets_list_dir, 
            "total_points_val": total_points_val, 
            "tpos_val": tpos_val, 
            "tneg_val": tneg_val, 
            "window_size": window_size,
            "sfreq": sfreq, 
            "Nc": Nc, 
            "Np": Np, 
            "Nb": Nb,
            "max_Nb_iters": max_Nb_iters, 
            "total_points_factor": total_points_factor, 
            "windowed_data_name": windowed_data_name,
            "windowed_start_time_name": windowed_start_time_name,
            "data_folder_name": data_folder_name, 
            "data_root_name": data_root_name, 
            "file_names_list": file_names_list, 
            "train_portion": train_portion, 
            "val_portion": val_portion, 
            "test_portion": test_portion, 
            "random_seed": random_seed, 
        }, outfile)

    plot_avgs(avg_train_losses, avg_train_accs, avg_val_accs, "Final_Model", save_dir_for_model)
    plot_avgs(avg_adversary_train_losses, avg_adversary_train_accs, avg_adversary_val_accs, "Final_Adversary", save_dir_for_model)
    
    print("train_SA_model: DONE!")