In [1]:
# -*-coding:utf-8-*-
import torch
from torch.autograd import grad
import os
import math
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from tqdm import tqdm

import gc
torch.autograd.set_detect_anomaly(True)
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
device = 'cuda' if torch.cuda.is_available() else 'cpu'
grad_layer = 4

def pretty(vector):
    if type(vector) is list:
        vlist = vector
    elif type(vector) is np.ndarray:
        vlist = vector.reshape(-1).tolist()
    else:
        vlist = vector.view(-1).tolist()

    return "[" + ", ".join("{:+.4f}".format(vi) for vi in vlist) + "]"


# Optimizer Class to maximize loss of adversarial dataset
class Adam:
    def __init__(self, learning_rate=1e-3, beta1=0.9, beta2=0.9, epsilon=1e-8):
        self.device = torch.device(device)
        self.lr = learning_rate
        self.beta1 = beta1
        self.beta2 = beta2
        self.epsilon = epsilon
        self.m = None
        self.v = None
        self.m_hat = None
        self.v_hat = None
        self.initialize = False

    # Grad = adv_gradient
    # Iternum = iteration
    # theta = adv_images
    # Gradient ascent
    def update(self, grad, iternum, theta):
        if not self.initialize:
            self.m = (1 - self.beta1) * grad
            self.v = (1 - self.beta2) * grad ** 2
            self.initialize = True
        else:
            assert self.m.shape == grad.shape
            self.m = self.beta1 * self.m + (1 - self.beta1) * grad
            self.v = self.beta2 * self.v + (1 - self.beta2) * grad ** 2

        self.m_hat = self.m / (1 - self.beta1 ** iternum)
        self.v_hat = self.v / (1 - self.beta2 ** iternum)
        return theta + self.lr * self.m_hat / (self.epsilon + torch.sqrt(self.v_hat))


class IRNet_intorch(torch.nn.Module):
    #'128-64-16'    
    def __init__(self, input_size):
        super(IRNet_intorch, self).__init__()
        self.fc128 =nn.Linear(128, 128)
        self.fc64 =nn.Linear(64, 64)
        self.fc16 =nn.Linear(16, 16)

        self.bn128 =nn.BatchNorm1d(128)
        self.bn64 =nn.BatchNorm1d(64)
        self.bn16 =nn.BatchNorm1d(16)
      
        self.relu = nn.ReLU()
        self.inputlayer = nn.Linear(input_size, 128)

        self.con128_64 = nn.Linear(128, 64)
        self.con64_16 = nn.Linear(64,16)
        self.output16 = nn.Linear(16,1)
   
    
    
    
    def forward(self, x):
        x = self.inputlayer(x)

        x_res = x
        x = self.fc128(x)
        x = self.bn128(x)
        x = self.relu(x)
        x = x+x_res
        x = self.con128_64(x)
    
        x_res = x
        x = self.fc64(x)
        x = self.bn64(x)
        x = self.relu(x)
        x = x+x_res
        x = self.con64_16(x)

        x_res = x
        x = self.fc16(x)
        x = self.bn16(x)
        x = self.relu(x)
        x = x+x_res   

        x = self.output16(x)
        return x

