In [1]:
import torch
import pandas as pd
import sklearn
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from tqdm import tqdm
from torch.utils.data import Dataset,DataLoader
from utils.nn_utils import *
from models.ViT import ViT_LRP_nan_excluded

X_datapaths = ['./preprocessed/prepared/nan/L2Y1.pkl','./preprocessed/prepared/nan/L2Y2.pkl','./preprocessed/prepared/nan/L2Y3.pkl','./preprocessed/prepared/nan/L2Y4.pkl','./preprocessed/prepared/nan/L2Y5.pkl','./preprocessed/prepared/nan/L2Y6.pkl',]
label_datapath = './preprocessed/prepared/nan/label.pkl'
#X_datapaths = ['./preprocessed/prepared/fill/L2Y1.pkl','./preprocessed/prepared/fill/L2Y2.pkl','./preprocessed/prepared/fill/L2Y3.pkl','./preprocessed/prepared/fill/L2Y4.pkl','./preprocessed/prepared/fill/L2Y5.pkl','./preprocessed/prepared/fill/L2Y6.pkl',]
#label_datapath = './preprocessed/prepared/fill/label.pkl'
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")



# read pickle
input_datas = [] # list of each input pandas dataframe
for datapath in X_datapaths:
    temp = pd.read_pickle(datapath)
    temp = temp.reset_index()
    input_datas.append(temp)

label_data = pd.read_pickle(label_datapath)
label_data = label_data.reset_index()



seq_len = len(input_datas)

label_data = label_data - 1


CLS2IDX = {
    0 : '1등급',
    1 : '2등급',
    2 : '3등급',
    3 : '4등급',
    4 : '5등급',
    5 : '6등급',
    6 : '7등급',
    7 : '8등급',
    8 : '9등급'
}
is_regression = False
X_trains, X_tests, y_train, y_test = make_splited_data(input_datas,label_data,is_regression=is_regression)

train_dataset = KELSDataSet(X_trains,y_train,is_regression=is_regression)
test_dataset = KELSDataSet(X_tests,y_test,is_regression=is_regression)


batch_size = 32
hidden_features = 100
embbed_dim = 72



train_loader = DataLoader(train_dataset,batch_size=batch_size,shuffle=True)
test_loader = DataLoader(test_dataset,batch_size=batch_size,shuffle=False)
split_list = train_dataset.split_list
#embedding_networks : 년차별로 맞는 mlp 리스트. 리스트 내용물에 따라 인풋 채널 개수 다름.

sample_loader = DataLoader(train_dataset,batch_size=batch_size,shuffle=True)
assert batch_size
(sample,label) = next(iter(sample_loader))
sample_datas = batch_to_splited_datas(sample,split_list)


#embbeding_networks = make_embbeding_networks(sample_datas,batch_size=batch_size,hidden_features = hidden_features,out_features=embbed_dim)
#embbeding network 사용법 : 
#1. train_loader에서 배치를 받는다. 이 배치는 (batch, 학생수, features)인데, features는 모든 년차별 feature가 concat된 상태.
#2. 이 배치를 batch_to_splited_datas에 넣는다. datas = batch_to_splited_datas(배치,split_list)
# 그러면 datas에 년차별 feature가 나누어진 리스트가 생기게 된다. 리스트 길이는 년차 길이고, 내용물은 행렬
# 3. 이 datas를 batch_to_embbedings 에 넣는다. 그럼 각 행렬이 embbeding_networks에 든 mlp를 통과하여 같은 크기의 행렬들을 리턴한다. 
# emb_batch_list= batch_to_embbedings(datas,embbeding_networks)
#4. emb_batch_list의 내용물이 바로 contrastive loss에 들어갈 벡터들이 된다. 

#5. lstm에 넣으려면 얘들을 seq_len 방향으로 쌓아야 한다. 
#emb_batched_seq = torch.stack(emb_batch_list).transpose(0,1) 으로 쌓고 lstm에 넣으면 완성..?

