# **Homework 1: COVID-19 Cases Prediction (Regression)**

### [作业辅导](https://www.youtube.com/watch?v=iMzxjBDMvac)

Objectives:
* Solve a regression problem with deep neural networks (DNN).
* Understand basic DNN training tips.
* Familiarize yourself with PyTorch.

If you have any questions, please contact the TAs via TA hours, NTU COOL, or email to mlta-2023-spring@googlegroups.com

In [1]:
# check gpu type
!nvidia-smi

Fri Nov 10 22:58:46 2023       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 537.13                 Driver Version: 537.13       CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                     TCC/WDDM  | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA GeForce RTX 4060 ...  WDDM  | 00000000:01:00.0  On |                  N/A |
| N/A   38C    P8               2W /  93W |     94MiB /  8188MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

# Import packages

In [2]:
# Numerical Operations
import math
import numpy as np

# Reading/Writing Data
import pandas as pd
import os
import csv

# Feature selecting
import sklearn
from sklearn.feature_selection import SelectKBest,f_regression

# For Progress Bar
from tqdm import tqdm

# Pytorch
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split

# For plotting learning curve
from torch.utils.tensorboard import SummaryWriter

In [3]:
# playground
# dir方法会返回类中所有的公共方法
# dir(torch)
# dir(torch.cuda.is_available)
# help(torch.cuda.is_available)

# Download data
If the Google Drive links below do not work, you can use the dropbox link below or download data from [Kaggle](https://www.kaggle.com/competitions/ml2023spring-hw1/overview), and upload data manually to the workspace.

In [4]:
# google drive link
# !gdown --id '1BjXalPZxq9mybPKNjF3h5L3NcF7XKTS-' --output covid_train.csv
# !gdown --id '1B55t74Jg2E5FCsKCsUEkPKIuqaY7UIi1' --output covid_test.csv

# dropbox link
# wget方法会报错，经查资料这个wget是linux系统上的方法，由于现在电脑是windows无法调用，后续在mac电脑上尝试
# 在colab上可以执行
# !wget -O covid_train.csv https://www.dropbox.com/s/lmy1riadzoy0ahw/covid.train.csv?dl=0
# !wget -O covid_test.csv https://www.dropbox.com/s/zalbw42lu4nmhr2/covid.test.csv?dl=0
# /kaggle/input/hw1-covid-19

File_Path_Train = './covid_train.csv'
File_Path_Test = './covid_test.csv'

# Kaggle
# File_Path_Train = '/kaggle/input/hw1-covid-19/covid_train.csv'
# File_Path_Test = '/kaggle/input/hw1-covid-19/covid_test.csv'

# Common
ds_train = pd.read_csv(File_Path_Train)
ds_test = pd.read_csv(File_Path_Train)

### Data checking
先检查一遍数据，按照作业要求，基础数据中，应该包含了美国35个州过去3天的感染人数数据：

- 先尝试看懂数据
- 如果有数据需要padding，先做padding

In [5]:
# ds_train.head()

In [6]:
# ds_train.describe()

In [7]:
# 查看所有列
# for idx,name in enumerate(ds_train.columns):
#   print(idx,name)

In [8]:
# ds_train.shape[0]

In [9]:
# print('--------------------------- STATE ----------------------------------')
# print('id:',ds_train.columns.get_loc('id'))
# print('AL:',ds_train.columns.get_loc('AL'))
# print('WI:',ds_train.columns.get_loc('WI'))
# print('--------------------------- DAY 1 ----------------------------------')
# print('cli:',ds_train.columns.get_loc('cli'))
# print('tested_positive:',ds_train.columns.get_loc('tested_positive'))
# print('--------------------------- DAY 2 ----------------------------------')
# print('cli.1:',ds_train.columns.get_loc('cli.1'))
# print('tested_positive.1:',ds_train.columns.get_loc('tested_positive.1'))
# print('--------------------------- DAY 3 ----------------------------------')
# print('cli.2:',ds_train.columns.get_loc('cli.2'))
# print('tested_positive.2:',ds_train.columns.get_loc('tested_positive.2'))

#### 检查是否包含35个州?
不是35个州，一共34个州。从col[1]到col[34]，AL->WI，
> PS: 美国一共50个州 https://zh.wikipedia.org/zh-hans/%E7%BE%8E%E5%9B%BD%E5%B7%9E%E4%BB%BD

#### 检查是否包含3天的数据?
有三天的数据，每一天有18列数据
- 第一天的数据从col[35]到col[52],从cli到tested_positive
- 第二天有.1的尾缀，从col[53]到col[70],从cli.1到tested_positive.1
- 第三天有.2的尾缀，从col[71]到col[88],从cli.2到tested_positive.2

#### 列名分类
|no|col name|remark|
|-----------|-----------|-----------|
|1|COVID-like illness|冠状病毒样疾病|
|2|Behavior indicators|行为指标|
|3|Belief indicators|信念指标|
|4|Mental indicator|心理指标|
|5|Enviromental indicator|环境指标|
|6|Tested Positive Cases|检测呈阳性的病例|

- ![col-feature-desc](./hw-col-feature-desc.jpg)
- ![col-feature-desc-2](./hw-col-feature-desc-2.jpg)
- ![col-feature-desc-3](./hw-col-feature-desc-3.jpg)

#### 列名解释
由于所有列都为缩写，看起来很费劲，这里找到了一份调查，其中有对应翻译其内容：
https://cmu-delphi.github.io/delphi-epidata/api/covidcast-signals/fb-survey.html

第一个字母为*w*的，代表百分比，否则代表数量

简要摘录内容如下:

|no|col name|desc|Chinese desc|Classification|
|-----------|-----------|-----------|-----------|-----------|
|1|cli|COVID-like illness|COVID-like illness新冠疾病|COVID-like illness|
|2|ili|influenza-like illness|流感样疾病|COVID-like illness|
|3|wnohh_cmnty_cli|not including household community COVID-like illness|社区中非家庭中统计到的cli比例|COVID-like illness|
|4|wbelief_masking_effective||相信戴口罩是有效的预防手段|Belief indicators|
|5|wbelief_distancing_effective||相信远距离是有效的预防手段|Belief indicators|
|6|wcovid_vaccinated_friends||他们的大多数朋友和家人都接种了新冠肺炎疫苗|Behavior indicators|
|7|wlarge_event_indoors||参与过大于10人活动的百分比|Behavior indicators|
|8|wothers_masked_public||受访者表示除他之外 ，大多数人或所有人在公共场合都戴口罩的估计百分比|Enviromental indicator|
|9|wothers_distanced_public||受访者表示除他之外 ，大多数人或所有人在公共场所与他们保持至少为6英尺（1.8米）的估计百分比|Enviromental indicator|
|10|wshop_indoors||去过室内市场、杂货店或药店的受访者的估计百分比|Behavior indicators|
|11|wrestaurant_indoors||去过室内“酒吧、餐厅或咖啡馆”的受访者的估计百分比|Behavior indicators|
|12|wworried_catch_covid||对新冠肺炎感到非常或适度担忧的受访者的估计百分比|Mental indicator|
|13|hh_cmnty_cli||在当地社区报告新冠的人数包括他们的家人|COVID-like illness|
|14|nohh_cmnty_cli||在当地社区报告新冠的人数，不包括他们的家人|COVID-like illness|
|15|wearing_mask_7d||过去7天内大部分或所有时间在公共场合戴口罩的人|Enviromental indicator|
|16|public_transit||使用公共交通|Behavior indicators|
|17|worried_finances||表示非常担心或有点担心下个月家庭财务状况的受访者|Mental indicator|
|18|tested_positive||检测呈阳性的病例|Tested Positive Cases|

In [10]:
# print('--------------------------- COVID-like illness ----------------------------------')
# print(ds_train.loc[:0,['id','cli','ili','wnohh_cmnty_cli','hh_cmnty_cli','nohh_cmnty_cli']])
# print('--------------------------- Behavior indicators ----------------------------------')
# print(ds_train.loc[:0,['id','wcovid_vaccinated_friends','wlarge_event_indoors','wshop_indoors','wrestaurant_indoors','public_transit']])
# print('--------------------------- Belief indicators ----------------------------------')
# print(ds_train.loc[:0,['id','wbelief_masking_effective','wbelief_distancing_effective']])
# print('--------------------------- Mental indicator ----------------------------------')
# print(ds_train.loc[:0,['id','wworried_catch_covid','worried_finances']])
# print('--------------------------- Enviromental indicator ----------------------------------')
# print(ds_train.loc[:0,['id','wothers_masked_public','wothers_distanced_public','wearing_mask_7d']])
# print('--------------------------- Tested Positive Cases ----------------------------------')
# print(ds_train.loc[:0,['id','tested_positive',]])

# Configurations
`config` contains hyper-parameters for training and the path to save your model.

In [11]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# 'select_features': np.r_[1:34,35:37,40:45,47:50,52,53:55,58:63,65:68,70,71:73,76:81,83:86], # Select columns with out id, Mental indicator
config = {
    'model_use_business': False,# 设计模型的时候，神经元数量的确实方式是按照feature数量来，还是按照2的N次方来
    'seed': 19871201,      # Your seed number, you can pick your lucky number. :)
    'select_all': False,   # Whether to use all features.
    'select_features': [35,36,37,47,48,52,53,54,55,65,66,70,71,72,73,83,84],
    'valid_ratio': 0.2,   # validation_size = train_size * valid_ratio
    'n_epochs': 5000,     # Number of epochs.
    'batch_size': 256,
    'learning_rate': 1e-3,
    'early_stop': 600,    # If model has not improved for this many consecutive epochs, stop training.
    'weight_decay': 1e-6,
    'save_path': './models/model.ckpt'  # Your model will be saved here.
}

## 记录一些实验结果

### Epoch [1000/1000]: Train loss: 0.8504, Valid loss: 0.8970
将model的layer改为2的n次方，感觉效果一般
```python
# 将神经元得与业务解耦，将业务体现在input_dim,也就是feature_select上。
# 下方linear的神经元数量修改为2的n次方
layer_1_nodes = 64
layer_2_nodes = 32
layer_3_nodes = 16

class My_Model(nn.Module):
    def __init__(self, input_dim):
        super(My_Model, self).__init__()
        # TODO: modify model's structure, be aware of dimensions.
        self.layers = nn.Sequential(
            nn.Linear(input_dim, layer_1_nodes),
            nn.ReLU(),
            nn.Linear(layer_1_nodes, layer_2_nodes),
            nn.ReLU(),
            nn.Linear(layer_2_nodes, layer_3_nodes),
            nn.ReLU(),
            nn.Linear(layer_3_nodes, 1) # 最后归为一个阳性的数量
        )

    def forward(self, x):
        x = self.layers(x)
        x = x.squeeze(1) # (B, 1) -> (B)
        return x
```

### Epoch [1000/1000]: Train loss: 0.8320, Valid loss: 0.8807

选feature时，更新了feature的筛选，将beilf和mental都去掉。同时，将feature放入config，这需要修改train时的代码
```python
config = {
    'seed': 19871201,      # Your seed number, you can pick your lucky number. :)
    'select_all': False,   # Whether to use all features.
    'select_features': np.r_[1:34,35:37,40:45,47:50,52,53:55,58:63,65:68,70,71:73,76:81,83:86], # Select columns with out id, Mental indicator
    'valid_ratio': 0.2,   # validation_size = train_size * valid_ratio
    'n_epochs': 1000,     # Number of epochs.
    'batch_size': 34*10,
    'learning_rate': 1e-3,
    'early_stop': 600,    # If model has not improved for this many consecutive epochs, stop training.
    'weight_decay': 1e-6,
    'save_path': './models/model.ckpt'  # Your model will be saved here.
}

x_train, x_valid, x_test, y_train, y_valid = select_feat(train_data, valid_data, test_data, config['select_all'], config['select_features'])
```

optimizer使用了Adam方法，并进行L2 Regularzation
```python
optimizer = torch.optim.Adam(model.parameters(), lr=config['learning_rate'], betas=(0.9, 0.999), eps=1e-08, weight_decay=config['weight_decay'], amsgrad=False)
```

### Epoch [1000/1000]: Train loss: 1.1560, Valid loss: 1.3151

```python
# define feature
state_count = 34
feature_count = 15
feature_type_count = 4 # 去掉了一个心理因素
neuron_nodes = state_count + feature_count*3 # 34个州，一共有15列有用特征*3天

nn.Sequential(
    nn.Linear(input_dim, neuron_nodes),
    nn.ReLU(),
    nn.Linear(neuron_nodes, feature_count),
    nn.ReLU(),
    nn.Linear(feature_count, feature_type_count),
    nn.ReLU(),
    nn.Linear(feature_type_count, 1) # 最后归为一个阳性的数量
)

feat_idx = np.r_[1:34,35:45,47:50,52,53:63,65:68,70,71:81,83:86] # Select columns with out id, Mental indicator
```

### Epoch [1000/1000]: Train loss: 4.8335, Valid loss: 3.4793

```python
class My_Model(nn.Module):
    def __init__(self, input_dim):
        super(My_Model, self).__init__()
        # TODO: modify model's structure, be aware of dimensions.
        self.layers = nn.Sequential(
            nn.Linear(input_dim, 17),
            nn.ReLU(),
            nn.ReLU(),
            nn.Linear(17, 5),
            nn.ReLU(),
            nn.ReLU(),
            nn.Linear(5, 1)
        )

    def forward(self, x):
        x = self.layers(x)
        x = x.squeeze(1) # (B, 1) -> (B)
        return x

```

### 下面这个原生网络的结果
修改batch之后：


Epoch [1000/1000]: Train loss: 5.5759, Valid loss: 2.4200

Epoch [4569/5000]: Train loss: 1.2975, Valid loss: 1.9382

``` python
class My_Model(nn.Module):
    def __init__(self, input_dim):
        super(My_Model, self).__init__()
        # TODO: modify model's structure, be aware of dimensions.
        self.layers = nn.Sequential(
            nn.Linear(input_dim, 16),
            nn.ReLU(),
            nn.Linear(16, 8),
            nn.ReLU(),
            nn.Linear(8, 1)
        )

    def forward(self, x):
        x = self.layers(x)
        x = x.squeeze(1) # (B, 1) -> (B)
        return x
```

# Some Utility Functions

You do not need to modify this part.

In [12]:
def same_seed(seed):
    '''Fixes random number generator seeds for reproducibility.'''
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

def train_valid_split(data_set, valid_ratio, seed):
    '''Split provided training data into training set and validation set'''
    valid_set_size = int(valid_ratio * len(data_set))
    train_set_size = len(data_set) - valid_set_size
    train_set, valid_set = random_split(data_set, [train_set_size, valid_set_size], generator=torch.Generator().manual_seed(seed))
    return np.array(train_set), np.array(valid_set)

def predict(test_loader, model, device):
    model.eval() # Set your model to evaluation mode.切换模型状态，后面还有model.train()，要来回切换使用，挺麻烦的
    preds = []
    for x in tqdm(test_loader):
        x = x.to(device)
        with torch.no_grad():
            pred = model(x)
            preds.append(pred.detach().cpu())
    preds = torch.cat(preds, dim=0).numpy()
    return preds

# Dataset

In [13]:
class COVID19Dataset(Dataset):
    '''
    x: Features.
    y: Targets, if none, do prediction.
    '''
    def __init__(self, x, y=None):
        if y is None:
            self.y = y
        else:
            self.y = torch.FloatTensor(y)
        self.x = torch.FloatTensor(x)

    def __getitem__(self, idx):
        if self.y is None:
            return self.x[idx]
        else:
            return self.x[idx], self.y[idx]

    def __len__(self):
        return len(self.x)

# Neural Network Model
Try out different model architectures by modifying the class below.

In [14]:
# @title
# 定义nn(nerual network) 一般需要重写init和forward方法
# forward是前馈网络的意思

# 将神经元得与业务解耦，将业务体现在input_dim,也就是feature_select上
# 下方linear的神经元数量修改为2的n次方
layer_1_nodes = 16
layer_2_nodes = 8
layer_3_nodes = 4

if config['model_use_business']:
  # define feature
  state_count = 34
  previous_count = 2
  feature_count = 13
  feature_type_count = 3 # 去掉了心理因素，相信因素
  neuron_nodes = feature_count*3 # 34个州，一共有15列有用特征*3天

  layer_1_nodes = neuron_nodes + previous_count
  layer_2_nodes = feature_count+1 + previous_count
  layer_3_nodes = feature_type_count+ previous_count

class My_Model(nn.Module):
    def __init__(self, input_dim):
        super(My_Model, self).__init__()
        # TODO: modify model's structure, be aware of dimensions.
        self.layers = nn.Sequential(
            nn.Linear(input_dim, layer_1_nodes),
            nn.ReLU(),
            nn.Linear(layer_1_nodes, layer_2_nodes),
            nn.ReLU(),
            nn.Linear(layer_2_nodes, layer_3_nodes),
            nn.ReLU(),
            nn.Linear(layer_3_nodes, 1) # 最后归为一个阳性的数量
        )

    def forward(self, x):
        x = self.layers(x)
        x = x.squeeze(1) # (B, 1) -> (B)
        return x

# Feature Selection
Choose features you deem useful by modifying the function below.

In [15]:
# np.r_[35:37,40:45,47:50,52,53:55,58:63,65:68,70,71:73,76:81,83:86], # 往期自定义的结果


train_data_for_select = pd.read_csv(File_Path_Train).values
# train_data_for_select, valid_data = train_valid_split(train_data_for_select, config['valid_ratio'], config['seed'])
train_data_for_select_X,train_data_for_select_Y = train_data_for_select[:,:-1],train_data_for_select[:,-1]


# k设置为17，其预估是将前两天的阳性人数和3天中相关度比较的因素取前5，最终返回的列如下：
# cli   ili   wnohh_cmnty_cli   hh_cmnty_cli   nohh_cmnty_cli   tested_positive
# cli.1 ili.1 wnohh_cmnty_cli.1 hh_cmnty_cli.1 nohh_cmnty_cli.1 tested_positive.1
# cli.2 ili.2 wnohh_cmnty_cli.2 hh_cmnty_cli.2 nohh_cmnty_cli.2

# k设置为2，只取两个因素，其方法返回的值为：
# tested_positive，tested_positive.1

# 从结果上看，返回的内容很不错

def select_feat_sklearn(x,y):
 k=17
 x_new = SelectKBest(score_func=f_regression,k=k).fit(x,y)
 return x_new

# print(train_data_for_select_X[:1,:])
# train_data_for_select_X.shape
# train_data_for_select_Y.shape
# print(train_data_for_select_Y)

features = select_feat_sklearn(train_data_for_select_X,train_data_for_select_Y)
array = features.get_feature_names_out()
for idx,name in enumerate(array):
 array[idx] = int(array[idx][1:])

# [35 36 37 47 48 52 53 54 55 65 66 70 71 72 73 83 84]
    
print(array)
# idx = np.argsort(features.scores_)[::1]
# print(features[:3])


[35 36 37 47 48 52 53 54 55 65 66 70 71 72 73 83 84]


In [16]:
# 参数select_all默认为True，代表默认会将所有列作为feature加入计算
# 如果在训练过程中，有发现某几列可能关联性不大，可以排除出来，那要将select_all改为false，并手工导入列数
def select_feat(train_data, valid_data, test_data, select_all=True, select_features=[]):
    '''Selects useful features to perform regression'''
    y_train, y_valid = train_data[:,-1], valid_data[:,-1]
    raw_x_train, raw_x_valid, raw_x_test = train_data[:,:-1], valid_data[:,:-1], test_data

    if select_all:
        feat_idx = list(range(raw_x_train.shape[1]))
    else:
        feat_idx = select_features

    return raw_x_train[:,feat_idx], raw_x_valid[:,feat_idx], raw_x_test[:,feat_idx], y_train, y_valid

# Training Loop

In [17]:
same_seed(config['seed'])

train_data, test_data = pd.read_csv(File_Path_Train).values, pd.read_csv(File_Path_Test).values
train_data, valid_data = train_valid_split(train_data, config['valid_ratio'], config['seed'])

# Print out the data size.
print(f"""train_data size: {train_data.shape}
valid_data size: {valid_data.shape}
test_data size: {test_data.shape}""")

# Select features
x_train, x_valid, x_test, y_train, y_valid = select_feat(train_data, valid_data, test_data, config['select_all'], config['select_features'])

print('x_train')
print(x_train)

# Print out the number of features.
print(f'number of features: {x_train.shape[1]}')

train_dataset, valid_dataset, test_dataset = COVID19Dataset(x_train, y_train), \
                                            COVID19Dataset(x_valid, y_valid), \
                                            COVID19Dataset(x_test)

# Pytorch data loader loads pytorch dataset into batches.
train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True, pin_memory=True)
valid_loader = DataLoader(valid_dataset, batch_size=config['batch_size'], shuffle=True, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=False, pin_memory=True)

