In [2]:
import torch
import pandas as pd
import sklearn
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from tqdm import tqdm
from torch.utils.data import Dataset,DataLoader
from utils.nn_utils import *
from models.ViT import ViT_LRP_nan_excluded
from models.LSTM import Recurrent_Classifier

X_datapaths = ['./preprocessed/prepared/nan/L2Y1.pkl','./preprocessed/prepared/nan/L2Y2.pkl','./preprocessed/prepared/nan/L2Y3.pkl','./preprocessed/prepared/nan/L2Y4.pkl','./preprocessed/prepared/nan/L2Y5.pkl','./preprocessed/prepared/nan/L2Y6.pkl',]
label_datapath = './preprocessed/prepared/nan/label.pkl'
#X_datapaths = ['./preprocessed/prepared/fill/L2Y1.pkl','./preprocessed/prepared/fill/L2Y2.pkl','./preprocessed/prepared/fill/L2Y3.pkl','./preprocessed/prepared/fill/L2Y4.pkl','./preprocessed/prepared/fill/L2Y5.pkl','./preprocessed/prepared/fill/L2Y6.pkl',]
#label_datapath = './preprocessed/prepared/fill/label.pkl'
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")



# read pickle
input_datas = [] # list of each input pandas dataframe
for datapath in X_datapaths:
    temp = pd.read_pickle(datapath)
    temp = temp.reset_index()
    input_datas.append(temp)

label_data = pd.read_pickle(label_datapath)
label_data = label_data.reset_index()



seq_len = len(input_datas)

label_data = label_data - 1


CLS2IDX = {
    0 : '1등급',
    1 : '2등급',
    2 : '3등급',
    3 : '4등급',
    4 : '5등급',
    5 : '6등급',
    6 : '7등급',
    7 : '8등급',
    8 : '9등급'
}
is_regression = False
X_trains, X_tests, y_train, y_test = make_splited_data(input_datas,label_data,is_regression=is_regression)

train_dataset = KELSDataSet(X_trains,y_train,is_regression=is_regression)
test_dataset = KELSDataSet(X_tests,y_test,is_regression=is_regression)


batch_size = 32
hidden_features = 100
embbed_dim = 72



train_loader = DataLoader(train_dataset,batch_size=batch_size,shuffle=True)
test_loader = DataLoader(test_dataset,batch_size=batch_size,shuffle=False)
split_list = train_dataset.split_list
#embedding_networks : 년차별로 맞는 mlp 리스트. 리스트 내용물에 따라 인풋 채널 개수 다름.

sample_loader = DataLoader(train_dataset,batch_size=batch_size,shuffle=True)
assert batch_size
(sample,label) = next(iter(sample_loader))
sample = sample.to(device)
sample_datas = batch_to_splited_datas(sample,split_list)


#embbeding_networks = make_embbeding_networks(sample_datas,batch_size=batch_size,hidden_features = hidden_features,out_features=embbed_dim)
#embbeding network 사용법 : 
#1. train_loader에서 배치를 받는다. 이 배치는 (batch, 학생수, features)인데, features는 모든 년차별 feature가 concat된 상태.
#2. 이 배치를 batch_to_splited_datas에 넣는다. datas = batch_to_splited_datas(배치,split_list)
# 그러면 datas에 년차별 feature가 나누어진 리스트가 생기게 된다. 리스트 길이는 년차 길이고, 내용물은 행렬
# 3. 이 datas를 batch_to_embbedings 에 넣는다. 그럼 각 행렬이 embbeding_networks에 든 mlp를 통과하여 같은 크기의 행렬들을 리턴한다. 
# emb_batch_list= batch_to_embbedings(datas,embbeding_networks)
#4. emb_batch_list의 내용물이 바로 contrastive loss에 들어갈 벡터들이 된다. 

#5. lstm에 넣으려면 얘들을 seq_len 방향으로 쌓아야 한다. 
#emb_batched_seq = torch.stack(emb_batch_list).transpose(0,1) 으로 쌓고 lstm에 넣으면 완성..?

In [3]:
def accuracy_roughly(y_pred, y_label):
    #TODO add as loss
    if len(y_pred) != len(y_label):
        print("not available, fit size first")
        return
    cnt = 0
    correct = 0
    for pred, label in zip(y_pred, y_label):
        cnt += 1
        if abs(pred-label) <= 1:
            correct += 1
    return correct / cnt

In [4]:
def train_net(model,train_loader,test_loader,optimizer_cls = optim.AdamW, criterion = nn.CrossEntropyLoss(),
n_iter=10,device='cpu',lr = 0.001,weight_decay = 0.01,mode = None):
        
        train_losses = []
        train_acc = []
        val_accs = []
        positive_accs = []
        #optimizer = optimizer_cls(model.parameters(),lr=lr,weight_decay=weight_decay)
        optimizer = optimizer_cls(model.parameters(),lr=lr)
        #scheduler = optim.lr_scheduler.MultiStepLR(optimizer,milestones=[25,40,60,80], gamma=0.5,last_epoch=-1)
        

        for epoch in range(n_iter):
                running_loss = 0.0
                model.train()
                n = 0
                n_acc = 0
                for i,(xx,(label_E,label_K,label_M)) in tqdm(enumerate(train_loader)):

                
                
                        xx = xx.to(device)
                        if mode == 'E':
                                yy = label_E
                        elif mode == 'K':
                                yy = label_K
                        elif mode == 'M':
                                yy = label_M
                        else:
                                assert True
                        
                        yy = yy.to(device)
                        

                        
                        
                
                        optimizer.zero_grad()
                        outputs = model(xx)

                        # Calculate Loss: softmax --> cross entropy loss
                        loss = criterion(outputs, yy)


                        # Getting gradients w.r.t. parameters
                        loss.backward()

                        # Updating parameters
                        optimizer.step()
                        
                        
                        
                        i += 1
                        n += len(xx)
                        _, y_pred = outputs.max(1)
                        n_acc += (yy == y_pred).float().sum().item()
                #scheduler.step()
                train_losses.append(running_loss/i)
                train_acc.append(n_acc/n)
                acc, positive_acc = eval_net(model,test_loader,device,mode = mode)
                val_accs.append(acc)
                positive_accs.append(positive_acc)

                print(f'epoch : {epoch}, train_acc : {train_acc[-1]}, acc : {val_accs[-1]}. positive_acc : {positive_accs[-1]}',flush = True)

In [5]:
def eval_net(model,data_loader,device,mode=None):
    model.eval()
    ys = []
    ypreds = []
    for xx,(label_E,label_K,label_M) in data_loader:

                
                
        xx = xx.to(device)
        if mode == 'E':
            y = label_E
        elif mode == 'K':
            y = label_K
        elif mode == 'M':
            y = label_M
        else:
            assert True
        
        y = y.to(device)

        with torch.no_grad():
                score = model(xx)
                _,y_pred = score.max(1)
        ys.append(y)
        ypreds.append(y_pred)

    ys = torch.cat(ys)
    ypreds = torch.cat(ypreds)
    positive_acc = accuracy_roughly(ypreds,ys)
    acc= (ys == ypreds).float().sum() / len(ys)

    # print(sklearn.metrics.confusion_matrix(ys.numpy(),ypreds.numpy()))


    # print(sklearn.metrics.classification_report(ys.numpy(),ypreds.numpy()))
    

    return acc, positive_acc
    #return acc.item()

