In [1]:
import torch
import pandas as pd
import sklearn
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from tqdm import tqdm
from torch.utils.data import Dataset,DataLoader
from utils.nn_utils import *
from models.ViT import ViT_LRP_nan_excluded
from models.LSTM import Recurrent_Classifier

X_datapaths = ['./preprocessed/prepared/nan/L2Y1.pkl','./preprocessed/prepared/nan/L2Y2.pkl','./preprocessed/prepared/nan/L2Y3.pkl','./preprocessed/prepared/nan/L2Y4.pkl','./preprocessed/prepared/nan/L2Y5.pkl','./preprocessed/prepared/nan/L2Y6.pkl',]
label_datapath = './preprocessed/prepared/nan/label.pkl'
#X_datapaths = ['./preprocessed/prepared/fill/L2Y1.pkl','./preprocessed/prepared/fill/L2Y2.pkl','./preprocessed/prepared/fill/L2Y3.pkl','./preprocessed/prepared/fill/L2Y4.pkl','./preprocessed/prepared/fill/L2Y5.pkl','./preprocessed/prepared/fill/L2Y6.pkl',]
#label_datapath = './preprocessed/prepared/fill/label.pkl'
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")



# read pickle
input_datas = [] # list of each input pandas dataframe
for datapath in X_datapaths:
    temp = pd.read_pickle(datapath)
    temp = temp.reset_index()
    input_datas.append(temp)

label_data = pd.read_pickle(label_datapath)
label_data = label_data.reset_index()



seq_len = len(input_datas)

label_data = label_data - 1


CLS2IDX = {
    0 : '1등급',
    1 : '2등급',
    2 : '3등급',
    3 : '4등급',
    4 : '5등급',
    5 : '6등급',
    6 : '7등급',
    7 : '8등급',
    8 : '9등급'
}
is_regression = False
X_trains, X_tests, y_train, y_test = make_splited_data(input_datas,label_data,is_regression=is_regression)

train_dataset = KELSDataSet(X_trains,y_train,is_regression=is_regression)
test_dataset = KELSDataSet(X_tests,y_test,is_regression=is_regression)


batch_size = 32
hidden_features = 100
embbed_dim = 72



train_loader = DataLoader(train_dataset,batch_size=batch_size,shuffle=True)
test_loader = DataLoader(test_dataset,batch_size=batch_size,shuffle=False)
split_list = train_dataset.split_list
#embedding_networks : 년차별로 맞는 mlp 리스트. 리스트 내용물에 따라 인풋 채널 개수 다름.

sample_loader = DataLoader(train_dataset,batch_size=batch_size,shuffle=True)
assert batch_size
(sample,label) = next(iter(sample_loader))
sample = sample.to(device)
sample_datas = batch_to_splited_datas(sample,split_list)


#embbeding_networks = make_embbeding_networks(sample_datas,batch_size=batch_size,hidden_features = hidden_features,out_features=embbed_dim)
#embbeding network 사용법 : 
#1. train_loader에서 배치를 받는다. 이 배치는 (batch, 학생수, features)인데, features는 모든 년차별 feature가 concat된 상태.
#2. 이 배치를 batch_to_splited_datas에 넣는다. datas = batch_to_splited_datas(배치,split_list)
# 그러면 datas에 년차별 feature가 나누어진 리스트가 생기게 된다. 리스트 길이는 년차 길이고, 내용물은 행렬
# 3. 이 datas를 batch_to_embbedings 에 넣는다. 그럼 각 행렬이 embbeding_networks에 든 mlp를 통과하여 같은 크기의 행렬들을 리턴한다. 
# emb_batch_list= batch_to_embbedings(datas,embbeding_networks)
#4. emb_batch_list의 내용물이 바로 contrastive loss에 들어갈 벡터들이 된다. 

#5. lstm에 넣으려면 얘들을 seq_len 방향으로 쌓아야 한다. 
#emb_batched_seq = torch.stack(emb_batch_list).transpose(0,1) 으로 쌓고 lstm에 넣으면 완성..?

In [2]:
def accuracy_roughly(y_pred, y_label):
    #TODO add as loss
    if len(y_pred) != len(y_label):
        print("not available, fit size first")
        return
    cnt = 0
    correct = 0
    for pred, label in zip(y_pred, y_label):
        cnt += 1
        if abs(pred-label) <= 1:
            correct += 1
    return correct / cnt

In [30]:
def train_net(model,train_loader,test_loader,optimizer_cls = optim.AdamW, criterion = nn.CrossEntropyLoss(),
n_iter=10,device='cpu',lr = 0.001,weight_decay = 0.01,mode = None):
        
        train_losses = []
        train_acc = []
        val_accs = []
        positive_accs = []
        #optimizer = optimizer_cls(model.parameters(),lr=lr,weight_decay=weight_decay)
        optimizer = optimizer_cls(model.parameters(),lr=lr)
        #scheduler = optim.lr_scheduler.MultiStepLR(optimizer,milestones=[25,40,60,80], gamma=0.5,last_epoch=-1)
        

        for epoch in range(n_iter):
                running_loss = 0.0
                model.train()
                n = 0
                n_acc = 0
                ys = []
                ypreds = []
                for i,(xx,(label_E,label_K,label_M)) in tqdm(enumerate(train_loader)):

                
                
                        xx = xx.to(device)
                        if mode == 'E':
                                yy = label_E
                        elif mode == 'K':
                                yy = label_K
                        elif mode == 'M':
                                yy = label_M
                        else:
                                assert True
                        
                        yy = yy.to(device)
                        

                        
                        
                
                        optimizer.zero_grad()
                        outputs = model(xx)
                        _,y_pred = outputs.max(1)

                        # Calculate Loss: softmax --> cross entropy loss
                        #loss = criterion(outputs, yy)
                        loss1 = criterion(outputs,yy)



                        yy_idx1 = (yy == 8)
                        yy_idx2 = (yy == 0)

                        # outputs2 = outputs.clone()
                        # outputs2[yy_idx1] = 0
                        # outputs3 = outputs.clone()
                        # outputs3[yy_idx2] = 0


                        loss2 = criterion(outputs,(yy+1).clamp(max=8))
                        loss3 = criterion(outputs,(yy-1).clamp(min=0))
                        #loss2[yy_idx1] = 0
                        #loss3[yy_idx2] = 0
                        loss = loss1 + loss2 + loss3

                        # Getting gradients w.r.t. parameters
                        loss.backward()

                        # Updating parameters
                        optimizer.step()
                        ys.append(yy)
                        ypreds.append(y_pred)
                        
                        
                        i += 1
                        n += len(xx)
                        _, y_pred = outputs.max(1)
                        n_acc += (yy == y_pred).float().sum().item()
                #scheduler.step()
                train_losses.append(running_loss/i)
                train_acc.append(n_acc/n)
                ys = torch.cat(ys)
                ypreds = torch.cat(ypreds)
                train_positive_acc = accuracy_roughly(ypreds,ys)
                acc, positive_acc = eval_net(model,test_loader,device,mode = mode)
                val_accs.append(acc)
                positive_accs.append(positive_acc)

                print(f'epoch : {epoch},train_positive_acc : {train_positive_acc} train_acc : {train_acc[-1]}, acc : {val_accs[-1]}. positive_acc : {positive_accs[-1]}',flush = True)

In [31]:
def eval_net(model,data_loader,device,mode=None):
    model.eval()
    ys = []
    ypreds = []
    for xx,(label_E,label_K,label_M) in data_loader:

                
                
        xx = xx.to(device)
        if mode == 'E':
            y = label_E
        elif mode == 'K':
            y = label_K
        elif mode == 'M':
            y = label_M
        else:
            assert True
        
        y = y.to(device)

        with torch.no_grad():
                score = model(xx)
                _,y_pred = score.max(1)
        ys.append(y)
        ypreds.append(y_pred)

    ys = torch.cat(ys)
    ypreds = torch.cat(ypreds)
    positive_acc = accuracy_roughly(ypreds,ys)
    acc= (ys == ypreds).float().sum() / len(ys)

    # print(sklearn.metrics.confusion_matrix(ys.numpy(),ypreds.numpy()))


    # print(sklearn.metrics.classification_report(ys.numpy(),ypreds.numpy()))
    

    return acc, positive_acc
    #return acc.item()

In [5]:
model_LSTM = Recurrent_Classifier.RecurrentClassifier(sample_datas,split_list,embedding_dim=32,hidden_dim=64,output_size=9)
model_LSTM = model_LSTM.to(device)
train_net(model_LSTM,train_loader,test_loader,n_iter=100,device=device,mode='E',lr=0.0001,optimizer_cls = optim.AdamW)

131it [00:02, 53.31it/s]


epoch : 0, train_acc : 0.129171668667467, acc : 0.19961611926555634. positive_acc : 0.40978886756238003


131it [00:02, 54.73it/s]


epoch : 1, train_acc : 0.18319327731092436, acc : 0.19961611926555634. positive_acc : 0.40978886756238003


131it [00:02, 56.14it/s]


epoch : 2, train_acc : 0.18127250900360145, acc : 0.19961611926555634. positive_acc : 0.40978886756238003


131it [00:02, 55.89it/s]


epoch : 3, train_acc : 0.1836734693877551, acc : 0.19961611926555634. positive_acc : 0.40978886756238003


131it [00:02, 55.67it/s]


epoch : 4, train_acc : 0.18415366146458584, acc : 0.19961611926555634. positive_acc : 0.40978886756238003


131it [00:02, 55.83it/s]


epoch : 5, train_acc : 0.1786314525810324, acc : 0.2005758136510849. positive_acc : 0.41170825335892514


131it [00:02, 56.09it/s]


epoch : 6, train_acc : 0.1951980792316927, acc : 0.24856045842170715. positive_acc : 0.5239923224568138


131it [00:02, 56.45it/s]


epoch : 7, train_acc : 0.2319327731092437, acc : 0.2581573724746704. positive_acc : 0.5499040307101728


131it [00:02, 55.91it/s]


epoch : 8, train_acc : 0.24849939975990396, acc : 0.29558539390563965. positive_acc : 0.5940499040307101


131it [00:02, 56.13it/s]


epoch : 9, train_acc : 0.2799519807923169, acc : 0.2994241714477539. positive_acc : 0.6065259117082533


131it [00:02, 57.42it/s]


epoch : 10, train_acc : 0.2917166866746699, acc : 0.30422264337539673. positive_acc : 0.6084452975047985


131it [00:02, 54.96it/s]


epoch : 11, train_acc : 0.3106842737094838, acc : 0.33109402656555176. positive_acc : 0.6775431861804223


131it [00:02, 57.09it/s]


epoch : 12, train_acc : 0.31020408163265306, acc : 0.3243761956691742. positive_acc : 0.6698656429942419


131it [00:02, 54.98it/s]


epoch : 13, train_acc : 0.3138055222088836, acc : 0.3320537209510803. positive_acc : 0.6756238003838771


131it [00:02, 56.04it/s]


epoch : 14, train_acc : 0.3258103241296519, acc : 0.334932804107666. positive_acc : 0.6785028790786948


131it [00:02, 56.87it/s]


epoch : 15, train_acc : 0.33085234093637456, acc : 0.34452974796295166. positive_acc : 0.6948176583493282


131it [00:02, 56.55it/s]


epoch : 16, train_acc : 0.3368547418967587, acc : 0.3483685255050659. positive_acc : 0.6938579654510557


