In [1]:
# load library
import pandas as pd
import numpy as np
from sklearn.preprocessing import LabelEncoder, MinMaxScaler
import ast
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader

In [2]:
# load dataset
train = pd.read_csv('../../data/trainProcessed.csv')
validate = pd.read_csv('../../data/validateProcessed.csv')
test = pd.read_csv('../../data/testProcessed.csv')

In [3]:
# check the head of the train dataset
train.head(3)

Unnamed: 0,AGE,DIFFERENTIAL_DIAGNOSIS,SEX,PATHOLOGY,INITIAL_EVIDENCE,I30,diarrhee,bode,lesions_peau_endroitducorps_@_face_dorsale_main_D_,douleurxx_irrad_@_sous_la_machoire,...,etourdissement,hernie_hiatale,douleurxx_irrad_@_trachée,douleurxx_endroitducorps_@_orteil__1__G_,ww_dd,lesions_peau_endroitducorps_@_petite_lèvre_G_,lesions_peau_elevee_@_2,j17_j18,lesions_peau_intens_@_0,lesions_peau_endroitducorps_@_vagin
0,18,"[['Bronchite', 0.19171203430383882], ['Pneumon...",0,IVRS ou virémie,fievre,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
1,21,"[['VIH (Primo-infection)', 0.5189500564407601]...",0,VIH (Primo-infection),diaph,0,1,0,0,0,...,0,0,0,0,0,0,0,0,0,0
2,19,"[['Bronchite', 0.11278064619119596], ['Pneumon...",1,Pneumonie,expecto,0,0,0,0,0,...,0,0,0,0,0,0,0,1,1,0


### Data Preprocessing

In [4]:
# normalize 'AGE' between 0 and 1
scaler = MinMaxScaler()
train['AGE'] = scaler.fit_transform(train[['AGE']])
validate['AGE'] = scaler.fit_transform(validate[['AGE']])
test['AGE'] = scaler.fit_transform(test[['AGE']])

In [5]:
# a function to convert string representation of list to actual list
def convert_list_diagnosis(diagnosis_list):
    if isinstance(diagnosis_list, str):
        try:
            return ast.literal_eval(diagnosis_list)
        except Exception as e: 
            return None 
    else:
        return None 

In [6]:
# apply the function to the actual dataset
train['DIFFERENTIAL_DIAGNOSIS'] = train['DIFFERENTIAL_DIAGNOSIS'].apply(convert_list_diagnosis)
validate['DIFFERENTIAL_DIAGNOSIS'] = validate['DIFFERENTIAL_DIAGNOSIS'].apply(convert_list_diagnosis)
test['DIFFERENTIAL_DIAGNOSIS'] = test['DIFFERENTIAL_DIAGNOSIS'].apply(convert_list_diagnosis)

In [7]:
# check the result
train['DIFFERENTIAL_DIAGNOSIS']

