# **Import Some Packages**

In [436]:
# PyTorch
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split

# For data preprocess
import pandas as pd
import numpy as np
import csv
import os

from tqdm import tqdm

import math
from torch.utils.tensorboard import SummaryWriter
# For plotting
import matplotlib.pyplot as plt
from matplotlib.pyplot import figure
    
from sklearn.preprocessing import StandardScaler

# **Some Utilities**

You do not need to modify this part.

In [2]:
def get_device():
    ''' Get device (if GPU is available, use GPU) '''
    return 'cuda' if torch.cuda.is_available() else 'cpu'

def plot_learning_curve(loss_record, title=''):
    ''' Plot learning curve of your DNN (train & dev loss) '''
    total_steps = len(loss_record['mean_train_loss'])
    x_1 = range(total_steps)
    x_2 = x_1[::len(loss_record['mean_train_loss']) // len(loss_record['mean_valid_loss'])]
    figure(figsize=(6, 4))
    plt.plot(x_1, loss_record['mean_train_loss'], c='tab:red', label='train')
    plt.plot(x_2, loss_record['mean_valid_loss'], c='tab:cyan', label='dev')
    plt.ylim(0,3)
    plt.xlabel('Training steps')
    plt.ylabel('MSE loss')
    plt.title('Learning curve of {}'.format(title))
    plt.legend()
    plt.show()


def plot_pred(dv_set, model, device, lim=35., preds=None, targets=None):
    ''' Plot prediction of your DNN '''
    if preds is None or targets is None:
        model.eval()
        preds, targets = [], []
        for x, y in dv_set:
            x, y = x.to(device), y.to(device)
            with torch.no_grad():
                pred = model(x)
                preds.append(pred.detach().cpu())
                targets.append(y.detach().cpu())
        preds = torch.cat(preds, dim=0).numpy()
        targets = torch.cat(targets, dim=0).numpy()

    figure(figsize=(5, 5))
    plt.scatter(targets, preds, c='r', alpha=0.5)
    plt.plot([-0.2, lim], [-0.2, lim], c='b')
    plt.xlim(-0.2, lim)
    plt.ylim(-0.2, lim)
    plt.xlabel('ground truth value')
    plt.ylabel('predicted value')
    plt.title('Ground Truth v.s. Prediction')
    plt.show()
    


In [438]:
def same_seed(seed): 
    '''Fixes random number generator seeds for reproducibility.'''
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

def train_valid_split(data_set, valid_ratio, seed):
    '''Split provided training data into training set and validation set'''
    valid_set_size = int(valid_ratio * len(data_set)) 
    train_set_size = len(data_set) - valid_set_size
    train_set, valid_set = random_split(data_set, [train_set_size, valid_set_size], generator=torch.Generator().manual_seed(seed))
    return np.array(train_set), np.array(valid_set)

def predict(test_loader, model, device):
    model.eval() # Set your model to evaluation mode.
    preds = []
    for x in tqdm(test_loader):
        x = x.to(device)                        
        with torch.no_grad():                   
            pred = model(x)                     
            preds.append(pred.detach().cpu())   
    preds = torch.cat(preds, dim=0).numpy()  
    return preds

In [439]:
attri_data = pd.read_csv('../LiterallyWikidata/files_needed/numeric_literals_ver06')

In [440]:
attri_data

Unnamed: 0,e,a,v,name_e,name_a,ent_type,std_v,minmax_v
0,Q1000056,P1082,1.103200e+04,Sušice,population,Q7841907,-0.030708,2.649215e-06
1,Q1000056,P2044,4.720000e+02,Sušice,elevation above sea level,Q7841907,0.031306,2.219876e-03
2,Q1000056,P2046,4.563000e+07,Sušice,area,Q7841907,-0.006852,7.492611e-12
3,Q1000138,P1082,1.375000e+03,Cantenac,population,Q484170,-0.031014,3.301914e-07
4,Q1000138,P2044,1.000000e+00,Cantenac,elevation above sea level,Q484170,-0.116549,1.058141e-03
...,...,...,...,...,...,...,...,...
296303,Q99987,P1333_Longtiude,9.586108e+00,Brembate di Sopra,coordinates of southernmost point,Q747074,-0.140675,5.334437e-01
296304,Q99987,P1334_Longtiude,9.595788e+00,Brembate di Sopra,coordinates of easternmost point,Q747074,-0.144653,5.264729e-01
296305,Q99987,P1335_Longtiude,9.564391e+00,Brembate di Sopra,coordinates of westernmost point,Q747074,-0.142728,5.252904e-01
296306,Q99987,P625_Longtiude,9.580647e+00,Brembate di Sopra,coordinate location(logtitude),Q747074,0.130122,5.244460e-01


In [441]:
# 用kgeemb順序
ent2idx ={}
with open('../LiterallyWikidata/files_needed/list_ent_ids.txt','r') as fr:
    for i, word in enumerate(fr.readlines()):
        ent2idx[word.strip()] = i


attri_data_std_v = attri_data[['e','a','std_v']]
attri_data_eav = attri_data[['e','a','v']]
# att2idx = {}
# #rel2idx = {v:k for k,v in enumerate(relations['label'].unique())}

# with open('../LiterallyWikidata/files_needed/attribute.txt','r') as fr:
#     for i, word in enumerate(fr.readlines()):
#         att2idx[word.strip()] = i
        
att2idx = {v:k for k,v in enumerate(attri_data['a'].unique())}

In [442]:
# attri_data['a_idx']=attri_data['a'].map(att2idx)
# attri_data['e_idx']=attri_data['e'].map(ent2idx)

In [443]:
attri_data_eav

Unnamed: 0,e,a,v
0,Q1000056,P1082,1.103200e+04
1,Q1000056,P2044,4.720000e+02
2,Q1000056,P2046,4.563000e+07
3,Q1000138,P1082,1.375000e+03
4,Q1000138,P2044,1.000000e+00
...,...,...,...
296303,Q99987,P1333_Longtiude,9.586108e+00
296304,Q99987,P1334_Longtiude,9.595788e+00
296305,Q99987,P1335_Longtiude,9.564391e+00
296306,Q99987,P625_Longtiude,9.580647e+00


In [444]:
def numeric_literal_array(data, ent2idx, att2idx):
    #'LiterallyWikidata/LitWD48K/train_attri_data'
    df_all = data

    # Resulting file
    num_lit = np.zeros([len(ent2idx), len(att2idx)], dtype=np.float32)

# Create literal wrt vocab
    for i, (s, p, lit) in enumerate(df_all.values):
        try:
            num_lit[ent2idx[s], att2idx[p]] = lit
        except KeyError:
            continue
    return num_lit


# num_lit shape (47998, 86)


In [445]:
num_lit = numeric_literal_array(attri_data_eav, ent2idx, att2idx)
print(num_lit.shape)

(47998, 86)


In [446]:
num_lit_stdv = numeric_literal_array(attri_data_std_v, ent2idx, att2idx)
print(num_lit_stdv.shape)

(47998, 86)


In [447]:
# ## constraint needed:
pop_idx = att2idx['P1082']
gdp = att2idx['P4010']
nominal_gdp = att2idx['P2131']
# nominal_gdp_per = att2idx['P2132']
gdp_per = att2idx['P2299']
# date_of_birth = att2idx['P569']
# date_of_death = att2idx['P570']
# area = ['P2046']
# # net_profit = att2idx['P2295']
# # retirement_age = att2idx['P3001']
# # age_of_majority = att2idx['P2997']
# # work_start = att2idx['P2031']
# # work_end = att2idx['P2032']

In [448]:
ent2idx['Q1000']

84

In [449]:
attri_data[attri_data['a']=='P1082']

Unnamed: 0,e,a,v,name_e,name_a,ent_type,std_v,minmax_v
0,Q1000056,P1082,11032.0,Sušice,population,Q7841907,-0.030708,2.649215e-06
3,Q1000138,P1082,1375.0,Cantenac,population,Q484170,-0.031014,3.301914e-07
6,Q100013,P1082,4109.0,Brembilla,population,Q1134686,-0.030927,9.867318e-07
9,Q100015,P1082,6009.0,Brignano Gera d'Adda,population,Q747074,-0.030867,1.442996e-06
12,Q100016,P1082,119.0,Brumano,population,Q747074,-0.031053,2.857656e-08
...,...,...,...,...,...,...,...,...
123807,Q99983,P1082,706.0,Bracca,population,Q747074,-0.031035,1.695383e-07
123810,Q99985,P1082,712.0,Branzi,population,Q747074,-0.031035,1.709791e-07
123813,Q99986,P1082,8551.0,Brembate,population,Q747074,-0.030787,2.053430e-06
123816,Q99987,P1082,7868.0,Brembate di Sopra,population,Q747074,-0.030808,1.889415e-06


In [450]:
num_lit[84][gdp_per]*num_lit[84][pop_idx]
num_lit[84][gdp]

36681910000.0

In [451]:
# x_list: ent的gdp有值，把除了gdp那個值之外的值存到inner_x_list，len(x_list)是有幾組變數有值
# normalized or non-normalized
x_list=[]

for ent in num_lit:
    if ent[gdp] == 0:
        pass
    else:
        #print(ent[gdp])
        inner_x_list=[]
        for j in range(len(ent)):
            if j != gdp :
                inner_x_list.append(ent[j])
        inner_x_list.append(ent[gdp])
        x_list.append(inner_x_list)


In [452]:
select_feature = []
for i in range(len(x_list[0])-1):
    if x_list[0][i] !=0:
        select_feature.append(i)

In [453]:
pd.DataFrame(x_list,columns=list(range(len(x_list[0]))))

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,76,77,78,79,80,81,82,83,84,85
0,325145952.0,0.0,9.826675e+12,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,71.379997,25.837132,44.815392,65.637886,-98.579498,-156.479996,-97.394829,-66.949951,-168.118286,1.948539e+13
1,83149296.0,0.0,3.574000e+11,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,54.911209,47.270027,51.272800,51.051102,10.000000,8.669200,10.178447,15.041781,5.866366,4.345631e+12
2,2025137.0,0.0,2.676670e+08,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,2.320000,-3.960050,-0.617460,-0.624444,11.500000,11.700000,11.153220,14.526600,8.708055,3.668191e+10
3,19586540.0,0.0,2.383970e+11,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,48.270000,43.623371,45.163815,46.121685,25.000000,26.700001,25.391186,29.716679,20.261787,5.221131e+11
4,42558328.0,0.0,6.036290e+11,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,52.379528,44.386417,49.260277,48.419167,32.000000,33.190605,33.777222,40.228333,22.137222,3.695664e+11
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
151,807610.0,0.0,3.839400e+10,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,28.320000,26.702021,27.281780,27.141870,90.500000,90.000000,89.770851,92.125229,88.746468,7.584701e+09
152,777859.0,0.0,2.149700e+11,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,8.530000,1.164920,1.958640,5.962424,-59.316666,-59.980000,-58.825779,-56.491150,-61.406292,6.362827e+09
153,341465152.0,0.0,0.000000e+00,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,1.539088e+13
154,5260750.0,0.0,3.420000e+11,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,3.700000,0.000000,0.000000,0.000000,15.383330,17.480000,0.000000,0.000000,0.000000,2.869352e+10


In [454]:
y_list=list()
for ent2 in num_lit:
    if ent2[gdp] ==0:
        pass
    else:
        y_list.append(ent2[gdp])

In [455]:
for i in range(len(x_list)):
    inner_x = x_list[i]
    inner_x.append(y_list[i])
x_list.append(inner_x)


In [456]:
select_feature

[0,
 2,
 12,
 13,
 14,
 15,
 16,
 17,
 18,
 19,
 65,
 75,
 76,
 77,
 78,
 79,
 80,
 81,
 82,
 83,
 84]

In [457]:
attri_data[attri_data['a']=='P4010']

Unnamed: 0,e,a,v,name_e,name_a,ent_type,std_v,minmax_v
156,Q1000,P4010,3.668191e+10,Gabon,GDP (PPP),Q3624078,-0.225985,0.001880
414,Q1005,P4010,3.569106e+09,The Gambia,GDP (PPP),Q6256,-0.241117,0.000181
461,Q1006,P4010,2.857575e+10,Guinea,GDP (PPP),Q6256,-0.229690,0.001464
519,Q1007,P4010,3.171312e+09,Guinea-Bissau,GDP (PPP),Q6256,-0.241299,0.000160
588,Q1008,P4010,9.583674e+10,Ivory Coast,GDP (PPP),Q7270,-0.198953,0.004916
...,...,...,...,...,...,...,...,...
121853,Q971,P4010,2.869352e+10,Republic of the Congo,GDP (PPP),Q6256,-0.229636,0.001470
122011,Q974,P4010,7.231902e+10,Democratic Republic of the Congo,GDP (PPP),Q6256,-0.209700,0.003709
122243,Q977,P4010,2.342711e+09,Djibouti,GDP (PPP),Q6256,-0.241678,0.000118
122774,Q983,P4010,3.098132e+10,Equatorial Guinea,GDP (PPP),Q3624078,-0.228591,0.001588


In [458]:
def select_feat(train_data, valid_data, select_all=True):
    '''Selects useful features to perform regression'''
    sc = StandardScaler()
    train_data = sc.fit_transform(train_data)
    valid_data = sc.transform(valid_data)
    y_train, y_valid = train_data[:,-1], valid_data[:,-1]
    raw_x_train, raw_x_valid = train_data[:,:-1], valid_data[:,:-1]
    

    if select_all:
        feat_idx = list(range(raw_x_train.shape[1]))
    else:
        feat_idx = select_feature # TODO: Select suitable feature columns.
        
    return raw_x_train[:,feat_idx], raw_x_valid[:,feat_idx], y_train, y_valid

In [459]:
class KGMTL_Data(Dataset):
    '''
    x: Features.
    y: Targets, if none, do prediction.
    '''
    def __init__(self, x, y=None):
        if y is None:
            self.y = y
        else:
            self.y = torch.FloatTensor(y)
        self.x = torch.FloatTensor(x)
        

    def __getitem__(self, idx):
        if self.y is None:
            return self.x[idx]
        else:
            return self.x[idx], self.y[idx]

    def __len__(self):
        return len(self.x)


    

In [460]:
# Set seed for reproducibility
same_seed(config['seed'])


# train_data size: 2699 x 118 (id + 37 states + 16 features x 5 days) 
# test_data size: 1078 x 117 (without last day's positive rate)
# train_data, test_data = pd.read_csv('../LiterallyWikidata/files_needed/train_attri_data.csv',sep='\t').values, 
# pd.read_csv('../LiterallyWikidata/files_needed/test_attri_data.csv',sep='\t').values
train_data, valid_data = train_valid_split(x_list, config['valid_ratio'], config['seed'])

# Print out the data size.
print(f"""train_data size: {train_data.shape} 
valid_data size: {valid_data.shape} """)
# test_data size: {test_data.shape}""")


# Select features
x_train, x_valid, y_train, y_valid = select_feat(train_data, valid_data, config['select_all'])

# Print out the number of features.
print(f'number of features: {x_train.shape[1]}')

train_dataset, valid_dataset = KGMTL_Data(x_train, y_train), \
                                            KGMTL_Data(x_valid, y_valid)

print('train_dataset', train_dataset[0])

# Pytorch data loader loads pytorch dataset into batches.
train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True, pin_memory=True)
valid_loader = DataLoader(valid_dataset, batch_size=config['batch_size'], shuffle=False, pin_memory=True)
# test_loader = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=False, pin_memory=True)

train_data size: (142, 87) 
valid_data size: (15, 87) 
number of features: 21
train_dataset (tensor([-0.1708, -0.2967, -0.1158, -0.1861, -0.5509,  0.3183,  0.3378, -0.6285,
        -1.5861, -0.3149, -0.1587, -0.0197, -0.0379,  0.1674,  0.1237,  0.1008,
        -1.5408, -1.4524, -1.9680, -1.4826, -1.6331]), tensor(-0.2249))


In [461]:
train_dataset[5]

(tensor([-0.2469, -0.4112, -0.0890, -0.1975, -0.5994,  0.2029, -0.1439, -0.7458,
          0.5574,  0.5738,  0.3054, -0.6115, -0.5746, -0.9687, -0.9789, -0.5979,
          2.1452, -2.5163, -3.2250, -2.4723,  2.4697]),
 tensor(-0.2451))

In [462]:
class NeuralNet(nn.Module):
    ''' A simple fully-connected deep neural network '''
    def __init__(self, input_dim):
        super(NeuralNet, self).__init__()

        self.layers = nn.Sequential(
            nn.Linear(input_dim, 21),
            nn.ReLU(),
            nn.Dropout(0.5),
#             nn.Linear(20,10),
#             nn.ReLU(),
            nn.Linear(21, 1)
        )

        # Mean squared error loss
        self.criterion = nn.MSELoss(reduction='mean')

    def forward(self, x):
        ''' Given input of size (batch_size x input_dim), compute output of the network '''
        x = self.layers(x)
        x = x.squeeze(1)
        return x

    def cal_loss(self, pred, target):
        ''' Calculate loss '''
        # TODO: you may implement L1/L2 regularization here
        return self.criterion(pred, target)

# **Preprocess**

We have three kinds of datasets:
* `train`: for training
* `dev`: for validation
* `test`: for testing (w/o target value)

In [463]:
loss_record={'train': [], 'dev': [],'mean_train_loss':[],'mean_valid_loss':[]} 