train_data size: (2408, 89)
valid_data size: (601, 89)
test_data size: (997, 88)
x_train
[[ 1.3138593  1.2748362 15.7057702 ... 15.8342843 22.0118552 16.1409036]
 [ 1.6800285  1.7248551 16.6408925 ... 16.2795037 21.1383018 16.1498569]
 [ 0.8818757  0.9345295  9.7551765 ...  8.7525655 12.8101195  8.8501928]
 ...
 [ 1.0946756  1.1118771 11.7432324 ... 12.0424304 15.3299617 10.9991829]
 [ 3.6352241  3.7769054 26.5714705 ... 23.7412583 32.9247972 25.8964017]
 [ 3.6850859  3.7616652 33.6158757 ... 31.6886024 37.331582  31.3212202]]
number of features: 17


In [21]:
def trainer(train_loader, valid_loader, model, config, device):

    criterion = nn.MSELoss(reduction='mean') # Define your loss function, do not modify this.

    # Define your optimization algorithm.
    # TODO: Please check https://pytorch.org/docs/stable/optim.html to get more available algorithms.
    # TODO: L2 regularization (optimizer(weight decay...) or implement by your self).
    # optimizer = torch.optim.SGD(model.parameters(), lr=config['learning_rate'], momentum=0.9)
    optimizer = torch.optim.Adam(model.parameters(), lr=config['learning_rate'], betas=(0.9, 0.999), eps=1e-08, weight_decay=config['weight_decay'], amsgrad=False)
    writer = SummaryWriter() # Writer of tensoboard.

    if not os.path.isdir('./models'):
        os.mkdir('./models') # Create directory of saving models.

    n_epochs, best_loss, step, early_stop_count = config['n_epochs'], math.inf, 0, 0

    for epoch in range(n_epochs):
        model.train() # Set your model to train mode.
        loss_record = []

        # tqdm is a package to visualize your training progress.
        train_pbar = tqdm(train_loader, position=0, leave=True)

        for x, y in train_pbar:
            optimizer.zero_grad()               # Set gradient to zero.
            x, y = x.to(device), y.to(device)   # Move your data to device.
            pred = model(x)
            loss = criterion(pred, y)
            loss.backward()                     # Compute gradient(backpropagation).
            optimizer.step()                    # Update parameters.
            step += 1
            loss_record.append(loss.detach().item())

            # Display current epoch number and loss on tqdm progress bar.
            train_pbar.set_description(f'Epoch [{epoch+1}/{n_epochs}]')
            train_pbar.set_postfix({'loss': loss.detach().item()})

        mean_train_loss = sum(loss_record)/len(loss_record)
        # writer.add_scalar('Loss/train', mean_train_loss, step)

        model.eval() # Set your model to evaluation mode.
        loss_record = []
        for x, y in valid_loader:
            x, y = x.to(device), y.to(device)
            with torch.no_grad():
                pred = model(x)
                loss = criterion(pred, y)

            loss_record.append(loss.item())

        mean_valid_loss = sum(loss_record)/len(loss_record)
        if(epoch%100 == 0):
            print(f'Epoch [{epoch+1}/{n_epochs}]: Train loss: {mean_train_loss:.4f}, Valid loss: {mean_valid_loss:.4f}')
        # writer.add_scalar('Loss/valid', mean_valid_loss, step)

        if mean_valid_loss < best_loss:
            best_loss = mean_valid_loss
            torch.save(model.state_dict(), config['save_path'] + str(best_loss)) # Save your best model
            print('Saving model with loss {:.3f}...'.format(best_loss))
            early_stop_count = 0
        else:
            early_stop_count += 1

        if early_stop_count >= config['early_stop']:
            print('\nModel is not improving, so we halt the training session.')
            print('Model current loss {:.3f}'.format(best_loss))
            return

# Dataloader
Read data from files and set up training, validation, and testing sets. You do not need to modify this part.

# Start training!

In [22]:
model = My_Model(input_dim=x_train.shape[1]).to(device) # put your model and data on the same computation device.
trainer(train_loader, valid_loader, model, config, device)

Epoch [1/5000]: 100%|██████████| 10/10 [00:00<00:00, 187.58it/s, loss=357]


Epoch [1/5000]: Train loss: 358.7049, Valid loss: 336.1781
Saving model with loss 336.178...