In [None]:
model_LSTM = Recurrent_Classifier.RecurrentClassifier(sample_datas,split_list,embedding_dim=32,hidden_dim=64,output_size=9)
model_LSTM = model_LSTM.to(device)
train_net(model_LSTM,train_loader,test_loader,n_iter=100,device=device,mode='E',lr=0.0001,optimizer_cls = optim.AdamW)

In [24]:
model_E = ViT_LRP_nan_excluded.VisionTransformer(sample_datas,split_list,seq_len=6, num_classes=9, embed_dim=32*3, depth=6,
                 num_heads=8, mlp_ratio=4., qkv_bias=False, mlp_head=False, drop_rate=0.2, attn_drop_rate=0.2)
model_E = model_E.to(device)

train_net(model_E,train_loader,test_loader,n_iter=100,device=device,mode='E',lr=0.0001,optimizer_cls = optim.AdamW)


131it [00:06, 19.72it/s]


epoch : 0, train_acc : 0.22593037214885955, acc : 0.27639153599739075. positive_acc : 0.5777351247600768


131it [00:07, 17.21it/s]


epoch : 1, train_acc : 0.3085234093637455, acc : 0.3272552788257599. positive_acc : 0.6525911708253359


131it [00:09, 13.11it/s]


epoch : 2, train_acc : 0.32941176470588235, acc : 0.35604605078697205. positive_acc : 0.6928982725527831


131it [00:09, 13.42it/s]


epoch : 3, train_acc : 0.34213685474189676, acc : 0.33685219287872314. positive_acc : 0.6737044145873321


131it [00:09, 13.63it/s]


epoch : 4, train_acc : 0.36206482593037215, acc : 0.3435700535774231. positive_acc : 0.6746641074856046


131it [00:09, 13.55it/s]


epoch : 5, train_acc : 0.356062424969988, acc : 0.3579654395580292. positive_acc : 0.6900191938579654


131it [00:10, 13.02it/s]


epoch : 6, train_acc : 0.35942376950780314, acc : 0.35892513394355774. positive_acc : 0.699616122840691


131it [00:10, 12.96it/s]


epoch : 7, train_acc : 0.3671068427370948, acc : 0.35028788447380066. positive_acc : 0.690978886756238


131it [00:09, 13.25it/s]


epoch : 8, train_acc : 0.3695078031212485, acc : 0.36180421710014343. positive_acc : 0.6957773512476008


131it [00:10, 12.70it/s]


epoch : 9, train_acc : 0.37767106842737097, acc : 0.3512475788593292. positive_acc : 0.6919385796545106


131it [00:10, 12.96it/s]


epoch : 10, train_acc : 0.37214885954381755, acc : 0.3435700535774231. positive_acc : 0.6900191938579654


131it [00:10, 12.65it/s]


epoch : 11, train_acc : 0.3764705882352941, acc : 0.3435700535774231. positive_acc : 0.6756238003838771


131it [00:09, 13.30it/s]


epoch : 12, train_acc : 0.37599039615846336, acc : 0.3512475788593292. positive_acc : 0.6890595009596929


131it [00:09, 13.42it/s]


epoch : 13, train_acc : 0.37671068427370946, acc : 0.3493282198905945. positive_acc : 0.6967370441458733


131it [00:09, 13.58it/s]


epoch : 14, train_acc : 0.3836734693877551, acc : 0.3512475788593292. positive_acc : 0.6871401151631478


131it [00:08, 14.94it/s]


epoch : 15, train_acc : 0.3843937575030012, acc : 0.3541266620159149. positive_acc : 0.6880998080614203


131it [00:07, 17.22it/s]


epoch : 16, train_acc : 0.3971188475390156, acc : 0.3522072732448578. positive_acc : 0.6794625719769674


131it [00:07, 18.12it/s]


epoch : 17, train_acc : 0.39399759903961584, acc : 0.3522072732448578. positive_acc : 0.6948176583493282


76it [00:04, 18.82it/s]


KeyboardInterrupt: 

In [12]:
model_K = ViT_LRP_nan_excluded.VisionTransformer(sample_datas,split_list,seq_len=6, num_classes=9, embed_dim=16*3, depth=8,
                 num_heads=6, mlp_ratio=4., qkv_bias=False, mlp_head=False, drop_rate=0.2, attn_drop_rate=0.2)
model_K = model_K.to(device)

train_net(model_K,train_loader,test_loader,n_iter=100,device=device,mode='K',lr=0.0001,optimizer_cls = optim.AdamW)


131it [00:05, 21.95it/s]


epoch : 0, train_acc : 0.2261704681872749, acc : 0.28694817423820496. positive_acc : 0.6103646833013435


131it [00:05, 22.09it/s]


epoch : 1, train_acc : 0.2957983193277311, acc : 0.34165066480636597. positive_acc : 0.6833013435700576


131it [00:05, 22.09it/s]


epoch : 2, train_acc : 0.3171668667466987, acc : 0.34740883111953735. positive_acc : 0.699616122840691


131it [00:05, 22.09it/s]


epoch : 3, train_acc : 0.3238895558223289, acc : 0.3023032546043396. positive_acc : 0.6477927063339731


131it [00:05, 22.32it/s]


epoch : 4, train_acc : 0.33661464585834333, acc : 0.35316696763038635. positive_acc : 0.7197696737044146


131it [00:05, 22.16it/s]


epoch : 5, train_acc : 0.3296518607442977, acc : 0.30902111530303955. positive_acc : 0.654510556621881


131it [00:05, 22.25it/s]


epoch : 6, train_acc : 0.34357743097238896, acc : 0.362763911485672. positive_acc : 0.7341650671785028


131it [00:05, 22.21it/s]


epoch : 7, train_acc : 0.3457382953181273, acc : 0.3483685255050659. positive_acc : 0.738003838771593


131it [00:05, 22.39it/s]


epoch : 8, train_acc : 0.33733493397358943, acc : 0.36660268902778625. positive_acc : 0.7399232245681382


131it [00:05, 22.49it/s]


epoch : 9, train_acc : 0.3433373349339736, acc : 0.3656429946422577. positive_acc : 0.7341650671785028


131it [00:05, 22.14it/s]


epoch : 10, train_acc : 0.3452581032412965, acc : 0.32053741812705994. positive_acc : 0.6880998080614203


131it [00:05, 22.60it/s]


epoch : 11, train_acc : 0.35270108043217285, acc : 0.36660268902778625. positive_acc : 0.7456813819577736


131it [00:05, 22.43it/s]


epoch : 12, train_acc : 0.35198079231692675, acc : 0.3406909704208374. positive_acc : 0.7216890595009597


131it [00:05, 22.24it/s]


epoch : 13, train_acc : 0.3551020408163265, acc : 0.35892513394355774. positive_acc : 0.7485604606525912


131it [00:05, 22.05it/s]


epoch : 14, train_acc : 0.35942376950780314, acc : 0.35316696763038635. positive_acc : 0.7245681381957774


131it [00:05, 22.34it/s]


epoch : 15, train_acc : 0.3630252100840336, acc : 0.3454894423484802. positive_acc : 0.7140115163147792


131it [00:05, 22.03it/s]