131it [00:02, 57.87it/s]


epoch : 17, train_acc : 0.34093637454981995, acc : 0.3598848283290863. positive_acc : 0.7178502879078695


131it [00:02, 57.72it/s]


epoch : 18, train_acc : 0.3430972388955582, acc : 0.3579654395580292. positive_acc : 0.7111324376199616


131it [00:02, 57.08it/s]


epoch : 19, train_acc : 0.33733493397358943, acc : 0.34740883111953735. positive_acc : 0.6890595009596929


131it [00:02, 56.41it/s]


epoch : 20, train_acc : 0.3411764705882353, acc : 0.3598848283290863. positive_acc : 0.7140115163147792


131it [00:02, 56.79it/s]


epoch : 21, train_acc : 0.3438175270108043, acc : 0.35892513394355774. positive_acc : 0.7149712092130518


131it [00:02, 56.55it/s]


epoch : 22, train_acc : 0.33997599039615845, acc : 0.3541266620159149. positive_acc : 0.6986564299424184


131it [00:02, 57.59it/s]


epoch : 23, train_acc : 0.33997599039615845, acc : 0.36084452271461487. positive_acc : 0.710172744721689


131it [00:02, 57.07it/s]


epoch : 24, train_acc : 0.33997599039615845, acc : 0.36180421710014343. positive_acc : 0.7111324376199616


131it [00:02, 56.49it/s]


epoch : 25, train_acc : 0.34477791116446577, acc : 0.35892513394355774. positive_acc : 0.7092130518234165


131it [00:02, 57.22it/s]


epoch : 26, train_acc : 0.35486194477791116, acc : 0.36372360587120056. positive_acc : 0.7178502879078695


131it [00:02, 57.24it/s]


epoch : 27, train_acc : 0.3539015606242497, acc : 0.3598848283290863. positive_acc : 0.7063339731285988


131it [00:02, 56.65it/s]


epoch : 28, train_acc : 0.3495798319327731, acc : 0.3646833002567291. positive_acc : 0.7197696737044146


131it [00:02, 56.21it/s]


epoch : 29, train_acc : 0.3572629051620648, acc : 0.3579654395580292. positive_acc : 0.7063339731285988


131it [00:02, 57.33it/s]


epoch : 30, train_acc : 0.3558223289315726, acc : 0.36084452271461487. positive_acc : 0.7159309021113244


131it [00:02, 55.44it/s]


epoch : 31, train_acc : 0.3570228091236495, acc : 0.3598848283290863. positive_acc : 0.7140115163147792


131it [00:02, 57.76it/s]


epoch : 32, train_acc : 0.35798319327731093, acc : 0.3598848283290863. positive_acc : 0.7140115163147792


131it [00:02, 57.51it/s]


epoch : 33, train_acc : 0.35678271308523407, acc : 0.36180421710014343. positive_acc : 0.7159309021113244


131it [00:02, 57.51it/s]


epoch : 34, train_acc : 0.35558223289315727, acc : 0.37236082553863525. positive_acc : 0.727447216890595


131it [00:02, 58.25it/s]


epoch : 35, train_acc : 0.35798319327731093, acc : 0.3704414367675781. positive_acc : 0.7207293666026872


131it [00:02, 56.82it/s]


epoch : 36, train_acc : 0.35558223289315727, acc : 0.36372360587120056. positive_acc : 0.7197696737044146


131it [00:02, 56.53it/s]


epoch : 37, train_acc : 0.3589435774309724, acc : 0.36948174238204956. positive_acc : 0.7207293666026872


131it [00:02, 56.46it/s]


epoch : 38, train_acc : 0.35294117647058826, acc : 0.3646833002567291. positive_acc : 0.718809980806142


131it [00:02, 56.65it/s]


epoch : 39, train_acc : 0.36086434573829534, acc : 0.36660268902778625. positive_acc : 0.7197696737044146


131it [00:02, 56.47it/s]


epoch : 40, train_acc : 0.3577430972388956, acc : 0.3704414367675781. positive_acc : 0.7216890595009597


131it [00:02, 56.03it/s]


epoch : 41, train_acc : 0.35990396158463384, acc : 0.36084452271461487. positive_acc : 0.7140115163147792


131it [00:02, 56.72it/s]


epoch : 42, train_acc : 0.36398559423769505, acc : 0.3646833002567291. positive_acc : 0.7178502879078695


131it [00:02, 57.53it/s]


epoch : 43, train_acc : 0.36014405762304924, acc : 0.36660268902778625. positive_acc : 0.7178502879078695


131it [00:02, 56.98it/s]


epoch : 44, train_acc : 0.365906362545018, acc : 0.36948174238204956. positive_acc : 0.7178502879078695


131it [00:02, 57.10it/s]


epoch : 45, train_acc : 0.36614645858343337, acc : 0.362763911485672. positive_acc : 0.716890595009597


131it [00:02, 55.87it/s]


epoch : 46, train_acc : 0.3671068427370948, acc : 0.3733205199241638. positive_acc : 0.7264875239923224


131it [00:02, 56.32it/s]


epoch : 47, train_acc : 0.3611044417767107, acc : 0.3733205199241638. positive_acc : 0.7072936660268714


131it [00:02, 56.56it/s]


epoch : 48, train_acc : 0.36422569027611046, acc : 0.3714011311531067. positive_acc : 0.7140115163147792


131it [00:02, 56.39it/s]


epoch : 49, train_acc : 0.3663865546218487, acc : 0.37236082553863525. positive_acc : 0.7178502879078695


131it [00:02, 57.25it/s]


epoch : 50, train_acc : 0.3680672268907563, acc : 0.3790786862373352. positive_acc : 0.7178502879078695


131it [00:02, 56.76it/s]


epoch : 51, train_acc : 0.3702280912364946, acc : 0.3675623834133148. positive_acc : 0.710172744721689


131it [00:02, 56.08it/s]


epoch : 52, train_acc : 0.37262905162064824, acc : 0.3742802143096924. positive_acc : 0.7140115163147792


131it [00:02, 55.68it/s]


epoch : 53, train_acc : 0.3675870348139256, acc : 0.3656429946422577. positive_acc : 0.7063339731285988


131it [00:02, 56.64it/s]


epoch : 54, train_acc : 0.36542617046818726, acc : 0.38003838062286377. positive_acc : 0.7178502879078695


131it [00:02, 55.55it/s]


epoch : 55, train_acc : 0.3723889555822329, acc : 0.36660268902778625. positive_acc : 0.7159309021113244


131it [00:02, 56.02it/s]


epoch : 56, train_acc : 0.3752701080432173, acc : 0.36660268902778625. positive_acc : 0.7120921305182342


131it [00:02, 56.90it/s]


epoch : 57, train_acc : 0.3716686674669868, acc : 0.37236082553863525. positive_acc : 0.7140115163147792


131it [00:02, 56.76it/s]


epoch : 58, train_acc : 0.37286914765906365, acc : 0.362763911485672. positive_acc : 0.7111324376199616


131it [00:02, 55.89it/s]


epoch : 59, train_acc : 0.37070828331332534, acc : 0.3598848283290863. positive_acc : 0.7053742802303263


131it [00:02, 56.79it/s]


epoch : 60, train_acc : 0.37406962785114045, acc : 0.368522047996521. positive_acc : 0.7149712092130518


131it [00:02, 55.68it/s]


epoch : 61, train_acc : 0.3771908763505402, acc : 0.3656429946422577. positive_acc : 0.7072936660268714


131it [00:02, 56.72it/s]


epoch : 62, train_acc : 0.375750300120048, acc : 0.36084452271461487. positive_acc : 0.7092130518234165


131it [00:02, 55.92it/s]


epoch : 63, train_acc : 0.36998799519807923, acc : 0.36660268902778625. positive_acc : 0.708253358925144


131it [00:02, 56.41it/s]


epoch : 64, train_acc : 0.3786314525810324, acc : 0.3656429946422577. positive_acc : 0.7072936660268714


131it [00:02, 56.75it/s]


epoch : 65, train_acc : 0.3865546218487395, acc : 0.35604605078697205. positive_acc : 0.7005758157389635


131it [00:02, 57.42it/s]


epoch : 66, train_acc : 0.3786314525810324, acc : 0.3646833002567291. positive_acc : 0.7130518234165067


131it [00:02, 57.01it/s]


epoch : 67, train_acc : 0.3810324129651861, acc : 0.3598848283290863. positive_acc : 0.7111324376199616


131it [00:02, 55.83it/s]


epoch : 68, train_acc : 0.38079231692677074, acc : 0.3656429946422577. positive_acc : 0.7111324376199616


131it [00:02, 56.41it/s]


epoch : 69, train_acc : 0.3723889555822329, acc : 0.35604605078697205. positive_acc : 0.7015355086372361


131it [00:02, 55.30it/s]


epoch : 70, train_acc : 0.3858343337334934, acc : 0.3714011311531067. positive_acc : 0.7130518234165067


131it [00:02, 56.83it/s]


epoch : 71, train_acc : 0.37671068427370946, acc : 0.3656429946422577. positive_acc : 0.7120921305182342


131it [00:02, 57.53it/s]


epoch : 72, train_acc : 0.3903961584633854, acc : 0.362763911485672. positive_acc : 0.710172744721689


131it [00:02, 57.03it/s]


epoch : 73, train_acc : 0.3879951980792317, acc : 0.37236082553863525. positive_acc : 0.7197696737044146


131it [00:02, 57.09it/s]


epoch : 74, train_acc : 0.3872749099639856, acc : 0.3790786862373352. positive_acc : 0.7226487523992322


131it [00:02, 55.44it/s]


epoch : 75, train_acc : 0.3851140456182473, acc : 0.3646833002567291. positive_acc : 0.7092130518234165


131it [00:02, 56.35it/s]


epoch : 76, train_acc : 0.38391356542617044, acc : 0.36660268902778625. positive_acc : 0.716890595009597


131it [00:02, 56.81it/s]


epoch : 77, train_acc : 0.38679471788715486, acc : 0.3675623834133148. positive_acc : 0.710172744721689


131it [00:02, 55.58it/s]


epoch : 78, train_acc : 0.38895558223289317, acc : 0.3771592974662781. positive_acc : 0.718809980806142


131it [00:02, 56.82it/s]


epoch : 79, train_acc : 0.3879951980792317, acc : 0.3742802143096924. positive_acc : 0.7130518234165067


131it [00:02, 55.71it/s]


epoch : 80, train_acc : 0.38871548619447777, acc : 0.38867560029029846. positive_acc : 0.72552783109405


131it [00:02, 57.06it/s]


epoch : 81, train_acc : 0.39399759903961584, acc : 0.36660268902778625. positive_acc : 0.7130518234165067


131it [00:02, 56.71it/s]


epoch : 82, train_acc : 0.38607442977190876, acc : 0.36948174238204956. positive_acc : 0.7053742802303263


131it [00:02, 56.71it/s]


epoch : 83, train_acc : 0.38751500600240096, acc : 0.36372360587120056. positive_acc : 0.7044145873320538


131it [00:02, 57.47it/s]


epoch : 84, train_acc : 0.39279711884753904, acc : 0.3656429946422577. positive_acc : 0.708253358925144


131it [00:02, 56.68it/s]


epoch : 85, train_acc : 0.38871548619447777, acc : 0.3714011311531067. positive_acc : 0.7044145873320538


131it [00:02, 56.39it/s]