Epoch [2/5000]: 100%|██████████| 10/10 [00:00<00:00, 239.50it/s, loss=246]


Saving model with loss 275.893...


Epoch [3/5000]: 100%|██████████| 10/10 [00:00<00:00, 192.65it/s, loss=195]


Saving model with loss 221.584...


Epoch [4/5000]: 100%|██████████| 10/10 [00:00<00:00, 143.75it/s, loss=151]


Saving model with loss 137.585...


Epoch [5/5000]: 100%|██████████| 10/10 [00:00<00:00, 262.48it/s, loss=110]


Saving model with loss 71.013...


Epoch [6/5000]: 100%|██████████| 10/10 [00:00<00:00, 215.24it/s, loss=26.4]

Saving model with loss 23.814...



Epoch [7/5000]: 100%|██████████| 10/10 [00:00<00:00, 254.36it/s, loss=16]


Saving model with loss 13.466...


Epoch [8/5000]: 100%|██████████| 10/10 [00:00<00:00, 204.93it/s, loss=15.4]
Epoch [9/5000]: 100%|██████████| 10/10 [00:00<00:00, 259.65it/s, loss=13.5]
Epoch [10/5000]: 100%|██████████| 10/10 [00:00<00:00, 258.69it/s, loss=17.5]


Saving model with loss 13.265...


Epoch [11/5000]: 100%|██████████| 10/10 [00:00<00:00, 224.82it/s, loss=13]


Saving model with loss 12.812...


Epoch [12/5000]: 100%|██████████| 10/10 [00:00<00:00, 269.87it/s, loss=14]
Epoch [13/5000]: 100%|██████████| 10/10 [00:00<00:00, 249.24it/s, loss=13.6]


Saving model with loss 12.201...


Epoch [14/5000]: 100%|██████████| 10/10 [00:00<00:00, 281.75it/s, loss=15.4]
Epoch [15/5000]: 100%|██████████| 10/10 [00:00<00:00, 226.11it/s, loss=14.1]
Epoch [16/5000]: 100%|██████████| 10/10 [00:00<00:00, 247.44it/s, loss=9.38]
Epoch [17/5000]: 100%|██████████| 10/10 [00:00<00:00, 273.70it/s, loss=11.5]
Epoch [18/5000]: 100%|██████████| 10/10 [00:00<00:00, 276.24it/s, loss=12.9]


Saving model with loss 11.739...


Epoch [19/5000]: 100%|██████████| 10/10 [00:00<00:00, 271.69it/s, loss=13.3]


Saving model with loss 11.450...


Epoch [20/5000]: 100%|██████████| 10/10 [00:00<00:00, 237.35it/s, loss=10.7]
Epoch [21/5000]: 100%|██████████| 10/10 [00:00<00:00, 291.31it/s, loss=11.2]


Saving model with loss 11.380...


Epoch [22/5000]: 100%|██████████| 10/10 [00:00<00:00, 259.88it/s, loss=13]


Saving model with loss 10.661...


Epoch [23/5000]: 100%|██████████| 10/10 [00:00<00:00, 241.20it/s, loss=10.5]


Saving model with loss 10.374...


Epoch [24/5000]: 100%|██████████| 10/10 [00:00<00:00, 265.58it/s, loss=16.4]
Epoch [25/5000]: 100%|██████████| 10/10 [00:00<00:00, 290.95it/s, loss=11.4]

Saving model with loss 10.153...



Epoch [26/5000]: 100%|██████████| 10/10 [00:00<00:00, 289.33it/s, loss=9.55]


Saving model with loss 10.111...


Epoch [27/5000]: 100%|██████████| 10/10 [00:00<00:00, 296.37it/s, loss=8.74]


Saving model with loss 9.900...


Epoch [28/5000]: 100%|██████████| 10/10 [00:00<00:00, 251.37it/s, loss=12.1]


Saving model with loss 9.681...


Epoch [29/5000]: 100%|██████████| 10/10 [00:00<00:00, 241.65it/s, loss=11.5]


Saving model with loss 9.236...


Epoch [30/5000]: 100%|██████████| 10/10 [00:00<00:00, 203.22it/s, loss=9.12]


Saving model with loss 8.675...


Epoch [31/5000]: 100%|██████████| 10/10 [00:00<00:00, 281.79it/s, loss=9.12]
Epoch [32/5000]: 100%|██████████| 10/10 [00:00<00:00, 307.35it/s, loss=6.92]


Saving model with loss 7.615...


Epoch [33/5000]: 100%|██████████| 10/10 [00:00<00:00, 249.79it/s, loss=7.83]


Saving model with loss 7.085...


Epoch [34/5000]: 100%|██████████| 10/10 [00:00<00:00, 272.41it/s, loss=6.45]
Epoch [35/5000]: 100%|██████████| 10/10 [00:00<00:00, 289.35it/s, loss=6.47]


Saving model with loss 6.603...


Epoch [36/5000]: 100%|██████████| 10/10 [00:00<00:00, 222.67it/s, loss=6.89]


Saving model with loss 6.248...


Epoch [37/5000]: 100%|██████████| 10/10 [00:00<00:00, 229.15it/s, loss=6.63]


Saving model with loss 5.901...


Epoch [38/5000]: 100%|██████████| 10/10 [00:00<00:00, 223.56it/s, loss=6.18]


Saving model with loss 5.207...


Epoch [39/5000]: 100%|██████████| 10/10 [00:00<00:00, 243.77it/s, loss=4.5]


Saving model with loss 4.898...


Epoch [40/5000]: 100%|██████████| 10/10 [00:00<00:00, 231.38it/s, loss=4.51]


Saving model with loss 4.842...


Epoch [41/5000]: 100%|██████████| 10/10 [00:00<00:00, 299.99it/s, loss=4.07]


Saving model with loss 4.282...


Epoch [42/5000]: 100%|██████████| 10/10 [00:00<00:00, 307.60it/s, loss=4.52]


Saving model with loss 3.981...


Epoch [43/5000]: 100%|██████████| 10/10 [00:00<00:00, 296.18it/s, loss=4.29]


Saving model with loss 3.632...


Epoch [44/5000]: 100%|██████████| 10/10 [00:00<00:00, 244.76it/s, loss=3.55]


Saving model with loss 3.181...


Epoch [45/5000]: 100%|██████████| 10/10 [00:00<00:00, 272.80it/s, loss=3.62]


Saving model with loss 2.885...


Epoch [46/5000]: 100%|██████████| 10/10 [00:00<00:00, 294.63it/s, loss=2.41]


Saving model with loss 2.594...


Epoch [47/5000]: 100%|██████████| 10/10 [00:00<00:00, 295.83it/s, loss=2.94]


Saving model with loss 2.209...


Epoch [48/5000]: 100%|██████████| 10/10 [00:00<00:00, 288.30it/s, loss=2.12]


Saving model with loss 2.021...


Epoch [49/5000]: 100%|██████████| 10/10 [00:00<00:00, 312.23it/s, loss=2.14]


Saving model with loss 1.727...


Epoch [50/5000]: 100%|██████████| 10/10 [00:00<00:00, 273.15it/s, loss=1.95]


Saving model with loss 1.723...


Epoch [51/5000]: 100%|██████████| 10/10 [00:00<00:00, 290.04it/s, loss=1.65]


Saving model with loss 1.512...


Epoch [52/5000]: 100%|██████████| 10/10 [00:00<00:00, 257.16it/s, loss=1.81]


Saving model with loss 1.479...


Epoch [53/5000]: 100%|██████████| 10/10 [00:00<00:00, 250.07it/s, loss=0.967]


Saving model with loss 1.398...


Epoch [54/5000]: 100%|██████████| 10/10 [00:00<00:00, 237.69it/s, loss=1.37]
Epoch [55/5000]: 100%|██████████| 10/10 [00:00<00:00, 236.19it/s, loss=1.33]


Saving model with loss 1.335...


Epoch [56/5000]: 100%|██████████| 10/10 [00:00<00:00, 258.54it/s, loss=1.24]


Saving model with loss 1.294...


Epoch [57/5000]: 100%|██████████| 10/10 [00:00<00:00, 294.35it/s, loss=0.981]
Epoch [58/5000]: 100%|██████████| 10/10 [00:00<00:00, 278.87it/s, loss=1.86]
Epoch [59/5000]: 100%|██████████| 10/10 [00:00<00:00, 243.64it/s, loss=1.18]


Saving model with loss 1.237...


Epoch [60/5000]: 100%|██████████| 10/10 [00:00<00:00, 248.03it/s, loss=1.64]
Epoch [61/5000]: 100%|██████████| 10/10 [00:00<00:00, 298.66it/s, loss=0.991]
Epoch [62/5000]: 100%|██████████| 10/10 [00:00<00:00, 316.40it/s, loss=1.16]
Epoch [63/5000]: 100%|██████████| 10/10 [00:00<00:00, 249.13it/s, loss=1.18]


Saving model with loss 1.236...


Epoch [64/5000]: 100%|██████████| 10/10 [00:00<00:00, 244.59it/s, loss=1.27]
Epoch [65/5000]: 100%|██████████| 10/10 [00:00<00:00, 276.95it/s, loss=1.51]


Saving model with loss 1.229...


Epoch [66/5000]: 100%|██████████| 10/10 [00:00<00:00, 262.59it/s, loss=1.34]
Epoch [67/5000]: 100%|██████████| 10/10 [00:00<00:00, 261.50it/s, loss=1.41]
Epoch [68/5000]: 100%|██████████| 10/10 [00:00<00:00, 286.16it/s, loss=1.34]
Epoch [69/5000]: 100%|██████████| 10/10 [00:00<00:00, 285.63it/s, loss=1.24]
Epoch [70/5000]: 100%|██████████| 10/10 [00:00<00:00, 246.16it/s, loss=1.3]
Epoch [71/5000]: 100%|██████████| 10/10 [00:00<00:00, 285.79it/s, loss=0.792]
Epoch [72/5000]: 100%|██████████| 10/10 [00:00<00:00, 272.50it/s, loss=1.33]
Epoch [73/5000]: 100%|██████████| 10/10 [00:00<00:00, 255.53it/s, loss=0.987]
Epoch [74/5000]: 100%|██████████| 10/10 [00:00<00:00, 64.59it/s, loss=1.31]
Epoch [75/5000]: 100%|██████████| 10/10 [00:00<00:00, 270.29it/s, loss=0.967]
Epoch [76/5000]: 100%|██████████| 10/10 [00:00<00:00, 307.03it/s, loss=1.22]
Epoch [77/5000]: 100%|██████████| 10/10 [00:00<00:00, 277.66it/s, loss=1.3]
Epoch [78/5000]: 100%|██████████| 10/10 [00:00<00:00, 251.46it/s, loss=1.46]

Saving model with loss 1.213...


Epoch [85/5000]: 100%|██████████| 10/10 [00:00<00:00, 272.70it/s, loss=1.28]


Saving model with loss 1.199...


Epoch [86/5000]: 100%|██████████| 10/10 [00:00<00:00, 306.59it/s, loss=1.17]
Epoch [87/5000]: 100%|██████████| 10/10 [00:00<00:00, 307.46it/s, loss=1.4]
Epoch [88/5000]: 100%|██████████| 10/10 [00:00<00:00, 293.58it/s, loss=1.39]
Epoch [89/5000]: 100%|██████████| 10/10 [00:00<00:00, 293.35it/s, loss=1.43]
Epoch [90/5000]: 100%|██████████| 10/10 [00:00<00:00, 293.14it/s, loss=0.984]
Epoch [91/5000]: 100%|██████████| 10/10 [00:00<00:00, 329.40it/s, loss=1.1]


Saving model with loss 1.173...


Epoch [92/5000]: 100%|██████████| 10/10 [00:00<00:00, 338.08it/s, loss=1.01]
Epoch [93/5000]: 100%|██████████| 10/10 [00:00<00:00, 296.76it/s, loss=1.1]
Epoch [94/5000]: 100%|██████████| 10/10 [00:00<00:00, 330.41it/s, loss=1.41]
Epoch [95/5000]: 100%|██████████| 10/10 [00:00<00:00, 357.80it/s, loss=1.53]
Epoch [96/5000]: 100%|██████████| 10/10 [00:00<00:00, 299.73it/s, loss=0.902]
Epoch [97/5000]: 100%|██████████| 10/10 [00:00<00:00, 365.35it/s, loss=1.56]
Epoch [98/5000]: 100%|██████████| 10/10 [00:00<00:00, 337.15it/s, loss=1.08]
Epoch [99/5000]: 100%|██████████| 10/10 [00:00<00:00, 314.27it/s, loss=1.22]


Saving model with loss 1.152...


