In [None]:
### 再优化一次模型结构

strong : Different model architectures and optimizers

### 开始训练

In [27]:
# Numerical Operations
import math
import numpy as np

# Reading/Writing Data
import pandas as pd
import os
import csv

# For Progress Bar
from tqdm import tqdm

# Pytorch
import torch 
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split

# For plotting learning curve
from torch.utils.tensorboard import SummaryWriter


def same_seed(seed): 
    '''Fixes random number generator seeds for reproducibility.'''
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

def train_valid_split(data_set, valid_ratio, seed):
    '''Split provided training data into training set and validation set'''
    valid_set_size = int(valid_ratio * len(data_set)) 
    train_set_size = len(data_set) - valid_set_size
    train_set, valid_set = random_split(data_set, [train_set_size, valid_set_size], generator=torch.Generator().manual_seed(seed))
    return np.array(train_set), np.array(valid_set)

def predict(test_loader, model, device):
    model.eval() # Set your model to evaluation mode.
    preds = []
    for x in tqdm(test_loader):
        x = x.to(device)                        
        with torch.no_grad():                   
            pred = model(x)                     
            preds.append(pred.detach().cpu())   
    preds = torch.cat(preds, dim=0).numpy()  
    return preds

class COVID19Dataset(Dataset):
    '''
    x: Features.
    y: Targets, if none, do prediction.
    '''
    def __init__(self, x, y=None):
        if y is None:
            self.y = y
        else:
            self.y = torch.FloatTensor(y)
        self.x = torch.FloatTensor(x)

    def __getitem__(self, idx):
        if self.y is None:
            return self.x[idx]
        else:
            return self.x[idx], self.y[idx]

    def __len__(self):
        return len(self.x)

class My_Model(nn.Module):
    def __init__(self, input_dim):
        super(My_Model, self).__init__()
        # TODO: modify model's structure, be aware of dimensions. 
        self.layers = nn.Sequential(
            nn.Linear(input_dim, 80),
            nn.ReLU(),
            nn.Linear(80, 64),
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, 16),
            nn.ReLU(),
            nn.Linear(16, 8),
            nn.ReLU(),
            nn.Linear(8, 4),
            nn.ReLU(),
            nn.Linear(4, 1)
        )

    def forward(self, x):
        x = self.layers(x)
        x = x.squeeze(1) # (B, 1) -> (B)
        return x

def select_feat(train_data, valid_data, test_data, select_all=True):
    '''Selects useful features to perform regression'''
    y_train, y_valid = train_data[:,-1], valid_data[:,-1]
    raw_x_train, raw_x_valid, raw_x_test = train_data[:,:-1], valid_data[:,:-1], test_data

    if select_all:
        feat_idx = list(range(raw_x_train.shape[1]))
    else:
        feat_idx = list(range(38, 117, 1))
        
    return raw_x_train[:,feat_idx], raw_x_valid[:,feat_idx], raw_x_test[:,feat_idx], y_train, y_valid

def trainer(train_loader, valid_loader, model, config, device):

    criterion = nn.MSELoss(reduction='mean') # Define your loss function, do not modify this.

    # Define your optimization algorithm. 
    # TODO: Please check https://pytorch.org/docs/stable/optim.html to get more available algorithms.
    # TODO: L2 regularization (optimizer(weight decay...) or implement by your self).
    # optimizer = torch.optim.SGD(model.parameters(), lr=config['learning_rate'], momentum=0.9) 
    # optimizer = torch.optim.Adam(model.parameters(), lr=config['learning_rate'], weight_decay=1e-5, amsgrad=True)

    # optimizer = torch.optim.AdamW(model.parameters(), lr=config['learning_rate'], weight_decay=1e-1)
    # optimizer = torch.optim.ASGD(model.parameters(), lr=config['learning_rate'], weight_decay=1e-2)   # nothing surprising
    # optimizer = torch.optim.Adagrad(model.parameters(), lr=config['learning_rate'])
    optimizer = torch.optim.RMSprop(model.parameters(), lr=config['learning_rate'])


    writer = SummaryWriter() # Writer of tensoboard.

    if not os.path.isdir('./models'):
        os.mkdir('./models') # Create directory of saving models.

    n_epochs, best_loss, step, early_stop_count = config['n_epochs'], math.inf, 0, 0

    for epoch in range(n_epochs):
        model.train() # Set your model to train mode.
        loss_record = []

        # tqdm is a package to visualize your training progress.
        train_pbar = tqdm(train_loader, position=0, leave=True)

        for x, y in train_pbar:
            optimizer.zero_grad()               # Set gradient to zero.
            x, y = x.to(device), y.to(device)   # Move your data to device. 
            pred = model(x)             
            loss = criterion(pred, y)
            loss.backward()                     # Compute gradient(backpropagation).
            optimizer.step()                    # Update parameters.
            step += 1
            loss_record.append(loss.detach().item())
            
            # Display current epoch number and loss on tqdm progress bar.
            train_pbar.set_description(f'Epoch [{epoch+1}/{n_epochs}]')
            train_pbar.set_postfix({'loss': loss.detach().item()})

        mean_train_loss = sum(loss_record)/len(loss_record)
        writer.add_scalar('Loss/train', mean_train_loss, step)

        model.eval() # Set your model to evaluation mode.
        loss_record = []
        for x, y in valid_loader:
            x, y = x.to(device), y.to(device)
            with torch.no_grad():
                pred = model(x)
                loss = criterion(pred, y)

            loss_record.append(loss.item())
            
        mean_valid_loss = sum(loss_record)/len(loss_record)
        # print(f'Epoch [{epoch+1}/{n_epochs}]: Train loss: {mean_train_loss:.4f}, Valid loss: {mean_valid_loss:.4f}')
        writer.add_scalar('Loss/valid', mean_valid_loss, step)

        if mean_valid_loss < best_loss:
            best_loss = mean_valid_loss
            torch.save(model.state_dict(), config['save_path']) # Save your best model
            print('Saving model with loss {:.3f}...'.format(best_loss))
            early_stop_count = 0
        else: 
            early_stop_count += 1

        if early_stop_count >= config['early_stop']:
            print('\nModel is not improving, so we halt the training session.')
            return

device = 'cuda' if torch.cuda.is_available() else 'cpu'
config = {
    'seed': 131413131,      # Your seed number, you can pick your lucky number. :)
    'select_all': False,   # Whether to use all features.
    'valid_ratio': 0.2,   # validation_size = train_size * valid_ratio
    'n_epochs': 30000,     # Number of epochs.            
    'batch_size': 256, 
    'learning_rate': 1e-2,              # 1e-2,
    'early_stop': 300,    # If model has not improved for this many consecutive epochs, stop training.     
    'save_path': './models/model_1_c.ckpt'  # Your model will be saved here.
}