In [2]:
def accuracy_roughly(y_pred, y_label):
    if len(y_pred) != len(y_label):
        print("not available, fit size first")
        return
    cnt = 0
    correct = 0
    for pred, label in zip(y_pred, y_label):
        cnt += 1
        if abs(pred-label) <= 1:
            correct += 1
    return correct / cnt

In [9]:
def train_net(model,train_loader,test_loader,optimizer_cls = optim.AdamW, criterion = nn.CrossEntropyLoss(),
n_iter=10,device='cpu',lr = 0.001,weight_decay = 0.01,mode = None):
        
        train_losses = []
        train_acc = []
        val_acc = []
        #optimizer = optimizer_cls(model.parameters(),lr=lr,weight_decay=weight_decay)
        optimizer = optimizer_cls(model.parameters(),lr=lr)
        #scheduler = optim.lr_scheduler.MultiStepLR(optimizer,milestones=[25,40,60,80], gamma=0.5,last_epoch=-1)
        

        for epoch in range(n_iter):
                running_loss = 0.0
                model.train()
                n = 0
                n_acc = 0
                for i,(xx,(label_E,label_K,label_M)) in tqdm(enumerate(train_loader)):

                
                
                        xx = xx.to(device)
                        if mode == 'E':
                                yy = label_E
                        elif mode == 'K':
                                yy = label_K
                        elif mode == 'M':
                                yy = label_M
                        else:
                                assert True
                        
                        yy = yy.to(device)
                        

                        
                        
                
                        optimizer.zero_grad()
                        outputs = model(xx)

                        # Calculate Loss: softmax --> cross entropy loss
                        loss = criterion(outputs, yy)


                        # Getting gradients w.r.t. parameters
                        loss.backward()

                        # Updating parameters
                        optimizer.step()
                        
                        
                        
                        i += 1
                        n += len(xx)
                        _, y_pred = outputs.max(1)
                        n_acc += (yy == y_pred).float().sum().item()
                #scheduler.step()
                train_losses.append(running_loss/i)
                train_acc.append(n_acc/n)

                val_acc.append(eval_net(model,test_loader,device,mode = mode))

                print(f'epoch : {epoch}, train_acc : {train_acc[-1]}, validation_acc : {val_acc[-1]}',flush = True)

In [12]:
def eval_net(model,data_loader,device,mode=None):
    model.eval()
    ys = []
    ypreds = []
    for xx,(label_E,label_K,label_M) in data_loader:

                
                
        xx = xx.to(device)
        if mode == 'E':
            y = label_E
        elif mode == 'K':
            y = label_K
        elif mode == 'M':
            y = label_M
        else:
            assert True
        
        y = y.to(device)

        with torch.no_grad():
                score = model(xx)
                _,y_pred = score.max(1)
        ys.append(y)
        ypreds.append(y_pred)

    ys = torch.cat(ys)
    ypreds = torch.cat(ypreds)
    acc = accuracy_roughly(ypreds,ys)
    #acc= (ys == ypreds).float().sum() / len(ys)

    # print(sklearn.metrics.confusion_matrix(ys.numpy(),ypreds.numpy()))


    # print(sklearn.metrics.classification_report(ys.numpy(),ypreds.numpy()))
    

    return acc
    #return acc.item()

In [None]:
model_E = ViT_LRP_nan_excluded.VisionTransformer(sample_datas,split_list,seq_len=6, num_classes=9, embed_dim=16*3, depth=8,
                 num_heads=6, mlp_ratio=4., qkv_bias=False, mlp_head=False, drop_rate=0.2, attn_drop_rate=0.2)
model_E = model_E.to(device)

train_net(model_E,train_loader,test_loader,n_iter=100,device=device,mode='E',lr=0.0001,optimizer_cls = optim.AdamW)


In [13]:
model_K = ViT_LRP_nan_excluded.VisionTransformer(sample_datas,split_list,seq_len=6, num_classes=9, embed_dim=16*3, depth=8,
                 num_heads=6, mlp_ratio=4., qkv_bias=False, mlp_head=False, drop_rate=0.2, attn_drop_rate=0.2)