Epoch [100/5000]: 100%|██████████| 10/10 [00:00<00:00, 283.27it/s, loss=1.68]
Epoch [101/5000]: 100%|██████████| 10/10 [00:00<00:00, 323.90it/s, loss=1.53]


Epoch [101/5000]: Train loss: 1.2081, Valid loss: 1.1942


Epoch [102/5000]: 100%|██████████| 10/10 [00:00<00:00, 298.09it/s, loss=1.24]
Epoch [103/5000]: 100%|██████████| 10/10 [00:00<00:00, 284.40it/s, loss=0.986]
Epoch [104/5000]: 100%|██████████| 10/10 [00:00<00:00, 251.24it/s, loss=0.992]
Epoch [105/5000]: 100%|██████████| 10/10 [00:00<00:00, 315.44it/s, loss=1.06]
Epoch [106/5000]: 100%|██████████| 10/10 [00:00<00:00, 321.13it/s, loss=1.28]
Epoch [107/5000]: 100%|██████████| 10/10 [00:00<00:00, 304.56it/s, loss=1.23]
Epoch [108/5000]: 100%|██████████| 10/10 [00:00<00:00, 307.39it/s, loss=1.18]
Epoch [109/5000]: 100%|██████████| 10/10 [00:00<00:00, 309.08it/s, loss=0.889]
Epoch [110/5000]: 100%|██████████| 10/10 [00:00<00:00, 300.71it/s, loss=1.06]


Saving model with loss 1.136...