epoch : 16, train_acc : 0.3625450180072029, acc : 0.3646833002567291. positive_acc : 0.736084452975048


131it [00:05, 22.50it/s]


epoch : 17, train_acc : 0.36134453781512604, acc : 0.3598848283290863. positive_acc : 0.7293666026871402


131it [00:05, 22.39it/s]


epoch : 18, train_acc : 0.36542617046818726, acc : 0.3454894423484802. positive_acc : 0.7207293666026872


131it [00:05, 22.33it/s]


epoch : 19, train_acc : 0.36926770708283313, acc : 0.3742802143096924. positive_acc : 0.746641074856046


131it [00:05, 22.23it/s]


epoch : 20, train_acc : 0.37118847539015604, acc : 0.3704414367675781. positive_acc : 0.7504798464491362


131it [00:05, 22.35it/s]


epoch : 21, train_acc : 0.3683073229291717, acc : 0.3646833002567291. positive_acc : 0.7476007677543186


131it [00:05, 22.20it/s]


epoch : 22, train_acc : 0.3671068427370948, acc : 0.3675623834133148. positive_acc : 0.7591170825335892


131it [00:05, 22.27it/s]


epoch : 23, train_acc : 0.3824729891956783, acc : 0.3675623834133148. positive_acc : 0.7629558541266794


131it [00:05, 22.22it/s]


epoch : 24, train_acc : 0.37551020408163266, acc : 0.37523990869522095. positive_acc : 0.7514395393474088


131it [00:05, 22.43it/s]


epoch : 25, train_acc : 0.38079231692677074, acc : 0.3579654395580292. positive_acc : 0.7351247600767754


131it [00:05, 22.40it/s]


epoch : 26, train_acc : 0.37478991596638656, acc : 0.36084452271461487. positive_acc : 0.755278310940499


131it [00:05, 22.75it/s]


epoch : 27, train_acc : 0.38055222088835533, acc : 0.3579654395580292. positive_acc : 0.7504798464491362


131it [00:05, 22.54it/s]


epoch : 28, train_acc : 0.38079231692677074, acc : 0.3761996030807495. positive_acc : 0.7571976967370442


131it [00:05, 22.35it/s]


epoch : 29, train_acc : 0.3786314525810324, acc : 0.3464491367340088. positive_acc : 0.7216890595009597


131it [00:05, 22.52it/s]


epoch : 30, train_acc : 0.37815126050420167, acc : 0.36180421710014343. positive_acc : 0.7408829174664108


131it [00:05, 22.27it/s]


epoch : 31, train_acc : 0.38271308523409364, acc : 0.368522047996521. positive_acc : 0.7418426103646834


131it [00:05, 22.41it/s]


epoch : 32, train_acc : 0.3795918367346939, acc : 0.3714011311531067. positive_acc : 0.736084452975048


131it [00:05, 22.05it/s]


epoch : 33, train_acc : 0.3858343337334934, acc : 0.36372360587120056. positive_acc : 0.7591170825335892


131it [00:05, 21.87it/s]


epoch : 34, train_acc : 0.3896758703481393, acc : 0.3598848283290863. positive_acc : 0.7399232245681382


131it [00:05, 22.58it/s]


epoch : 35, train_acc : 0.3795918367346939, acc : 0.35028788447380066. positive_acc : 0.7264875239923224


131it [00:05, 22.59it/s]


epoch : 36, train_acc : 0.38391356542617044, acc : 0.3387715816497803. positive_acc : 0.7111324376199616


131it [00:05, 22.35it/s]


epoch : 37, train_acc : 0.39063625450180073, acc : 0.3512475788593292. positive_acc : 0.7245681381957774


131it [00:05, 22.30it/s]


epoch : 38, train_acc : 0.3911164465786314, acc : 0.329174667596817. positive_acc : 0.7111324376199616


131it [00:05, 22.45it/s]


epoch : 39, train_acc : 0.39879951980792316, acc : 0.35028788447380066. positive_acc : 0.708253358925144


131it [00:05, 22.23it/s]


epoch : 40, train_acc : 0.3985594237695078, acc : 0.3646833002567291. positive_acc : 0.7514395393474088


131it [00:05, 22.26it/s]


epoch : 41, train_acc : 0.3937575030012005, acc : 0.35604605078697205. positive_acc : 0.7370441458733206


131it [00:05, 22.48it/s]


epoch : 42, train_acc : 0.4069627851140456, acc : 0.3454894423484802. positive_acc : 0.718809980806142


131it [00:05, 22.18it/s]


epoch : 43, train_acc : 0.39759903961584636, acc : 0.3262955844402313. positive_acc : 0.7120921305182342


131it [00:05, 22.14it/s]


epoch : 44, train_acc : 0.4120048019207683, acc : 0.35028788447380066. positive_acc : 0.7284069097888676


131it [00:05, 22.28it/s]


epoch : 45, train_acc : 0.4084033613445378, acc : 0.3483685255050659. positive_acc : 0.7418426103646834


131it [00:05, 22.26it/s]


epoch : 46, train_acc : 0.40024009603841537, acc : 0.3454894423484802. positive_acc : 0.7216890595009597


131it [00:05, 22.38it/s]


epoch : 47, train_acc : 0.4069627851140456, acc : 0.3262955844402313. positive_acc : 0.6976967370441459


131it [00:05, 22.06it/s]


epoch : 48, train_acc : 0.4064825930372149, acc : 0.36084452271461487. positive_acc : 0.7485604606525912


131it [00:05, 22.53it/s]


epoch : 49, train_acc : 0.4069627851140456, acc : 0.368522047996521. positive_acc : 0.7514395393474088


131it [00:05, 22.58it/s]


epoch : 50, train_acc : 0.417046818727491, acc : 0.3483685255050659. positive_acc : 0.7226487523992322


131it [00:05, 22.32it/s]


epoch : 51, train_acc : 0.40864345738295316, acc : 0.34165066480636597. positive_acc : 0.7332053742802304


131it [00:05, 22.24it/s]


epoch : 52, train_acc : 0.41272509003601443, acc : 0.3387715816497803. positive_acc : 0.7418426103646834


131it [00:05, 22.05it/s]


epoch : 53, train_acc : 0.42737094837935174, acc : 0.362763911485672. positive_acc : 0.7437619961612284


131it [00:05, 22.30it/s]


epoch : 54, train_acc : 0.41320528211284513, acc : 0.35316696763038635. positive_acc : 0.7293666026871402


131it [00:05, 21.91it/s]


epoch : 55, train_acc : 0.4249699879951981, acc : 0.3435700535774231. positive_acc : 0.7284069097888676


131it [00:05, 22.28it/s]


epoch : 56, train_acc : 0.4177671068427371, acc : 0.35028788447380066. positive_acc : 0.7245681381957774


131it [00:05, 21.94it/s]


epoch : 57, train_acc : 0.424249699879952, acc : 0.31957772374153137. positive_acc : 0.6861804222648752


131it [00:05, 22.07it/s]


epoch : 58, train_acc : 0.42665066026410564, acc : 0.35028788447380066. positive_acc : 0.7351247600767754


131it [00:05, 22.25it/s]


epoch : 59, train_acc : 0.4261704681872749, acc : 0.3464491367340088. positive_acc : 0.727447216890595