model_K = model_K.to(device)

train_net(model_K,train_loader,test_loader,n_iter=100,device=device,mode='K',lr=0.0001,optimizer_cls = optim.AdamW)


131it [00:09, 14.07it/s]


epoch : 0, train_acc : 0.21272509003601442, validation_acc : 0.527831094049904


131it [00:09, 14.13it/s]


epoch : 1, train_acc : 0.25618247298919566, validation_acc : 0.6564299424184261


131it [00:09, 14.12it/s]


epoch : 2, train_acc : 0.3106842737094838, validation_acc : 0.6813819577735125


131it [00:09, 14.14it/s]


epoch : 3, train_acc : 0.3176470588235294, validation_acc : 0.6919385796545106


131it [00:09, 13.95it/s]


epoch : 4, train_acc : 0.32821128451380555, validation_acc : 0.7428023032629558


131it [00:09, 13.87it/s]


epoch : 5, train_acc : 0.33733493397358943, validation_acc : 0.7197696737044146


131it [00:09, 14.16it/s]


epoch : 6, train_acc : 0.3404561824729892, validation_acc : 0.6986564299424184


131it [00:09, 14.25it/s]


epoch : 7, train_acc : 0.33997599039615845, validation_acc : 0.7303262955854126


131it [00:09, 14.19it/s]


epoch : 8, train_acc : 0.34549819927971187, validation_acc : 0.7226487523992322


131it [00:09, 14.24it/s]


epoch : 9, train_acc : 0.3498199279711885, validation_acc : 0.7216890595009597


131it [00:09, 14.24it/s]


epoch : 10, train_acc : 0.356062424969988, validation_acc : 0.72552783109405


131it [00:09, 14.23it/s]


epoch : 11, train_acc : 0.3596638655462185, validation_acc : 0.7389635316698656


131it [00:09, 14.24it/s]


epoch : 12, train_acc : 0.36134453781512604, validation_acc : 0.7428023032629558


131it [00:09, 14.28it/s]


epoch : 13, train_acc : 0.3551020408163265, validation_acc : 0.6813819577735125


131it [00:09, 14.24it/s]


epoch : 14, train_acc : 0.35222088835534215, validation_acc : 0.7303262955854126


131it [00:09, 14.25it/s]


epoch : 15, train_acc : 0.36062424969987994, validation_acc : 0.7245681381957774


131it [00:09, 14.25it/s]


epoch : 16, train_acc : 0.36494597839135656, validation_acc : 0.7418426103646834


131it [00:09, 14.10it/s]


epoch : 17, train_acc : 0.37623049219687876, validation_acc : 0.7293666026871402


131it [00:09, 14.18it/s]


epoch : 18, train_acc : 0.37478991596638656, validation_acc : 0.7408829174664108


131it [00:09, 14.22it/s]


epoch : 19, train_acc : 0.36542617046818726, validation_acc : 0.7543186180422264


131it [00:09, 14.20it/s]


epoch : 20, train_acc : 0.37671068427370946, validation_acc : 0.7437619961612284


131it [00:09, 14.17it/s]


epoch : 21, train_acc : 0.3702280912364946, validation_acc : 0.7399232245681382


131it [00:09, 14.21it/s]


epoch : 22, train_acc : 0.37695078031212487, validation_acc : 0.7571976967370442


131it [00:09, 14.19it/s]


epoch : 23, train_acc : 0.37286914765906365, validation_acc : 0.6900191938579654


131it [00:09, 14.19it/s]


epoch : 24, train_acc : 0.3831932773109244, validation_acc : 0.7428023032629558


131it [00:09, 14.18it/s]


epoch : 25, train_acc : 0.37478991596638656, validation_acc : 0.755278310940499


131it [00:09, 14.19it/s]


epoch : 26, train_acc : 0.3798319327731092, validation_acc : 0.7456813819577736


131it [00:09, 14.18it/s]


epoch : 27, train_acc : 0.37551020408163266, validation_acc : 0.7514395393474088