class StableAL():
    def __init__(self, environment):
        self.weights = None
        self.model = None
        self.weight_grad = None
        self.xa_grad = None
        self.theta_grad = None
        self.gamma = None
        self.adversarial_data = None
        self.loss_criterion = torch.nn.MSELoss()

        self.adv_based_on = None
        self.adv_again = None

        self.X = None
        self.y = None
       
        # init
        # Number of covariates
        dim_x=150
        self.model = IRNet_intorch(dim_x).to(device)
        # Covariate Weights
        self.weights = torch.zeros(dim_x).reshape(-1, 1) + 100.0
        self.weights = self.weights.to(device)
    def cost_function(self, x, x_adv):
        # Variable cost level where the weights determine the cost level
        cost = torch.mean(((x - x_adv) ** 2).mm(self.weights)).to(device)
        return cost

    # Loss across Training environments
    # Self.loss_criterion = MSELoss
    def r(self, environments, alpha=10.0):
        result = 0.0
        env_loss = []
        for x_e, y_e in environments:
            x_e =x_e.to(device)
            y_e =y_e.to(torch.float32).to(device)
            env_loss.append(self.loss_criterion(self.model(x_e), y_e))
        env_loss = torch.Tensor(env_loss)
        max_index = torch.argmax(env_loss)
        min_index = torch.argmin(env_loss)

        for idx, (x_e, y_e) in enumerate(environments):
            x_e =x_e.to(device)
            y_e =y_e.to(torch.float32).to(device)
            if idx == max_index:
                result += (alpha+1)*self.loss_criterion(self.model(x_e), y_e)
            elif idx == min_index:
                result += (1-alpha)*self.loss_criterion(self.model(x_e), y_e)
            else:
                result += self.loss_criterion(self.model(x_e),y_e)
        return result

  
    # generate adversarial data
    # Maximize the loss using their own ADAM.update method(their own optimizer)
    def attack(self, gamma, data, step):
        attack_lr = 7e-3
        images, labels = data
        images_adv = images.clone().detach()
        
        optimizer = Adam(learning_rate=attack_lr)

        for i in range(step):
            if images_adv.grad is not None:
                images_adv.grad.data.zero_()


            images_adv=images_adv.to(device)
            images_adv.requires_grad_(True)
            outputs = self.model(images_adv)
           
            labels = labels.float().to(device)
            images = images.to(device)
            loss = self.loss_criterion(
                outputs, labels) - gamma * self.cost_function(images, images_adv)

            loss.backward()


            images_adv.data = optimizer.update(images_adv.grad, i + 1, images_adv)

        self.weight_grad = -2 * gamma * attack_lr * (images_adv - images)
        temp_image = images_adv.clone().detach()
        temp_label = labels.clone().detach()
        self.adversarial_data = (temp_image, temp_label)
        return images_adv, labels

   
    # Optimizes the model paremeters such that the loss is minimized
    # on the adversarial data from self.attack
    def train_theta(self, data, epochs, epoch_attack, gamma, end_flag=False):
        optimizer = optim.Adam(self.model.parameters(), lr=0.01)
        # For __ Theta Epochs
        for i in range(epochs):
            if i % 5 == 0 or not end_flag:
                images_adv, labels = self.attack(gamma, data, step=epoch_attack)

            else:
                self.adv_again = self.adversarial_data
                images_adv, labels = self.attack(gamma, self.adversarial_data, step=epoch_attack)
            self.adv_based_on = data
                
            #print(f"original data: {data[0].shape}")
            # print(f"attack data: {images_adv.shape}")
            optimizer.zero_grad()
            images_adv =images_adv.to(device)
            outputs = self.model(images_adv)
            loss = self.loss_criterion(outputs, labels.float()) 
            

            
            if self.xa_grad is None:
                dtheta_dx = []
                dloss_dtheta = grad(loss, self.model.parameters(), create_graph=True)[grad_layer].reshape(-1)

                # size dloss = model.para size
                for name1, param in self.model.named_parameters():
                    print (f"grad: {name1}          {param.shape}")

                #time.sleep(5.5)    # Pause 5.5 seconds
                print(f"dloss_dtheta.shape:  {dloss_dtheta.shape[0]}")
                for j in range(dloss_dtheta.shape[0]):
                    #print(f"dloss_dtheta.shape[0]:j     {j}")
                    dtheta_dx.append(grad(dloss_dtheta[j], images_adv, create_graph=True)[0].detach()) 
   
                self.xa_grad = torch.stack(dtheta_dx,1).detach()
                
            else:
                dloss_dtheta = grad(loss, self.model.parameters(), create_graph=True)[grad_layer].reshape(-1)
                dtheta_dx = []

                for j in range(dloss_dtheta.shape[0]):
                    dtheta_dx.append(grad(dloss_dtheta[j], images_adv, create_graph=True)[0].detach())
                self.xa_grad += torch.stack(dtheta_dx, 1).detach()
                
            #print(f"xa_grad size: {self.xa_grad.shape}")
            del dtheta_dx
            del dloss_dtheta
            torch.cuda.empty_cache()

            #print('%d | %.4f | %s'%(i, loss, pretty(self.model.layer[4].weight)))
            #if i % 1000 == 999:
              

            #print(f"loss?")
            loss.backward(retain_graph=True)
            #print(f"step?")
            optimizer.step()
            #print(f"step!")
        self.xa_grad *= (-0.01)

In [2]:
import pandas as pd
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader

# Step 1: Load the CSV file

Rsplt_testset =pd.read_csv('../dataset/Rsplt_testset.csv', index_col=None)
Xshft_testset =pd.read_csv('../dataset/Xshft_testset.csv', index_col=None)
pizeo_testset =pd.read_csv('../dataset/pizeo_testset.csv', index_col=None)
statY_testset =pd.read_csv('../dataset/statY_testset.csv', index_col=None)
infoY_testset =pd.read_csv('../dataset/infoY_testset.csv', index_col=None)
final_train   =pd.read_csv('../dataset/final_trainset.csv',index_col=None)  

Rsplt_testset1 =pd.read_csv('../dataset/Rsplt_testset1.csv', index_col=None)
Rsplt_testset2 =pd.read_csv('../dataset/Rsplt_testset2.csv', index_col=None)
Rsplt_testset3 =pd.read_csv('../dataset/Rsplt_testset3.csv', index_col=None)
Rsplt_testset4 =pd.read_csv('../dataset/Rsplt_testset4.csv', index_col=None)
Rsplt_testset5 =pd.read_csv('../dataset/Rsplt_testset5.csv', index_col=None)



# Step 3: Define a custom PyTorch Dataset class
class MyDataset(Dataset):
    def __init__(self, df):
        self.inputs = df.drop(columns=['delta_e','pretty_comp']).values
        self.labels = df['delta_e'].values
        
    def __len__(self):
        return len(self.inputs)
    
    def __getitem__(self, index):
        input = self.inputs[index].tolist()[:]
        label = self.labels[index].tolist()
        return torch.tensor(input, dtype=torch.float32), torch.tensor(label, dtype=torch.float32)
    def getSALdata(self):
        input = np.array(self.inputs[:].tolist())
        label = np.array(self.labels[:].tolist())
        return (input, label)
        
class RecurrentDataset(Dataset):
    def __init__(self, df):
        self.inputs = df['data'].values
        self.labels = df['label'].values
        
    def __len__(self):
        return len(self.inputs)
    
    def __getitem__(self, index):
        input = self.inputs[index].tolist()[:]
        label = self.labels[index].tolist()
        return torch.tensor(input, dtype=torch.float32), torch.tensor(label, dtype=torch.float32)


# Step 4: Use DataLoader to create batches
batch_size = 128

train_dataset = MyDataset(final_train)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,drop_last=True)


Rsplt_test_dataset = MyDataset(Rsplt_testset)
Xshft_test_dataset = MyDataset(Xshft_testset)
pizeo_test_dataset = MyDataset(pizeo_testset)
statY_test_dataset = MyDataset(statY_testset)
infoY_test_dataset = MyDataset(infoY_testset)

Rsplt_testset1_dataset = MyDataset(Rsplt_testset1)
Rsplt_testset2_dataset = MyDataset(Rsplt_testset2)
Rsplt_testset3_dataset = MyDataset(Rsplt_testset3)
Rsplt_testset4_dataset = MyDataset(Rsplt_testset4)
Rsplt_testset5_dataset = MyDataset(Rsplt_testset5)

Rsplt_testset1_loader = DataLoader(Rsplt_testset1_dataset, batch_size=len(Rsplt_testset1))
Rsplt_testset2_loader = DataLoader(Rsplt_testset2_dataset, batch_size=len(Rsplt_testset2))
Rsplt_testset3_loader = DataLoader(Rsplt_testset3_dataset, batch_size=len(Rsplt_testset3))
Rsplt_testset4_loader = DataLoader(Rsplt_testset4_dataset, batch_size=len(Rsplt_testset4))
Rsplt_testset5_loader = DataLoader(Rsplt_testset5_dataset, batch_size=len(Rsplt_testset5))

Rsplt_test_loader = DataLoader(Rsplt_test_dataset, batch_size=len(Rsplt_testset))
Xshft_test_loader = DataLoader(Xshft_test_dataset, batch_size=len(Xshft_testset))
pizeo_test_loader = DataLoader(pizeo_test_dataset, batch_size=len(pizeo_testset))
statY_test_loader = DataLoader(statY_test_dataset, batch_size=len(statY_testset))
infoY_test_loader = DataLoader(infoY_test_dataset, batch_size=len(infoY_testset))


data=train_dataset.getSALdata()


In [3]:
method = StableAL([train_dataset.getSALdata()])