131it [00:05, 22.48it/s]


epoch : 60, train_acc : 0.417046818727491, acc : 0.33685219287872314. positive_acc : 0.7140115163147792


131it [00:05, 22.08it/s]


epoch : 61, train_acc : 0.42472989195678273, acc : 0.3435700535774231. positive_acc : 0.7207293666026872


131it [00:05, 22.10it/s]


epoch : 62, train_acc : 0.42208883553421367, acc : 0.3330134153366089. positive_acc : 0.7264875239923224


131it [00:05, 22.17it/s]


epoch : 63, train_acc : 0.4261704681872749, acc : 0.33109402656555176. positive_acc : 0.7111324376199616


131it [00:05, 22.13it/s]


epoch : 64, train_acc : 0.43289315726290517, acc : 0.3272552788257599. positive_acc : 0.7245681381957774


131it [00:05, 22.57it/s]


epoch : 65, train_acc : 0.42448979591836733, acc : 0.34452974796295166. positive_acc : 0.738003838771593


131it [00:05, 22.34it/s]


epoch : 66, train_acc : 0.42929171668667465, acc : 0.32341650128364563. positive_acc : 0.7034548944337812


131it [00:05, 22.37it/s]


epoch : 67, train_acc : 0.4336134453781513, acc : 0.3512475788593292. positive_acc : 0.7437619961612284


131it [00:05, 22.14it/s]


epoch : 68, train_acc : 0.44537815126050423, acc : 0.33973127603530884. positive_acc : 0.7044145873320538


131it [00:05, 22.34it/s]


epoch : 69, train_acc : 0.44345738295318127, acc : 0.3387715816497803. positive_acc : 0.7293666026871402


131it [00:05, 22.51it/s]


epoch : 70, train_acc : 0.44201680672268906, acc : 0.32245680689811707. positive_acc : 0.718809980806142


131it [00:05, 22.50it/s]


epoch : 71, train_acc : 0.43937575030012005, acc : 0.3598848283290863. positive_acc : 0.746641074856046


131it [00:05, 22.64it/s]


epoch : 72, train_acc : 0.44729891956782714, acc : 0.31765833497047424. positive_acc : 0.6957773512476008


131it [00:05, 22.32it/s]


epoch : 73, train_acc : 0.4362545018007203, acc : 0.3579654395580292. positive_acc : 0.7428023032629558


131it [00:05, 22.02it/s]


epoch : 74, train_acc : 0.4516206482593037, acc : 0.31957772374153137. positive_acc : 0.7140115163147792


131it [00:05, 22.55it/s]


epoch : 75, train_acc : 0.44417767106842737, acc : 0.3512475788593292. positive_acc : 0.7476007677543186


131it [00:05, 22.23it/s]


epoch : 76, train_acc : 0.44201680672268906, acc : 0.3262955844402313. positive_acc : 0.7053742802303263


131it [00:05, 22.06it/s]


epoch : 77, train_acc : 0.44849939975990394, acc : 0.329174667596817. positive_acc : 0.72552783109405


131it [00:05, 22.34it/s]


epoch : 78, train_acc : 0.4468187274909964, acc : 0.31190016865730286. positive_acc : 0.6967370441458733


131it [00:05, 22.14it/s]


epoch : 79, train_acc : 0.44537815126050423, acc : 0.32053741812705994. positive_acc : 0.7197696737044146


131it [00:05, 22.32it/s]


epoch : 80, train_acc : 0.4581032412965186, acc : 0.3128598630428314. positive_acc : 0.699616122840691


131it [00:05, 22.49it/s]


epoch : 81, train_acc : 0.46002400960384154, acc : 0.2946256995201111. positive_acc : 0.6804222648752399


131it [00:05, 22.10it/s]


epoch : 82, train_acc : 0.46218487394957986, acc : 0.32821497321128845. positive_acc : 0.7178502879078695


131it [00:05, 22.42it/s]


epoch : 83, train_acc : 0.4545018007202881, acc : 0.34261035919189453. positive_acc : 0.7370441458733206


131it [00:05, 22.34it/s]


epoch : 84, train_acc : 0.46338535414165666, acc : 0.3099808096885681. positive_acc : 0.6880998080614203


131it [00:05, 22.43it/s]


epoch : 85, train_acc : 0.4631452581032413, acc : 0.3272552788257599. positive_acc : 0.7044145873320538


131it [00:05, 22.32it/s]


epoch : 86, train_acc : 0.45546218487394957, acc : 0.3109405040740967. positive_acc : 0.7111324376199616


131it [00:05, 22.12it/s]


epoch : 87, train_acc : 0.4645858343337335, acc : 0.3301343619823456. positive_acc : 0.7264875239923224


131it [00:05, 22.45it/s]


epoch : 88, train_acc : 0.46602641056422567, acc : 0.3099808096885681. positive_acc : 0.6919385796545106


131it [00:05, 22.45it/s]


epoch : 89, train_acc : 0.4657863145258103, acc : 0.29270634055137634. positive_acc : 0.6880998080614203


131it [00:05, 22.54it/s]


epoch : 90, train_acc : 0.47130852340936374, acc : 0.3186180293560028. positive_acc : 0.7120921305182342


131it [00:05, 22.21it/s]


epoch : 91, train_acc : 0.4775510204081633, acc : 0.3330134153366089. positive_acc : 0.7178502879078695


131it [00:05, 22.34it/s]


epoch : 92, train_acc : 0.47274909963985595, acc : 0.3051823377609253. positive_acc : 0.7044145873320538


131it [00:05, 22.34it/s]


epoch : 93, train_acc : 0.47322929171668665, acc : 0.30422264337539673. positive_acc : 0.7159309021113244


131it [00:05, 21.99it/s]


epoch : 94, train_acc : 0.47923169267707083, acc : 0.3262955844402313. positive_acc : 0.716890595009597


131it [00:05, 22.21it/s]


epoch : 95, train_acc : 0.48403361344537815, acc : 0.34165066480636597. positive_acc : 0.7264875239923224


131it [00:05, 22.05it/s]


epoch : 96, train_acc : 0.468187274909964, acc : 0.29270634055137634. positive_acc : 0.6852207293666027


131it [00:05, 22.19it/s]


epoch : 97, train_acc : 0.4751500600240096, acc : 0.27927061915397644. positive_acc : 0.6833013435700576


131it [00:05, 22.24it/s]


epoch : 98, train_acc : 0.4801920768307323, acc : 0.2850287854671478. positive_acc : 0.6583493282149712


131it [00:05, 21.99it/s]


epoch : 99, train_acc : 0.487875150060024, acc : 0.3214971125125885. positive_acc : 0.7197696737044146


In [38]:
model_K = ViT_LRP_nan_excluded.VisionTransformer(sample_datas,split_list,seq_len=6, num_classes=9, embed_dim=110*3, depth=3,
                 num_heads=6, mlp_ratio=4., qkv_bias=False, mlp_head=True, drop_rate=0.2, attn_drop_rate=0.2)
model_K = model_K.to(device)

train_net(model_K,train_loader,test_loader,n_iter=100,device=device,mode='K',lr=0.0001,optimizer_cls = optim.AdamW)


131it [00:08, 15.16it/s]


epoch : 0, train_acc : 0.25618247298919566, acc : 0.329174667596817. positive_acc : 0.699616122840691