131it [00:09, 14.03it/s]


epoch : 28, train_acc : 0.37743097238895557, validation_acc : 0.736084452975048


131it [00:09, 14.19it/s]


epoch : 29, train_acc : 0.3788715486194478, validation_acc : 0.708253358925144


131it [00:09, 14.20it/s]


epoch : 30, train_acc : 0.38127250900360143, validation_acc : 0.736084452975048


131it [00:09, 14.19it/s]


epoch : 31, train_acc : 0.3899159663865546, validation_acc : 0.7456813819577736


131it [00:09, 14.21it/s]


epoch : 32, train_acc : 0.3870348139255702, validation_acc : 0.7178502879078695


131it [00:09, 14.21it/s]


epoch : 33, train_acc : 0.38559423769507806, validation_acc : 0.7629558541266794


131it [00:09, 14.19it/s]


epoch : 34, train_acc : 0.38487394957983195, validation_acc : 0.7293666026871402


131it [00:09, 14.18it/s]


epoch : 35, train_acc : 0.3884753901560624, validation_acc : 0.7264875239923224


131it [00:09, 14.21it/s]


epoch : 36, train_acc : 0.40336134453781514, validation_acc : 0.7562380038387716


131it [00:09, 14.20it/s]


epoch : 37, train_acc : 0.3870348139255702, validation_acc : 0.753358925143954


131it [00:09, 13.96it/s]


epoch : 38, train_acc : 0.3851140456182473, validation_acc : 0.7437619961612284


131it [00:09, 14.18it/s]


epoch : 39, train_acc : 0.3983193277310924, validation_acc : 0.7341650671785028


131it [00:09, 14.20it/s]


epoch : 40, train_acc : 0.40024009603841537, validation_acc : 0.7370441458733206


131it [00:09, 14.19it/s]


epoch : 41, train_acc : 0.3911164465786314, validation_acc : 0.7236084452975048


131it [00:09, 14.17it/s]


epoch : 42, train_acc : 0.4050420168067227, validation_acc : 0.7197696737044146


131it [00:09, 14.20it/s]


epoch : 43, train_acc : 0.4007202881152461, validation_acc : 0.7504798464491362


131it [00:09, 14.22it/s]


epoch : 44, train_acc : 0.4069627851140456, validation_acc : 0.7370441458733206


131it [00:09, 14.19it/s]


epoch : 45, train_acc : 0.40888355342136856, validation_acc : 0.7322456813819578


131it [00:09, 14.19it/s]


epoch : 46, train_acc : 0.4105642256902761, validation_acc : 0.763915547024952


131it [00:09, 14.19it/s]


epoch : 47, train_acc : 0.3937575030012005, validation_acc : 0.7485604606525912


131it [00:09, 14.20it/s]


epoch : 48, train_acc : 0.40408163265306124, validation_acc : 0.7332053742802304


131it [00:09, 14.20it/s]


epoch : 49, train_acc : 0.4141656662665066, validation_acc : 0.7696737044145874


131it [00:09, 14.19it/s]


epoch : 50, train_acc : 0.42400960384153663, validation_acc : 0.7581573896353166


131it [00:09, 14.18it/s]


epoch : 51, train_acc : 0.40960384153661467, validation_acc : 0.746641074856046


131it [00:09, 14.20it/s]


epoch : 52, train_acc : 0.4156062424969988, validation_acc : 0.746641074856046


131it [00:09, 14.19it/s]


epoch : 53, train_acc : 0.40888355342136856, validation_acc : 0.736084452975048


131it [00:09, 14.01it/s]


epoch : 54, train_acc : 0.4112845138055222, validation_acc : 0.7149712092130518


131it [00:09, 14.21it/s]


epoch : 55, train_acc : 0.4283313325330132, validation_acc : 0.7264875239923224


131it [00:09, 14.18it/s]


epoch : 56, train_acc : 0.42208883553421367, validation_acc : 0.7245681381957774


131it [00:09, 14.21it/s]


epoch : 57, train_acc : 0.42016806722689076, validation_acc : 0.7399232245681382