Epoch [111/5000]: 100%|██████████| 10/10 [00:00<00:00, 309.05it/s, loss=1.12]
Epoch [112/5000]: 100%|██████████| 10/10 [00:00<00:00, 340.34it/s, loss=1.28]
Epoch [113/5000]: 100%|██████████| 10/10 [00:00<00:00, 306.68it/s, loss=1.13]
Epoch [114/5000]: 100%|██████████| 10/10 [00:00<00:00, 328.33it/s, loss=1.12]
Epoch [115/5000]: 100%|██████████| 10/10 [00:00<00:00, 301.93it/s, loss=0.951]
Epoch [116/5000]: 100%|██████████| 10/10 [00:00<00:00, 331.48it/s, loss=1.12]
Epoch [117/5000]: 100%|██████████| 10/10 [00:00<00:00, 344.04it/s, loss=1.06]
Epoch [118/5000]: 100%|██████████| 10/10 [00:00<00:00, 321.55it/s, loss=1.24]
Epoch [119/5000]: 100%|██████████| 10/10 [00:00<00:00, 320.74it/s, loss=1.2]
Epoch [120/5000]: 100%|██████████| 10/10 [00:00<00:00, 334.72it/s, loss=1.48]
Epoch [121/5000]: 100%|██████████| 10/10 [00:00<00:00, 294.85it/s, loss=1.17]
Epoch [122/5000]: 100%|██████████| 10/10 [00:00<00:00, 305.07it/s, loss=0.893]
Epoch [123/5000]: 100%|██████████| 10/10 [00:00<00:00, 330.08it

Saving model with loss 1.100...


Epoch [142/5000]: 100%|██████████| 10/10 [00:00<00:00, 306.62it/s, loss=1.04]
Epoch [143/5000]: 100%|██████████| 10/10 [00:00<00:00, 334.19it/s, loss=1.02]
Epoch [144/5000]: 100%|██████████| 10/10 [00:00<00:00, 312.83it/s, loss=0.965]
Epoch [145/5000]: 100%|██████████| 10/10 [00:00<00:00, 342.64it/s, loss=0.941]
Epoch [146/5000]: 100%|██████████| 10/10 [00:00<00:00, 343.82it/s, loss=0.931]
Epoch [147/5000]: 100%|██████████| 10/10 [00:00<00:00, 314.47it/s, loss=1.16]
Epoch [148/5000]: 100%|██████████| 10/10 [00:00<00:00, 333.03it/s, loss=1.1]
Epoch [149/5000]: 100%|██████████| 10/10 [00:00<00:00, 206.35it/s, loss=0.899]
Epoch [150/5000]: 100%|██████████| 10/10 [00:00<00:00, 278.31it/s, loss=1.16]
Epoch [151/5000]: 100%|██████████| 10/10 [00:00<00:00, 289.35it/s, loss=1.02]
Epoch [152/5000]: 100%|██████████| 10/10 [00:00<00:00, 257.57it/s, loss=1.22]
Epoch [153/5000]: 100%|██████████| 10/10 [00:00<00:00, 280.54it/s, loss=0.804]
Epoch [154/5000]: 100%|██████████| 10/10 [00:00<00:00, 295.2

Saving model with loss 1.079...


Epoch [168/5000]: 100%|██████████| 10/10 [00:00<00:00, 345.97it/s, loss=1.02]
Epoch [169/5000]: 100%|██████████| 10/10 [00:00<00:00, 305.75it/s, loss=1.04]
Epoch [170/5000]: 100%|██████████| 10/10 [00:00<00:00, 346.04it/s, loss=0.989]
Epoch [171/5000]: 100%|██████████| 10/10 [00:00<00:00, 334.55it/s, loss=1.16]
Epoch [172/5000]: 100%|██████████| 10/10 [00:00<00:00, 297.07it/s, loss=0.927]
Epoch [173/5000]: 100%|██████████| 10/10 [00:00<00:00, 370.53it/s, loss=0.774]
Epoch [174/5000]: 100%|██████████| 10/10 [00:00<00:00, 313.50it/s, loss=0.861]


Saving model with loss 1.070...


Epoch [175/5000]: 100%|██████████| 10/10 [00:00<00:00, 297.02it/s, loss=0.863]
Epoch [176/5000]: 100%|██████████| 10/10 [00:00<00:00, 323.55it/s, loss=1.03]
Epoch [177/5000]: 100%|██████████| 10/10 [00:00<00:00, 322.85it/s, loss=1.12]
Epoch [178/5000]: 100%|██████████| 10/10 [00:00<00:00, 306.75it/s, loss=1.24]
Epoch [179/5000]: 100%|██████████| 10/10 [00:00<00:00, 333.31it/s, loss=0.939]
Epoch [180/5000]: 100%|██████████| 10/10 [00:00<00:00, 308.50it/s, loss=0.911]
Epoch [181/5000]: 100%|██████████| 10/10 [00:00<00:00, 323.63it/s, loss=0.921]
Epoch [182/5000]: 100%|██████████| 10/10 [00:00<00:00, 323.95it/s, loss=0.841]


Saving model with loss 1.054...


Epoch [183/5000]: 100%|██████████| 10/10 [00:00<00:00, 280.97it/s, loss=0.959]
Epoch [184/5000]: 100%|██████████| 10/10 [00:00<00:00, 345.90it/s, loss=0.99]
Epoch [185/5000]: 100%|██████████| 10/10 [00:00<00:00, 323.69it/s, loss=0.953]
Epoch [186/5000]: 100%|██████████| 10/10 [00:00<00:00, 297.80it/s, loss=1.01]
Epoch [187/5000]: 100%|██████████| 10/10 [00:00<00:00, 286.67it/s, loss=1.35]
Epoch [188/5000]: 100%|██████████| 10/10 [00:00<00:00, 265.66it/s, loss=1.62]
Epoch [189/5000]: 100%|██████████| 10/10 [00:00<00:00, 257.58it/s, loss=0.753]


Saving model with loss 1.046...


Epoch [190/5000]: 100%|██████████| 10/10 [00:00<00:00, 263.41it/s, loss=1.04]
Epoch [191/5000]: 100%|██████████| 10/10 [00:00<00:00, 281.22it/s, loss=1.19]
Epoch [192/5000]: 100%|██████████| 10/10 [00:00<00:00, 323.95it/s, loss=1.03]
Epoch [193/5000]: 100%|██████████| 10/10 [00:00<00:00, 306.21it/s, loss=1.28]
Epoch [194/5000]: 100%|██████████| 10/10 [00:00<00:00, 334.38it/s, loss=1.26]
Epoch [195/5000]: 100%|██████████| 10/10 [00:00<00:00, 334.42it/s, loss=0.994]
Epoch [196/5000]: 100%|██████████| 10/10 [00:00<00:00, 296.90it/s, loss=1.13]
Epoch [197/5000]: 100%|██████████| 10/10 [00:00<00:00, 303.70it/s, loss=1.09]
Epoch [198/5000]: 100%|██████████| 10/10 [00:00<00:00, 322.93it/s, loss=1.01]
Epoch [199/5000]: 100%|██████████| 10/10 [00:00<00:00, 289.16it/s, loss=0.814]
Epoch [200/5000]: 100%|██████████| 10/10 [00:00<00:00, 323.86it/s, loss=1.16]
Epoch [201/5000]: 100%|██████████| 10/10 [00:00<00:00, 322.72it/s, loss=1]


Epoch [201/5000]: Train loss: 1.0748, Valid loss: 1.0669


Epoch [202/5000]: 100%|██████████| 10/10 [00:00<00:00, 278.79it/s, loss=0.881]
Epoch [203/5000]: 100%|██████████| 10/10 [00:00<00:00, 278.69it/s, loss=0.963]
Epoch [204/5000]: 100%|██████████| 10/10 [00:00<00:00, 375.95it/s, loss=1.5]
Epoch [205/5000]: 100%|██████████| 10/10 [00:00<00:00, 323.59it/s, loss=1.03]
Epoch [206/5000]: 100%|██████████| 10/10 [00:00<00:00, 302.46it/s, loss=1.07]
Epoch [207/5000]: 100%|██████████| 10/10 [00:00<00:00, 316.05it/s, loss=1.09]
Epoch [208/5000]: 100%|██████████| 10/10 [00:00<00:00, 295.44it/s, loss=0.987]
Epoch [209/5000]: 100%|██████████| 10/10 [00:00<00:00, 280.50it/s, loss=1.56]
Epoch [210/5000]: 100%|██████████| 10/10 [00:00<00:00, 313.64it/s, loss=0.979]
Epoch [211/5000]: 100%|██████████| 10/10 [00:00<00:00, 346.01it/s, loss=1.1]
Epoch [212/5000]: 100%|██████████| 10/10 [00:00<00:00, 296.91it/s, loss=1.08]
Epoch [213/5000]: 100%|██████████| 10/10 [00:00<00:00, 263.74it/s, loss=1.06]
Epoch [214/5000]: 100%|██████████| 10/10 [00:00<00:00, 246.10i

Saving model with loss 1.046...


Epoch [226/5000]: 100%|██████████| 10/10 [00:00<00:00, 313.53it/s, loss=1.03]


Saving model with loss 1.039...


Epoch [227/5000]: 100%|██████████| 10/10 [00:00<00:00, 280.34it/s, loss=1.27]


Saving model with loss 1.026...


Epoch [228/5000]: 100%|██████████| 10/10 [00:00<00:00, 334.38it/s, loss=0.896]
Epoch [229/5000]: 100%|██████████| 10/10 [00:00<00:00, 324.14it/s, loss=1.56]
Epoch [230/5000]: 100%|██████████| 10/10 [00:00<00:00, 297.44it/s, loss=0.744]
Epoch [231/5000]: 100%|██████████| 10/10 [00:00<00:00, 286.67it/s, loss=1.31]
Epoch [232/5000]: 100%|██████████| 10/10 [00:00<00:00, 313.01it/s, loss=0.837]
Epoch [233/5000]: 100%|██████████| 10/10 [00:00<00:00, 337.58it/s, loss=1.21]
Epoch [234/5000]: 100%|██████████| 10/10 [00:00<00:00, 323.66it/s, loss=1.11]
Epoch [235/5000]: 100%|██████████| 10/10 [00:00<00:00, 260.84it/s, loss=0.94]
Epoch [236/5000]: 100%|██████████| 10/10 [00:00<00:00, 278.71it/s, loss=1.02]
Epoch [237/5000]: 100%|██████████| 10/10 [00:00<00:00, 320.33it/s, loss=0.885]
Epoch [238/5000]: 100%|██████████| 10/10 [00:00<00:00, 325.96it/s, loss=1.18]
Epoch [239/5000]: 100%|██████████| 10/10 [00:00<00:00, 355.68it/s, loss=0.941]
Epoch [240/5000]: 100%|██████████| 10/10 [00:00<00:00, 268.

Saving model with loss 0.986...


Epoch [251/5000]: 100%|██████████| 10/10 [00:00<00:00, 303.19it/s, loss=1.08]
Epoch [252/5000]: 100%|██████████| 10/10 [00:00<00:00, 288.97it/s, loss=1.07]
Epoch [253/5000]: 100%|██████████| 10/10 [00:00<00:00, 323.70it/s, loss=0.767]
Epoch [254/5000]: 100%|██████████| 10/10 [00:00<00:00, 315.64it/s, loss=0.971]
Epoch [255/5000]: 100%|██████████| 10/10 [00:00<00:00, 313.57it/s, loss=1.01]
Epoch [256/5000]: 100%|██████████| 10/10 [00:00<00:00, 346.05it/s, loss=0.816]
Epoch [257/5000]: 100%|██████████| 10/10 [00:00<00:00, 325.95it/s, loss=1.36]
Epoch [258/5000]: 100%|██████████| 10/10 [00:00<00:00, 304.03it/s, loss=1.18]
Epoch [259/5000]: 100%|██████████| 10/10 [00:00<00:00, 312.81it/s, loss=0.657]
Epoch [260/5000]: 100%|██████████| 10/10 [00:00<00:00, 306.54it/s, loss=0.974]
Epoch [261/5000]: 100%|██████████| 10/10 [00:00<00:00, 304.05it/s, loss=0.966]
Epoch [262/5000]: 100%|██████████| 10/10 [00:00<00:00, 333.28it/s, loss=0.915]
Epoch [263/5000]: 100%|██████████| 10/10 [00:00<00:00, 31

Saving model with loss 0.973...


Epoch [278/5000]: 100%|██████████| 10/10 [00:00<00:00, 285.43it/s, loss=0.919]
Epoch [279/5000]: 100%|██████████| 10/10 [00:00<00:00, 281.17it/s, loss=1.03]
Epoch [280/5000]: 100%|██████████| 10/10 [00:00<00:00, 191.74it/s, loss=0.945]
Epoch [281/5000]: 100%|██████████| 10/10 [00:00<00:00, 289.44it/s, loss=0.714]
Epoch [282/5000]: 100%|██████████| 10/10 [00:00<00:00, 323.56it/s, loss=1.22]
Epoch [283/5000]: 100%|██████████| 10/10 [00:00<00:00, 333.37it/s, loss=1.18]
Epoch [284/5000]: 100%|██████████| 10/10 [00:00<00:00, 333.87it/s, loss=1.07]
Epoch [285/5000]: 100%|██████████| 10/10 [00:00<00:00, 345.87it/s, loss=1.06]
Epoch [286/5000]: 100%|██████████| 10/10 [00:00<00:00, 288.16it/s, loss=1.29]
Epoch [287/5000]: 100%|██████████| 10/10 [00:00<00:00, 324.02it/s, loss=0.989]
Epoch [288/5000]: 100%|██████████| 10/10 [00:00<00:00, 313.54it/s, loss=0.945]
Epoch [289/5000]: 100%|██████████| 10/10 [00:00<00:00, 350.30it/s, loss=0.856]
Epoch [290/5000]: 100%|██████████| 10/10 [00:00<00:00, 323

Saving model with loss 0.964...


Epoch [295/5000]: 100%|██████████| 10/10 [00:00<00:00, 323.73it/s, loss=0.931]
Epoch [296/5000]: 100%|██████████| 10/10 [00:00<00:00, 277.81it/s, loss=1.45]
Epoch [297/5000]: 100%|██████████| 10/10 [00:00<00:00, 337.46it/s, loss=0.81]
Epoch [298/5000]: 100%|██████████| 10/10 [00:00<00:00, 334.32it/s, loss=1.02]
Epoch [299/5000]: 100%|██████████| 10/10 [00:00<00:00, 286.16it/s, loss=1.15]
Epoch [300/5000]: 100%|██████████| 10/10 [00:00<00:00, 312.63it/s, loss=1.24]
Epoch [301/5000]: 100%|██████████| 10/10 [00:00<00:00, 323.73it/s, loss=1.1]


Epoch [301/5000]: Train loss: 0.9954, Valid loss: 1.0490


Epoch [302/5000]: 100%|██████████| 10/10 [00:00<00:00, 315.72it/s, loss=0.819]
Epoch [303/5000]: 100%|██████████| 10/10 [00:00<00:00, 334.51it/s, loss=1.04]
Epoch [304/5000]: 100%|██████████| 10/10 [00:00<00:00, 345.96it/s, loss=0.895]
Epoch [305/5000]: 100%|██████████| 10/10 [00:00<00:00, 336.77it/s, loss=1.11]
Epoch [306/5000]: 100%|██████████| 10/10 [00:00<00:00, 323.71it/s, loss=0.996]
Epoch [307/5000]: 100%|██████████| 10/10 [00:00<00:00, 286.24it/s, loss=1.04]
Epoch [308/5000]: 100%|██████████| 10/10 [00:00<00:00, 327.01it/s, loss=1.12]
Epoch [309/5000]: 100%|██████████| 10/10 [00:00<00:00, 334.46it/s, loss=0.764]
Epoch [310/5000]: 100%|██████████| 10/10 [00:00<00:00, 312.58it/s, loss=0.804]


Saving model with loss 0.956...


Epoch [311/5000]: 100%|██████████| 10/10 [00:00<00:00, 345.87it/s, loss=0.935]
Epoch [312/5000]: 100%|██████████| 10/10 [00:00<00:00, 346.12it/s, loss=1.14]
Epoch [313/5000]: 100%|██████████| 10/10 [00:00<00:00, 288.40it/s, loss=1.19]
Epoch [314/5000]: 100%|██████████| 10/10 [00:00<00:00, 284.01it/s, loss=0.951]
Epoch [315/5000]: 100%|██████████| 10/10 [00:00<00:00, 303.60it/s, loss=0.823]
Epoch [316/5000]: 100%|██████████| 10/10 [00:00<00:00, 281.87it/s, loss=0.83]
Epoch [317/5000]: 100%|██████████| 10/10 [00:00<00:00, 271.17it/s, loss=0.694]


Saving model with loss 0.938...


Epoch [318/5000]: 100%|██████████| 10/10 [00:00<00:00, 272.77it/s, loss=0.927]
Epoch [319/5000]: 100%|██████████| 10/10 [00:00<00:00, 278.66it/s, loss=0.948]
Epoch [320/5000]: 100%|██████████| 10/10 [00:00<00:00, 312.52it/s, loss=0.69]
Epoch [321/5000]: 100%|██████████| 10/10 [00:00<00:00, 313.52it/s, loss=0.928]
Epoch [322/5000]: 100%|██████████| 10/10 [00:00<00:00, 304.06it/s, loss=0.826]
Epoch [323/5000]: 100%|██████████| 10/10 [00:00<00:00, 337.43it/s, loss=0.823]
Epoch [324/5000]: 100%|██████████| 10/10 [00:00<00:00, 372.00it/s, loss=0.672]


Saving model with loss 0.927...


Epoch [325/5000]: 100%|██████████| 10/10 [00:00<00:00, 279.08it/s, loss=0.842]
Epoch [326/5000]: 100%|██████████| 10/10 [00:00<00:00, 308.23it/s, loss=1.02]
Epoch [327/5000]: 100%|██████████| 10/10 [00:00<00:00, 323.66it/s, loss=1.16]
Epoch [328/5000]: 100%|██████████| 10/10 [00:00<00:00, 278.68it/s, loss=1.07]
Epoch [329/5000]: 100%|██████████| 10/10 [00:00<00:00, 308.60it/s, loss=0.868]
Epoch [330/5000]: 100%|██████████| 10/10 [00:00<00:00, 278.44it/s, loss=0.889]
Epoch [331/5000]: 100%|██████████| 10/10 [00:00<00:00, 288.57it/s, loss=1.02]
Epoch [332/5000]: 100%|██████████| 10/10 [00:00<00:00, 335.10it/s, loss=0.917]
Epoch [333/5000]: 100%|██████████| 10/10 [00:00<00:00, 286.58it/s, loss=1.32]
Epoch [334/5000]: 100%|██████████| 10/10 [00:00<00:00, 315.86it/s, loss=1.35]
Epoch [335/5000]: 100%|██████████| 10/10 [00:00<00:00, 334.48it/s, loss=0.789]
Epoch [336/5000]: 100%|██████████| 10/10 [00:00<00:00, 295.14it/s, loss=0.884]
Epoch [337/5000]: 100%|██████████| 10/10 [00:00<00:00, 307

Saving model with loss 0.921...


Epoch [373/5000]: 100%|██████████| 10/10 [00:00<00:00, 303.95it/s, loss=0.805]
Epoch [374/5000]: 100%|██████████| 10/10 [00:00<00:00, 344.79it/s, loss=1.43]
Epoch [375/5000]: 100%|██████████| 10/10 [00:00<00:00, 326.73it/s, loss=0.993]
Epoch [376/5000]: 100%|██████████| 10/10 [00:00<00:00, 313.24it/s, loss=0.941]
Epoch [377/5000]: 100%|██████████| 10/10 [00:00<00:00, 333.55it/s, loss=0.862]
Epoch [378/5000]: 100%|██████████| 10/10 [00:00<00:00, 338.11it/s, loss=0.894]
Epoch [379/5000]: 100%|██████████| 10/10 [00:00<00:00, 313.47it/s, loss=0.843]
Epoch [380/5000]: 100%|██████████| 10/10 [00:00<00:00, 344.87it/s, loss=1.15]
Epoch [381/5000]: 100%|██████████| 10/10 [00:00<00:00, 333.52it/s, loss=0.833]
Epoch [382/5000]: 100%|██████████| 10/10 [00:00<00:00, 334.83it/s, loss=1.06]
Epoch [383/5000]: 100%|██████████| 10/10 [00:00<00:00, 337.55it/s, loss=0.922]
Epoch [384/5000]: 100%|██████████| 10/10 [00:00<00:00, 358.37it/s, loss=1.32]


Saving model with loss 0.907...


Epoch [385/5000]: 100%|██████████| 10/10 [00:00<00:00, 295.06it/s, loss=0.977]
Epoch [386/5000]: 100%|██████████| 10/10 [00:00<00:00, 349.20it/s, loss=0.885]
Epoch [387/5000]: 100%|██████████| 10/10 [00:00<00:00, 346.13it/s, loss=1.11]
Epoch [388/5000]: 100%|██████████| 10/10 [00:00<00:00, 334.35it/s, loss=0.92]
Epoch [389/5000]: 100%|██████████| 10/10 [00:00<00:00, 348.49it/s, loss=1.04]
Epoch [390/5000]: 100%|██████████| 10/10 [00:00<00:00, 358.27it/s, loss=1.07]
Epoch [391/5000]: 100%|██████████| 10/10 [00:00<00:00, 345.83it/s, loss=1.03]
Epoch [392/5000]: 100%|██████████| 10/10 [00:00<00:00, 315.83it/s, loss=1.06]
Epoch [393/5000]: 100%|██████████| 10/10 [00:00<00:00, 346.00it/s, loss=0.853]
Epoch [394/5000]: 100%|██████████| 10/10 [00:00<00:00, 346.12it/s, loss=1.09]
Epoch [395/5000]: 100%|██████████| 10/10 [00:00<00:00, 326.97it/s, loss=0.835]
Epoch [396/5000]: 100%|██████████| 10/10 [00:00<00:00, 346.01it/s, loss=0.946]
Epoch [397/5000]: 100%|██████████| 10/10 [00:00<00:00, 333.

Epoch [401/5000]: Train loss: 0.9383, Valid loss: 1.0006


Epoch [402/5000]: 100%|██████████| 10/10 [00:00<00:00, 257.63it/s, loss=0.936]


Saving model with loss 0.903...


Epoch [403/5000]: 100%|██████████| 10/10 [00:00<00:00, 281.16it/s, loss=0.993]
Epoch [404/5000]: 100%|██████████| 10/10 [00:00<00:00, 282.18it/s, loss=1.1]
Epoch [405/5000]: 100%|██████████| 10/10 [00:00<00:00, 280.59it/s, loss=0.853]
Epoch [406/5000]: 100%|██████████| 10/10 [00:00<00:00, 324.02it/s, loss=0.927]
Epoch [407/5000]: 100%|██████████| 10/10 [00:00<00:00, 345.90it/s, loss=0.745]
Epoch [408/5000]: 100%|██████████| 10/10 [00:00<00:00, 297.08it/s, loss=1]
Epoch [409/5000]: 100%|██████████| 10/10 [00:00<00:00, 334.54it/s, loss=1.17]
Epoch [410/5000]: 100%|██████████| 10/10 [00:00<00:00, 346.54it/s, loss=0.864]
Epoch [411/5000]: 100%|██████████| 10/10 [00:00<00:00, 297.48it/s, loss=0.941]
Epoch [412/5000]: 100%|██████████| 10/10 [00:00<00:00, 334.18it/s, loss=0.895]


Saving model with loss 0.890...


Epoch [413/5000]: 100%|██████████| 10/10 [00:00<00:00, 294.28it/s, loss=0.845]
Epoch [414/5000]: 100%|██████████| 10/10 [00:00<00:00, 286.60it/s, loss=1.09]
Epoch [415/5000]: 100%|██████████| 10/10 [00:00<00:00, 295.12it/s, loss=0.949]
Epoch [416/5000]: 100%|██████████| 10/10 [00:00<00:00, 306.19it/s, loss=1.09]
Epoch [417/5000]: 100%|██████████| 10/10 [00:00<00:00, 286.69it/s, loss=0.691]
Epoch [418/5000]: 100%|██████████| 10/10 [00:00<00:00, 322.62it/s, loss=0.718]
Epoch [419/5000]: 100%|██████████| 10/10 [00:00<00:00, 84.69it/s, loss=1.31]


Saving model with loss 0.885...


Epoch [420/5000]: 100%|██████████| 10/10 [00:00<00:00, 323.72it/s, loss=0.994]
Epoch [421/5000]: 100%|██████████| 10/10 [00:00<00:00, 303.19it/s, loss=0.751]
Epoch [422/5000]: 100%|██████████| 10/10 [00:00<00:00, 349.76it/s, loss=1.22]
Epoch [423/5000]: 100%|██████████| 10/10 [00:00<00:00, 346.02it/s, loss=0.733]
Epoch [424/5000]: 100%|██████████| 10/10 [00:00<00:00, 312.77it/s, loss=0.772]
Epoch [425/5000]: 100%|██████████| 10/10 [00:00<00:00, 303.58it/s, loss=0.958]
Epoch [426/5000]: 100%|██████████| 10/10 [00:00<00:00, 303.69it/s, loss=0.915]
Epoch [427/5000]: 100%|██████████| 10/10 [00:00<00:00, 315.64it/s, loss=0.964]
Epoch [428/5000]: 100%|██████████| 10/10 [00:00<00:00, 323.67it/s, loss=1.02]
Epoch [429/5000]: 100%|██████████| 10/10 [00:00<00:00, 334.59it/s, loss=0.995]
Epoch [430/5000]: 100%|██████████| 10/10 [00:00<00:00, 338.24it/s, loss=0.989]
Epoch [431/5000]: 100%|██████████| 10/10 [00:00<00:00, 334.42it/s, loss=1.3]
Epoch [432/5000]: 100%|██████████| 10/10 [00:00<00:00, 3

Saving model with loss 0.869...


Epoch [492/5000]: 100%|██████████| 10/10 [00:00<00:00, 295.12it/s, loss=1.22]
Epoch [493/5000]: 100%|██████████| 10/10 [00:00<00:00, 345.00it/s, loss=0.733]


Saving model with loss 0.864...


Epoch [494/5000]: 100%|██████████| 10/10 [00:00<00:00, 346.09it/s, loss=0.91]
Epoch [495/5000]: 100%|██████████| 10/10 [00:00<00:00, 286.80it/s, loss=0.968]
Epoch [496/5000]: 100%|██████████| 10/10 [00:00<00:00, 361.03it/s, loss=1.01]
Epoch [497/5000]: 100%|██████████| 10/10 [00:00<00:00, 345.93it/s, loss=0.792]
Epoch [498/5000]: 100%|██████████| 10/10 [00:00<00:00, 303.98it/s, loss=0.643]
Epoch [499/5000]: 100%|██████████| 10/10 [00:00<00:00, 325.93it/s, loss=0.806]
Epoch [500/5000]: 100%|██████████| 10/10 [00:00<00:00, 334.32it/s, loss=0.876]
Epoch [501/5000]: 100%|██████████| 10/10 [00:00<00:00, 312.51it/s, loss=0.882]


Epoch [501/5000]: Train loss: 0.9368, Valid loss: 0.9285


Epoch [502/5000]: 100%|██████████| 10/10 [00:00<00:00, 349.81it/s, loss=0.893]
Epoch [503/5000]: 100%|██████████| 10/10 [00:00<00:00, 323.62it/s, loss=0.846]
Epoch [504/5000]: 100%|██████████| 10/10 [00:00<00:00, 286.40it/s, loss=0.946]
Epoch [505/5000]: 100%|██████████| 10/10 [00:00<00:00, 80.15it/s, loss=0.845]
Epoch [506/5000]: 100%|██████████| 10/10 [00:00<00:00, 322.18it/s, loss=1.11]
Epoch [507/5000]: 100%|██████████| 10/10 [00:00<00:00, 312.50it/s, loss=1.13]
Epoch [508/5000]: 100%|██████████| 10/10 [00:00<00:00, 348.90it/s, loss=0.758]
Epoch [509/5000]: 100%|██████████| 10/10 [00:00<00:00, 334.47it/s, loss=1.16]
Epoch [510/5000]: 100%|██████████| 10/10 [00:00<00:00, 306.81it/s, loss=1.16]
Epoch [511/5000]: 100%|██████████| 10/10 [00:00<00:00, 359.06it/s, loss=1.26]
Epoch [512/5000]: 100%|██████████| 10/10 [00:00<00:00, 323.77it/s, loss=0.667]
Epoch [513/5000]: 100%|██████████| 10/10 [00:00<00:00, 296.72it/s, loss=0.922]
Epoch [514/5000]: 100%|██████████| 10/10 [00:00<00:00, 334

Saving model with loss 0.855...


Epoch [520/5000]: 100%|██████████| 10/10 [00:00<00:00, 346.27it/s, loss=1.27]
Epoch [521/5000]: 100%|██████████| 10/10 [00:00<00:00, 333.35it/s, loss=0.991]
Epoch [522/5000]: 100%|██████████| 10/10 [00:00<00:00, 306.67it/s, loss=1.25]
Epoch [523/5000]: 100%|██████████| 10/10 [00:00<00:00, 284.18it/s, loss=0.896]
Epoch [524/5000]: 100%|██████████| 10/10 [00:00<00:00, 316.15it/s, loss=0.968]
Epoch [525/5000]: 100%|██████████| 10/10 [00:00<00:00, 313.51it/s, loss=0.703]
Epoch [526/5000]: 100%|██████████| 10/10 [00:00<00:00, 334.63it/s, loss=0.672]
Epoch [527/5000]: 100%|██████████| 10/10 [00:00<00:00, 326.11it/s, loss=0.99]
Epoch [528/5000]: 100%|██████████| 10/10 [00:00<00:00, 313.27it/s, loss=0.917]
Epoch [529/5000]: 100%|██████████| 10/10 [00:00<00:00, 312.59it/s, loss=1.39]
Epoch [530/5000]: 100%|██████████| 10/10 [00:00<00:00, 247.64it/s, loss=1.1]
Epoch [531/5000]: 100%|██████████| 10/10 [00:00<00:00, 278.46it/s, loss=1.16]
Epoch [532/5000]: 100%|██████████| 10/10 [00:00<00:00, 288.

Saving model with loss 0.854...


Epoch [539/5000]: 100%|██████████| 10/10 [00:00<00:00, 280.16it/s, loss=0.889]
Epoch [540/5000]: 100%|██████████| 10/10 [00:00<00:00, 323.64it/s, loss=1.17]
Epoch [541/5000]: 100%|██████████| 10/10 [00:00<00:00, 278.54it/s, loss=0.827]
Epoch [542/5000]: 100%|██████████| 10/10 [00:00<00:00, 298.97it/s, loss=0.927]
Epoch [543/5000]: 100%|██████████| 10/10 [00:00<00:00, 334.90it/s, loss=0.724]
Epoch [544/5000]: 100%|██████████| 10/10 [00:00<00:00, 294.02it/s, loss=0.596]
Epoch [545/5000]: 100%|██████████| 10/10 [00:00<00:00, 316.51it/s, loss=0.771]
Epoch [546/5000]: 100%|██████████| 10/10 [00:00<00:00, 278.83it/s, loss=0.84]
Epoch [547/5000]: 100%|██████████| 10/10 [00:00<00:00, 281.20it/s, loss=1.04]
Epoch [548/5000]: 100%|██████████| 10/10 [00:00<00:00, 322.90it/s, loss=1.04]
Epoch [549/5000]: 100%|██████████| 10/10 [00:00<00:00, 298.65it/s, loss=0.822]
Epoch [550/5000]: 100%|██████████| 10/10 [00:00<00:00, 289.45it/s, loss=0.826]
Epoch [551/5000]: 100%|██████████| 10/10 [00:00<00:00, 2

Saving model with loss 0.848...


Epoch [552/5000]: 100%|██████████| 10/10 [00:00<00:00, 348.70it/s, loss=0.946]
Epoch [553/5000]: 100%|██████████| 10/10 [00:00<00:00, 346.32it/s, loss=1.14]
Epoch [554/5000]: 100%|██████████| 10/10 [00:00<00:00, 313.53it/s, loss=1.05]
Epoch [555/5000]: 100%|██████████| 10/10 [00:00<00:00, 336.81it/s, loss=1.04]
Epoch [556/5000]: 100%|██████████| 10/10 [00:00<00:00, 334.44it/s, loss=0.745]
Epoch [557/5000]: 100%|██████████| 10/10 [00:00<00:00, 303.85it/s, loss=1.1]
Epoch [558/5000]: 100%|██████████| 10/10 [00:00<00:00, 338.01it/s, loss=0.959]
Epoch [559/5000]: 100%|██████████| 10/10 [00:00<00:00, 346.23it/s, loss=0.587]
Epoch [560/5000]: 100%|██████████| 10/10 [00:00<00:00, 303.09it/s, loss=0.807]
Epoch [561/5000]: 100%|██████████| 10/10 [00:00<00:00, 337.41it/s, loss=0.814]
Epoch [562/5000]: 100%|██████████| 10/10 [00:00<00:00, 346.47it/s, loss=1.05]
Epoch [563/5000]: 100%|██████████| 10/10 [00:00<00:00, 286.64it/s, loss=0.813]
Epoch [564/5000]: 100%|██████████| 10/10 [00:00<00:00, 346

Saving model with loss 0.845...


Epoch [565/5000]: 100%|██████████| 10/10 [00:00<00:00, 334.50it/s, loss=0.99]
Epoch [566/5000]: 100%|██████████| 10/10 [00:00<00:00, 307.97it/s, loss=0.931]
Epoch [567/5000]: 100%|██████████| 10/10 [00:00<00:00, 334.42it/s, loss=0.999]
Epoch [568/5000]: 100%|██████████| 10/10 [00:00<00:00, 323.85it/s, loss=0.915]
Epoch [569/5000]: 100%|██████████| 10/10 [00:00<00:00, 326.09it/s, loss=0.704]
Epoch [570/5000]: 100%|██████████| 10/10 [00:00<00:00, 323.68it/s, loss=0.959]
Epoch [571/5000]: 100%|██████████| 10/10 [00:00<00:00, 345.92it/s, loss=1.24]
Epoch [572/5000]: 100%|██████████| 10/10 [00:00<00:00, 306.22it/s, loss=0.759]
Epoch [573/5000]: 100%|██████████| 10/10 [00:00<00:00, 345.92it/s, loss=1.08]
Epoch [574/5000]: 100%|██████████| 10/10 [00:00<00:00, 346.07it/s, loss=0.812]
Epoch [575/5000]: 100%|██████████| 10/10 [00:00<00:00, 307.05it/s, loss=1.01]
Epoch [576/5000]: 100%|██████████| 10/10 [00:00<00:00, 334.46it/s, loss=0.788]
Epoch [577/5000]: 100%|██████████| 10/10 [00:00<00:00, 3

Saving model with loss 0.827...


Epoch [588/5000]: 100%|██████████| 10/10 [00:00<00:00, 344.95it/s, loss=0.755]
Epoch [589/5000]: 100%|██████████| 10/10 [00:00<00:00, 316.71it/s, loss=0.998]
Epoch [590/5000]: 100%|██████████| 10/10 [00:00<00:00, 278.67it/s, loss=0.716]
Epoch [591/5000]: 100%|██████████| 10/10 [00:00<00:00, 83.00it/s, loss=0.949]
Epoch [592/5000]: 100%|██████████| 10/10 [00:00<00:00, 297.86it/s, loss=0.782]
Epoch [593/5000]: 100%|██████████| 10/10 [00:00<00:00, 295.08it/s, loss=1.13]
Epoch [594/5000]: 100%|██████████| 10/10 [00:00<00:00, 337.90it/s, loss=0.631]
Epoch [595/5000]: 100%|██████████| 10/10 [00:00<00:00, 257.25it/s, loss=0.934]
Epoch [596/5000]: 100%|██████████| 10/10 [00:00<00:00, 255.95it/s, loss=0.848]
Epoch [597/5000]: 100%|██████████| 10/10 [00:00<00:00, 297.62it/s, loss=0.927]
Epoch [598/5000]: 100%|██████████| 10/10 [00:00<00:00, 323.73it/s, loss=0.784]
Epoch [599/5000]: 100%|██████████| 10/10 [00:00<00:00, 296.71it/s, loss=0.843]
Epoch [600/5000]: 100%|██████████| 10/10 [00:00<00:00,

Epoch [601/5000]: Train loss: 0.9116, Valid loss: 0.8484


Epoch [602/5000]: 100%|██████████| 10/10 [00:00<00:00, 279.93it/s, loss=0.792]
Epoch [603/5000]: 100%|██████████| 10/10 [00:00<00:00, 323.70it/s, loss=0.866]
Epoch [604/5000]: 100%|██████████| 10/10 [00:00<00:00, 358.31it/s, loss=1.15]
Epoch [605/5000]: 100%|██████████| 10/10 [00:00<00:00, 338.09it/s, loss=0.844]
Epoch [606/5000]: 100%|██████████| 10/10 [00:00<00:00, 323.82it/s, loss=1.04]
Epoch [607/5000]: 100%|██████████| 10/10 [00:00<00:00, 333.81it/s, loss=0.648]
Epoch [608/5000]: 100%|██████████| 10/10 [00:00<00:00, 326.66it/s, loss=1.22]
Epoch [609/5000]: 100%|██████████| 10/10 [00:00<00:00, 313.50it/s, loss=0.963]
Epoch [610/5000]: 100%|██████████| 10/10 [00:00<00:00, 280.13it/s, loss=0.992]
Epoch [611/5000]: 100%|██████████| 10/10 [00:00<00:00, 313.58it/s, loss=0.971]
Epoch [612/5000]: 100%|██████████| 10/10 [00:00<00:00, 313.63it/s, loss=0.81]
Epoch [613/5000]: 100%|██████████| 10/10 [00:00<00:00, 297.11it/s, loss=1.06]
Epoch [614/5000]: 100%|██████████| 10/10 [00:00<00:00, 34

Saving model with loss 0.814...


Epoch [638/5000]: 100%|██████████| 10/10 [00:00<00:00, 278.71it/s, loss=1.17]
Epoch [639/5000]: 100%|██████████| 10/10 [00:00<00:00, 305.67it/s, loss=0.731]
Epoch [640/5000]: 100%|██████████| 10/10 [00:00<00:00, 304.36it/s, loss=0.818]
Epoch [641/5000]: 100%|██████████| 10/10 [00:00<00:00, 303.19it/s, loss=0.901]
Epoch [642/5000]: 100%|██████████| 10/10 [00:00<00:00, 327.02it/s, loss=0.869]
Epoch [643/5000]: 100%|██████████| 10/10 [00:00<00:00, 323.67it/s, loss=0.865]
Epoch [644/5000]: 100%|██████████| 10/10 [00:00<00:00, 280.44it/s, loss=1.21]
Epoch [645/5000]: 100%|██████████| 10/10 [00:00<00:00, 334.35it/s, loss=1.13]
Epoch [646/5000]: 100%|██████████| 10/10 [00:00<00:00, 323.41it/s, loss=0.884]
Epoch [647/5000]: 100%|██████████| 10/10 [00:00<00:00, 279.64it/s, loss=1.06]
Epoch [648/5000]: 100%|██████████| 10/10 [00:00<00:00, 334.52it/s, loss=0.578]
Epoch [649/5000]: 100%|██████████| 10/10 [00:00<00:00, 323.67it/s, loss=1.17]
Epoch [650/5000]: 100%|██████████| 10/10 [00:00<00:00, 27

Epoch [701/5000]: Train loss: 0.9017, Valid loss: 0.8904


Epoch [702/5000]: 100%|██████████| 10/10 [00:00<00:00, 313.51it/s, loss=0.936]
Epoch [703/5000]: 100%|██████████| 10/10 [00:00<00:00, 304.05it/s, loss=0.755]
Epoch [704/5000]: 100%|██████████| 10/10 [00:00<00:00, 280.48it/s, loss=0.68]
Epoch [705/5000]: 100%|██████████| 10/10 [00:00<00:00, 313.49it/s, loss=1.05]
Epoch [706/5000]: 100%|██████████| 10/10 [00:00<00:00, 294.31it/s, loss=0.665]
Epoch [707/5000]: 100%|██████████| 10/10 [00:00<00:00, 297.86it/s, loss=0.732]
Epoch [708/5000]: 100%|██████████| 10/10 [00:00<00:00, 313.58it/s, loss=0.609]
Epoch [709/5000]: 100%|██████████| 10/10 [00:00<00:00, 316.48it/s, loss=0.87]
Epoch [710/5000]: 100%|██████████| 10/10 [00:00<00:00, 323.52it/s, loss=0.999]
Epoch [711/5000]: 100%|██████████| 10/10 [00:00<00:00, 286.73it/s, loss=0.918]
Epoch [712/5000]: 100%|██████████| 10/10 [00:00<00:00, 297.91it/s, loss=1.06]
Epoch [713/5000]: 100%|██████████| 10/10 [00:00<00:00, 303.97it/s, loss=0.893]
Epoch [714/5000]: 100%|██████████| 10/10 [00:00<00:00, 2

Epoch [801/5000]: Train loss: 0.8923, Valid loss: 0.8612


Epoch [802/5000]: 100%|██████████| 10/10 [00:00<00:00, 297.06it/s, loss=0.714]
Epoch [803/5000]: 100%|██████████| 10/10 [00:00<00:00, 334.74it/s, loss=0.805]
Epoch [804/5000]: 100%|██████████| 10/10 [00:00<00:00, 346.23it/s, loss=1.23]
Epoch [805/5000]: 100%|██████████| 10/10 [00:00<00:00, 297.03it/s, loss=1.07]
Epoch [806/5000]: 100%|██████████| 10/10 [00:00<00:00, 345.90it/s, loss=1.22]
Epoch [807/5000]: 100%|██████████| 10/10 [00:00<00:00, 371.41it/s, loss=0.834]
Epoch [808/5000]: 100%|██████████| 10/10 [00:00<00:00, 306.92it/s, loss=1.16]
Epoch [809/5000]: 100%|██████████| 10/10 [00:00<00:00, 358.37it/s, loss=0.815]
Epoch [810/5000]: 100%|██████████| 10/10 [00:00<00:00, 302.97it/s, loss=0.894]
Epoch [811/5000]: 100%|██████████| 10/10 [00:00<00:00, 306.63it/s, loss=1.17]
Epoch [812/5000]: 100%|██████████| 10/10 [00:00<00:00, 250.01it/s, loss=0.763]
Epoch [813/5000]: 100%|██████████| 10/10 [00:00<00:00, 252.19it/s, loss=0.86]
Epoch [814/5000]: 100%|██████████| 10/10 [00:00<00:00, 313

Saving model with loss 0.811...


Epoch [833/5000]: 100%|██████████| 10/10 [00:00<00:00, 334.40it/s, loss=1.06]
Epoch [834/5000]: 100%|██████████| 10/10 [00:00<00:00, 348.78it/s, loss=1.08]
Epoch [835/5000]: 100%|██████████| 10/10 [00:00<00:00, 295.05it/s, loss=0.958]
Epoch [836/5000]: 100%|██████████| 10/10 [00:00<00:00, 358.33it/s, loss=1.04]
Epoch [837/5000]: 100%|██████████| 10/10 [00:00<00:00, 281.08it/s, loss=0.813]
Epoch [838/5000]: 100%|██████████| 10/10 [00:00<00:00, 278.69it/s, loss=0.606]
Epoch [839/5000]: 100%|██████████| 10/10 [00:00<00:00, 272.83it/s, loss=0.904]
Epoch [840/5000]: 100%|██████████| 10/10 [00:00<00:00, 334.30it/s, loss=1.01]
Epoch [841/5000]: 100%|██████████| 10/10 [00:00<00:00, 346.28it/s, loss=0.841]
Epoch [842/5000]: 100%|██████████| 10/10 [00:00<00:00, 336.88it/s, loss=1.19]
Epoch [843/5000]: 100%|██████████| 10/10 [00:00<00:00, 333.47it/s, loss=0.92]
Epoch [844/5000]: 100%|██████████| 10/10 [00:00<00:00, 315.61it/s, loss=0.952]
Epoch [845/5000]: 100%|██████████| 10/10 [00:00<00:00, 266

Saving model with loss 0.811...


Epoch [899/5000]: 100%|██████████| 10/10 [00:00<00:00, 323.67it/s, loss=0.945]
Epoch [900/5000]: 100%|██████████| 10/10 [00:00<00:00, 313.59it/s, loss=0.957]
Epoch [901/5000]: 100%|██████████| 10/10 [00:00<00:00, 348.64it/s, loss=0.967]


Epoch [901/5000]: Train loss: 0.8964, Valid loss: 0.8583


Epoch [902/5000]: 100%|██████████| 10/10 [00:00<00:00, 334.43it/s, loss=1.14]
Epoch [903/5000]: 100%|██████████| 10/10 [00:00<00:00, 323.72it/s, loss=0.821]
Epoch [904/5000]: 100%|██████████| 10/10 [00:00<00:00, 361.27it/s, loss=0.736]
Epoch [905/5000]: 100%|██████████| 10/10 [00:00<00:00, 345.96it/s, loss=1.06]
Epoch [906/5000]: 100%|██████████| 10/10 [00:00<00:00, 304.06it/s, loss=1.3]
Epoch [907/5000]: 100%|██████████| 10/10 [00:00<00:00, 315.79it/s, loss=0.892]
Epoch [908/5000]: 100%|██████████| 10/10 [00:00<00:00, 313.51it/s, loss=1.08]
Epoch [909/5000]: 100%|██████████| 10/10 [00:00<00:00, 277.99it/s, loss=0.852]
Epoch [910/5000]: 100%|██████████| 10/10 [00:00<00:00, 295.03it/s, loss=0.822]
Epoch [911/5000]: 100%|██████████| 10/10 [00:00<00:00, 286.65it/s, loss=1.1]
Epoch [912/5000]: 100%|██████████| 10/10 [00:00<00:00, 280.15it/s, loss=0.86]
Epoch [913/5000]: 100%|██████████| 10/10 [00:00<00:00, 304.04it/s, loss=1.34]
Epoch [914/5000]: 100%|██████████| 10/10 [00:00<00:00, 345.04

Saving model with loss 0.784...


Epoch [959/5000]: 100%|██████████| 10/10 [00:00<00:00, 304.01it/s, loss=0.725]
Epoch [960/5000]: 100%|██████████| 10/10 [00:00<00:00, 288.42it/s, loss=0.73]
Epoch [961/5000]: 100%|██████████| 10/10 [00:00<00:00, 270.96it/s, loss=0.863]
Epoch [962/5000]: 100%|██████████| 10/10 [00:00<00:00, 294.51it/s, loss=0.958]
Epoch [963/5000]: 100%|██████████| 10/10 [00:00<00:00, 316.48it/s, loss=0.868]
Epoch [964/5000]: 100%|██████████| 10/10 [00:00<00:00, 334.50it/s, loss=0.794]
Epoch [965/5000]: 100%|██████████| 10/10 [00:00<00:00, 280.41it/s, loss=0.905]
Epoch [966/5000]: 100%|██████████| 10/10 [00:00<00:00, 295.18it/s, loss=0.801]
Epoch [967/5000]: 100%|██████████| 10/10 [00:00<00:00, 291.60it/s, loss=0.742]
Epoch [968/5000]: 100%|██████████| 10/10 [00:00<00:00, 252.40it/s, loss=0.956]
Epoch [969/5000]: 100%|██████████| 10/10 [00:00<00:00, 294.76it/s, loss=0.631]
Epoch [970/5000]: 100%|██████████| 10/10 [00:00<00:00, 347.13it/s, loss=0.981]
Epoch [971/5000]: 100%|██████████| 10/10 [00:00<00:00

Saving model with loss 0.775...


Epoch [993/5000]: 100%|██████████| 10/10 [00:00<00:00, 323.60it/s, loss=0.728]
Epoch [994/5000]: 100%|██████████| 10/10 [00:00<00:00, 271.23it/s, loss=1.52]
Epoch [995/5000]: 100%|██████████| 10/10 [00:00<00:00, 281.09it/s, loss=1.06]
Epoch [996/5000]: 100%|██████████| 10/10 [00:00<00:00, 278.70it/s, loss=1.15]
Epoch [997/5000]: 100%|██████████| 10/10 [00:00<00:00, 280.42it/s, loss=1.12]
Epoch [998/5000]: 100%|██████████| 10/10 [00:00<00:00, 278.60it/s, loss=0.991]
Epoch [999/5000]: 100%|██████████| 10/10 [00:00<00:00, 285.98it/s, loss=0.95]
Epoch [1000/5000]: 100%|██████████| 10/10 [00:00<00:00, 281.19it/s, loss=0.735]
Epoch [1001/5000]: 100%|██████████| 10/10 [00:00<00:00, 295.11it/s, loss=1.15]


Epoch [1001/5000]: Train loss: 0.9413, Valid loss: 0.8132


Epoch [1002/5000]: 100%|██████████| 10/10 [00:00<00:00, 348.73it/s, loss=0.729]
Epoch [1003/5000]: 100%|██████████| 10/10 [00:00<00:00, 304.01it/s, loss=1.09]
Epoch [1004/5000]: 100%|██████████| 10/10 [00:00<00:00, 313.56it/s, loss=0.788]
Epoch [1005/5000]: 100%|██████████| 10/10 [00:00<00:00, 326.89it/s, loss=0.819]
Epoch [1006/5000]: 100%|██████████| 10/10 [00:00<00:00, 238.90it/s, loss=0.539]
Epoch [1007/5000]: 100%|██████████| 10/10 [00:00<00:00, 305.93it/s, loss=0.872]
Epoch [1008/5000]: 100%|██████████| 10/10 [00:00<00:00, 313.57it/s, loss=1.03]
Epoch [1009/5000]: 100%|██████████| 10/10 [00:00<00:00, 323.48it/s, loss=0.64]
Epoch [1010/5000]: 100%|██████████| 10/10 [00:00<00:00, 315.38it/s, loss=1.07]
Epoch [1011/5000]: 100%|██████████| 10/10 [00:00<00:00, 264.03it/s, loss=0.933]
Epoch [1012/5000]: 100%|██████████| 10/10 [00:00<00:00, 322.74it/s, loss=1.04]
Epoch [1013/5000]: 100%|██████████| 10/10 [00:00<00:00, 306.94it/s, loss=0.819]
Epoch [1014/5000]: 100%|██████████| 10/10 [00

Epoch [1101/5000]: Train loss: 0.8905, Valid loss: 0.8122


Epoch [1102/5000]: 100%|██████████| 10/10 [00:00<00:00, 348.42it/s, loss=0.729]
Epoch [1103/5000]: 100%|██████████| 10/10 [00:00<00:00, 304.07it/s, loss=0.871]
Epoch [1104/5000]: 100%|██████████| 10/10 [00:00<00:00, 304.33it/s, loss=1.3]
Epoch [1105/5000]: 100%|██████████| 10/10 [00:00<00:00, 315.75it/s, loss=0.663]
Epoch [1106/5000]: 100%|██████████| 10/10 [00:00<00:00, 286.64it/s, loss=0.89]
Epoch [1107/5000]: 100%|██████████| 10/10 [00:00<00:00, 344.83it/s, loss=0.792]
Epoch [1108/5000]: 100%|██████████| 10/10 [00:00<00:00, 337.69it/s, loss=0.904]
Epoch [1109/5000]: 100%|██████████| 10/10 [00:00<00:00, 286.65it/s, loss=0.837]
Epoch [1110/5000]: 100%|██████████| 10/10 [00:00<00:00, 333.39it/s, loss=0.724]
Epoch [1111/5000]: 100%|██████████| 10/10 [00:00<00:00, 313.53it/s, loss=1.07]
Epoch [1112/5000]: 100%|██████████| 10/10 [00:00<00:00, 271.06it/s, loss=0.992]
Epoch [1113/5000]: 100%|██████████| 10/10 [00:00<00:00, 67.53it/s, loss=0.705]
Epoch [1114/5000]: 100%|██████████| 10/10 [00

Epoch [1201/5000]: Train loss: 0.9020, Valid loss: 0.8732


Epoch [1202/5000]: 100%|██████████| 10/10 [00:00<00:00, 286.60it/s, loss=0.836]
Epoch [1203/5000]: 100%|██████████| 10/10 [00:00<00:00, 306.54it/s, loss=1.04]
Epoch [1204/5000]: 100%|██████████| 10/10 [00:00<00:00, 313.42it/s, loss=0.929]
Epoch [1205/5000]: 100%|██████████| 10/10 [00:00<00:00, 295.05it/s, loss=0.776]
Epoch [1206/5000]: 100%|██████████| 10/10 [00:00<00:00, 337.66it/s, loss=0.959]
Epoch [1207/5000]: 100%|██████████| 10/10 [00:00<00:00, 345.81it/s, loss=0.906]
Epoch [1208/5000]: 100%|██████████| 10/10 [00:00<00:00, 312.51it/s, loss=0.794]
Epoch [1209/5000]: 100%|██████████| 10/10 [00:00<00:00, 297.87it/s, loss=0.873]
Epoch [1210/5000]: 100%|██████████| 10/10 [00:00<00:00, 304.18it/s, loss=0.633]
Epoch [1211/5000]: 100%|██████████| 10/10 [00:00<00:00, 265.50it/s, loss=0.818]
Epoch [1212/5000]: 100%|██████████| 10/10 [00:00<00:00, 303.97it/s, loss=0.826]
Epoch [1213/5000]: 100%|██████████| 10/10 [00:00<00:00, 295.08it/s, loss=0.777]
Epoch [1214/5000]: 100%|██████████| 10/10

Epoch [1301/5000]: Train loss: 0.9019, Valid loss: 0.8789


Epoch [1302/5000]: 100%|██████████| 10/10 [00:00<00:00, 326.17it/s, loss=0.994]
Epoch [1303/5000]: 100%|██████████| 10/10 [00:00<00:00, 313.63it/s, loss=0.842]
Epoch [1304/5000]: 100%|██████████| 10/10 [00:00<00:00, 323.44it/s, loss=1.13]
Epoch [1305/5000]: 100%|██████████| 10/10 [00:00<00:00, 337.48it/s, loss=0.826]
Epoch [1306/5000]: 100%|██████████| 10/10 [00:00<00:00, 334.35it/s, loss=1.06]
Epoch [1307/5000]: 100%|██████████| 10/10 [00:00<00:00, 294.14it/s, loss=0.746]
Epoch [1308/5000]: 100%|██████████| 10/10 [00:00<00:00, 349.50it/s, loss=0.863]
Epoch [1309/5000]: 100%|██████████| 10/10 [00:00<00:00, 334.52it/s, loss=1.16]
Epoch [1310/5000]: 100%|██████████| 10/10 [00:00<00:00, 312.65it/s, loss=0.64]
Epoch [1311/5000]: 100%|██████████| 10/10 [00:00<00:00, 333.39it/s, loss=0.725]
Epoch [1312/5000]: 100%|██████████| 10/10 [00:00<00:00, 345.84it/s, loss=0.833]
Epoch [1313/5000]: 100%|██████████| 10/10 [00:00<00:00, 333.23it/s, loss=1.04]
Epoch [1314/5000]: 100%|██████████| 10/10 [00

Epoch [1401/5000]: Train loss: 0.8830, Valid loss: 0.8552


Epoch [1402/5000]: 100%|██████████| 10/10 [00:00<00:00, 315.56it/s, loss=1.11]
Epoch [1403/5000]: 100%|██████████| 10/10 [00:00<00:00, 346.02it/s, loss=1.04]
Epoch [1404/5000]: 100%|██████████| 10/10 [00:00<00:00, 312.76it/s, loss=0.684]
Epoch [1405/5000]: 100%|██████████| 10/10 [00:00<00:00, 297.86it/s, loss=0.603]
Epoch [1406/5000]: 100%|██████████| 10/10 [00:00<00:00, 358.29it/s, loss=0.789]
Epoch [1407/5000]: 100%|██████████| 10/10 [00:00<00:00, 312.55it/s, loss=0.71]
Epoch [1408/5000]: 100%|██████████| 10/10 [00:00<00:00, 280.92it/s, loss=1.28]
Epoch [1409/5000]: 100%|██████████| 10/10 [00:00<00:00, 323.79it/s, loss=0.715]
Epoch [1410/5000]: 100%|██████████| 10/10 [00:00<00:00, 325.84it/s, loss=0.722]
Epoch [1411/5000]: 100%|██████████| 10/10 [00:00<00:00, 304.05it/s, loss=0.57]
Epoch [1412/5000]: 100%|██████████| 10/10 [00:00<00:00, 323.62it/s, loss=0.824]
Epoch [1413/5000]: 100%|██████████| 10/10 [00:00<00:00, 348.81it/s, loss=0.785]
Epoch [1414/5000]: 100%|██████████| 10/10 [00

Epoch [1501/5000]: Train loss: 0.8867, Valid loss: 0.8703


Epoch [1502/5000]: 100%|██████████| 10/10 [00:00<00:00, 286.62it/s, loss=1.03]
Epoch [1503/5000]: 100%|██████████| 10/10 [00:00<00:00, 265.95it/s, loss=1.21]
Epoch [1504/5000]: 100%|██████████| 10/10 [00:00<00:00, 286.73it/s, loss=0.999]
Epoch [1505/5000]: 100%|██████████| 10/10 [00:00<00:00, 273.13it/s, loss=0.973]
Epoch [1506/5000]: 100%|██████████| 10/10 [00:00<00:00, 345.99it/s, loss=0.994]
Epoch [1507/5000]: 100%|██████████| 10/10 [00:00<00:00, 358.43it/s, loss=0.774]
Epoch [1508/5000]: 100%|██████████| 10/10 [00:00<00:00, 315.39it/s, loss=0.931]
Epoch [1509/5000]: 100%|██████████| 10/10 [00:00<00:00, 346.04it/s, loss=1.25]
Epoch [1510/5000]: 100%|██████████| 10/10 [00:00<00:00, 358.47it/s, loss=0.893]
Epoch [1511/5000]: 100%|██████████| 10/10 [00:00<00:00, 297.08it/s, loss=0.92]
Epoch [1512/5000]: 100%|██████████| 10/10 [00:00<00:00, 323.79it/s, loss=0.949]
Epoch [1513/5000]: 100%|██████████| 10/10 [00:00<00:00, 345.97it/s, loss=1]
Epoch [1514/5000]: 100%|██████████| 10/10 [00:00


Model is not improving, so we halt the training session.
Model current loss 0.775





### 验证模型结果
模型是loss在0.77左右存档的，用该模型重新对验证集(valid)进行验证，观察loss是否仍然是0.77

In [24]:
same_seed(config['seed'])
train_data, test_data = pd.read_csv(File_Path_Train).values, pd.read_csv(File_Path_Test).values
train_data, valid_data = train_valid_split(train_data, config['valid_ratio'], config['seed'])

# # Select features
x_train, x_valid, x_test, y_train, y_valid = select_feat(train_data, valid_data, test_data, config['select_all'], config['select_features'])

valid_loader = DataLoader(valid_dataset, batch_size=config['batch_size'], shuffle=True, pin_memory=True)

criterion = nn.MSELoss(reduction='mean')

model = My_Model(input_dim=x_train.shape[1]).to(device)
model.load_state_dict(torch.load(config['save_path']))
loss_record = []
for x, y in valid_loader:
    x, y = x.to(device), y.to(device)
    with torch.no_grad():
        pred = model(x)
        loss = criterion(pred, y)

    loss_record.append(loss.item())

mean_valid_loss = sum(loss_record)/len(loss_record)
print(mean_valid_loss)

0.8618664741516113


# Plot learning curves with `tensorboard` (optional)

`tensorboard` is a tool that allows you to visualize your training progress.

If this block does not display your learning curve, please wait for few minutes, and re-run this block. It might take some time to load your logging information.

In [None]:
%reload_ext tensorboard
%tensorboard --logdir=./runs/

# Testing
The predictions of your model on testing set will be stored at `pred.csv`.

In [None]:
def save_pred(preds, file):
    ''' Save predictions to specified file '''
    with open(file, 'w') as fp:
        writer = csv.writer(fp)
        writer.writerow(['id', 'tested_positive'])
        for i, p in enumerate(preds):
            writer.writerow([i, p])

model = My_Model(input_dim=x_train.shape[1]).to(device)
model.load_state_dict(torch.load(config['save_path']))
preds = predict(test_loader, model, device)
save_pred(preds, 'pred.csv')

# Download

Run this block to download the `pred.csv` automatically.

In [None]:
# from google.colab import files
# files.download('pred.csv')

# Reference
This notebook uses code written by Heng-Jui Chang @ NTUEE (https://github.com/ga642381/ML2021-Spring/blob/main/HW01/HW01.ipynb)