# Set seed for reproducibility
same_seed(config['seed'])


# train_data size: 2699 x 118 (id + 37 states + 16 features x 5 days) 
# test_data size: 1078 x 117 (without last day's positive rate)
train_data, test_data = pd.read_csv('./covid.train.csv').values, pd.read_csv('./covid.test.csv').values
train_data, valid_data = train_valid_split(train_data, config['valid_ratio'], config['seed'])

# Print out the data size.
print(f"""train_data size: {train_data.shape} 
valid_data size: {valid_data.shape} 
test_data size: {test_data.shape}""")

# Select features
x_train, x_valid, x_test, y_train, y_valid = select_feat(train_data, valid_data, test_data, config['select_all'])

# Print out the number of features.
print(f'number of features: {x_train.shape[1]}')

train_dataset, valid_dataset, test_dataset = COVID19Dataset(x_train, y_train), \
                                            COVID19Dataset(x_valid, y_valid), \
                                            COVID19Dataset(x_test)

# Pytorch data loader loads pytorch dataset into batches.
train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True, pin_memory=True)
valid_loader = DataLoader(valid_dataset, batch_size=config['batch_size'], shuffle=True, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=False, pin_memory=True)



model = My_Model(input_dim=x_train.shape[1]).to(device) # put your model and data on the same computation device.
trainer(train_loader, valid_loader, model, config, device)

train_data size: (2160, 118) 
valid_data size: (539, 118) 
test_data size: (1078, 117)
number of features: 79


Epoch [1/30000]: 100%|██████████| 9/9 [00:00<00:00, 222.44it/s, loss=16.7]


Saving model with loss 50.380...


Epoch [2/30000]: 100%|██████████| 9/9 [00:00<00:00, 237.33it/s, loss=25.3]


Saving model with loss 7.670...


Epoch [3/30000]: 100%|██████████| 9/9 [00:00<00:00, 253.04it/s, loss=6.93]


Saving model with loss 5.298...


Epoch [4/30000]: 100%|██████████| 9/9 [00:00<00:00, 283.86it/s, loss=8.4]
Epoch [5/30000]: 100%|██████████| 9/9 [00:00<00:00, 274.92it/s, loss=3.95]
Epoch [6/30000]: 100%|██████████| 9/9 [00:00<00:00, 255.62it/s, loss=7.84]
Epoch [7/30000]: 100%|██████████| 9/9 [00:00<00:00, 286.29it/s, loss=6.03]


Saving model with loss 5.148...


Epoch [8/30000]: 100%|██████████| 9/9 [00:00<00:00, 291.32it/s, loss=6.82]
Epoch [9/30000]: 100%|██████████| 9/9 [00:00<00:00, 293.42it/s, loss=6.77]
Epoch [10/30000]: 100%|██████████| 9/9 [00:00<00:00, 298.73it/s, loss=4.7]
Epoch [11/30000]: 100%|██████████| 9/9 [00:00<00:00, 282.71it/s, loss=6.2]
Epoch [12/30000]: 100%|██████████| 9/9 [00:00<00:00, 285.80it/s, loss=9.65]
Epoch [13/30000]: 100%|██████████| 9/9 [00:00<00:00, 211.64it/s, loss=6.26]
Epoch [14/30000]: 100%|██████████| 9/9 [00:00<00:00, 270.44it/s, loss=6.23]
Epoch [15/30000]: 100%|██████████| 9/9 [00:00<00:00, 281.75it/s, loss=6.14]
Epoch [16/30000]: 100%|██████████| 9/9 [00:00<00:00, 260.02it/s, loss=5.15]
Epoch [17/30000]: 100%|██████████| 9/9 [00:00<00:00, 281.05it/s, loss=4.3]
Epoch [18/30000]: 100%|██████████| 9/9 [00:00<00:00, 289.49it/s, loss=6.27]
Epoch [19/30000]: 100%|██████████| 9/9 [00:00<00:00, 288.63it/s, loss=6.61]


Saving model with loss 4.982...


Epoch [20/30000]: 100%|██████████| 9/9 [00:00<00:00, 289.11it/s, loss=8.37]
Epoch [21/30000]: 100%|██████████| 9/9 [00:00<00:00, 291.20it/s, loss=4.22]


Saving model with loss 4.885...


Epoch [22/30000]: 100%|██████████| 9/9 [00:00<00:00, 272.53it/s, loss=5.38]
Epoch [23/30000]: 100%|██████████| 9/9 [00:00<00:00, 297.58it/s, loss=5.21]


Saving model with loss 4.618...


Epoch [24/30000]: 100%|██████████| 9/9 [00:00<00:00, 290.41it/s, loss=6.96]
Epoch [25/30000]: 100%|██████████| 9/9 [00:00<00:00, 287.18it/s, loss=8.66]
Epoch [26/30000]: 100%|██████████| 9/9 [00:00<00:00, 279.16it/s, loss=5.5]
Epoch [27/30000]: 100%|██████████| 9/9 [00:00<00:00, 277.86it/s, loss=6.49]
Epoch [28/30000]: 100%|██████████| 9/9 [00:00<00:00, 259.04it/s, loss=7.28]
Epoch [29/30000]: 100%|██████████| 9/9 [00:00<00:00, 251.26it/s, loss=4.89]
Epoch [30/30000]: 100%|██████████| 9/9 [00:00<00:00, 275.30it/s, loss=9.95]
Epoch [31/30000]: 100%|██████████| 9/9 [00:00<00:00, 265.03it/s, loss=3.91]


Saving model with loss 4.065...


Epoch [32/30000]: 100%|██████████| 9/9 [00:00<00:00, 269.37it/s, loss=3.6]


Saving model with loss 3.940...