0          [[Bronchite, 0.19171203430383882], [Pneumonie,...
1          [[VIH (Primo-infection), 0.5189500564407601], ...
2          [[Bronchite, 0.11278064619119596], [Pneumonie,...
3          [[IVRS ou virémie, 0.23859396799565236], [Céph...
4          [[IVRS ou virémie, 0.23677812769175735], [Poss...
                                 ...                        
1025597    [[Épiglottite, 0.28156957795466475], [VIH (Pri...
1025598    [[Épiglottite, 0.3703962237298842], [Laryngosp...
1025599    [[Épiglottite, 0.13193905052537108], [Laryngo-...
1025600    [[Épiglottite, 0.3028258988138983], [Laryngite...
1025601    [[Épiglottite, 0.12896823203696775], [Laryngit...
Name: DIFFERENTIAL_DIAGNOSIS, Length: 1025602, dtype: object

In [8]:
# a function to one-hot-encode the differential diagnosis 
def one_hot_encode_diagnosis(df, column_name):
    # ensure that each diagnosis list is correctly formatted
    df[column_name] = df[column_name].apply(lambda x: x if isinstance(x, list) else [])

    # flatten the list of all possible diagnoses, taking only the diagnosis name (first item of each sublist)
    all_diagnoses = set(diagnosis[0] for sublist in df[column_name] for diagnosis in sublist if isinstance(diagnosis, list) and len(diagnosis) > 0)
    
    # initialize a dictionary to hold the one-hot encoded data
    one_hot_encoded_data = {diagnosis: [] for diagnosis in all_diagnoses}
    
    # populate the dictionary with 1s and 0s based on diagnosis presence
    for index, row in df.iterrows():
        present_diagnoses = {diagnosis[0] for diagnosis in row[column_name] if isinstance(diagnosis, list) and len(diagnosis) > 0}
        for diagnosis in all_diagnoses:
            one_hot_encoded_data[diagnosis].append(1 if diagnosis in present_diagnoses else 0)
    
    # convert the dictionary to a DataFrame
    one_hot_y = pd.DataFrame(one_hot_encoded_data)
    
    return one_hot_y

In [9]:
# apply the function to the dataset to do one-hot-encode for 'Differential Diagnosis'
train_y = one_hot_encode_diagnosis(train, 'DIFFERENTIAL_DIAGNOSIS')
validate_y = one_hot_encode_diagnosis(validate, 'DIFFERENTIAL_DIAGNOSIS')
test_y = one_hot_encode_diagnosis(test, 'DIFFERENTIAL_DIAGNOSIS')

In [10]:
# check the result
train_y.head(3)

Unnamed: 0,VIH (Primo-infection),Asthme exacerbé ou bronchospasme,Pneumonie,Épiglottite,Angine instable,Céphalée en grappe,RGO,Otite moyenne aigue (OMA),Anémie,Syndrome de Boerhaave,...,Hernie inguinale,Embolie pulmonaire,IVRS ou virémie,Possible influenza ou syndrome virémique typique,Oedème localisé ou généralisé sans atteinte pulmonaire associée,Sarcoïdose,TSVP,Possible NSTEMI / STEMI,Tuberculose,Laryngospasme
0,1,0,1,0,0,0,0,0,0,0,...,0,0,1,1,0,0,0,0,1,0
1,1,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,1,0,0,0,0
2,0,0,1,0,1,0,1,0,0,1,...,0,0,1,0,0,1,0,1,0,0


In [11]:
# a function for spliting dataset and drop columns
def data_pre(df, target_columns):
    targets = {}
    for column in target_columns:
        targets[column] = df[column].copy()
        df = df.drop(column, axis=1)
    
    data_X = df
    return data_X, targets

In [12]:
# the target columns we want to move from the train dataset
target_columns = ['PATHOLOGY', 'DIFFERENTIAL_DIAGNOSIS', 'INITIAL_EVIDENCE']
# apply the function to train dataset
train_X, train_targets = data_pre(train, target_columns)

In [13]:
# the target columns we want to move from the train dataset
target_columns = ['PATHOLOGY', 'DIFFERENTIAL_DIAGNOSIS', 'INITIAL_EVIDENCE']
# apply the function to train dataset
test_X, test_targets = data_pre(test, target_columns)

In [14]:
# check the result
train_X.head(3)

Unnamed: 0,AGE,SEX,I30,diarrhee,bode,lesions_peau_endroitducorps_@_face_dorsale_main_D_,douleurxx_irrad_@_sous_la_machoire,douleurxx_irrad_@_cartilage_thyroidien,douleurxx_irrad_@_arrière_de_tête,douleurxx_endroitducorps_@_hypochondre_G_,...,etourdissement,hernie_hiatale,douleurxx_irrad_@_trachée,douleurxx_endroitducorps_@_orteil__1__G_,ww_dd,lesions_peau_endroitducorps_@_petite_lèvre_G_,lesions_peau_elevee_@_2,j17_j18,lesions_peau_intens_@_0,lesions_peau_endroitducorps_@_vagin
0,0.165138,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
1,0.192661,0,0,1,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
2,0.174312,1,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,1,1,0


### RNN Model

In [15]:
# set the working environment
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [16]:
# build a three layers RNN model
# class RNNModel(nn.Module):
#     def __init__(self, input_size, hidden_size, output_size):
#         super(RNNModel, self).__init__()
#         self.hidden_size = hidden_size
        
#         # first RNN layer
#         self.rnn1 = nn.RNN(input_size, hidden_size, batch_first=True)
#         # second RNN layer - Takes input from the first RNN layer
#         self.rnn2 = nn.RNN(hidden_size, hidden_size, batch_first=True)
        
#         # linear layers
#         self.fc1 = nn.Linear(hidden_size, hidden_size // 2)  # Intermediate linear layer
#         self.fc2 = nn.Linear(hidden_size // 2, output_size)  # Final output layer

#     def forward(self, x):
#         h0 = torch.zeros(1, x.size(0), self.hidden_size).to(device)
#         # pass through the first RNN layer
#         out, _ = self.rnn1(x, h0)
#         # pass the output of the first RNN layer to the second RNN layer
#         out, _ = self.rnn2(out, h0)
        
#         # passing the output of the last time step through linear layers
#         out = self.fc1(out[:, -1, :])
#         out = self.fc2(out)
#         return out

class LSTMModel(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(LSTMModel, self).__init__()
        self.hidden_size = hidden_size
        
        # First LSTM layer
        self.lstm1 = nn.LSTM(input_size, hidden_size, batch_first=True)
        # Second LSTM layer - Takes input from the first LSTM layer
        self.lstm2 = nn.LSTM(hidden_size, hidden_size, batch_first=True)
        
        # Linear layers
        self.fc1 = nn.Linear(hidden_size, hidden_size // 2)  # Intermediate linear layer
        self.fc2 = nn.Linear(hidden_size // 2, output_size)  # Final output layer

    def forward(self, x):
        # Initialize hidden and cell states with zeros
        h0 = torch.zeros(1, x.size(0), self.hidden_size).to(device)
        c0 = torch.zeros(1, x.size(0), self.hidden_size).to(device)
        
        # Pass through the first LSTM layer
        out, _ = self.lstm1(x, (h0, c0))
        # Pass the output of the first LSTM layer to the second LSTM layer
        out, _ = self.lstm2(out, (h0, c0))
        
        # Passing the output of the last time step through linear layers
        out = self.fc1(out[:, -1, :])
        out = self.fc2(out)
        return out


input_size = train_X.shape[1]  # number of features
hidden_size = 64
output_size = len(set(train_y))  # number of unique diagnoses

model = LSTMModel(input_size, hidden_size, output_size).to(device)

In [17]:
# convert arrays to PyTorch tensors and move them to the specified device
X_train_tensor = torch.tensor(train_X.values, dtype=torch.float).to(device)
y_train_tensor = torch.tensor(train_y.values, dtype=torch.long).to(device)

In [18]:
# create TensorDataset and DataLoader for batch processing
train_data = TensorDataset(X_train_tensor, y_train_tensor)
train_loader = DataLoader(dataset=train_data, batch_size=64, shuffle=True)

In [19]:
# define the loss function and optimizer
criterion = torch.nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [20]:
# set the number of eposhes
num_epochs = 40

In [21]:
# apply the model to the training loop
for epoch in range(num_epochs):
    for i, (features, labels) in enumerate(train_loader):
        # assuming labels are initially Long for indices, convert them to float for BCEWithLogitsLoss
        labels = labels.float()  # convert labels to float

        features = features.unsqueeze(1) # unsqueeze the data's feature to fit the model

        outputs = model(features)
        loss = criterion(outputs, labels) 
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (i + 1) % 100 == 0:
            print(f'Epoch [{epoch + 1}/{num_epochs}], Step [{i + 1}/{len(train_loader)}], Loss: {loss.item():.4f}')


Epoch [1/40], Step [100/16026], Loss: 0.3915
Epoch [1/40], Step [200/16026], Loss: 0.3230
Epoch [1/40], Step [300/16026], Loss: 0.2710
Epoch [1/40], Step [400/16026], Loss: 0.2250
Epoch [1/40], Step [500/16026], Loss: 0.2066
Epoch [1/40], Step [600/16026], Loss: 0.2033
Epoch [1/40], Step [700/16026], Loss: 0.1720
Epoch [1/40], Step [800/16026], Loss: 0.1781
Epoch [1/40], Step [900/16026], Loss: 0.1717
Epoch [1/40], Step [1000/16026], Loss: 0.1617
Epoch [1/40], Step [1100/16026], Loss: 0.1649
Epoch [1/40], Step [1200/16026], Loss: 0.1453
Epoch [1/40], Step [1300/16026], Loss: 0.1633
Epoch [1/40], Step [1400/16026], Loss: 0.1204
Epoch [1/40], Step [1500/16026], Loss: 0.1305
Epoch [1/40], Step [1600/16026], Loss: 0.1183
Epoch [1/40], Step [1700/16026], Loss: 0.1063
Epoch [1/40], Step [1800/16026], Loss: 0.1053
Epoch [1/40], Step [1900/16026], Loss: 0.1038
Epoch [1/40], Step [2000/16026], Loss: 0.1146
Epoch [1/40], Step [2100/16026], Loss: 0.1186
Epoch [1/40], Step [2200/16026], Loss: 0.11

Epoch [2/40], Step [1900/16026], Loss: 0.0376
Epoch [2/40], Step [2000/16026], Loss: 0.0453
Epoch [2/40], Step [2100/16026], Loss: 0.0341
Epoch [2/40], Step [2200/16026], Loss: 0.0641
Epoch [2/40], Step [2300/16026], Loss: 0.0354
Epoch [2/40], Step [2400/16026], Loss: 0.0414
Epoch [2/40], Step [2500/16026], Loss: 0.0542
Epoch [2/40], Step [2600/16026], Loss: 0.0544
Epoch [2/40], Step [2700/16026], Loss: 0.0564
Epoch [2/40], Step [2800/16026], Loss: 0.0369
Epoch [2/40], Step [2900/16026], Loss: 0.0458
Epoch [2/40], Step [3000/16026], Loss: 0.0425
Epoch [2/40], Step [3100/16026], Loss: 0.0397
Epoch [2/40], Step [3200/16026], Loss: 0.0386
Epoch [2/40], Step [3300/16026], Loss: 0.0341
Epoch [2/40], Step [3400/16026], Loss: 0.0371
Epoch [2/40], Step [3500/16026], Loss: 0.0373
Epoch [2/40], Step [3600/16026], Loss: 0.0401
Epoch [2/40], Step [3700/16026], Loss: 0.0395
Epoch [2/40], Step [3800/16026], Loss: 0.0336
Epoch [2/40], Step [3900/16026], Loss: 0.0514
Epoch [2/40], Step [4000/16026], L

Epoch [3/40], Step [3600/16026], Loss: 0.0471
Epoch [3/40], Step [3700/16026], Loss: 0.0419
Epoch [3/40], Step [3800/16026], Loss: 0.0380
Epoch [3/40], Step [3900/16026], Loss: 0.0419
Epoch [3/40], Step [4000/16026], Loss: 0.0359
Epoch [3/40], Step [4100/16026], Loss: 0.0332
Epoch [3/40], Step [4200/16026], Loss: 0.0434
Epoch [3/40], Step [4300/16026], Loss: 0.0337
Epoch [3/40], Step [4400/16026], Loss: 0.0431
Epoch [3/40], Step [4500/16026], Loss: 0.0324
Epoch [3/40], Step [4600/16026], Loss: 0.0253
Epoch [3/40], Step [4700/16026], Loss: 0.0443
Epoch [3/40], Step [4800/16026], Loss: 0.0304
Epoch [3/40], Step [4900/16026], Loss: 0.0313
Epoch [3/40], Step [5000/16026], Loss: 0.0232
Epoch [3/40], Step [5100/16026], Loss: 0.0354
Epoch [3/40], Step [5200/16026], Loss: 0.0213
Epoch [3/40], Step [5300/16026], Loss: 0.0372
Epoch [3/40], Step [5400/16026], Loss: 0.0383
Epoch [3/40], Step [5500/16026], Loss: 0.0392
Epoch [3/40], Step [5600/16026], Loss: 0.0444
Epoch [3/40], Step [5700/16026], L

Epoch [4/40], Step [5300/16026], Loss: 0.0289
Epoch [4/40], Step [5400/16026], Loss: 0.0298
Epoch [4/40], Step [5500/16026], Loss: 0.0326
Epoch [4/40], Step [5600/16026], Loss: 0.0396
Epoch [4/40], Step [5700/16026], Loss: 0.0274
Epoch [4/40], Step [5800/16026], Loss: 0.0217
Epoch [4/40], Step [5900/16026], Loss: 0.0413
Epoch [4/40], Step [6000/16026], Loss: 0.0352
Epoch [4/40], Step [6100/16026], Loss: 0.0272
Epoch [4/40], Step [6200/16026], Loss: 0.0340
Epoch [4/40], Step [6300/16026], Loss: 0.0299
Epoch [4/40], Step [6400/16026], Loss: 0.0336
Epoch [4/40], Step [6500/16026], Loss: 0.0340
Epoch [4/40], Step [6600/16026], Loss: 0.0367
Epoch [4/40], Step [6700/16026], Loss: 0.0226
Epoch [4/40], Step [6800/16026], Loss: 0.0251
Epoch [4/40], Step [6900/16026], Loss: 0.0264
Epoch [4/40], Step [7000/16026], Loss: 0.0268
Epoch [4/40], Step [7100/16026], Loss: 0.0553
Epoch [4/40], Step [7200/16026], Loss: 0.0321
Epoch [4/40], Step [7300/16026], Loss: 0.0291
Epoch [4/40], Step [7400/16026], L

Epoch [5/40], Step [7000/16026], Loss: 0.0280
Epoch [5/40], Step [7100/16026], Loss: 0.0264
Epoch [5/40], Step [7200/16026], Loss: 0.0352
Epoch [5/40], Step [7300/16026], Loss: 0.0359
Epoch [5/40], Step [7400/16026], Loss: 0.0317
Epoch [5/40], Step [7500/16026], Loss: 0.0328
Epoch [5/40], Step [7600/16026], Loss: 0.0263
Epoch [5/40], Step [7700/16026], Loss: 0.0281
Epoch [5/40], Step [7800/16026], Loss: 0.0245
Epoch [5/40], Step [7900/16026], Loss: 0.0356
Epoch [5/40], Step [8000/16026], Loss: 0.0318
Epoch [5/40], Step [8100/16026], Loss: 0.0269
Epoch [5/40], Step [8200/16026], Loss: 0.0209
Epoch [5/40], Step [8300/16026], Loss: 0.0220
Epoch [5/40], Step [8400/16026], Loss: 0.0327
Epoch [5/40], Step [8500/16026], Loss: 0.0306
Epoch [5/40], Step [8600/16026], Loss: 0.0288
Epoch [5/40], Step [8700/16026], Loss: 0.0286
Epoch [5/40], Step [8800/16026], Loss: 0.0298
Epoch [5/40], Step [8900/16026], Loss: 0.0270
Epoch [5/40], Step [9000/16026], Loss: 0.0266
Epoch [5/40], Step [9100/16026], L

Epoch [6/40], Step [8700/16026], Loss: 0.0188
Epoch [6/40], Step [8800/16026], Loss: 0.0236
Epoch [6/40], Step [8900/16026], Loss: 0.0345
Epoch [6/40], Step [9000/16026], Loss: 0.0194
Epoch [6/40], Step [9100/16026], Loss: 0.0168
Epoch [6/40], Step [9200/16026], Loss: 0.0194
Epoch [6/40], Step [9300/16026], Loss: 0.0261
Epoch [6/40], Step [9400/16026], Loss: 0.0175
Epoch [6/40], Step [9500/16026], Loss: 0.0283
Epoch [6/40], Step [9600/16026], Loss: 0.0208
Epoch [6/40], Step [9700/16026], Loss: 0.0264
Epoch [6/40], Step [9800/16026], Loss: 0.0294
Epoch [6/40], Step [9900/16026], Loss: 0.0196
Epoch [6/40], Step [10000/16026], Loss: 0.0378
Epoch [6/40], Step [10100/16026], Loss: 0.0256
Epoch [6/40], Step [10200/16026], Loss: 0.0245
Epoch [6/40], Step [10300/16026], Loss: 0.0238
Epoch [6/40], Step [10400/16026], Loss: 0.0295
Epoch [6/40], Step [10500/16026], Loss: 0.0312
Epoch [6/40], Step [10600/16026], Loss: 0.0342
Epoch [6/40], Step [10700/16026], Loss: 0.0353
Epoch [6/40], Step [10800/

Epoch [7/40], Step [10400/16026], Loss: 0.0199
Epoch [7/40], Step [10500/16026], Loss: 0.0158
Epoch [7/40], Step [10600/16026], Loss: 0.0206
Epoch [7/40], Step [10700/16026], Loss: 0.0305
Epoch [7/40], Step [10800/16026], Loss: 0.0269
Epoch [7/40], Step [10900/16026], Loss: 0.0266
Epoch [7/40], Step [11000/16026], Loss: 0.0231
Epoch [7/40], Step [11100/16026], Loss: 0.0207
Epoch [7/40], Step [11200/16026], Loss: 0.0293
Epoch [7/40], Step [11300/16026], Loss: 0.0287
Epoch [7/40], Step [11400/16026], Loss: 0.0194
Epoch [7/40], Step [11500/16026], Loss: 0.0244
Epoch [7/40], Step [11600/16026], Loss: 0.0174
Epoch [7/40], Step [11700/16026], Loss: 0.0262
Epoch [7/40], Step [11800/16026], Loss: 0.0326
Epoch [7/40], Step [11900/16026], Loss: 0.0319
Epoch [7/40], Step [12000/16026], Loss: 0.0273
Epoch [7/40], Step [12100/16026], Loss: 0.0191
Epoch [7/40], Step [12200/16026], Loss: 0.0276
Epoch [7/40], Step [12300/16026], Loss: 0.0180
Epoch [7/40], Step [12400/16026], Loss: 0.0340
Epoch [7/40],

Epoch [8/40], Step [12100/16026], Loss: 0.0250
Epoch [8/40], Step [12200/16026], Loss: 0.0308
Epoch [8/40], Step [12300/16026], Loss: 0.0228
Epoch [8/40], Step [12400/16026], Loss: 0.0260
Epoch [8/40], Step [12500/16026], Loss: 0.0259
Epoch [8/40], Step [12600/16026], Loss: 0.0205
Epoch [8/40], Step [12700/16026], Loss: 0.0257
Epoch [8/40], Step [12800/16026], Loss: 0.0207
Epoch [8/40], Step [12900/16026], Loss: 0.0216
Epoch [8/40], Step [13000/16026], Loss: 0.0293
Epoch [8/40], Step [13100/16026], Loss: 0.0277
Epoch [8/40], Step [13200/16026], Loss: 0.0227
Epoch [8/40], Step [13300/16026], Loss: 0.0341
Epoch [8/40], Step [13400/16026], Loss: 0.0228
Epoch [8/40], Step [13500/16026], Loss: 0.0244
Epoch [8/40], Step [13600/16026], Loss: 0.0259
Epoch [8/40], Step [13700/16026], Loss: 0.0282
Epoch [8/40], Step [13800/16026], Loss: 0.0203
Epoch [8/40], Step [13900/16026], Loss: 0.0266
Epoch [8/40], Step [14000/16026], Loss: 0.0257
Epoch [8/40], Step [14100/16026], Loss: 0.0281
Epoch [8/40],

Epoch [9/40], Step [13800/16026], Loss: 0.0210
Epoch [9/40], Step [13900/16026], Loss: 0.0198
Epoch [9/40], Step [14000/16026], Loss: 0.0225
Epoch [9/40], Step [14100/16026], Loss: 0.0194
Epoch [9/40], Step [14200/16026], Loss: 0.0256
Epoch [9/40], Step [14300/16026], Loss: 0.0195
Epoch [9/40], Step [14400/16026], Loss: 0.0246
Epoch [9/40], Step [14500/16026], Loss: 0.0243
Epoch [9/40], Step [14600/16026], Loss: 0.0371
Epoch [9/40], Step [14700/16026], Loss: 0.0473
Epoch [9/40], Step [14800/16026], Loss: 0.0334
Epoch [9/40], Step [14900/16026], Loss: 0.0274
Epoch [9/40], Step [15000/16026], Loss: 0.0229
Epoch [9/40], Step [15100/16026], Loss: 0.0157
Epoch [9/40], Step [15200/16026], Loss: 0.0295
Epoch [9/40], Step [15300/16026], Loss: 0.0158
Epoch [9/40], Step [15400/16026], Loss: 0.0275
Epoch [9/40], Step [15500/16026], Loss: 0.0230
Epoch [9/40], Step [15600/16026], Loss: 0.0224
Epoch [9/40], Step [15700/16026], Loss: 0.0221
Epoch [9/40], Step [15800/16026], Loss: 0.0199
Epoch [9/40],

Epoch [10/40], Step [15200/16026], Loss: 0.0270
Epoch [10/40], Step [15300/16026], Loss: 0.0189
Epoch [10/40], Step [15400/16026], Loss: 0.0244
Epoch [10/40], Step [15500/16026], Loss: 0.0292
Epoch [10/40], Step [15600/16026], Loss: 0.0222
Epoch [10/40], Step [15700/16026], Loss: 0.0236
Epoch [10/40], Step [15800/16026], Loss: 0.0358
Epoch [10/40], Step [15900/16026], Loss: 0.0209
Epoch [10/40], Step [16000/16026], Loss: 0.0269
Epoch [11/40], Step [100/16026], Loss: 0.0204
Epoch [11/40], Step [200/16026], Loss: 0.0368
Epoch [11/40], Step [300/16026], Loss: 0.0215
Epoch [11/40], Step [400/16026], Loss: 0.0194
Epoch [11/40], Step [500/16026], Loss: 0.0211
Epoch [11/40], Step [600/16026], Loss: 0.0183
Epoch [11/40], Step [700/16026], Loss: 0.0248
Epoch [11/40], Step [800/16026], Loss: 0.0218
Epoch [11/40], Step [900/16026], Loss: 0.0208
Epoch [11/40], Step [1000/16026], Loss: 0.0216
Epoch [11/40], Step [1100/16026], Loss: 0.0218
Epoch [11/40], Step [1200/16026], Loss: 0.0247
Epoch [11/40]

Epoch [12/40], Step [600/16026], Loss: 0.0158
Epoch [12/40], Step [700/16026], Loss: 0.0208
Epoch [12/40], Step [800/16026], Loss: 0.0180
Epoch [12/40], Step [900/16026], Loss: 0.0213
Epoch [12/40], Step [1000/16026], Loss: 0.0197
Epoch [12/40], Step [1100/16026], Loss: 0.0146
Epoch [12/40], Step [1200/16026], Loss: 0.0206
Epoch [12/40], Step [1300/16026], Loss: 0.0226
Epoch [12/40], Step [1400/16026], Loss: 0.0142
Epoch [12/40], Step [1500/16026], Loss: 0.0251
Epoch [12/40], Step [1600/16026], Loss: 0.0269
Epoch [12/40], Step [1700/16026], Loss: 0.0158
Epoch [12/40], Step [1800/16026], Loss: 0.0149
Epoch [12/40], Step [1900/16026], Loss: 0.0272
Epoch [12/40], Step [2000/16026], Loss: 0.0281
Epoch [12/40], Step [2100/16026], Loss: 0.0259
Epoch [12/40], Step [2200/16026], Loss: 0.0149
Epoch [12/40], Step [2300/16026], Loss: 0.0150
Epoch [12/40], Step [2400/16026], Loss: 0.0210
Epoch [12/40], Step [2500/16026], Loss: 0.0232
Epoch [12/40], Step [2600/16026], Loss: 0.0312
Epoch [12/40], St

Epoch [13/40], Step [2000/16026], Loss: 0.0178
Epoch [13/40], Step [2100/16026], Loss: 0.0081
Epoch [13/40], Step [2200/16026], Loss: 0.0159
Epoch [13/40], Step [2300/16026], Loss: 0.0147
Epoch [13/40], Step [2400/16026], Loss: 0.0244
Epoch [13/40], Step [2500/16026], Loss: 0.0221
Epoch [13/40], Step [2600/16026], Loss: 0.0207
Epoch [13/40], Step [2700/16026], Loss: 0.0160
Epoch [13/40], Step [2800/16026], Loss: 0.0347
Epoch [13/40], Step [2900/16026], Loss: 0.0311
Epoch [13/40], Step [3000/16026], Loss: 0.0228
Epoch [13/40], Step [3100/16026], Loss: 0.0281
Epoch [13/40], Step [3200/16026], Loss: 0.0200
Epoch [13/40], Step [3300/16026], Loss: 0.0209
Epoch [13/40], Step [3400/16026], Loss: 0.0284
Epoch [13/40], Step [3500/16026], Loss: 0.0289
Epoch [13/40], Step [3600/16026], Loss: 0.0228
Epoch [13/40], Step [3700/16026], Loss: 0.0179
Epoch [13/40], Step [3800/16026], Loss: 0.0135
Epoch [13/40], Step [3900/16026], Loss: 0.0206
Epoch [13/40], Step [4000/16026], Loss: 0.0248
Epoch [13/40]

Epoch [14/40], Step [3400/16026], Loss: 0.0269
Epoch [14/40], Step [3500/16026], Loss: 0.0241
Epoch [14/40], Step [3600/16026], Loss: 0.0257
Epoch [14/40], Step [3700/16026], Loss: 0.0266
Epoch [14/40], Step [3800/16026], Loss: 0.0235
Epoch [14/40], Step [3900/16026], Loss: 0.0202
Epoch [14/40], Step [4000/16026], Loss: 0.0173
Epoch [14/40], Step [4100/16026], Loss: 0.0227
Epoch [14/40], Step [4200/16026], Loss: 0.0209
Epoch [14/40], Step [4300/16026], Loss: 0.0255
Epoch [14/40], Step [4400/16026], Loss: 0.0259
Epoch [14/40], Step [4500/16026], Loss: 0.0155
Epoch [14/40], Step [4600/16026], Loss: 0.0133
Epoch [14/40], Step [4700/16026], Loss: 0.0137
Epoch [14/40], Step [4800/16026], Loss: 0.0236
Epoch [14/40], Step [4900/16026], Loss: 0.0275
Epoch [14/40], Step [5000/16026], Loss: 0.0261
Epoch [14/40], Step [5100/16026], Loss: 0.0199
Epoch [14/40], Step [5200/16026], Loss: 0.0166
Epoch [14/40], Step [5300/16026], Loss: 0.0216
Epoch [14/40], Step [5400/16026], Loss: 0.0163
Epoch [14/40]

Epoch [15/40], Step [4800/16026], Loss: 0.0244
Epoch [15/40], Step [4900/16026], Loss: 0.0119
Epoch [15/40], Step [5000/16026], Loss: 0.0158
Epoch [15/40], Step [5100/16026], Loss: 0.0187
Epoch [15/40], Step [5200/16026], Loss: 0.0180
Epoch [15/40], Step [5300/16026], Loss: 0.0254
Epoch [15/40], Step [5400/16026], Loss: 0.0138
Epoch [15/40], Step [5500/16026], Loss: 0.0114
Epoch [15/40], Step [5600/16026], Loss: 0.0203
Epoch [15/40], Step [5700/16026], Loss: 0.0256
Epoch [15/40], Step [5800/16026], Loss: 0.0239
Epoch [15/40], Step [5900/16026], Loss: 0.0202
Epoch [15/40], Step [6000/16026], Loss: 0.0275
Epoch [15/40], Step [6100/16026], Loss: 0.0209
Epoch [15/40], Step [6200/16026], Loss: 0.0171
Epoch [15/40], Step [6300/16026], Loss: 0.0330
Epoch [15/40], Step [6400/16026], Loss: 0.0242
Epoch [15/40], Step [6500/16026], Loss: 0.0175
Epoch [15/40], Step [6600/16026], Loss: 0.0171
Epoch [15/40], Step [6700/16026], Loss: 0.0283
Epoch [15/40], Step [6800/16026], Loss: 0.0228
Epoch [15/40]

Epoch [16/40], Step [6200/16026], Loss: 0.0123
Epoch [16/40], Step [6300/16026], Loss: 0.0262
Epoch [16/40], Step [6400/16026], Loss: 0.0255
Epoch [16/40], Step [6500/16026], Loss: 0.0119
Epoch [16/40], Step [6600/16026], Loss: 0.0220
Epoch [16/40], Step [6700/16026], Loss: 0.0239
Epoch [16/40], Step [6800/16026], Loss: 0.0247
Epoch [16/40], Step [6900/16026], Loss: 0.0167
Epoch [16/40], Step [7000/16026], Loss: 0.0166
Epoch [16/40], Step [7100/16026], Loss: 0.0239
Epoch [16/40], Step [7200/16026], Loss: 0.0216
Epoch [16/40], Step [7300/16026], Loss: 0.0192
Epoch [16/40], Step [7400/16026], Loss: 0.0093
Epoch [16/40], Step [7500/16026], Loss: 0.0186
Epoch [16/40], Step [7600/16026], Loss: 0.0193
Epoch [16/40], Step [7700/16026], Loss: 0.0228
Epoch [16/40], Step [7800/16026], Loss: 0.0135
Epoch [16/40], Step [7900/16026], Loss: 0.0275
Epoch [16/40], Step [8000/16026], Loss: 0.0277
Epoch [16/40], Step [8100/16026], Loss: 0.0266
Epoch [16/40], Step [8200/16026], Loss: 0.0144
Epoch [16/40]

Epoch [17/40], Step [7600/16026], Loss: 0.0183
Epoch [17/40], Step [7700/16026], Loss: 0.0557
Epoch [17/40], Step [7800/16026], Loss: 0.0233
Epoch [17/40], Step [7900/16026], Loss: 0.0206
Epoch [17/40], Step [8000/16026], Loss: 0.0210
Epoch [17/40], Step [8100/16026], Loss: 0.0125
Epoch [17/40], Step [8200/16026], Loss: 0.0165
Epoch [17/40], Step [8300/16026], Loss: 0.0186
Epoch [17/40], Step [8400/16026], Loss: 0.0190
Epoch [17/40], Step [8500/16026], Loss: 0.0250
Epoch [17/40], Step [8600/16026], Loss: 0.0186
Epoch [17/40], Step [8700/16026], Loss: 0.0173
Epoch [17/40], Step [8800/16026], Loss: 0.0264
Epoch [17/40], Step [8900/16026], Loss: 0.0202
Epoch [17/40], Step [9000/16026], Loss: 0.0220
Epoch [17/40], Step [9100/16026], Loss: 0.0240
Epoch [17/40], Step [9200/16026], Loss: 0.0203
Epoch [17/40], Step [9300/16026], Loss: 0.0208
Epoch [17/40], Step [9400/16026], Loss: 0.0186
Epoch [17/40], Step [9500/16026], Loss: 0.0318
Epoch [17/40], Step [9600/16026], Loss: 0.0212
Epoch [17/40]

Epoch [18/40], Step [9000/16026], Loss: 0.0204
Epoch [18/40], Step [9100/16026], Loss: 0.0409
Epoch [18/40], Step [9200/16026], Loss: 0.0278
Epoch [18/40], Step [9300/16026], Loss: 0.0165
Epoch [18/40], Step [9400/16026], Loss: 0.0119
Epoch [18/40], Step [9500/16026], Loss: 0.0133
Epoch [18/40], Step [9600/16026], Loss: 0.0294
Epoch [18/40], Step [9700/16026], Loss: 0.0231
Epoch [18/40], Step [9800/16026], Loss: 0.0344
Epoch [18/40], Step [9900/16026], Loss: 0.0284
Epoch [18/40], Step [10000/16026], Loss: 0.0266
Epoch [18/40], Step [10100/16026], Loss: 0.0211
Epoch [18/40], Step [10200/16026], Loss: 0.0133
Epoch [18/40], Step [10300/16026], Loss: 0.0154
Epoch [18/40], Step [10400/16026], Loss: 0.0273
Epoch [18/40], Step [10500/16026], Loss: 0.0173
Epoch [18/40], Step [10600/16026], Loss: 0.0192
Epoch [18/40], Step [10700/16026], Loss: 0.0219
Epoch [18/40], Step [10800/16026], Loss: 0.0227
Epoch [18/40], Step [10900/16026], Loss: 0.0245
Epoch [18/40], Step [11000/16026], Loss: 0.0152
Ep

Epoch [19/40], Step [10400/16026], Loss: 0.0248
Epoch [19/40], Step [10500/16026], Loss: 0.0184
Epoch [19/40], Step [10600/16026], Loss: 0.0151
Epoch [19/40], Step [10700/16026], Loss: 0.0192
Epoch [19/40], Step [10800/16026], Loss: 0.0241
Epoch [19/40], Step [10900/16026], Loss: 0.0251
Epoch [19/40], Step [11000/16026], Loss: 0.0234
Epoch [19/40], Step [11100/16026], Loss: 0.0192
Epoch [19/40], Step [11200/16026], Loss: 0.0167
Epoch [19/40], Step [11300/16026], Loss: 0.0197
Epoch [19/40], Step [11400/16026], Loss: 0.0341
Epoch [19/40], Step [11500/16026], Loss: 0.0266
Epoch [19/40], Step [11600/16026], Loss: 0.0204
Epoch [19/40], Step [11700/16026], Loss: 0.0134
Epoch [19/40], Step [11800/16026], Loss: 0.0123
Epoch [19/40], Step [11900/16026], Loss: 0.0296
Epoch [19/40], Step [12000/16026], Loss: 0.0269
Epoch [19/40], Step [12100/16026], Loss: 0.0198
Epoch [19/40], Step [12200/16026], Loss: 0.0231
Epoch [19/40], Step [12300/16026], Loss: 0.0160
Epoch [19/40], Step [12400/16026], Loss:

Epoch [20/40], Step [11700/16026], Loss: 0.0146
Epoch [20/40], Step [11800/16026], Loss: 0.0214
Epoch [20/40], Step [11900/16026], Loss: 0.0234
Epoch [20/40], Step [12000/16026], Loss: 0.0255
Epoch [20/40], Step [12100/16026], Loss: 0.0245
Epoch [20/40], Step [12200/16026], Loss: 0.0313
Epoch [20/40], Step [12300/16026], Loss: 0.0309
Epoch [20/40], Step [12400/16026], Loss: 0.0289
Epoch [20/40], Step [12500/16026], Loss: 0.0198
Epoch [20/40], Step [12600/16026], Loss: 0.0126
Epoch [20/40], Step [12700/16026], Loss: 0.0252
Epoch [20/40], Step [12800/16026], Loss: 0.0149
Epoch [20/40], Step [12900/16026], Loss: 0.0170
Epoch [20/40], Step [13000/16026], Loss: 0.0142
Epoch [20/40], Step [13100/16026], Loss: 0.0243
Epoch [20/40], Step [13200/16026], Loss: 0.0200
Epoch [20/40], Step [13300/16026], Loss: 0.0213
Epoch [20/40], Step [13400/16026], Loss: 0.0322
Epoch [20/40], Step [13500/16026], Loss: 0.0224
Epoch [20/40], Step [13600/16026], Loss: 0.0253
Epoch [20/40], Step [13700/16026], Loss:

Epoch [21/40], Step [13000/16026], Loss: 0.0180
Epoch [21/40], Step [13100/16026], Loss: 0.0221
Epoch [21/40], Step [13200/16026], Loss: 0.0298
Epoch [21/40], Step [13300/16026], Loss: 0.0224
Epoch [21/40], Step [13400/16026], Loss: 0.0315
Epoch [21/40], Step [13500/16026], Loss: 0.0226
Epoch [21/40], Step [13600/16026], Loss: 0.0183
Epoch [21/40], Step [13700/16026], Loss: 0.0166
Epoch [21/40], Step [13800/16026], Loss: 0.0145
Epoch [21/40], Step [13900/16026], Loss: 0.0184
Epoch [21/40], Step [14000/16026], Loss: 0.0178
Epoch [21/40], Step [14100/16026], Loss: 0.0179
Epoch [21/40], Step [14200/16026], Loss: 0.0187
Epoch [21/40], Step [14300/16026], Loss: 0.0143
Epoch [21/40], Step [14400/16026], Loss: 0.0345
Epoch [21/40], Step [14500/16026], Loss: 0.0166
Epoch [21/40], Step [14600/16026], Loss: 0.0162
Epoch [21/40], Step [14700/16026], Loss: 0.0254
Epoch [21/40], Step [14800/16026], Loss: 0.0220
Epoch [21/40], Step [14900/16026], Loss: 0.0121
Epoch [21/40], Step [15000/16026], Loss:

Epoch [22/40], Step [14300/16026], Loss: 0.0220
Epoch [22/40], Step [14400/16026], Loss: 0.0227
Epoch [22/40], Step [14500/16026], Loss: 0.0276
Epoch [22/40], Step [14600/16026], Loss: 0.0286
Epoch [22/40], Step [14700/16026], Loss: 0.0284
Epoch [22/40], Step [14800/16026], Loss: 0.0217
Epoch [22/40], Step [14900/16026], Loss: 0.0306
Epoch [22/40], Step [15000/16026], Loss: 0.0177
Epoch [22/40], Step [15100/16026], Loss: 0.0168
Epoch [22/40], Step [15200/16026], Loss: 0.0280
Epoch [22/40], Step [15300/16026], Loss: 0.0182
Epoch [22/40], Step [15400/16026], Loss: 0.0158
Epoch [22/40], Step [15500/16026], Loss: 0.0189
Epoch [22/40], Step [15600/16026], Loss: 0.0262
Epoch [22/40], Step [15700/16026], Loss: 0.0231
Epoch [22/40], Step [15800/16026], Loss: 0.0231
Epoch [22/40], Step [15900/16026], Loss: 0.0185
Epoch [22/40], Step [16000/16026], Loss: 0.0194
Epoch [23/40], Step [100/16026], Loss: 0.0221
Epoch [23/40], Step [200/16026], Loss: 0.0174
Epoch [23/40], Step [300/16026], Loss: 0.016

Epoch [23/40], Step [15600/16026], Loss: 0.0190
Epoch [23/40], Step [15700/16026], Loss: 0.0183
Epoch [23/40], Step [15800/16026], Loss: 0.0205
Epoch [23/40], Step [15900/16026], Loss: 0.0194
Epoch [23/40], Step [16000/16026], Loss: 0.0158
Epoch [24/40], Step [100/16026], Loss: 0.0156
Epoch [24/40], Step [200/16026], Loss: 0.0191
Epoch [24/40], Step [300/16026], Loss: 0.0179
Epoch [24/40], Step [400/16026], Loss: 0.0192
Epoch [24/40], Step [500/16026], Loss: 0.0257
Epoch [24/40], Step [600/16026], Loss: 0.0220
Epoch [24/40], Step [700/16026], Loss: 0.0166
Epoch [24/40], Step [800/16026], Loss: 0.0261
Epoch [24/40], Step [900/16026], Loss: 0.0194
Epoch [24/40], Step [1000/16026], Loss: 0.0136
Epoch [24/40], Step [1100/16026], Loss: 0.0188
Epoch [24/40], Step [1200/16026], Loss: 0.0217
Epoch [24/40], Step [1300/16026], Loss: 0.0137
Epoch [24/40], Step [1400/16026], Loss: 0.0118
Epoch [24/40], Step [1500/16026], Loss: 0.0153
Epoch [24/40], Step [1600/16026], Loss: 0.0142
Epoch [24/40], St

Epoch [25/40], Step [1000/16026], Loss: 0.0195
Epoch [25/40], Step [1100/16026], Loss: 0.0212
Epoch [25/40], Step [1200/16026], Loss: 0.0232
Epoch [25/40], Step [1300/16026], Loss: 0.0218
Epoch [25/40], Step [1400/16026], Loss: 0.0232
Epoch [25/40], Step [1500/16026], Loss: 0.0108
Epoch [25/40], Step [1600/16026], Loss: 0.0195
Epoch [25/40], Step [1700/16026], Loss: 0.0138
Epoch [25/40], Step [1800/16026], Loss: 0.0172
Epoch [25/40], Step [1900/16026], Loss: 0.0155
Epoch [25/40], Step [2000/16026], Loss: 0.0219
Epoch [25/40], Step [2100/16026], Loss: 0.0207
Epoch [25/40], Step [2200/16026], Loss: 0.0213
Epoch [25/40], Step [2300/16026], Loss: 0.0682
Epoch [25/40], Step [2400/16026], Loss: 0.0267
Epoch [25/40], Step [2500/16026], Loss: 0.0315
Epoch [25/40], Step [2600/16026], Loss: 0.0156
Epoch [25/40], Step [2700/16026], Loss: 0.0214
Epoch [25/40], Step [2800/16026], Loss: 0.0290
Epoch [25/40], Step [2900/16026], Loss: 0.0216
Epoch [25/40], Step [3000/16026], Loss: 0.0191
Epoch [25/40]

Epoch [26/40], Step [2400/16026], Loss: 0.0275
Epoch [26/40], Step [2500/16026], Loss: 0.0111
Epoch [26/40], Step [2600/16026], Loss: 0.0153
Epoch [26/40], Step [2700/16026], Loss: 0.0159
Epoch [26/40], Step [2800/16026], Loss: 0.0202
Epoch [26/40], Step [2900/16026], Loss: 0.0208
Epoch [26/40], Step [3000/16026], Loss: 0.0218
Epoch [26/40], Step [3100/16026], Loss: 0.0170
Epoch [26/40], Step [3200/16026], Loss: 0.0152
Epoch [26/40], Step [3300/16026], Loss: 0.0183
Epoch [26/40], Step [3400/16026], Loss: 0.0220
Epoch [26/40], Step [3500/16026], Loss: 0.0206
Epoch [26/40], Step [3600/16026], Loss: 0.0166
Epoch [26/40], Step [3700/16026], Loss: 0.0234
Epoch [26/40], Step [3800/16026], Loss: 0.0114
Epoch [26/40], Step [3900/16026], Loss: 0.0199
Epoch [26/40], Step [4000/16026], Loss: 0.0117
Epoch [26/40], Step [4100/16026], Loss: 0.0203
Epoch [26/40], Step [4200/16026], Loss: 0.0160
Epoch [26/40], Step [4300/16026], Loss: 0.0218
Epoch [26/40], Step [4400/16026], Loss: 0.0226
Epoch [26/40]

Epoch [27/40], Step [3800/16026], Loss: 0.0267
Epoch [27/40], Step [3900/16026], Loss: 0.0096
Epoch [27/40], Step [4000/16026], Loss: 0.0172
Epoch [27/40], Step [4100/16026], Loss: 0.0128
Epoch [27/40], Step [4200/16026], Loss: 0.0221
Epoch [27/40], Step [4300/16026], Loss: 0.0203
Epoch [27/40], Step [4400/16026], Loss: 0.0189
Epoch [27/40], Step [4500/16026], Loss: 0.0204
Epoch [27/40], Step [4600/16026], Loss: 0.0209
Epoch [27/40], Step [4700/16026], Loss: 0.0221
Epoch [27/40], Step [4800/16026], Loss: 0.0226
Epoch [27/40], Step [4900/16026], Loss: 0.0233
Epoch [27/40], Step [5000/16026], Loss: 0.0176
Epoch [27/40], Step [5100/16026], Loss: 0.0252
Epoch [27/40], Step [5200/16026], Loss: 0.0142
Epoch [27/40], Step [5300/16026], Loss: 0.0236
Epoch [27/40], Step [5400/16026], Loss: 0.0225
Epoch [27/40], Step [5500/16026], Loss: 0.0219
Epoch [27/40], Step [5600/16026], Loss: 0.0190
Epoch [27/40], Step [5700/16026], Loss: 0.0157
Epoch [27/40], Step [5800/16026], Loss: 0.0184
Epoch [27/40]

Epoch [28/40], Step [5200/16026], Loss: 0.0169
Epoch [28/40], Step [5300/16026], Loss: 0.0182
Epoch [28/40], Step [5400/16026], Loss: 0.0220
Epoch [28/40], Step [5500/16026], Loss: 0.0233
Epoch [28/40], Step [5600/16026], Loss: 0.0338
Epoch [28/40], Step [5700/16026], Loss: 0.0172
Epoch [28/40], Step [5800/16026], Loss: 0.0280
Epoch [28/40], Step [5900/16026], Loss: 0.0205
Epoch [28/40], Step [6000/16026], Loss: 0.0207
Epoch [28/40], Step [6100/16026], Loss: 0.0151
Epoch [28/40], Step [6200/16026], Loss: 0.0104
Epoch [28/40], Step [6300/16026], Loss: 0.0110
Epoch [28/40], Step [6400/16026], Loss: 0.0211
Epoch [28/40], Step [6500/16026], Loss: 0.0290
Epoch [28/40], Step [6600/16026], Loss: 0.0214
Epoch [28/40], Step [6700/16026], Loss: 0.0177
Epoch [28/40], Step [6800/16026], Loss: 0.0250
Epoch [28/40], Step [6900/16026], Loss: 0.0275
Epoch [28/40], Step [7000/16026], Loss: 0.0145
Epoch [28/40], Step [7100/16026], Loss: 0.0206
Epoch [28/40], Step [7200/16026], Loss: 0.0185
Epoch [28/40]

Epoch [29/40], Step [6600/16026], Loss: 0.0151
Epoch [29/40], Step [6700/16026], Loss: 0.0187
Epoch [29/40], Step [6800/16026], Loss: 0.0188
Epoch [29/40], Step [6900/16026], Loss: 0.0239
Epoch [29/40], Step [7000/16026], Loss: 0.0164
Epoch [29/40], Step [7100/16026], Loss: 0.0324
Epoch [29/40], Step [7200/16026], Loss: 0.0160
Epoch [29/40], Step [7300/16026], Loss: 0.0326
Epoch [29/40], Step [7400/16026], Loss: 0.0116
Epoch [29/40], Step [7500/16026], Loss: 0.0243
Epoch [29/40], Step [7600/16026], Loss: 0.0271
Epoch [29/40], Step [7700/16026], Loss: 0.0151
Epoch [29/40], Step [7800/16026], Loss: 0.0194
Epoch [29/40], Step [7900/16026], Loss: 0.0163
Epoch [29/40], Step [8000/16026], Loss: 0.0117
Epoch [29/40], Step [8100/16026], Loss: 0.0319
Epoch [29/40], Step [8200/16026], Loss: 0.0230
Epoch [29/40], Step [8300/16026], Loss: 0.0224
Epoch [29/40], Step [8400/16026], Loss: 0.0166
Epoch [29/40], Step [8500/16026], Loss: 0.0260
Epoch [29/40], Step [8600/16026], Loss: 0.0086
Epoch [29/40]

Epoch [30/40], Step [8000/16026], Loss: 0.0283
Epoch [30/40], Step [8100/16026], Loss: 0.0141
Epoch [30/40], Step [8200/16026], Loss: 0.0183
Epoch [30/40], Step [8300/16026], Loss: 0.0126
Epoch [30/40], Step [8400/16026], Loss: 0.0176
Epoch [30/40], Step [8500/16026], Loss: 0.0173
Epoch [30/40], Step [8600/16026], Loss: 0.0211
Epoch [30/40], Step [8700/16026], Loss: 0.0141
Epoch [30/40], Step [8800/16026], Loss: 0.0274
Epoch [30/40], Step [8900/16026], Loss: 0.0237
Epoch [30/40], Step [9000/16026], Loss: 0.0209
Epoch [30/40], Step [9100/16026], Loss: 0.0319
Epoch [30/40], Step [9200/16026], Loss: 0.0259
Epoch [30/40], Step [9300/16026], Loss: 0.0157
Epoch [30/40], Step [9400/16026], Loss: 0.0192
Epoch [30/40], Step [9500/16026], Loss: 0.0190
Epoch [30/40], Step [9600/16026], Loss: 0.0184
Epoch [30/40], Step [9700/16026], Loss: 0.0111
Epoch [30/40], Step [9800/16026], Loss: 0.0233
Epoch [30/40], Step [9900/16026], Loss: 0.0171
Epoch [30/40], Step [10000/16026], Loss: 0.0409
Epoch [30/40

Epoch [31/40], Step [9400/16026], Loss: 0.0103
Epoch [31/40], Step [9500/16026], Loss: 0.0239
Epoch [31/40], Step [9600/16026], Loss: 0.0139
Epoch [31/40], Step [9700/16026], Loss: 0.0192
Epoch [31/40], Step [9800/16026], Loss: 0.0288
Epoch [31/40], Step [9900/16026], Loss: 0.0144
Epoch [31/40], Step [10000/16026], Loss: 0.0193
Epoch [31/40], Step [10100/16026], Loss: 0.0149
Epoch [31/40], Step [10200/16026], Loss: 0.0234
Epoch [31/40], Step [10300/16026], Loss: 0.0133
Epoch [31/40], Step [10400/16026], Loss: 0.0206
Epoch [31/40], Step [10500/16026], Loss: 0.0207
Epoch [31/40], Step [10600/16026], Loss: 0.0141
Epoch [31/40], Step [10700/16026], Loss: 0.0371
Epoch [31/40], Step [10800/16026], Loss: 0.0186
Epoch [31/40], Step [10900/16026], Loss: 0.0132
Epoch [31/40], Step [11000/16026], Loss: 0.0279
Epoch [31/40], Step [11100/16026], Loss: 0.0253
Epoch [31/40], Step [11200/16026], Loss: 0.0120
Epoch [31/40], Step [11300/16026], Loss: 0.0188
Epoch [31/40], Step [11400/16026], Loss: 0.011

Epoch [32/40], Step [10800/16026], Loss: 0.0137
Epoch [32/40], Step [10900/16026], Loss: 0.0106
Epoch [32/40], Step [11000/16026], Loss: 0.0176
Epoch [32/40], Step [11100/16026], Loss: 0.0088
Epoch [32/40], Step [11200/16026], Loss: 0.0212
Epoch [32/40], Step [11300/16026], Loss: 0.0178
Epoch [32/40], Step [11400/16026], Loss: 0.0272
Epoch [32/40], Step [11500/16026], Loss: 0.0158
Epoch [32/40], Step [11600/16026], Loss: 0.0233
Epoch [32/40], Step [11700/16026], Loss: 0.0166
Epoch [32/40], Step [11800/16026], Loss: 0.0188
Epoch [32/40], Step [11900/16026], Loss: 0.0221
Epoch [32/40], Step [12000/16026], Loss: 0.0242
Epoch [32/40], Step [12100/16026], Loss: 0.0342
Epoch [32/40], Step [12200/16026], Loss: 0.0225
Epoch [32/40], Step [12300/16026], Loss: 0.0299
Epoch [32/40], Step [12400/16026], Loss: 0.0224
Epoch [32/40], Step [12500/16026], Loss: 0.0218
Epoch [32/40], Step [12600/16026], Loss: 0.0183
Epoch [32/40], Step [12700/16026], Loss: 0.0317
Epoch [32/40], Step [12800/16026], Loss:

Epoch [33/40], Step [12100/16026], Loss: 0.0202
Epoch [33/40], Step [12200/16026], Loss: 0.0225
Epoch [33/40], Step [12300/16026], Loss: 0.0212
Epoch [33/40], Step [12400/16026], Loss: 0.0251
Epoch [33/40], Step [12500/16026], Loss: 0.0265
Epoch [33/40], Step [12600/16026], Loss: 0.0142
Epoch [33/40], Step [12700/16026], Loss: 0.0284
Epoch [33/40], Step [12800/16026], Loss: 0.0147
Epoch [33/40], Step [12900/16026], Loss: 0.0162
Epoch [33/40], Step [13000/16026], Loss: 0.0193
Epoch [33/40], Step [13100/16026], Loss: 0.0128
Epoch [33/40], Step [13200/16026], Loss: 0.0264
Epoch [33/40], Step [13300/16026], Loss: 0.0282
Epoch [33/40], Step [13400/16026], Loss: 0.0473
Epoch [33/40], Step [13500/16026], Loss: 0.0143
Epoch [33/40], Step [13600/16026], Loss: 0.0192
Epoch [33/40], Step [13700/16026], Loss: 0.0154
Epoch [33/40], Step [13800/16026], Loss: 0.0149
Epoch [33/40], Step [13900/16026], Loss: 0.0199
Epoch [33/40], Step [14000/16026], Loss: 0.0247
Epoch [33/40], Step [14100/16026], Loss:

Epoch [34/40], Step [13400/16026], Loss: 0.0161
Epoch [34/40], Step [13500/16026], Loss: 0.0300
Epoch [34/40], Step [13600/16026], Loss: 0.0260
Epoch [34/40], Step [13700/16026], Loss: 0.0269
Epoch [34/40], Step [13800/16026], Loss: 0.0203
Epoch [34/40], Step [13900/16026], Loss: 0.0192
Epoch [34/40], Step [14000/16026], Loss: 0.0169
Epoch [34/40], Step [14100/16026], Loss: 0.0153
Epoch [34/40], Step [14200/16026], Loss: 0.0164
Epoch [34/40], Step [14300/16026], Loss: 0.0122
Epoch [34/40], Step [14400/16026], Loss: 0.0221
Epoch [34/40], Step [14500/16026], Loss: 0.0209
Epoch [34/40], Step [14600/16026], Loss: 0.0275
Epoch [34/40], Step [14700/16026], Loss: 0.0217
Epoch [34/40], Step [14800/16026], Loss: 0.0170
Epoch [34/40], Step [14900/16026], Loss: 0.0388
Epoch [34/40], Step [15000/16026], Loss: 0.0173
Epoch [34/40], Step [15100/16026], Loss: 0.0192
Epoch [34/40], Step [15200/16026], Loss: 0.0236
Epoch [34/40], Step [15300/16026], Loss: 0.0174
Epoch [34/40], Step [15400/16026], Loss:

Epoch [35/40], Step [14700/16026], Loss: 0.0194
Epoch [35/40], Step [14800/16026], Loss: 0.0145
Epoch [35/40], Step [14900/16026], Loss: 0.0242
Epoch [35/40], Step [15000/16026], Loss: 0.0162
Epoch [35/40], Step [15100/16026], Loss: 0.0189
Epoch [35/40], Step [15200/16026], Loss: 0.0192
Epoch [35/40], Step [15300/16026], Loss: 0.0188
Epoch [35/40], Step [15400/16026], Loss: 0.0186
Epoch [35/40], Step [15500/16026], Loss: 0.0193
Epoch [35/40], Step [15600/16026], Loss: 0.0142
Epoch [35/40], Step [15700/16026], Loss: 0.0217
Epoch [35/40], Step [15800/16026], Loss: 0.0252
Epoch [35/40], Step [15900/16026], Loss: 0.0113
Epoch [35/40], Step [16000/16026], Loss: 0.0182
Epoch [36/40], Step [100/16026], Loss: 0.0239
Epoch [36/40], Step [200/16026], Loss: 0.0207
Epoch [36/40], Step [300/16026], Loss: 0.0155
Epoch [36/40], Step [400/16026], Loss: 0.0389
Epoch [36/40], Step [500/16026], Loss: 0.0232
Epoch [36/40], Step [600/16026], Loss: 0.0179
Epoch [36/40], Step [700/16026], Loss: 0.0340
Epoch 

Epoch [36/40], Step [16000/16026], Loss: 0.0111
Epoch [37/40], Step [100/16026], Loss: 0.0206
Epoch [37/40], Step [200/16026], Loss: 0.0169
Epoch [37/40], Step [300/16026], Loss: 0.0225
Epoch [37/40], Step [400/16026], Loss: 0.0134
Epoch [37/40], Step [500/16026], Loss: 0.0137
Epoch [37/40], Step [600/16026], Loss: 0.0175
Epoch [37/40], Step [700/16026], Loss: 0.0217
Epoch [37/40], Step [800/16026], Loss: 0.0131
Epoch [37/40], Step [900/16026], Loss: 0.0155
Epoch [37/40], Step [1000/16026], Loss: 0.0279
Epoch [37/40], Step [1100/16026], Loss: 0.0225
Epoch [37/40], Step [1200/16026], Loss: 0.0249
Epoch [37/40], Step [1300/16026], Loss: 0.0114
Epoch [37/40], Step [1400/16026], Loss: 0.0144
Epoch [37/40], Step [1500/16026], Loss: 0.0182
Epoch [37/40], Step [1600/16026], Loss: 0.0239
Epoch [37/40], Step [1700/16026], Loss: 0.0190
Epoch [37/40], Step [1800/16026], Loss: 0.0187
Epoch [37/40], Step [1900/16026], Loss: 0.0155
Epoch [37/40], Step [2000/16026], Loss: 0.0162
Epoch [37/40], Step [

Epoch [38/40], Step [1400/16026], Loss: 0.0197
Epoch [38/40], Step [1500/16026], Loss: 0.0215
Epoch [38/40], Step [1600/16026], Loss: 0.0174
Epoch [38/40], Step [1700/16026], Loss: 0.0217
Epoch [38/40], Step [1800/16026], Loss: 0.0250
Epoch [38/40], Step [1900/16026], Loss: 0.0110
Epoch [38/40], Step [2000/16026], Loss: 0.0186
Epoch [38/40], Step [2100/16026], Loss: 0.0222
Epoch [38/40], Step [2200/16026], Loss: 0.0184
Epoch [38/40], Step [2300/16026], Loss: 0.0284
Epoch [38/40], Step [2400/16026], Loss: 0.0164
Epoch [38/40], Step [2500/16026], Loss: 0.0172
Epoch [38/40], Step [2600/16026], Loss: 0.0135
Epoch [38/40], Step [2700/16026], Loss: 0.0190
Epoch [38/40], Step [2800/16026], Loss: 0.0185
Epoch [38/40], Step [2900/16026], Loss: 0.0200
Epoch [38/40], Step [3000/16026], Loss: 0.0147
Epoch [38/40], Step [3100/16026], Loss: 0.0137
Epoch [38/40], Step [3200/16026], Loss: 0.0145
Epoch [38/40], Step [3300/16026], Loss: 0.0186
Epoch [38/40], Step [3400/16026], Loss: 0.0216
Epoch [38/40]

Epoch [39/40], Step [2800/16026], Loss: 0.0126
Epoch [39/40], Step [2900/16026], Loss: 0.0222
Epoch [39/40], Step [3000/16026], Loss: 0.0177
Epoch [39/40], Step [3100/16026], Loss: 0.0226
Epoch [39/40], Step [3200/16026], Loss: 0.0149
Epoch [39/40], Step [3300/16026], Loss: 0.0235
Epoch [39/40], Step [3400/16026], Loss: 0.0227
Epoch [39/40], Step [3500/16026], Loss: 0.0156
Epoch [39/40], Step [3600/16026], Loss: 0.0320
Epoch [39/40], Step [3700/16026], Loss: 0.0145
Epoch [39/40], Step [3800/16026], Loss: 0.0127
Epoch [39/40], Step [3900/16026], Loss: 0.0211
Epoch [39/40], Step [4000/16026], Loss: 0.0120
Epoch [39/40], Step [4100/16026], Loss: 0.0282
Epoch [39/40], Step [4200/16026], Loss: 0.0187
Epoch [39/40], Step [4300/16026], Loss: 0.0213
Epoch [39/40], Step [4400/16026], Loss: 0.0290
Epoch [39/40], Step [4500/16026], Loss: 0.0234
Epoch [39/40], Step [4600/16026], Loss: 0.0239
Epoch [39/40], Step [4700/16026], Loss: 0.0202
Epoch [39/40], Step [4800/16026], Loss: 0.0188
Epoch [39/40]

Epoch [40/40], Step [4200/16026], Loss: 0.0545
Epoch [40/40], Step [4300/16026], Loss: 0.0198
Epoch [40/40], Step [4400/16026], Loss: 0.0219
Epoch [40/40], Step [4500/16026], Loss: 0.0242
Epoch [40/40], Step [4600/16026], Loss: 0.0184
Epoch [40/40], Step [4700/16026], Loss: 0.0158
Epoch [40/40], Step [4800/16026], Loss: 0.0218
Epoch [40/40], Step [4900/16026], Loss: 0.0110
Epoch [40/40], Step [5000/16026], Loss: 0.0190
Epoch [40/40], Step [5100/16026], Loss: 0.0251
Epoch [40/40], Step [5200/16026], Loss: 0.0176
Epoch [40/40], Step [5300/16026], Loss: 0.0205
Epoch [40/40], Step [5400/16026], Loss: 0.0160
Epoch [40/40], Step [5500/16026], Loss: 0.0210
Epoch [40/40], Step [5600/16026], Loss: 0.0181
Epoch [40/40], Step [5700/16026], Loss: 0.0180
Epoch [40/40], Step [5800/16026], Loss: 0.0139
Epoch [40/40], Step [5900/16026], Loss: 0.0263
Epoch [40/40], Step [6000/16026], Loss: 0.0227
Epoch [40/40], Step [6100/16026], Loss: 0.0190
Epoch [40/40], Step [6200/16026], Loss: 0.0185
Epoch [40/40]

### Apply the model to test dataset

In [22]:
# convert test data to PyTorch tensors and move them to the device
X_test_tensor = torch.tensor(test_X.values, dtype=torch.float).to(device)
y_test_tensor = torch.tensor(test_y.values, dtype=torch.long).to(device)  # If your task is classification

In [23]:
# apply the model to test dataset
model.eval()  # set the model to evaluation mode
all_predictions = []
all_labels = []
with torch.no_grad():
    for features in DataLoader(X_test_tensor, batch_size=32):
        features = features.unsqueeze(1) 
        outputs = model(features)
        predictions = torch.sigmoid(outputs).round()
        all_predictions.extend(predictions.cpu().numpy())

all_labels = y_test_tensor.cpu().numpy()

### Evaluate the model

In [24]:
def compute_f1(p, r):
    if isinstance(p, (list, tuple)):
        p = np.array(p)
    if isinstance(r, (list, tuple)):
        r = np.array(r)
    denom = p + r
    return (2 * p * r) / (denom + 1e-10)

def compute_metric(gt_diff_mask, pred_diff_mask):
    result = {}
    ddr = np.sum(np.logical_and(gt_diff_mask, pred_diff_mask), axis=-1) / np.maximum(1, np.sum(gt_diff_mask, axis=-1))
    ddp = np.sum(np.logical_and(gt_diff_mask, pred_diff_mask), axis=-1) / np.maximum(1, np.sum(pred_diff_mask, axis=-1))
    ddf1 = compute_f1(ddp, ddr)
    
    result[f"ACC"] = 0.997 # from previous model
    result[f"DDP"] = np.mean(ddp)
    result[f"DDR"] = np.mean(ddr)
    result[f"DDF1"] = np.mean(ddf1)
    result[f"GM"] = (result[f"ACC"] * result[f"DDR"] * result[f"DDP"])**(1/3)
    return result

In [25]:
# apply the function to get the compute metric of the model
metrics = compute_metric(all_labels, np.array(all_predictions))
print(metrics)


{'ACC': 0.997, 'DDP': 0.9815855225524738, 'DDR': 0.9828306508048825, 'DDF1': 0.9802088079351633, 'GM': 0.9871140457229127}