epoch : 86, train_acc : 0.39663865546218485, acc : 0.3790786862373352. positive_acc : 0.716890595009597


131it [00:02, 56.83it/s]


epoch : 87, train_acc : 0.38391356542617044, acc : 0.3646833002567291. positive_acc : 0.7053742802303263


131it [00:02, 56.55it/s]


epoch : 88, train_acc : 0.39015606242497, acc : 0.36660268902778625. positive_acc : 0.708253358925144


131it [00:02, 57.21it/s]


epoch : 89, train_acc : 0.3863145258103241, acc : 0.35892513394355774. positive_acc : 0.7005758157389635


131it [00:02, 56.96it/s]


epoch : 90, train_acc : 0.39183673469387753, acc : 0.3675623834133148. positive_acc : 0.7053742802303263


131it [00:02, 56.38it/s]


epoch : 91, train_acc : 0.3908763505402161, acc : 0.3733205199241638. positive_acc : 0.710172744721689


131it [00:02, 55.99it/s]


epoch : 92, train_acc : 0.3944777911164466, acc : 0.3742802143096924. positive_acc : 0.7216890595009597


131it [00:02, 57.64it/s]


epoch : 93, train_acc : 0.39207683073229294, acc : 0.37523990869522095. positive_acc : 0.710172744721689


131it [00:02, 56.18it/s]


epoch : 94, train_acc : 0.3884753901560624, acc : 0.3733205199241638. positive_acc : 0.7197696737044146


131it [00:02, 57.18it/s]


epoch : 95, train_acc : 0.3930372148859544, acc : 0.3704414367675781. positive_acc : 0.7140115163147792


131it [00:02, 57.20it/s]


epoch : 96, train_acc : 0.3990396158463385, acc : 0.3742802143096924. positive_acc : 0.710172744721689


131it [00:02, 57.13it/s]


epoch : 97, train_acc : 0.3915966386554622, acc : 0.3714011311531067. positive_acc : 0.7130518234165067


131it [00:02, 56.62it/s]


epoch : 98, train_acc : 0.39063625450180073, acc : 0.37236082553863525. positive_acc : 0.7072936660268714


131it [00:02, 56.15it/s]


epoch : 99, train_acc : 0.39327731092436974, acc : 0.362763911485672. positive_acc : 0.7063339731285988


In [24]:
model_E = ViT_LRP_nan_excluded.VisionTransformer(sample_datas,split_list,seq_len=6, num_classes=9, embed_dim=32*3, depth=6,
                 num_heads=8, mlp_ratio=4., qkv_bias=False, mlp_head=False, drop_rate=0.2, attn_drop_rate=0.2)
model_E = model_E.to(device)

train_net(model_E,train_loader,test_loader,n_iter=100,device=device,mode='E',lr=0.0001,optimizer_cls = optim.AdamW)


131it [00:06, 19.72it/s]


epoch : 0, train_acc : 0.22593037214885955, acc : 0.27639153599739075. positive_acc : 0.5777351247600768


131it [00:07, 17.21it/s]


epoch : 1, train_acc : 0.3085234093637455, acc : 0.3272552788257599. positive_acc : 0.6525911708253359


131it [00:09, 13.11it/s]


epoch : 2, train_acc : 0.32941176470588235, acc : 0.35604605078697205. positive_acc : 0.6928982725527831


131it [00:09, 13.42it/s]


epoch : 3, train_acc : 0.34213685474189676, acc : 0.33685219287872314. positive_acc : 0.6737044145873321


131it [00:09, 13.63it/s]


epoch : 4, train_acc : 0.36206482593037215, acc : 0.3435700535774231. positive_acc : 0.6746641074856046


131it [00:09, 13.55it/s]


epoch : 5, train_acc : 0.356062424969988, acc : 0.3579654395580292. positive_acc : 0.6900191938579654


131it [00:10, 13.02it/s]


epoch : 6, train_acc : 0.35942376950780314, acc : 0.35892513394355774. positive_acc : 0.699616122840691


131it [00:10, 12.96it/s]


epoch : 7, train_acc : 0.3671068427370948, acc : 0.35028788447380066. positive_acc : 0.690978886756238


131it [00:09, 13.25it/s]


epoch : 8, train_acc : 0.3695078031212485, acc : 0.36180421710014343. positive_acc : 0.6957773512476008


131it [00:10, 12.70it/s]


epoch : 9, train_acc : 0.37767106842737097, acc : 0.3512475788593292. positive_acc : 0.6919385796545106


131it [00:10, 12.96it/s]


epoch : 10, train_acc : 0.37214885954381755, acc : 0.3435700535774231. positive_acc : 0.6900191938579654


131it [00:10, 12.65it/s]


epoch : 11, train_acc : 0.3764705882352941, acc : 0.3435700535774231. positive_acc : 0.6756238003838771


131it [00:09, 13.30it/s]


epoch : 12, train_acc : 0.37599039615846336, acc : 0.3512475788593292. positive_acc : 0.6890595009596929


131it [00:09, 13.42it/s]


epoch : 13, train_acc : 0.37671068427370946, acc : 0.3493282198905945. positive_acc : 0.6967370441458733


131it [00:09, 13.58it/s]


epoch : 14, train_acc : 0.3836734693877551, acc : 0.3512475788593292. positive_acc : 0.6871401151631478


131it [00:08, 14.94it/s]


epoch : 15, train_acc : 0.3843937575030012, acc : 0.3541266620159149. positive_acc : 0.6880998080614203


131it [00:07, 17.22it/s]


epoch : 16, train_acc : 0.3971188475390156, acc : 0.3522072732448578. positive_acc : 0.6794625719769674


131it [00:07, 18.12it/s]


epoch : 17, train_acc : 0.39399759903961584, acc : 0.3522072732448578. positive_acc : 0.6948176583493282


76it [00:04, 18.82it/s]


KeyboardInterrupt: 

In [32]:
model_K = ViT_LRP_nan_excluded.VisionTransformer(sample_datas,split_list,seq_len=6, num_classes=9, embed_dim=16*3, depth=8,
                 num_heads=6, mlp_ratio=4., qkv_bias=False, mlp_head=False, drop_rate=0.2, attn_drop_rate=0.2)
model_K = model_K.to(device)

train_net(model_K,train_loader,test_loader,n_iter=100,device=device,mode='K',lr=0.0001,optimizer_cls = optim.AdamW)


131it [00:06, 19.85it/s]


epoch : 0,train_positive_acc : 0.6014405762304922 train_acc : 0.1735894357743097, acc : 0.19577734172344208. positive_acc : 0.6861804222648752


131it [00:06, 19.93it/s]


epoch : 1,train_positive_acc : 0.7258103241296519 train_acc : 0.22713085234093638, acc : 0.2476007640361786. positive_acc : 0.7687140115163148


131it [00:06, 19.77it/s]


epoch : 2,train_positive_acc : 0.7577430972388955 train_acc : 0.2518607442977191, acc : 0.26295584440231323. positive_acc : 0.7783109404990403


131it [00:06, 20.05it/s]


epoch : 3,train_positive_acc : 0.7642256902761104 train_acc : 0.27202881152460984, acc : 0.2696737051010132. positive_acc : 0.7562380038387716


131it [00:06, 20.05it/s]


epoch : 4,train_positive_acc : 0.7702280912364946 train_acc : 0.26746698679471786, acc : 0.2850287854671478. positive_acc : 0.791746641074856


131it [00:06, 19.62it/s]


epoch : 5,train_positive_acc : 0.7726290516206482 train_acc : 0.287875150060024, acc : 0.2783109247684479. positive_acc : 0.7927063339731286


131it [00:06, 19.80it/s]


epoch : 6,train_positive_acc : 0.788235294117647 train_acc : 0.28907563025210087, acc : 0.2715930938720703. positive_acc : 0.7687140115163148


131it [00:06, 19.55it/s]


epoch : 7,train_positive_acc : 0.7726290516206482 train_acc : 0.2900360144057623, acc : 0.2907869517803192. positive_acc : 0.783109404990403


131it [00:06, 19.53it/s]


epoch : 8,train_positive_acc : 0.7947178871548619 train_acc : 0.30156062424969987, acc : 0.23416505753993988. positive_acc : 0.6871401151631478


131it [00:06, 19.42it/s]


epoch : 9,train_positive_acc : 0.7822328931572629 train_acc : 0.3006002400960384, acc : 0.29846447706222534. positive_acc : 0.7840690978886756


131it [00:06, 19.69it/s]


epoch : 10,train_positive_acc : 0.7899159663865546 train_acc : 0.30348139255702283, acc : 0.2754318416118622. positive_acc : 0.7754318618042226


131it [00:06, 19.59it/s]


epoch : 11,train_positive_acc : 0.7959183673469388 train_acc : 0.2957983193277311, acc : 0.27351248264312744. positive_acc : 0.7667946257197696


131it [00:06, 19.61it/s]


epoch : 12,train_positive_acc : 0.7930372148859544 train_acc : 0.29627851140456185, acc : 0.2725527882575989. positive_acc : 0.7303262955854126


131it [00:06, 19.73it/s]


epoch : 13,train_positive_acc : 0.7935174069627852 train_acc : 0.2979591836734694, acc : 0.27063339948654175. positive_acc : 0.800383877159309


131it [00:06, 19.58it/s]


epoch : 14,train_positive_acc : 0.7992797118847539 train_acc : 0.29603841536614645, acc : 0.2859884798526764. positive_acc : 0.7571976967370442


131it [00:06, 19.74it/s]


epoch : 15,train_positive_acc : 0.7987995198079232 train_acc : 0.30012004801920766, acc : 0.28982725739479065. positive_acc : 0.7648752399232246


131it [00:06, 19.61it/s]


epoch : 16,train_positive_acc : 0.812484993997599 train_acc : 0.3078031212484994, acc : 0.28119000792503357. positive_acc : 0.753358925143954


131it [00:06, 19.44it/s]


epoch : 17,train_positive_acc : 0.8031212484993998 train_acc : 0.3090036014405762, acc : 0.27351248264312744. positive_acc : 0.7284069097888676


131it [00:06, 19.64it/s]


epoch : 18,train_positive_acc : 0.798079231692677 train_acc : 0.2998799519807923, acc : 0.2879078686237335. positive_acc : 0.761996161228407


131it [00:06, 19.80it/s]


epoch : 19,train_positive_acc : 0.8141656662665065 train_acc : 0.3183673469387755, acc : 0.2754318416118622. positive_acc : 0.7677543186180422


131it [00:06, 19.73it/s]


epoch : 20,train_positive_acc : 0.8052821128451381 train_acc : 0.3126050420168067, acc : 0.26487523317337036. positive_acc : 0.744721689059501


131it [00:06, 19.75it/s]


epoch : 21,train_positive_acc : 0.8112845138055222 train_acc : 0.32220888355342137, acc : 0.27927061915397644. positive_acc : 0.8013435700575816


131it [00:06, 19.49it/s]


epoch : 22,train_positive_acc : 0.8105642256902761 train_acc : 0.31548619447779114, acc : 0.2744721472263336. positive_acc : 0.7629558541266794


131it [00:06, 19.72it/s]


epoch : 23,train_positive_acc : 0.8172869147659063 train_acc : 0.3186074429771909, acc : 0.2773512303829193. positive_acc : 0.7610364683301344


131it [00:06, 19.38it/s]