Epoch [33/30000]: 100%|██████████| 9/9 [00:00<00:00, 234.03it/s, loss=7.32]
Epoch [34/30000]: 100%|██████████| 9/9 [00:00<00:00, 233.86it/s, loss=6.61]
Epoch [35/30000]: 100%|██████████| 9/9 [00:00<00:00, 265.21it/s, loss=5.12]
Epoch [36/30000]: 100%|██████████| 9/9 [00:00<00:00, 278.95it/s, loss=14.8]
Epoch [37/30000]: 100%|██████████| 9/9 [00:00<00:00, 262.51it/s, loss=4.87]
Epoch [38/30000]: 100%|██████████| 9/9 [00:00<00:00, 249.35it/s, loss=3.49]
Epoch [39/30000]: 100%|██████████| 9/9 [00:00<00:00, 268.41it/s, loss=8.63]
Epoch [40/30000]: 100%|██████████| 9/9 [00:00<00:00, 260.51it/s, loss=9.89]
Epoch [41/30000]: 100%|██████████| 9/9 [00:00<00:00, 253.23it/s, loss=5.29]
Epoch [42/30000]: 100%|██████████| 9/9 [00:00<00:00, 242.34it/s, loss=9.91]
Epoch [43/30000]: 100%|██████████| 9/9 [00:00<00:00, 271.59it/s, loss=8.45]
Epoch [44/30000]: 100%|██████████| 9/9 [00:00<00:00, 284.97it/s, loss=9.21]
Epoch [45/30000]: 100%|██████████| 9/9 [00:00<00:00, 267.23it/s, loss=6.65]
Epoch [46/30

Saving model with loss 3.884...


Epoch [86/30000]: 100%|██████████| 9/9 [00:00<00:00, 292.99it/s, loss=5.74]
Epoch [87/30000]: 100%|██████████| 9/9 [00:00<00:00, 289.80it/s, loss=17.6]
Epoch [88/30000]: 100%|██████████| 9/9 [00:00<00:00, 265.86it/s, loss=7.78]
Epoch [89/30000]: 100%|██████████| 9/9 [00:00<00:00, 286.63it/s, loss=18.2]
Epoch [90/30000]: 100%|██████████| 9/9 [00:00<00:00, 290.21it/s, loss=3.86]
Epoch [91/30000]: 100%|██████████| 9/9 [00:00<00:00, 296.72it/s, loss=20.3]
Epoch [92/30000]: 100%|██████████| 9/9 [00:00<00:00, 267.05it/s, loss=8.65]
Epoch [93/30000]: 100%|██████████| 9/9 [00:00<00:00, 276.28it/s, loss=7.86]
Epoch [94/30000]: 100%|██████████| 9/9 [00:00<00:00, 286.68it/s, loss=6.96]
Epoch [95/30000]: 100%|██████████| 9/9 [00:00<00:00, 289.21it/s, loss=11.2]
Epoch [96/30000]: 100%|██████████| 9/9 [00:00<00:00, 293.39it/s, loss=7.44]
Epoch [97/30000]: 100%|██████████| 9/9 [00:00<00:00, 242.83it/s, loss=7.01]
Epoch [98/30000]: 100%|██████████| 9/9 [00:00<00:00, 252.94it/s, loss=11.5]
Epoch [99/30

Saving model with loss 2.600...


Epoch [138/30000]: 100%|██████████| 9/9 [00:00<00:00, 295.01it/s, loss=2.33]
Epoch [139/30000]: 100%|██████████| 9/9 [00:00<00:00, 295.47it/s, loss=15.4]
Epoch [140/30000]: 100%|██████████| 9/9 [00:00<00:00, 278.09it/s, loss=8.48]
Epoch [141/30000]: 100%|██████████| 9/9 [00:00<00:00, 246.12it/s, loss=6.28]
Epoch [142/30000]: 100%|██████████| 9/9 [00:00<00:00, 288.16it/s, loss=5.6]
Epoch [143/30000]: 100%|██████████| 9/9 [00:00<00:00, 264.43it/s, loss=24.7]
Epoch [144/30000]: 100%|██████████| 9/9 [00:00<00:00, 276.07it/s, loss=4.53]
Epoch [145/30000]: 100%|██████████| 9/9 [00:00<00:00, 158.75it/s, loss=8.75]
Epoch [146/30000]: 100%|██████████| 9/9 [00:00<00:00, 284.81it/s, loss=5.48]
Epoch [147/30000]: 100%|██████████| 9/9 [00:00<00:00, 284.04it/s, loss=7.58]
Epoch [148/30000]: 100%|██████████| 9/9 [00:00<00:00, 279.18it/s, loss=7.51]
Epoch [149/30000]: 100%|██████████| 9/9 [00:00<00:00, 277.04it/s, loss=18.2]
Epoch [150/30000]: 100%|██████████| 9/9 [00:00<00:00, 270.36it/s, loss=2.89]


Saving model with loss 2.264...


Epoch [174/30000]: 100%|██████████| 9/9 [00:00<00:00, 243.21it/s, loss=4.13]
Epoch [175/30000]: 100%|██████████| 9/9 [00:00<00:00, 112.36it/s, loss=5.6]
Epoch [176/30000]: 100%|██████████| 9/9 [00:00<00:00, 282.94it/s, loss=9.46]
Epoch [177/30000]: 100%|██████████| 9/9 [00:00<00:00, 264.12it/s, loss=3.31]
Epoch [178/30000]: 100%|██████████| 9/9 [00:00<00:00, 295.76it/s, loss=13.7]
Epoch [179/30000]: 100%|██████████| 9/9 [00:00<00:00, 298.90it/s, loss=5.1]
Epoch [180/30000]: 100%|██████████| 9/9 [00:00<00:00, 292.86it/s, loss=11.8]
Epoch [181/30000]: 100%|██████████| 9/9 [00:00<00:00, 288.91it/s, loss=5.32]
Epoch [182/30000]: 100%|██████████| 9/9 [00:00<00:00, 292.86it/s, loss=11.7]
Epoch [183/30000]: 100%|██████████| 9/9 [00:00<00:00, 280.90it/s, loss=7.97]
Epoch [184/30000]: 100%|██████████| 9/9 [00:00<00:00, 295.73it/s, loss=5.67]
Epoch [185/30000]: 100%|██████████| 9/9 [00:00<00:00, 293.72it/s, loss=3.6]
Epoch [186/30000]: 100%|██████████| 9/9 [00:00<00:00, 265.54it/s, loss=2.44]
Ep

Saving model with loss 2.239...


Epoch [212/30000]: 100%|██████████| 9/9 [00:00<00:00, 249.05it/s, loss=3.75]
Epoch [213/30000]: 100%|██████████| 9/9 [00:00<00:00, 232.04it/s, loss=4.61]
Epoch [214/30000]: 100%|██████████| 9/9 [00:00<00:00, 288.68it/s, loss=8.44]
Epoch [215/30000]: 100%|██████████| 9/9 [00:00<00:00, 294.52it/s, loss=3.09]
Epoch [216/30000]: 100%|██████████| 9/9 [00:00<00:00, 282.05it/s, loss=2.52]
Epoch [217/30000]: 100%|██████████| 9/9 [00:00<00:00, 289.03it/s, loss=5.04]
Epoch [218/30000]: 100%|██████████| 9/9 [00:00<00:00, 271.60it/s, loss=10.4]
Epoch [219/30000]: 100%|██████████| 9/9 [00:00<00:00, 278.52it/s, loss=4.3]
Epoch [220/30000]: 100%|██████████| 9/9 [00:00<00:00, 294.48it/s, loss=6.07]
Epoch [221/30000]: 100%|██████████| 9/9 [00:00<00:00, 289.37it/s, loss=5.29]
Epoch [222/30000]: 100%|██████████| 9/9 [00:00<00:00, 297.74it/s, loss=7.95]
Epoch [223/30000]: 100%|██████████| 9/9 [00:00<00:00, 266.86it/s, loss=4.61]
Epoch [224/30000]: 100%|██████████| 9/9 [00:00<00:00, 236.41it/s, loss=2.84]


Saving model with loss 2.174...


Epoch [242/30000]: 100%|██████████| 9/9 [00:00<00:00, 286.72it/s, loss=8.04]
Epoch [243/30000]: 100%|██████████| 9/9 [00:00<00:00, 308.96it/s, loss=4.38]
Epoch [244/30000]: 100%|██████████| 9/9 [00:00<00:00, 297.11it/s, loss=8.61]
Epoch [245/30000]: 100%|██████████| 9/9 [00:00<00:00, 284.49it/s, loss=2.33]


Saving model with loss 2.022...


Epoch [246/30000]: 100%|██████████| 9/9 [00:00<00:00, 290.08it/s, loss=2.86]
Epoch [247/30000]: 100%|██████████| 9/9 [00:00<00:00, 285.63it/s, loss=10]
Epoch [248/30000]: 100%|██████████| 9/9 [00:00<00:00, 290.13it/s, loss=4.22]
Epoch [249/30000]: 100%|██████████| 9/9 [00:00<00:00, 289.77it/s, loss=3.73]
Epoch [250/30000]: 100%|██████████| 9/9 [00:00<00:00, 287.15it/s, loss=3.5]
Epoch [251/30000]: 100%|██████████| 9/9 [00:00<00:00, 271.87it/s, loss=3.8]
Epoch [252/30000]: 100%|██████████| 9/9 [00:00<00:00, 244.23it/s, loss=4.57]
Epoch [253/30000]: 100%|██████████| 9/9 [00:00<00:00, 277.05it/s, loss=10.4]
Epoch [254/30000]: 100%|██████████| 9/9 [00:00<00:00, 276.41it/s, loss=2.46]
Epoch [255/30000]: 100%|██████████| 9/9 [00:00<00:00, 287.27it/s, loss=7.08]
Epoch [256/30000]: 100%|██████████| 9/9 [00:00<00:00, 294.14it/s, loss=2.64]
Epoch [257/30000]: 100%|██████████| 9/9 [00:00<00:00, 276.45it/s, loss=2.42]
Epoch [258/30000]: 100%|██████████| 9/9 [00:00<00:00, 298.95it/s, loss=4]
Epoch 

Saving model with loss 1.994...


Epoch [300/30000]: 100%|██████████| 9/9 [00:00<00:00, 271.01it/s, loss=2.2]
Epoch [301/30000]: 100%|██████████| 9/9 [00:00<00:00, 282.65it/s, loss=3.13]
Epoch [302/30000]: 100%|██████████| 9/9 [00:00<00:00, 245.15it/s, loss=2.59]


Saving model with loss 1.751...


Epoch [303/30000]: 100%|██████████| 9/9 [00:00<00:00, 257.28it/s, loss=3.49]
Epoch [304/30000]: 100%|██████████| 9/9 [00:00<00:00, 283.65it/s, loss=5.15]
Epoch [305/30000]: 100%|██████████| 9/9 [00:00<00:00, 259.91it/s, loss=5.37]
Epoch [306/30000]: 100%|██████████| 9/9 [00:00<00:00, 301.14it/s, loss=2.2]
Epoch [307/30000]: 100%|██████████| 9/9 [00:00<00:00, 281.30it/s, loss=4.19]
Epoch [308/30000]: 100%|██████████| 9/9 [00:00<00:00, 288.77it/s, loss=4.9]
Epoch [309/30000]: 100%|██████████| 9/9 [00:00<00:00, 240.15it/s, loss=11.2]
Epoch [310/30000]: 100%|██████████| 9/9 [00:00<00:00, 272.37it/s, loss=3.28]
Epoch [311/30000]: 100%|██████████| 9/9 [00:00<00:00, 290.13it/s, loss=9.38]
Epoch [312/30000]: 100%|██████████| 9/9 [00:00<00:00, 297.86it/s, loss=3.26]
Epoch [313/30000]: 100%|██████████| 9/9 [00:00<00:00, 281.32it/s, loss=5.9]
Epoch [314/30000]: 100%|██████████| 9/9 [00:00<00:00, 297.61it/s, loss=4.26]
Epoch [315/30000]: 100%|██████████| 9/9 [00:00<00:00, 296.15it/s, loss=2.97]
Ep

Saving model with loss 1.705...


Epoch [404/30000]: 100%|██████████| 9/9 [00:00<00:00, 270.64it/s, loss=1.76]
Epoch [405/30000]: 100%|██████████| 9/9 [00:00<00:00, 288.29it/s, loss=12.9]
Epoch [406/30000]: 100%|██████████| 9/9 [00:00<00:00, 292.67it/s, loss=1.91]
Epoch [407/30000]: 100%|██████████| 9/9 [00:00<00:00, 293.08it/s, loss=2.71]
Epoch [408/30000]: 100%|██████████| 9/9 [00:00<00:00, 294.64it/s, loss=4.72]
Epoch [409/30000]: 100%|██████████| 9/9 [00:00<00:00, 300.14it/s, loss=5.14]
Epoch [410/30000]: 100%|██████████| 9/9 [00:00<00:00, 263.40it/s, loss=2.22]
Epoch [411/30000]: 100%|██████████| 9/9 [00:00<00:00, 291.62it/s, loss=8.91]
Epoch [412/30000]: 100%|██████████| 9/9 [00:00<00:00, 273.21it/s, loss=2.55]
Epoch [413/30000]: 100%|██████████| 9/9 [00:00<00:00, 298.38it/s, loss=4.12]
Epoch [414/30000]: 100%|██████████| 9/9 [00:00<00:00, 283.52it/s, loss=4.49]
Epoch [415/30000]: 100%|██████████| 9/9 [00:00<00:00, 297.11it/s, loss=3.69]
Epoch [416/30000]: 100%|██████████| 9/9 [00:00<00:00, 277.27it/s, loss=3.84]

Saving model with loss 1.578...


Epoch [493/30000]: 100%|██████████| 9/9 [00:00<00:00, 305.97it/s, loss=2.79]
Epoch [494/30000]: 100%|██████████| 9/9 [00:00<00:00, 289.81it/s, loss=4.03]
Epoch [495/30000]: 100%|██████████| 9/9 [00:00<00:00, 292.83it/s, loss=3.36]
Epoch [496/30000]: 100%|██████████| 9/9 [00:00<00:00, 263.31it/s, loss=2.1]
Epoch [497/30000]: 100%|██████████| 9/9 [00:00<00:00, 274.39it/s, loss=4.01]
Epoch [498/30000]: 100%|██████████| 9/9 [00:00<00:00, 291.86it/s, loss=3.88]
Epoch [499/30000]: 100%|██████████| 9/9 [00:00<00:00, 278.76it/s, loss=4.2]
Epoch [500/30000]: 100%|██████████| 9/9 [00:00<00:00, 286.79it/s, loss=8.04]
Epoch [501/30000]: 100%|██████████| 9/9 [00:00<00:00, 302.07it/s, loss=2.91]
Epoch [502/30000]: 100%|██████████| 9/9 [00:00<00:00, 297.96it/s, loss=2.5]
Epoch [503/30000]: 100%|██████████| 9/9 [00:00<00:00, 275.63it/s, loss=2.39]
Epoch [504/30000]: 100%|██████████| 9/9 [00:00<00:00, 292.78it/s, loss=2.35]
Epoch [505/30000]: 100%|██████████| 9/9 [00:00<00:00, 279.49it/s, loss=3.37]
Ep

Saving model with loss 1.562...


Epoch [600/30000]: 100%|██████████| 9/9 [00:00<00:00, 290.80it/s, loss=5.79]
Epoch [601/30000]: 100%|██████████| 9/9 [00:00<00:00, 284.42it/s, loss=1.61]
Epoch [602/30000]: 100%|██████████| 9/9 [00:00<00:00, 265.93it/s, loss=4.87]
Epoch [603/30000]: 100%|██████████| 9/9 [00:00<00:00, 301.37it/s, loss=2.69]
Epoch [604/30000]: 100%|██████████| 9/9 [00:00<00:00, 113.48it/s, loss=4.32]
Epoch [605/30000]: 100%|██████████| 9/9 [00:00<00:00, 284.22it/s, loss=3.31]
Epoch [606/30000]: 100%|██████████| 9/9 [00:00<00:00, 267.00it/s, loss=4.48]
Epoch [607/30000]: 100%|██████████| 9/9 [00:00<00:00, 285.20it/s, loss=2.03]
Epoch [608/30000]: 100%|██████████| 9/9 [00:00<00:00, 307.78it/s, loss=2.93]
Epoch [609/30000]: 100%|██████████| 9/9 [00:00<00:00, 282.56it/s, loss=1.2]
Epoch [610/30000]: 100%|██████████| 9/9 [00:00<00:00, 296.10it/s, loss=2.93]
Epoch [611/30000]: 100%|██████████| 9/9 [00:00<00:00, 303.85it/s, loss=3.11]
Epoch [612/30000]: 100%|██████████| 9/9 [00:00<00:00, 261.38it/s, loss=1.81]


Saving model with loss 1.559...


Epoch [713/30000]: 100%|██████████| 9/9 [00:00<00:00, 283.94it/s, loss=3.62]
Epoch [714/30000]: 100%|██████████| 9/9 [00:00<00:00, 301.11it/s, loss=1.64]
Epoch [715/30000]: 100%|██████████| 9/9 [00:00<00:00, 299.20it/s, loss=2.23]
Epoch [716/30000]: 100%|██████████| 9/9 [00:00<00:00, 290.15it/s, loss=4.65]
Epoch [717/30000]: 100%|██████████| 9/9 [00:00<00:00, 285.90it/s, loss=2.75]
Epoch [718/30000]: 100%|██████████| 9/9 [00:00<00:00, 293.02it/s, loss=2.26]
Epoch [719/30000]: 100%|██████████| 9/9 [00:00<00:00, 267.68it/s, loss=2.41]
Epoch [720/30000]: 100%|██████████| 9/9 [00:00<00:00, 299.61it/s, loss=4.91]
Epoch [721/30000]: 100%|██████████| 9/9 [00:00<00:00, 282.90it/s, loss=2.97]
Epoch [722/30000]: 100%|██████████| 9/9 [00:00<00:00, 269.31it/s, loss=2.87]
Epoch [723/30000]: 100%|██████████| 9/9 [00:00<00:00, 307.04it/s, loss=2.17]
Epoch [724/30000]: 100%|██████████| 9/9 [00:00<00:00, 285.66it/s, loss=4.88]
Epoch [725/30000]: 100%|██████████| 9/9 [00:00<00:00, 275.58it/s, loss=2.11]

Saving model with loss 1.553...


Epoch [732/30000]: 100%|██████████| 9/9 [00:00<00:00, 299.55it/s, loss=3.37]
Epoch [733/30000]: 100%|██████████| 9/9 [00:00<00:00, 297.72it/s, loss=2.08]
Epoch [734/30000]: 100%|██████████| 9/9 [00:00<00:00, 264.17it/s, loss=1.8]


Saving model with loss 1.365...


Epoch [735/30000]: 100%|██████████| 9/9 [00:00<00:00, 288.58it/s, loss=2.02]
Epoch [736/30000]: 100%|██████████| 9/9 [00:00<00:00, 290.91it/s, loss=2.97]
Epoch [737/30000]: 100%|██████████| 9/9 [00:00<00:00, 290.69it/s, loss=2.73]
Epoch [738/30000]: 100%|██████████| 9/9 [00:00<00:00, 262.61it/s, loss=2.11]
Epoch [739/30000]: 100%|██████████| 9/9 [00:00<00:00, 281.92it/s, loss=1.45]
Epoch [740/30000]: 100%|██████████| 9/9 [00:00<00:00, 178.12it/s, loss=5.38]
Epoch [741/30000]: 100%|██████████| 9/9 [00:00<00:00, 270.44it/s, loss=2.55]
Epoch [742/30000]: 100%|██████████| 9/9 [00:00<00:00, 259.49it/s, loss=2.6]
Epoch [743/30000]: 100%|██████████| 9/9 [00:00<00:00, 265.99it/s, loss=2.95]
Epoch [744/30000]: 100%|██████████| 9/9 [00:00<00:00, 254.06it/s, loss=2.03]
Epoch [745/30000]: 100%|██████████| 9/9 [00:00<00:00, 266.54it/s, loss=2.2]
Epoch [746/30000]: 100%|██████████| 9/9 [00:00<00:00, 295.12it/s, loss=2.38]
Epoch [747/30000]: 100%|██████████| 9/9 [00:00<00:00, 288.89it/s, loss=1.77]
E

Saving model with loss 1.358...


Epoch [932/30000]: 100%|██████████| 9/9 [00:00<00:00, 148.11it/s, loss=1.63]
Epoch [933/30000]: 100%|██████████| 9/9 [00:00<00:00, 249.44it/s, loss=2.66]
Epoch [934/30000]: 100%|██████████| 9/9 [00:00<00:00, 255.97it/s, loss=2.32]
Epoch [935/30000]: 100%|██████████| 9/9 [00:00<00:00, 252.81it/s, loss=2.16]
Epoch [936/30000]: 100%|██████████| 9/9 [00:00<00:00, 285.37it/s, loss=2.29]
Epoch [937/30000]: 100%|██████████| 9/9 [00:00<00:00, 276.11it/s, loss=2.82]
Epoch [938/30000]: 100%|██████████| 9/9 [00:00<00:00, 267.68it/s, loss=2.95]
Epoch [939/30000]: 100%|██████████| 9/9 [00:00<00:00, 278.65it/s, loss=1.17]
Epoch [940/30000]: 100%|██████████| 9/9 [00:00<00:00, 264.82it/s, loss=1.8]
Epoch [941/30000]: 100%|██████████| 9/9 [00:00<00:00, 287.98it/s, loss=3.42]
Epoch [942/30000]: 100%|██████████| 9/9 [00:00<00:00, 255.99it/s, loss=2.52]
Epoch [943/30000]: 100%|██████████| 9/9 [00:00<00:00, 267.90it/s, loss=3.52]
Epoch [944/30000]: 100%|██████████| 9/9 [00:00<00:00, 265.01it/s, loss=1.89]


Saving model with loss 1.357...


Epoch [979/30000]: 100%|██████████| 9/9 [00:00<00:00, 277.35it/s, loss=1.3]
Epoch [980/30000]: 100%|██████████| 9/9 [00:00<00:00, 266.24it/s, loss=2.08]
Epoch [981/30000]: 100%|██████████| 9/9 [00:00<00:00, 292.93it/s, loss=2.74]
Epoch [982/30000]: 100%|██████████| 9/9 [00:00<00:00, 303.12it/s, loss=1.8]
Epoch [983/30000]: 100%|██████████| 9/9 [00:00<00:00, 287.92it/s, loss=1.76]
Epoch [984/30000]: 100%|██████████| 9/9 [00:00<00:00, 258.23it/s, loss=2.25]
Epoch [985/30000]: 100%|██████████| 9/9 [00:00<00:00, 285.86it/s, loss=2.17]


Saving model with loss 1.329...


Epoch [986/30000]: 100%|██████████| 9/9 [00:00<00:00, 166.87it/s, loss=1.59]
Epoch [987/30000]: 100%|██████████| 9/9 [00:00<00:00, 224.21it/s, loss=2.72]
Epoch [988/30000]: 100%|██████████| 9/9 [00:00<00:00, 282.32it/s, loss=2.75]
Epoch [989/30000]: 100%|██████████| 9/9 [00:00<00:00, 287.52it/s, loss=2.24]
Epoch [990/30000]: 100%|██████████| 9/9 [00:00<00:00, 269.86it/s, loss=2.48]
Epoch [991/30000]: 100%|██████████| 9/9 [00:00<00:00, 293.99it/s, loss=2.19]
Epoch [992/30000]: 100%|██████████| 9/9 [00:00<00:00, 276.76it/s, loss=4.41]
Epoch [993/30000]: 100%|██████████| 9/9 [00:00<00:00, 278.33it/s, loss=1.92]
Epoch [994/30000]: 100%|██████████| 9/9 [00:00<00:00, 293.88it/s, loss=2.46]
Epoch [995/30000]: 100%|██████████| 9/9 [00:00<00:00, 287.27it/s, loss=2.03]
Epoch [996/30000]: 100%|██████████| 9/9 [00:00<00:00, 275.45it/s, loss=3.01]
Epoch [997/30000]: 100%|██████████| 9/9 [00:00<00:00, 294.62it/s, loss=3.36]
Epoch [998/30000]: 100%|██████████| 9/9 [00:00<00:00, 296.32it/s, loss=1.8]


Saving model with loss 1.319...


Epoch [1030/30000]: 100%|██████████| 9/9 [00:00<00:00, 303.17it/s, loss=1.43]
Epoch [1031/30000]: 100%|██████████| 9/9 [00:00<00:00, 287.34it/s, loss=1.27]
Epoch [1032/30000]: 100%|██████████| 9/9 [00:00<00:00, 237.27it/s, loss=1.81]
Epoch [1033/30000]: 100%|██████████| 9/9 [00:00<00:00, 113.09it/s, loss=2.4]
Epoch [1034/30000]: 100%|██████████| 9/9 [00:00<00:00, 294.38it/s, loss=2.42]
Epoch [1035/30000]: 100%|██████████| 9/9 [00:00<00:00, 303.99it/s, loss=1.88]
Epoch [1036/30000]: 100%|██████████| 9/9 [00:00<00:00, 270.59it/s, loss=4.19]
Epoch [1037/30000]: 100%|██████████| 9/9 [00:00<00:00, 287.14it/s, loss=3]
Epoch [1038/30000]: 100%|██████████| 9/9 [00:00<00:00, 295.68it/s, loss=1.44]
Epoch [1039/30000]: 100%|██████████| 9/9 [00:00<00:00, 159.32it/s, loss=1.37]
Epoch [1040/30000]: 100%|██████████| 9/9 [00:00<00:00, 284.91it/s, loss=1.54]
Epoch [1041/30000]: 100%|██████████| 9/9 [00:00<00:00, 291.55it/s, loss=4.95]
Epoch [1042/30000]: 100%|██████████| 9/9 [00:00<00:00, 267.32it/s, l

Saving model with loss 1.311...


Epoch [1088/30000]: 100%|██████████| 9/9 [00:00<00:00, 289.52it/s, loss=3.72]
Epoch [1089/30000]: 100%|██████████| 9/9 [00:00<00:00, 279.55it/s, loss=2.17]
Epoch [1090/30000]: 100%|██████████| 9/9 [00:00<00:00, 275.13it/s, loss=2.61]
Epoch [1091/30000]: 100%|██████████| 9/9 [00:00<00:00, 293.92it/s, loss=1.99]
Epoch [1092/30000]: 100%|██████████| 9/9 [00:00<00:00, 299.44it/s, loss=4.32]
Epoch [1093/30000]: 100%|██████████| 9/9 [00:00<00:00, 268.11it/s, loss=4.89]
Epoch [1094/30000]: 100%|██████████| 9/9 [00:00<00:00, 112.54it/s, loss=2.14]
Epoch [1095/30000]: 100%|██████████| 9/9 [00:00<00:00, 160.02it/s, loss=1.76]
Epoch [1096/30000]: 100%|██████████| 9/9 [00:00<00:00, 291.03it/s, loss=2.91]
Epoch [1097/30000]: 100%|██████████| 9/9 [00:00<00:00, 271.34it/s, loss=1.09]
Epoch [1098/30000]: 100%|██████████| 9/9 [00:00<00:00, 297.73it/s, loss=1.88]
Epoch [1099/30000]: 100%|██████████| 9/9 [00:00<00:00, 299.07it/s, loss=3]
Epoch [1100/30000]: 100%|██████████| 9/9 [00:00<00:00, 286.88it/s, 

Saving model with loss 1.309...


Epoch [1158/30000]: 100%|██████████| 9/9 [00:00<00:00, 276.90it/s, loss=5.64]
Epoch [1159/30000]: 100%|██████████| 9/9 [00:00<00:00, 259.43it/s, loss=1.34]
Epoch [1160/30000]: 100%|██████████| 9/9 [00:00<00:00, 291.86it/s, loss=1.85]
Epoch [1161/30000]: 100%|██████████| 9/9 [00:00<00:00, 276.78it/s, loss=1.47]
Epoch [1162/30000]: 100%|██████████| 9/9 [00:00<00:00, 282.21it/s, loss=3.65]
Epoch [1163/30000]: 100%|██████████| 9/9 [00:00<00:00, 266.93it/s, loss=2.54]
Epoch [1164/30000]: 100%|██████████| 9/9 [00:00<00:00, 284.50it/s, loss=3.08]
Epoch [1165/30000]: 100%|██████████| 9/9 [00:00<00:00, 206.83it/s, loss=1.66]
Epoch [1166/30000]: 100%|██████████| 9/9 [00:00<00:00, 242.17it/s, loss=1.56]
Epoch [1167/30000]: 100%|██████████| 9/9 [00:00<00:00, 212.41it/s, loss=1.79]
Epoch [1168/30000]: 100%|██████████| 9/9 [00:00<00:00, 210.45it/s, loss=1.28]
Epoch [1169/30000]: 100%|██████████| 9/9 [00:00<00:00, 269.64it/s, loss=2.78]
Epoch [1170/30000]: 100%|██████████| 9/9 [00:00<00:00, 287.75it/

Saving model with loss 1.264...


Epoch [1247/30000]: 100%|██████████| 9/9 [00:00<00:00, 280.16it/s, loss=2.23]
Epoch [1248/30000]: 100%|██████████| 9/9 [00:00<00:00, 202.89it/s, loss=2.82]
Epoch [1249/30000]: 100%|██████████| 9/9 [00:00<00:00, 244.19it/s, loss=2.85]
Epoch [1250/30000]: 100%|██████████| 9/9 [00:00<00:00, 273.01it/s, loss=2.28]
Epoch [1251/30000]: 100%|██████████| 9/9 [00:00<00:00, 259.73it/s, loss=2.44]
Epoch [1252/30000]: 100%|██████████| 9/9 [00:00<00:00, 282.55it/s, loss=1.25]
Epoch [1253/30000]: 100%|██████████| 9/9 [00:00<00:00, 248.41it/s, loss=2.14]
Epoch [1254/30000]: 100%|██████████| 9/9 [00:00<00:00, 273.42it/s, loss=1.73]
Epoch [1255/30000]: 100%|██████████| 9/9 [00:00<00:00, 287.40it/s, loss=1.27]
Epoch [1256/30000]: 100%|██████████| 9/9 [00:00<00:00, 299.01it/s, loss=1.43]
Epoch [1257/30000]: 100%|██████████| 9/9 [00:00<00:00, 297.40it/s, loss=1.44]
Epoch [1258/30000]: 100%|██████████| 9/9 [00:00<00:00, 288.67it/s, loss=2.42]
Epoch [1259/30000]: 100%|██████████| 9/9 [00:00<00:00, 251.24it/

Saving model with loss 1.220...


Epoch [1300/30000]: 100%|██████████| 9/9 [00:00<00:00, 291.90it/s, loss=1.76]
Epoch [1301/30000]: 100%|██████████| 9/9 [00:00<00:00, 282.04it/s, loss=1.32]
Epoch [1302/30000]: 100%|██████████| 9/9 [00:00<00:00, 285.73it/s, loss=1.45]
Epoch [1303/30000]: 100%|██████████| 9/9 [00:00<00:00, 263.71it/s, loss=1.97]
Epoch [1304/30000]: 100%|██████████| 9/9 [00:00<00:00, 282.54it/s, loss=3.73]
Epoch [1305/30000]: 100%|██████████| 9/9 [00:00<00:00, 281.06it/s, loss=1.65]
Epoch [1306/30000]: 100%|██████████| 9/9 [00:00<00:00, 296.34it/s, loss=2.47]
Epoch [1307/30000]: 100%|██████████| 9/9 [00:00<00:00, 297.71it/s, loss=1.49]
Epoch [1308/30000]: 100%|██████████| 9/9 [00:00<00:00, 291.04it/s, loss=2.46]
Epoch [1309/30000]: 100%|██████████| 9/9 [00:00<00:00, 263.62it/s, loss=3.4]
Epoch [1310/30000]: 100%|██████████| 9/9 [00:00<00:00, 284.15it/s, loss=2]
Epoch [1311/30000]: 100%|██████████| 9/9 [00:00<00:00, 282.07it/s, loss=1.21]
Epoch [1312/30000]: 100%|██████████| 9/9 [00:00<00:00, 285.40it/s, l

Saving model with loss 1.190...


Epoch [1437/30000]: 100%|██████████| 9/9 [00:00<00:00, 204.85it/s, loss=1.43]
Epoch [1438/30000]: 100%|██████████| 9/9 [00:00<00:00, 232.52it/s, loss=2.37]
Epoch [1439/30000]: 100%|██████████| 9/9 [00:00<00:00, 289.83it/s, loss=2.62]
Epoch [1440/30000]: 100%|██████████| 9/9 [00:00<00:00, 290.76it/s, loss=1.69]
Epoch [1441/30000]: 100%|██████████| 9/9 [00:00<00:00, 280.00it/s, loss=1.71]
Epoch [1442/30000]: 100%|██████████| 9/9 [00:00<00:00, 243.58it/s, loss=2.45]
Epoch [1443/30000]: 100%|██████████| 9/9 [00:00<00:00, 262.11it/s, loss=1.4]
Epoch [1444/30000]: 100%|██████████| 9/9 [00:00<00:00, 264.33it/s, loss=2.38]
Epoch [1445/30000]: 100%|██████████| 9/9 [00:00<00:00, 283.32it/s, loss=1.88]
Epoch [1446/30000]: 100%|██████████| 9/9 [00:00<00:00, 268.30it/s, loss=1.53]
Epoch [1447/30000]: 100%|██████████| 9/9 [00:00<00:00, 287.13it/s, loss=1.77]
Epoch [1448/30000]: 100%|██████████| 9/9 [00:00<00:00, 261.19it/s, loss=1.74]
Epoch [1449/30000]: 100%|██████████| 9/9 [00:00<00:00, 286.23it/s

Saving model with loss 1.138...


Epoch [1564/30000]: 100%|██████████| 9/9 [00:00<00:00, 277.31it/s, loss=1.09]
Epoch [1565/30000]: 100%|██████████| 9/9 [00:00<00:00, 286.70it/s, loss=1.1]
Epoch [1566/30000]: 100%|██████████| 9/9 [00:00<00:00, 287.79it/s, loss=1.15]
Epoch [1567/30000]: 100%|██████████| 9/9 [00:00<00:00, 290.15it/s, loss=1.43]
Epoch [1568/30000]: 100%|██████████| 9/9 [00:00<00:00, 276.74it/s, loss=1.44]
Epoch [1569/30000]: 100%|██████████| 9/9 [00:00<00:00, 273.37it/s, loss=3.27]
Epoch [1570/30000]: 100%|██████████| 9/9 [00:00<00:00, 271.04it/s, loss=1.21]
Epoch [1571/30000]: 100%|██████████| 9/9 [00:00<00:00, 276.86it/s, loss=2.22]
Epoch [1572/30000]: 100%|██████████| 9/9 [00:00<00:00, 281.09it/s, loss=1.78]
Epoch [1573/30000]: 100%|██████████| 9/9 [00:00<00:00, 292.42it/s, loss=1.49]
Epoch [1574/30000]: 100%|██████████| 9/9 [00:00<00:00, 291.99it/s, loss=1.47]
Epoch [1575/30000]: 100%|██████████| 9/9 [00:00<00:00, 284.19it/s, loss=1.22]
Epoch [1576/30000]: 100%|██████████| 9/9 [00:00<00:00, 263.97it/s

Saving model with loss 1.021...


Epoch [1725/30000]: 100%|██████████| 9/9 [00:00<00:00, 261.96it/s, loss=1.13]
Epoch [1726/30000]: 100%|██████████| 9/9 [00:00<00:00, 253.64it/s, loss=1.35]
Epoch [1727/30000]: 100%|██████████| 9/9 [00:00<00:00, 278.10it/s, loss=2.88]
Epoch [1728/30000]: 100%|██████████| 9/9 [00:00<00:00, 277.03it/s, loss=1.74]
Epoch [1729/30000]: 100%|██████████| 9/9 [00:00<00:00, 270.10it/s, loss=1.85]
Epoch [1730/30000]: 100%|██████████| 9/9 [00:00<00:00, 294.61it/s, loss=1.55]
Epoch [1731/30000]: 100%|██████████| 9/9 [00:00<00:00, 276.57it/s, loss=1.67]
Epoch [1732/30000]: 100%|██████████| 9/9 [00:00<00:00, 291.34it/s, loss=1.08]
Epoch [1733/30000]: 100%|██████████| 9/9 [00:00<00:00, 282.06it/s, loss=2.71]
Epoch [1734/30000]: 100%|██████████| 9/9 [00:00<00:00, 291.00it/s, loss=1.72]
Epoch [1735/30000]: 100%|██████████| 9/9 [00:00<00:00, 235.00it/s, loss=1.67]
Epoch [1736/30000]: 100%|██████████| 9/9 [00:00<00:00, 248.18it/s, loss=1.32]
Epoch [1737/30000]: 100%|██████████| 9/9 [00:00<00:00, 276.37it/


Model is not improving, so we halt the training session.


### 训练结果

Epoch [1656/3000]: Train loss: 1.0902, Valid loss: 0.9599
Saving model with loss 0.960...

Epoch [1365/3000]: Train loss: 1.0947, Valid loss: 0.9278
Saving model with loss 0.928...

#### Adam

Epoch [588/30000]: Train loss: 1.3394, Valid loss: 0.9677
Saving model with loss 0.968...

Epoch [1365/30000]: Train loss: 1.0381, Valid loss: 0.9425
Saving model with loss 0.942..


Epoch [868/30000]: Train loss: 1.1504, Valid loss: 0.9203
Saving model with loss 0.920...

Epoch [1020/30000]: Train loss: 1.0783, Valid loss: 0.9449
Saving model with loss 0.945..

Epoch [878/30000]: Train loss: 1.0665, Valid loss: 0.9840
Saving model with loss 0.984...

Epoch [3030/30000]: Train loss: 1.0804, Valid loss: 0.9626
Saving model with loss 0.963...

Epoch [2686/30000]: Train loss: 1.1251, Valid loss: 0.7586
Saving model with loss 0.759...

### 新的预测

In [28]:
def save_pred(preds, file):
    ''' Save predictions to specified file '''
    with open(file, 'w') as fp:
        writer = csv.writer(fp)
        writer.writerow(['id', 'tested_positive'])
        for i, p in enumerate(preds):
            writer.writerow([i, p])

model = My_Model(input_dim=x_train.shape[1]).to(device)
model.load_state_dict(torch.load(config['save_path']))
preds = predict(test_loader, model, device) 
save_pred(preds, 'pred_1_c.csv')      

100%|██████████| 5/5 [00:00<00:00, 1530.43it/s]


### 结果不理想

### 最新的结果
Use Adam Algorithm with weight_decay, also change the random seed!!!!
![](./my_result_7_lessthan1.jpg)