131it [00:08, 15.54it/s]


epoch : 1, train_acc : 0.32148859543817526, acc : 0.3330134153366089. positive_acc : 0.7092130518234165


131it [00:08, 15.39it/s]


epoch : 2, train_acc : 0.3289315726290516, acc : 0.3378118872642517. positive_acc : 0.7514395393474088


131it [00:09, 13.57it/s]


epoch : 3, train_acc : 0.3378151260504202, acc : 0.3464491367340088. positive_acc : 0.7428023032629558


131it [00:07, 17.45it/s]


epoch : 4, train_acc : 0.341656662665066, acc : 0.3435700535774231. positive_acc : 0.710172744721689


131it [00:18,  7.12it/s]


epoch : 5, train_acc : 0.34861944777911164, acc : 0.3512475788593292. positive_acc : 0.727447216890595


131it [00:08, 14.58it/s]


epoch : 6, train_acc : 0.3426170468187275, acc : 0.36180421710014343. positive_acc : 0.7293666026871402


131it [00:17,  7.32it/s]


epoch : 7, train_acc : 0.3584633853541417, acc : 0.3301343619823456. positive_acc : 0.7072936660268714


131it [00:18,  7.01it/s]


epoch : 8, train_acc : 0.3663865546218487, acc : 0.32053741812705994. positive_acc : 0.7418426103646834


131it [00:05, 22.15it/s]


epoch : 9, train_acc : 0.3575030012004802, acc : 0.3243761956691742. positive_acc : 0.6976967370441459


131it [00:05, 24.14it/s]


epoch : 10, train_acc : 0.3697478991596639, acc : 0.36084452271461487. positive_acc : 0.7504798464491362


131it [00:05, 24.05it/s]


epoch : 11, train_acc : 0.36398559423769505, acc : 0.3243761956691742. positive_acc : 0.7312859884836852


131it [00:05, 24.06it/s]


epoch : 12, train_acc : 0.3678271308523409, acc : 0.34261035919189453. positive_acc : 0.7063339731285988


131it [00:05, 23.65it/s]


epoch : 13, train_acc : 0.38343337334933975, acc : 0.3675623834133148. positive_acc : 0.7504798464491362


131it [00:05, 23.89it/s]


epoch : 14, train_acc : 0.37671068427370946, acc : 0.3598848283290863. positive_acc : 0.7667946257197696


131it [00:05, 24.16it/s]


epoch : 15, train_acc : 0.37623049219687876, acc : 0.3435700535774231. positive_acc : 0.7332053742802304


131it [00:05, 24.18it/s]


epoch : 16, train_acc : 0.4007202881152461, acc : 0.2917466461658478. positive_acc : 0.6861804222648752


131it [00:05, 23.89it/s]


epoch : 17, train_acc : 0.3903961584633854, acc : 0.280230313539505. positive_acc : 0.6708253358925144


131it [00:05, 23.67it/s]


epoch : 18, train_acc : 0.38055222088835533, acc : 0.3435700535774231. positive_acc : 0.7399232245681382


131it [00:05, 24.04it/s]


epoch : 19, train_acc : 0.40672268907563025, acc : 0.35892513394355774. positive_acc : 0.7264875239923224


131it [00:05, 24.27it/s]


epoch : 20, train_acc : 0.39735894357743096, acc : 0.3493282198905945. positive_acc : 0.7591170825335892


131it [00:05, 23.75it/s]


epoch : 21, train_acc : 0.40336134453781514, acc : 0.36084452271461487. positive_acc : 0.7543186180422264


131it [00:05, 23.94it/s]


epoch : 22, train_acc : 0.40456182472989194, acc : 0.368522047996521. positive_acc : 0.7485604606525912


131it [00:05, 23.82it/s]


epoch : 23, train_acc : 0.4297719087635054, acc : 0.3272552788257599. positive_acc : 0.7207293666026872


131it [00:05, 23.89it/s]


epoch : 24, train_acc : 0.43337334933973587, acc : 0.32245680689811707. positive_acc : 0.7293666026871402


131it [00:05, 23.74it/s]


epoch : 25, train_acc : 0.4381752701080432, acc : 0.32821497321128845. positive_acc : 0.6938579654510557


131it [00:06, 21.60it/s]


epoch : 26, train_acc : 0.4350540216086435, acc : 0.3214971125125885. positive_acc : 0.6976967370441459


131it [00:07, 16.40it/s]


epoch : 27, train_acc : 0.4530612244897959, acc : 0.33973127603530884. positive_acc : 0.7178502879078695


63it [00:05, 12.22it/s]


KeyboardInterrupt: 

In [7]:
model_K = ViT_LRP_nan_excluded.VisionTransformer(sample_datas,split_list,seq_len=6, num_classes=9, embed_dim=32*3, depth=5,
                 num_heads=6, mlp_ratio=4., qkv_bias=False, mlp_head=False, drop_rate=0.2, attn_drop_rate=0)
model_K = model_K.to(device)

train_net(model_K,train_loader,test_loader,n_iter=100,device=device,mode='K',lr=0.0001,optimizer_cls = optim.AdamW)

131it [00:07, 17.01it/s]


epoch : 0, train_acc : 0.23145258103241295, acc : 0.3099808096885681. positive_acc : 0.663147792706334


131it [00:07, 17.83it/s]


epoch : 1, train_acc : 0.3018007202881152, acc : 0.2783109247684479. positive_acc : 0.6199616122840691


131it [00:04, 29.28it/s]


epoch : 2, train_acc : 0.3303721488595438, acc : 0.35604605078697205. positive_acc : 0.7226487523992322


131it [00:04, 26.21it/s]


epoch : 3, train_acc : 0.33373349339735897, acc : 0.3301343619823456. positive_acc : 0.6890595009596929


131it [00:05, 25.68it/s]


epoch : 4, train_acc : 0.3517406962785114, acc : 0.3579654395580292. positive_acc : 0.7284069097888676


131it [00:05, 25.06it/s]


epoch : 5, train_acc : 0.36422569027611046, acc : 0.3387715816497803. positive_acc : 0.7034548944337812


131it [00:05, 25.58it/s]


epoch : 6, train_acc : 0.3615846338535414, acc : 0.36084452271461487. positive_acc : 0.7370441458733206


131it [00:05, 24.96it/s]


epoch : 7, train_acc : 0.3536614645858343, acc : 0.3512475788593292. positive_acc : 0.7197696737044146


131it [00:04, 27.09it/s]


epoch : 8, train_acc : 0.37046818727491, acc : 0.33973127603530884. positive_acc : 0.708253358925144


131it [00:04, 27.10it/s]


epoch : 9, train_acc : 0.3752701080432173, acc : 0.34740883111953735. positive_acc : 0.7341650671785028


131it [00:05, 25.51it/s]


epoch : 10, train_acc : 0.35006002400960384, acc : 0.31765833497047424. positive_acc : 0.6842610364683301


131it [00:04, 28.63it/s]


epoch : 11, train_acc : 0.37286914765906365, acc : 0.31765833497047424. positive_acc : 0.6813819577735125


131it [00:04, 28.35it/s]


epoch : 12, train_acc : 0.3709483793517407, acc : 0.3464491367340088. positive_acc : 0.7159309021113244