epoch : 24,train_positive_acc : 0.8129651860744298 train_acc : 0.3178871548619448, acc : 0.2994241714477539. positive_acc : 0.8023032629558541


131it [00:06, 19.65it/s]


epoch : 25,train_positive_acc : 0.8122448979591836 train_acc : 0.32869147659063624, acc : 0.31190016865730286. positive_acc : 0.7879078694817658


131it [00:06, 19.82it/s]


epoch : 26,train_positive_acc : 0.8230492196878751 train_acc : 0.3169267707082833, acc : 0.3023032546043396. positive_acc : 0.789827255278311


131it [00:06, 19.51it/s]


epoch : 27,train_positive_acc : 0.8165666266506603 train_acc : 0.3186074429771909, acc : 0.280230313539505. positive_acc : 0.7763915547024952


131it [00:06, 19.49it/s]


epoch : 28,train_positive_acc : 0.8316926770708283 train_acc : 0.32100840336134456, acc : 0.2783109247684479. positive_acc : 0.7312859884836852


131it [00:06, 19.81it/s]


epoch : 29,train_positive_acc : 0.8189675870348139 train_acc : 0.31932773109243695, acc : 0.2754318416118622. positive_acc : 0.72552783109405


131it [00:06, 19.57it/s]


epoch : 30,train_positive_acc : 0.8218487394957983 train_acc : 0.3296518607442977, acc : 0.2994241714477539. positive_acc : 0.772552783109405


131it [00:06, 19.71it/s]


epoch : 31,train_positive_acc : 0.8271308523409364 train_acc : 0.32725090036014404, acc : 0.27351248264312744. positive_acc : 0.7495201535508638


131it [00:06, 19.55it/s]


epoch : 32,train_positive_acc : 0.8216086434573829 train_acc : 0.31284513805522207, acc : 0.2850287854671478. positive_acc : 0.7859884836852208


131it [00:06, 19.36it/s]


epoch : 33,train_positive_acc : 0.8290516206482593 train_acc : 0.3346938775510204, acc : 0.26775431632995605. positive_acc : 0.7245681381957774


131it [00:06, 19.33it/s]


epoch : 34,train_positive_acc : 0.834813925570228 train_acc : 0.334453781512605, acc : 0.2562379837036133. positive_acc : 0.7140115163147792


131it [00:06, 19.50it/s]


epoch : 35,train_positive_acc : 0.8384153661464586 train_acc : 0.3289315726290516, acc : 0.2994241714477539. positive_acc : 0.7802303262955854


131it [00:06, 19.46it/s]


epoch : 36,train_positive_acc : 0.8360144057623049 train_acc : 0.33877551020408164, acc : 0.29270634055137634. positive_acc : 0.77447216890595


131it [00:06, 19.61it/s]


epoch : 37,train_positive_acc : 0.8319327731092437 train_acc : 0.33925570228091234, acc : 0.2917466461658478. positive_acc : 0.7735124760076776


131it [00:06, 19.49it/s]


epoch : 38,train_positive_acc : 0.8456182472989195 train_acc : 0.3411764705882353, acc : 0.2639155387878418. positive_acc : 0.7293666026871402


131it [00:06, 19.54it/s]


epoch : 39,train_positive_acc : 0.8331332533013205 train_acc : 0.34357743097238896, acc : 0.2917466461658478. positive_acc : 0.7543186180422264


131it [00:06, 19.56it/s]


epoch : 40,train_positive_acc : 0.8480192076830733 train_acc : 0.3430972388955582, acc : 0.27063339948654175. positive_acc : 0.7264875239923224


131it [00:06, 19.62it/s]


epoch : 41,train_positive_acc : 0.8350540216086435 train_acc : 0.33949579831932775, acc : 0.3071017265319824. positive_acc : 0.7802303262955854


131it [00:06, 19.49it/s]


epoch : 42,train_positive_acc : 0.8412965186074429 train_acc : 0.3471788715486194, acc : 0.2917466461658478. positive_acc : 0.77447216890595


131it [00:06, 19.43it/s]


epoch : 43,train_positive_acc : 0.8477791116446579 train_acc : 0.3483793517406963, acc : 0.3051823377609253. positive_acc : 0.7821497120921305


131it [00:06, 19.66it/s]


epoch : 44,train_positive_acc : 0.8549819927971188 train_acc : 0.3490996398559424, acc : 0.29846447706222534. positive_acc : 0.7687140115163148


131it [00:06, 19.48it/s]


epoch : 45,train_positive_acc : 0.8496998799519808 train_acc : 0.34357743097238896, acc : 0.30422264337539673. positive_acc : 0.7706333973128598


131it [00:06, 19.47it/s]


epoch : 46,train_positive_acc : 0.8547418967587035 train_acc : 0.3469387755102041, acc : 0.2831093966960907. positive_acc : 0.761996161228407


131it [00:06, 19.59it/s]


epoch : 47,train_positive_acc : 0.8468187274909964 train_acc : 0.35150060024009605, acc : 0.29270634055137634. positive_acc : 0.7715930902111324


131it [00:06, 19.54it/s]


epoch : 48,train_positive_acc : 0.8561824729891957 train_acc : 0.34933973589435774, acc : 0.2783109247684479. positive_acc : 0.7581573896353166


131it [00:06, 19.73it/s]


epoch : 49,train_positive_acc : 0.8528211284513806 train_acc : 0.346218487394958, acc : 0.280230313539505. positive_acc : 0.753358925143954


131it [00:06, 19.46it/s]


epoch : 50,train_positive_acc : 0.8605042016806723 train_acc : 0.3570228091236495, acc : 0.29558539390563965. positive_acc : 0.7591170825335892


131it [00:06, 19.71it/s]


epoch : 51,train_positive_acc : 0.8554621848739495 train_acc : 0.35942376950780314, acc : 0.27351248264312744. positive_acc : 0.7476007677543186


131it [00:06, 19.48it/s]


epoch : 52,train_positive_acc : 0.8677070828331332 train_acc : 0.363265306122449, acc : 0.2850287854671478. positive_acc : 0.7332053742802304


131it [00:06, 19.79it/s]


epoch : 53,train_positive_acc : 0.8631452581032413 train_acc : 0.36686674669867947, acc : 0.2696737051010132. positive_acc : 0.7149712092130518


131it [00:06, 19.68it/s]


epoch : 54,train_positive_acc : 0.8641056422569028 train_acc : 0.37334933973589435, acc : 0.280230313539505. positive_acc : 0.7495201535508638


131it [00:06, 19.43it/s]


epoch : 55,train_positive_acc : 0.870108043217287 train_acc : 0.3752701080432173, acc : 0.2783109247684479. positive_acc : 0.772552783109405


131it [00:06, 19.61it/s]


epoch : 56,train_positive_acc : 0.8629051620648259 train_acc : 0.3644657863145258, acc : 0.2879078686237335. positive_acc : 0.7495201535508638


131it [00:06, 19.68it/s]


epoch : 57,train_positive_acc : 0.8696278511404562 train_acc : 0.36470588235294116, acc : 0.2610364556312561. positive_acc : 0.7514395393474088


131it [00:06, 19.44it/s]


epoch : 58,train_positive_acc : 0.8684273709483794 train_acc : 0.36134453781512604, acc : 0.2687140107154846. positive_acc : 0.7370441458733206


131it [00:06, 19.36it/s]


epoch : 59,train_positive_acc : 0.8773109243697479 train_acc : 0.3779111644657863, acc : 0.28214970231056213. positive_acc : 0.7629558541266794


131it [00:06, 19.47it/s]


epoch : 60,train_positive_acc : 0.8715486194477791 train_acc : 0.38199279711884754, acc : 0.2936660051345825. positive_acc : 0.7543186180422264


131it [00:06, 19.45it/s]


epoch : 61,train_positive_acc : 0.8744297719087635 train_acc : 0.3815126050420168, acc : 0.2994241714477539. positive_acc : 0.7543186180422264


131it [00:06, 19.56it/s]


epoch : 62,train_positive_acc : 0.8758703481392557 train_acc : 0.35942376950780314, acc : 0.24664106965065002. positive_acc : 0.7092130518234165


131it [00:06, 19.69it/s]


epoch : 63,train_positive_acc : 0.8811524609843937 train_acc : 0.38271308523409364, acc : 0.259117066860199. positive_acc : 0.7140115163147792


131it [00:06, 19.49it/s]


epoch : 64,train_positive_acc : 0.8775510204081632 train_acc : 0.37551020408163266, acc : 0.2552782893180847. positive_acc : 0.7341650671785028


131it [00:06, 19.69it/s]


epoch : 65,train_positive_acc : 0.8741896758703481 train_acc : 0.3752701080432173, acc : 0.2562379837036133. positive_acc : 0.699616122840691


131it [00:06, 19.63it/s]


epoch : 66,train_positive_acc : 0.8845138055222089 train_acc : 0.37334933973589435, acc : 0.27639153599739075. positive_acc : 0.7476007677543186


131it [00:06, 19.45it/s]


epoch : 67,train_positive_acc : 0.8818727490996399 train_acc : 0.3735894357743097, acc : 0.2754318416118622. positive_acc : 0.753358925143954


131it [00:06, 19.53it/s]


epoch : 68,train_positive_acc : 0.8794717887154861 train_acc : 0.38823529411764707, acc : 0.2639155387878418. positive_acc : 0.6890595009596929


131it [00:06, 19.59it/s]


epoch : 69,train_positive_acc : 0.8785114045618247 train_acc : 0.37070828331332534, acc : 0.2773512303829193. positive_acc : 0.7629558541266794


131it [00:06, 19.26it/s]


epoch : 70,train_positive_acc : 0.8941176470588236 train_acc : 0.382953181272509, acc : 0.2773512303829193. positive_acc : 0.7332053742802304


131it [00:06, 19.62it/s]


epoch : 71,train_positive_acc : 0.885234093637455 train_acc : 0.3872749099639856, acc : 0.29846447706222534. positive_acc : 0.744721689059501


131it [00:06, 19.64it/s]


epoch : 72,train_positive_acc : 0.885234093637455 train_acc : 0.38871548619447777, acc : 0.29846447706222534. positive_acc : 0.7332053742802304


131it [00:06, 19.62it/s]


epoch : 73,train_positive_acc : 0.8917166866746699 train_acc : 0.38607442977190876, acc : 0.2696737051010132. positive_acc : 0.746641074856046


131it [00:06, 19.64it/s]


epoch : 74,train_positive_acc : 0.8921968787515006 train_acc : 0.4016806722689076, acc : 0.280230313539505. positive_acc : 0.744721689059501


131it [00:06, 19.85it/s]


epoch : 75,train_positive_acc : 0.8960384153661465 train_acc : 0.40672268907563025, acc : 0.2946256995201111. positive_acc : 0.7610364683301344


131it [00:06, 19.57it/s]


epoch : 76,train_positive_acc : 0.8957983193277311 train_acc : 0.39663865546218485, acc : 0.28214970231056213. positive_acc : 0.7696737044145874


131it [00:06, 19.53it/s]


epoch : 77,train_positive_acc : 0.892436974789916 train_acc : 0.3937575030012005, acc : 0.28694817423820496. positive_acc : 0.7312859884836852


131it [00:06, 19.50it/s]


epoch : 78,train_positive_acc : 0.897719087635054 train_acc : 0.39543817527010805, acc : 0.2639155387878418. positive_acc : 0.7111324376199616