def trainer(train_loader, valid_loader, model, config, device):

    criterion = nn.MSELoss(reduction='mean') # Define your loss function, do not modify this.

    # Define your optimization algorithm. 
    # TODO: Please check https://pytorch.org/docs/stable/optim.html to get more available algorithms.
    # TODO: L2 regularization (optimizer(weight decay...) or implement by your self).
    optimizer = torch.optim.SGD(model.parameters(), lr=config['learning_rate'], momentum=0.9, weight_decay=1e-6) 

    n_epochs, best_loss, step, early_stop_count = config['n_epochs'], math.inf, 0, 0

    writer = SummaryWriter() # Writer of tensoboard
    if not os.path.isdir('./models_var'):
        os.mkdir('./models_var') # Create directory of saving models.
    
    for epoch in range(n_epochs):
        model.train() # Set your model to train mode.

        # tqdm is a package to visualize your training progress.
        train_pbar = tqdm(train_loader, position=0, leave=True)

        for x, y in train_pbar:
            optimizer.zero_grad()               # Set gradient to zero.
            
            x, y = x.to(device), y.to(device)   # Move your data to device. 
            pred = model(x)   
            print(f'-------predict: {pred}, y: {y}----------') 
            #x_constraint = torch.tensor([ (y[i] - x[i][0]*x[i][18]) ** 2 for i in range(len(x))])
            x_constraint = torch.tensor([x[i][pop_idx]*x[i][gdp_per] for i in range(len(x))])
            print(x_constraint)
            x_constraint = x_constraint.to(device)          
            loss = criterion(pred, y) + criterion(pred, x_constraint)
            #loss = criterion(pred, y) 
            # criterion(pred,x_constraint)
                # ((pred-x[0]*x[18])**2) 

            loss.backward()                     # Compute gradient(backpropagation).
            optimizer.step()                    # Update parameters.
            step += 1
            loss_record["train"].append(loss.detach().item())
            
            # Display current epoch number and loss on tqdm progress bar.
            train_pbar.set_description(f'Epoch [{epoch+1}/{n_epochs}]')
            train_pbar.set_postfix({'loss': loss.detach().item()})

        mean_train_loss = sum(loss_record["train"])/len(loss_record["train"])
        writer.add_scalar('Loss/train', mean_train_loss, step)
        loss_record['mean_train_loss'].append(mean_train_loss)

        model.eval() # Set your model to evaluation mode.
        for x, y in valid_loader:
            x, y = x.to(device), y.to(device)
            with torch.no_grad():
                pred = model(x)
                # print(f'x: {x}')
                loss = criterion(pred, y)

            loss_record["dev"].append(loss.item())
            
        mean_valid_loss = sum(loss_record["dev"])/len(loss_record["dev"])
        print(f'Epoch [{epoch+1}/{n_epochs}]: Train loss: {mean_train_loss:.4f}, Valid loss: {mean_valid_loss:.4f}')
        writer.add_scalar('Loss/valid', mean_valid_loss, step)
        loss_record['mean_valid_loss'].append(mean_valid_loss)

        if mean_valid_loss < best_loss:
            best_loss = mean_valid_loss
            torch.save(model.state_dict(), config['save_path']) # Save your best model
            print('Saving model with loss {:.3f}...'.format(best_loss))
            early_stop_count = 0
        else: 
            early_stop_count += 1

        if early_stop_count >= config['early_stop']:
            print('\nModel is not improving, so we halt the training session.')
            return

## **Validation**

In [464]:
def dev(dv_set, model, device):
    model.eval()                                # set model to evalutation mode
    total_loss = 0
    for x, y in dv_set:                         # iterate through the dataloader
        x, y = x.to(device), y.to(device)       # move data to device (cpu/cuda)
        with torch.no_grad():                   # disable gradient calculation
            pred = model(x)# forward pass (compute output)
            #print(pred)
            mse_loss = model.cal_loss(pred, y)  # compute loss
            print(mse_loss)
        total_loss += mse_loss.detach().cpu().item() * len(x)  # accumulate loss
    total_loss = total_loss / len(dv_set.dataset)              # compute averaged loss

    return total_loss

## **Testing**

In [465]:
device = get_device()                 # get the current available device ('cpu' or 'cuda')
#os.makedirs('models', exist_ok=True)  # The trained model will be saved to ./models/

# TODO: How to tune these hyper-parameters to improve your model's performance?
config = {
    'seed': 80215,      # Your seed number, you can pick your lucky number. :)
    'select_all': False,   # Whether to use all features.
    'n_epochs': 50,                # maximum number of epochs
    'batch_size': 141,               # mini-batch size for dataloader
    'learning_rate':1e-3,
    'early_stop': 10,               # early stopping epochs (the number epochs since your model's last improvement)
    'save_path': './models_var/model_gdp_no_cons.pt' , # your model will be saved here
    'valid_ratio': 0.1,   # validation_size = train_size * valid_ratio
}


# **Load data and model**

In [466]:
model = NeuralNet(input_dim=x_train.shape[1]).to(device)  # Construct model and move to device
print(model)

NeuralNet(
  (layers): Sequential(
    (0): Linear(in_features=21, out_features=21, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=21, out_features=1, bias=True)
  )
  (criterion): MSELoss()
)


# **Start Training!**

In [467]:
trainer(train_loader, valid_loader, model, config, device)

Epoch [1/50]: 100%|██████████████████████████████████████████████████████| 2/2 [00:00<00:00, 32.01it/s, loss=0.464]