131it [00:04, 28.55it/s]


epoch : 13, train_acc : 0.37743097238895557, acc : 0.3512475788593292. positive_acc : 0.7264875239923224


131it [00:06, 21.40it/s]


epoch : 14, train_acc : 0.3764705882352941, acc : 0.31957772374153137. positive_acc : 0.6833013435700576


131it [00:06, 18.81it/s]


epoch : 15, train_acc : 0.38415366146458585, acc : 0.3099808096885681. positive_acc : 0.7063339731285988


131it [00:04, 27.10it/s]


epoch : 16, train_acc : 0.3903961584633854, acc : 0.32245680689811707. positive_acc : 0.7024952015355086


131it [00:04, 27.26it/s]


epoch : 17, train_acc : 0.39663865546218485, acc : 0.33685219287872314. positive_acc : 0.7178502879078695


131it [00:05, 24.52it/s]


epoch : 18, train_acc : 0.39063625450180073, acc : 0.32821497321128845. positive_acc : 0.716890595009597


131it [00:06, 19.41it/s]


epoch : 19, train_acc : 0.3870348139255702, acc : 0.35604605078697205. positive_acc : 0.7226487523992322


131it [00:05, 25.25it/s]


epoch : 20, train_acc : 0.39615846338535415, acc : 0.33397310972213745. positive_acc : 0.710172744721689


131it [00:04, 26.79it/s]


epoch : 21, train_acc : 0.39135654261704683, acc : 0.32341650128364563. positive_acc : 0.7015355086372361


131it [00:05, 26.12it/s]


epoch : 22, train_acc : 0.3990396158463385, acc : 0.3522072732448578. positive_acc : 0.7332053742802304


131it [00:05, 23.56it/s]


epoch : 23, train_acc : 0.40264105642256903, acc : 0.34165066480636597. positive_acc : 0.7111324376199616


131it [00:06, 19.03it/s]


epoch : 24, train_acc : 0.412484993997599, acc : 0.32821497321128845. positive_acc : 0.6938579654510557


131it [00:06, 21.22it/s]


epoch : 25, train_acc : 0.40456182472989194, acc : 0.30038386583328247. positive_acc : 0.6612284069097889


131it [00:05, 25.20it/s]


epoch : 26, train_acc : 0.40600240096038415, acc : 0.32053741812705994. positive_acc : 0.7092130518234165


131it [00:05, 25.86it/s]


epoch : 27, train_acc : 0.4064825930372149, acc : 0.3272552788257599. positive_acc : 0.7034548944337812


53it [00:01, 26.87it/s]

In [None]:
model_K = ViT_LRP_nan_excluded.VisionTransformer(sample_datas,split_list,seq_len=6, num_classes=9, embed_dim=64*3, depth=6,
                 num_heads=6, mlp_ratio=4., qkv_bias=False, mlp_head=False, drop_rate=0.2, attn_drop_rate=0.2)
model_K = model_K.to(device)

train_net(model_K,train_loader,test_loader,n_iter=100,device=device,mode='K',lr=0.0001,optimizer_cls = optim.AdamW)

In [10]:
xx,(label_E,label_K,label_M) = next(iter(train_loader))

In [30]:
class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(233,80),
            nn.ReLU(),
            nn.Linear(80,30),
            nn.ReLU(),
            nn.Linear(30,9)
        )
    def forward(self,x):
        x[torch.isnan(x)] = 0
        return self.mlp(x)

In [31]:
MLP = MLP()

In [34]:
MLP = MLP.to(device)

In [35]:
train_net(MLP,train_loader,test_loader,n_iter=100,device=device,mode='K',lr=0.0001)

131it [00:02, 54.41it/s]


epoch : 0, train_acc : 0.15078031212484994, acc : 0.1593090146780014. positive_acc : 0.5873320537428023


131it [00:03, 38.61it/s]


epoch : 1, train_acc : 0.19447779111644659, acc : 0.22072936594486237. positive_acc : 0.5047984644913628


131it [00:03, 38.88it/s]


epoch : 2, train_acc : 0.21464585834333733, acc : 0.22744721174240112. positive_acc : 0.5038387715930902


131it [00:03, 36.92it/s]


epoch : 3, train_acc : 0.21824729891956782, acc : 0.23224566876888275. positive_acc : 0.5134357005758158


131it [00:03, 42.46it/s]


epoch : 4, train_acc : 0.22400960384153662, acc : 0.25335893034935. positive_acc : 0.5508637236084453


131it [00:03, 42.20it/s]


epoch : 5, train_acc : 0.24417767106842736, acc : 0.25335893034935. positive_acc : 0.5518234165067178


131it [00:02, 51.89it/s]


epoch : 6, train_acc : 0.26338535414165665, acc : 0.2879078686237335. positive_acc : 0.6218809980806143


131it [00:02, 47.82it/s]


epoch : 7, train_acc : 0.28667466986794715, acc : 0.3166986405849457. positive_acc : 0.654510556621881


131it [00:02, 45.98it/s]


epoch : 8, train_acc : 0.3085234093637455, acc : 0.3330134153366089. positive_acc : 0.6708253358925144


131it [00:02, 48.06it/s]


epoch : 9, train_acc : 0.31476590636254503, acc : 0.32821497321128845. positive_acc : 0.6765834932821497


131it [00:02, 44.28it/s]


epoch : 10, train_acc : 0.32148859543817526, acc : 0.34165066480636597. positive_acc : 0.7034548944337812


131it [00:02, 44.84it/s]


epoch : 11, train_acc : 0.33085234093637456, acc : 0.35028788447380066. positive_acc : 0.718809980806142


131it [00:02, 47.49it/s]


epoch : 12, train_acc : 0.32941176470588235, acc : 0.35892513394355774. positive_acc : 0.7207293666026872


131it [00:03, 40.41it/s]


epoch : 13, train_acc : 0.3339735894357743, acc : 0.35316696763038635. positive_acc : 0.7303262955854126


131it [00:03, 38.42it/s]


epoch : 14, train_acc : 0.3375750300120048, acc : 0.3541266620159149. positive_acc : 0.7159309021113244


131it [00:03, 40.96it/s]


epoch : 15, train_acc : 0.3397358943577431, acc : 0.3598848283290863. positive_acc : 0.7312859884836852


131it [00:03, 37.68it/s]


epoch : 16, train_acc : 0.3404561824729892, acc : 0.3550863564014435. positive_acc : 0.727447216890595


131it [00:03, 35.07it/s]


epoch : 17, train_acc : 0.3426170468187275, acc : 0.3598848283290863. positive_acc : 0.7293666026871402


131it [00:04, 32.17it/s]


epoch : 18, train_acc : 0.3430972388955582, acc : 0.3598848283290863. positive_acc : 0.7351247600767754


131it [00:03, 33.87it/s]


epoch : 19, train_acc : 0.34861944777911164, acc : 0.36084452271461487. positive_acc : 0.7322456813819578


131it [00:03, 40.11it/s]


epoch : 20, train_acc : 0.3471788715486194, acc : 0.36180421710014343. positive_acc : 0.7408829174664108


131it [00:02, 46.06it/s]


epoch : 21, train_acc : 0.35222088835534215, acc : 0.3550863564014435. positive_acc : 0.736084452975048


131it [00:01, 121.27it/s]