131it [00:06, 19.07it/s]


epoch : 79,train_positive_acc : 0.8931572629051621 train_acc : 0.39687875150060026, acc : 0.26487523317337036. positive_acc : 0.7044145873320538


131it [00:06, 19.55it/s]


epoch : 80,train_positive_acc : 0.9046818727490996 train_acc : 0.4043217286914766, acc : 0.2667946219444275. positive_acc : 0.6938579654510557


131it [00:06, 19.12it/s]


epoch : 81,train_positive_acc : 0.8953181272509003 train_acc : 0.3951980792316927, acc : 0.2994241714477539. positive_acc : 0.7456813819577736


131it [00:06, 19.72it/s]


epoch : 82,train_positive_acc : 0.8957983193277311 train_acc : 0.39183673469387753, acc : 0.280230313539505. positive_acc : 0.7284069097888676


131it [00:06, 19.59it/s]


epoch : 83,train_positive_acc : 0.9013205282112845 train_acc : 0.41032412965186077, acc : 0.27351248264312744. positive_acc : 0.7351247600767754


131it [00:06, 19.74it/s]


epoch : 84,train_positive_acc : 0.9039615846338536 train_acc : 0.4069627851140456, acc : 0.28406909108161926. positive_acc : 0.7303262955854126


131it [00:06, 19.46it/s]


epoch : 85,train_positive_acc : 0.9003601440576231 train_acc : 0.40024009603841537, acc : 0.27063339948654175. positive_acc : 0.7399232245681382


131it [00:06, 19.68it/s]


epoch : 86,train_positive_acc : 0.8981992797118847 train_acc : 0.3985594237695078, acc : 0.31190016865730286. positive_acc : 0.7408829174664108


131it [00:06, 19.57it/s]


epoch : 87,train_positive_acc : 0.9039615846338536 train_acc : 0.41152460984393757, acc : 0.27351248264312744. positive_acc : 0.7236084452975048


131it [00:06, 19.45it/s]


epoch : 88,train_positive_acc : 0.8996398559423769 train_acc : 0.40480192076830734, acc : 0.2946256995201111. positive_acc : 0.7418426103646834


131it [00:06, 19.57it/s]


epoch : 89,train_positive_acc : 0.9092436974789916 train_acc : 0.42064825930372146, acc : 0.25431862473487854. positive_acc : 0.7072936660268714


131it [00:06, 19.56it/s]


epoch : 90,train_positive_acc : 0.9087635054021609 train_acc : 0.4057623049219688, acc : 0.2831093966960907. positive_acc : 0.761996161228407


131it [00:06, 19.50it/s]


epoch : 91,train_positive_acc : 0.9097238895558223 train_acc : 0.4136854741896759, acc : 0.2610364556312561. positive_acc : 0.7562380038387716


131it [00:06, 19.55it/s]


epoch : 92,train_positive_acc : 0.9150060024009604 train_acc : 0.4256902761104442, acc : 0.2504798471927643. positive_acc : 0.6765834932821497


131it [00:06, 19.63it/s]


epoch : 93,train_positive_acc : 0.9130852340936375 train_acc : 0.4338535414165666, acc : 0.2581573724746704. positive_acc : 0.6794625719769674


131it [00:06, 19.47it/s]


epoch : 94,train_positive_acc : 0.9147659063625451 train_acc : 0.4235294117647059, acc : 0.2667946219444275. positive_acc : 0.716890595009597


131it [00:06, 19.54it/s]


epoch : 95,train_positive_acc : 0.916686674669868 train_acc : 0.4398559423769508, acc : 0.24952015280723572. positive_acc : 0.6737044145873321


131it [00:06, 19.48it/s]


epoch : 96,train_positive_acc : 0.917406962785114 train_acc : 0.424249699879952, acc : 0.3023032546043396. positive_acc : 0.7514395393474088


131it [00:06, 19.55it/s]


epoch : 97,train_positive_acc : 0.9147659063625451 train_acc : 0.42809123649459785, acc : 0.2715930938720703. positive_acc : 0.718809980806142


131it [00:06, 19.33it/s]


epoch : 98,train_positive_acc : 0.9133253301320529 train_acc : 0.42953181272509006, acc : 0.2831093966960907. positive_acc : 0.6948176583493282


131it [00:06, 19.64it/s]


epoch : 99,train_positive_acc : 0.9162064825930372 train_acc : 0.44129651860744296, acc : 0.2667946219444275. positive_acc : 0.708253358925144


In [None]:
model_K = ViT_LRP_nan_excluded.VisionTransformer(sample_datas,split_list,seq_len=6, num_classes=9, embed_dim=16*3, depth=8,
                 num_heads=6, mlp_ratio=4., qkv_bias=False, mlp_head=False, drop_rate=0.2, attn_drop_rate=0.2)
model_K = model_K.to(device)

train_net(model_K,train_loader,test_loader,n_iter=100,device=device,mode='K',lr=0.0001,optimizer_cls = optim.AdamW)


131it [00:06, 20.15it/s]


epoch : 0,train_positive_acc : 0.5707082833133253 train_acc : 0.16806722689075632, acc : 0.23128598928451538. positive_acc : 0.7428023032629558


131it [00:06, 20.24it/s]


epoch : 1,train_positive_acc : 0.7205282112845138 train_acc : 0.22737094837935173, acc : 0.2418425977230072. positive_acc : 0.7418426103646834


131it [00:06, 20.05it/s]


epoch : 2,train_positive_acc : 0.7414165666266507 train_acc : 0.24201680672268908, acc : 0.2447216808795929. positive_acc : 0.736084452975048


131it [00:06, 20.19it/s]


epoch : 3,train_positive_acc : 0.7527010804321729 train_acc : 0.25618247298919566, acc : 0.2504798471927643. positive_acc : 0.7696737044145874


131it [00:06, 20.09it/s]


epoch : 4,train_positive_acc : 0.7683073229291717 train_acc : 0.2653061224489796, acc : 0.24280229210853577. positive_acc : 0.7149712092130518


131it [00:06, 19.96it/s]


epoch : 5,train_positive_acc : 0.765906362545018 train_acc : 0.27418967587034815, acc : 0.2696737051010132. positive_acc : 0.7389635316698656


131it [00:06, 20.07it/s]


epoch : 6,train_positive_acc : 0.7779111644657863 train_acc : 0.28187274909963983, acc : 0.27927061915397644. positive_acc : 0.7869481765834933


131it [00:06, 20.29it/s]


epoch : 7,train_positive_acc : 0.7939975990396159 train_acc : 0.2965186074429772, acc : 0.2744721472263336. positive_acc : 0.783109404990403


131it [00:06, 19.77it/s]


epoch : 8,train_positive_acc : 0.7865546218487395 train_acc : 0.2799519807923169, acc : 0.28406909108161926. positive_acc : 0.755278310940499


131it [00:06, 19.79it/s]


epoch : 9,train_positive_acc : 0.7870348139255702 train_acc : 0.28883553421368546, acc : 0.280230313539505. positive_acc : 0.8032629558541267


131it [00:06, 19.86it/s]


epoch : 10,train_positive_acc : 0.7915966386554621 train_acc : 0.2837935174069628, acc : 0.28694817423820496. positive_acc : 0.7802303262955854


131it [00:06, 20.02it/s]


epoch : 11,train_positive_acc : 0.792076830732293 train_acc : 0.28811524609843936, acc : 0.2687140107154846. positive_acc : 0.781190019193858


131it [00:06, 19.89it/s]


epoch : 12,train_positive_acc : 0.7987995198079232 train_acc : 0.30948379351740696, acc : 0.2859884798526764. positive_acc : 0.7840690978886756


131it [00:06, 20.04it/s]


epoch : 13,train_positive_acc : 0.7915966386554621 train_acc : 0.30156062424969987, acc : 0.2917466461658478. positive_acc : 0.7994241842610365


131it [00:06, 19.85it/s]


epoch : 14,train_positive_acc : 0.8 train_acc : 0.297719087635054, acc : 0.28694817423820496. positive_acc : 0.7715930902111324


131it [00:06, 19.97it/s]


epoch : 15,train_positive_acc : 0.7942376950780312 train_acc : 0.3078031212484994, acc : 0.2725527882575989. positive_acc : 0.7658349328214972


131it [00:06, 19.66it/s]


epoch : 16,train_positive_acc : 0.8146458583433374 train_acc : 0.3070828331332533, acc : 0.2917466461658478. positive_acc : 0.7610364683301344


131it [00:06, 19.81it/s]


epoch : 17,train_positive_acc : 0.8036014405762305 train_acc : 0.3044417767106843, acc : 0.2975047826766968. positive_acc : 0.7869481765834933


131it [00:06, 19.74it/s]


epoch : 18,train_positive_acc : 0.8115246098439376 train_acc : 0.3092436974789916, acc : 0.26775431632995605. positive_acc : 0.772552783109405


131it [00:06, 19.83it/s]


epoch : 19,train_positive_acc : 0.8105642256902761 train_acc : 0.3056422569027611, acc : 0.2754318416118622. positive_acc : 0.7610364683301344


83it [00:04, 19.36it/s]


KeyboardInterrupt: 

In [38]:
model_K = ViT_LRP_nan_excluded.VisionTransformer(sample_datas,split_list,seq_len=6, num_classes=9, embed_dim=36*3, depth=3,
                 num_heads=6, mlp_ratio=4., qkv_bias=False, mlp_head=True, drop_rate=0.2, attn_drop_rate=0.2)
model_K = model_K.to(device)

train_net(model_K,train_loader,test_loader,n_iter=100,device=device,mode='K',lr=0.0001,optimizer_cls = optim.AdamW)


131it [00:08, 15.16it/s]


epoch : 0, train_acc : 0.25618247298919566, acc : 0.329174667596817. positive_acc : 0.699616122840691


131it [00:08, 15.54it/s]


epoch : 1, train_acc : 0.32148859543817526, acc : 0.3330134153366089. positive_acc : 0.7092130518234165


131it [00:08, 15.39it/s]


epoch : 2, train_acc : 0.3289315726290516, acc : 0.3378118872642517. positive_acc : 0.7514395393474088


131it [00:09, 13.57it/s]


epoch : 3, train_acc : 0.3378151260504202, acc : 0.3464491367340088. positive_acc : 0.7428023032629558


131it [00:07, 17.45it/s]


epoch : 4, train_acc : 0.341656662665066, acc : 0.3435700535774231. positive_acc : 0.710172744721689


131it [00:18,  7.12it/s]


epoch : 5, train_acc : 0.34861944777911164, acc : 0.3512475788593292. positive_acc : 0.727447216890595


131it [00:08, 14.58it/s]


epoch : 6, train_acc : 0.3426170468187275, acc : 0.36180421710014343. positive_acc : 0.7293666026871402


131it [00:17,  7.32it/s]


epoch : 7, train_acc : 0.3584633853541417, acc : 0.3301343619823456. positive_acc : 0.7072936660268714


131it [00:18,  7.01it/s]


epoch : 8, train_acc : 0.3663865546218487, acc : 0.32053741812705994. positive_acc : 0.7418426103646834


131it [00:05, 22.15it/s]


epoch : 9, train_acc : 0.3575030012004802, acc : 0.3243761956691742. positive_acc : 0.6976967370441459


131it [00:05, 24.14it/s]


epoch : 10, train_acc : 0.3697478991596639, acc : 0.36084452271461487. positive_acc : 0.7504798464491362