131it [00:09, 14.18it/s]


epoch : 58, train_acc : 0.42208883553421367, validation_acc : 0.7216890595009597


131it [00:09, 14.19it/s]


epoch : 59, train_acc : 0.42545018007202884, validation_acc : 0.7571976967370442


131it [00:09, 14.21it/s]


epoch : 60, train_acc : 0.43025210084033616, validation_acc : 0.736084452975048


131it [00:09, 14.20it/s]


epoch : 61, train_acc : 0.4261704681872749, validation_acc : 0.7677543186180422


131it [00:09, 14.20it/s]


epoch : 62, train_acc : 0.42208883553421367, validation_acc : 0.7197696737044146


131it [00:09, 14.21it/s]


epoch : 63, train_acc : 0.4391356542617047, validation_acc : 0.7485604606525912


131it [00:09, 14.20it/s]


epoch : 64, train_acc : 0.43649459783913563, validation_acc : 0.7571976967370442


131it [00:09, 14.22it/s]


epoch : 65, train_acc : 0.4326530612244898, validation_acc : 0.7514395393474088


131it [00:09, 14.18it/s]


epoch : 66, train_acc : 0.4261704681872749, validation_acc : 0.7408829174664108


131it [00:09, 13.99it/s]


epoch : 67, train_acc : 0.43673469387755104, validation_acc : 0.7399232245681382


131it [00:09, 14.15it/s]


epoch : 68, train_acc : 0.42448979591836733, validation_acc : 0.7284069097888676


131it [00:09, 14.17it/s]


epoch : 69, train_acc : 0.4391356542617047, validation_acc : 0.7514395393474088


131it [00:09, 14.20it/s]


epoch : 70, train_acc : 0.4429771908763505, validation_acc : 0.7504798464491362


131it [00:09, 14.12it/s]


epoch : 71, train_acc : 0.4456182472989196, validation_acc : 0.7476007677543186


131it [00:09, 13.62it/s]


epoch : 72, train_acc : 0.44657863145258103, validation_acc : 0.7399232245681382


131it [00:09, 14.09it/s]


epoch : 73, train_acc : 0.44537815126050423, validation_acc : 0.7303262955854126


131it [00:09, 14.21it/s]


epoch : 74, train_acc : 0.4410564225690276, validation_acc : 0.7408829174664108


131it [00:09, 14.24it/s]


epoch : 75, train_acc : 0.4417767106842737, validation_acc : 0.746641074856046


131it [00:09, 14.17it/s]


epoch : 76, train_acc : 0.45138055222088835, validation_acc : 0.7562380038387716


131it [00:09, 14.04it/s]


epoch : 77, train_acc : 0.45546218487394957, validation_acc : 0.7408829174664108


131it [00:09, 14.25it/s]


epoch : 78, train_acc : 0.4523409363745498, validation_acc : 0.761996161228407


131it [00:09, 14.14it/s]


epoch : 79, train_acc : 0.45930372148859544, validation_acc : 0.736084452975048


131it [00:09, 14.25it/s]


epoch : 80, train_acc : 0.4451380552220888, validation_acc : 0.7456813819577736


131it [00:09, 14.26it/s]


epoch : 81, train_acc : 0.458343337334934, validation_acc : 0.7504798464491362


131it [00:09, 14.08it/s]


epoch : 82, train_acc : 0.4530612244897959, validation_acc : 0.7399232245681382


131it [00:09, 14.06it/s]


epoch : 83, train_acc : 0.4645858343337335, validation_acc : 0.7610364683301344


131it [00:09, 14.25it/s]


epoch : 84, train_acc : 0.4595438175270108, validation_acc : 0.7437619961612284


131it [00:09, 14.24it/s]


epoch : 85, train_acc : 0.4590636254501801, validation_acc : 0.7303262955854126


131it [00:09, 14.24it/s]


epoch : 86, train_acc : 0.46146458583433375, validation_acc : 0.736084452975048


131it [00:09, 14.09it/s]