epoch : 22, train_acc : 0.3543817527010804, acc : 0.36372360587120056. positive_acc : 0.7437619961612284


131it [00:01, 117.79it/s]


epoch : 23, train_acc : 0.3536614645858343, acc : 0.368522047996521. positive_acc : 0.7408829174664108


131it [00:01, 104.77it/s]


epoch : 24, train_acc : 0.35558223289315727, acc : 0.36948174238204956. positive_acc : 0.7408829174664108


131it [00:00, 134.48it/s]


epoch : 25, train_acc : 0.3543817527010804, acc : 0.362763911485672. positive_acc : 0.7389635316698656


131it [00:00, 136.98it/s]


epoch : 26, train_acc : 0.35558223289315727, acc : 0.3675623834133148. positive_acc : 0.7428023032629558


131it [00:00, 138.28it/s]


epoch : 27, train_acc : 0.35294117647058826, acc : 0.36948174238204956. positive_acc : 0.7476007677543186


131it [00:00, 135.83it/s]


epoch : 28, train_acc : 0.3577430972388956, acc : 0.368522047996521. positive_acc : 0.7456813819577736


131it [00:00, 132.53it/s]


epoch : 29, train_acc : 0.3575030012004802, acc : 0.3646833002567291. positive_acc : 0.7428023032629558


131it [00:00, 132.38it/s]


epoch : 30, train_acc : 0.35414165666266506, acc : 0.3598848283290863. positive_acc : 0.744721689059501


131it [00:00, 135.81it/s]


epoch : 31, train_acc : 0.3543817527010804, acc : 0.37236082553863525. positive_acc : 0.7418426103646834


131it [00:00, 134.75it/s]


epoch : 32, train_acc : 0.3584633853541417, acc : 0.3733205199241638. positive_acc : 0.7514395393474088


131it [00:00, 134.94it/s]


epoch : 33, train_acc : 0.363265306122449, acc : 0.3656429946422577. positive_acc : 0.736084452975048


131it [00:01, 126.96it/s]


epoch : 34, train_acc : 0.3582232893157263, acc : 0.3675623834133148. positive_acc : 0.7428023032629558


131it [00:00, 135.69it/s]


epoch : 35, train_acc : 0.36470588235294116, acc : 0.3704414367675781. positive_acc : 0.7476007677543186


131it [00:00, 132.53it/s]


epoch : 36, train_acc : 0.3618247298919568, acc : 0.3675623834133148. positive_acc : 0.7437619961612284


131it [00:00, 137.14it/s]


epoch : 37, train_acc : 0.36206482593037215, acc : 0.368522047996521. positive_acc : 0.7456813819577736


131it [00:01, 130.21it/s]


epoch : 38, train_acc : 0.36686674669867947, acc : 0.362763911485672. positive_acc : 0.753358925143954


131it [00:00, 134.41it/s]


epoch : 39, train_acc : 0.36566626650660267, acc : 0.36660268902778625. positive_acc : 0.746641074856046


131it [00:00, 134.92it/s]


epoch : 40, train_acc : 0.3644657863145258, acc : 0.3714011311531067. positive_acc : 0.7523992322456814


131it [00:01, 126.90it/s]


epoch : 41, train_acc : 0.36470588235294116, acc : 0.3733205199241638. positive_acc : 0.7571976967370442


131it [00:00, 131.59it/s]


epoch : 42, train_acc : 0.3678271308523409, acc : 0.36948174238204956. positive_acc : 0.7523992322456814


131it [00:00, 138.03it/s]


epoch : 43, train_acc : 0.3651860744297719, acc : 0.3704414367675781. positive_acc : 0.755278310940499


131it [00:00, 134.34it/s]


epoch : 44, train_acc : 0.36998799519807923, acc : 0.37236082553863525. positive_acc : 0.744721689059501


131it [00:00, 134.89it/s]


epoch : 45, train_acc : 0.373109243697479, acc : 0.3761996030807495. positive_acc : 0.753358925143954


131it [00:00, 136.62it/s]


epoch : 46, train_acc : 0.36998799519807923, acc : 0.3761996030807495. positive_acc : 0.7523992322456814


131it [00:00, 132.60it/s]


epoch : 47, train_acc : 0.37262905162064824, acc : 0.37811899185180664. positive_acc : 0.7543186180422264


131it [00:00, 134.93it/s]


epoch : 48, train_acc : 0.3745498199279712, acc : 0.37811899185180664. positive_acc : 0.7523992322456814


131it [00:00, 136.40it/s]


epoch : 49, train_acc : 0.3723889555822329, acc : 0.3771592974662781. positive_acc : 0.7495201535508638


131it [00:01, 129.80it/s]


epoch : 50, train_acc : 0.3697478991596639, acc : 0.37811899185180664. positive_acc : 0.7562380038387716


131it [00:00, 139.37it/s]


epoch : 51, train_acc : 0.3716686674669868, acc : 0.3761996030807495. positive_acc : 0.7523992322456814


131it [00:01, 129.32it/s]


epoch : 52, train_acc : 0.3764705882352941, acc : 0.37811899185180664. positive_acc : 0.755278310940499


131it [00:00, 132.60it/s]


epoch : 53, train_acc : 0.37551020408163266, acc : 0.3761996030807495. positive_acc : 0.7591170825335892


131it [00:00, 136.75it/s]


epoch : 54, train_acc : 0.3810324129651861, acc : 0.3761996030807495. positive_acc : 0.7504798464491362


131it [00:00, 132.35it/s]


epoch : 55, train_acc : 0.3743097238895558, acc : 0.37523990869522095. positive_acc : 0.7591170825335892


131it [00:00, 135.55it/s]


epoch : 56, train_acc : 0.3786314525810324, acc : 0.38291746377944946. positive_acc : 0.7543186180422264


131it [00:00, 136.16it/s]


epoch : 57, train_acc : 0.37551020408163266, acc : 0.38099807500839233. positive_acc : 0.7562380038387716


131it [00:00, 134.33it/s]


epoch : 58, train_acc : 0.37623049219687876, acc : 0.37523990869522095. positive_acc : 0.7514395393474088


131it [00:00, 133.47it/s]


epoch : 59, train_acc : 0.382953181272509, acc : 0.38099807500839233. positive_acc : 0.7591170825335892


131it [00:00, 132.67it/s]


epoch : 60, train_acc : 0.3824729891956783, acc : 0.3550863564014435. positive_acc : 0.7303262955854126


131it [00:00, 139.09it/s]


epoch : 61, train_acc : 0.37623049219687876, acc : 0.37811899185180664. positive_acc : 0.755278310940499


131it [00:00, 136.61it/s]


epoch : 62, train_acc : 0.3791116446578631, acc : 0.3704414367675781. positive_acc : 0.7485604606525912


131it [00:00, 136.88it/s]


epoch : 63, train_acc : 0.3824729891956783, acc : 0.3646833002567291. positive_acc : 0.7408829174664108


131it [00:00, 134.20it/s]


epoch : 64, train_acc : 0.38343337334933975, acc : 0.3714011311531067. positive_acc : 0.7485604606525912


131it [00:01, 127.11it/s]


epoch : 65, train_acc : 0.3865546218487395, acc : 0.37236082553863525. positive_acc : 0.7456813819577736


131it [00:00, 133.91it/s]