131it [00:05, 24.05it/s]


epoch : 11, train_acc : 0.36398559423769505, acc : 0.3243761956691742. positive_acc : 0.7312859884836852


131it [00:05, 24.06it/s]


epoch : 12, train_acc : 0.3678271308523409, acc : 0.34261035919189453. positive_acc : 0.7063339731285988


131it [00:05, 23.65it/s]


epoch : 13, train_acc : 0.38343337334933975, acc : 0.3675623834133148. positive_acc : 0.7504798464491362


131it [00:05, 23.89it/s]


epoch : 14, train_acc : 0.37671068427370946, acc : 0.3598848283290863. positive_acc : 0.7667946257197696


131it [00:05, 24.16it/s]


epoch : 15, train_acc : 0.37623049219687876, acc : 0.3435700535774231. positive_acc : 0.7332053742802304


131it [00:05, 24.18it/s]


epoch : 16, train_acc : 0.4007202881152461, acc : 0.2917466461658478. positive_acc : 0.6861804222648752


131it [00:05, 23.89it/s]


epoch : 17, train_acc : 0.3903961584633854, acc : 0.280230313539505. positive_acc : 0.6708253358925144


131it [00:05, 23.67it/s]


epoch : 18, train_acc : 0.38055222088835533, acc : 0.3435700535774231. positive_acc : 0.7399232245681382


131it [00:05, 24.04it/s]


epoch : 19, train_acc : 0.40672268907563025, acc : 0.35892513394355774. positive_acc : 0.7264875239923224


131it [00:05, 24.27it/s]


epoch : 20, train_acc : 0.39735894357743096, acc : 0.3493282198905945. positive_acc : 0.7591170825335892


131it [00:05, 23.75it/s]


epoch : 21, train_acc : 0.40336134453781514, acc : 0.36084452271461487. positive_acc : 0.7543186180422264


131it [00:05, 23.94it/s]


epoch : 22, train_acc : 0.40456182472989194, acc : 0.368522047996521. positive_acc : 0.7485604606525912


131it [00:05, 23.82it/s]


epoch : 23, train_acc : 0.4297719087635054, acc : 0.3272552788257599. positive_acc : 0.7207293666026872


131it [00:05, 23.89it/s]


epoch : 24, train_acc : 0.43337334933973587, acc : 0.32245680689811707. positive_acc : 0.7293666026871402


131it [00:05, 23.74it/s]


epoch : 25, train_acc : 0.4381752701080432, acc : 0.32821497321128845. positive_acc : 0.6938579654510557


131it [00:06, 21.60it/s]


epoch : 26, train_acc : 0.4350540216086435, acc : 0.3214971125125885. positive_acc : 0.6976967370441459


131it [00:07, 16.40it/s]


epoch : 27, train_acc : 0.4530612244897959, acc : 0.33973127603530884. positive_acc : 0.7178502879078695


63it [00:05, 12.22it/s]


KeyboardInterrupt: 

In [6]:
model_K = ViT_LRP_nan_excluded.VisionTransformer(sample_datas,split_list,seq_len=6, num_classes=9, embed_dim=32*3, depth=5,
                 num_heads=6, mlp_ratio=4., qkv_bias=False, mlp_head=False, drop_rate=0.2, attn_drop_rate=0.1)
model_K = model_K.to(device)

train_net(model_K,train_loader,test_loader,n_iter=100,device=device,mode='K',lr=0.0001,optimizer_cls = optim.AdamW)

131it [00:04, 29.68it/s]


epoch : 0, train_acc : 0.23169267707082833, acc : 0.3378118872642517. positive_acc : 0.690978886756238


131it [00:04, 29.62it/s]


epoch : 1, train_acc : 0.31020408163265306, acc : 0.33397310972213745. positive_acc : 0.6957773512476008


131it [00:04, 30.03it/s]


epoch : 2, train_acc : 0.32941176470588235, acc : 0.30422264337539673. positive_acc : 0.6669865642994242


131it [00:04, 29.97it/s]


epoch : 3, train_acc : 0.33373349339735897, acc : 0.3157389461994171. positive_acc : 0.6765834932821497


131it [00:04, 30.00it/s]


epoch : 4, train_acc : 0.3382953181272509, acc : 0.33109402656555176. positive_acc : 0.7063339731285988


131it [00:04, 29.87it/s]


epoch : 5, train_acc : 0.34933973589435774, acc : 0.334932804107666. positive_acc : 0.7332053742802304


131it [00:04, 29.94it/s]


epoch : 6, train_acc : 0.3536614645858343, acc : 0.3099808096885681. positive_acc : 0.6794625719769674


131it [00:04, 30.11it/s]


epoch : 7, train_acc : 0.35126050420168065, acc : 0.3320537209510803. positive_acc : 0.7159309021113244


131it [00:04, 29.58it/s]


epoch : 8, train_acc : 0.35198079231692675, acc : 0.32245680689811707. positive_acc : 0.7120921305182342


131it [00:04, 30.23it/s]


epoch : 9, train_acc : 0.36086434573829534, acc : 0.32245680689811707. positive_acc : 0.6967370441458733


131it [00:04, 29.67it/s]


epoch : 10, train_acc : 0.3683073229291717, acc : 0.31477925181388855. positive_acc : 0.6880998080614203


131it [00:04, 29.78it/s]


epoch : 11, train_acc : 0.3611044417767107, acc : 0.34452974796295166. positive_acc : 0.7159309021113244


131it [00:04, 29.95it/s]


epoch : 12, train_acc : 0.3671068427370948, acc : 0.329174667596817. positive_acc : 0.7072936660268714


131it [00:04, 29.75it/s]


epoch : 13, train_acc : 0.3558223289315726, acc : 0.3454894423484802. positive_acc : 0.7207293666026872


131it [00:04, 29.70it/s]


epoch : 14, train_acc : 0.37262905162064824, acc : 0.35316696763038635. positive_acc : 0.7293666026871402


131it [00:04, 29.26it/s]


epoch : 15, train_acc : 0.3872749099639856, acc : 0.34165066480636597. positive_acc : 0.7226487523992322


131it [00:04, 29.91it/s]


epoch : 16, train_acc : 0.37815126050420167, acc : 0.3483685255050659. positive_acc : 0.7245681381957774


131it [00:04, 30.06it/s]


epoch : 17, train_acc : 0.38751500600240096, acc : 0.3358924984931946. positive_acc : 0.72552783109405


131it [00:04, 29.95it/s]


epoch : 18, train_acc : 0.37262905162064824, acc : 0.3358924984931946. positive_acc : 0.7159309021113244


131it [00:04, 29.76it/s]


epoch : 19, train_acc : 0.39543817527010805, acc : 0.32053741812705994. positive_acc : 0.7092130518234165


131it [00:04, 29.77it/s]


epoch : 20, train_acc : 0.3793517406962785, acc : 0.3435700535774231. positive_acc : 0.7341650671785028


131it [00:04, 29.83it/s]


epoch : 21, train_acc : 0.39471788715486195, acc : 0.32341650128364563. positive_acc : 0.7053742802303263


131it [00:04, 29.98it/s]


epoch : 22, train_acc : 0.3971188475390156, acc : 0.2831093966960907. positive_acc : 0.6497120921305183


131it [00:04, 30.07it/s]


epoch : 23, train_acc : 0.39063625450180073, acc : 0.3023032546043396. positive_acc : 0.6583493282149712


131it [00:04, 29.97it/s]


epoch : 24, train_acc : 0.40528211284513804, acc : 0.3378118872642517. positive_acc : 0.7332053742802304


131it [00:04, 29.71it/s]


epoch : 25, train_acc : 0.4050420168067227, acc : 0.3262955844402313. positive_acc : 0.6957773512476008


131it [00:04, 29.59it/s]


epoch : 26, train_acc : 0.41008403361344536, acc : 0.32053741812705994. positive_acc : 0.7159309021113244


131it [00:04, 29.28it/s]


epoch : 27, train_acc : 0.41008403361344536, acc : 0.31381955742836. positive_acc : 0.6948176583493282


131it [00:04, 29.32it/s]


epoch : 28, train_acc : 0.41584633853541414, acc : 0.3301343619823456. positive_acc : 0.7130518234165067


131it [00:04, 29.56it/s]


epoch : 29, train_acc : 0.4228091236494598, acc : 0.3214971125125885. positive_acc : 0.6967370441458733


131it [00:04, 29.88it/s]


epoch : 30, train_acc : 0.414405762304922, acc : 0.3464491367340088. positive_acc : 0.7264875239923224


131it [00:04, 29.49it/s]


epoch : 31, train_acc : 0.42208883553421367, acc : 0.3166986405849457. positive_acc : 0.7072936660268714


131it [00:04, 29.72it/s]


epoch : 32, train_acc : 0.42521008403361343, acc : 0.32053741812705994. positive_acc : 0.718809980806142


131it [00:04, 29.91it/s]


epoch : 33, train_acc : 0.4316926770708283, acc : 0.32053741812705994. positive_acc : 0.7130518234165067


131it [00:04, 29.59it/s]


epoch : 34, train_acc : 0.42521008403361343, acc : 0.334932804107666. positive_acc : 0.753358925143954


131it [00:04, 29.83it/s]


epoch : 35, train_acc : 0.43601440576230494, acc : 0.3166986405849457. positive_acc : 0.7216890595009597


131it [00:04, 29.70it/s]


epoch : 36, train_acc : 0.4388955582232893, acc : 0.30038386583328247. positive_acc : 0.6813819577735125


131it [00:04, 29.70it/s]


epoch : 37, train_acc : 0.44537815126050423, acc : 0.3387715816497803. positive_acc : 0.718809980806142


131it [00:04, 29.81it/s]


epoch : 38, train_acc : 0.4429771908763505, acc : 0.32245680689811707. positive_acc : 0.699616122840691


131it [00:04, 30.03it/s]


epoch : 39, train_acc : 0.44849939975990394, acc : 0.329174667596817. positive_acc : 0.727447216890595


131it [00:04, 29.58it/s]


epoch : 40, train_acc : 0.4509003601440576, acc : 0.32245680689811707. positive_acc : 0.7456813819577736


131it [00:04, 28.87it/s]


epoch : 41, train_acc : 0.4509003601440576, acc : 0.30422264337539673. positive_acc : 0.6900191938579654


131it [00:04, 30.05it/s]


epoch : 42, train_acc : 0.46938775510204084, acc : 0.31477925181388855. positive_acc : 0.6737044145873321


84it [00:02, 29.84it/s]


KeyboardInterrupt: 

In [None]:
model_K = ViT_LRP_nan_excluded.VisionTransformer(sample_datas,split_list,seq_len=6, num_classes=9, embed_dim=64*3, depth=6,
                 num_heads=6, mlp_ratio=4., qkv_bias=False, mlp_head=False, drop_rate=0.2, attn_drop_rate=0.2)
model_K = model_K.to(device)

train_net(model_K,train_loader,test_loader,n_iter=100,device=device,mode='K',lr=0.0001,optimizer_cls = optim.AdamW)

In [10]:
xx,(label_E,label_K,label_M) = next(iter(train_loader))

In [23]:
class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(233,80),
            nn.ReLU(),
            nn.Linear(80,30),
            nn.ReLU(),
            nn.Linear(30,9)
        )
    def forward(self,x):
        x[torch.isnan(x)] = 0
        return self.mlp(x)