In [4]:
import warnings
import csv

warnings.filterwarnings("ignore")

min_weight = torch.min(method.weights)
attack_gamma = (1.0 / min_weight).data
criterion = nn.MSELoss()
deltaall = 20
alpha = 0.5

epoch = 0
num_epochs=1000

zero_list = []
epoch_theta=10
epoch_attack=2

end_flag = False

method.model =IRNet_intorch(150).to(device)
method.model.optimizer = optim.Adam(method.model.parameters(), lr=0.001)

partialbatch = 5


train_best_loss      = float('inf')
partial_train_best_loss      = float('inf')
Rsplt_test_best_loss = float('inf')
Xshft_test_best_loss = float('inf')
pizeo_test_best_loss = float('inf')
statY_test_best_loss = float('inf')
infoY_test_best_loss = float('inf')

Rsplt_testset1_best_loss = float('inf')
Rsplt_testset2_best_loss = float('inf')
Rsplt_testset3_best_loss = float('inf')
Rsplt_testset4_best_loss = float('inf')
Rsplt_testset5_best_loss = float('inf')

loss_df = pd.DataFrame(columns=[
                        'epoch', 
                        'train',
                        'partialtrain',
                        'Rsplt1',
                        'Rsplt2',
                        'Rsplt3',
                        'Rsplt4',
                        'Rsplt5',
                        'RspltAVE',
                        'Xshft',
                        'pizeo',
                        'statY',
                        'infoY',
                        'BLANK',
                        'best_train',
                        'best_partialtrain',
                        'best_Rsplt1',
                        'best_Rsplt2',
                        'best_Rsplt3',
                        'best_Rsplt4',
                        'best_Rsplt5',
                        'best_Rsplt_AVE',
                        'best_Xshft',
                        'best_pizeo',
                        'best_statY',
                        'best_infoY',
                        'save',
                        'attack_gamma'
                        ])