epoch : 66, train_acc : 0.3865546218487395, acc : 0.3761996030807495. positive_acc : 0.7610364683301344


131it [00:01, 130.73it/s]


epoch : 67, train_acc : 0.3843937575030012, acc : 0.3714011311531067. positive_acc : 0.7495201535508638


131it [00:00, 139.07it/s]


epoch : 68, train_acc : 0.3858343337334934, acc : 0.3704414367675781. positive_acc : 0.7571976967370442


131it [00:01, 127.52it/s]


epoch : 69, train_acc : 0.38559423769507806, acc : 0.3819577693939209. positive_acc : 0.7600767754318618


131it [00:00, 135.82it/s]


epoch : 70, train_acc : 0.3896758703481393, acc : 0.38003838062286377. positive_acc : 0.7600767754318618


131it [00:00, 131.60it/s]


epoch : 71, train_acc : 0.38679471788715486, acc : 0.38291746377944946. positive_acc : 0.7571976967370442


131it [00:00, 134.16it/s]


epoch : 72, train_acc : 0.38871548619447777, acc : 0.36660268902778625. positive_acc : 0.7495201535508638


131it [00:00, 135.45it/s]


epoch : 73, train_acc : 0.38943577430972387, acc : 0.3733205199241638. positive_acc : 0.7476007677543186


131it [00:00, 132.01it/s]


epoch : 74, train_acc : 0.3865546218487395, acc : 0.38579654693603516. positive_acc : 0.7571976967370442


131it [00:00, 133.90it/s]


epoch : 75, train_acc : 0.3877551020408163, acc : 0.3771592974662781. positive_acc : 0.7543186180422264


131it [00:00, 133.37it/s]


epoch : 76, train_acc : 0.3884753901560624, acc : 0.35028788447380066. positive_acc : 0.7389635316698656


131it [00:00, 131.75it/s]


epoch : 77, train_acc : 0.38943577430972387, acc : 0.3570057451725006. positive_acc : 0.7351247600767754


131it [00:00, 134.91it/s]


epoch : 78, train_acc : 0.3896758703481393, acc : 0.36660268902778625. positive_acc : 0.7629558541266794


131it [00:01, 127.20it/s]


epoch : 79, train_acc : 0.3944777911164466, acc : 0.3714011311531067. positive_acc : 0.7591170825335892


131it [00:00, 137.54it/s]


epoch : 80, train_acc : 0.39207683073229294, acc : 0.3656429946422577. positive_acc : 0.7408829174664108


131it [00:00, 132.01it/s]


epoch : 81, train_acc : 0.39471788715486195, acc : 0.3704414367675781. positive_acc : 0.7523992322456814


131it [00:00, 136.76it/s]


epoch : 82, train_acc : 0.39207683073229294, acc : 0.34740883111953735. positive_acc : 0.7351247600767754


131it [00:00, 140.69it/s]


epoch : 83, train_acc : 0.39471788715486195, acc : 0.368522047996521. positive_acc : 0.7428023032629558


131it [00:00, 134.83it/s]


epoch : 84, train_acc : 0.38943577430972387, acc : 0.3733205199241638. positive_acc : 0.7523992322456814


131it [00:00, 133.80it/s]


epoch : 85, train_acc : 0.39327731092436974, acc : 0.3771592974662781. positive_acc : 0.7523992322456814


131it [00:01, 129.51it/s]


epoch : 86, train_acc : 0.39663865546218485, acc : 0.3550863564014435. positive_acc : 0.7293666026871402


131it [00:00, 132.22it/s]


epoch : 87, train_acc : 0.3990396158463385, acc : 0.3675623834133148. positive_acc : 0.7341650671785028


131it [00:00, 135.28it/s]


epoch : 88, train_acc : 0.39663865546218485, acc : 0.3541266620159149. positive_acc : 0.7341650671785028


131it [00:00, 131.50it/s]


epoch : 89, train_acc : 0.40024009603841537, acc : 0.36180421710014343. positive_acc : 0.7437619961612284


131it [00:00, 132.68it/s]


epoch : 90, train_acc : 0.39879951980792316, acc : 0.35892513394355774. positive_acc : 0.7571976967370442


131it [00:00, 136.76it/s]


epoch : 91, train_acc : 0.39951980792316927, acc : 0.3675623834133148. positive_acc : 0.7437619961612284


131it [00:00, 134.26it/s]


epoch : 92, train_acc : 0.3997599039615846, acc : 0.36660268902778625. positive_acc : 0.7514395393474088


131it [00:00, 136.58it/s]


epoch : 93, train_acc : 0.39951980792316927, acc : 0.36372360587120056. positive_acc : 0.7523992322456814


131it [00:00, 135.94it/s]


epoch : 94, train_acc : 0.4031212484993998, acc : 0.35604605078697205. positive_acc : 0.7332053742802304


131it [00:00, 132.32it/s]


epoch : 95, train_acc : 0.39879951980792316, acc : 0.3579654395580292. positive_acc : 0.7504798464491362


131it [00:01, 129.25it/s]


epoch : 96, train_acc : 0.40336134453781514, acc : 0.34452974796295166. positive_acc : 0.727447216890595


131it [00:01, 92.25it/s] 


epoch : 97, train_acc : 0.40528211284513804, acc : 0.3656429946422577. positive_acc : 0.7476007677543186


131it [00:01, 67.89it/s]


epoch : 98, train_acc : 0.40480192076830734, acc : 0.36180421710014343. positive_acc : 0.7332053742802304


131it [00:02, 56.65it/s]


KeyboardInterrupt: 

In [None]:
model_E = ViT_LRP_copy.VisionTransformer(seq_len=6, num_classes=9, embed_dim=72, depth=8,
                 num_heads=6, mlp_ratio=4., qkv_bias=False, mlp_head=False, drop_rate=0.1, attn_drop_rate=0.1)
model_K = ViT_LRP_copy.VisionTransformer(seq_len=6, num_classes=9, embed_dim=72, depth=8,
                 num_heads=6, mlp_ratio=4., qkv_bias=False, mlp_head=False, drop_rate=0.1, attn_drop_rate=0.1)
model_M = ViT_LRP_copy.VisionTransformer(seq_len=6, num_classes=9, embed_dim=72, depth=8,
                 num_heads=6, mlp_ratio=4., qkv_bias=False, mlp_head=False, drop_rate=0.1, attn_drop_rate=0.1)                 
for concated_data,(label_E,label_K,label_M) in train_loader:
    datas = batch_to_splited_datas(concated_data,split_list)
    emb_batch_list= batch_to_embbedings(datas,embbeding_networks) # can be used for contrastive loss
    # emb_batch_list : 임베딩 벡터들의 리스트. 얘를 이제 batch x seq x feature 행렬로 쌓음
    emb_batched_seq = torch.stack(emb_batch_list).transpose(0,1)
    attn_mask = make_attn_mask(emb_batched_seq)
    (E_score,K_score,M_score) = model_E(emb_batched_seq,attn_mask),model_K(emb_batched_seq,attn_mask),model_M(emb_batched_seq,attn_mask)
    



criterion = nn.CrossEntropyLoss()

loss_E = criterion(E_score,label_E)
loss_K = criterion(K_score,label_K)
loss_M = criterion(M_score,label_M)