In [24]:
mlp = MLP()

In [25]:
mlp = mlp.to(device)

In [26]:
train_net(mlp,train_loader,test_loader,n_iter=100,device=device,mode='K',lr=0.0001)

131it [00:00, 179.46it/s]


epoch : 0,train_positive_acc : 0.4391356542617047 train_acc : 0.11692677070828332, acc : 0.17946256697177887. positive_acc : 0.5422264875239923


131it [00:00, 171.24it/s]


epoch : 1,train_positive_acc : 0.5181272509003602 train_acc : 0.17575030012004803, acc : 0.1650671809911728. positive_acc : 0.5873320537428023


131it [00:00, 175.49it/s]


epoch : 2,train_positive_acc : 0.5454981992797119 train_acc : 0.15990396158463385, acc : 0.16602686047554016. positive_acc : 0.5988483685220729


131it [00:00, 177.16it/s]


epoch : 3,train_positive_acc : 0.5831932773109244 train_acc : 0.16710684273709484, acc : 0.1871401071548462. positive_acc : 0.6487523992322457


131it [00:00, 178.56it/s]


epoch : 4,train_positive_acc : 0.6782713085234093 train_acc : 0.1992797118847539, acc : 0.22456812858581543. positive_acc : 0.7245681381957774


131it [00:00, 186.76it/s]


epoch : 5,train_positive_acc : 0.729171668667467 train_acc : 0.21680672268907564, acc : 0.22936660051345825. positive_acc : 0.7341650671785028


131it [00:00, 172.32it/s]


epoch : 6,train_positive_acc : 0.7529411764705882 train_acc : 0.22713085234093638, acc : 0.2389635294675827. positive_acc : 0.753358925143954


131it [00:00, 178.66it/s]


epoch : 7,train_positive_acc : 0.7565426170468187 train_acc : 0.23433373349339737, acc : 0.2504798471927643. positive_acc : 0.7581573896353166


131it [00:00, 165.32it/s]


epoch : 8,train_positive_acc : 0.7642256902761104 train_acc : 0.234093637454982, acc : 0.23992322385311127. positive_acc : 0.761996161228407


131it [00:00, 198.77it/s]


epoch : 9,train_positive_acc : 0.7690276110444177 train_acc : 0.2388955582232893, acc : 0.25431862473487854. positive_acc : 0.7792706333973128


131it [00:00, 193.62it/s]


epoch : 10,train_positive_acc : 0.7779111644657863 train_acc : 0.25474189675870346, acc : 0.25335893034935. positive_acc : 0.7715930902111324


131it [00:00, 194.30it/s]


epoch : 11,train_positive_acc : 0.775750300120048 train_acc : 0.2566626650660264, acc : 0.25335893034935. positive_acc : 0.7754318618042226


131it [00:00, 196.23it/s]


epoch : 12,train_positive_acc : 0.7855942376950781 train_acc : 0.25570228091236497, acc : 0.26007676124572754. positive_acc : 0.7667946257197696


131it [00:00, 195.27it/s]


epoch : 13,train_positive_acc : 0.7903961584633854 train_acc : 0.2657863145258103, acc : 0.26775431632995605. positive_acc : 0.7763915547024952


131it [00:00, 203.43it/s]


epoch : 14,train_positive_acc : 0.792076830732293 train_acc : 0.27010804321728693, acc : 0.2667946219444275. positive_acc : 0.7850287907869482


131it [00:00, 196.54it/s]


epoch : 15,train_positive_acc : 0.7963985594237695 train_acc : 0.27274909963985594, acc : 0.2687140107154846. positive_acc : 0.7754318618042226


131it [00:00, 197.24it/s]


epoch : 16,train_positive_acc : 0.797358943577431 train_acc : 0.2729891956782713, acc : 0.26199615001678467. positive_acc : 0.7792706333973128


131it [00:00, 191.59it/s]


epoch : 17,train_positive_acc : 0.7956782713085234 train_acc : 0.28211284513805523, acc : 0.2667946219444275. positive_acc : 0.7859884836852208


131it [00:00, 194.93it/s]


epoch : 18,train_positive_acc : 0.7975990396158463 train_acc : 0.2785114045618247, acc : 0.2725527882575989. positive_acc : 0.7792706333973128


131it [00:00, 197.83it/s]


epoch : 19,train_positive_acc : 0.7995198079231692 train_acc : 0.2804321728691477, acc : 0.2658349275588989. positive_acc : 0.7802303262955854


131it [00:00, 198.94it/s]


epoch : 20,train_positive_acc : 0.8014405762304923 train_acc : 0.2816326530612245, acc : 0.2667946219444275. positive_acc : 0.781190019193858


131it [00:00, 199.73it/s]


epoch : 21,train_positive_acc : 0.8036014405762305 train_acc : 0.28667466986794715, acc : 0.280230313539505. positive_acc : 0.7907869481765835


131it [00:00, 192.03it/s]


epoch : 22,train_positive_acc : 0.8012004801920768 train_acc : 0.28907563025210087, acc : 0.2773512303829193. positive_acc : 0.7936660268714012


131it [00:00, 194.84it/s]


epoch : 23,train_positive_acc : 0.8055222088835534 train_acc : 0.2830732292917167, acc : 0.2725527882575989. positive_acc : 0.7850287907869482


131it [00:00, 198.73it/s]


epoch : 24,train_positive_acc : 0.8045618247298919 train_acc : 0.29267707082833133, acc : 0.2773512303829193. positive_acc : 0.7907869481765835


131it [00:00, 193.25it/s]


epoch : 25,train_positive_acc : 0.8050420168067227 train_acc : 0.2885954381752701, acc : 0.27351248264312744. positive_acc : 0.7879078694817658


131it [00:00, 198.78it/s]


epoch : 26,train_positive_acc : 0.8057623049219688 train_acc : 0.29027611044417767, acc : 0.2754318416118622. positive_acc : 0.7869481765834933


131it [00:00, 197.06it/s]


epoch : 27,train_positive_acc : 0.8088835534213685 train_acc : 0.2893157262905162, acc : 0.2744721472263336. positive_acc : 0.7975047984644914


131it [00:00, 197.94it/s]


epoch : 28,train_positive_acc : 0.8074429771908763 train_acc : 0.29195678271308523, acc : 0.2715930938720703. positive_acc : 0.781190019193858


131it [00:00, 197.62it/s]


epoch : 29,train_positive_acc : 0.8115246098439376 train_acc : 0.295078031212485, acc : 0.27927061915397644. positive_acc : 0.7955854126679462


131it [00:00, 197.22it/s]


epoch : 30,train_positive_acc : 0.8108043217286914 train_acc : 0.29747899159663865, acc : 0.27351248264312744. positive_acc : 0.791746641074856


131it [00:00, 192.40it/s]


epoch : 31,train_positive_acc : 0.8105642256902761 train_acc : 0.3010804321728692, acc : 0.27639153599739075. positive_acc : 0.7907869481765835


131it [00:00, 199.24it/s]


epoch : 32,train_positive_acc : 0.8156062424969988 train_acc : 0.30036014405762307, acc : 0.2696737051010132. positive_acc : 0.7869481765834933


131it [00:00, 195.82it/s]


epoch : 33,train_positive_acc : 0.812484993997599 train_acc : 0.2986794717887155, acc : 0.27927061915397644. positive_acc : 0.7869481765834933


131it [00:00, 196.90it/s]


epoch : 34,train_positive_acc : 0.8120048019207683 train_acc : 0.30612244897959184, acc : 0.27063339948654175. positive_acc : 0.791746641074856


131it [00:00, 197.50it/s]


epoch : 35,train_positive_acc : 0.8151260504201681 train_acc : 0.2979591836734694, acc : 0.2696737051010132. positive_acc : 0.7850287907869482


131it [00:00, 194.66it/s]


epoch : 36,train_positive_acc : 0.8132052821128452 train_acc : 0.3013205282112845, acc : 0.2715930938720703. positive_acc : 0.7869481765834933


131it [00:00, 194.96it/s]


epoch : 37,train_positive_acc : 0.8148859543817527 train_acc : 0.30660264105642254, acc : 0.259117066860199. positive_acc : 0.7840690978886756


131it [00:00, 192.72it/s]


epoch : 38,train_positive_acc : 0.8115246098439376 train_acc : 0.3070828331332533, acc : 0.27639153599739075. positive_acc : 0.7946257197696737


131it [00:00, 196.25it/s]


epoch : 39,train_positive_acc : 0.8175270108043218 train_acc : 0.30612244897959184, acc : 0.2696737051010132. positive_acc : 0.7792706333973128


131it [00:00, 194.90it/s]


epoch : 40,train_positive_acc : 0.8175270108043218 train_acc : 0.30948379351740696, acc : 0.27927061915397644. positive_acc : 0.7955854126679462


131it [00:00, 203.28it/s]


epoch : 41,train_positive_acc : 0.8180072028811525 train_acc : 0.3085234093637455, acc : 0.27063339948654175. positive_acc : 0.7936660268714012


131it [00:00, 198.88it/s]


epoch : 42,train_positive_acc : 0.8184873949579832 train_acc : 0.3111644657863145, acc : 0.28214970231056213. positive_acc : 0.789827255278311


131it [00:00, 193.22it/s]


epoch : 43,train_positive_acc : 0.8172869147659063 train_acc : 0.3070828331332533, acc : 0.2888675630092621. positive_acc : 0.7907869481765835


131it [00:00, 195.81it/s]


epoch : 44,train_positive_acc : 0.8156062424969988 train_acc : 0.304921968787515, acc : 0.2859884798526764. positive_acc : 0.7984644913627639


131it [00:00, 196.52it/s]


epoch : 45,train_positive_acc : 0.817046818727491 train_acc : 0.3135654261704682, acc : 0.2744721472263336. positive_acc : 0.7879078694817658


131it [00:00, 200.36it/s]


epoch : 46,train_positive_acc : 0.81968787515006 train_acc : 0.3123649459783914, acc : 0.2754318416118622. positive_acc : 0.7859884836852208


131it [00:00, 197.37it/s]


epoch : 47,train_positive_acc : 0.8180072028811525 train_acc : 0.31092436974789917, acc : 0.27639153599739075. positive_acc : 0.7994241842610365


131it [00:00, 198.11it/s]


epoch : 48,train_positive_acc : 0.8218487394957983 train_acc : 0.3138055222088836, acc : 0.280230313539505. positive_acc : 0.7907869481765835


131it [00:00, 194.19it/s]


epoch : 49,train_positive_acc : 0.8187274909963985 train_acc : 0.31548619447779114, acc : 0.26487523317337036. positive_acc : 0.7763915547024952


131it [00:00, 194.04it/s]


epoch : 50,train_positive_acc : 0.8232893157262905 train_acc : 0.30612244897959184, acc : 0.2725527882575989. positive_acc : 0.7850287907869482


131it [00:00, 202.41it/s]


epoch : 51,train_positive_acc : 0.8201680672268907 train_acc : 0.3138055222088836, acc : 0.2783109247684479. positive_acc : 0.7859884836852208


131it [00:00, 194.70it/s]


epoch : 52,train_positive_acc : 0.8201680672268907 train_acc : 0.31740696278511404, acc : 0.2744721472263336. positive_acc : 0.7840690978886756


131it [00:00, 197.91it/s]