-------predict: tensor([-0.2672,  0.1972,  0.2223,  0.2213,  0.4988, -0.2635,  0.1044, -0.0908,
         0.2254,  0.2885,  0.4322,  0.7681, -0.0205,  0.6254,  0.2754, -0.3920,
         0.1301,  0.6112,  0.0498, -0.3820,  0.1285,  0.4366,  0.0991,  0.3344,
         0.2212,  0.2757, -0.2075,  0.1163,  0.5275,  0.4677,  0.2288,  1.0191,
         0.3684, -0.1631,  0.0926,  0.3659, -0.0996,  0.4232,  0.1880,  0.3865,
         0.1612,  0.4054,  0.2070, -0.1229,  0.4490,  0.2506,  0.3447,  0.5494,
         0.2211,  0.2592,  0.1252,  0.2614,  0.0551,  0.5209,  0.0999,  0.1336,
         0.0840,  0.1445,  0.4245,  1.5251,  0.3617,  0.8709,  0.1748,  0.6865,
         0.2957,  0.0406, -0.4904, -0.1823,  0.1664,  0.4909,  0.1708,  0.2483,
         0.4104,  0.2096,  0.3806,  0.4941,  0.1236,  0.3923,  0.4880,  0.2490,
         0.5151,  0.3653, -0.0429, -0.2162, -0.1947,  0.7083, -0.1541,  0.2042,
        -0.2021,  0.2649,  0.3253,  0.6084,  2.1907,  0.3737,  1.0690,  0.4265,
         0.1090,  0.3641

  0%|                                                                                        | 0/2 [00:00<?, ?it/s]

-------predict: tensor([ 0.4726,  0.4325,  0.0763,  0.6669,  0.0965,  0.3843,  0.2821,  0.3977,
         0.2448,  0.0825,  0.2800,  0.1647,  0.2558,  0.3317,  0.2566, -0.0866,
         0.0931,  0.1724,  0.7773,  0.1392,  0.4585,  0.3394,  0.1685,  0.3403,
         2.5054,  0.3206,  0.2442,  0.0444, -0.0882, -0.1145,  0.1186,  0.3894,
        -0.1489,  0.0173,  0.0297,  0.1438,  0.1203,  0.2252,  0.1416,  0.4494,
         0.4780,  0.1004, -0.1861,  0.2277, -0.3130,  0.2130, -0.0344,  0.2966,
         0.4489,  0.5486,  0.2995,  0.1532,  0.0626,  0.3162,  0.2137,  0.0319,
         0.3111, -0.0535,  0.0313,  0.3362,  0.2111, -0.1591,  0.3250,  0.3203,
        -0.3703,  0.1770,  0.1222,  0.1205,  0.0715,  0.4472,  0.3242, -0.0166,
         0.0909,  0.2868,  0.1873,  0.3871,  0.2077, -0.3332,  0.1718, -0.1222,
        -0.3405,  0.2964,  1.0411,  0.8570,  0.7908,  0.6713,  0.0178,  0.4656,
         0.3058,  0.1810,  0.0225,  0.1197,  0.5689,  0.4451,  0.4565,  0.0509,
        -0.0253, -0.3904

Epoch [2/50]: 100%|█████████████████████████████████████████████████████| 2/2 [00:00<00:00, 46.49it/s, loss=0.0253]


tensor([-3.5336e-02, -1.1954e-02,  5.6202e-02, -1.8274e-02, -7.0731e-02,
        -4.1969e-01,  5.6118e-02, -1.3818e-03,  2.0601e-02, -3.0378e-04,
         4.2126e-02,  1.2040e-02,  5.8750e-03, -1.6254e-02,  2.4295e-02,
        -1.6353e-01,  3.8809e-03,  2.8987e-02,  6.2135e-01,  5.4263e-02,
        -2.4948e-01,  2.7419e-01,  3.3819e-02, -2.2636e-02, -6.0509e+00,
        -1.1386e-01, -1.9864e-01,  1.1882e-02, -1.3127e-02,  2.2802e-01,
         1.8101e-01, -9.4473e-02,  2.2287e-02, -5.3402e-01, -1.3853e-01,
         1.9640e-01,  6.0025e-02, -1.1993e-01,  2.6465e-01, -3.1619e-03,
        -7.5340e-02,  2.2071e+00, -1.3775e-01,  6.7521e-01,  1.4527e+00,
         7.6434e-02, -7.2477e-02, -1.6644e-02,  2.0005e-01,  3.2908e-01,
        -2.1879e-02,  1.3043e+00, -1.5848e-02,  1.0088e-01, -9.2765e-02,
         8.9560e-03, -1.1672e-01, -4.3638e-01, -2.2393e-02,  3.0004e-01,
         1.1929e-02, -2.1008e-02, -5.7422e-02,  2.9132e-01, -5.2797e-01,
        -2.8738e-02,  2.2896e-01,  1.3274e-02,  4.3

Epoch [3/50]: 100%|██████████████████████████████████████████████████████| 2/2 [00:00<00:00, 50.54it/s, loss=0.239]


-------predict: tensor([-0.0318,  0.8532, -0.1336,  0.4993,  1.5893,  0.1658,  1.0500,  0.1902,
         0.2137,  0.1455, -0.3020,  0.1623,  0.0455,  0.1289,  0.2314,  0.2935,
         0.0705,  0.7003, -0.5764,  0.9443,  0.5499,  0.0979,  0.6047,  0.1300,
         0.3812,  0.3729,  0.1811,  0.6651,  0.3939,  0.1253,  0.2627, -0.4163,
        -0.1866, -0.0486,  0.5823,  0.1032,  0.2684,  0.2539,  0.0470,  0.0885,
         0.1963,  0.0878,  0.4169,  0.2779,  0.2560,  0.1173,  0.0802,  0.1836,
         0.2555,  0.3681,  0.5022,  0.2872,  0.0908,  2.7226,  0.4906,  0.2026,
         0.2704,  0.0526,  0.0245,  0.4354, -0.5173,  0.0399,  0.2401,  0.1799,
         0.4799,  0.1301,  0.5098,  0.5316,  0.3181,  0.7133,  0.2562,  0.1308,
        -0.1855,  0.2300,  0.6032,  0.7854,  0.0428,  0.6734, -0.1035, -0.6342,
         0.8266,  0.4286,  0.8283,  0.3625,  0.2675,  0.4178,  0.1276,  0.3224,
         0.2636, -0.4329,  0.2146,  0.1523,  0.6901,  0.5954,  0.9672,  0.1923,
         0.0539,  0.9295

Epoch [4/50]:   0%|                                                               | 0/2 [00:00<?, ?it/s, loss=2.18]

-------predict: tensor([ 2.6010e-01,  3.9377e-01,  2.8351e-01,  9.6520e-02,  1.8292e-01,
         9.8381e-02,  4.6660e-01,  3.0665e-01,  5.5439e-01,  2.2861e-01,
         4.0363e-01,  2.6696e-01, -7.0299e-02,  1.5094e-01, -4.6570e-02,
        -2.1496e-02,  2.2211e-01,  5.8481e-01,  7.3939e-01,  6.3207e-01,
         3.3926e-02,  3.6311e-01,  2.4756e-01, -1.2359e-01,  7.7098e-02,
         7.3693e-02,  2.0253e-01,  7.6038e-01,  3.1851e-01,  1.6111e-01,
         3.0752e-01,  4.2374e-01,  1.0440e+00,  1.5491e-01, -9.6438e-01,
         2.4391e-01,  3.2132e-01,  3.2527e-01,  3.0933e-01, -5.4641e-02,
         3.5020e-01, -1.1946e-01,  6.5557e-02,  3.7084e-01,  1.2019e-01,
         3.1394e-01,  4.6536e-01,  3.3188e-01,  1.1781e+00,  8.7761e-01,
         6.7130e-02,  2.8431e-01,  2.7254e-01,  1.0517e-01,  9.0923e-02,
         9.5801e-01,  5.0276e-01,  2.4711e-01,  2.3391e-01, -5.2274e-02,
         1.7772e-01,  1.1511e-01, -2.0988e-01,  2.8885e-01,  4.1268e-01,
         9.9703e-02,  1.8122e-01, -

Epoch [4/50]: 100%|██████████████████████████████████████████████████████| 2/2 [00:00<00:00, 52.52it/s, loss=0.152]


Epoch [4/50]: Train loss: 1.2115, Valid loss: 0.1570
Saving model with loss 0.157...


Epoch [5/50]: 100%|████████████████████████████████████████████████████| 2/2 [00:00<00:00, 57.93it/s, loss=0.00905]


-------predict: tensor([ 0.4020,  0.0382, -0.0843,  0.1083,  0.0219,  0.2732, -0.1328,  1.1189,
        -0.0197,  0.0245,  0.1956,  0.7391,  0.0908,  0.0681,  0.4995,  0.2801,
        -0.2397, -0.1712,  0.6958,  0.1028,  0.2293,  0.2082,  0.8003,  0.4296,
        -0.3379,  0.0441, -0.1177, -0.1330, -0.0984,  0.1297,  0.1174,  0.1076,
         0.6513, -0.0334,  0.9695,  0.4839, -0.2322,  0.3880, -0.6407,  0.4231,
         0.3158,  0.1083,  0.1417,  0.1593,  0.2712,  0.0806,  0.2069,  0.0441,
         0.1539,  0.0980,  0.7370,  0.3114,  0.1421,  0.3040,  0.2559, -0.1450,
        -0.4963,  0.1852,  0.6115,  0.4147,  0.4529,  0.5327, -0.0958, -0.0680,
         0.3741,  0.0526,  0.2772, -0.1636,  0.4043,  1.3144, -0.4172, -0.0624,
        -0.2151,  0.5428,  0.2925, -0.1988,  0.2318,  0.3148,  0.5974,  0.5457,
         0.0408,  0.7575, -0.1086,  0.4182, -0.0089,  0.1941,  0.1167,  0.3154,
         0.3239, -0.0613,  0.3238,  0.2470,  0.6032,  0.0785,  0.5030,  0.4329,
         0.3750, -0.0588

  0%|                                                                                        | 0/2 [00:00<?, ?it/s]

-------predict: tensor([ 0.0184, -0.1211, -0.9503,  0.2468,  0.1306, -0.0046, -0.0160,  0.3383,
         0.3558,  0.6673,  0.6929,  0.4068,  0.0783,  0.2896,  0.1589,  0.2660,
         0.1077,  0.4616, -0.0502,  0.4353,  0.0332,  0.2275,  0.0926,  0.4130,
        -0.2312,  0.0486,  0.0341,  0.0307, -0.0914, -0.3093,  0.0027,  0.3102,
         0.0810,  0.1680,  0.1322,  0.2073,  0.0708,  0.9004,  0.5378,  0.1124,
         0.6580,  0.3664,  0.0959, -0.0120,  0.1872,  0.0194,  0.2282,  0.6772,
         0.2860,  0.0252,  0.1724,  0.0847, -0.1106,  0.0757,  0.3472,  0.5095,
        -0.0815, -0.4963,  0.1445,  0.1811,  0.2239,  0.6903,  0.4454, -0.9372,
         0.2362,  0.6842, -0.0136,  0.1264,  0.0775,  0.1244,  0.3068, -0.3139,
         0.2393,  0.0641, -0.1488,  0.2028,  0.7529, -0.2480,  0.8426, -0.3385,
         0.1394,  0.8032,  0.6871,  0.6923,  0.1943,  0.0450, -0.0614,  0.3304,
         0.1756,  0.0348,  0.3496,  0.3179,  0.0596,  0.0155,  0.4724,  1.0556,
         0.7864,  0.3033

Epoch [6/50]: 100%|██████████████████████████████████████████████████████| 2/2 [00:00<00:00, 56.15it/s, loss=0.311]


-------predict: tensor([-0.2955], device='cuda:0', grad_fn=<SqueezeBackward1>), y: tensor([-0.2262], device='cuda:0')----------
tensor([0.2580])
Epoch [6/50]: Train loss: 1.2285, Valid loss: 0.1465
Saving model with loss 0.147...


  0%|                                                                                        | 0/2 [00:00<?, ?it/s]

-------predict: tensor([ 3.5321e-01,  5.1862e-01,  3.9542e-01,  6.3715e-01,  3.8044e-01,
        -9.0155e-04,  6.7425e-02,  2.2060e-01, -5.4846e-03,  3.5094e-01,
        -1.3675e-03,  1.2296e-01,  2.3467e-02,  2.5652e-02,  7.9949e-01,
        -1.2002e-01,  5.2043e-01,  6.9433e-01, -3.8556e-02,  1.9133e-01,
         2.5015e-01, -2.2779e-02,  3.7064e-01,  2.3137e-01, -1.1639e-01,
         2.2110e-01,  8.6293e-01,  4.8806e-02,  1.3897e+00,  4.5538e-01,
         5.8648e-01, -4.5405e-02, -2.0814e-01,  8.6047e-01, -4.4173e-01,
         2.5664e-01,  1.1657e+00, -2.9354e-01,  2.9649e-01,  1.4053e-01,
         1.4303e-01,  6.6830e-01, -7.7745e-02,  1.4529e-01,  1.1928e+00,
         1.2614e-01,  1.2452e-01, -7.3423e-01,  4.2634e-01,  1.1814e-01,
         4.7674e-01,  9.9287e-02, -9.5870e-02,  1.0251e-01, -4.5315e-02,
         1.8962e-01,  3.1047e-02,  3.8549e-01, -4.7213e-01,  2.4630e-01,
         3.4033e-01,  6.4071e-01,  5.1970e-01,  5.4169e-02, -6.6353e-02,
        -7.4880e-02,  2.7669e-01,  

Epoch [7/50]: 100%|███████████████████████████████████████████████████████| 2/2 [00:00<00:00, 57.23it/s, loss=1.01]


tensor([-2.5499e-01, -7.5340e-02, -9.4473e-02, -2.8738e-02, -3.1619e-03,
         8.9560e-03, -4.1969e-06,  1.3904e-02, -1.9705e-02,  1.4809e-01,
        -3.0333e-02,  1.4238e-01,  5.1115e-02, -3.0216e-02,  2.6243e-02,
         1.1882e-02,  2.6465e-01, -9.7351e-02, -4.9693e-01, -3.0378e-04,
         2.0005e-01, -2.1879e-02,  2.9132e-01,  2.7172e-01,  5.6118e-02,
         1.8101e-01,  2.2896e-01, -5.2797e-01,  3.4854e-02, -4.7574e-02,
         3.3819e-02, -1.2142e-01,  2.4615e-01,  9.3753e+00, -5.5249e-01,
         2.8582e-02, -5.8953e-01, -4.1969e-01, -7.0731e-02, -1.6644e-02,
         5.5600e-01,  2.4800e-01, -3.2542e-02,  6.0025e-02, -6.0509e+00,
        -9.4370e-02,  2.0601e-02,  5.6202e-02, -1.1672e-01,  2.5802e-01,
        -2.1008e-02,  5.5103e-03,  2.2802e-01,  1.1929e-02,  3.8809e-03,
        -2.2636e-02,  5.8750e-03,  2.7804e-01,  1.4527e+00,  2.8527e-01,
        -1.1954e-02,  2.7158e-02, -1.1386e-01,  2.2071e+00,  4.3104e-02,
        -4.5970e-02,  4.5290e-02,  3.6443e-01, -1.4

  0%|                                                                                        | 0/2 [00:00<?, ?it/s]

-------predict: tensor([ 6.0811e-01,  2.4440e-01,  4.3824e-04, -1.0032e+00,  4.5284e-01,
        -2.9315e-01,  6.2900e-02, -4.3785e-01,  1.0453e+00,  3.9240e-01,
         4.1775e-02, -2.9550e-01,  2.4099e-01,  8.2136e-02,  3.5367e-01,
         3.8944e-01,  2.3993e-01,  1.9283e-01,  3.7292e-01,  1.3614e-01,
         6.5832e-01,  1.1853e-01, -1.4112e-01, -1.2073e-01,  8.6609e-02,
         8.0347e-03, -9.7287e-02,  9.4272e-02, -2.2505e-02,  3.7347e-01,
         2.5874e-02,  5.0619e-02,  7.6640e-01,  3.8880e-01, -2.0854e-02,
         2.3351e-01,  2.5391e-01,  1.8572e-01,  8.7120e-01,  1.9516e-01,
         1.3820e+00, -1.2101e-01,  4.3951e-01,  1.3058e-01,  1.7800e-01,
         1.9434e-01,  1.5238e-02,  8.1477e-01,  6.7415e-01,  2.6526e-01,
         3.5155e-01,  7.1297e-01,  1.9813e-01,  8.7428e-01, -5.6896e-01,
         5.6479e-01, -6.2335e-02,  1.9471e-01,  3.9604e-01,  3.0548e-01,
         6.2098e-01, -5.1358e-03,  4.9902e-02,  4.4963e-02,  8.5511e-02,
         9.9452e-02, -6.4628e-02,  

Epoch [8/50]: 100%|███████████████████████████████████████████████████████| 2/2 [00:00<00:00, 58.74it/s, loss=0.16]


-------predict: tensor([0.0616], device='cuda:0', grad_fn=<SqueezeBackward1>), y: tensor([-0.2420], device='cuda:0')----------
tensor([-0.1986])
Epoch [8/50]: Train loss: 1.2646, Valid loss: 0.1359
Saving model with loss 0.136...


Epoch [9/50]:   0%|                                                               | 0/2 [00:00<?, ?it/s, loss=2.44]

-------predict: tensor([ 0.0305,  0.1158, -0.0133, -0.1583,  0.3158,  0.1569,  0.0342, -0.0040,
        -0.1463, -0.1547,  0.3846,  0.1997,  0.8582, -0.0695,  0.0374,  0.3862,
         0.7355,  0.1039,  0.1507,  0.0165,  0.5343,  0.0417,  0.5678, -0.0603,
         0.0292,  0.1954,  0.3843, -0.6166,  0.1367,  0.1225,  0.0237,  0.1006,
         0.1458,  0.1248,  0.2522, -0.0853,  0.3027,  0.5203, -0.0603,  0.0202,
        -0.0827,  0.2762, -0.7193, -0.0860,  0.4201,  0.0513,  0.4720, -0.0998,
         0.0395,  0.1546,  0.2696, -0.4906,  0.0103,  0.0364,  0.1811,  0.0887,
         0.0582,  0.2104,  0.0165,  0.0048,  0.1698,  0.0497, -0.2179,  0.2772,
        -0.1208, -0.1813, -0.0538, -0.6221,  1.1616,  0.3191, -0.6648,  0.0259,
         0.2279,  0.3797,  0.2302,  0.1147,  0.6087,  0.3391,  0.2463,  0.0122,
         0.0353, -0.2687,  0.1600, -0.3150,  0.0287,  0.4156, -0.0394, -0.1890,
         0.2982, -1.1560,  0.3752,  0.1271, -0.2779,  0.4587,  0.2027, -0.0373,
        -0.1350, -1.0241

Epoch [9/50]: 100%|██████████████████████████████████████████████████████| 2/2 [00:00<00:00, 59.34it/s, loss=0.143]


Epoch [9/50]: Train loss: 1.2673, Valid loss: 0.1303
Saving model with loss 0.130...


Epoch [10/50]: 100%|████████████████████████████████████████████████████| 2/2 [00:00<00:00, 61.06it/s, loss=0.0472]

-------predict: tensor([-8.7293e-02,  1.8451e-01, -6.2486e-01, -4.1598e-01, -2.1305e-01,
         1.4285e-01,  1.1096e-02,  1.6910e-01,  4.4360e-04,  1.2494e-01,
        -1.4922e-01,  4.3428e-02, -5.2649e-02,  1.4030e-01,  2.4308e-01,
         4.2712e-01,  6.4458e-02,  3.5532e-01,  2.7773e-01, -7.7564e-02,
        -1.3630e-01,  1.3237e-01, -9.9564e-02, -2.6469e-01,  2.9763e-01,
        -2.2439e-01,  3.6187e-01, -4.0695e-02, -8.8736e-02,  2.6897e-01,
        -8.9173e-02,  1.7109e-01,  6.7130e-02,  2.9448e-01, -4.0463e-02,
         3.0384e-01, -3.6434e-01, -5.8284e-02,  2.6114e-01,  8.5199e-02,
         3.2995e-01,  2.1819e-01,  3.6866e-01,  4.0607e-01, -4.2406e-02,
         5.2666e-01, -2.4065e-01,  1.6875e-01,  3.0734e-01, -1.1604e-01,
         1.7326e-01,  1.1393e-01,  5.9355e-02,  6.0097e-02, -1.0058e-01,
        -3.3745e-02,  2.7717e-01,  8.4597e-02,  3.2570e-01,  1.3701e-01,
        -3.8035e-03, -1.1281e-01,  3.7692e-01,  1.3379e-02,  2.8156e-01,
        -5.5887e-02,  7.8353e-02, -




Epoch [10/50]: Train loss: 1.2488, Valid loss: 0.1249
Saving model with loss 0.125...


Epoch [11/50]: 100%|██████████████████████████████████████████████████████| 2/2 [00:00<00:00, 60.88it/s, loss=1.19]

-------predict: tensor([ 0.0507,  0.6409,  0.5295,  0.4330, -0.0441,  0.1039,  0.1436,  0.9482,
         0.1819,  0.1700,  0.1217,  0.3220, -0.4998, -0.0541, -0.2247,  0.3660,
         0.0320,  0.0205, -0.0292,  0.0431, -0.0382,  0.0366, -0.0107,  0.4814,
         0.0492, -0.1303, -0.0598, -0.1603, -0.0594, -0.0752, -0.1466,  0.0613,
         0.1782,  0.0574, -0.2851,  0.4276,  0.7935, -0.2445,  0.1016,  0.4178,
         0.0642,  0.2942, -0.0509,  0.1815,  0.1059,  0.4944, -0.0285,  0.8580,
        -0.0538,  0.0401,  0.1590, -0.1119, -0.0221,  0.0634, -0.0075,  0.1163,
         0.2627,  0.0226, -0.0779,  0.2253,  0.0353, -0.0877, -0.2203, -0.0831,
         0.0576, -0.1096,  0.4837,  0.0099,  0.0314,  0.2644,  0.3387, -0.1243,
        -0.0292, -0.3218,  0.3396,  0.0719,  0.0078, -0.1630, -0.2055, -0.8436,
         0.0136,  0.2393, -0.0234,  0.0979,  0.0092,  0.1803, -0.0447, -0.1167,
        -0.0567,  0.4347,  0.2244, -0.1373, -0.5136, -0.2442,  0.2499,  0.1208,
         0.1008,  0.0209




Saving model with loss 0.120...


Epoch [12/50]: 100%|██████████████████████████████████████████████████████| 2/2 [00:00<00:00, 61.08it/s, loss=1.52]

-------predict: tensor([-0.0094,  0.9239, -0.1876,  0.5926, -0.1747, -0.1542, -0.0633, -0.5631,
         0.0842,  0.1449, -0.7664,  0.0349,  0.0176,  0.1895, -0.1175,  0.3307,
         0.0833,  0.1274,  0.3456, -0.2146, -0.0213,  0.0470, -0.1110, -0.0428,
        -0.1045, -0.0190, -0.1721,  0.1214,  0.0244,  0.3188, -0.1507,  0.1883,
         0.0785, -0.0476, -0.4320, -0.2643, -0.1199,  0.0505,  0.0162,  0.2998,
        -0.0633, -0.6266,  0.0463, -0.1664,  0.9073,  0.0792,  0.2634,  0.1076,
         0.1518,  0.1490,  0.2980,  0.3028, -0.9130,  0.3894,  0.1263,  0.0942,
        -0.0418,  0.2954,  0.3664,  0.3267,  0.3397,  0.1053,  0.2019, -0.0673,
         0.0554,  0.3680, -0.4319, -0.0098,  0.3246, -0.1252, -0.0480,  0.1768,
         0.1367, -0.0404,  0.2233, -0.0453, -0.0313, -0.1491,  0.0825, -0.0881,
         0.2059,  0.0815, -0.0735, -0.0117,  0.6605, -0.2536,  0.0451,  0.5468,
         0.5293,  0.0746, -0.8561,  0.1055,  0.1226, -0.1333, -0.3215, -0.1286,
        -0.5261,  0.0080




Saving model with loss 0.116...


Epoch [13/50]: 100%|████████████████████████████████████████████████████| 2/2 [00:00<00:00, 62.84it/s, loss=0.0312]

-------predict: tensor([ 1.7464e-01, -3.9975e-02,  4.9631e-01, -1.5982e-01,  2.6893e-01,
        -1.7591e-01,  9.1704e-03,  3.3111e-01, -3.9795e-01, -1.3796e-01,
         3.2050e-01,  8.0242e-02, -8.4139e-02, -4.3718e-01, -3.2105e-01,
        -1.7709e-02,  3.4975e-02,  1.2134e-01, -4.0938e-02,  7.6333e-01,
        -1.9380e-01, -1.9910e-03,  3.9485e-02,  7.7095e-02, -1.4453e-01,
         5.9194e-03, -3.5238e-01,  1.0217e-01,  1.2238e-01,  5.8487e-02,
        -2.6428e-02,  1.6089e-01,  1.3309e-01,  1.1612e-01, -2.9514e-01,
         9.0848e-04, -3.0921e-01,  7.2679e-02,  7.6659e-02,  1.3942e-01,
         3.9756e-03, -1.0921e-01,  7.2325e-01, -2.5035e-01,  4.5446e-01,
         3.7333e-01,  1.6263e-01, -2.3362e-01, -4.0053e-02,  1.0377e-01,
         4.0894e-02,  2.1903e-02, -3.5927e-01, -1.0097e-01, -3.8649e-01,
         5.7170e-02, -2.5114e-01,  1.6684e+00,  1.8147e-01,  1.5061e-01,
         1.8532e-01, -1.8628e-01,  5.8911e-01, -6.8036e-02,  4.5004e-02,
        -1.9017e-01,  6.9344e-02, -




Saving model with loss 0.113...


  0%|                                                                                        | 0/2 [00:00<?, ?it/s]

-------predict: tensor([ 0.2703, -0.1992,  0.0666,  0.0668,  0.1499, -0.0686, -0.2191,  0.0545,
        -0.1546,  0.2510,  0.1217, -0.0035,  0.5252,  0.2979,  0.1026, -0.2602,
         0.1713,  0.0095,  0.1175, -0.4130, -0.0476, -0.3041,  0.1492,  0.5582,
         0.1971,  0.2330, -0.0674, -0.0624,  0.1971,  0.1852,  0.1626,  0.4655,
         0.6486,  0.3099, -0.2056, -0.0402,  0.0593, -0.1282,  0.0063, -0.0100,
         0.3298,  0.0330, -0.1218,  0.5694, -0.1489,  0.3772,  0.2169,  0.4230,
         0.0596,  0.1957,  0.3214, -0.1702, -0.4292, -0.1105, -0.0695, -0.3787,
        -0.2052, -0.1908, -0.0895,  0.1205,  0.4756, -0.0560, -0.0900,  0.1022,
        -0.0773,  0.2382,  0.4315,  0.0257, -0.1174,  0.0201,  0.1242, -0.0109,
         0.2344, -1.1337,  0.8741,  0.1946,  0.0733,  0.1635, -0.0658, -0.1191,
        -0.7296, -0.0670, -0.0422,  0.4868,  0.1592,  0.4082,  0.2179, -0.5305,
         0.1242,  0.0573, -0.1759,  0.1575, -0.6123,  0.0479, -0.0579,  0.1923,
         0.1167,  0.1808

Epoch [14/50]: 100%|█████████████████████████████████████████████████████| 2/2 [00:00<00:00, 58.84it/s, loss=0.105]

tensor([-1.1386e-01, -3.1619e-03, -7.1206e-02, -1.9705e-02, -3.4175e-01,
         6.6167e-01, -2.4948e-01, -1.2142e-01,  4.2126e-02,  3.0004e-01,
         1.9640e-01,  4.7452e-02,  1.5892e-01,  8.9280e-03,  2.7419e-01,
         1.2040e-02, -5.7422e-02,  2.8527e-01,  1.3904e-02, -4.3638e-01,
         1.1890e-01, -1.6644e-02, -1.9864e-01, -4.5970e-02,  2.8987e-02,
         2.4800e-01, -1.3507e-01, -1.3818e-03, -6.3523e-02,  1.9810e-01,
         5.8750e-03,  2.2802e-01,  3.2908e-01,  1.3274e-02,  4.3104e-02,
        -3.0333e-02, -2.1008e-02, -4.1969e-01, -1.4102e-04,  2.4295e-02,
         2.2071e+00,  5.5103e-03,  7.4072e-02,  2.0005e-01, -9.4370e-02,
         2.7158e-02,  2.7804e-01, -2.3377e-01,  2.8582e-02,  2.9952e-01,
        -2.8738e-02, -3.7727e-02, -4.9693e-01, -4.1669e-02, -4.1969e-06,
        -1.3775e-01,  5.6189e-02,  3.4854e-02,  4.5290e-02,  2.7031e-01,
        -1.6353e-01, -2.1879e-02,  6.7521e-01,  4.1903e-02,  2.8127e-02,
         5.5241e-02, -7.2309e-02,  2.4615e-01,  3.3




Saving model with loss 0.110...


  0%|                                                                                        | 0/2 [00:00<?, ?it/s]

-------predict: tensor([-0.1452, -0.2047,  0.2493,  0.1727,  0.0365,  0.0944, -0.0688, -0.1408,
         0.1424,  0.2559, -0.1806, -0.0691,  0.5343,  0.0692, -0.4759, -0.1473,
         0.6323, -0.0233,  0.1679,  0.3702,  1.3131,  0.0412, -0.1856,  0.2234,
        -0.0858,  0.3724, -0.1185,  0.1168, -0.0982,  0.0791, -0.2572,  0.0292,
         0.4276, -0.0986,  0.0681,  0.3211,  0.2597,  0.3334, -0.4002,  0.4095,
        -0.0288, -0.0155,  0.1977, -0.0350,  0.1451, -0.0395, -0.1628, -0.0650,
        -0.1767,  0.0220,  0.1266, -0.0876, -0.0422, -0.0424,  0.1425,  0.2184,
         0.1987,  0.2057, -0.0384, -0.0260,  0.2617,  0.0098, -0.0382, -0.1998,
        -0.5363,  0.1235,  0.1518,  0.6263,  0.4955,  0.4383,  0.2593,  0.0209,
        -0.0302,  0.3697,  0.4627,  0.2510,  0.1926,  0.0402,  0.1746, -0.1985,
        -0.1155, -0.3188,  0.4335,  0.1324,  0.1028, -0.4710, -0.0289, -0.0293,
         0.0177,  0.1725, -0.2147,  0.1180, -0.5783, -0.0852,  1.0850,  0.1089,
         0.1032,  0.1706

Epoch [15/50]: 100%|████████████████████████████████████████████████████| 2/2 [00:00<00:00, 59.71it/s, loss=0.0114]

tensor([-1.4991e-01, -7.1206e-02,  2.7265e-01,  2.9132e-01,  3.3819e-02,
         4.1903e-02,  5.5103e-03, -1.2142e-01,  1.4238e-01, -7.5340e-02,
         4.3104e-02, -3.0737e-02, -5.2797e-01,  1.0088e-01, -2.4948e-01,
        -1.3775e-01,  2.9952e-01,  7.2377e-02,  8.9560e-03,  2.4615e-01,
        -6.0509e+00, -7.2477e-02, -4.1969e-06,  1.0193e-01,  5.6202e-02,
         6.2135e-01, -2.1879e-02,  2.7172e-01,  5.4263e-02, -3.7727e-02,
        -1.6644e-02, -3.7125e-01, -1.5848e-02,  1.4430e-02,  1.1882e-02,
        -1.3507e-01,  2.2802e-01,  5.8750e-03, -5.5249e-01, -4.1969e-01,
         2.7419e-01, -6.3523e-02,  4.7452e-02, -1.2856e-02, -4.5970e-02,
        -4.3638e-01, -7.2309e-02,  5.5600e-01,  1.9640e-01,  2.8582e-02,
         2.9925e-02, -3.0333e-02, -2.2393e-02,  1.2040e-02, -3.2542e-02,
        -3.5336e-02, -2.8738e-02, -1.3127e-02,  5.5241e-02, -2.1008e-02,
         1.3274e-02,  2.8127e-02, -1.9864e-01, -5.7422e-02, -5.8953e-01,
         2.5802e-01, -3.1619e-03,  1.5892e-01,  2.0




Saving model with loss 0.108...


Epoch [16/50]:   0%|                                                                         | 0/2 [00:00<?, ?it/s]

-------predict: tensor([-5.1327e-01, -3.6640e-01,  5.9422e-01,  1.6777e-01, -1.8629e-02,
         2.5981e-01,  2.0434e-01, -1.7085e-02, -6.5534e-02,  3.3795e-01,
        -3.3004e-01,  4.4490e-01, -2.4000e-01, -1.9114e-01, -1.2687e-01,
        -1.5377e-01,  2.0851e-01,  1.6401e-01,  1.3451e-01,  1.4785e-01,
        -1.8670e-02,  1.8285e+00,  2.5724e-02, -4.1766e-02,  1.7360e-01,
        -7.4915e-02, -1.3765e-01, -9.1141e-02,  5.7835e-02,  1.6978e-01,
        -1.1432e-01, -6.6591e-02,  6.0733e-02,  8.9775e-01,  2.8730e-02,
        -1.7196e-01,  5.4915e-01,  5.4534e-01,  2.5627e-01, -9.1477e-02,
         6.9874e-01,  3.7109e-02, -2.0171e-01, -2.4891e-02,  6.0141e-03,
         3.4043e-01,  8.8038e-02,  4.3261e-01, -1.3578e-01,  6.9873e-02,
        -1.2159e-03,  6.4769e-02, -5.7872e-02, -4.8184e-02,  2.3997e-01,
        -9.1402e-03, -6.7924e-02,  1.5246e-01, -9.3390e-02,  2.1621e-02,
        -8.1897e-02,  6.2045e-01,  8.1911e-02,  4.2709e-02,  1.4274e-01,
         4.0455e-01,  4.9449e-02, -

Epoch [16/50]: 100%|███████████████████████████████████████████████████| 2/2 [00:00<00:00, 60.58it/s, loss=0.00841]

-------predict: tensor([-0.0495], device='cuda:0', grad_fn=<SqueezeBackward1>), y: tensor([-0.0460], device='cuda:0')----------
tensor([0.0421])
Epoch [16/50]: Train loss: 1.2399, Valid loss: 0.1059





Saving model with loss 0.106...


Epoch [17/50]:   0%|                                                               | 0/2 [00:00<?, ?it/s, loss=2.2]

-------predict: tensor([ 0.2794, -0.1455,  0.1493,  0.1957, -0.0403, -0.0144,  0.1970,  0.1740,
        -0.0203, -0.0785,  0.0350, -0.5358,  0.2462,  0.0878,  0.0888,  0.4501,
        -0.1075, -0.7281,  0.1122, -0.2621, -0.0018, -0.1115,  0.0028,  0.4726,
         0.1435, -0.0338, -0.1584,  0.7002,  0.3447, -0.0031, -0.3892,  0.3323,
        -0.1727, -0.3555, -0.0861,  0.0314, -0.3070, -0.0418, -0.0901,  0.0107,
         0.6005,  0.2583, -0.0184,  0.1203,  0.4278,  0.4411,  0.0861,  0.1392,
         0.0071, -0.1397,  0.2147,  0.1244,  0.1829, -0.1127,  0.6058, -0.5562,
         0.1806, -0.1617,  0.0022,  0.5913, -0.1878, -0.0681,  0.1077,  0.1481,
        -0.6102,  0.1066,  0.1889,  0.5240,  0.1996,  0.2914,  0.3246,  0.1610,
         0.3351,  0.3004,  0.3324,  0.1420,  0.5123,  0.1202,  0.2153, -0.0463,
         0.3961,  0.3300, -0.1183,  0.2129, -0.0439, -0.6527,  0.3236,  0.0847,
        -0.0032,  0.0302,  0.3512,  0.5862, -0.0197, -0.1084, -0.0729, -0.1980,
         0.0597, -0.0293

Epoch [17/50]: 100%|█████████████████████████████████████████████████████| 2/2 [00:00<00:00, 59.74it/s, loss=0.056]

-------predict: tensor([0.0327], device='cuda:0', grad_fn=<SqueezeBackward1>), y: tensor([-0.2033], device='cuda:0')----------
tensor([0.0144])
Epoch [17/50]: Train loss: 1.2332, Valid loss: 0.1043





Saving model with loss 0.104...


Epoch [18/50]:   0%|                                                                         | 0/2 [00:00<?, ?it/s]

-------predict: tensor([-5.3501e-02, -2.4042e-01,  3.5521e-03, -8.3736e-02, -2.4637e-01,
         5.3358e-01, -5.0492e-01,  4.1861e-01,  1.2796e+00, -1.8386e-01,
         4.5914e-01, -2.1873e-01,  1.6368e-01, -1.2231e-02,  1.2609e-01,
         3.8881e-02, -5.2440e-02, -9.6085e-02,  6.9171e-02,  1.9210e-01,
        -9.1177e-02,  2.4258e-01, -4.2399e-02,  2.0974e-01, -1.2178e-01,
        -3.0384e-01,  1.0746e+00,  2.1545e-01, -3.0156e-01, -1.5315e-01,
         3.9992e-01, -4.0745e-02, -1.7351e-02,  1.0193e-02, -5.0919e-02,
         2.6608e-01,  6.5109e-02,  2.1873e-01,  9.2028e-02, -2.6413e-01,
        -1.1832e-01, -5.0736e-02,  1.6568e-01, -1.6962e-01,  2.5725e-01,
        -1.3222e-01,  1.5493e-01,  1.3164e-01,  2.0511e-01,  1.6986e-01,
         3.8427e-01, -6.8389e-01,  5.1288e-03,  7.0637e-02, -3.5459e-04,
         1.4371e-01,  1.2887e-01,  5.4105e-01,  9.7745e-02,  4.7535e-02,
         2.5753e-01,  1.5363e-01, -1.4039e-02, -6.9645e-02,  4.5967e-02,
        -7.7682e-02,  1.1377e-01,  

Epoch [18/50]: 100%|█████████████████████████████████████████████████████| 2/2 [00:00<00:00, 60.09it/s, loss=0.099]

-------predict: tensor([-0.1666], device='cuda:0', grad_fn=<SqueezeBackward1>), y: tensor([-0.1608], device='cuda:0')----------
tensor([0.1481])
Epoch [18/50]: Train loss: 1.2186, Valid loss: 0.1028





Saving model with loss 0.103...


Epoch [19/50]:   0%|                                                              | 0/2 [00:00<?, ?it/s, loss=1.91]

-------predict: tensor([ 6.9399e-03,  1.5484e-01, -9.6422e-02,  4.5389e-01,  3.0559e-01,
        -9.0133e-02, -3.9375e-02,  2.3721e-01, -2.8558e-01, -2.2490e-01,
        -1.6672e-02,  7.6473e-03,  2.1957e-01, -1.7224e-01,  3.4206e-01,
         2.4293e-01, -4.8222e-02,  3.7549e-02,  1.8765e-01,  2.0368e-01,
        -1.0341e-01,  2.7348e-01, -3.3146e-02,  3.0514e-01, -1.0999e-01,
         3.8708e-01,  1.8307e-01,  7.9743e-02, -2.5928e-01, -2.8652e-01,
         1.7697e-01,  3.6323e-01, -9.0667e-02,  5.3473e-01,  2.3314e-01,
        -4.1952e-01, -1.4329e-01, -5.1383e-02,  1.5745e-01, -6.4780e-03,
         4.6121e-01,  2.7333e-01,  1.2284e-01,  1.2405e-01,  1.7209e-01,
        -3.2113e-01,  5.2054e-01,  4.2685e-02, -4.2439e-02,  3.3534e-01,
         3.9396e-01, -1.4945e-01, -3.6496e-01,  4.3392e-01, -6.4125e-02,
         2.6182e-01,  2.1461e-02, -2.0590e-01,  4.7238e-01,  1.2510e-01,
         1.9497e-01, -1.6890e-01, -3.6701e-02, -1.3300e-02,  5.0717e-01,
         5.3428e-01,  1.0939e-01, -

Epoch [19/50]: 100%|█████████████████████████████████████████████████████| 2/2 [00:00<00:00, 58.99it/s, loss=0.324]

-------predict: tensor([0.2370], device='cuda:0', grad_fn=<SqueezeBackward1>), y: tensor([-0.2422], device='cuda:0')----------
tensor([-0.0707])
Epoch [19/50]: Train loss: 1.2131, Valid loss: 0.1014





Saving model with loss 0.101...


Epoch [20/50]:   0%|                                                              | 0/2 [00:00<?, ?it/s, loss=1.91]

-------predict: tensor([-9.9708e-02, -1.2734e-01, -1.3828e-01,  6.8401e-01,  2.0316e-01,
        -6.5051e-02,  7.0873e-02, -1.8630e-01, -1.9136e-02,  1.5849e-01,
         8.7470e-02,  3.2033e-01, -8.4016e-02,  2.6177e-01,  3.1093e-01,
        -6.1329e-01,  1.2684e-01,  7.6972e-02, -9.8186e-02,  3.5877e-01,
         5.4284e-01,  8.0579e-01,  1.9663e-01, -1.3507e-01, -4.8978e-03,
        -9.9218e-02, -1.4580e-01,  1.1718e-01,  5.9596e-01, -2.1079e-02,
         2.8949e-01, -3.1354e-02,  1.5382e-02,  3.6366e-01, -5.3129e-02,
        -7.5489e-02,  1.2199e-01, -2.8764e-01, -1.2702e-01, -1.9111e-01,
         4.3837e-01, -3.7538e-01,  1.6993e-02,  3.0749e-01,  3.8732e-02,
        -1.3606e-01,  5.5306e-01,  9.9321e-02, -9.0997e-02,  1.0910e-01,
         2.0818e-01, -3.5510e-01, -5.4502e-02,  4.5284e-02,  2.3361e+00,
        -8.3687e-02,  6.1993e-01, -7.9861e-01, -7.6312e-02,  2.4269e-01,
         4.1063e-01,  2.0578e-01, -7.2353e-02,  5.1807e-02,  1.2587e+00,
         2.8539e-01, -7.8596e-02,  

Epoch [20/50]: 100%|█████████████████████████████████████████████████████| 2/2 [00:00<00:00, 62.80it/s, loss=0.171]


-------predict: tensor([0.0438], device='cuda:0', grad_fn=<SqueezeBackward1>), y: tensor([-0.2418], device='cuda:0')----------
tensor([-0.2550])
Epoch [20/50]: Train loss: 1.2046, Valid loss: 0.1000
Saving model with loss 0.100...


  0%|                                                                                        | 0/2 [00:00<?, ?it/s]

-------predict: tensor([ 2.2126e-01,  1.0962e-01,  2.3307e-01, -1.4214e-02,  9.2167e-02,
         6.1040e-02, -8.7167e-02, -1.8690e-01,  3.1305e-01,  1.2662e-01,
        -2.0073e-01, -1.7253e-02,  2.9567e-01,  1.3089e-01,  7.3918e-02,
         6.3232e-01,  2.8061e-01, -8.8038e-02,  1.0749e+00,  1.9904e-01,
         1.7675e-01,  5.4855e-02, -7.0137e-02,  3.2679e-01, -2.4542e-01,
        -8.1313e-02,  1.8760e-02, -2.4078e-01, -9.4422e-02,  4.2938e-01,
         1.8624e-01,  5.9771e-02, -2.6638e-01,  3.2887e-01,  7.0875e-01,
         1.2638e-01,  1.4428e-01, -2.9171e-01,  9.4882e-03,  3.2338e-02,
        -1.4292e-01, -1.6265e-01,  4.4069e-01, -2.2380e-02,  1.2215e-01,
         3.2421e-01,  2.2351e-01, -1.3664e-01,  1.2570e-01,  4.4309e-01,
         2.1107e+00,  2.1491e-01,  1.5684e-01,  5.6746e-01,  5.9395e-01,
         6.3137e-01,  2.5533e-01,  6.1694e-01,  1.8937e-01, -3.0461e-01,
         1.6073e-01, -2.2398e-02,  1.2578e-01, -6.9396e-02,  7.7717e-02,
        -1.0474e+00, -2.1814e-02, -

Epoch [21/50]:   0%|                                                              | 0/2 [00:00<?, ?it/s, loss=1.85]

tensor([-1.3775e-01,  8.9560e-03,  2.2071e+00, -2.2303e-01, -4.9693e-01,
         7.4072e-02, -2.1879e-02,  6.0025e-02,  2.9952e-01,  3.8809e-03,
         5.5600e-01, -7.5340e-02, -5.7422e-02, -5.8953e-01,  2.7031e-01,
         2.6465e-01,  4.7452e-02, -3.2542e-02,  9.3753e+00, -1.9864e-01,
         1.3043e+00,  7.6434e-02,  2.5802e-01, -2.5499e-01, -3.4574e-01,
         5.4263e-02,  1.9640e-01, -3.5336e-02,  2.4615e-01,  2.0005e-01,
         2.8527e-01,  5.6202e-02, -1.2142e-01,  2.2287e-02,  1.4527e+00,
         2.7265e-01, -7.0731e-02, -5.2776e-02,  2.0601e-02,  4.1903e-02,
        -5.5249e-01, -3.1619e-03, -1.2856e-02, -3.7727e-02, -1.9864e-01,
         3.0004e-01,  2.4800e-01, -1.4102e-04,  1.4809e-01, -9.7351e-02,
        -6.0509e+00, -4.1969e-01, -2.8738e-02, -7.2309e-02,  1.1890e-01,
         1.2040e-02,  3.2908e-01,  3.6443e-01,  2.4295e-02, -4.1669e-02,
         2.7158e-02,  1.4430e-02,  1.4238e-01, -1.8274e-02, -4.5528e-03,
        -1.6254e-02,  5.6118e-02, -1.6353e-01,  5.5

Epoch [21/50]: 100%|█████████████████████████████████████████████████████| 2/2 [00:00<00:00, 62.81it/s, loss=0.102]


Epoch [21/50]: Train loss: 1.1938, Valid loss: 0.0984
Saving model with loss 0.098...


  0%|                                                                                        | 0/2 [00:00<?, ?it/s]

-------predict: tensor([ 1.6156e-01, -4.1240e-02,  7.0387e-02,  9.4809e-02,  9.5492e-02,
         1.0477e-01,  6.9870e-01,  9.0813e-02,  2.6967e-01,  4.1134e-02,
        -1.3232e-01,  1.3383e-01,  4.6711e-01,  5.1421e-02, -9.6522e-04,
         1.1749e-01, -1.2602e-01, -3.6067e-03,  3.0801e-02,  1.6048e-01,
         5.0739e-01,  5.4524e-01,  7.8447e-02, -5.2633e-03,  4.8111e-02,
         3.4614e-02, -9.3052e-03,  1.5751e-01,  1.8906e-03,  3.9293e-02,
        -5.0252e-02,  2.4898e-01, -2.1029e-01,  2.3771e-01,  3.5321e-01,
         1.3422e-01,  4.6731e-01,  9.0984e-02,  2.4190e-01, -7.7112e-01,
        -4.4254e-02,  8.6922e-02,  2.4908e-01,  3.4009e-01, -4.0665e-02,
        -1.7363e-01, -3.5010e-01,  2.3418e-03, -2.2182e-01,  9.6314e-02,
         6.5700e-02, -9.1432e-02, -3.5954e-01,  3.7114e-01,  3.9251e-02,
        -8.2300e-02,  6.5581e-02, -2.7056e-01,  3.7120e-01, -2.7243e-02,
        -3.0467e-02, -1.9384e-01,  8.1407e-01,  3.8968e-01,  5.5763e-01,
         1.0114e-01,  2.5703e-02,  

Epoch [22/50]:   0%|                                                              | 0/2 [00:00<?, ?it/s, loss=1.97]

tensor([-4.1969e-01, -1.1993e-01, -4.3638e-01,  1.8101e-01, -5.2797e-01,
         2.8527e-01,  6.2135e-01, -1.1954e-02,  1.2040e-02, -1.2122e-02,
         8.9280e-03,  7.6434e-02,  1.1882e-02,  2.8127e-02,  1.4809e-01,
         2.7804e-01, -5.2776e-02, -2.1008e-02, -2.4948e-01, -2.1879e-02,
        -3.5336e-02,  2.9952e-01,  2.4800e-01,  1.9640e-01,  2.5802e-01,
        -1.9864e-01, -3.7125e-01,  2.2802e-01,  3.8809e-03, -9.4370e-02,
        -1.4991e-01, -5.7422e-02,  1.9887e-02,  2.7265e-01, -1.3818e-03,
        -1.3775e-01, -7.2309e-02,  2.9925e-02, -6.3523e-02,  3.4854e-02,
        -7.5340e-02, -2.2393e-02,  4.5290e-02, -7.2477e-02, -1.3507e-01,
         7.2377e-02, -2.2636e-02,  2.8582e-02, -3.0378e-04,  4.2126e-02,
         2.9132e-01, -3.0333e-02, -5.4034e-01,  3.0004e-01,  5.5241e-02,
         5.6118e-02, -3.4175e-01, -1.9705e-02,  1.3274e-02,  8.9560e-03,
        -4.5528e-03,  5.4263e-02, -1.2856e-02, -1.6353e-01,  9.3753e+00,
         1.4527e+00, -3.7727e-02,  1.3043e+00,  1.1

Epoch [22/50]: 100%|█████████████████████████████████████████████████████| 2/2 [00:00<00:00, 61.80it/s, loss=0.361]


Epoch [22/50]: Train loss: 1.1926, Valid loss: 0.0968
Saving model with loss 0.097...


  0%|                                                                                        | 0/2 [00:00<?, ?it/s]

-------predict: tensor([-2.6323e-02,  1.2281e-02,  2.0302e-01, -2.8125e-02,  1.5262e-02,
        -4.7507e-02, -7.1197e-02, -1.8004e-02,  6.4298e-01, -3.8036e-02,
         3.1967e-01, -8.6255e-02, -2.2034e-02, -1.9795e-01,  2.7716e-02,
        -8.0893e-02,  1.5097e-01,  4.4241e-02,  1.5483e-01,  8.5783e-02,
         6.3017e-02, -7.7099e-02,  2.5181e-01, -1.8153e-01, -5.7704e-02,
         1.9290e-01,  5.5509e-02, -8.2249e-02, -3.5985e-03,  2.9122e-01,
         4.1039e-01,  6.0967e-04,  5.5987e-02,  1.7102e+00,  4.6573e-01,
        -1.8500e-01, -4.8258e-01, -1.4083e-01, -1.7787e-02,  1.7256e-01,
         2.5894e-01,  2.1424e-02, -7.6974e-02,  1.3395e-02,  1.9075e-01,
         1.1373e-01, -1.4040e-01, -8.9895e-02,  8.3744e-02,  2.3895e-01,
         4.4752e-01, -1.2676e-03,  1.6351e-01,  1.2032e-02, -1.0688e-01,
         3.2233e-01, -1.9308e-02,  5.7373e-01,  5.4373e-02,  1.6302e-01,
         1.2126e-01, -2.2777e-01, -2.4177e-02,  3.9130e-02,  1.0907e-01,
         1.4767e-01, -8.5629e-02,  

Epoch [23/50]:   0%|                                                            | 0/2 [00:00<?, ?it/s, loss=0.0704]

-------predict: tensor([-0.0815], device='cuda:0', grad_fn=<SqueezeBackward1>), y: tensor([0.1626], device='cuda:0')----------
tensor([0.0223])


Epoch [23/50]: 100%|████████████████████████████████████████████████████| 2/2 [00:00<00:00, 65.51it/s, loss=0.0704]


Epoch [23/50]: Train loss: 1.1879, Valid loss: 0.0952
Saving model with loss 0.095...


  0%|                                                                                        | 0/2 [00:00<?, ?it/s]

-------predict: tensor([-0.0648, -0.2344,  0.1840,  0.0561, -0.1264, -0.5571, -0.3343,  0.1531,
        -0.0639, -0.1357, -0.3094,  0.0515,  0.2214,  0.3786,  0.2830,  0.1700,
         0.1391,  0.5729, -0.8408,  0.0512,  0.1119, -0.3452, -0.1546,  1.1157,
         0.2390,  0.3609,  0.3120,  0.6055, -0.0647, -0.0482,  0.4835, -0.3605,
         0.0870,  0.0823, -0.0112, -0.0221, -0.0459, -0.0586, -0.1548,  0.1190,
        -0.0383,  0.1909, -0.0659, -0.1945,  0.6070, -0.1512, -0.0383, -0.0817,
        -0.0524, -0.0508, -0.1597, -0.0632,  0.1111,  0.1947, -0.1367, -0.1524,
         0.1465, -0.2667,  0.3089, -0.2881, -0.0456,  0.0642, -0.0564,  0.8253,
         0.0710, -0.2323, -0.1659,  0.0730,  0.0849, -0.0105,  0.2619,  0.3822,
        -0.0270, -0.1552,  0.1067,  0.0528, -0.1220, -0.2893,  0.0232, -0.1501,
         0.2176,  0.6813,  0.0254,  0.0756,  0.0999,  0.1559,  0.4305,  0.1571,
         0.3731,  0.4771, -0.0893,  0.0952, -0.1450,  0.0904,  0.4675,  0.1861,
        -0.0978, -0.1552

Epoch [24/50]: 100%|█████████████████████████████████████████████████████| 2/2 [00:00<00:00, 63.77it/s, loss=0.171]

-------predict: tensor([0.2584], device='cuda:0', grad_fn=<SqueezeBackward1>), y: tensor([-0.0837], device='cuda:0')----------
tensor([0.0262])





Epoch [24/50]: Train loss: 1.1830, Valid loss: 0.0935
Saving model with loss 0.093...


  0%|                                                                                        | 0/2 [00:00<?, ?it/s]

-------predict: tensor([-1.3401e-01,  7.3279e-02,  2.1675e-01,  3.0524e-01,  4.1770e-01,
         5.4603e-01,  1.3139e-01,  3.7821e-01, -2.2866e-01,  1.2052e-01,
        -1.1329e-01,  1.3496e-02,  1.2959e-01,  2.8610e-01,  3.1907e-01,
        -1.5266e-01, -8.4078e-02,  1.6448e-01,  1.9758e-01, -1.8702e-02,
         4.0298e-01,  4.0335e-02,  3.9135e-01,  2.8171e-02, -1.0973e-01,
        -1.2891e-01, -7.7913e-02, -2.1171e-01, -6.1948e-02, -2.0897e-01,
        -7.3532e-01,  3.0399e-01,  1.0378e-01,  4.3756e-01, -2.1398e-02,
         2.0388e-01,  9.5822e-02,  2.6181e-01,  3.6542e-01,  1.2547e-01,
        -4.2811e-01,  1.6625e-02,  7.0372e-03,  1.3773e-01,  7.6300e-01,
         1.7466e-01, -1.8550e-01,  1.7858e-01,  2.6442e-01, -1.8652e-01,
        -9.8811e-02,  4.5291e-02,  1.8553e-01, -1.5906e-01,  4.1307e-02,
        -2.3936e-01,  1.8153e-01,  1.7454e-01,  3.2048e-01,  8.9511e-05,
        -4.8985e-01,  3.8905e-01,  1.2594e-01,  9.7612e-03,  1.1621e-02,
         6.6472e-02, -1.1995e-01,  

Epoch [25/50]: 100%|████████████████████████████████████████████████████| 2/2 [00:00<00:00, 63.48it/s, loss=0.0426]


tensor([-1.3775e-01,  2.8527e-01, -9.7351e-02, -2.2303e-01, -5.2797e-01,
         8.9280e-03, -1.9864e-01, -7.5340e-02, -3.0378e-04, -3.7125e-01,
        -4.1669e-02, -2.2636e-02, -1.2122e-02,  1.3043e+00,  3.3819e-02,
         2.7804e-01, -2.4948e-01,  2.9132e-01,  2.6465e-01, -4.1969e-01,
         1.5892e-01,  7.6434e-02,  2.2071e+00,  3.8809e-03, -1.9864e-01,
         2.5802e-01,  1.9887e-02, -7.0731e-02,  2.0601e-02,  1.3274e-02,
        -5.6639e-01, -1.1386e-01,  1.4238e-01,  2.4295e-02, -2.2393e-02,
        -5.7422e-02, -3.4175e-01,  2.7419e-01, -1.3127e-02, -2.1008e-02,
        -5.3402e-01,  2.7172e-01,  5.5241e-02, -9.4473e-02,  6.2135e-01,
         2.2802e-01,  1.3904e-02,  1.0088e-01,  1.1929e-02, -1.6254e-02,
         1.0193e-01,  7.2377e-02,  2.2287e-02,  1.1890e-01,  6.6167e-01,
         1.1882e-02,  1.4527e+00,  4.7452e-02, -1.5848e-02, -4.5970e-02,
         5.6202e-02,  2.4800e-01, -5.4034e-01, -2.1879e-02, -3.2542e-02,
         2.9952e-01,  8.9560e-03, -1.1672e-01, -1.6

Epoch [26/50]:   0%|                                                              | 0/2 [00:00<?, ?it/s, loss=1.68]

-------predict: tensor([ 2.3226e-01, -2.0444e-02, -6.6876e-02,  8.0876e-01,  1.8654e-01,
        -8.3950e-02, -2.2358e-02, -6.8452e-02, -1.9728e-01, -8.7616e-02,
         1.5372e-01, -5.2658e-04,  9.2054e-02, -7.7195e-02,  7.3158e-01,
         1.3079e-02,  8.9233e-02, -4.9100e-03,  1.5841e-01, -5.5587e-01,
         3.3004e-02, -6.9305e-01, -8.0587e-02,  2.0072e-01, -1.4037e-01,
        -1.3378e-01, -5.1987e-02,  1.9536e+00, -2.7499e-02, -2.1469e-01,
        -7.9315e-02,  5.9576e-02, -4.6593e-01,  2.0499e-01,  2.7741e-01,
         5.2074e-02,  5.7300e-01,  3.9029e-01, -2.7229e-01, -1.0279e-01,
        -4.2983e-01, -1.7707e-01, -2.1621e-02,  1.2658e-01, -1.7206e-01,
        -3.0540e-01,  1.9194e-01,  3.2498e-01, -1.0145e-01, -2.1283e-01,
        -1.1488e-01, -1.2376e-01, -1.6798e-01, -2.7158e-01,  3.7915e-01,
         8.4042e-02,  9.5328e-02, -1.0752e-02, -3.9112e-02, -7.8051e-02,
        -1.7636e-03,  7.2316e-02, -3.6929e-01, -3.5922e-01,  1.7679e-02,
         6.0468e-02,  9.9996e-03,  

Epoch [26/50]: 100%|█████████████████████████████████████████████████████| 2/2 [00:00<00:00, 62.27it/s, loss=0.277]


-------predict: tensor([0.3393], device='cuda:0', grad_fn=<SqueezeBackward1>), y: tensor([-0.0837], device='cuda:0')----------
tensor([0.0262])
Epoch [26/50]: Train loss: 1.1685, Valid loss: 0.0899
Saving model with loss 0.090...


  0%|                                                                                        | 0/2 [00:00<?, ?it/s]

-------predict: tensor([ 0.2520,  0.1525,  0.5673, -0.0611,  0.1050,  0.0683, -0.2202,  0.1282,
        -0.0105, -0.0688, -0.1105,  0.2893, -0.2696,  0.0548, -0.1088, -0.1713,
         0.3251,  0.0917, -0.0228,  0.1468, -0.0964, -0.1528, -0.1694, -0.0419,
        -0.0791, -0.2204, -0.3036, -0.1354, -0.1960, -0.2157, -0.1443,  0.0377,
         0.0676, -0.3532, -0.1063,  0.0152, -0.1333, -0.1262, -0.0593,  0.1278,
        -0.3398, -0.1457,  0.3689, -0.0922,  0.0526, -0.0136, -0.1310,  0.3621,
         0.0239,  0.4537, -0.0649,  0.3161, -0.1158, -0.0951, -0.0064,  0.3785,
        -0.1101, -0.1806, -0.1438,  0.7155,  0.0985, -0.1684, -0.0688,  0.0422,
        -0.1957, -0.1759, -0.4470, -0.1715, -0.0901, -0.3404,  0.0378,  0.1079,
        -0.3144,  0.1150, -0.3260,  0.0351, -0.1795, -0.1339,  0.4987, -0.2231,
         0.0570,  0.0405, -0.0019,  0.2681,  0.0013,  0.3633,  0.0279, -0.2440,
         0.1740, -0.1922, -0.3241, -0.0843,  0.0590,  0.0967, -0.2147,  0.4106,
         0.1445, -0.0576

Epoch [27/50]: 100%|████████████████████████████████████████████████████| 2/2 [00:00<00:00, 61.88it/s, loss=0.0515]

-------predict: tensor([-0.1513], device='cuda:0', grad_fn=<SqueezeBackward1>), y: tensor([-0.2433], device='cuda:0')----------
tensor([0.0561])
Epoch [27/50]: Train loss: 1.1605, Valid loss: 0.0881





Saving model with loss 0.088...


  0%|                                                                                        | 0/2 [00:00<?, ?it/s]

-------predict: tensor([ 0.2663,  0.2947,  0.3105, -0.0789, -0.1910, -0.0583,  0.3144, -0.1071,
         0.1369,  0.1568, -0.2436, -0.0061,  0.0259,  0.0591, -0.1123,  0.1632,
         0.1457, -0.1538,  0.1262,  0.3299, -0.1449, -0.0086,  0.5221,  0.5407,
         0.3327,  0.3989,  0.0384,  0.0754,  0.1436, -0.2126,  0.1479,  0.0390,
        -0.6757, -0.1411,  0.4709, -0.1382,  0.3597,  0.1495, -0.1331, -0.0544,
         0.3273,  0.0502,  0.0907, -0.0824, -0.3607, -0.1350, -0.1238,  0.1442,
        -0.1291, -0.0136,  0.2480,  0.0106,  0.0118, -0.0696, -0.0075,  0.1111,
        -0.0292, -0.3935,  0.0652,  0.3580, -0.1818, -0.1416,  0.1883,  0.3264,
         0.0949,  0.0862,  0.2166,  0.0859, -0.0612, -0.1221, -0.1791, -0.2379,
         0.1984,  0.1557, -0.3759,  0.0122, -0.0544, -0.0172, -0.0448,  0.3103,
         0.0191,  0.5537, -0.4199,  0.4749,  0.1212, -0.1479,  0.7278,  0.0330,
        -0.1060,  0.0537, -0.4311, -0.1248,  0.1122, -0.0090, -0.3061, -0.3002,
         0.0313, -0.1872

Epoch [28/50]: 100%|████████████████████████████████████████████████████| 2/2 [00:00<00:00, 66.74it/s, loss=0.0348]

-------predict: tensor([-0.1315], device='cuda:0', grad_fn=<SqueezeBackward1>), y: tensor([-0.2272], device='cuda:0')----------
tensor([0.0286])
Epoch [28/50]: Train loss: 1.1528, Valid loss: 0.0864





Saving model with loss 0.086...


  0%|                                                                                        | 0/2 [00:00<?, ?it/s]

-------predict: tensor([ 6.0820e-03,  1.0738e-02,  3.0616e-01, -1.6311e-01, -6.0107e-02,
         5.1550e-01,  2.5231e-01,  1.2256e-01,  2.2301e-01,  8.4051e-02,
        -1.3817e-01,  4.4120e-02, -2.6611e-02, -1.0020e-01, -1.4279e-01,
        -2.2973e-01,  3.0916e-02, -1.5872e-01, -1.7213e-01, -1.4449e-02,
        -7.3013e-02,  3.5578e-01,  2.6766e-02, -1.2009e-01,  1.9965e-02,
        -1.4461e-01, -2.4423e-01, -2.1495e-01,  7.4997e-02,  2.0423e-01,
         2.1774e-01, -2.0178e-01, -2.6166e-01, -1.8477e-01, -1.2265e-01,
         5.5601e-02, -1.1085e-01,  1.1051e-01, -8.1653e-02, -5.1238e-02,
        -6.3077e-02, -3.1021e-01, -9.5208e-02, -2.3388e-02, -4.9124e-02,
         6.0428e-02,  8.4449e-02, -3.7928e-01,  5.7181e-02,  2.3714e-01,
        -1.8624e-01,  3.2370e-02,  5.8579e-02,  2.9665e-01,  3.8918e-02,
         4.6287e-02, -1.8195e-01,  1.1669e-01,  2.5459e-03, -8.8217e-02,
        -1.0575e-01,  1.4226e-01,  1.0393e-01, -7.5188e-02,  2.1861e-01,
        -1.9524e-01, -3.2557e-01,  

Epoch [29/50]: 100%|████████████████████████████████████████████████████| 2/2 [00:00<00:00, 62.10it/s, loss=0.0161]


tensor([-3.4574e-01,  7.4072e-02, -4.7574e-02,  2.8127e-02,  2.7172e-01,
        -1.6254e-02,  2.9952e-01,  5.8750e-03, -2.4948e-01,  2.7031e-01,
         7.2377e-02, -5.2797e-01,  2.2287e-02,  1.9887e-02, -1.6644e-02,
        -3.7727e-02,  2.2802e-01, -1.1993e-01, -3.0333e-02, -9.4473e-02,
         1.1882e-02, -5.4034e-01, -1.2142e-01,  1.4238e-01,  6.2135e-01,
         4.1903e-02,  4.2126e-02, -4.1969e-06,  5.5600e-01,  2.2071e+00,
        -2.2303e-01, -3.0378e-04, -1.1954e-02,  5.6189e-02,  3.8809e-03,
        -1.3127e-02, -3.7125e-01,  2.7419e-01, -1.3853e-01,  2.6465e-01,
        -2.1008e-02,  1.4527e+00, -1.4991e-01,  2.4295e-02, -1.1386e-01,
         4.7452e-02, -4.3638e-01, -7.5340e-02, -3.5336e-02, -1.8274e-02,
         5.1115e-02, -1.9705e-02,  1.0193e-01,  1.2040e-02, -4.9693e-01,
        -2.3377e-01, -2.8738e-02,  1.3274e-02, -9.7351e-02, -5.3402e-01,
        -1.2856e-02,  2.0005e-01, -4.1969e-01, -5.6639e-01, -7.2477e-02,
        -5.2776e-02,  5.5241e-02, -4.1669e-02,  6.0

Epoch [30/50]: 100%|█████████████████████████████████████████████████████| 2/2 [00:00<00:00, 64.37it/s, loss=0.085]


-------predict: tensor([-0.2098, -0.0623,  0.1315, -0.0974, -0.3014, -0.2513, -0.1928, -0.2897,
        -0.2406,  0.2779, -0.1461, -0.0764, -0.0443, -0.1484,  0.0354,  0.1380,
        -0.4235,  0.1402,  0.4010, -0.2309, -0.2577, -0.0532, -0.1850, -0.1742,
        -0.1014, -0.2014, -0.0812,  0.0281,  0.3882,  0.2406,  0.3610, -0.1358,
        -0.1149,  0.1352,  0.0438, -0.0437, -0.1438, -0.5626,  0.1689,  0.3113,
        -0.2192,  0.0263,  0.0446, -0.0894,  0.1545,  0.0074, -0.1187, -0.2762,
         0.0526,  0.0043, -0.0796,  0.8836,  0.2251, -0.3068,  0.0710,  0.2689,
         0.0400,  0.0295, -0.1490, -0.2896,  0.1953,  0.1611, -0.1922, -0.0937,
        -0.1237,  0.2009, -0.2678, -0.0381,  0.1561, -0.3082, -0.0123, -0.2780,
        -0.4688, -0.1449,  0.0288, -0.1184,  0.1188, -0.1266,  0.0285,  0.0405,
        -0.0930, -0.0981, -0.2692, -0.4255, -0.0285, -0.1184,  0.1549, -0.1582,
         0.2300, -0.0897, -0.1552, -0.2211, -0.1315, -0.2437,  0.5005,  0.0116,
         0.0067, -0.0125

  0%|                                                                                        | 0/2 [00:00<?, ?it/s]

-------predict: tensor([ 7.1214e-01, -4.2988e-01,  9.8526e-02, -2.1934e-01, -2.2098e-01,
         8.9504e-02,  5.9629e-02, -9.4281e-02, -2.7480e-02, -1.1929e-01,
         4.6384e-02,  5.1729e-02,  1.3776e-01,  8.3909e-02,  2.2776e-01,
        -4.0143e-02, -9.5595e-02,  2.3670e-02,  2.0752e-01, -4.6645e-01,
         4.9566e-02,  6.3703e-02,  3.5167e-03, -4.7557e-02, -1.3610e-01,
         2.2693e-01, -3.1042e-03, -9.9665e-02, -1.6909e-01, -1.1500e-01,
        -6.5189e-01,  1.6297e-01, -2.3729e-01,  9.1800e-02, -1.7003e-01,
        -8.1341e-02, -2.3192e-01, -2.0994e-01, -7.0478e-02, -1.6942e-01,
         2.1762e-01,  2.5654e-01,  8.5405e-03, -9.8995e-02, -1.2953e-01,
         2.7058e-01,  1.3374e-01,  2.3352e-02, -2.7159e-01,  6.2412e-01,
         1.8330e-01,  2.9417e-01, -1.6706e-01, -1.6103e-01, -1.8009e-01,
        -2.0267e-03, -1.1545e-01,  1.5565e-01, -7.8012e-03, -2.3303e-01,
        -2.0815e-01, -2.0878e-01,  2.2732e-01, -4.6783e-01, -8.4228e-02,
         4.3256e-02, -3.8721e-02,  

Epoch [31/50]: 100%|████████████████████████████████████████████████████| 2/2 [00:00<00:00, 63.36it/s, loss=0.0582]


tensor([-1.2856e-02,  5.6202e-02,  5.8750e-03, -1.2122e-02, -3.0737e-02,
        -1.6353e-01, -7.1206e-02,  2.7804e-01,  2.2071e+00,  1.4430e-02,
         2.7031e-01, -3.4574e-01, -6.0509e+00,  2.0005e-01,  1.8101e-01,
        -9.4473e-02,  2.7419e-01, -5.2776e-02, -3.4175e-01, -5.5249e-01,
         1.4809e-01,  2.5802e-01, -2.4948e-01,  4.1903e-02, -1.3853e-01,
        -5.2797e-01, -5.8953e-01,  5.6118e-02,  2.4615e-01,  2.6243e-02,
         3.6443e-01, -1.9864e-01, -1.6644e-02,  4.5290e-02, -1.9705e-02,
         2.0601e-02, -9.7351e-02, -5.3402e-01, -1.3818e-03, -1.1672e-01,
         3.3819e-02,  6.2135e-01, -1.3775e-01,  7.6434e-02, -3.2542e-02,
        -1.3127e-02,  2.7172e-01, -2.1879e-02,  7.2377e-02,  1.2040e-02,
        -7.2477e-02,  1.3043e+00, -1.4102e-04,  4.3104e-02, -3.0378e-04,
         1.9640e-01, -7.0731e-02,  1.9887e-02, -6.3523e-02, -5.7422e-02,
        -1.5848e-02,  6.7521e-01, -1.8274e-02,  5.5241e-02,  2.8987e-02,
        -4.7574e-02,  5.4263e-02, -1.1386e-01,  2.9

  0%|                                                                                        | 0/2 [00:00<?, ?it/s]

-------predict: tensor([ 7.1686e-01, -1.8556e-01, -7.4042e-02, -2.6286e-01, -3.3611e-02,
         3.2015e-02,  1.6292e-02,  5.8685e-01, -2.8625e-01, -8.4518e-02,
         9.7379e-02,  4.1439e-01, -1.0764e-01,  1.1925e-01, -5.3548e-02,
        -1.5148e-01, -4.3035e-02, -4.3801e-01, -3.4281e-01,  1.4461e-01,
        -1.9572e-01,  1.9948e-01, -1.1064e-01, -4.1847e-02,  9.4641e-02,
        -8.6880e-03, -6.8736e-02, -1.9910e-01,  2.3680e-02, -1.2288e-01,
         9.4937e-03,  3.1336e-03, -4.1438e-02,  8.3865e-02,  4.2074e-02,
         3.1254e-01,  3.5122e-01,  2.6977e-02, -1.8314e-01,  4.5739e-02,
        -6.3786e-02,  3.0414e-01, -2.8502e-01, -3.4251e-01, -1.1106e-01,
         3.0474e-02, -3.7111e-02, -1.1754e-02,  2.9514e-01, -4.9500e-02,
         3.5158e-02,  1.4211e-01, -2.0053e-01,  4.9702e-02, -6.2134e-02,
         3.2365e-01, -2.0687e-01, -1.7845e-01, -3.2954e-01,  1.6126e-01,
        -2.8777e-01,  1.3980e-01,  6.5975e-02, -2.1557e-01,  9.3005e-02,
        -7.7126e-02, -8.4803e-02,  

Epoch [32/50]: 100%|██████████████████████████████████████████████████████| 2/2 [00:00<00:00, 65.09it/s, loss=0.02]

-------predict: tensor([-0.1145], device='cuda:0', grad_fn=<SqueezeBackward1>), y: tensor([-0.1581], device='cuda:0')----------
tensor([0.0199])
Epoch [32/50]: Train loss: 1.1170, Valid loss: 0.0800





Saving model with loss 0.080...


Epoch [33/50]:   0%|                                                                         | 0/2 [00:00<?, ?it/s]

-------predict: tensor([-0.1547, -0.1815, -0.7944, -0.1401, -0.0340, -0.1053, -0.0229,  0.1463,
         0.0140, -0.1242, -0.0394, -0.2144,  0.3591, -0.1782,  0.1512,  0.0214,
         0.1248, -0.1002, -0.2627,  0.0764,  0.0159, -0.5478,  0.2975, -0.0184,
        -0.1527,  0.0929,  0.1366,  0.1323, -0.2928, -0.0225,  0.2174,  0.0419,
        -0.2432, -0.1510,  0.0955,  0.1032, -0.2056, -0.1498, -0.1909, -0.3968,
        -0.0416, -0.0672,  0.3872,  0.0591,  0.0172, -0.0648, -0.1079, -0.1730,
        -0.3616, -0.0648, -0.0399, -0.2138,  0.0924, -0.1311, -0.1863,  0.0141,
         0.0763,  0.0858, -0.0197,  0.6799, -0.1180,  0.0072, -0.0258, -0.1433,
         0.0428, -0.0267, -0.0572,  0.1373, -0.3700, -0.1964, -0.0374,  0.6305,
         0.1029, -0.3623, -0.3546, -0.0149, -0.0654, -0.1326, -0.0056, -0.0169,
        -0.0324,  0.1809,  0.5182, -0.2566,  0.1424, -0.0495,  0.0692, -0.3809,
        -0.0354, -0.0983,  2.2099,  0.0344, -0.1415,  0.3616,  0.1955,  0.3904,
         0.0195, -0.1772

Epoch [33/50]: 100%|████████████████████████████████████████████████████| 2/2 [00:00<00:00, 64.94it/s, loss=0.0198]

-------predict: tensor([-0.1194], device='cuda:0', grad_fn=<SqueezeBackward1>), y: tensor([-0.2281], device='cuda:0')----------
tensor([-0.0302])
Epoch [33/50]: Train loss: 1.1085, Valid loss: 0.0787





Saving model with loss 0.079...


  0%|                                                                                        | 0/2 [00:00<?, ?it/s]

-------predict: tensor([-6.5924e-02, -9.0205e-02,  3.6584e-02,  2.6213e-01, -2.0223e-01,
         1.4455e-01, -1.4892e-01, -2.8844e-01, -1.7441e-01, -2.3936e-01,
         5.2022e-02, -1.5614e-01, -4.0406e-02, -1.7850e-01,  1.7872e-01,
        -1.4801e-01, -9.0450e-02,  6.4444e-02, -9.5049e-02, -1.2978e-01,
        -5.7110e-02, -1.7638e-01, -1.4095e-01,  6.4265e-02,  2.2033e-01,
        -2.6778e-02,  9.7839e-03, -1.0697e-01, -4.9121e-02, -9.9176e-02,
        -4.4776e-03,  5.4202e-02, -7.4493e-02, -7.7575e-02, -2.8271e-01,
        -3.1070e-02,  1.0426e-02, -1.4187e-01, -2.3104e-01, -5.8558e-02,
         1.7795e-03,  4.3819e-02,  1.8348e-02,  6.4776e-02,  3.9384e-01,
        -9.1253e-03, -1.6457e-01,  3.4824e-02,  5.1901e-02,  5.7651e-03,
        -8.5938e-02, -1.6118e-01,  1.0630e-01, -1.2250e-01, -1.1846e-01,
        -5.1131e-02,  4.0534e-01,  2.9841e-01,  1.0731e-01, -2.0869e-01,
         2.5511e-03, -1.0602e-01, -1.2735e-02, -1.5956e-01, -2.5319e-01,
        -1.5277e-01, -6.8644e-02, -

Epoch [34/50]: 100%|██████████████████████████████████████████████████████| 2/2 [00:00<00:00, 64.62it/s, loss=2.74]


tensor([ 8.9280e-03,  2.7031e-01,  2.7172e-01,  1.5892e-01, -1.3507e-01,
         3.2908e-01, -3.0378e-04, -1.2122e-02, -2.2393e-02,  4.3104e-02,
        -3.7727e-02, -1.6644e-02, -3.5336e-02, -1.5848e-02,  6.6167e-01,
        -5.5249e-01,  1.4238e-01,  6.7521e-01,  4.5290e-02, -2.8738e-02,
        -7.2477e-02, -1.3853e-01,  5.6118e-02, -2.5499e-01,  2.7265e-01,
         2.0005e-01, -3.4574e-01,  1.9887e-02,  3.8809e-03, -1.1993e-01,
        -3.1619e-03,  5.5600e-01,  1.9640e-01,  5.1115e-02,  6.2135e-01,
        -9.4473e-02, -2.1879e-02, -1.3775e-01, -9.4370e-02,  4.7452e-02,
        -1.4102e-04,  1.1929e-02, -4.9693e-01,  5.8750e-03, -1.6254e-02,
         5.6202e-02, -7.0731e-02,  7.4072e-02,  2.7804e-01, -1.9864e-01,
        -9.2765e-02, -9.7351e-02,  3.3819e-02,  2.8582e-02,  4.2126e-02,
        -4.1969e-01,  2.2071e+00,  2.2287e-02,  3.0004e-01,  5.6189e-02,
        -1.4991e-01,  1.1890e-01, -1.2856e-02, -1.1672e-01, -2.1008e-02,
         2.8127e-02, -1.2142e-01, -5.7422e-02,  5.5

  0%|                                                                                        | 0/2 [00:00<?, ?it/s]

-------predict: tensor([-3.1910e-01,  8.1784e-02, -9.3614e-02, -3.5417e-02, -2.3557e-01,
        -9.7049e-02, -8.2532e-02, -5.4948e-02,  2.1693e-01, -1.4841e-01,
         4.3761e-01, -8.1293e-02, -3.7985e-02, -2.6404e-01, -3.0585e-01,
        -1.7424e-01, -2.2048e-01, -1.1419e-01, -2.5030e-01, -2.1309e-01,
         1.5604e-01, -1.9959e-01,  7.5057e-02, -1.5174e-01,  3.1154e-01,
        -2.0064e-02,  5.7616e-02,  6.8929e-03, -1.9055e-01, -1.4361e-01,
         1.3891e-01, -1.8919e-02, -1.5745e-01, -1.4076e-02,  3.6935e-01,
         2.2708e-02,  5.0596e-02,  9.0897e-02,  2.9105e-01, -1.5575e-01,
        -6.3904e-02, -1.0215e-02, -8.8801e-02, -4.1730e-01, -1.2967e-02,
        -2.0948e-01, -4.7765e-02, -5.8820e-02, -1.2221e-01, -1.6233e-01,
        -5.3367e-01, -8.6465e-02, -3.0783e-02, -2.6700e-01,  3.6578e-04,
        -1.4847e-01, -2.8488e-01, -7.3129e-03, -3.9981e-02, -1.1771e-01,
         4.0932e-01, -3.4229e-01,  1.5208e-01, -7.9462e-03,  3.6331e-01,
        -5.2881e-02, -2.0123e-01, -

Epoch [35/50]: 100%|████████████████████████████████████████████████████| 2/2 [00:00<00:00, 64.13it/s, loss=0.0385]



tensor([-3.7727e-02, -1.3507e-01, -1.2142e-01,  8.9560e-03,  7.4072e-02,
        -3.2542e-02,  2.2802e-01,  1.8101e-01, -3.7125e-01,  1.3043e+00,
        -1.2856e-02,  2.5802e-01, -2.3377e-01, -2.2636e-02, -1.8274e-02,
        -1.1993e-01, -4.1969e-06, -7.1206e-02, -1.9705e-02, -5.2776e-02,
         1.2040e-02,  5.5600e-01, -2.4948e-01,  1.4809e-01, -5.8953e-01,
         6.7521e-01, -4.1969e-01,  1.5892e-01,  3.8809e-03,  1.0088e-01,
         2.4615e-01, -4.9693e-01,  1.4238e-01, -7.2477e-02, -2.2303e-01,
        -7.5340e-02, -1.3127e-02, -6.3523e-02, -9.7351e-02, -3.0737e-02,
         5.6202e-02,  1.9887e-02,  5.1115e-02,  1.3274e-02,  4.7452e-02,
         2.7031e-01, -2.1879e-02,  3.3819e-02, -1.6254e-02,  4.2126e-02,
        -3.5336e-02,  2.2896e-01,  2.7419e-01, -5.5249e-01,  6.2135e-01,
        -3.0333e-02,  1.1882e-02, -3.1619e-03,  1.0193e-01, -1.6644e-02,
         3.6443e-01, -2.1008e-02,  1.9640e-01, -7.2309e-02, -1.1954e-02,
        -5.2797e-01,  5.8750e-03, -2.8738e-02, -3.

  0%|                                                                                        | 0/2 [00:00<?, ?it/s]

-------predict: tensor([-1.4329e-01, -2.6049e-01, -1.3716e-01, -1.4145e-01, -2.7629e-02,
         8.7455e-02,  1.8159e+00,  2.5859e-01,  2.0362e-01, -1.5674e-01,
        -6.4697e-02, -1.7130e-01,  8.4053e-01, -2.9005e-01,  1.4806e-01,
         2.3605e-01,  1.8935e-01, -1.9031e-01,  1.9214e-01,  1.8019e-02,
         1.6711e-01,  3.6337e-01,  9.8851e-02,  4.6619e-02, -1.2078e-01,
         2.1117e-01,  1.6273e-01,  5.8439e-01, -9.9193e-02,  2.1834e-01,
        -8.0269e-02, -1.9100e-01, -4.8522e-02, -1.0877e-02, -1.9253e-01,
         5.1282e-02, -1.6925e-01, -3.1885e-02, -1.0655e-01, -1.2309e-01,
        -4.6022e-02,  3.5401e-01,  5.8547e-01, -2.4821e-01, -1.6067e-02,
        -6.0097e-01, -2.4443e-02,  3.6607e-01,  4.5245e-04, -4.6182e-01,
         3.1639e-01,  1.7186e-02,  1.4169e-01, -7.5324e-02, -5.3796e-02,
        -8.8837e-02,  7.2053e-02, -4.3508e-02, -1.2254e-02, -1.9374e-02,
        -7.2875e-02, -2.6952e-02, -2.3321e-01,  3.3859e-01, -3.2965e-02,
        -9.2275e-02, -1.0351e-01,  

Epoch [36/50]: 100%|█████████████████████████████████████████████████████| 2/2 [00:00<00:00, 66.45it/s, loss=0.168]


tensor([ 2.8127e-02, -2.4948e-01, -4.1669e-02,  1.3904e-02,  1.4809e-01,
         4.2126e-02, -6.0509e+00, -1.5848e-02, -3.4574e-01, -5.6639e-01,
        -9.2765e-02,  5.6189e-02,  9.3753e+00,  3.3819e-02,  4.7452e-02,
        -3.5336e-02,  2.4295e-02,  6.0025e-02, -1.3127e-02, -9.4473e-02,
        -4.1969e-01,  1.1890e-01, -5.3402e-01,  5.1115e-02, -1.3775e-01,
        -1.1954e-02,  2.2802e-01,  6.2135e-01, -1.4102e-04,  1.3043e+00,
         2.7804e-01, -7.5340e-02, -4.7574e-02,  2.0005e-01,  1.9640e-01,
        -1.9864e-01,  6.7521e-01, -2.3377e-01, -5.2776e-02,  8.9280e-03,
        -1.2142e-01, -2.2303e-01,  3.6443e-01,  1.3274e-02,  1.8101e-01,
         5.6202e-02, -7.1206e-02, -7.2477e-02,  6.6167e-01,  3.4854e-02,
         1.4527e+00, -1.9864e-01,  1.9887e-02,  5.5103e-03,  2.8582e-02,
        -1.8274e-02,  1.0193e-01,  5.6118e-02,  2.7031e-01, -5.7422e-02,
         8.9560e-03,  2.6465e-01,  4.1903e-02, -9.7351e-02,  1.4430e-02,
        -3.2542e-02,  2.8987e-02, -4.5970e-02, -9.4

Epoch [37/50]:   0%|                                                              | 0/2 [00:00<?, ?it/s, loss=1.64]

-------predict: tensor([ 2.1306e-01,  3.5462e-01, -9.8520e-02, -3.7318e-02,  4.7949e-02,
        -1.9296e-01,  3.5590e-01, -1.9123e-01, -1.2962e-01,  2.4784e-01,
        -1.4300e-01,  1.0324e-01, -7.6286e-02,  1.3945e-01, -4.2961e-02,
         2.3868e-02, -1.3362e-01, -6.3869e-02,  3.8481e-01,  5.3179e-01,
        -1.1695e-01, -3.5714e-02,  1.8649e-01, -4.2072e-02,  5.3155e-02,
         1.0499e-01, -9.1010e-02,  2.6614e-01,  1.2272e-01,  6.4506e-02,
        -4.9257e-02,  6.3249e-02,  3.9590e-02, -1.9887e-01,  1.7566e-01,
        -1.1211e-01,  2.2751e-01,  9.8662e-02, -1.1279e-01, -7.3250e-03,
         5.4808e-01, -4.0516e-02,  3.3169e-01, -6.8219e-02, -3.1303e-02,
         1.9503e-01, -5.6163e-02,  8.9489e-02,  7.9493e-02,  6.9087e-02,
        -5.2583e-02, -4.8828e-02, -8.2978e-02,  7.5793e-02,  6.0522e-04,
         1.5563e-01,  4.3662e-01, -9.2271e-02,  1.9452e-01,  3.9908e-02,
        -1.5274e-01, -2.0406e-01,  8.1931e-02,  1.6562e-01,  6.4885e-02,
        -3.1918e-01, -2.3269e-02, -

Epoch [37/50]: 100%|████████████████████████████████████████████████████| 2/2 [00:00<00:00, 64.33it/s, loss=0.0463]

-------predict: tensor([-0.1567], device='cuda:0', grad_fn=<SqueezeBackward1>), y: tensor([-0.2395], device='cuda:0')----------
tensor([0.0419])
Epoch [37/50]: Train loss: 1.1244, Valid loss: 0.0746





Saving model with loss 0.075...


Epoch [38/50]:   0%|                                                              | 0/2 [00:00<?, ?it/s, loss=1.93]

-------predict: tensor([ 0.1034, -0.0614,  0.7898,  0.0415,  0.0129,  0.1203,  0.0558, -0.1046,
         0.1016, -0.0724,  0.4933,  0.0424,  0.2559, -0.1306,  0.1603, -0.1105,
         0.4145,  0.0634, -0.1416,  0.3180,  0.1329,  0.3588, -0.1222, -0.3117,
         0.1174,  0.4570, -0.0045, -0.0077,  0.3366,  0.1798, -0.2833, -0.0630,
        -0.1224,  0.1767, -0.1472, -0.3028,  0.0796,  0.1055, -0.0995,  0.3146,
         0.0209, -0.0995,  0.0589, -0.1518,  0.0386,  0.0549,  0.0804, -0.0438,
        -0.0900,  0.0328, -0.1386,  0.0632, -0.0570,  0.5203, -0.0911, -0.4461,
         0.1926, -0.1225,  0.1662, -0.2237, -0.2828, -0.0272,  0.0805,  0.1894,
        -0.0100, -0.0877,  0.1045, -0.0692,  0.1872, -0.1293,  0.1315,  0.0932,
        -0.0628, -0.0939, -0.0862, -0.0356,  0.2021, -0.3635, -0.0723, -0.0860,
         0.0374,  0.1851, -0.2457, -0.0641,  0.7196,  0.1585, -0.1435,  0.1583,
        -0.0748, -0.0755, -0.0834,  0.0353,  0.1897,  0.0201,  0.0417,  0.1146,
        -0.1444,  0.2909

Epoch [38/50]: 100%|█████████████████████████████████████████████████████| 2/2 [00:00<00:00, 65.46it/s, loss=0.161]

-------predict: tensor([-0.1076], device='cuda:0', grad_fn=<SqueezeBackward1>), y: tensor([-0.2441], device='cuda:0')----------
tensor([0.2703])
Epoch [38/50]: Train loss: 1.1223, Valid loss: 0.0738





Saving model with loss 0.074...


  0%|                                                                                        | 0/2 [00:00<?, ?it/s]

-------predict: tensor([-3.1073e-01, -9.2788e-01, -5.8578e-02, -2.2405e-01,  1.0878e-02,
        -8.4182e-02, -2.0766e-01, -3.6329e-01, -1.0508e-01, -1.1344e-01,
        -1.9681e-01, -1.2966e-01,  3.9087e-02, -2.7233e-01, -7.5158e-02,
         3.5965e-02,  6.3248e-01,  2.3658e-01,  1.6208e-04,  1.7279e-01,
        -8.5777e-02,  2.3796e-01, -7.7830e-04, -4.3022e-02, -2.4832e-02,
         3.0454e-01,  2.2918e-02,  2.9364e-01,  2.6257e-01, -1.6508e-01,
         1.5440e-01,  1.6642e-02, -2.4824e-01,  2.8701e-01,  1.8150e-01,
        -9.5343e-02, -1.5608e-01,  3.2144e-02,  2.5048e-01,  6.2832e-03,
        -1.3822e-01, -6.5009e-02,  2.0298e-01, -1.0968e-01,  1.6928e-01,
        -1.4638e-01,  5.0047e-02, -1.4225e-01, -1.6364e-01,  2.8743e-01,
        -1.6717e-01,  1.0782e-01, -1.5986e-01,  1.1638e-01,  3.4926e-01,
        -5.0183e-02,  5.8939e-01, -5.3339e-02, -4.3636e-02, -1.5260e-01,
        -6.2151e-02,  7.5933e-02, -7.9021e-03, -3.3098e-01,  4.7585e-02,
         4.3341e-02, -4.3986e-02,  

Epoch [39/50]: 100%|█████████████████████████████████████████████████████| 2/2 [00:00<00:00, 61.91it/s, loss=0.447]


tensor([-5.2776e-02, -5.6639e-01, -7.2309e-02, -2.5499e-01,  3.3819e-02,
        -1.8274e-02, -2.4948e-01,  5.6189e-02, -5.3402e-01,  5.1115e-02,
         5.5600e-01, -4.7574e-02,  8.9560e-03, -5.5249e-01, -3.0333e-02,
         2.5802e-01, -6.0509e+00, -2.1008e-02,  3.6443e-01,  4.7452e-02,
        -7.0731e-02,  2.7158e-02,  5.5103e-03, -2.1879e-02, -4.1969e-01,
        -1.3127e-02,  5.6202e-02, -1.6254e-02,  3.2908e-01, -5.2797e-01,
         2.7172e-01, -1.9864e-01,  1.3274e-02,  1.5892e-01,  1.1890e-01,
        -1.1672e-01, -3.0216e-02, -1.3507e-01,  2.2802e-01, -9.4473e-02,
         4.5290e-02, -2.2393e-02,  2.4800e-01, -1.3853e-01, -3.0737e-02,
        -7.5340e-02, -5.7422e-02, -4.5528e-03, -4.1969e-06,  1.2040e-02,
         4.1903e-02, -4.9693e-01,  1.4809e-01, -9.7351e-02, -5.8953e-01,
        -9.4370e-02,  1.4527e+00,  1.0088e-01,  5.5241e-02,  4.2126e-02,
        -1.9705e-02, -9.2765e-02, -3.2542e-02, -2.2636e-02, -3.7125e-01,
         1.0193e-01,  1.8101e-01, -4.5970e-02,  5.8

Epoch [40/50]: 100%|█████████████████████████████████████████████████████| 2/2 [00:00<00:00, 63.04it/s, loss=0.145]


-------predict: tensor([ 1.7462e-01,  1.5940e-02, -3.8523e-02, -2.8575e-01, -1.8603e-01,
        -7.6574e-02,  1.4876e-01,  6.1847e-02, -2.4657e-01, -1.2231e-01,
        -3.9661e-01, -1.4331e-02, -1.2284e-02,  2.1523e-01, -2.2339e-01,
         9.4789e-02, -2.1598e-01,  5.6291e-01,  3.7131e-01, -1.1586e-01,
         2.4530e-02, -5.1758e-02,  2.1192e-02,  5.3972e-02,  2.7852e-01,
         1.8481e-01, -3.1797e-02, -5.9717e-02, -7.8633e-02,  7.1858e-02,
        -2.5858e-01, -1.1755e-01,  4.7541e-01, -1.6620e-01,  1.3177e-01,
        -9.8763e-02,  8.8603e-02,  1.2984e-01, -4.5595e-02, -1.8478e-01,
        -2.3087e-01, -9.1080e-02,  1.7602e-01, -2.0437e-01,  4.2649e-01,
        -2.7403e-01, -4.1331e-02,  3.0454e-01, -1.1017e-01,  2.4449e-01,
         1.4940e-01,  2.4875e-02,  2.5933e+00,  2.4013e-01, -4.1299e-02,
        -5.4855e-02,  1.3407e-01, -1.5525e-01, -2.2122e-02, -2.3122e-01,
        -8.2433e-02, -1.2835e-01,  3.1897e-01, -2.1973e-01,  1.0895e-01,
         1.3123e-01, -1.6960e-01,  

Epoch [41/50]: 100%|█████████████████████████████████████████████████████| 2/2 [00:00<00:00, 65.77it/s, loss=0.246]


-------predict: tensor([ 1.0798e-01,  9.5954e-02, -2.3471e-01,  2.0332e-01, -3.7661e-01,
        -1.3421e-01,  5.3754e-01,  1.6221e-01, -7.2293e-02,  4.2057e-02,
         5.3554e-02,  2.2788e-02,  1.6107e-02,  4.6573e-01, -6.7081e-02,
        -4.7038e-02, -1.0063e-01, -9.9560e-02,  1.1802e-01, -1.1066e-01,
         1.0782e-01,  1.2988e+00, -2.1478e-01, -2.1675e-01, -2.2495e-01,
        -1.0778e-01,  8.3103e-02, -2.9837e-02, -7.3527e-02,  3.6145e-01,
        -1.1842e-01, -1.1156e-01,  2.4573e-02,  2.3975e-01,  3.9980e-02,
        -2.6312e-01,  4.6458e-01, -1.2253e-01, -1.6106e-01,  1.3158e-01,
        -1.4236e-01,  2.4841e-01,  8.6540e-02,  3.3805e-01,  1.9665e-02,
        -1.0158e-03, -2.2090e-01, -1.5842e-01,  2.5634e-01, -1.3733e-01,
        -2.3324e-01,  5.7728e-02, -9.0031e-02,  2.6473e+00,  4.3612e-01,
         4.2283e-03,  7.0170e-02, -2.7218e-01,  3.3521e-02, -6.8081e-02,
        -6.9720e-03, -3.0705e-02, -3.7397e-03,  7.2676e-02,  1.1882e-01,
        -3.9166e-01,  3.1780e-02,  

  0%|                                                                                        | 0/2 [00:00<?, ?it/s]

-------predict: tensor([-0.0415,  0.4196,  0.4365, -0.0183, -0.1059, -0.0804, -0.1191, -0.0203,
         0.1476, -0.0392,  0.2407,  0.5370, -0.2520, -0.1238,  0.0449,  0.1495,
         0.0762, -0.0378, -0.0441,  0.6541,  0.1336,  0.4850, -0.3758, -0.0021,
         0.4031,  0.0050, -0.1883, -0.1061,  0.0129, -0.0528,  0.0707, -0.0358,
        -0.0408, -0.0238, -0.0869,  0.4943,  0.0783, -0.2550, -0.4548,  0.2284,
         0.0620,  0.0811, -0.0438,  0.0480, -0.1389, -0.1667,  0.1741,  0.3466,
        -0.0681,  0.4833,  0.0672,  0.2311,  0.4575,  0.1341,  0.0409,  0.2589,
        -0.0268,  0.0055,  0.3862,  0.1117, -0.2106, -0.6025, -0.0886, -0.2388,
         0.5845, -0.2372,  0.0229,  0.1185,  0.1465,  0.1030, -0.1714, -0.0694,
        -0.1397,  1.1777, -0.0728,  0.1728,  0.0305, -0.0386,  0.0574,  0.1510,
         1.0172,  0.2348,  0.8085,  0.2359, -0.0992, -0.1729,  0.7892,  0.0952,
         0.1842,  0.6219, -0.0881,  0.2826, -0.0800,  0.3131,  0.7012,  0.2075,
        -0.0774, -0.1168

Epoch [42/50]:   0%|                                                              | 0/2 [00:00<?, ?it/s, loss=1.79]

tensor([ 2.8127e-02,  1.5892e-01, -1.1954e-02, -1.2122e-02,  6.2135e-01,
         2.8987e-02,  1.1882e-02,  2.0601e-02, -3.5336e-02,  2.7265e-01,
        -6.3523e-02,  1.3043e+00,  1.9887e-02,  2.7158e-02,  1.4809e-01,
         7.6434e-02,  6.7521e-01,  1.9810e-01, -3.0333e-02, -6.0509e+00,
         2.2287e-02, -4.5970e-02, -1.9705e-02, -9.4473e-02, -2.8738e-02,
         3.2908e-01, -1.6644e-02, -5.3402e-01,  1.8101e-01,  5.8750e-03,
         2.8582e-02, -2.1879e-02, -4.7574e-02, -2.5499e-01, -1.3853e-01,
         8.9280e-03,  2.7419e-01, -9.4370e-02,  5.5600e-01, -1.2856e-02,
        -3.4175e-01,  4.7452e-02,  2.7031e-01, -1.3775e-01, -7.0731e-02,
         4.3104e-02, -2.4948e-01, -1.3127e-02, -1.3818e-03, -1.5848e-02,
         3.3819e-02,  6.6167e-01, -9.2765e-02, -7.2309e-02, -2.2393e-02,
        -2.3377e-01, -1.9864e-01,  2.9132e-01,  1.4527e+00,  5.6189e-02,
         5.5241e-02, -5.6639e-01, -5.5249e-01,  7.2377e-02, -3.0737e-02,
        -3.2542e-02, -1.3507e-01, -1.2142e-01, -7.2

Epoch [42/50]: 100%|█████████████████████████████████████████████████████| 2/2 [00:00<00:00, 66.25it/s, loss=0.289]


Epoch [42/50]: Train loss: 1.1107, Valid loss: 0.0716
Saving model with loss 0.072...


  0%|                                                                                        | 0/2 [00:00<?, ?it/s]

-------predict: tensor([-3.4041e-02,  1.4210e-01,  1.5414e-01, -3.6198e-01,  4.6007e-02,
         8.2782e-01,  1.4782e-01,  1.4729e-01,  6.6925e-01, -7.0815e-02,
        -3.4482e-01, -1.0774e-01,  5.3127e-01, -1.2417e-01,  1.6175e-02,
         3.4629e-02, -6.2669e-02,  2.9218e-02,  2.4915e-01,  5.5861e-01,
        -7.2531e-02, -4.4060e-02, -1.2396e-01,  3.6182e-01, -1.4702e-01,
        -1.6240e-01, -1.4968e-02, -7.9598e-02,  2.5725e-01,  2.7626e-01,
         3.0339e-01,  3.2032e-01, -2.4720e-02, -8.0767e-02,  4.1495e+00,
        -1.8515e-01,  4.6163e-02, -2.0942e-01, -3.3771e-02, -2.3134e-01,
        -5.3460e-02,  5.6116e-02, -2.8877e-01,  3.1719e-01, -3.9897e-03,
        -2.7905e-01,  4.1116e-01, -8.4468e-02, -1.2433e-01,  3.0616e-01,
        -1.4877e-01, -2.1587e-02,  3.1080e-01,  4.1624e-01, -2.0847e-02,
        -2.8837e-01, -6.6971e-01, -8.8375e-02, -2.5871e-01, -2.8603e-01,
        -9.1503e-02,  1.3129e-02, -3.1518e-01,  4.7385e-02, -8.5519e-02,
         2.6679e-01, -1.7018e-01,  

Epoch [43/50]:   0%|                                                              | 0/2 [00:00<?, ?it/s, loss=1.42]

tensor([ 5.1115e-02, -1.1386e-01,  2.7265e-01, -5.5249e-01,  8.9560e-03,
        -5.8953e-01, -3.4175e-01,  2.0005e-01, -1.6254e-02, -3.7125e-01,
         5.6202e-02,  2.8527e-01,  1.1890e-01,  7.2377e-02,  1.9810e-01,
         1.3904e-02,  5.6118e-02,  2.7172e-01,  2.4800e-01,  1.4527e+00,
        -6.3523e-02,  2.2896e-01, -5.3402e-01, -5.4034e-01, -1.5848e-02,
         6.0025e-02,  4.3104e-02,  1.9640e-01, -2.5499e-01,  2.6465e-01,
        -2.4948e-01, -4.5970e-02,  2.0601e-02, -3.0378e-04,  9.3753e+00,
        -5.2776e-02, -4.1969e-01,  2.7031e-01,  1.4430e-02, -1.9705e-02,
        -1.2122e-02,  5.8750e-03,  3.3819e-02,  2.7158e-02, -9.4370e-02,
         5.5241e-02, -2.8738e-02, -1.6644e-02,  2.9925e-02, -3.1619e-03,
         1.0193e-01, -4.9693e-01, -1.4991e-01, -2.2303e-01,  7.6434e-02,
        -2.3377e-01, -5.6639e-01, -3.4574e-01, -7.1206e-02, -1.3818e-03,
         5.6189e-02, -1.1954e-02, -4.1969e-06, -4.7574e-02, -9.7351e-02,
         3.2908e-01,  2.8127e-02,  2.9952e-01,  1.2

Epoch [43/50]: 100%|██████████████████████████████████████████████████████| 2/2 [00:00<00:00, 64.22it/s, loss=1.36]


Epoch [43/50]: Train loss: 1.1171, Valid loss: 0.0710
Saving model with loss 0.071...


  0%|                                                                                        | 0/2 [00:00<?, ?it/s]

-------predict: tensor([-1.2544e-01, -5.3194e-02, -1.0275e-01,  2.2353e-01,  1.2089e-01,
         8.5102e-02, -2.8318e-01,  4.4472e-01, -1.1120e-01,  3.7559e-01,
        -1.5588e-01, -1.0553e-01, -2.2403e-01, -8.2933e-02,  1.5498e-01,
         1.8982e-02, -1.0555e-01,  1.3223e-01,  1.0606e-03,  2.2006e-02,
         2.2983e-01,  1.2064e-01, -5.9590e-02,  1.2647e-01,  2.0492e-01,
        -1.1964e-01, -5.9133e-02,  2.9283e-01, -2.9613e-02, -2.5368e-01,
         1.4951e-01,  2.5520e-01,  2.8628e-01,  5.1711e-01,  1.1730e-01,
         3.1826e-01,  9.1583e-01, -7.2773e-02, -1.3978e-03,  2.1402e-01,
        -1.2603e-01,  1.1996e-01,  9.1117e-01,  9.1634e-02, -2.4241e-01,
         6.3261e-02,  4.8106e-02,  5.3454e-02, -3.2212e-01,  3.5509e-01,
        -8.3930e-02, -2.2996e-01,  1.1534e+00,  2.2485e-01, -5.1945e-02,
         1.4866e-01, -1.2687e-01,  1.6428e-01, -1.8627e-01,  2.1968e-01,
         1.5279e-01, -3.1942e-01,  9.5564e-02, -3.4377e-01, -1.0603e-01,
        -2.9186e-01,  2.3972e-01, -

Epoch [44/50]: 100%|████████████████████████████████████████████████████| 2/2 [00:00<00:00, 67.01it/s, loss=0.0426]


tensor([ 5.5241e-02, -2.8738e-02, -2.2636e-02,  1.9810e-01, -1.3775e-01,
        -2.4948e-01,  4.5290e-02, -4.5970e-02, -3.0333e-02,  3.3819e-02,
         4.7452e-02,  2.8127e-02, -1.4991e-01, -3.2542e-02, -3.4574e-01,
        -1.3127e-02,  4.3104e-02,  2.9132e-01, -7.0731e-02,  1.4238e-01,
         2.4295e-02, -1.2856e-02, -9.2765e-02,  7.6434e-02,  2.5802e-01,
         1.0088e-01,  8.9560e-03,  2.7158e-02, -1.6353e-01, -4.5528e-03,
         2.7172e-01, -7.5340e-02, -1.1954e-02, -3.7727e-02, -2.5499e-01,
        -4.3638e-01,  2.2071e+00, -3.0216e-02,  3.2908e-01,  1.3274e-02,
        -3.0378e-04,  1.8101e-01,  1.4527e+00,  1.1929e-02, -2.1008e-02,
        -5.2797e-01,  1.4430e-02, -3.0737e-02,  3.4854e-02,  2.6465e-01,
         4.2126e-02,  4.1903e-02,  9.3753e+00,  3.0004e-01,  5.5103e-03,
        -4.7574e-02, -4.1969e-06, -7.2477e-02,  5.5600e-01,  5.8750e-03,
        -5.4034e-01, -3.1619e-03, -1.4102e-04, -5.5249e-01,  5.1115e-02,
         5.6189e-02,  2.9952e-01,  6.7521e-01, -5.8

Epoch [45/50]: 100%|████████████████████████████████████████████████████| 2/2 [00:00<00:00, 70.54it/s, loss=0.0472]


-------predict: tensor([-2.2230e-01,  1.2046e-01, -1.5615e-01, -3.4921e-01, -1.3822e-01,
         8.4638e-01, -1.7061e-01,  5.1749e-01,  2.9543e-03, -2.3458e-01,
        -1.2225e-01, -4.7590e-02,  6.0122e-02,  5.3054e-02, -1.6666e-02,
        -3.9557e-02, -5.1455e-02,  1.0479e-01,  1.7127e-01, -1.5333e-01,
        -1.0634e-01, -2.0003e-01, -8.8140e-03, -8.2424e-02,  4.6940e-01,
         2.2492e-01,  3.7880e-01, -1.3982e-01,  1.2174e-01,  9.2482e-02,
        -1.1529e-01, -1.8066e-01, -1.0224e-01, -3.6352e-02, -1.7584e-01,
         5.9814e-02, -3.8604e-02, -2.2245e-01, -6.5195e-02,  6.2644e-02,
         3.8215e-03,  8.8867e-02,  3.0828e-01, -1.8081e-01, -1.8942e-01,
         1.5682e-01, -7.5690e-02, -1.0761e-02, -1.1225e-01, -4.1283e-01,
        -1.0549e-01, -1.8834e-01, -7.4273e-02, -2.5987e-01,  2.1419e-01,
         4.7900e-02, -2.8733e-01,  1.9320e-01, -8.7739e-02, -1.5451e-01,
         3.9936e-01,  4.2101e+00, -9.5290e-02, -1.5110e-01,  4.1341e-01,
         1.1667e-01,  2.8732e-01, -

Epoch [46/50]: 100%|█████████████████████████████████████████████████████| 2/2 [00:00<00:00, 65.46it/s, loss=0.458]


-------predict: tensor([-0.0552,  0.2476, -0.3582,  0.0258, -0.0168, -0.0814,  0.0061, -0.1946,
        -0.0508, -0.0988, -0.1330, -0.0465,  0.5335,  0.0818, -0.1699,  0.1052,
        -0.0105, -0.1831,  0.2817, -0.1352,  0.2168, -0.1373, -0.0468,  0.0230,
        -0.0701, -0.0939, -0.0791, -0.0566,  0.0611, -0.2091, -0.1244, -0.0553,
        -0.0489, -0.1115,  0.1933, -0.1078,  0.5423, -0.0287, -0.0260, -0.0428,
        -0.1184, -0.0974, -0.1491, -0.0645, -0.0540,  0.2095,  0.0549,  0.1850,
         0.0092, -0.1256, -0.1386,  0.1155, -0.1051, -0.1103,  0.0415, -0.0208,
        -0.2210, -0.0799, -0.2944,  0.1296,  0.0594, -0.1744,  0.0233, -0.0968,
         0.3391,  0.0978,  0.0831,  0.0154, -0.1101,  0.1947,  0.0107,  0.1083,
         2.6538, -0.1695,  0.0298, -0.2144, -0.1015,  0.1212, -0.0306,  0.0818,
         0.2774,  0.2434, -0.1454, -0.0296, -0.3996,  0.0551,  0.1796, -0.1795,
        -0.0641, -0.0086,  0.4335, -0.1262, -0.0640, -0.1100, -0.0821,  0.3756,
         0.4341,  1.4493

  0%|                                                                                        | 0/2 [00:00<?, ?it/s]

-------predict: tensor([ 7.3959e-03,  5.8134e-01, -1.2304e-01, -2.2289e-01,  1.4106e-01,
         7.0196e-02, -4.4009e-02, -2.7160e-02,  4.1882e-01, -1.9680e-02,
         3.8464e-01, -1.9240e-02, -1.6009e-01, -1.1154e-01, -3.5466e-01,
        -2.3582e-01, -1.2891e-01,  3.4901e-01, -1.4204e-01,  7.1256e-02,
        -2.6276e-01, -1.0413e-01,  3.0998e-01,  1.7231e-02,  8.7206e-02,
         2.8792e-02, -1.1054e-01, -9.7003e-02,  1.1683e-01,  2.7226e-02,
         1.5457e-01, -4.3336e-02, -3.7395e-01,  2.8043e-01,  5.1794e-02,
         3.5432e-02,  2.0092e-03,  2.7522e-01,  1.1216e-01, -2.7809e-01,
        -2.8708e-01, -5.2156e-02, -2.5261e-02,  2.6014e+00, -1.2577e-01,
         2.1884e-01,  5.7723e-01,  8.2313e-01,  1.3713e-01, -1.6137e-01,
         2.5326e-01, -2.1668e-02, -5.5824e-02, -9.6837e-02,  5.8954e-02,
         5.8442e-02, -2.3864e-01,  2.4738e-01,  7.1265e-02, -1.1494e-01,
         1.4037e-01, -7.7737e-02, -3.1469e-02, -6.9909e-02,  1.5262e-01,
        -9.4905e-02,  4.2566e-02,  

Epoch [47/50]:   0%|                                                                         | 0/2 [00:00<?, ?it/s]

tensor([ 2.7158e-02,  2.2071e+00, -9.4370e-02,  5.1115e-02,  5.5241e-02,
         2.0005e-01,  3.8809e-03, -1.9705e-02, -1.5848e-02, -9.4473e-02,
         2.2287e-02,  2.7419e-01,  2.5802e-01,  4.1903e-02,  1.1890e-01,
         2.7031e-01,  5.6189e-02,  5.5600e-01,  1.0193e-01,  2.9132e-01,
         4.2126e-02,  4.3104e-02, -1.2856e-02, -1.3507e-01,  2.6465e-01,
        -2.2303e-01,  1.4809e-01,  2.8987e-02, -7.5340e-02,  2.7172e-01,
        -5.4034e-01,  2.7265e-01,  5.6202e-02, -2.3377e-01, -2.5499e-01,
         4.5290e-02, -1.3818e-03, -3.4175e-01, -1.9864e-01,  6.0025e-02,
        -7.2309e-02, -4.5528e-03,  1.4527e+00,  9.3753e+00, -4.3638e-01,
        -1.1386e-01,  3.4854e-02, -7.2477e-02, -9.7351e-02, -4.1969e-06,
         3.3819e-02, -2.4948e-01,  1.9810e-01, -1.1954e-02,  6.6167e-01,
        -1.4102e-04, -7.0731e-02,  3.0004e-01,  1.3043e+00,  8.9280e-03,
        -3.1619e-03, -3.2542e-02, -1.3775e-01, -1.4991e-01, -6.0509e+00,
        -1.3853e-01,  5.4263e-02, -1.3127e-02,  1.1

Epoch [47/50]: 100%|██████████████████████████████████████████████████████| 2/2 [00:00<00:00, 73.98it/s, loss=1.28]


-------predict: tensor([0.4630], device='cuda:0', grad_fn=<SqueezeBackward1>), y: tensor([-0.2450], device='cuda:0')----------
tensor([-0.4197])
Epoch [47/50]: Train loss: 1.1099, Valid loss: 0.0683
Saving model with loss 0.068...


  0%|                                                                                        | 0/2 [00:00<?, ?it/s]

-------predict: tensor([ 9.1939e-04, -2.7340e-01, -1.1878e-01, -1.8383e-01, -2.3952e-01,
         2.0436e-01, -2.7746e-01, -3.8188e-01,  3.6136e-01, -1.5725e-01,
        -1.4685e-01, -1.1899e-01, -1.0161e-01,  2.9386e-01,  5.5227e-02,
         1.9322e-01, -7.1284e-02, -3.5630e-02, -6.9264e-02,  5.5482e-01,
        -1.1885e-01, -9.8645e-02, -1.4631e-01, -1.3477e-01,  2.2636e-01,
         1.6936e-01, -8.0682e-03,  1.2778e-01, -9.8222e-02, -6.8119e-02,
         1.1582e-01, -2.8715e-01, -4.3673e-01,  1.4629e-01, -1.0185e-01,
         1.6864e-01, -1.0270e-01, -2.7710e-01, -1.8517e-02, -1.4521e-01,
        -7.4586e-03, -9.9943e-02,  4.3156e-01, -1.0484e-01,  5.0339e-02,
         1.5525e-01, -2.6399e-01, -2.0029e-01,  8.3422e-02,  8.6141e-01,
         2.8328e-01,  1.4504e-01, -8.0163e-02,  2.7214e-01, -1.5482e-01,
        -6.6267e-02, -1.4586e-02, -2.1370e-01, -2.3473e-01,  1.1769e-01,
        -1.6782e-01, -1.0316e-01, -1.2007e-01,  7.3641e-01,  3.2489e-02,
        -2.1923e-01, -2.3975e-01,  

Epoch [48/50]:   0%|                                                              | 0/2 [00:00<?, ?it/s, loss=1.59]

tensor([-3.7125e-01,  6.7521e-01, -1.3127e-02, -3.0216e-02, -1.9705e-02,
        -3.0737e-02,  2.5802e-01, -1.6644e-02,  1.4527e+00,  5.6118e-02,
        -1.1672e-01,  4.2126e-02,  2.8582e-02,  2.4295e-02,  5.5600e-01,
         2.4800e-01, -5.8953e-01, -1.9864e-01, -2.5499e-01,  1.3043e+00,
         1.2040e-02,  8.9280e-03, -2.1879e-02,  2.8527e-01,  2.6465e-01,
        -3.4175e-01, -2.1008e-02, -7.1206e-02,  2.7419e-01, -1.4991e-01,
        -3.5336e-02,  1.9887e-02, -4.1669e-02, -9.7351e-02,  7.4072e-02,
        -3.4574e-01, -1.2142e-01, -7.0731e-02, -4.5528e-03, -3.2542e-02,
        -4.9693e-01, -7.5340e-02,  2.2896e-01, -2.2393e-02,  3.2908e-01,
         2.2287e-02,  7.2377e-02, -1.8274e-02, -2.8738e-02, -7.2477e-02,
         2.6243e-02,  1.4809e-01,  5.4263e-02, -4.7574e-02,  6.6167e-01,
         1.0193e-01,  1.9810e-01,  1.4430e-02,  2.4615e-01, -1.3818e-03,
         3.8809e-03, -1.5848e-02,  4.1903e-02,  3.6443e-01, -3.7727e-02,
         1.1890e-01, -3.0333e-02, -1.9864e-01,  8.9

Epoch [48/50]: 100%|████████████████████████████████████████████████████| 2/2 [00:00<00:00, 76.22it/s, loss=0.0705]


tensor([0.0349])
Epoch [48/50]: Train loss: 1.1040, Valid loss: 0.0676
Saving model with loss 0.068...


  0%|                                                                                        | 0/2 [00:00<?, ?it/s]

-------predict: tensor([ 1.4562e-01, -2.1984e-01,  2.1724e-01,  1.1208e-02, -4.4218e-01,
         7.3631e-02, -1.9120e-01,  2.3358e-01, -1.9427e-01,  1.1846e+00,
        -1.3769e-01, -3.7397e-01, -4.9170e-02,  4.3995e-02, -2.2676e-01,
         1.0065e-01,  8.7222e-01,  2.9061e-02, -1.4560e-01,  1.0330e+00,
        -6.1205e-02, -1.9541e-01, -1.2824e-01, -1.3157e-01, -3.3697e-02,
         6.2094e-02, -8.0645e-02, -1.3313e-01, -4.5155e-02,  5.2350e-01,
        -2.3641e-01, -5.8403e-02, -9.3144e-02, -9.5209e-02, -1.1426e-02,
         2.0415e-04,  6.5604e-02, -6.4418e-02, -4.3810e-02, -2.9028e-01,
        -6.7222e-02, -1.0823e-01, -1.6234e-01,  3.1301e-01,  4.4021e-02,
        -1.8668e-01,  4.0891e-03, -1.9201e-01, -6.8787e-02,  2.0142e-01,
        -3.1223e-01, -4.1597e-02, -3.5738e-01, -9.0438e-02, -1.6115e-01,
         8.8088e-02,  5.0328e-02,  9.8567e-02,  1.4471e-01, -1.0084e-01,
         2.6072e-01, -9.8228e-03, -5.9773e-02,  7.2564e-02, -2.3342e-01,
        -1.2282e-01, -1.0859e-01, -

Epoch [49/50]:   0%|                                                               | 0/2 [00:00<?, ?it/s, loss=1.5]

tensor([ 1.9810e-01,  1.4238e-01,  1.3274e-02,  2.2287e-02,  3.6443e-01,
        -1.2122e-02, -3.0333e-02,  1.3043e+00,  7.2377e-02, -5.8953e-01,
         1.4430e-02,  1.2040e-02,  1.0088e-01, -4.9693e-01, -2.2636e-02,
         2.2896e-01,  3.4854e-02, -1.1672e-01,  5.4263e-02,  1.4527e+00,
         4.5290e-02, -4.5528e-03, -1.6644e-02, -9.7351e-02, -1.9864e-01,
         2.4615e-01,  2.8987e-02,  2.7804e-01, -9.2765e-02, -5.4034e-01,
         5.5103e-03,  1.9640e-01,  2.8127e-02, -1.3818e-03, -2.2303e-01,
        -1.4102e-04, -3.0216e-02, -2.1879e-02,  8.9560e-03,  2.6243e-02,
        -1.3853e-01, -1.9705e-02,  4.1903e-02, -7.2309e-02,  2.8582e-02,
         5.8750e-03,  1.0193e-01,  7.4072e-02, -3.4574e-01, -3.4175e-01,
        -1.8274e-02, -3.1619e-03, -1.6254e-02,  3.8809e-03, -3.2542e-02,
         6.7521e-01,  2.4295e-02,  1.1882e-02, -2.4948e-01,  1.5892e-01,
        -7.1206e-02, -4.1669e-02,  2.5802e-01,  2.9952e-01,  6.6167e-01,
        -1.2856e-02,  7.6434e-02, -7.0731e-02,  5.6

Epoch [49/50]: 100%|█████████████████████████████████████████████████████| 2/2 [00:00<00:00, 68.77it/s, loss=0.136]


-------predict: tensor([0.0514], device='cuda:0', grad_fn=<SqueezeBackward1>), y: tensor([-0.2446], device='cuda:0')----------
tensor([0.2717])
Epoch [49/50]: Train loss: 1.0982, Valid loss: 0.0669
Saving model with loss 0.067...


  0%|                                                                                        | 0/2 [00:00<?, ?it/s]

-------predict: tensor([-8.0634e-02, -1.6896e-01, -2.9192e-01,  1.2267e-01,  4.3612e-02,
         8.9554e-02, -1.0587e-01, -2.9388e-01,  5.4924e-01, -1.7043e-01,
         2.0933e-02,  1.1682e-01,  1.1625e-01,  7.3957e-02, -3.8395e-01,
        -1.0796e-01,  2.4955e-01, -1.8434e-02,  1.0674e-01, -1.1780e-01,
        -3.1818e-01, -1.3623e-01,  3.1929e-02, -1.2703e-01, -1.1217e-01,
         4.5928e-01,  2.8414e-02,  4.9261e-02, -2.0218e-01, -1.5846e-01,
        -2.0593e-01,  2.7765e-01,  2.1872e-02, -1.9232e-01, -1.3625e-01,
        -7.6961e-02,  1.0137e-01,  2.5288e-01, -2.2454e-01, -1.0670e-01,
         3.4179e-01, -4.5766e-02, -3.3145e-02, -3.5005e-02, -5.4889e-01,
        -6.0037e-02, -3.2841e-01, -5.5517e-02,  1.3283e-01, -1.2957e-01,
         2.1392e-01, -5.5754e-02,  4.5074e-02, -9.0925e-02, -2.5060e-01,
         6.4535e-01, -3.4792e-02, -1.2172e-01, -1.8165e-01, -2.3177e-01,
         2.4698e-01,  9.4691e-02, -6.4938e-02, -2.8586e-01,  2.2700e-01,
        -8.6346e-02,  4.0319e-02, -

Epoch [50/50]: 100%|█████████████████████████████████████████████████████| 2/2 [00:00<00:00, 68.64it/s, loss=0.078]


tensor([ 7.2377e-02, -3.0378e-04, -7.0731e-02, -1.5848e-02,  3.2908e-01,
         5.5600e-01,  6.2135e-01, -2.8738e-02,  8.9280e-03,  3.8809e-03,
        -1.2856e-02, -4.9693e-01,  1.5892e-01, -2.1008e-02, -3.7727e-02,
        -7.1206e-02, -1.8274e-02,  1.4809e-01,  2.7265e-01,  2.4615e-01,
        -1.2122e-02,  6.0025e-02, -2.1879e-02, -2.5499e-01, -5.2776e-02,
         1.2040e-02,  2.6465e-01, -2.3377e-01,  5.1115e-02, -1.3775e-01,
        -3.0333e-02, -1.6254e-02,  2.8527e-01,  5.6118e-02,  4.2126e-02,
         2.7804e-01,  2.2071e+00,  2.2896e-01,  5.5103e-03,  3.6443e-01,
        -1.3127e-02,  1.1929e-02, -9.2765e-02, -4.1669e-02,  5.6202e-02,
         2.2287e-02,  4.5290e-02,  2.0601e-02,  2.2802e-01,  2.8127e-02,
         1.9887e-02, -1.2142e-01,  1.3043e+00, -5.7422e-02, -3.4175e-01,
        -7.2477e-02, -5.8953e-01,  2.9132e-01,  1.1882e-02,  2.8987e-02,
        -9.4473e-02, -4.7574e-02, -5.6639e-01, -2.4948e-01, -1.4991e-01,
        -2.2303e-01,  3.0004e-01,  1.9810e-01,  1.9

In [1]:
plot_learning_curve(loss_record, title='deep model')

NameError: name 'plot_learning_curve' is not defined

In [426]:
del model
model = NeuralNet(input_dim=x_train.shape[1]).to(device)
ckpt = torch.load(config['save_path'], map_location='cpu')  # Load your best model
model.load_state_dict(ckpt)
#plot_pred(dv_set, model, device)  # Show prediction on the validation set

<All keys matched successfully>

In [430]:
from sklearn.metrics import mean_squared_error
from sklearn.metrics import mean_absolute_error
def eval_matrics(y_test, y_pred):

    MSE = mean_squared_error(y_test, y_pred)
    print('MSE=',MSE)
    RMSE =np.sqrt(MSE)
    print('RMSE=',RMSE)
    MAE= mean_absolute_error(y_test, y_pred)
    print('MAE=',MAE)

    R2=1-MSE/np.var(y_test)
    print("R2=", R2)

In [431]:

def test(tt_set, model, device):
    model.eval()                                # set model to evalutation mode
    preds = []; y_b=[]
    for x,y in tt_set:                            # iterate through the dataloader
        x ,y = x.to(device), y.to(device)                          # move data to device (cpu/cuda)
        with torch.no_grad():                   # disable gradient calculation
            pred = model(x)                     # forward pass (compute output)
            preds.append(pred.detach().cpu())
            y_b.append(y.detach().cpu())   # collect prediction
    preds = torch.cat(preds, dim=0).numpy().reshape(-1,1)     # concatenate all predictions and convert to a numpy array
    y_b= torch.cat(y_b,0).numpy().reshape(-1,1) 
    table  = np.concatenate((preds, y_b),axis=1)
    eval_matrics(y_b,preds)
    return table

In [469]:
preds = test(valid_loader, model, device) 

MSE= 0.028597383
RMSE= 0.1691076
MAE= 0.13527282
R2= -0.3172532320022583


# **Testing**
The predictions of your model on testing set will be stored at `pred.csv`.

In [None]:
def save_pred(preds, file):
    ''' Save predictions to specified file '''
    print('Saving results to {}'.format(file))
    with open(file, 'w') as fp:
        writer = csv.writer(fp)
        writer.writerow(['id', 'tested_positive'])
        for i, p in enumerate(preds):
            writer.writerow([i, p])

preds = test(valid_loader, model, device)  # predict COVID-19 cases with your model
print('pred')         # save prediction file to pred.csv

# **Hints**

## **Simple Baseline**
* Run sample code

## **Medium Baseline**
* Feature selection: 40 states + 2 `tested_positive` (`TODO` in dataset)

## **Strong Baseline**
* Feature selection (what other features are useful?)
* DNN architecture (layers? dimension? activation function?)
* Training (mini-batch? optimizer? learning rate?)
* L2 regularization
* There are some mistakes in the sample code, can you find them?

# **Reference**
This code is completely written by Heng-Jui Chang @ NTUEE.  
Copying or reusing this code is required to specify the original author. 

E.g.  
Source: Heng-Jui Chang @ NTUEE (https://github.com/ga642381/ML2021-Spring/blob/main/HW01/HW01.ipynb)


In [94]:
## Load pretrain embedding
emb_ent = torch.load('../LiterallyWikidata/files_needed/pretrained_kge/pretrained_complex_entemb.pt')
list_ent_ids =[]
with open('../LiterallyWikidata/files_needed/list_ent_ids.txt','r') as f:
    for line in f:
        list_ent_ids.append(line.strip())
## Preparing ent embedding
ent2idx = {e:i for i,e in enumerate(list_ent_ids)}
attri_data['ent_idx']= attri_data['e'].map(ent2idx)
embedding_e = torch.nn.Embedding.from_pretrained(emb_ent)
input_e = torch.LongTensor(attri_data['ent_idx'].to_numpy())

entity_embedding = embedding_e(input_e)
## Preparing att embedding
# att2idx = {a:i for i,a in enumerate(attri_data['a'].unique())}
attri_data['a_idx']=attri_data['a'].map(att2idx)
embedding_a = torch.nn.Embedding(len(att2idx),128,padding_idx=0)
input_a = torch.LongTensor(attri_data['a'].to_numpy())

attribute_embedding = embedding_a(input_a)
## concat two embedding
x_data = torch.cat([entity_embedding,attribute_embedding],dim=1).detach().numpy()