while epoch <=num_epochs:
    train_loss = 0.0

    
    Rsplt1_test_mse_loss = 0
    Rsplt2_test_mse_loss = 0
    Rsplt3_test_mse_loss = 0
    Rsplt4_test_mse_loss = 0
    Rsplt5_test_mse_loss = 0

    Rsplt_test_mse_loss = 0
    Xshft_test_mse_loss = 0
    pizeo_test_mse_loss = 0
    statY_test_mse_loss = 0
    infoY_test_mse_loss = 0

    partial_train_loss = 0.0
    total_train_loss = 0.0

    minima = []
    optimizer = optim.Adam(method.model.parameters(), lr=0.001)

    for batch_idx, (data, target) in enumerate(train_loader):
        print(f"current in epoch    {epoch}      batch {batch_idx}")
        
         # Zero the gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = method.model(data.to(device))
        loss = criterion(outputs.to(device) , target.float().to(device))
        
        # Backward pass and optimization
        loss.backward()
        optimizer.step()
        partial_train_loss += loss.cpu().item() 
        if epoch <5:
            continue
        
        method.train_theta((data, target), epoch_theta, epoch_attack, attack_gamma, end_flag)
        rtheta = method.r([[data, target]], alpha=alpha / math.sqrt(epoch + 1))
        method.theta_grad = grad(rtheta, list(method.model.parameters()), create_graph=True, allow_unused=True)
        dr_dx = torch.matmul(method.theta_grad[grad_layer].reshape(-1), method.xa_grad).squeeze()
        deltaw = dr_dx * method.weight_grad
        deltaw = torch.sum(deltaw, 0)
        
        deltaw[zero_list] = 0.0
        max_grad = torch.max(torch.abs(deltaw))
        deltastep = deltaall
        lr_weight = (deltastep / max_grad).detach()
        print(f'RLoss: {rtheta.data}')



        if epoch %20==0:
            #save adv ori advadv data
            images, labels = method.adversarial_data 
            adv_data_list = data.numpy()
            #adv_labels_list = labels.tolist()
            
            images, labels = method.adv_based_on 
            ori_data_list = data.numpy()
            ori_labels_list = labels.numpy()
            
            #images, labels = method.adv_again 
            #advadv_data_list = data.numpy()
            #advadv_labels_list = labels.tolist()
            
            
            csv_filename = f'adv_data_labels_{epoch}_{batch_idx}.csv' if method.adv_again is None else f'adv_adv_data_labels_{epoch}_{batch_idx}.csv'
            
            with open(csv_filename, 'w', newline='') as csvfile:
                writer = csv.writer(csvfile)
            
                # Write header
                header = [f'feature_{i}' for i in range(ori_data_list.shape[1])] + ['label']
                writer.writerow(header)
            
                # Write data and labels for each group
                rows_group1 = np.column_stack((ori_data_list, ori_labels_list))
                rows_group2 = np.column_stack((adv_data_list, ori_labels_list))
            
                # Check if data_group3 is not None before including it
                if method.adv_again is not None:
                    images, labels = method.adv_again 
                    advadv_data_list = data.numpy()
                    rows_group3 = np.column_stack((advadv_data_list, ori_labels_list))
            
                    # Combine data from all groups
                    all_rows = np.vstack((rows_group1, rows_group2, rows_group3))
                else:
                    # Combine data from the first two groups only
                    all_rows = np.vstack((rows_group1, rows_group2))
            
                writer.writerows(all_rows)

        if epoch >5:

            if batch_idx ==partialbatch:# train partial trainset
                break
        
    

    partial_train_loss /= (batch_idx+1)        
    print("=="*20)
    print(f"Epoch {epoch+1}/{num_epochs} - partial_train_loss: {partial_train_loss:.4f} ")
    if epoch %5==0:
        with torch.no_grad():
                print("sorting training set")
                # for sorting Training set
                sort_MAE=pd.DataFrame(columns = ['data', 'label', 'loss'])
                method.model.eval()
                for i in range(len(train_dataset.inputs)):
                    inp = train_dataset.inputs[i]
                    tar = train_dataset.labels[i]

                    x = torch.tensor([inp.tolist()], dtype=torch.float32).to(device) 
                    y = torch.tensor(tar.tolist(), dtype=torch.float32).to(device)

                    output = method.model(x)
                    loss = criterion(output, y).cpu()
                    # Accumulate the training loss
                    train_loss += loss.item() 
                    #print(f"loss:       {loss}")
                    sort_MAE = sort_MAE.append({'data' : inp, 
                                                'label' :tar, 
                                                'loss' : loss},
                                                ignore_index = True)
                    #if i%1000==0:
                    #    print(i)
                #print(sort_MAE)
                if epoch%50==0:
                    sort_MAE.to_csv(f"./train_set_sorting_{epoch}.csv",index=False)
                new_train = sort_MAE.sort_values(by=['loss'],ascending=False)
                new_train_dataset = RecurrentDataset(new_train)
                train_loader = DataLoader(new_train_dataset, batch_size=batch_size, shuffle=False,drop_last=True)

                train_loss /= len(train_dataset.inputs)
                print(f"Epoch {epoch+1}/{num_epochs} - Training loss: {train_loss:.4f} ")
                print("=="*20)



    for batch_idx, (data, target) in enumerate(train_loader):
        
            outputs = method.model(data.to(device))
            loss = criterion(outputs , target.float().to(device))
            train_loss += loss.item() 
    train_loss /= len(train_loader)

    print(f'Epoch: [{(epoch + 1)}/{num_epochs}], TrainLoss: {train_loss}')

    for batch_idx, (data, target) in enumerate(Rsplt_testset1_loader):
        
            outputs = method.model(data.to(device))
            loss = criterion(outputs , target.float().to(device))
            Rsplt1_test_mse_loss += loss.item() 
    Rsplt1_test_mse_loss /= len(Rsplt_testset1_loader)

    for batch_idx, (data, target) in enumerate(Rsplt_testset2_loader):
        
            outputs = method.model(data.to(device))
            loss = criterion(outputs , target.float().to(device))
            Rsplt2_test_mse_loss += loss.item() 
    Rsplt2_test_mse_loss /= len(Rsplt_testset2_loader)

    for batch_idx, (data, target) in enumerate(Rsplt_testset3_loader):
        
            outputs = method.model(data.to(device))
            loss = criterion(outputs , target.float().to(device))
            Rsplt3_test_mse_loss += loss.item() 
    Rsplt3_test_mse_loss /= len(Rsplt_testset3_loader)

    for batch_idx, (data, target) in enumerate(Rsplt_testset4_loader):
    
        outputs = method.model(data.to(device))
        loss = criterion(outputs , target.float().to(device))
        Rsplt4_test_mse_loss += loss.item() 
    Rsplt4_test_mse_loss /= len(Rsplt_testset4_loader)

    for batch_idx, (data, target) in enumerate(Rsplt_testset5_loader):
    
        outputs = method.model(data.to(device))
        loss = criterion(outputs , target.float().to(device))
        Rsplt5_test_mse_loss += loss.item() 
    Rsplt5_test_mse_loss /= len(Rsplt_testset5_loader)




    for batch_idx, (data, target) in enumerate(Rsplt_test_loader):
    
        outputs = method.model(data.to(device))
        loss = criterion(outputs , target.float().to(device))
        Rsplt_test_mse_loss += loss.item() 
    Rsplt_test_mse_loss /= len(Rsplt_test_loader)

    for batch_idx, (data, target) in enumerate(Xshft_test_loader):
    
        outputs = method.model(data.to(device))
        loss = criterion(outputs , target.float().to(device))
        Xshft_test_mse_loss += loss.item() 
    Xshft_test_mse_loss /= len(Xshft_test_loader)
    
    for batch_idx, (data, target) in enumerate(pizeo_test_loader):
    
        outputs = method.model(data.to(device))
        loss = criterion(outputs , target.float().to(device))
        pizeo_test_mse_loss += loss.item() 
    pizeo_test_mse_loss /= len(pizeo_test_loader)
            
    for batch_idx, (data, target) in enumerate(statY_test_loader):
    
        outputs = method.model(data.to(device))
        loss = criterion(outputs , target.float().to(device))
        statY_test_mse_loss += loss.item() 
    statY_test_mse_loss /= len(statY_test_loader)
    
    for batch_idx, (data, target) in enumerate(infoY_test_loader):
    
        outputs = method.model(data.to(device))
        loss = criterion(outputs , target.float().to(device))
        infoY_test_mse_loss += loss.item() 
    infoY_test_mse_loss /= len(infoY_test_loader)


    rsplt_ave = np.average([
                                Rsplt1_test_mse_loss,
                                Rsplt2_test_mse_loss,
                                Rsplt3_test_mse_loss,
                                Rsplt4_test_mse_loss,
                                Rsplt5_test_mse_loss
                                ])

    save =[]
    if partial_train_loss < partial_train_best_loss:
        partial_train_best_loss = partial_train_loss
        save.append("partialTrain")
    if Rsplt1_test_mse_loss < Rsplt_testset1_best_loss:
            Rsplt_testset1_best_loss  =  Rsplt1_test_mse_loss
            save.append("Rsplt1")
    if Rsplt2_test_mse_loss < Rsplt_testset2_best_loss:
            Rsplt_testset2_best_loss  =  Rsplt2_test_mse_loss
            save.append("Rsplt2")
    if Rsplt3_test_mse_loss < Rsplt_testset3_best_loss:
            Rsplt_testset3_best_loss  =  Rsplt3_test_mse_loss
            save.append("Rsplt3")
    if Rsplt4_test_mse_loss < Rsplt_testset4_best_loss:
            Rsplt_testset4_best_loss  =  Rsplt4_test_mse_loss
            save.append("Rsplt4")
    if Rsplt5_test_mse_loss < Rsplt_testset5_best_loss:
            Rsplt_testset5_best_loss  =  Rsplt5_test_mse_loss
            save.append("Rsplt5")
    



    if Xshft_test_mse_loss < Xshft_test_best_loss:
            Xshft_test_best_loss  =  Xshft_test_mse_loss
            save.append("Xshft")
    if pizeo_test_mse_loss < pizeo_test_best_loss:
            pizeo_test_best_loss  =  pizeo_test_mse_loss
            save.append("pizeo")
    if statY_test_mse_loss < statY_test_best_loss:
            statY_test_best_loss  =  statY_test_mse_loss
            save.append("statY")
    if infoY_test_mse_loss < infoY_test_best_loss:
            infoY_test_best_loss  =  infoY_test_mse_loss
            save.append("infoY")

    
    # Stop the training process if the training loss has stopped decreasing or has started to increase
    if train_loss < train_best_loss:
        train_best_loss = train_loss
        counter = 0
        torch.save(method.model, f'IR3_epoch_{epoch}.pt')
        torch.save(method.weights, f"SAL_weight_{epoch}_gamma_{attack_gamma}.pt")
        save.append("Train")
    else:
        counter += 1
        
        print(f'training Loss has not improved for {counter} epochs.')
            
   
   
    if rsplt_ave < Rsplt_test_best_loss:
            Rsplt_test_best_loss  =  rsplt_ave
            counter_val = 0
            torch.save(method.model,"IR3_SAL-bset-Rsplt_test_mse_loss.pt")
            save.append("Rsplt_AVE")
    else:
        counter_val += 1
        if counter_val >= 500:
            print(f'Training stopped. Valid (rand) Loss has not improved for {500} epochs.')
            break


    entry = [epoch, 
                                    f"{train_loss:.4f}",
                                    f"{partial_train_loss:.4f}",
                                    f"{Rsplt1_test_mse_loss:.4f}",
                                    f"{Rsplt2_test_mse_loss:.4f}",
                                    f"{Rsplt3_test_mse_loss:.4f}",
                                    f"{Rsplt4_test_mse_loss:.4f}",
                                    f"{Rsplt5_test_mse_loss:.4f}",
                                    f"{rsplt_ave:.4f}",

                                    f"{Xshft_test_mse_loss:.4f}",
                                    f"{pizeo_test_mse_loss:.4f}",
                                    f"{statY_test_mse_loss:.4f}",
                                    f"{infoY_test_mse_loss:.4f}",

                                    f"                         ", 

                                    f"{train_best_loss:.4f}", 
                                    f"{partial_train_best_loss:.4f}",
                                    f"{Rsplt_testset1_best_loss:.4f}", 
                                    f"{Rsplt_testset2_best_loss:.4f}", 
                                    f"{Rsplt_testset3_best_loss:.4f}", 
                                    f"{Rsplt_testset4_best_loss:.4f}", 
                                    f"{Rsplt_testset5_best_loss:.4f}", 

                                    f"{Rsplt_test_best_loss:.4f}", 
                                    f"{Xshft_test_best_loss:.4f}", 
                                    f"{pizeo_test_best_loss:.4f}", 
                                    f"{statY_test_best_loss:.4f}", 
                                    f"{infoY_test_best_loss:.4f}", 

                                    save    ,
                                    f'{attack_gamma}'
                                    ]
    loss_df.loc[len(loss_df)] = entry

    loss_df.to_csv('IR3_plain-training_loss.csv', index=False)
    epoch=epoch+1

    # adjust gamma according to min(weight)
    min_weight = 1e8
    for i in range(method.weights.shape[0]):
        if method.weights[i] > 0.0 and method.weights[i] < min_weight:
            min_weight = method.weights[i]
        if method.weights[i] < 0.0:
            method.weights[i] = 1.0
            zero_list.append(i)

    attack_gamma = (1.0 / min_weight).data
    if epoch <=5:
        continue
    method.weights -= lr_weight * deltaw.detach().reshape(method.weights.shape)
    del rtheta
    del dr_dx
    del deltaw
    del max_grad
    del deltastep
    del lr_weight
    gc.collect()
    torch.cuda.empty_cache()
    gc.collect()
    torch.cuda.empty_cache()





current in epoch    0      batch 0
current in epoch    0      batch 1
current in epoch    0      batch 2
current in epoch    0      batch 3
current in epoch    0      batch 4
current in epoch    0      batch 5
current in epoch    0      batch 6
current in epoch    0      batch 7
current in epoch    0      batch 8
current in epoch    0      batch 9
current in epoch    0      batch 10
current in epoch    0      batch 11
current in epoch    0      batch 12
current in epoch    0      batch 13
current in epoch    0      batch 14
current in epoch    0      batch 15
current in epoch    0      batch 16
current in epoch    0      batch 17
current in epoch    0      batch 18
current in epoch    0      batch 19
current in epoch    0      batch 20
current in epoch    0      batch 21
current in epoch    0      batch 22
current in epoch    0      batch 23
current in epoch    0      batch 24
current in epoch    0      batch 25
current in epoch    0      batch 26
current in epoch    0      batch 27
Ep

RuntimeError: Function 'MmBackward0' returned nan values in its 0th output.

In [None]:
torch.save(method.weights, f"Whole_SAL_{epoch}.pt")

In [None]:
torch.save(method.model, f"IR3_epoch_{epoch}.pt")