epoch : 53,train_positive_acc : 0.8228091236494598 train_acc : 0.3118847539015606, acc : 0.28694817423820496. positive_acc : 0.7821497120921305


131it [00:00, 198.88it/s]


epoch : 54,train_positive_acc : 0.8199279711884754 train_acc : 0.3114045618247299, acc : 0.27639153599739075. positive_acc : 0.7879078694817658


131it [00:00, 202.05it/s]


epoch : 55,train_positive_acc : 0.8228091236494598 train_acc : 0.31212484993997597, acc : 0.2783109247684479. positive_acc : 0.7821497120921305


131it [00:00, 202.26it/s]


epoch : 56,train_positive_acc : 0.8232893157262905 train_acc : 0.3142857142857143, acc : 0.27927061915397644. positive_acc : 0.7888675623800384


131it [00:00, 193.38it/s]


epoch : 57,train_positive_acc : 0.8232893157262905 train_acc : 0.3133253301320528, acc : 0.280230313539505. positive_acc : 0.791746641074856


131it [00:00, 194.79it/s]


epoch : 58,train_positive_acc : 0.8254501800720289 train_acc : 0.31524609843937573, acc : 0.28406909108161926. positive_acc : 0.7850287907869482


131it [00:00, 199.40it/s]


epoch : 59,train_positive_acc : 0.8271308523409364 train_acc : 0.31596638655462184, acc : 0.2850287854671478. positive_acc : 0.7859884836852208


131it [00:00, 197.68it/s]


epoch : 60,train_positive_acc : 0.8254501800720289 train_acc : 0.3138055222088836, acc : 0.28694817423820496. positive_acc : 0.7946257197696737


131it [00:00, 202.40it/s]


epoch : 61,train_positive_acc : 0.8273709483793518 train_acc : 0.32148859543817526, acc : 0.2907869517803192. positive_acc : 0.7879078694817658


131it [00:00, 193.85it/s]


epoch : 62,train_positive_acc : 0.824249699879952 train_acc : 0.32004801920768305, acc : 0.2639155387878418. positive_acc : 0.7600767754318618


131it [00:00, 196.28it/s]


epoch : 63,train_positive_acc : 0.8280912364945978 train_acc : 0.32100840336134456, acc : 0.2831093966960907. positive_acc : 0.7783109404990403


131it [00:00, 191.23it/s]


epoch : 64,train_positive_acc : 0.8273709483793518 train_acc : 0.32004801920768305, acc : 0.28214970231056213. positive_acc : 0.7888675623800384


131it [00:00, 202.09it/s]


epoch : 65,train_positive_acc : 0.828811524609844 train_acc : 0.3224489795918367, acc : 0.28694817423820496. positive_acc : 0.7792706333973128


131it [00:00, 200.47it/s]


epoch : 66,train_positive_acc : 0.8292917166866747 train_acc : 0.3303721488595438, acc : 0.28214970231056213. positive_acc : 0.783109404990403


131it [00:00, 198.26it/s]


epoch : 67,train_positive_acc : 0.8290516206482593 train_acc : 0.3190876350540216, acc : 0.2667946219444275. positive_acc : 0.7677543186180422


131it [00:00, 201.44it/s]


epoch : 68,train_positive_acc : 0.828811524609844 train_acc : 0.3217286914765906, acc : 0.2888675630092621. positive_acc : 0.7850287907869482


131it [00:00, 194.95it/s]


epoch : 69,train_positive_acc : 0.8319327731092437 train_acc : 0.3284513805522209, acc : 0.26775431632995605. positive_acc : 0.7687140115163148


131it [00:00, 201.22it/s]


epoch : 70,train_positive_acc : 0.8273709483793518 train_acc : 0.32869147659063624, acc : 0.2783109247684479. positive_acc : 0.781190019193858


131it [00:00, 198.49it/s]


epoch : 71,train_positive_acc : 0.8316926770708283 train_acc : 0.32653061224489793, acc : 0.2725527882575989. positive_acc : 0.7783109404990403


131it [00:00, 192.65it/s]


epoch : 72,train_positive_acc : 0.8256902761104442 train_acc : 0.3241296518607443, acc : 0.2888675630092621. positive_acc : 0.781190019193858


131it [00:00, 196.54it/s]


epoch : 73,train_positive_acc : 0.8280912364945978 train_acc : 0.33157262905162066, acc : 0.27927061915397644. positive_acc : 0.7677543186180422


131it [00:00, 193.86it/s]


epoch : 74,train_positive_acc : 0.8324129651860744 train_acc : 0.3217286914765906, acc : 0.28406909108161926. positive_acc : 0.7850287907869482


131it [00:00, 193.82it/s]


epoch : 75,train_positive_acc : 0.8326530612244898 train_acc : 0.32148859543817526, acc : 0.27063339948654175. positive_acc : 0.7610364683301344


131it [00:00, 195.99it/s]


epoch : 76,train_positive_acc : 0.8316926770708283 train_acc : 0.32677070828331334, acc : 0.28214970231056213. positive_acc : 0.7821497120921305


131it [00:00, 195.32it/s]


epoch : 77,train_positive_acc : 0.8331332533013205 train_acc : 0.3258103241296519, acc : 0.2725527882575989. positive_acc : 0.7754318618042226


131it [00:00, 192.47it/s]


epoch : 78,train_positive_acc : 0.8264105642256903 train_acc : 0.3231692677070828, acc : 0.27351248264312744. positive_acc : 0.7706333973128598


131it [00:00, 198.62it/s]


epoch : 79,train_positive_acc : 0.8360144057623049 train_acc : 0.33349339735894357, acc : 0.2754318416118622. positive_acc : 0.7802303262955854


131it [00:00, 202.71it/s]


epoch : 80,train_positive_acc : 0.834813925570228 train_acc : 0.3346938775510204, acc : 0.2859884798526764. positive_acc : 0.7754318618042226


131it [00:00, 196.51it/s]


epoch : 81,train_positive_acc : 0.834813925570228 train_acc : 0.3270108043217287, acc : 0.2715930938720703. positive_acc : 0.7735124760076776


131it [00:00, 198.69it/s]


epoch : 82,train_positive_acc : 0.8338535414165666 train_acc : 0.3224489795918367, acc : 0.2879078686237335. positive_acc : 0.7600767754318618


131it [00:00, 203.94it/s]


epoch : 83,train_positive_acc : 0.8345738295318127 train_acc : 0.3298919567827131, acc : 0.27351248264312744. positive_acc : 0.7610364683301344


131it [00:00, 197.99it/s]


epoch : 84,train_positive_acc : 0.8360144057623049 train_acc : 0.3284513805522209, acc : 0.27927061915397644. positive_acc : 0.7792706333973128


131it [00:00, 196.70it/s]


epoch : 85,train_positive_acc : 0.833373349339736 train_acc : 0.3303721488595438, acc : 0.27351248264312744. positive_acc : 0.7792706333973128


131it [00:00, 197.68it/s]


epoch : 86,train_positive_acc : 0.838655462184874 train_acc : 0.3238895558223289, acc : 0.2754318416118622. positive_acc : 0.7610364683301344


131it [00:00, 193.78it/s]


epoch : 87,train_positive_acc : 0.8355342136854742 train_acc : 0.3248499399759904, acc : 0.27639153599739075. positive_acc : 0.7773512476007678


131it [00:00, 196.93it/s]


epoch : 88,train_positive_acc : 0.834093637454982 train_acc : 0.3296518607442977, acc : 0.2859884798526764. positive_acc : 0.7821497120921305


131it [00:00, 197.51it/s]


epoch : 89,train_positive_acc : 0.8324129651860744 train_acc : 0.3310924369747899, acc : 0.28982725739479065. positive_acc : 0.783109404990403


131it [00:00, 189.66it/s]


epoch : 90,train_positive_acc : 0.8362545018007203 train_acc : 0.3325330132052821, acc : 0.2879078686237335. positive_acc : 0.7773512476007678


131it [00:00, 190.60it/s]


epoch : 91,train_positive_acc : 0.8398559423769508 train_acc : 0.3296518607442977, acc : 0.27351248264312744. positive_acc : 0.7735124760076776


131it [00:00, 197.73it/s]


epoch : 92,train_positive_acc : 0.8331332533013205 train_acc : 0.3332533013205282, acc : 0.2907869517803192. positive_acc : 0.7821497120921305


131it [00:00, 192.46it/s]


epoch : 93,train_positive_acc : 0.8362545018007203 train_acc : 0.3243697478991597, acc : 0.26775431632995605. positive_acc : 0.761996161228407


131it [00:00, 196.12it/s]


epoch : 94,train_positive_acc : 0.8381752701080433 train_acc : 0.33013205282112845, acc : 0.2859884798526764. positive_acc : 0.783109404990403


131it [00:00, 194.56it/s]


epoch : 95,train_positive_acc : 0.8336134453781513 train_acc : 0.33229291716686676, acc : 0.2773512303829193. positive_acc : 0.7696737044145874


131it [00:00, 202.28it/s]


epoch : 96,train_positive_acc : 0.8388955582232893 train_acc : 0.3339735894357743, acc : 0.2744721472263336. positive_acc : 0.7696737044145874


131it [00:00, 196.55it/s]


epoch : 97,train_positive_acc : 0.8405762304921969 train_acc : 0.32869147659063624, acc : 0.2783109247684479. positive_acc : 0.761996161228407


131it [00:00, 192.29it/s]


epoch : 98,train_positive_acc : 0.8379351740696278 train_acc : 0.33085234093637456, acc : 0.2773512303829193. positive_acc : 0.7802303262955854


131it [00:00, 199.65it/s]


epoch : 99,train_positive_acc : 0.8403361344537815 train_acc : 0.33301320528211287, acc : 0.2744721472263336. positive_acc : 0.7677543186180422


In [None]:
model_E = ViT_LRP_copy.VisionTransformer(seq_len=6, num_classes=9, embed_dim=72, depth=8,
                 num_heads=6, mlp_ratio=4., qkv_bias=False, mlp_head=False, drop_rate=0.1, attn_drop_rate=0.1)
model_K = ViT_LRP_copy.VisionTransformer(seq_len=6, num_classes=9, embed_dim=72, depth=8,
                 num_heads=6, mlp_ratio=4., qkv_bias=False, mlp_head=False, drop_rate=0.1, attn_drop_rate=0.1)
model_M = ViT_LRP_copy.VisionTransformer(seq_len=6, num_classes=9, embed_dim=72, depth=8,
                 num_heads=6, mlp_ratio=4., qkv_bias=False, mlp_head=False, drop_rate=0.1, attn_drop_rate=0.1)                 
for concated_data,(label_E,label_K,label_M) in train_loader:
    datas = batch_to_splited_datas(concated_data,split_list)
    emb_batch_list= batch_to_embbedings(datas,embbeding_networks) # can be used for contrastive loss
    # emb_batch_list : 임베딩 벡터들의 리스트. 얘를 이제 batch x seq x feature 행렬로 쌓음
    emb_batched_seq = torch.stack(emb_batch_list).transpose(0,1)
    attn_mask = make_attn_mask(emb_batched_seq)
    (E_score,K_score,M_score) = model_E(emb_batched_seq,attn_mask),model_K(emb_batched_seq,attn_mask),model_M(emb_batched_seq,attn_mask)
    



criterion = nn.CrossEntropyLoss()

loss_E = criterion(E_score,label_E)
loss_K = criterion(K_score,label_K)
loss_M = criterion(M_score,label_M)