epoch : 87, train_acc : 0.4696278511404562, validation_acc : 0.7437619961612284


131it [00:09, 14.26it/s]


epoch : 88, train_acc : 0.46914765906362543, validation_acc : 0.7332053742802304


131it [00:09, 14.25it/s]


epoch : 89, train_acc : 0.4696278511404562, validation_acc : 0.7389635316698656


131it [00:09, 14.26it/s]


epoch : 90, train_acc : 0.4816326530612245, validation_acc : 0.7428023032629558


131it [00:09, 14.25it/s]


epoch : 91, train_acc : 0.4801920768307323, validation_acc : 0.7562380038387716


131it [00:09, 14.26it/s]


epoch : 92, train_acc : 0.46482593037214887, validation_acc : 0.718809980806142


131it [00:09, 14.22it/s]


epoch : 93, train_acc : 0.4773109243697479, validation_acc : 0.7485604606525912


131it [00:09, 14.24it/s]


epoch : 94, train_acc : 0.4669867947178872, validation_acc : 0.7495201535508638


131it [00:09, 14.25it/s]


epoch : 95, train_acc : 0.48403361344537815, validation_acc : 0.7389635316698656


131it [00:09, 14.07it/s]


epoch : 96, train_acc : 0.48187274909963984, validation_acc : 0.7485604606525912


131it [00:09, 14.23it/s]


epoch : 97, train_acc : 0.4929171668667467, validation_acc : 0.7351247600767754


131it [00:09, 14.22it/s]


epoch : 98, train_acc : 0.49123649459783914, validation_acc : 0.7581573896353166


131it [00:09, 14.21it/s]


epoch : 99, train_acc : 0.48259303721488594, validation_acc : 0.755278310940499


In [10]:
xx,(label_E,label_K,label_M) = next(iter(train_loader))

In [46]:
class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(233,80),
            nn.ReLU(),
            nn.Linear(80,30),
            nn.ReLU(),
            nn.Linear(30,9)
        )
    def forward(self,x):
        x[torch.isnan(x)] = 0
        return self.mlp(x)

In [47]:
MLP = MLP()

In [11]:
train_net(MLP,train_loader,test_loader,n_iter=100,device=device,mode='E',lr=0.0001)

NameError: name 'MLP' is not defined

In [None]:
model_E = ViT_LRP_copy.VisionTransformer(seq_len=6, num_classes=9, embed_dim=72, depth=8,
                 num_heads=6, mlp_ratio=4., qkv_bias=False, mlp_head=False, drop_rate=0.1, attn_drop_rate=0.1)
model_K = ViT_LRP_copy.VisionTransformer(seq_len=6, num_classes=9, embed_dim=72, depth=8,
                 num_heads=6, mlp_ratio=4., qkv_bias=False, mlp_head=False, drop_rate=0.1, attn_drop_rate=0.1)
model_M = ViT_LRP_copy.VisionTransformer(seq_len=6, num_classes=9, embed_dim=72, depth=8,
                 num_heads=6, mlp_ratio=4., qkv_bias=False, mlp_head=False, drop_rate=0.1, attn_drop_rate=0.1)                 
for concated_data,(label_E,label_K,label_M) in train_loader:
    datas = batch_to_splited_datas(concated_data,split_list)
    emb_batch_list= batch_to_embbedings(datas,embbeding_networks) # can be used for contrastive loss
    # emb_batch_list : 임베딩 벡터들의 리스트. 얘를 이제 batch x seq x feature 행렬로 쌓음
    emb_batched_seq = torch.stack(emb_batch_list).transpose(0,1)
    attn_mask = make_attn_mask(emb_batched_seq)
    (E_score,K_score,M_score) = model_E(emb_batched_seq,attn_mask),model_K(emb_batched_seq,attn_mask),model_M(emb_batched_seq,attn_mask)
    



criterion = nn.CrossEntropyLoss()

loss_E = criterion(E_score,label_E)
loss_K = criterion(K_score,label_K)
loss_M = criterion(M_score,label_M)